use std::{path::PathBuf, sync::Arc};

use dashmap::DashMap;
use http_body_util::{BodyExt as _, Full};
use hyper::{
    Method, Request, Response, StatusCode,
    body::{Bytes, Incoming},
    server::conn::http1,
    service::service_fn,
};
use hyper_util::rt::{TokioIo, TokioTimer};
use tokio::net::UnixListener;
use tracing_subscriber::{layer::SubscriberExt as _, util::SubscriberInitExt as _};

#[derive(Clone)]
struct Data {
    map: Arc<DashMap<String, Vec<u8>, ahash::RandomState>>,
}

#[tokio::main]
async fn main() {
    tracing_subscriber::registry()
        .with(
            tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| {
                concat!(env!("CARGO_CRATE_NAME"), "=debug,tower_http=warn,axum=warn").into()
            }),
        )
        .with(tracing_subscriber::fmt::layer().without_time())
        .init();

    let data = Data {
        map: Arc::new(DashMap::with_hasher(ahash::RandomState::new())),
    };

    let path = PathBuf::from("/tmp/afterimage.sock");
    let _ = tokio::fs::remove_file(&path).await;
    tokio::fs::create_dir_all(path.parent().unwrap())
        .await
        .unwrap();

    let uds = UnixListener::bind(path).expect("Failed to bind UNIX socket");

    loop {
        let (stream, _) = match uds.accept().await {
            Ok(pair) => pair,
            Err(err) => {
                tracing::warn!(err=?err, "failed to accept unix socket connection");

                continue;
            }
        };

        let io = TokioIo::new(stream);

        let service = service_fn({
            let data = data.clone();

            move |req| {
                let data = data.clone();

                handle(req, data)
            }
        });

        tokio::task::spawn(async move {
            if let Err(err) = http1::Builder::new()
                .timer(TokioTimer::new())
                .serve_connection(io, service)
                .await
            {
                tracing::error!("Error serving connection: {:?}", err);
            }
        });
    }
}

#[inline]
async fn handle(req: Request<Incoming>, data: Data) -> Result<Response<Full<Bytes>>, hyper::Error> {
    let method = req.method();
    let Some(query) = req.uri().query() else {
        return Ok(response(StatusCode::BAD_REQUEST, empty()));
    };
    let Some(key) = parse_query(query) else {
        return Ok(response(StatusCode::BAD_REQUEST, empty()));
    };

    let res = match *method {
        Method::GET => data.map.get(&key).map_or_else(
            || response(StatusCode::NOT_FOUND, empty()),
            |bytes| response(StatusCode::OK, full(&bytes[..])),
        ),
        Method::POST => {
            let bytes = req.collect().await?.to_bytes();

            data.map.insert(key, bytes.to_vec());

            response(StatusCode::OK, full(&bytes[..]))
        }
        _ => response(StatusCode::BAD_REQUEST, empty()),
    };

    Ok(res)
}

#[inline]
fn parse_query(query: &str) -> Option<String> {
    #[derive(serde::Deserialize)]
    struct Key {
        key: String,
    }

    serde_path_to_error::deserialize(serde_urlencoded::Deserializer::new(form_urlencoded::parse(
        query.as_bytes(),
    )))
    .ok()
    .map(|key: Key| key.key)
}

#[inline]
fn response(status: StatusCode, body: Full<Bytes>) -> Response<Full<Bytes>> {
    Response::builder().status(status).body(body).unwrap()
}

#[inline]
fn empty() -> Full<Bytes> {
    Full::new(Bytes::from_static(&[]))
}

#[inline]
fn full(bytes: &[u8]) -> Full<Bytes> {
    Full::new(Bytes::from(bytes.to_vec()))
}
