From ad0cd12877690121b293446ca94212ebd75e0d36 Mon Sep 17 00:00:00 2001 From: Tristan Druyen Date: Mon, 10 Feb 2025 23:40:27 +0100 Subject: [PATCH] refactor(proxy_man): Org & modularize app - Introduce new modules: `config.rs`, `error.rs`, `inference_process.rs`, `proxy.rs`, `state.rs`, `util.rs` - Update `logging.rs` to include static initialization of logging - Add lib.rs for including in other projects - Refactor `main.rs` to use new modules and improve code structure --- llama_proxy_man/src/config.rs | 48 +++ llama_proxy_man/src/error.rs | 36 ++ llama_proxy_man/src/inference_process.rs | 149 ++++++++ llama_proxy_man/src/lib.rs | 70 ++++ llama_proxy_man/src/logging.rs | 17 + llama_proxy_man/src/main.rs | 464 +---------------------- llama_proxy_man/src/proxy.rs | 69 ++++ llama_proxy_man/src/state.rs | 36 ++ llama_proxy_man/src/util.rs | 22 ++ 9 files changed, 452 insertions(+), 459 deletions(-) create mode 100644 llama_proxy_man/src/config.rs create mode 100644 llama_proxy_man/src/error.rs create mode 100644 llama_proxy_man/src/inference_process.rs create mode 100644 llama_proxy_man/src/lib.rs create mode 100644 llama_proxy_man/src/proxy.rs create mode 100644 llama_proxy_man/src/state.rs create mode 100644 llama_proxy_man/src/util.rs diff --git a/llama_proxy_man/src/config.rs b/llama_proxy_man/src/config.rs new file mode 100644 index 0000000..186cf16 --- /dev/null +++ b/llama_proxy_man/src/config.rs @@ -0,0 +1,48 @@ +use serde::Deserialize; +use std::{collections::HashMap, fs}; + +#[derive(Clone, Debug, Deserialize)] +pub struct AppConfig { + pub system_resources: SystemResources, + pub model_specs: Vec, +} + +impl AppConfig { + pub fn default_from_pwd_yml() -> Self { + let config_str = fs::read_to_string("config.yaml").expect("Failed to read config.yaml"); + serde_yaml::from_str::(&config_str) + .expect("Failed to parse config.yaml") + .assign_internal_ports() + } + + // Ensure every model has an internal port + pub fn assign_internal_ports(self) -> Self { + let mut config = self.clone(); + for model in &mut config.model_specs { + if model.internal_port.is_none() { + model.internal_port = Some( + openport::pick_random_unused_port() + .expect(&format!("No open port found for {:?}", model)), + ); + } + } + config + } +} + +#[derive(Clone, Debug, Deserialize)] +pub struct SystemResources { + pub ram: String, + pub vram: String, +} + +#[derive(Clone, Debug, Deserialize)] +pub struct ModelSpec { + pub name: String, + pub port: u16, + pub internal_port: Option, + pub env: HashMap, + pub args: HashMap, + pub vram_usage: String, + pub ram_usage: String, +} diff --git a/llama_proxy_man/src/error.rs b/llama_proxy_man/src/error.rs new file mode 100644 index 0000000..01c7fae --- /dev/null +++ b/llama_proxy_man/src/error.rs @@ -0,0 +1,36 @@ +use axum::{http, response::IntoResponse}; +use hyper; +use reqwest; +use reqwest_middleware; +use std::io; +use thiserror::Error; +use anyhow::Error as AnyError; + +#[derive(Error, Debug)] +pub enum AppError { + #[error("Axum Error")] + AxumError(#[from] axum::Error), + #[error("Axum Http Error")] + AxumHttpError(#[from] http::Error), + #[error("Reqwest Error")] + ReqwestError(#[from] reqwest::Error), + #[error("Reqwest Middleware Error")] + ReqwestMiddlewareError(#[from] reqwest_middleware::Error), + #[error("Client Error")] + ClientError(#[from] hyper::Error), + #[error("Io Error")] + IoError(#[from] io::Error), + #[error("Unknown error")] + Unknown(#[from] AnyError), +} + +impl IntoResponse for AppError { + fn into_response(self) -> axum::response::Response { + tracing::error!("::AppError::into_response: {:?}", self); + ( + http::StatusCode::INTERNAL_SERVER_ERROR, + format!("Something went wrong: {:?}", self), + ) + .into_response() + } +} diff --git a/llama_proxy_man/src/inference_process.rs b/llama_proxy_man/src/inference_process.rs new file mode 100644 index 0000000..4671d51 --- /dev/null +++ b/llama_proxy_man/src/inference_process.rs @@ -0,0 +1,149 @@ +use crate::{config::ModelSpec, error::AppError, state::AppState, util::parse_size}; +use anyhow::anyhow; +use itertools::Itertools; +use std::{process::Stdio, sync::Arc}; +use tokio::{ + net::TcpStream, + process::{Child, Command}, + sync::Mutex, + time::{sleep, Duration}, +}; + +#[derive(Clone, Debug)] +pub struct InferenceProcess { + pub spec: ModelSpec, + pub process: Arc>, +} + +impl InferenceProcess { + /// Retrieve a running process from state or spawn a new one. + pub async fn get_or_spawn( + spec: &ModelSpec, + state: AppState, + ) -> Result { + let required_ram = parse_size(&spec.ram_usage).expect("Invalid ram_usage in model spec"); + let required_vram = parse_size(&spec.vram_usage).expect("Invalid vram_usage in model spec"); + + { + let state_guard = state.lock().await; + if let Some(proc) = state_guard.processes.get(&spec.port) { + return Ok(proc.clone()); + } + } + + let mut state_guard = state.lock().await; + tracing::info!(msg = "App state before spawn", ?state_guard); + + // If not enough resources, stop some running processes. + if (state_guard.used_ram + required_ram > state_guard.total_ram) + || (state_guard.used_vram + required_vram > state_guard.total_vram) + { + let mut to_remove = Vec::new(); + let processes_by_usage = state_guard + .processes + .iter() + .sorted_by(|(_, a), (_, b)| { + parse_size(&b.spec.vram_usage) + .unwrap_or(0) + .cmp(&parse_size(&a.spec.vram_usage).unwrap_or(0)) + }) + .map(|(port, proc)| (*port, proc.clone())) + .collect::>(); + + for (port, proc) in processes_by_usage.iter() { + tracing::info!("Stopping process on port {}", port); + let mut lock = proc.process.lock().await; + lock.kill().await.ok(); + to_remove.push(*port); + state_guard.used_ram = state_guard + .used_ram + .saturating_sub(parse_size(&proc.spec.ram_usage).unwrap_or(0)); + state_guard.used_vram = state_guard + .used_vram + .saturating_sub(parse_size(&proc.spec.vram_usage).unwrap_or(0)); + if state_guard.used_ram + required_ram <= state_guard.total_ram + && state_guard.used_vram + required_vram <= state_guard.total_vram + { + tracing::info!("Freed enough resources"); + break; + } + } + for port in to_remove { + state_guard.processes.remove(&port); + } + } else { + tracing::info!("Sufficient resources available"); + } + + let proc = Self::spawn(spec).await?; + state_guard.used_ram += required_ram; + state_guard.used_vram += required_vram; + state_guard.processes.insert(spec.port, proc.clone()); + + sleep(Duration::from_millis(250)).await; + proc.wait_until_ready().await?; + sleep(Duration::from_millis(250)).await; + Ok(proc) + } + + /// Spawn a new inference process. + pub async fn spawn(spec: &ModelSpec) -> Result { + let args = spec + .args + .iter() + .flat_map(|(k, v)| { + if v == "true" { + vec![format!("--{}", k)] + } else { + vec![format!("--{}", k), v.clone()] + } + }) + .collect::>(); + + let internal_port = spec.internal_port.expect("Internal port must be set"); + + let mut cmd = Command::new("llama-server"); + cmd.kill_on_drop(true) + .envs(spec.env.clone()) + .args(&args) + .arg("--port") + .arg(format!("{}", internal_port)) + .stdout(Stdio::null()) + .stderr(Stdio::null()); + + tracing::info!("Starting llama-server with command: {:?}", cmd); + let child = cmd.spawn().expect("Failed to start llama-server"); + + Ok(InferenceProcess { + spec: spec.clone(), + process: Arc::new(Mutex::new(child)), + }) + } + + pub async fn wait_until_ready(&self) -> Result<(), anyhow::Error> { + async fn check_running(proc: &InferenceProcess) -> Result<(), AppError> { + let mut lock = proc.process.clone().lock_owned().await; + match lock.try_wait()? { + Some(exit_status) => { + tracing::error!("Inference process exited: {:?}", exit_status); + Err(AppError::Unknown(anyhow!("Inference process exited"))) + } + None => Ok(()), + } + } + + async fn wait_for_port(port: u16) -> Result<(), anyhow::Error> { + for _ in 0..10 { + if TcpStream::connect(("127.0.0.1", port)).await.is_ok() { + return Ok(()); + } + sleep(Duration::from_millis(500)).await; + } + Err(anyhow!("Timeout waiting for port")) + } + + check_running(self).await?; + wait_for_port(self.spec.internal_port.expect("Internal port must be set")).await?; + Ok(()) + } +} diff --git a/llama_proxy_man/src/lib.rs b/llama_proxy_man/src/lib.rs new file mode 100644 index 0000000..4a1b7e2 --- /dev/null +++ b/llama_proxy_man/src/lib.rs @@ -0,0 +1,70 @@ +pub mod config; +pub mod error; +pub mod inference_process; +pub mod logging; +pub mod proxy; +pub mod state; +pub mod util; + +use axum::{routing::any, Router}; +use config::{AppConfig, ModelSpec}; +use state::AppState; +use std::net::SocketAddr; +use tower_http::trace::{ + DefaultMakeSpan, DefaultOnEos, DefaultOnFailure, DefaultOnRequest, DefaultOnResponse, + TraceLayer, +}; +use tracing::Level; + +/// Creates an Axum application to handle inference requests for a specific model. +pub fn create_app(spec: &ModelSpec, state: AppState) -> Router { + Router::new() + .route( + "/", + any({ + let state = state.clone(); + let spec = spec.clone(); + move |req| proxy::handle_request(req, spec.clone(), state.clone()) + }), + ) + .route( + "/*path", + any({ + let state = state.clone(); + let spec = spec.clone(); + move |req| proxy::handle_request(req, spec.clone(), state.clone()) + }), + ) + .layer( + TraceLayer::new_for_http() + .make_span_with(DefaultMakeSpan::new().include_headers(true)) + .on_request(DefaultOnRequest::new().level(Level::DEBUG)) + .on_response(DefaultOnResponse::new().level(Level::TRACE)) + .on_eos(DefaultOnEos::new().level(Level::DEBUG)) + .on_failure(DefaultOnFailure::new().level(Level::ERROR)), + ) +} + +/// Starts an inference server for each model defined in the config. +pub async fn start_server(config: AppConfig) { + let state = AppState::from_config(config.clone()); + + let mut handles = Vec::new(); + for spec in config.model_specs { + let state = state.clone(); + let spec = spec.clone(); + + let handle = tokio::spawn(async move { + let app = create_app(&spec, state); + let addr = SocketAddr::from(([0, 0, 0, 0], spec.port)); + tracing::info!(msg = "Listening", ?spec); + let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); + axum::serve(listener, app.into_make_service()) + .await + .unwrap(); + }); + handles.push(handle); + } + + futures::future::join_all(handles).await; +} diff --git a/llama_proxy_man/src/logging.rs b/llama_proxy_man/src/logging.rs index 33e45fc..9c93185 100644 --- a/llama_proxy_man/src/logging.rs +++ b/llama_proxy_man/src/logging.rs @@ -5,6 +5,8 @@ use std::{ task::{Context, Poll}, }; +use std::sync::Once; + use axum::{body::Body, http::Request}; use pin_project_lite::pin_project; use tower::{Layer, Service}; @@ -81,3 +83,18 @@ where } } } + +static INIT: Once = Once::new(); + +pub fn initialize_logger() { + INIT.call_once(|| { + let env_filter = tracing_subscriber::EnvFilter::builder() + .with_default_directive(tracing::level_filters::LevelFilter::INFO.into()) + .from_env_lossy(); + + tracing_subscriber::fmt() + .compact() + .with_env_filter(env_filter) + .init(); + }); +} diff --git a/llama_proxy_man/src/main.rs b/llama_proxy_man/src/main.rs index 533ab03..9e98ffd 100644 --- a/llama_proxy_man/src/main.rs +++ b/llama_proxy_man/src/main.rs @@ -1,465 +1,11 @@ -mod logging; - -use anyhow::anyhow; -use axum::{ - body::Body, - http::{self, Request, Response}, - response::IntoResponse, - routing::any, - Router, -}; -use futures; -use itertools::Itertools; -use reqwest::Client; -use reqwest_middleware::{ClientBuilder, ClientWithMiddleware}; -use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; -use serde::Deserialize; -use std::{collections::HashMap, net::SocketAddr, process::Stdio, sync::Arc}; -use tokio::{ - net::TcpStream, - process::{Child, Command}, - sync::Mutex, - time::{sleep, Duration}, -}; -use tower_http; -use tracing::Level; - -use std::sync::Once; - -pub static INIT: Once = Once::new(); - -pub fn initialize_logger() { - INIT.call_once(|| { - let env_filter = tracing_subscriber::EnvFilter::builder() - .with_default_directive(tracing::level_filters::LevelFilter::INFO.into()) - .from_env_lossy(); - - tracing_subscriber::fmt() - .compact() - .with_env_filter(env_filter) - .init(); - }); -} - -// TODO Add valiation to the config -// - e.g. check double taken ports etc -#[derive(Clone, Debug, Deserialize)] -struct Config { - hardware: Hardware, - models: Vec, -} - -impl Config { - fn default_from_pwd_yml() -> Self { - // Read and parse the YAML configuration - let config_str = - std::fs::read_to_string("config.yaml").expect("Failed to read config.yaml"); - - serde_yaml::from_str::(&config_str) - .expect("Failed to parse config.yaml") - .pick_open_ports() - } - - // TODO split up into raw deser config and "parsed"/"processed" config which always has a port - // FIXME maybe just always initialize with a port ? - fn pick_open_ports(self) -> Self { - let mut config = self.clone(); - for model in &mut config.models { - if model.internal_port.is_none() { - model.internal_port = Some( - openport::pick_random_unused_port() - .expect(format!("No open port found for {:?}", model).as_str()), - ); - } - } - config - } -} - -#[derive(Clone, Debug, Deserialize)] -struct Hardware { - ram: String, - vram: String, -} - -#[derive(Clone, Debug, Deserialize)] -struct ModelConfig { - #[allow(dead_code)] - name: String, - port: u16, - internal_port: Option, - env: HashMap, - args: HashMap, - vram_usage: String, - ram_usage: String, -} - -#[derive(Clone, Debug)] -struct LlamaInstance { - config: ModelConfig, - process: Arc>, - // busy: bool, -} - -impl LlamaInstance { - /// Retrieve a running instance from state or spawn a new one. - pub async fn get_or_spawn_instance( - model_config: &ModelConfig, - state: SharedState, - ) -> Result { - let model_ram_usage = - parse_size(&model_config.ram_usage).expect("Invalid ram_usage in model config"); - let model_vram_usage = - parse_size(&model_config.vram_usage).expect("Invalid vram_usage in model config"); - - { - // First, check if we already have an instance. - let state_guard = state.lock().await; - if let Some(instance) = state_guard.instances.get(&model_config.port) { - return Ok(instance.clone()); - } - } - - let mut state_guard = state.lock().await; - tracing::info!(msg = "Shared state before spawn", ?state_guard); - - // If not enough free resources, stop some running instances. - if (state_guard.used_ram + model_ram_usage > state_guard.total_ram) - || (state_guard.used_vram + model_vram_usage > state_guard.total_vram) - { - let mut to_remove = Vec::new(); - let instances_by_size = state_guard - .instances - .iter() - .sorted_by(|(_, a), (_, b)| { - parse_size(&b.config.vram_usage) - .unwrap_or(0) - .cmp(&parse_size(&a.config.vram_usage).unwrap_or(0)) - }) - .map(|(port, instance)| (*port, instance.clone())) - .collect::>(); - - for (port, instance) in instances_by_size.iter() { - tracing::info!("Stopping instance on port {}", port); - let mut proc = instance.process.lock().await; - proc.kill().await.ok(); - to_remove.push(*port); - state_guard.used_ram = state_guard - .used_ram - .saturating_sub(parse_size(&instance.config.ram_usage).unwrap_or(0)); - state_guard.used_vram = state_guard - .used_vram - .saturating_sub(parse_size(&instance.config.vram_usage).unwrap_or(0)); - if state_guard.used_ram + model_ram_usage <= state_guard.total_ram - && state_guard.used_vram + model_vram_usage <= state_guard.total_vram - { - tracing::info!("Freed enough resources"); - break; - } - } - for port in to_remove { - state_guard.instances.remove(&port); - } - } else { - tracing::info!("Sufficient resources available"); - } - - // Spawn a new instance. - let instance = Self::spawn(&model_config).await?; - state_guard.used_ram += model_ram_usage; - state_guard.used_vram += model_vram_usage; - state_guard - .instances - .insert(model_config.port, instance.clone()); - - sleep(Duration::from_millis(250)).await; - - instance.running_and_port_ready().await?; - - sleep(Duration::from_millis(250)).await; - - Ok(instance) - } - - /// Spawn a new llama-server process based on the model configuration. - pub async fn spawn(model_config: &ModelConfig) -> Result { - let args = model_config - .args - .iter() - .flat_map(|(k, v)| { - if v == "true" { - vec![format!("--{}", k)] - } else { - vec![format!("--{}", k), v.clone()] - } - }) - .collect::>(); - - let internal_port = model_config - .internal_port - .expect("Internal port must be set"); - - let mut cmd = Command::new("llama-server"); - cmd.kill_on_drop(true) - .envs(model_config.env.clone()) - .args(&args) - .arg("--port") - .arg(format!("{}", internal_port)) - .stdout(Stdio::null()) - .stderr(Stdio::null()); - - // Silence output – could be captured later via an API. - - tracing::info!("Starting llama-server with command: {:?}", cmd); - let child = cmd.spawn().expect("Failed to start llama-server"); - - Ok(LlamaInstance { - config: model_config.clone(), - process: Arc::new(Mutex::new(child)), - }) - } - - pub async fn running_and_port_ready(&self) -> Result<(), anyhow::Error> { - async fn is_running(instance: &LlamaInstance) -> Result<(), AppError> { - let mut proc = instance.process.clone().lock_owned().await; - match proc.try_wait()? { - Some(exit_status) => { - tracing::error!("Llama instance exited: {:?}", exit_status); - Err(AppError::Unknown(anyhow!("Llama instance exited"))) - } - None => Ok(()), - } - } - - async fn wait_for_port(port: u16) -> Result<(), anyhow::Error> { - for _ in 0..10 { - if TcpStream::connect(("127.0.0.1", port)).await.is_ok() { - return Ok(()); - } - sleep(Duration::from_millis(500)).await; - } - Err(anyhow!("Timeout waiting for port")) - } - - is_running(self).await?; - wait_for_port(self.config.internal_port.expect("No port picked?")).await?; - Ok(()) - } -} - -#[derive(Clone, Debug)] -pub struct InnerSharedState { - total_ram: u64, - total_vram: u64, - used_ram: u64, - used_vram: u64, - instances: HashMap, -} - -type SharedStateArc = Arc>; - -/// TODO migrate to dashmap + individual Arc> or Rwlocks -#[derive(Clone, Debug, derive_more::Deref)] -pub struct SharedState(SharedStateArc); - -impl SharedState { - fn from_config(config: Config) -> Self { - // Parse hardware resources - let total_ram = parse_size(&config.hardware.ram).expect("Invalid RAM size in config"); - let total_vram = parse_size(&config.hardware.vram).expect("Invalid VRAM size in config"); - - // Initialize shared state - let shared_state = InnerSharedState { - total_ram, - total_vram, - used_ram: 0, - used_vram: 0, - instances: HashMap::new(), - }; - - Self(Arc::new(Mutex::new(shared_state))) - } -} - -/// Build an Axum app for a given model. -fn create_app(model_config: &ModelConfig, state: SharedState) -> Router { - Router::new() - .route( - "/", - any({ - let state = state.clone(); - let model_config = model_config.clone(); - move |req| handle_request(req, model_config.clone(), state.clone()) - }), - ) - .route( - "/*path", - any({ - let state = state.clone(); - let model_config = model_config.clone(); - move |req| handle_request(req, model_config.clone(), state.clone()) - }), - ) - .layer( - tower_http::trace::TraceLayer::new_for_http() - .make_span_with(tower_http::trace::DefaultMakeSpan::new().include_headers(true)) - .on_request(tower_http::trace::DefaultOnRequest::new().level(Level::DEBUG)) - .on_response(tower_http::trace::DefaultOnResponse::new().level(Level::TRACE)) - .on_eos(tower_http::trace::DefaultOnEos::new().level(Level::DEBUG)) - .on_failure(tower_http::trace::DefaultOnFailure::new().level(Level::ERROR)), - ) -} +use llama_proxy_man::{config::AppConfig, logging, start_server}; +use tokio; #[tokio::main] async fn main() { - initialize_logger(); + logging::initialize_logger(); - let config = Config::default_from_pwd_yml(); + let config = AppConfig::default_from_pwd_yml(); - let shared_state = SharedState::from_config(config.clone()); - - // For each model, set up an axum server listening on the specified port - let mut handles = Vec::new(); - - for model_config in config.models { - let state = shared_state.clone(); - let model_config = model_config.clone(); - - let handle = tokio::spawn(async move { - let model_config = model_config.clone(); - - let app = create_app(&model_config, state); - - let addr = SocketAddr::from(([0, 0, 0, 0], model_config.port)); - - tracing::info!(msg = "Listening", ?model_config); - let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); - - axum::serve(listener, app.into_make_service()) - .await - .unwrap(); - }); - - handles.push(handle); - } - - futures::future::join_all(handles).await; -} - -use thiserror::Error; - -#[derive(Error, Debug)] -pub enum AppError { - #[error("Axum Error")] - AxumError(#[from] axum::Error), - #[error("Axum Http Error")] - AxumHttpError(#[from] axum::http::Error), - #[error("Reqwest Error")] - ReqwestError(#[from] reqwest::Error), - #[error("Reqwest Middleware Error")] - ReqwestMiddlewareError(#[from] reqwest_middleware::Error), - #[error("Client Error")] - ClientError(#[from] hyper::Error), - #[error("Io Error")] - IoError(#[from] std::io::Error), - #[error("Unknown error")] - Unknown(#[from] anyhow::Error), -} - -// Tell axum how to convert `AppError` into a response. -impl IntoResponse for AppError { - fn into_response(self) -> axum::response::Response { - tracing::error!("::AppError::into_response:: {:?}", self); - ( - http::StatusCode::INTERNAL_SERVER_ERROR, - format!("Something went wrong: {:?}", self), - ) - .into_response() - } -} - -async fn proxy_request( - req: Request, - model_config: &ModelConfig, -) -> Result, AppError> { - let retry_policy = ExponentialBackoff::builder() - .retry_bounds( - std::time::Duration::from_millis(500), - std::time::Duration::from_secs(8), - ) - .jitter(reqwest_retry::Jitter::None) - .base(2) - .build_with_max_retries(8); - - let client: ClientWithMiddleware = ClientBuilder::new(Client::new()) - .with(RetryTransientMiddleware::new_with_policy(retry_policy)) - .build(); - - let internal_port = model_config - .internal_port - .expect("Internal port must be set"); - let uri = format!( - "http://127.0.0.1:{}{}", - internal_port, - req.uri() - .path_and_query() - .map(|pq| pq.as_str()) - .unwrap_or("") - ); - - let mut request_builder = client.request(req.method().clone(), &uri); - // Forward all headers. - for (name, value) in req.headers() { - request_builder = request_builder.header(name, value); - } - - let max_size = parse_size("1G").unwrap() as usize; - let body_bytes = axum::body::to_bytes(req.into_body(), max_size).await?; - let request = request_builder.body(body_bytes).build()?; - let response = client.execute(request).await?; - - let mut builder = Response::builder().status(response.status()); - for (name, value) in response.headers() { - builder = builder.header(name, value); - } - - let byte_stream = response.bytes_stream(); - let body = Body::from_stream(byte_stream); - Ok(builder.body(body)?) -} - -async fn handle_request( - req: Request, - model_config: ModelConfig, - state: SharedState, -) -> impl IntoResponse { - let _instance = LlamaInstance::get_or_spawn_instance(&model_config, state).await?; - let response = proxy_request(req, &model_config).await?; - - Ok::, AppError>(response) -} - -fn parse_size(size_str: &str) -> Option { - let mut num = String::new(); - let mut unit = String::new(); - - for c in size_str.chars() { - if c.is_digit(10) || c == '.' { - num.push(c); - } else { - unit.push(c); - } - } - - let num: f64 = num.parse().ok()?; - - let multiplier = match unit.to_lowercase().as_str() { - "g" | "gb" => 1024 * 1024 * 1024, - "m" | "mb" => 1024 * 1024, - "k" | "kb" => 1024, - _ => panic!("Invalid Size"), - }; - - let res = (num * multiplier as f64) as u64; - Some(res) + start_server(config).await; } diff --git a/llama_proxy_man/src/proxy.rs b/llama_proxy_man/src/proxy.rs new file mode 100644 index 0000000..18217d2 --- /dev/null +++ b/llama_proxy_man/src/proxy.rs @@ -0,0 +1,69 @@ +use crate::{ + config::ModelSpec, error::AppError, inference_process::InferenceProcess, state::AppState, + util::parse_size, +}; +use axum::{ + body::Body, + http::{Request, Response}, + response::IntoResponse, +}; +use reqwest::Client; +use reqwest_middleware::{ClientBuilder, ClientWithMiddleware}; +use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; + +pub async fn proxy_request( + req: Request, + spec: &ModelSpec, +) -> Result, AppError> { + let retry_policy = ExponentialBackoff::builder() + .retry_bounds( + std::time::Duration::from_millis(500), + std::time::Duration::from_secs(8), + ) + .jitter(reqwest_retry::Jitter::None) + .base(2) + .build_with_max_retries(8); + + let client: ClientWithMiddleware = ClientBuilder::new(Client::new()) + .with(RetryTransientMiddleware::new_with_policy(retry_policy)) + .build(); + + let internal_port = spec.internal_port.expect("Internal port must be set"); + let uri = format!( + "http://127.0.0.1:{}{}", + internal_port, + req.uri() + .path_and_query() + .map(|pq| pq.as_str()) + .unwrap_or("") + ); + + let mut request_builder = client.request(req.method().clone(), &uri); + for (name, value) in req.headers() { + request_builder = request_builder.header(name, value); + } + + let max_size = parse_size("1G").unwrap() as usize; + let body_bytes = axum::body::to_bytes(req.into_body(), max_size).await?; + let request = request_builder.body(body_bytes).build()?; + let response = client.execute(request).await?; + + let mut builder = Response::builder().status(response.status()); + for (name, value) in response.headers() { + builder = builder.header(name, value); + } + + let byte_stream = response.bytes_stream(); + let body = Body::from_stream(byte_stream); + Ok(builder.body(body)?) +} + +pub async fn handle_request( + req: Request, + spec: ModelSpec, + state: AppState, +) -> impl IntoResponse { + let _ = InferenceProcess::get_or_spawn(&spec, state).await?; + let response = proxy_request(req, &spec).await?; + Ok::, AppError>(response) +} diff --git a/llama_proxy_man/src/state.rs b/llama_proxy_man/src/state.rs new file mode 100644 index 0000000..f0dd664 --- /dev/null +++ b/llama_proxy_man/src/state.rs @@ -0,0 +1,36 @@ +use crate::{config::AppConfig, inference_process::InferenceProcess, util::parse_size}; +use std::{collections::HashMap, sync::Arc}; +use tokio::sync::Mutex; + +#[derive(Clone, Debug)] +pub struct ResourceManager { + pub total_ram: u64, + pub total_vram: u64, + pub used_ram: u64, + pub used_vram: u64, + pub processes: HashMap, +} + +pub type ResourceManagerHandle = Arc>; + +#[derive(Clone, Debug, derive_more::Deref)] +pub struct AppState(pub ResourceManagerHandle); + +impl AppState { + pub fn from_config(config: AppConfig) -> Self { + let total_ram = + parse_size(&config.system_resources.ram).expect("Invalid RAM size in config"); + let total_vram = + parse_size(&config.system_resources.vram).expect("Invalid VRAM size in config"); + + let resource_manager = ResourceManager { + total_ram, + total_vram, + used_ram: 0, + used_vram: 0, + processes: HashMap::new(), + }; + + Self(Arc::new(Mutex::new(resource_manager))) + } +} diff --git a/llama_proxy_man/src/util.rs b/llama_proxy_man/src/util.rs new file mode 100644 index 0000000..de1c849 --- /dev/null +++ b/llama_proxy_man/src/util.rs @@ -0,0 +1,22 @@ +pub fn parse_size(size_str: &str) -> Option { + let mut num = String::new(); + let mut unit = String::new(); + + for c in size_str.chars() { + if c.is_digit(10) || c == '.' { + num.push(c); + } else { + unit.push(c); + } + } + + let num: f64 = num.parse().ok()?; + let multiplier = match unit.to_lowercase().as_str() { + "g" | "gb" => 1024 * 1024 * 1024, + "m" | "mb" => 1024 * 1024, + "k" | "kb" => 1024, + _ => panic!("Invalid Size"), + }; + + Some((num * multiplier as f64) as u64) +}