mod logging; use anyhow::anyhow; use axum::{ body::Body, http::{self, Request, Response}, response::IntoResponse, routing::any, Router, }; use futures; use itertools::Itertools; use reqwest::Client; use reqwest_middleware::{ClientBuilder, ClientWithMiddleware}; use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; use serde::Deserialize; use std::{collections::HashMap, net::SocketAddr, process::Stdio, sync::Arc}; use tokio::{ net::TcpStream, process::{Child, Command}, sync::Mutex, time::{sleep, Duration}, }; use tower_http; use tracing::Level; use std::sync::Once; pub static INIT: Once = Once::new(); pub fn initialize_logger() { INIT.call_once(|| { let env_filter = tracing_subscriber::EnvFilter::builder() .with_default_directive(tracing::level_filters::LevelFilter::INFO.into()) .from_env_lossy(); tracing_subscriber::fmt() .compact() .with_env_filter(env_filter) .init(); }); } // TODO Add valiation to the config // - e.g. check double taken ports etc #[derive(Clone, Debug, Deserialize)] struct Config { hardware: Hardware, models: Vec, } impl Config { fn default_from_pwd_yml() -> Self { // Read and parse the YAML configuration let config_str = std::fs::read_to_string("config.yaml").expect("Failed to read config.yaml"); serde_yaml::from_str::(&config_str) .expect("Failed to parse config.yaml") .pick_open_ports() } // TODO split up into raw deser config and "parsed"/"processed" config which always has a port // FIXME maybe just always initialize with a port ? fn pick_open_ports(self) -> Self { let mut config = self.clone(); for model in &mut config.models { if model.internal_port.is_none() { model.internal_port = Some( openport::pick_random_unused_port() .expect(format!("No open port found for {:?}", model).as_str()), ); } } config } } #[derive(Clone, Debug, Deserialize)] struct Hardware { ram: String, vram: String, } #[derive(Clone, Debug, Deserialize)] struct ModelConfig { #[allow(dead_code)] name: String, port: u16, internal_port: Option, env: HashMap, args: HashMap, vram_usage: String, ram_usage: String, } #[derive(Clone, Debug)] struct LlamaInstance { config: ModelConfig, process: Arc>, // busy: bool, } impl LlamaInstance { /// Retrieve a running instance from state or spawn a new one. pub async fn get_or_spawn_instance( model_config: &ModelConfig, state: SharedState, ) -> Result { let model_ram_usage = parse_size(&model_config.ram_usage).expect("Invalid ram_usage in model config"); let model_vram_usage = parse_size(&model_config.vram_usage).expect("Invalid vram_usage in model config"); { // First, check if we already have an instance. let state_guard = state.lock().await; if let Some(instance) = state_guard.instances.get(&model_config.port) { return Ok(instance.clone()); } } let mut state_guard = state.lock().await; tracing::info!(msg = "Shared state before spawn", ?state_guard); // If not enough free resources, stop some running instances. if (state_guard.used_ram + model_ram_usage > state_guard.total_ram) || (state_guard.used_vram + model_vram_usage > state_guard.total_vram) { let mut to_remove = Vec::new(); let instances_by_size = state_guard .instances .iter() .sorted_by(|(_, a), (_, b)| { parse_size(&b.config.vram_usage) .unwrap_or(0) .cmp(&parse_size(&a.config.vram_usage).unwrap_or(0)) }) .map(|(port, instance)| (*port, instance.clone())) .collect::>(); for (port, instance) in instances_by_size.iter() { tracing::info!("Stopping instance on port {}", port); let mut proc = instance.process.lock().await; proc.kill().await.ok(); to_remove.push(*port); state_guard.used_ram = state_guard .used_ram .saturating_sub(parse_size(&instance.config.ram_usage).unwrap_or(0)); state_guard.used_vram = state_guard .used_vram .saturating_sub(parse_size(&instance.config.vram_usage).unwrap_or(0)); if state_guard.used_ram + model_ram_usage <= state_guard.total_ram && state_guard.used_vram + model_vram_usage <= state_guard.total_vram { tracing::info!("Freed enough resources"); break; } } for port in to_remove { state_guard.instances.remove(&port); } } else { tracing::info!("Sufficient resources available"); } // Spawn a new instance. let instance = Self::spawn(&model_config).await?; state_guard.used_ram += model_ram_usage; state_guard.used_vram += model_vram_usage; state_guard .instances .insert(model_config.port, instance.clone()); sleep(Duration::from_millis(250)).await; instance.running_and_port_ready().await?; sleep(Duration::from_millis(250)).await; Ok(instance) } /// Spawn a new llama-server process based on the model configuration. pub async fn spawn(model_config: &ModelConfig) -> Result { let args = model_config .args .iter() .flat_map(|(k, v)| { if v == "true" { vec![format!("--{}", k)] } else { vec![format!("--{}", k), v.clone()] } }) .collect::>(); let internal_port = model_config .internal_port .expect("Internal port must be set"); let mut cmd = Command::new("llama-server"); cmd.kill_on_drop(true) .envs(model_config.env.clone()) .args(&args) .arg("--port") .arg(format!("{}", internal_port)) .stdout(Stdio::null()) .stderr(Stdio::null()); // Silence output – could be captured later via an API. tracing::info!("Starting llama-server with command: {:?}", cmd); let child = cmd.spawn().expect("Failed to start llama-server"); Ok(LlamaInstance { config: model_config.clone(), process: Arc::new(Mutex::new(child)), }) } pub async fn running_and_port_ready(&self) -> Result<(), anyhow::Error> { async fn is_running(instance: &LlamaInstance) -> Result<(), AppError> { let mut proc = instance.process.clone().lock_owned().await; match proc.try_wait()? { Some(exit_status) => { tracing::error!("Llama instance exited: {:?}", exit_status); Err(AppError::Unknown(anyhow!("Llama instance exited"))) } None => Ok(()), } } async fn wait_for_port(port: u16) -> Result<(), anyhow::Error> { for _ in 0..10 { if TcpStream::connect(("127.0.0.1", port)).await.is_ok() { return Ok(()); } sleep(Duration::from_millis(500)).await; } Err(anyhow!("Timeout waiting for port")) } is_running(self).await?; wait_for_port(self.config.internal_port.expect("No port picked?")).await?; Ok(()) } } #[derive(Clone, Debug)] pub struct InnerSharedState { total_ram: u64, total_vram: u64, used_ram: u64, used_vram: u64, instances: HashMap, } type SharedStateArc = Arc>; /// TODO migrate to dashmap + individual Arc> or Rwlocks #[derive(Clone, Debug, derive_more::Deref)] pub struct SharedState(SharedStateArc); impl SharedState { fn from_config(config: Config) -> Self { // Parse hardware resources let total_ram = parse_size(&config.hardware.ram).expect("Invalid RAM size in config"); let total_vram = parse_size(&config.hardware.vram).expect("Invalid VRAM size in config"); // Initialize shared state let shared_state = InnerSharedState { total_ram, total_vram, used_ram: 0, used_vram: 0, instances: HashMap::new(), }; Self(Arc::new(Mutex::new(shared_state))) } } /// Build an Axum app for a given model. fn create_app(model_config: &ModelConfig, state: SharedState) -> Router { Router::new() .route( "/", any({ let state = state.clone(); let model_config = model_config.clone(); move |req| handle_request(req, model_config.clone(), state.clone()) }), ) .route( "/*path", any({ let state = state.clone(); let model_config = model_config.clone(); move |req| handle_request(req, model_config.clone(), state.clone()) }), ) .layer( tower_http::trace::TraceLayer::new_for_http() .make_span_with(tower_http::trace::DefaultMakeSpan::new().include_headers(true)) .on_request(tower_http::trace::DefaultOnRequest::new().level(Level::DEBUG)) .on_response(tower_http::trace::DefaultOnResponse::new().level(Level::TRACE)) .on_eos(tower_http::trace::DefaultOnEos::new().level(Level::DEBUG)) .on_failure(tower_http::trace::DefaultOnFailure::new().level(Level::ERROR)), ) } #[tokio::main] async fn main() { initialize_logger(); let config = Config::default_from_pwd_yml(); let shared_state = SharedState::from_config(config.clone()); // For each model, set up an axum server listening on the specified port let mut handles = Vec::new(); for model_config in config.models { let state = shared_state.clone(); let model_config = model_config.clone(); let handle = tokio::spawn(async move { let model_config = model_config.clone(); let app = create_app(&model_config, state); let addr = SocketAddr::from(([0, 0, 0, 0], model_config.port)); tracing::info!(msg = "Listening", ?model_config); let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); axum::serve(listener, app.into_make_service()) .await .unwrap(); }); handles.push(handle); } futures::future::join_all(handles).await; } use thiserror::Error; #[derive(Error, Debug)] pub enum AppError { #[error("Axum Error")] AxumError(#[from] axum::Error), #[error("Axum Http Error")] AxumHttpError(#[from] axum::http::Error), #[error("Reqwest Error")] ReqwestError(#[from] reqwest::Error), #[error("Reqwest Middleware Error")] ReqwestMiddlewareError(#[from] reqwest_middleware::Error), #[error("Client Error")] ClientError(#[from] hyper::Error), #[error("Io Error")] IoError(#[from] std::io::Error), #[error("Unknown error")] Unknown(#[from] anyhow::Error), } // Tell axum how to convert `AppError` into a response. impl IntoResponse for AppError { fn into_response(self) -> axum::response::Response { tracing::error!("::AppError::into_response:: {:?}", self); ( http::StatusCode::INTERNAL_SERVER_ERROR, format!("Something went wrong: {:?}", self), ) .into_response() } } async fn proxy_request( req: Request, model_config: &ModelConfig, ) -> Result, AppError> { let retry_policy = ExponentialBackoff::builder() .retry_bounds( std::time::Duration::from_millis(500), std::time::Duration::from_secs(8), ) .jitter(reqwest_retry::Jitter::None) .base(2) .build_with_max_retries(8); let client: ClientWithMiddleware = ClientBuilder::new(Client::new()) .with(RetryTransientMiddleware::new_with_policy(retry_policy)) .build(); let internal_port = model_config .internal_port .expect("Internal port must be set"); let uri = format!( "http://127.0.0.1:{}{}", internal_port, req.uri() .path_and_query() .map(|pq| pq.as_str()) .unwrap_or("") ); let mut request_builder = client.request(req.method().clone(), &uri); // Forward all headers. for (name, value) in req.headers() { request_builder = request_builder.header(name, value); } let max_size = parse_size("1G").unwrap() as usize; let body_bytes = axum::body::to_bytes(req.into_body(), max_size).await?; let request = request_builder.body(body_bytes).build()?; let response = client.execute(request).await?; let mut builder = Response::builder().status(response.status()); for (name, value) in response.headers() { builder = builder.header(name, value); } let byte_stream = response.bytes_stream(); let body = Body::from_stream(byte_stream); Ok(builder.body(body)?) } async fn handle_request( req: Request, model_config: ModelConfig, state: SharedState, ) -> impl IntoResponse { let _instance = LlamaInstance::get_or_spawn_instance(&model_config, state).await?; let response = proxy_request(req, &model_config).await?; Ok::, AppError>(response) } fn parse_size(size_str: &str) -> Option { let mut num = String::new(); let mut unit = String::new(); for c in size_str.chars() { if c.is_digit(10) || c == '.' { num.push(c); } else { unit.push(c); } } let num: f64 = num.parse().ok()?; let multiplier = match unit.to_lowercase().as_str() { "g" | "gb" => 1024 * 1024 * 1024, "m" | "mb" => 1024 * 1024, "k" | "kb" => 1024, _ => panic!("Invalid Size"), }; let res = (num * multiplier as f64) as u64; Some(res) }