redvault-ai/llama_forge_rs/src/server/middleware.rs
2025-02-20 02:13:21 +01:00

85 lines
2.1 KiB
Rust

use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use axum::{body::Body, http::Request};
use pin_project_lite::pin_project;
use tower::{Layer, Service};
use tracing::Span;
use uuid::Uuid; // Make sure to include `uuid` crate in your Cargo.toml
#[derive(Debug, Clone, Default)]
pub struct LoggingLayer;
impl<S> Layer<S> for LoggingLayer {
type Service = LoggingService<S>;
fn layer(&self, inner: S) -> Self::Service {
LoggingService {
inner,
}
}
}
#[derive(Clone, Debug)]
pub struct LoggingService<T> {
inner: T,
}
impl<T> Service<Request<Body>> for LoggingService<T>
where
T: Service<Request<Body>>,
{
type Error = T::Error;
type Future = LoggingServiceFuture<T::Future>;
type Response = T::Response;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let request_uuid = Uuid::now_v7(); // Generate UUID v7
let span =
tracing::debug_span!("request", ?request_uuid, method=?req.method(), uri=?req.uri());
tracing::debug!(msg = "request started", uuid=?request_uuid);
LoggingServiceFuture {
inner: self.inner.call(req),
uuid: Arc::new(request_uuid), // Store UUID in an Arc for shared ownership
span: Arc::new(span),
}
}
}
pin_project! {
#[derive(Clone, Debug)]
pub struct LoggingServiceFuture<T> {
#[pin]
inner: T,
uuid: Arc<Uuid>, // Shared state between LoggingService and LoggingServiceFuture
span: Arc<Span>,
}
}
impl<T> Future for LoggingServiceFuture<T>
where
T: Future,
{
type Output = T::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.inner.poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(output) => {
tracing::debug!(msg = "request finished", uuid=?this.uuid);
Poll::Ready(output)
}
}
}
}