use concread::cowcell::asynch::CowCellReadTxn;
use packed_struct::prelude::*;
use std::io::Error;
use std::net::SocketAddr;
use std::str::{from_utf8, FromStr};
use std::time::Duration;
use tokio::io::{self, AsyncReadExt};
use tokio::net::{TcpListener, TcpStream, UdpSocket};
use tokio::sync::{broadcast, mpsc, oneshot};
use tokio::task::JoinHandle;
use tokio::time::timeout;
use crate::config::ConfigFile;
use crate::datastore::Command;
use crate::enums::{Agent, AgentState, PacketType, Rcode, RecordClass, RecordType};
use crate::reply::{reply_any, reply_builder, reply_nxdomain, Reply};
use crate::resourcerecord::{DNSCharString, InternalResourceRecord};
use crate::zones::ZoneRecord;
use crate::{Header, OpCode, Question, HEADER_BYTES, REPLY_TIMEOUT_MS, UDP_BUFFER_SIZE};
lazy_static! {
    static ref LOCALHOST: std::net::IpAddr = std::net::IpAddr::from_str("127.0.0.1").expect("Failed to parse localhost IP address");
    }
async fn check_for_shutdown(r: &Reply, allowed_shutdown: bool) -> Result<Reply, Option<Reply>> {
    if let Some(q) = &r.question {
        if q.qclass == RecordClass::Chaos {
            let qname = from_utf8(&q.qname)
                .map_err(|e| {
                    log::error!(
                        "Failed to parse qname from {:?}, this shouldn't be able to happen! {e:?}",
                        q.qname
                    );
                })
                .unwrap();
            if qname == "shutdown" {
                match allowed_shutdown {
                    true => {
                        log::info!("Got CHAOS shutdown, shutting down");
                        let mut chaos_reply = r.clone();
                        chaos_reply.answers.push(CHAOS_OK.clone());
                        return Ok(chaos_reply);
                    }
                    false => {
                        log::warn!("Got CHAOS shutdown, ignoring!");
                        let mut chaos_reply = r.clone();
                        chaos_reply.answers.push(CHAOS_NO.clone());
                        chaos_reply.header.rcode = Rcode::Refused;
                        return Err(Some(chaos_reply));
                    }
                };
            }
        }
    };
    Err(None)
}
pub async fn udp_server(
    config: CowCellReadTxn<ConfigFile>,
    datastore_sender: mpsc::Sender<crate::datastore::Command>,
    _agent_tx: broadcast::Sender<AgentState>,
) -> io::Result<()> {
    let udp_sock = match UdpSocket::bind(
        config
            .dns_listener_address()
            .expect("Failed to get DNS listener address on startup!"),
    )
    .await
    {
        Ok(value) => {
            log::info!("Started UDP listener on {}:{}", config.address, config.port);
            value
        }
        Err(error) => {
            log::error!("Failed to start UDP listener: {:?}", error);
            return Ok(());
        }
    };
    let mut udp_buffer = [0; UDP_BUFFER_SIZE];
    loop {
        let (len, addr) = match udp_sock.recv_from(&mut udp_buffer).await {
            Ok(value) => value,
            Err(error) => {
                log::error!("Error accepting connection via UDP: {:?}", error);
                continue;
            }
        };
        log::debug!("{:?} bytes received from {:?}", len, addr);
        let udp_result = match timeout(
            Duration::from_millis(REPLY_TIMEOUT_MS),
            parse_query(
                datastore_sender.clone(),
                len,
                &udp_buffer,
                config.capture_packets,
            ),
        )
        .await
        {
            Ok(reply) => reply,
            Err(_) => {
                log::error!("Did not receive response from parse_query within 10 ms");
                continue;
            }
        };
        match udp_result {
            Ok(mut r) => {
                log::debug!("Result: {:?}", r);
                let reply_bytes: Vec<u8> = match r.as_bytes().await {
                    Ok(value) => {
                        if value.len() > UDP_BUFFER_SIZE {
                            let mut response_bytes = value.to_vec();
                            response_bytes.truncate(UDP_BUFFER_SIZE);
                            r = r.check_set_truncated().await;
                            let r = r.as_bytes_udp().await;
                            r.unwrap_or(value)
                        } else {
                            value
                        }
                    }
                    Err(error) => {
                        log::error!("Failed to parse reply {:?} into bytes: {:?}", r, error);
                        continue;
                    }
                };
                log::trace!("reply_bytes: {:?}", reply_bytes);
                let len = match udp_sock.send_to(&reply_bytes as &[u8], addr).await {
                    Ok(value) => value,
                    Err(err) => {
                        log::error!("Failed to send data back to {:?}: {:?}", addr, err);
                        return Ok(());
                    }
                };
                log::trace!("{:?} bytes sent", len);
            }
            Err(error) => log::error!("Error: {}", error),
        }
    }
}
pub async fn tcp_conn_handler(
    stream: &mut TcpStream,
    addr: SocketAddr,
    datastore_sender: mpsc::Sender<Command>,
    agent_tx: broadcast::Sender<AgentState>,
    capture_packets: bool,
    allowed_shutdown: bool,
) -> io::Result<()> {
    let (mut reader, writer) = stream.split();
    let msg_length: usize = reader.read_u16().await?.into();
    log::debug!("msg_length={msg_length}");
    let mut buf: Vec<u8> = vec![];
    while buf.len() < msg_length {
        let len = match reader.read_buf(&mut buf).await {
            Ok(size) => size,
            Err(error) => {
                log::error!("Failed to read from TCP Stream: {:?}", error);
                return Ok(());
            }
        };
        if len > 0 {
            log::debug!("Read {:?} bytes from TCP stream", len);
        }
    }
    crate::utils::hexdump(buf.clone());
    if buf.len() < msg_length {
        log::warn!(
            "Message length too short {}, wanted {}",
            buf.len(),
            msg_length + 2
        );
    } else {
        log::info!("TCP Message length ftw!");
    }
    let buf = &buf[0..msg_length];
    let result = match timeout(
        Duration::from_millis(REPLY_TIMEOUT_MS),
        parse_query(datastore_sender.clone(), msg_length, buf, capture_packets),
    )
    .await
    {
        Ok(reply) => reply,
        Err(_) => {
            log::error!("Did not receive response from parse_query within {REPLY_TIMEOUT_MS} ms");
            return Ok(());
        }
    };
    match result {
        Ok(r) => {
            log::debug!("TCP Result: {r:?}");
            let r = match check_for_shutdown(&r, allowed_shutdown).await {
                Err(reply) => match reply {
                    None => r,
                    Some(response) => response,
                },
                Ok(reply) => {
                    if let Err(error) = agent_tx.send(AgentState::Stopped {
                        agent: Agent::TCPServer,
                    }) {
                        eprintln!("Failed to send UDPServer shutdown message: {error:?}");
                    };
                    if let Err(error) = datastore_sender.send(Command::Shutdown).await {
                        eprintln!("Failed to send shutdown command to datastore.. {error:?}");
                    };
                    reply
                }
            };
            let reply_bytes: Vec<u8> = match r.as_bytes().await {
                Ok(value) => value,
                Err(error) => {
                    log::error!("Failed to parse reply {:?} into bytes: {:?}", r, error);
                    return Ok(());
                }
            };
            log::trace!("reply_bytes: {:?}", reply_bytes);
            let reply_bytes = &reply_bytes as &[u8];
            let response_length: u16 = reply_bytes.len() as u16;
            let len = match writer.try_write(&response_length.to_be_bytes()) {
                Ok(value) => value,
                Err(err) => {
                    log::error!("Failed to send data back to {:?}: {:?}", addr, err);
                    return Ok(());
                }
            };
            log::trace!("{:?} bytes sent", len);
            let len = match writer.try_write(reply_bytes) {
                Ok(value) => value,
                Err(err) => {
                    log::error!("Failed to send data back to {:?}: {:?}", addr, err);
                    return Ok(());
                }
            };
            log::trace!("{:?} bytes sent", len);
        }
        Err(error) => log::error!("Error: {}", error),
    }
    Ok(())
}
pub async fn tcp_server(
    config: CowCellReadTxn<ConfigFile>,
    tx: mpsc::Sender<crate::datastore::Command>,
    agent_tx: broadcast::Sender<AgentState>,
    ) -> io::Result<()> {
    let mut agent_rx = agent_tx.subscribe();
    let tcpserver = match TcpListener::bind(
        config
            .dns_listener_address()
            .expect("Failed to get DNS listener address on startup!"),
    )
    .await
    {
        Ok(value) => {
            log::info!(
                "Started TCP listener on {}",
                config
                    .dns_listener_address()
                    .expect("Failed to get DNS listener address on startup!")
            );
            value
        }
        Err(error) => {
            log::error!("Failed to start TCP Server: {:?}", error);
            return Ok(());
        }
    };
    let tcp_client_timeout = config.tcp_client_timeout;
    let shutdown_ip_address_list = config.ip_allow_lists.shutdown.to_vec();
    let capture_packets = config.capture_packets;
    loop {
        let (mut stream, addr) = match tcpserver.accept().await {
            Ok(value) => value,
            Err(error) => panic!("Couldn't get data from TcpStream: {:?}", error),
        };
        let allowed_shutdown = shutdown_ip_address_list.contains(&addr.ip());
        log::debug!("TCP connection from {:?}", addr);
        let loop_tx = tx.clone();
        let loop_agent_tx = agent_tx.clone();
        tokio::spawn(async move {
            if timeout(
                Duration::from_secs(tcp_client_timeout),
                tcp_conn_handler(
                    &mut stream,
                    addr,
                    loop_tx,
                    loop_agent_tx,
                    capture_packets,
                    allowed_shutdown,
                ),
            )
            .await
            .is_err()
            {
                log::warn!(
                    "TCP Connection from {addr:?} terminated after {} seconds.",
                    tcp_client_timeout
                );
            }
        })
        .await?;
        if let Ok(agent_state) = agent_rx.try_recv() {
            log::info!("Got agent state: {:?}", agent_state);
        };
    }
}
pub async fn parse_query(
    datastore: tokio::sync::mpsc::Sender<crate::datastore::Command>,
    len: usize,
    buf: &[u8],
    capture_packets: bool,
) -> Result<Reply, String> {
    if capture_packets {
        crate::packet_dumper::dump_bytes(
            buf[0..len].into(),
            crate::packet_dumper::DumpType::ClientRequest,
        )
        .await;
    }
    let mut split_header: [u8; HEADER_BYTES] = [0; HEADER_BYTES];
    split_header.copy_from_slice(&buf[0..HEADER_BYTES]);
    let header = match crate::Header::unpack(&split_header) {
        Ok(value) => value,
        Err(error) => {
            return Err(format!("Failed to parse header: {:?}", error));
        }
    };
    log::trace!("Buffer length: {}", len);
    log::trace!("Parsed header: {:?}", header);
    get_result(header, len, buf, datastore).await
}
lazy_static! {
    static ref CHAOS_OK: InternalResourceRecord = InternalResourceRecord::TXT {
        txtdata: DNSCharString::from("OK"),
        ttl: 0,
        class: RecordClass::Chaos,
    };
    static ref CHAOS_NO: InternalResourceRecord = InternalResourceRecord::TXT {
        txtdata: DNSCharString::from("NO"),
        ttl: 0,
        class: RecordClass::Chaos,
    };
}
async fn get_result(
    header: Header,
    len: usize,
    buf: &[u8],
    datastore: mpsc::Sender<crate::datastore::Command>,
) -> Result<Reply, String> {
    log::trace!("called get_result(header={header}, len={len})");
    if header.opcode != OpCode::Query {
        return Err(format!("Invalid OPCODE, got {:?}", header.opcode));
    };
    let question = match Question::from_packets(&buf[HEADER_BYTES..len]) {
        Ok(value) => {
            log::trace!("Parsed question: {:?}", value);
            value
        }
        Err(error) => {
            log::debug!("Failed to parse question: {} id={}", error, header.id);
            return reply_builder(header.id, Rcode::ServFail);
        }
    };
    if !question.qtype.supported() {
        log::debug!(
            "Unsupported request: {} {:?}, returning NotImplemented",
            from_utf8(&question.qname).unwrap_or("<unable to parse>"),
            question.qtype,
        );
        return reply_builder(header.id, Rcode::NotImplemented);
    }
    #[allow(clippy::collapsible_if)]
    if question.qclass == RecordClass::Chaos {
        if &question.normalized_name()? == "shutdown" {
            log::debug!("Got CHAOS shutdown!");
            return Ok(Reply {
                header,
                question: Some(question),
                answers: vec![],
                authorities: vec![],
                additional: vec![],
            });
        }
    }
    if let RecordType::ANY {} = question.qtype {
        return reply_any(header.id, question);
    };
    let (tx_oneshot, rx_oneshot) = oneshot::channel();
    let ds_req: Command = Command::GetRecord {
        name: question.qname.clone(),
        rrtype: question.qtype,
        rclass: question.qclass,
        resp: tx_oneshot,
    };
    match datastore.send(ds_req).await {
        Ok(_) => log::trace!("Sent a request to the datastore!"),
        Err(error) => log::error!("Error sending to datastore: {:?}", error),
    };
    let record: ZoneRecord = match rx_oneshot.await {
        Ok(value) => match value {
            Some(zr) => {
                log::debug!("DS Response: {}", zr);
                zr
            }
            None => {
                log::debug!("No response from datastore");
                return reply_nxdomain(header.id);
            }
        },
        Err(error) => {
            log::error!("Failed to get response from datastore: {:?}", error);
            return reply_builder(header.id, Rcode::ServFail);
        }
    };
    Ok(Reply {
        header: Header {
            id: header.id,
            qr: PacketType::Answer,
            opcode: header.opcode,
            authoritative: true,
            truncated: false, recursion_desired: header.recursion_desired,
            recursion_available: header.recursion_available, z: false,
            ad: true, cd: false, rcode: Rcode::NoError,
            qdcount: 1,
            ancount: record.typerecords.len() as u16,
            nscount: 0,
            arcount: 0,
        },
        question: Some(question),
        answers: record.typerecords,
        authorities: vec![], additional: vec![],
    })
}
#[derive(Debug)]
pub struct Servers {
    pub datastore: Option<JoinHandle<Result<(), String>>>,
    pub udpserver: Option<JoinHandle<Result<(), Error>>>,
    pub tcpserver: Option<JoinHandle<Result<(), Error>>>,
    pub apiserver: Option<JoinHandle<Result<(), Error>>>,
    pub agent_tx: broadcast::Sender<AgentState>,
}
impl Default for Servers {
    fn default() -> Self {
        let (agent_tx, _) = broadcast::channel(10000);
        Self {
            datastore: None,
            udpserver: None,
            tcpserver: None,
            apiserver: None,
            agent_tx,
        }
    }
}
impl Servers {
    pub fn build(agent_tx: broadcast::Sender<AgentState>) -> Self {
        Self {
            agent_tx,
            ..Default::default()
        }
    }
    pub fn with_apiserver(self, apiserver: Option<JoinHandle<Result<(), Error>>>) -> Self {
        Self { apiserver, ..self }
    }
    pub fn with_datastore(self, datastore: JoinHandle<Result<(), String>>) -> Self {
        Self {
            datastore: Some(datastore),
            ..self
        }
    }
    pub fn with_tcpserver(self, tcpserver: JoinHandle<Result<(), Error>>) -> Self {
        Self {
            tcpserver: Some(tcpserver),
            ..self
        }
    }
    pub fn with_udpserver(self, udpserver: JoinHandle<Result<(), Error>>) -> Self {
        Self {
            udpserver: Some(udpserver),
            ..self
        }
    }
    fn send_shutdown(&self, agent: Agent) {
        log::info!("{agent:?} shut down");
        if let Err(error) = self.agent_tx.send(AgentState::Stopped { agent }) {
            eprintln!("Failed to send agent shutdown message: {error:?}");
        };
    }
    pub fn all_finished(&self) -> bool {
        let mut results = vec![];
        if let Some(server) = &self.apiserver {
            if server.is_finished() {
                println!("Sending API Shutdown");
                self.send_shutdown(Agent::API);
            }
            results.push(server.is_finished())
        }
        if let Some(server) = &self.datastore {
            if server.is_finished() {
                println!("Sending Datastore Shutdown");
                self.send_shutdown(Agent::Datastore);
            }
            results.push(server.is_finished())
        }
        if let Some(server) = &self.tcpserver {
            if server.is_finished() {
                println!("Sending TCP Server Shutdown");
                self.send_shutdown(Agent::TCPServer);
            }
            results.push(server.is_finished())
        }
        if let Some(server) = &self.udpserver {
            if server.is_finished() {
                println!("Sending UDP Server Shutdown");
                self.send_shutdown(Agent::UDPServer);
            }
            results.push(server.is_finished())
        }
        results.iter().any(|&r| r)
    }
}