server/src/main.rs@main
raw
1use std::{path::PathBuf, sync::Arc};
2
3use dashmap::DashMap;
4use http_body_util::{BodyExt as _, Full};
5use hyper::{
6 Method, Request, Response, StatusCode,
7 body::{Bytes, Incoming},
8 server::conn::http1,
9 service::service_fn,
10};
11use hyper_util::rt::{TokioIo, TokioTimer};
12use tokio::net::UnixListener;
13use tracing_subscriber::{layer::SubscriberExt as _, util::SubscriberInitExt as _};
14
15#[derive(Clone)]
16struct Data {
17 map: Arc<DashMap<String, Vec<u8>, ahash::RandomState>>,
18}
19
20#[tokio::main]
21async fn main() {
22 tracing_subscriber::registry()
23 .with(
24 tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| {
25 concat!(env!("CARGO_CRATE_NAME"), "=debug,tower_http=warn,axum=warn").into()
26 }),
27 )
28 .with(tracing_subscriber::fmt::layer().without_time())
29 .init();
30
31 let data = Data {
32 map: Arc::new(DashMap::with_hasher(ahash::RandomState::new())),
33 };
34
35 let path = PathBuf::from("/tmp/afterimage.sock");
36 let _ = tokio::fs::remove_file(&path).await;
37 tokio::fs::create_dir_all(path.parent().unwrap())
38 .await
39 .unwrap();
40
41 let uds = UnixListener::bind(path).expect("Failed to bind UNIX socket");
42
43 loop {
44 let (stream, _) = match uds.accept().await {
45 Ok(pair) => pair,
46 Err(err) => {
47 tracing::warn!(err=?err, "failed to accept unix socket connection");
48
49 continue;
50 }
51 };
52
53 let io = TokioIo::new(stream);
54
55 let service = service_fn({
56 let data = data.clone();
57
58 move |req| {
59 let data = data.clone();
60
61 handle(req, data)
62 }
63 });
64
65 tokio::task::spawn(async move {
66 if let Err(err) = http1::Builder::new()
67 .timer(TokioTimer::new())
68 .serve_connection(io, service)
69 .await
70 {
71 tracing::error!("Error serving connection: {:?}", err);
72 }
73 });
74 }
75}
76
77#[inline]
78async fn handle(req: Request<Incoming>, data: Data) -> Result<Response<Full<Bytes>>, hyper::Error> {
79 let method = req.method();
80 let Some(query) = req.uri().query() else {
81 return Ok(response(StatusCode::BAD_REQUEST, empty()));
82 };
83 let Some(key) = parse_query(query) else {
84 return Ok(response(StatusCode::BAD_REQUEST, empty()));
85 };
86
87 let res = match *method {
88 Method::GET => data.map.get(&key).map_or_else(
89 || response(StatusCode::NOT_FOUND, empty()),
90 |bytes| response(StatusCode::OK, full(&bytes[..])),
91 ),
92 Method::POST => {
93 let bytes = req.collect().await?.to_bytes();
94
95 data.map.insert(key, bytes.to_vec());
96
97 response(StatusCode::OK, full(&bytes[..]))
98 }
99 _ => response(StatusCode::BAD_REQUEST, empty()),
100 };
101
102 Ok(res)
103}
104
105#[inline]
106fn parse_query(query: &str) -> Option<String> {
107 #[derive(serde::Deserialize)]
108 struct Key {
109 key: String,
110 }
111
112 serde_path_to_error::deserialize(serde_urlencoded::Deserializer::new(form_urlencoded::parse(
113 query.as_bytes(),
114 )))
115 .ok()
116 .map(|key: Key| key.key)
117}
118
119#[inline]
120fn response(status: StatusCode, body: Full<Bytes>) -> Response<Full<Bytes>> {
121 Response::builder().status(status).body(body).unwrap()
122}
123
124#[inline]
125fn empty() -> Full<Bytes> {
126 Full::new(Bytes::from_static(&[]))
127}
128
129#[inline]
130fn full(bytes: &[u8]) -> Full<Bytes> {
131 Full::new(Bytes::from(bytes.to_vec()))
132}
133