use concread::cowcell::asynch::CowCellReadTxn;
use packed_struct::prelude::*;
use std::io::Error;
use std::net::SocketAddr;
use std::str::{from_utf8, FromStr};
use std::time::Duration;
use tokio::io::{self, AsyncReadExt};
use tokio::net::{TcpListener, TcpStream, UdpSocket};
use tokio::sync::{broadcast, mpsc, oneshot};
use tokio::task::JoinHandle;
use tokio::time::timeout;
use crate::config::ConfigFile;
use crate::datastore::Command;
use crate::enums::{Agent, AgentState, PacketType, Rcode, RecordClass, RecordType};
use crate::reply::{reply_any, reply_builder, reply_nxdomain, Reply};
use crate::resourcerecord::{DNSCharString, InternalResourceRecord};
use crate::zones::ZoneRecord;
use crate::{Header, OpCode, Question, HEADER_BYTES, REPLY_TIMEOUT_MS, UDP_BUFFER_SIZE};
lazy_static! {
static ref LOCALHOST: std::net::IpAddr = std::net::IpAddr::from_str("127.0.0.1").expect("Failed to parse localhost IP address");
}
async fn check_for_shutdown(r: &Reply, allowed_shutdown: bool) -> Result<Reply, Option<Reply>> {
if let Some(q) = &r.question {
if q.qclass == RecordClass::Chaos {
let qname = from_utf8(&q.qname)
.map_err(|e| {
log::error!(
"Failed to parse qname from {:?}, this shouldn't be able to happen! {e:?}",
q.qname
);
})
.unwrap();
if qname == "shutdown" {
match allowed_shutdown {
true => {
log::info!("Got CHAOS shutdown, shutting down");
let mut chaos_reply = r.clone();
chaos_reply.answers.push(CHAOS_OK.clone());
return Ok(chaos_reply);
}
false => {
log::warn!("Got CHAOS shutdown, ignoring!");
let mut chaos_reply = r.clone();
chaos_reply.answers.push(CHAOS_NO.clone());
chaos_reply.header.rcode = Rcode::Refused;
return Err(Some(chaos_reply));
}
};
}
}
};
Err(None)
}
pub async fn udp_server(
config: CowCellReadTxn<ConfigFile>,
datastore_sender: mpsc::Sender<crate::datastore::Command>,
_agent_tx: broadcast::Sender<AgentState>,
) -> io::Result<()> {
let udp_sock = match UdpSocket::bind(
config
.dns_listener_address()
.expect("Failed to get DNS listener address on startup!"),
)
.await
{
Ok(value) => {
log::info!("Started UDP listener on {}:{}", config.address, config.port);
value
}
Err(error) => {
log::error!("Failed to start UDP listener: {:?}", error);
return Ok(());
}
};
let mut udp_buffer = [0; UDP_BUFFER_SIZE];
loop {
let (len, addr) = match udp_sock.recv_from(&mut udp_buffer).await {
Ok(value) => value,
Err(error) => {
log::error!("Error accepting connection via UDP: {:?}", error);
continue;
}
};
log::debug!("{:?} bytes received from {:?}", len, addr);
let udp_result = match timeout(
Duration::from_millis(REPLY_TIMEOUT_MS),
parse_query(
datastore_sender.clone(),
len,
&udp_buffer,
config.capture_packets,
),
)
.await
{
Ok(reply) => reply,
Err(_) => {
log::error!("Did not receive response from parse_query within 10 ms");
continue;
}
};
match udp_result {
Ok(mut r) => {
log::debug!("Result: {:?}", r);
let reply_bytes: Vec<u8> = match r.as_bytes().await {
Ok(value) => {
if value.len() > UDP_BUFFER_SIZE {
let mut response_bytes = value.to_vec();
response_bytes.truncate(UDP_BUFFER_SIZE);
r = r.check_set_truncated().await;
let r = r.as_bytes_udp().await;
r.unwrap_or(value)
} else {
value
}
}
Err(error) => {
log::error!("Failed to parse reply {:?} into bytes: {:?}", r, error);
continue;
}
};
log::trace!("reply_bytes: {:?}", reply_bytes);
let len = match udp_sock.send_to(&reply_bytes as &[u8], addr).await {
Ok(value) => value,
Err(err) => {
log::error!("Failed to send data back to {:?}: {:?}", addr, err);
return Ok(());
}
};
log::trace!("{:?} bytes sent", len);
}
Err(error) => log::error!("Error: {}", error),
}
}
}
pub async fn tcp_conn_handler(
stream: &mut TcpStream,
addr: SocketAddr,
datastore_sender: mpsc::Sender<Command>,
agent_tx: broadcast::Sender<AgentState>,
capture_packets: bool,
allowed_shutdown: bool,
) -> io::Result<()> {
let (mut reader, writer) = stream.split();
let msg_length: usize = reader.read_u16().await?.into();
log::debug!("msg_length={msg_length}");
let mut buf: Vec<u8> = vec![];
while buf.len() < msg_length {
let len = match reader.read_buf(&mut buf).await {
Ok(size) => size,
Err(error) => {
log::error!("Failed to read from TCP Stream: {:?}", error);
return Ok(());
}
};
if len > 0 {
log::debug!("Read {:?} bytes from TCP stream", len);
}
}
crate::utils::hexdump(buf.clone());
if buf.len() < msg_length {
log::warn!(
"Message length too short {}, wanted {}",
buf.len(),
msg_length + 2
);
} else {
log::info!("TCP Message length ftw!");
}
let buf = &buf[0..msg_length];
let result = match timeout(
Duration::from_millis(REPLY_TIMEOUT_MS),
parse_query(datastore_sender.clone(), msg_length, buf, capture_packets),
)
.await
{
Ok(reply) => reply,
Err(_) => {
log::error!("Did not receive response from parse_query within {REPLY_TIMEOUT_MS} ms");
return Ok(());
}
};
match result {
Ok(r) => {
log::debug!("TCP Result: {r:?}");
let r = match check_for_shutdown(&r, allowed_shutdown).await {
Err(reply) => match reply {
None => r,
Some(response) => response,
},
Ok(reply) => {
if let Err(error) = agent_tx.send(AgentState::Stopped {
agent: Agent::TCPServer,
}) {
eprintln!("Failed to send UDPServer shutdown message: {error:?}");
};
if let Err(error) = datastore_sender.send(Command::Shutdown).await {
eprintln!("Failed to send shutdown command to datastore.. {error:?}");
};
reply
}
};
let reply_bytes: Vec<u8> = match r.as_bytes().await {
Ok(value) => value,
Err(error) => {
log::error!("Failed to parse reply {:?} into bytes: {:?}", r, error);
return Ok(());
}
};
log::trace!("reply_bytes: {:?}", reply_bytes);
let reply_bytes = &reply_bytes as &[u8];
let response_length: u16 = reply_bytes.len() as u16;
let len = match writer.try_write(&response_length.to_be_bytes()) {
Ok(value) => value,
Err(err) => {
log::error!("Failed to send data back to {:?}: {:?}", addr, err);
return Ok(());
}
};
log::trace!("{:?} bytes sent", len);
let len = match writer.try_write(reply_bytes) {
Ok(value) => value,
Err(err) => {
log::error!("Failed to send data back to {:?}: {:?}", addr, err);
return Ok(());
}
};
log::trace!("{:?} bytes sent", len);
}
Err(error) => log::error!("Error: {}", error),
}
Ok(())
}
pub async fn tcp_server(
config: CowCellReadTxn<ConfigFile>,
tx: mpsc::Sender<crate::datastore::Command>,
agent_tx: broadcast::Sender<AgentState>,
) -> io::Result<()> {
let mut agent_rx = agent_tx.subscribe();
let tcpserver = match TcpListener::bind(
config
.dns_listener_address()
.expect("Failed to get DNS listener address on startup!"),
)
.await
{
Ok(value) => {
log::info!(
"Started TCP listener on {}",
config
.dns_listener_address()
.expect("Failed to get DNS listener address on startup!")
);
value
}
Err(error) => {
log::error!("Failed to start TCP Server: {:?}", error);
return Ok(());
}
};
let tcp_client_timeout = config.tcp_client_timeout;
let shutdown_ip_address_list = config.ip_allow_lists.shutdown.to_vec();
let capture_packets = config.capture_packets;
loop {
let (mut stream, addr) = match tcpserver.accept().await {
Ok(value) => value,
Err(error) => panic!("Couldn't get data from TcpStream: {:?}", error),
};
let allowed_shutdown = shutdown_ip_address_list.contains(&addr.ip());
log::debug!("TCP connection from {:?}", addr);
let loop_tx = tx.clone();
let loop_agent_tx = agent_tx.clone();
tokio::spawn(async move {
if timeout(
Duration::from_secs(tcp_client_timeout),
tcp_conn_handler(
&mut stream,
addr,
loop_tx,
loop_agent_tx,
capture_packets,
allowed_shutdown,
),
)
.await
.is_err()
{
log::warn!(
"TCP Connection from {addr:?} terminated after {} seconds.",
tcp_client_timeout
);
}
})
.await?;
if let Ok(agent_state) = agent_rx.try_recv() {
log::info!("Got agent state: {:?}", agent_state);
};
}
}
pub async fn parse_query(
datastore: tokio::sync::mpsc::Sender<crate::datastore::Command>,
len: usize,
buf: &[u8],
capture_packets: bool,
) -> Result<Reply, String> {
if capture_packets {
crate::packet_dumper::dump_bytes(
buf[0..len].into(),
crate::packet_dumper::DumpType::ClientRequest,
)
.await;
}
let mut split_header: [u8; HEADER_BYTES] = [0; HEADER_BYTES];
split_header.copy_from_slice(&buf[0..HEADER_BYTES]);
let header = match crate::Header::unpack(&split_header) {
Ok(value) => value,
Err(error) => {
return Err(format!("Failed to parse header: {:?}", error));
}
};
log::trace!("Buffer length: {}", len);
log::trace!("Parsed header: {:?}", header);
get_result(header, len, buf, datastore).await
}
lazy_static! {
static ref CHAOS_OK: InternalResourceRecord = InternalResourceRecord::TXT {
txtdata: DNSCharString::from("OK"),
ttl: 0,
class: RecordClass::Chaos,
};
static ref CHAOS_NO: InternalResourceRecord = InternalResourceRecord::TXT {
txtdata: DNSCharString::from("NO"),
ttl: 0,
class: RecordClass::Chaos,
};
}
async fn get_result(
header: Header,
len: usize,
buf: &[u8],
datastore: mpsc::Sender<crate::datastore::Command>,
) -> Result<Reply, String> {
log::trace!("called get_result(header={header}, len={len})");
if header.opcode != OpCode::Query {
return Err(format!("Invalid OPCODE, got {:?}", header.opcode));
};
let question = match Question::from_packets(&buf[HEADER_BYTES..len]) {
Ok(value) => {
log::trace!("Parsed question: {:?}", value);
value
}
Err(error) => {
log::debug!("Failed to parse question: {} id={}", error, header.id);
return reply_builder(header.id, Rcode::ServFail);
}
};
if !question.qtype.supported() {
log::debug!(
"Unsupported request: {} {:?}, returning NotImplemented",
from_utf8(&question.qname).unwrap_or("<unable to parse>"),
question.qtype,
);
return reply_builder(header.id, Rcode::NotImplemented);
}
#[allow(clippy::collapsible_if)]
if question.qclass == RecordClass::Chaos {
if &question.normalized_name()? == "shutdown" {
log::debug!("Got CHAOS shutdown!");
return Ok(Reply {
header,
question: Some(question),
answers: vec![],
authorities: vec![],
additional: vec![],
});
}
}
if let RecordType::ANY {} = question.qtype {
return reply_any(header.id, question);
};
let (tx_oneshot, rx_oneshot) = oneshot::channel();
let ds_req: Command = Command::GetRecord {
name: question.qname.clone(),
rrtype: question.qtype,
rclass: question.qclass,
resp: tx_oneshot,
};
match datastore.send(ds_req).await {
Ok(_) => log::trace!("Sent a request to the datastore!"),
Err(error) => log::error!("Error sending to datastore: {:?}", error),
};
let record: ZoneRecord = match rx_oneshot.await {
Ok(value) => match value {
Some(zr) => {
log::debug!("DS Response: {}", zr);
zr
}
None => {
log::debug!("No response from datastore");
return reply_nxdomain(header.id);
}
},
Err(error) => {
log::error!("Failed to get response from datastore: {:?}", error);
return reply_builder(header.id, Rcode::ServFail);
}
};
Ok(Reply {
header: Header {
id: header.id,
qr: PacketType::Answer,
opcode: header.opcode,
authoritative: true,
truncated: false, recursion_desired: header.recursion_desired,
recursion_available: header.recursion_available, z: false,
ad: true, cd: false, rcode: Rcode::NoError,
qdcount: 1,
ancount: record.typerecords.len() as u16,
nscount: 0,
arcount: 0,
},
question: Some(question),
answers: record.typerecords,
authorities: vec![], additional: vec![],
})
}
#[derive(Debug)]
pub struct Servers {
pub datastore: Option<JoinHandle<Result<(), String>>>,
pub udpserver: Option<JoinHandle<Result<(), Error>>>,
pub tcpserver: Option<JoinHandle<Result<(), Error>>>,
pub apiserver: Option<JoinHandle<Result<(), Error>>>,
pub agent_tx: broadcast::Sender<AgentState>,
}
impl Default for Servers {
fn default() -> Self {
let (agent_tx, _) = broadcast::channel(10000);
Self {
datastore: None,
udpserver: None,
tcpserver: None,
apiserver: None,
agent_tx,
}
}
}
impl Servers {
pub fn build(agent_tx: broadcast::Sender<AgentState>) -> Self {
Self {
agent_tx,
..Default::default()
}
}
pub fn with_apiserver(self, apiserver: Option<JoinHandle<Result<(), Error>>>) -> Self {
Self { apiserver, ..self }
}
pub fn with_datastore(self, datastore: JoinHandle<Result<(), String>>) -> Self {
Self {
datastore: Some(datastore),
..self
}
}
pub fn with_tcpserver(self, tcpserver: JoinHandle<Result<(), Error>>) -> Self {
Self {
tcpserver: Some(tcpserver),
..self
}
}
pub fn with_udpserver(self, udpserver: JoinHandle<Result<(), Error>>) -> Self {
Self {
udpserver: Some(udpserver),
..self
}
}
fn send_shutdown(&self, agent: Agent) {
log::info!("{agent:?} shut down");
if let Err(error) = self.agent_tx.send(AgentState::Stopped { agent }) {
eprintln!("Failed to send agent shutdown message: {error:?}");
};
}
pub fn all_finished(&self) -> bool {
let mut results = vec![];
if let Some(server) = &self.apiserver {
if server.is_finished() {
println!("Sending API Shutdown");
self.send_shutdown(Agent::API);
}
results.push(server.is_finished())
}
if let Some(server) = &self.datastore {
if server.is_finished() {
println!("Sending Datastore Shutdown");
self.send_shutdown(Agent::Datastore);
}
results.push(server.is_finished())
}
if let Some(server) = &self.tcpserver {
if server.is_finished() {
println!("Sending TCP Server Shutdown");
self.send_shutdown(Agent::TCPServer);
}
results.push(server.is_finished())
}
if let Some(server) = &self.udpserver {
if server.is_finished() {
println!("Sending UDP Server Shutdown");
self.send_shutdown(Agent::UDPServer);
}
results.push(server.is_finished())
}
results.iter().any(|&r| r)
}
}