use axum::body::{Body, Bytes};
use axum::extract::{Query, State};
use axum::http::{HeaderMap, StatusCode};
use axum::response::Response;
use axum::routing::{get, post};
use axum::Router;
use base64::{engine::general_purpose, Engine as _};
use packed_struct::PackedStruct;
use serde::{Deserialize, Serialize};
use std::str::from_utf8;
use crate::db::get_all_fzr_by_name;
use crate::enums::{Rcode, RecordClass, RecordType};
use crate::reply::Reply;
use crate::resourcerecord::InternalResourceRecord;
use crate::servers::parse_query;
use crate::web::GoatState;
use crate::{Header, Question, HEADER_BYTES};
#[derive(Debug, Serialize)]
pub struct JSONQuestion {
    name: String,
    #[serde(rename = "type")]
    qtype: u16,
}
#[derive(Debug, Serialize)]
pub struct JSONRecord {
    name: String,
    #[serde(rename = "type")]
    qtype: u16,
    #[serde(rename = "TTL")]
    ttl: u32,
    #[serde(skip_serializing_if = "Option::is_none")]
    data: Option<String>,
}
#[derive(Debug, Default, Serialize)]
pub struct JSONResponse {
    status: u32,
    #[serde(rename = "tc")]
    truncated: bool,
    #[serde(rename = "rd")]
    recursive_desired: bool,
    #[serde(rename = "ra")]
    recursion_available: bool,
    ad: bool,
    #[serde(rename = "cd")]
    client_dnssec_disable: bool,
    #[serde(rename = "Question")]
    question: Vec<JSONQuestion>,
    #[serde(rename = "Answer")]
    answer: Vec<JSONRecord>,
    #[serde(rename = "Comment", skip_serializing_if = "Option::is_none")]
    comment: Option<String>,
    #[serde(skip_serializing_if = "Option::is_none")]
    error: Option<String>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
pub struct GetQueryString {
    dns: Option<String>,
    name: Option<String>,
    #[serde(alias = "type", default)]
    rrtype: Option<String>,
    #[serde(alias = "do", default)]
    dnssec: bool,
    #[serde(default)]
    cd: bool,
    #[serde(default)]
    id: u16,
}
impl Default for GetQueryString {
    fn default() -> Self {
        Self {
            dns: None,
            name: None,
            rrtype: Some("A".to_string()),
            dnssec: false,
            cd: false,
            id: 0,
        }
    }
}
#[derive(Debug)]
enum ResponseType {
    Json,
    Raw,
    Invalid,
}
async fn parse_raw_http(bytes: Vec<u8>) -> Result<GetQueryString, String> {
    let mut split_header: [u8; HEADER_BYTES] = [0; HEADER_BYTES];
    split_header.copy_from_slice(&bytes[0..HEADER_BYTES]);
    let header = match crate::Header::unpack(&split_header) {
        Ok(value) => value,
        Err(error) => {
            return Err(format!("Failed to parse header: {:?}", error));
        }
    };
    log::trace!("Parsed header: {:?}", header);
    let question = Question::from_packets(&bytes[HEADER_BYTES..])?;
    log::debug!("Question: {question:?}");
    let name = match from_utf8(&question.qname) {
        Ok(value) => value.to_string(),
        Err(_) => {
            format!("{:?}", question.qname)
        }
    };
    Ok(GetQueryString {
        dns: None,
        name: Some(name),
        rrtype: Some(question.qtype.to_string()),
        id: header.id,
        cd: header.cd,
        ..Default::default()
    })
}
fn get_response_type_from_headers(headers: HeaderMap) -> ResponseType {
    match headers.get("accept") {
        Some(value) => match value.to_str().unwrap_or("") {
            "application/dns-json" => ResponseType::Json,
            "application/dns-message" => ResponseType::Raw,
            _ => ResponseType::Invalid,
        },
        None => ResponseType::Invalid,
    }
}
fn response_406() -> Response {
    axum::response::Response::builder()
        .status(StatusCode::from_u16(406).unwrap())
        .header("Cache-Control", "max-age=3600")
        .body(Body::empty())
        .unwrap()
}
fn response_500() -> Response<Body> {
    axum::response::Response::builder()
        .status(StatusCode::from_u16(500).unwrap())
        .header("Cache-Control", "max-age=1")
        .body(Body::empty())
        .unwrap()
}
pub async fn handle_get(
    State(state): State<GoatState>,
    headers: HeaderMap,
    Query(query): Query<GetQueryString>,
) -> Response {
    let response_type: ResponseType = get_response_type_from_headers(headers);
    if let ResponseType::Invalid = response_type {
        return response_406();
    }
    let mut qname: String = "".to_string();
    let mut rrtype: String = "A".to_string();
    let mut id: u16 = 0;
    if let Some(dns) = query.dns {
        let bytes = match general_purpose::STANDARD.decode(dns) {
            Ok(val) => val,
            Err(err) => {
                log::debug!("Failed to parse DoH GET RAW: {err:?}");
                return response_500(); }
        };
        let query = parse_raw_http(bytes.clone()).await.unwrap();
        qname = query.name.unwrap();
        rrtype = query.rrtype.unwrap();
        id = query.id;
    } else if query.name.is_some() {
        qname = query.name.unwrap();
        rrtype = query.rrtype.unwrap_or("A".to_string());
    }
    let records = match get_all_fzr_by_name(
        &mut state.read().await.connpool.begin().await.unwrap(),
        &qname.clone(),
        RecordType::from(rrtype.clone()) as u16,
    )
    .await
    {
        Ok(value) => value,
        Err(error) => {
            log::error!("Failed to query {qname}/{}: {error:?}", rrtype);
            return response_500(); }
    };
    log::trace!("Completed record request...");
    let ttl = records.iter().map(|r| r.ttl).min();
    let ttl = match ttl {
        Some(val) => val.to_owned(),
        None => {
            log::trace!("Failed to get minimum TTL from query, using 1");
            1
        }
    };
    log::trace!("Returned records: {records:?}");
    match response_type {
        ResponseType::Invalid => response_500(),
        ResponseType::Json => {
            let answer = records
                .iter()
                .map(|rec| JSONRecord {
                    name: rec.name.clone(),
                    qtype: RecordType::from(rec.rrtype.clone()) as u16,
                    ttl: rec.ttl.to_owned(),
                    data: Some(rec.rdata.clone()),
                })
                .collect();
            let reply = JSONResponse {
                answer,
                status: Rcode::NoError as u32,
                truncated: false,
                recursive_desired: false,
                recursion_available: false,
                ad: false,
                client_dnssec_disable: false,
                question: vec![JSONQuestion {
                    name: qname,
                    qtype: RecordType::from(rrtype) as u16,
                }],
                ..Default::default()
            };
            let response = serde_json::to_string(&reply).unwrap();
            let response_builder = axum::response::Response::builder()
                .status(StatusCode::OK)
                .header("Content-type", "application/dns-json")
                .header("Cache-Control", format!("max-age={ttl}"));
            response_builder.body(Body::from(response)).unwrap()
        }
        ResponseType::Raw => {
            let answers: Vec<InternalResourceRecord> = records
                .iter()
                .filter_map(|r| {
                    let rec: Option<InternalResourceRecord> = match r.to_owned().try_into() {
                        Ok(val) => Some(val),
                        Err(_) => None,
                    };
                    rec
                })
                .collect();
            let reply = Reply {
                header: Header {
                    id,
                    qr: crate::enums::PacketType::Answer,
                    opcode: crate::enums::OpCode::Query,
                    authoritative: true, truncated: false,
                    recursion_desired: false,
                    recursion_available: false,
                    z: false,
                    ad: false, cd: false, rcode: crate::enums::Rcode::NoError,
                    qdcount: 1,
                    ancount: records.len() as u16,
                    nscount: 0,
                    arcount: 0,
                },
                question: Some(Question {
                    qname: qname.into(),
                    qtype: RecordType::from(rrtype),
                    qclass: RecordClass::Internet,
                }),
                answers,
                authorities: vec![], additional: vec![],  };
            match reply.as_bytes().await {
                Ok(value) => axum::response::Response::builder()
                    .status(StatusCode::OK)
                    .header("Content-type", "application/dns-message")
                    .header("Cache-Control", format!("max-age={ttl}"))
                    .body(Body::from(value))
                    .unwrap(),
                Err(err) => {
                    log::error!("Failed to turn DoH GET request into bytes: {err:?}");
                    response_500()
                }
            }
        }
    }
}
pub async fn handle_post(
    State(state): State<GoatState>,
    headers: HeaderMap,
    body: Bytes,
) -> Response {
    let response_type: ResponseType = get_response_type_from_headers(headers);
    if let ResponseType::Invalid = response_type {
        return response_406();
    };
    if let ResponseType::Json = response_type {
        return response_406();
    };
    let state_reader = state.read().await;
    let datastore = state_reader.tx.clone();
    let res = parse_query(
        datastore,
        body.len(),
        &body,
        state_reader.config.capture_packets,
    )
    .await;
    match res {
        Ok(mut reply) => {
            let bytes = match reply.as_bytes().await {
                Ok(value) => {
                    if value.len() > 65535 {
                        reply.header.truncated = true;
                        let mut bytes: Vec<u8> = reply.as_bytes().await.unwrap();
                        bytes.resize(65535, 0);
                        bytes
                    } else {
                        value
                    }
                }
                Err(error) => {
                    log::error!("Failed to turn DoH POST response into bytes! {error:?}");
                    return response_500();
                }
            };
            let ttl = reply.answers.iter().map(|a| a.ttl()).min();
            let ttl = match ttl {
                Some(ttl) => ttl.to_owned(),
                None => 1,
            };
            axum::response::Response::builder()
                .status(StatusCode::OK)
                .header("Content-type", "application/dns-message")
                .header("Cache-Control", format!("max-age={ttl}"))
                .body(Body::from(bytes))
                .unwrap()
        }
        Err(err) => {
            log::error!("Failed to parse DoH POST query: {err:?}");
            response_500()
        }
    }
}
pub fn new() -> Router<GoatState> {
    Router::new()
        .route("/", get(handle_get))
        .route("/", post(handle_post))
}