From e3a3ec38267351976547b9cd1456605d8488ea17 Mon Sep 17 00:00:00 2001 From: Tristan Druyen Date: Mon, 10 Feb 2025 23:22:31 +0100 Subject: [PATCH] Refactor llama_proxy_man --- Cargo.lock | 25 +- llama_proxy_man/Cargo.toml | 1 + llama_proxy_man/src/main.rs | 469 +++++++++++++++++++----------------- 3 files changed, 277 insertions(+), 218 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2af41fc..ac5cac1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1141,6 +1141,26 @@ dependencies = [ "syn 2.0.98", ] +[[package]] +name = "derive_more" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "093242cf7570c207c83073cf82f79706fe7b8317e98620a47d5be7c3d8497678" +dependencies = [ + "derive_more-impl", +] + +[[package]] +name = "derive_more-impl" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bda628edc44c4bb645fbe0f758797143e4e07926f7ebf4e9bdfbd3d2ce621df3" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.98", +] + [[package]] name = "digest" version = "0.10.7" @@ -2950,7 +2970,7 @@ dependencies = [ "chrono", "console_error_panic_hook", "dashmap", - "derive_more", + "derive_more 0.99.19", "futures", "futures-util", "gloo-net 0.5.0", @@ -3002,6 +3022,7 @@ version = "0.1.1" dependencies = [ "anyhow", "axum", + "derive_more 2.0.1", "futures", "hyper", "itertools 0.13.0", @@ -4781,7 +4802,7 @@ checksum = "df320f1889ac4ba6bc0cdc9c9af7af4bd64bb927bccdf32d81140dc1f9be12fe" dependencies = [ "bitflags 1.3.2", "cssparser", - "derive_more", + "derive_more 0.99.19", "fxhash", "log", "matches", diff --git a/llama_proxy_man/Cargo.toml b/llama_proxy_man/Cargo.toml index f1e8ecf..5d9df28 100644 --- a/llama_proxy_man/Cargo.toml +++ b/llama_proxy_man/Cargo.toml @@ -29,3 +29,4 @@ reqwest-retry = "0.6.1" reqwest-middleware = { version = "0.3.3", features = ["charset", "http2", "json", "multipart", "rustls-tls"] } itertools = "0.13.0" openport = { version = "0.1.1", features = ["rand"] } +derive_more = { version = "2.0.1", features = ["deref"] } diff --git a/llama_proxy_man/src/main.rs b/llama_proxy_man/src/main.rs index c37820b..533ab03 100644 --- a/llama_proxy_man/src/main.rs +++ b/llama_proxy_man/src/main.rs @@ -11,6 +11,8 @@ use axum::{ use futures; use itertools::Itertools; use reqwest::Client; +use reqwest_middleware::{ClientBuilder, ClientWithMiddleware}; +use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; use serde::Deserialize; use std::{collections::HashMap, net::SocketAddr, process::Stdio, sync::Arc}; use tokio::{ @@ -19,10 +21,7 @@ use tokio::{ sync::Mutex, time::{sleep, Duration}, }; -use tower_http::trace::{ - DefaultMakeSpan, DefaultOnEos, DefaultOnFailure, DefaultOnRequest, DefaultOnResponse, - TraceLayer, -}; +use tower_http; use tracing::Level; use std::sync::Once; @@ -42,6 +41,8 @@ pub fn initialize_logger() { }); } +// TODO Add valiation to the config +// - e.g. check double taken ports etc #[derive(Clone, Debug, Deserialize)] struct Config { hardware: Hardware, @@ -49,7 +50,18 @@ struct Config { } impl Config { + fn default_from_pwd_yml() -> Self { + // Read and parse the YAML configuration + let config_str = + std::fs::read_to_string("config.yaml").expect("Failed to read config.yaml"); + + serde_yaml::from_str::(&config_str) + .expect("Failed to parse config.yaml") + .pick_open_ports() + } + // TODO split up into raw deser config and "parsed"/"processed" config which always has a port + // FIXME maybe just always initialize with a port ? fn pick_open_ports(self) -> Self { let mut config = self.clone(); for model in &mut config.models { @@ -89,8 +101,154 @@ struct LlamaInstance { // busy: bool, } +impl LlamaInstance { + /// Retrieve a running instance from state or spawn a new one. + pub async fn get_or_spawn_instance( + model_config: &ModelConfig, + state: SharedState, + ) -> Result { + let model_ram_usage = + parse_size(&model_config.ram_usage).expect("Invalid ram_usage in model config"); + let model_vram_usage = + parse_size(&model_config.vram_usage).expect("Invalid vram_usage in model config"); + + { + // First, check if we already have an instance. + let state_guard = state.lock().await; + if let Some(instance) = state_guard.instances.get(&model_config.port) { + return Ok(instance.clone()); + } + } + + let mut state_guard = state.lock().await; + tracing::info!(msg = "Shared state before spawn", ?state_guard); + + // If not enough free resources, stop some running instances. + if (state_guard.used_ram + model_ram_usage > state_guard.total_ram) + || (state_guard.used_vram + model_vram_usage > state_guard.total_vram) + { + let mut to_remove = Vec::new(); + let instances_by_size = state_guard + .instances + .iter() + .sorted_by(|(_, a), (_, b)| { + parse_size(&b.config.vram_usage) + .unwrap_or(0) + .cmp(&parse_size(&a.config.vram_usage).unwrap_or(0)) + }) + .map(|(port, instance)| (*port, instance.clone())) + .collect::>(); + + for (port, instance) in instances_by_size.iter() { + tracing::info!("Stopping instance on port {}", port); + let mut proc = instance.process.lock().await; + proc.kill().await.ok(); + to_remove.push(*port); + state_guard.used_ram = state_guard + .used_ram + .saturating_sub(parse_size(&instance.config.ram_usage).unwrap_or(0)); + state_guard.used_vram = state_guard + .used_vram + .saturating_sub(parse_size(&instance.config.vram_usage).unwrap_or(0)); + if state_guard.used_ram + model_ram_usage <= state_guard.total_ram + && state_guard.used_vram + model_vram_usage <= state_guard.total_vram + { + tracing::info!("Freed enough resources"); + break; + } + } + for port in to_remove { + state_guard.instances.remove(&port); + } + } else { + tracing::info!("Sufficient resources available"); + } + + // Spawn a new instance. + let instance = Self::spawn(&model_config).await?; + state_guard.used_ram += model_ram_usage; + state_guard.used_vram += model_vram_usage; + state_guard + .instances + .insert(model_config.port, instance.clone()); + + sleep(Duration::from_millis(250)).await; + + instance.running_and_port_ready().await?; + + sleep(Duration::from_millis(250)).await; + + Ok(instance) + } + + /// Spawn a new llama-server process based on the model configuration. + pub async fn spawn(model_config: &ModelConfig) -> Result { + let args = model_config + .args + .iter() + .flat_map(|(k, v)| { + if v == "true" { + vec![format!("--{}", k)] + } else { + vec![format!("--{}", k), v.clone()] + } + }) + .collect::>(); + + let internal_port = model_config + .internal_port + .expect("Internal port must be set"); + + let mut cmd = Command::new("llama-server"); + cmd.kill_on_drop(true) + .envs(model_config.env.clone()) + .args(&args) + .arg("--port") + .arg(format!("{}", internal_port)) + .stdout(Stdio::null()) + .stderr(Stdio::null()); + + // Silence output – could be captured later via an API. + + tracing::info!("Starting llama-server with command: {:?}", cmd); + let child = cmd.spawn().expect("Failed to start llama-server"); + + Ok(LlamaInstance { + config: model_config.clone(), + process: Arc::new(Mutex::new(child)), + }) + } + + pub async fn running_and_port_ready(&self) -> Result<(), anyhow::Error> { + async fn is_running(instance: &LlamaInstance) -> Result<(), AppError> { + let mut proc = instance.process.clone().lock_owned().await; + match proc.try_wait()? { + Some(exit_status) => { + tracing::error!("Llama instance exited: {:?}", exit_status); + Err(AppError::Unknown(anyhow!("Llama instance exited"))) + } + None => Ok(()), + } + } + + async fn wait_for_port(port: u16) -> Result<(), anyhow::Error> { + for _ in 0..10 { + if TcpStream::connect(("127.0.0.1", port)).await.is_ok() { + return Ok(()); + } + sleep(Duration::from_millis(500)).await; + } + Err(anyhow!("Timeout waiting for port")) + } + + is_running(self).await?; + wait_for_port(self.config.internal_port.expect("No port picked?")).await?; + Ok(()) + } +} + #[derive(Clone, Debug)] -struct SharedState { +pub struct InnerSharedState { total_ram: u64, total_vram: u64, used_ram: u64, @@ -98,31 +256,67 @@ struct SharedState { instances: HashMap, } +type SharedStateArc = Arc>; + +/// TODO migrate to dashmap + individual Arc> or Rwlocks +#[derive(Clone, Debug, derive_more::Deref)] +pub struct SharedState(SharedStateArc); + +impl SharedState { + fn from_config(config: Config) -> Self { + // Parse hardware resources + let total_ram = parse_size(&config.hardware.ram).expect("Invalid RAM size in config"); + let total_vram = parse_size(&config.hardware.vram).expect("Invalid VRAM size in config"); + + // Initialize shared state + let shared_state = InnerSharedState { + total_ram, + total_vram, + used_ram: 0, + used_vram: 0, + instances: HashMap::new(), + }; + + Self(Arc::new(Mutex::new(shared_state))) + } +} + +/// Build an Axum app for a given model. +fn create_app(model_config: &ModelConfig, state: SharedState) -> Router { + Router::new() + .route( + "/", + any({ + let state = state.clone(); + let model_config = model_config.clone(); + move |req| handle_request(req, model_config.clone(), state.clone()) + }), + ) + .route( + "/*path", + any({ + let state = state.clone(); + let model_config = model_config.clone(); + move |req| handle_request(req, model_config.clone(), state.clone()) + }), + ) + .layer( + tower_http::trace::TraceLayer::new_for_http() + .make_span_with(tower_http::trace::DefaultMakeSpan::new().include_headers(true)) + .on_request(tower_http::trace::DefaultOnRequest::new().level(Level::DEBUG)) + .on_response(tower_http::trace::DefaultOnResponse::new().level(Level::TRACE)) + .on_eos(tower_http::trace::DefaultOnEos::new().level(Level::DEBUG)) + .on_failure(tower_http::trace::DefaultOnFailure::new().level(Level::ERROR)), + ) +} + #[tokio::main] async fn main() { - // TODO add autostart of models based on config - // abstract starting logic out of handler for this to allow seperate calls to start - // maybe add to SharedState & LLamaInstance ? - initialize_logger(); - // Read and parse the YAML configuration - let config_str = std::fs::read_to_string("config.yaml").expect("Failed to read config.yaml"); - let config: Config = serde_yaml::from_str::(&config_str) - .expect("Failed to parse config.yaml") - .pick_open_ports(); - // Parse hardware resources - let total_ram = parse_size(&config.hardware.ram).expect("Invalid RAM size in config"); - let total_vram = parse_size(&config.hardware.vram).expect("Invalid VRAM size in config"); + let config = Config::default_from_pwd_yml(); - // Initialize shared state - let shared_state = Arc::new(Mutex::new(SharedState { - total_ram, - total_vram, - used_ram: 0, - used_vram: 0, - instances: HashMap::new(), - })); + let shared_state = SharedState::from_config(config.clone()); // For each model, set up an axum server listening on the specified port let mut handles = Vec::new(); @@ -134,31 +328,7 @@ async fn main() { let handle = tokio::spawn(async move { let model_config = model_config.clone(); - let app = Router::new() - .route( - "/", - any({ - let state = state.clone(); - let model_config = model_config.clone(); - move |req| handle_request(req, model_config, state) - }), - ) - .route( - "/*path", - any({ - let state = state.clone(); - let model_config = model_config.clone(); - move |req| handle_request(req, model_config, state) - }), - ) - .layer( - TraceLayer::new_for_http() - .make_span_with(DefaultMakeSpan::new().include_headers(true)) - .on_request(DefaultOnRequest::new().level(Level::DEBUG)) - .on_response(DefaultOnResponse::new().level(Level::TRACE)) - .on_eos(DefaultOnEos::new().level(Level::DEBUG)) - .on_failure(DefaultOnFailure::new().level(Level::ERROR)), - ); + let app = create_app(&model_config, state); let addr = SocketAddr::from(([0, 0, 0, 0], model_config.port)); @@ -190,6 +360,8 @@ pub enum AppError { ReqwestMiddlewareError(#[from] reqwest_middleware::Error), #[error("Client Error")] ClientError(#[from] hyper::Error), + #[error("Io Error")] + IoError(#[from] std::io::Error), #[error("Unknown error")] Unknown(#[from] anyhow::Error), } @@ -206,202 +378,67 @@ impl IntoResponse for AppError { } } -async fn handle_request( +async fn proxy_request( req: Request, - model_config: ModelConfig, - state: Arc>, - // ) -> Result, anyhow::Error> { -) -> impl IntoResponse { - let model_ram_usage = parse_size(&model_config.ram_usage).expect("Invalid ram_usage"); - let model_vram_usage = parse_size(&model_config.vram_usage).expect("Invalid vram_usage"); - - let instance = { - let mut state = state.lock().await; - - // Check if instance is running - if let Some(instance) = state.instances.get_mut(&model_config.port) { - // instance.busy = true; - instance.to_owned() - } else { - // Check resources - tracing::info!(msg = "Current state", ?state); - if ((state.used_ram + model_ram_usage) > state.total_ram) - || ((state.used_vram + model_vram_usage) > state.total_vram) - { - // Stop other instances - let mut to_remove = Vec::new(); - // TODO Actual smart stopping logic - // - search for smallest single model to stop to get enough room - // - if not possible search for smallest number of models to stop with lowest - // amount of "overshot - let instances_by_size = - state - .instances - .clone() - .into_iter() - .sorted_by(|(_, el_a), (_, el_b)| { - Ord::cmp( - &parse_size(el_b.config.vram_usage.as_str()), - &parse_size(el_a.config.vram_usage.as_str()), - ) - }); - for (port, instance) in instances_by_size { - // if !instance.busy { - tracing::info!("Stopping instance on port {}", port); - let mut process = instance.process.lock().await; - process.kill().await.ok(); - to_remove.push(port); - state.used_ram -= parse_size(&instance.config.ram_usage).unwrap_or(0); - state.used_vram -= parse_size(&instance.config.vram_usage).unwrap_or(0); - if state.used_ram + model_ram_usage <= state.total_ram - && state.used_vram + model_vram_usage <= state.total_vram - { - tracing::info!("Should have enough ram now"); - break; - } - // } - } - for port in to_remove { - tracing::info!("Removing instance on port {}", port); - state.instances.remove(&port); - } - } else { - tracing::info!("Already enough res free"); - } - - // Start new instance - let args = model_config - .args - .iter() - .flat_map(|(k, v)| { - if v == "true" { - vec![format!("--{}", k)] - } else { - vec![format!("--{}", k), v.clone()] - } - }) - .collect::>(); - - let mut cmd = Command::new("llama-server"); - cmd.kill_on_drop(true); - cmd.envs(model_config.env.clone()); - cmd.args(&args); - // TODO use openport crate via pick_random_unused_port for determining these - cmd.arg("--port"); - cmd.arg(format!( - "{}", - model_config - .internal_port - .expect("Unexpected empty port, should've been picked") - )); - cmd.stdout(Stdio::null()).stderr(Stdio::null()); // TODO save output and allow retrieval via api - - tracing::info!("Starting llama-server with {:?}", cmd); - let process = Arc::new(Mutex::new( - cmd.spawn().expect("Failed to start llama-server"), - )); - - state.used_ram += model_ram_usage; - state.used_vram += model_vram_usage; - - let instance = LlamaInstance { - config: model_config.clone(), - process, - // busy: true, - }; - sleep(Duration::from_millis(500)).await; - state.instances.insert(model_config.port, instance.clone()); - - instance - } - }; - - // Wait for the instance to be ready - is_llama_instance_running(&instance).await?; - wait_for_port( - model_config - .internal_port - .expect("Unexpected empty port, should've been picked"), - ) - .await?; - - // Proxy the request - let retry_policy = reqwest_retry::policies::ExponentialBackoff::builder() - .retry_bounds(Duration::from_millis(500), Duration::from_secs(8)) + model_config: &ModelConfig, +) -> Result, AppError> { + let retry_policy = ExponentialBackoff::builder() + .retry_bounds( + std::time::Duration::from_millis(500), + std::time::Duration::from_secs(8), + ) .jitter(reqwest_retry::Jitter::None) .base(2) .build_with_max_retries(8); - let client = reqwest_middleware::ClientBuilder::new(Client::new()) - .with(reqwest_retry::RetryTransientMiddleware::new_with_policy( - retry_policy, - )) + let client: ClientWithMiddleware = ClientBuilder::new(Client::new()) + .with(RetryTransientMiddleware::new_with_policy(retry_policy)) .build(); + let internal_port = model_config + .internal_port + .expect("Internal port must be set"); let uri = format!( "http://127.0.0.1:{}{}", - model_config - .internal_port - .expect("Unexpected empty port, should've been picked"), - req.uri().path_and_query().map(|x| x.as_str()).unwrap_or("") + internal_port, + req.uri() + .path_and_query() + .map(|pq| pq.as_str()) + .unwrap_or("") ); let mut request_builder = client.request(req.method().clone(), &uri); - // let mut request_builder = reqwest::RequestBuilder::from_parts(client, req.method().clone()) - + // Forward all headers. for (name, value) in req.headers() { request_builder = request_builder.header(name, value); } - let body_bytes = axum::body::to_bytes( - req.into_body(), - parse_size("1G").unwrap().try_into().unwrap(), - ) - .await?; + let max_size = parse_size("1G").unwrap() as usize; + let body_bytes = axum::body::to_bytes(req.into_body(), max_size).await?; let request = request_builder.body(body_bytes).build()?; - // tracing::info!("Proxying request to {}", uri); let response = client.execute(request).await?; - // tracing::info!("Received response from {}", uri); let mut builder = Response::builder().status(response.status()); for (name, value) in response.headers() { builder = builder.header(name, value); } - // let bytes = response.bytes().await?; - let byte_stream = response.bytes_stream(); + let body = Body::from_stream(byte_stream); + Ok(builder.body(body)?) +} - let body = axum::body::Body::from_stream(byte_stream); +async fn handle_request( + req: Request, + model_config: ModelConfig, + state: SharedState, +) -> impl IntoResponse { + let _instance = LlamaInstance::get_or_spawn_instance(&model_config, state).await?; + let response = proxy_request(req, &model_config).await?; - tracing::debug!("streaming response on port: {}", model_config.port); - let response = builder.body(body)?; Ok::, AppError>(response) } -async fn wait_for_port(port: u16) -> Result<(), anyhow::Error> { - for _ in 0..10 { - if TcpStream::connect(("127.0.0.1", port)).await.is_ok() { - return Ok(()); - } - sleep(Duration::from_millis(500)).await; - } - Err(anyhow!("Timeout waiting for port")) -} - -// reimplement wait_for_port using a LlamaInstance, also check if the process is still running -async fn is_llama_instance_running(instance: &LlamaInstance) -> Result<(), anyhow::Error> { - match instance.process.to_owned().lock_owned().await.try_wait()? { - Some(exit_status) => { - let msg = "Llama instance exited"; - tracing::error!(msg, ?exit_status); - Err(anyhow!(msg)) - } - - None => Ok(()), - } -} - fn parse_size(size_str: &str) -> Option { let mut num = String::new(); let mut unit = String::new();