refactor(proxy_man): Org & modularize app

- Introduce new modules: `config.rs`, `error.rs`, `inference_process.rs`, `proxy.rs`, `state.rs`, `util.rs`
- Update `logging.rs` to include static initialization of logging
- Add lib.rs for including in other projects
- Refactor `main.rs` to use new modules and improve code structure
This commit is contained in:
Tristan D. 2025-02-10 23:40:27 +01:00
parent e3a3ec3826
commit ad0cd12877
Signed by: tristan
SSH key fingerprint: SHA256:3RU4RLOoM8oAjFU19f1W6t8uouZbA7GWkaSW6rjp1k8
9 changed files with 452 additions and 459 deletions

View file

@ -0,0 +1,48 @@
use serde::Deserialize;
use std::{collections::HashMap, fs};
#[derive(Clone, Debug, Deserialize)]
pub struct AppConfig {
pub system_resources: SystemResources,
pub model_specs: Vec<ModelSpec>,
}
impl AppConfig {
pub fn default_from_pwd_yml() -> Self {
let config_str = fs::read_to_string("config.yaml").expect("Failed to read config.yaml");
serde_yaml::from_str::<Self>(&config_str)
.expect("Failed to parse config.yaml")
.assign_internal_ports()
}
// Ensure every model has an internal port
pub fn assign_internal_ports(self) -> Self {
let mut config = self.clone();
for model in &mut config.model_specs {
if model.internal_port.is_none() {
model.internal_port = Some(
openport::pick_random_unused_port()
.expect(&format!("No open port found for {:?}", model)),
);
}
}
config
}
}
#[derive(Clone, Debug, Deserialize)]
pub struct SystemResources {
pub ram: String,
pub vram: String,
}
#[derive(Clone, Debug, Deserialize)]
pub struct ModelSpec {
pub name: String,
pub port: u16,
pub internal_port: Option<u16>,
pub env: HashMap<String, String>,
pub args: HashMap<String, String>,
pub vram_usage: String,
pub ram_usage: String,
}

View file

@ -0,0 +1,36 @@
use axum::{http, response::IntoResponse};
use hyper;
use reqwest;
use reqwest_middleware;
use std::io;
use thiserror::Error;
use anyhow::Error as AnyError;
#[derive(Error, Debug)]
pub enum AppError {
#[error("Axum Error")]
AxumError(#[from] axum::Error),
#[error("Axum Http Error")]
AxumHttpError(#[from] http::Error),
#[error("Reqwest Error")]
ReqwestError(#[from] reqwest::Error),
#[error("Reqwest Middleware Error")]
ReqwestMiddlewareError(#[from] reqwest_middleware::Error),
#[error("Client Error")]
ClientError(#[from] hyper::Error),
#[error("Io Error")]
IoError(#[from] io::Error),
#[error("Unknown error")]
Unknown(#[from] AnyError),
}
impl IntoResponse for AppError {
fn into_response(self) -> axum::response::Response {
tracing::error!("::AppError::into_response: {:?}", self);
(
http::StatusCode::INTERNAL_SERVER_ERROR,
format!("Something went wrong: {:?}", self),
)
.into_response()
}
}

View file

@ -0,0 +1,149 @@
use crate::{config::ModelSpec, error::AppError, state::AppState, util::parse_size};
use anyhow::anyhow;
use itertools::Itertools;
use std::{process::Stdio, sync::Arc};
use tokio::{
net::TcpStream,
process::{Child, Command},
sync::Mutex,
time::{sleep, Duration},
};
#[derive(Clone, Debug)]
pub struct InferenceProcess {
pub spec: ModelSpec,
pub process: Arc<Mutex<Child>>,
}
impl InferenceProcess {
/// Retrieve a running process from state or spawn a new one.
pub async fn get_or_spawn(
spec: &ModelSpec,
state: AppState,
) -> Result<InferenceProcess, AppError> {
let required_ram = parse_size(&spec.ram_usage).expect("Invalid ram_usage in model spec");
let required_vram = parse_size(&spec.vram_usage).expect("Invalid vram_usage in model spec");
{
let state_guard = state.lock().await;
if let Some(proc) = state_guard.processes.get(&spec.port) {
return Ok(proc.clone());
}
}
let mut state_guard = state.lock().await;
tracing::info!(msg = "App state before spawn", ?state_guard);
// If not enough resources, stop some running processes.
if (state_guard.used_ram + required_ram > state_guard.total_ram)
|| (state_guard.used_vram + required_vram > state_guard.total_vram)
{
let mut to_remove = Vec::new();
let processes_by_usage = state_guard
.processes
.iter()
.sorted_by(|(_, a), (_, b)| {
parse_size(&b.spec.vram_usage)
.unwrap_or(0)
.cmp(&parse_size(&a.spec.vram_usage).unwrap_or(0))
})
.map(|(port, proc)| (*port, proc.clone()))
.collect::<Vec<_>>();
for (port, proc) in processes_by_usage.iter() {
tracing::info!("Stopping process on port {}", port);
let mut lock = proc.process.lock().await;
lock.kill().await.ok();
to_remove.push(*port);
state_guard.used_ram = state_guard
.used_ram
.saturating_sub(parse_size(&proc.spec.ram_usage).unwrap_or(0));
state_guard.used_vram = state_guard
.used_vram
.saturating_sub(parse_size(&proc.spec.vram_usage).unwrap_or(0));
if state_guard.used_ram + required_ram <= state_guard.total_ram
&& state_guard.used_vram + required_vram <= state_guard.total_vram
{
tracing::info!("Freed enough resources");
break;
}
}
for port in to_remove {
state_guard.processes.remove(&port);
}
} else {
tracing::info!("Sufficient resources available");
}
let proc = Self::spawn(spec).await?;
state_guard.used_ram += required_ram;
state_guard.used_vram += required_vram;
state_guard.processes.insert(spec.port, proc.clone());
sleep(Duration::from_millis(250)).await;
proc.wait_until_ready().await?;
sleep(Duration::from_millis(250)).await;
Ok(proc)
}
/// Spawn a new inference process.
pub async fn spawn(spec: &ModelSpec) -> Result<InferenceProcess, AppError> {
let args = spec
.args
.iter()
.flat_map(|(k, v)| {
if v == "true" {
vec![format!("--{}", k)]
} else {
vec![format!("--{}", k), v.clone()]
}
})
.collect::<Vec<_>>();
let internal_port = spec.internal_port.expect("Internal port must be set");
let mut cmd = Command::new("llama-server");
cmd.kill_on_drop(true)
.envs(spec.env.clone())
.args(&args)
.arg("--port")
.arg(format!("{}", internal_port))
.stdout(Stdio::null())
.stderr(Stdio::null());
tracing::info!("Starting llama-server with command: {:?}", cmd);
let child = cmd.spawn().expect("Failed to start llama-server");
Ok(InferenceProcess {
spec: spec.clone(),
process: Arc::new(Mutex::new(child)),
})
}
pub async fn wait_until_ready(&self) -> Result<(), anyhow::Error> {
async fn check_running(proc: &InferenceProcess) -> Result<(), AppError> {
let mut lock = proc.process.clone().lock_owned().await;
match lock.try_wait()? {
Some(exit_status) => {
tracing::error!("Inference process exited: {:?}", exit_status);
Err(AppError::Unknown(anyhow!("Inference process exited")))
}
None => Ok(()),
}
}
async fn wait_for_port(port: u16) -> Result<(), anyhow::Error> {
for _ in 0..10 {
if TcpStream::connect(("127.0.0.1", port)).await.is_ok() {
return Ok(());
}
sleep(Duration::from_millis(500)).await;
}
Err(anyhow!("Timeout waiting for port"))
}
check_running(self).await?;
wait_for_port(self.spec.internal_port.expect("Internal port must be set")).await?;
Ok(())
}
}

View file

@ -0,0 +1,70 @@
pub mod config;
pub mod error;
pub mod inference_process;
pub mod logging;
pub mod proxy;
pub mod state;
pub mod util;
use axum::{routing::any, Router};
use config::{AppConfig, ModelSpec};
use state::AppState;
use std::net::SocketAddr;
use tower_http::trace::{
DefaultMakeSpan, DefaultOnEos, DefaultOnFailure, DefaultOnRequest, DefaultOnResponse,
TraceLayer,
};
use tracing::Level;
/// Creates an Axum application to handle inference requests for a specific model.
pub fn create_app(spec: &ModelSpec, state: AppState) -> Router {
Router::new()
.route(
"/",
any({
let state = state.clone();
let spec = spec.clone();
move |req| proxy::handle_request(req, spec.clone(), state.clone())
}),
)
.route(
"/*path",
any({
let state = state.clone();
let spec = spec.clone();
move |req| proxy::handle_request(req, spec.clone(), state.clone())
}),
)
.layer(
TraceLayer::new_for_http()
.make_span_with(DefaultMakeSpan::new().include_headers(true))
.on_request(DefaultOnRequest::new().level(Level::DEBUG))
.on_response(DefaultOnResponse::new().level(Level::TRACE))
.on_eos(DefaultOnEos::new().level(Level::DEBUG))
.on_failure(DefaultOnFailure::new().level(Level::ERROR)),
)
}
/// Starts an inference server for each model defined in the config.
pub async fn start_server(config: AppConfig) {
let state = AppState::from_config(config.clone());
let mut handles = Vec::new();
for spec in config.model_specs {
let state = state.clone();
let spec = spec.clone();
let handle = tokio::spawn(async move {
let app = create_app(&spec, state);
let addr = SocketAddr::from(([0, 0, 0, 0], spec.port));
tracing::info!(msg = "Listening", ?spec);
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
axum::serve(listener, app.into_make_service())
.await
.unwrap();
});
handles.push(handle);
}
futures::future::join_all(handles).await;
}

View file

@ -5,6 +5,8 @@ use std::{
task::{Context, Poll}, task::{Context, Poll},
}; };
use std::sync::Once;
use axum::{body::Body, http::Request}; use axum::{body::Body, http::Request};
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
use tower::{Layer, Service}; use tower::{Layer, Service};
@ -81,3 +83,18 @@ where
} }
} }
} }
static INIT: Once = Once::new();
pub fn initialize_logger() {
INIT.call_once(|| {
let env_filter = tracing_subscriber::EnvFilter::builder()
.with_default_directive(tracing::level_filters::LevelFilter::INFO.into())
.from_env_lossy();
tracing_subscriber::fmt()
.compact()
.with_env_filter(env_filter)
.init();
});
}

View file

@ -1,465 +1,11 @@
mod logging; use llama_proxy_man::{config::AppConfig, logging, start_server};
use tokio;
use anyhow::anyhow;
use axum::{
body::Body,
http::{self, Request, Response},
response::IntoResponse,
routing::any,
Router,
};
use futures;
use itertools::Itertools;
use reqwest::Client;
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use serde::Deserialize;
use std::{collections::HashMap, net::SocketAddr, process::Stdio, sync::Arc};
use tokio::{
net::TcpStream,
process::{Child, Command},
sync::Mutex,
time::{sleep, Duration},
};
use tower_http;
use tracing::Level;
use std::sync::Once;
pub static INIT: Once = Once::new();
pub fn initialize_logger() {
INIT.call_once(|| {
let env_filter = tracing_subscriber::EnvFilter::builder()
.with_default_directive(tracing::level_filters::LevelFilter::INFO.into())
.from_env_lossy();
tracing_subscriber::fmt()
.compact()
.with_env_filter(env_filter)
.init();
});
}
// TODO Add valiation to the config
// - e.g. check double taken ports etc
#[derive(Clone, Debug, Deserialize)]
struct Config {
hardware: Hardware,
models: Vec<ModelConfig>,
}
impl Config {
fn default_from_pwd_yml() -> Self {
// Read and parse the YAML configuration
let config_str =
std::fs::read_to_string("config.yaml").expect("Failed to read config.yaml");
serde_yaml::from_str::<Self>(&config_str)
.expect("Failed to parse config.yaml")
.pick_open_ports()
}
// TODO split up into raw deser config and "parsed"/"processed" config which always has a port
// FIXME maybe just always initialize with a port ?
fn pick_open_ports(self) -> Self {
let mut config = self.clone();
for model in &mut config.models {
if model.internal_port.is_none() {
model.internal_port = Some(
openport::pick_random_unused_port()
.expect(format!("No open port found for {:?}", model).as_str()),
);
}
}
config
}
}
#[derive(Clone, Debug, Deserialize)]
struct Hardware {
ram: String,
vram: String,
}
#[derive(Clone, Debug, Deserialize)]
struct ModelConfig {
#[allow(dead_code)]
name: String,
port: u16,
internal_port: Option<u16>,
env: HashMap<String, String>,
args: HashMap<String, String>,
vram_usage: String,
ram_usage: String,
}
#[derive(Clone, Debug)]
struct LlamaInstance {
config: ModelConfig,
process: Arc<Mutex<Child>>,
// busy: bool,
}
impl LlamaInstance {
/// Retrieve a running instance from state or spawn a new one.
pub async fn get_or_spawn_instance(
model_config: &ModelConfig,
state: SharedState,
) -> Result<LlamaInstance, AppError> {
let model_ram_usage =
parse_size(&model_config.ram_usage).expect("Invalid ram_usage in model config");
let model_vram_usage =
parse_size(&model_config.vram_usage).expect("Invalid vram_usage in model config");
{
// First, check if we already have an instance.
let state_guard = state.lock().await;
if let Some(instance) = state_guard.instances.get(&model_config.port) {
return Ok(instance.clone());
}
}
let mut state_guard = state.lock().await;
tracing::info!(msg = "Shared state before spawn", ?state_guard);
// If not enough free resources, stop some running instances.
if (state_guard.used_ram + model_ram_usage > state_guard.total_ram)
|| (state_guard.used_vram + model_vram_usage > state_guard.total_vram)
{
let mut to_remove = Vec::new();
let instances_by_size = state_guard
.instances
.iter()
.sorted_by(|(_, a), (_, b)| {
parse_size(&b.config.vram_usage)
.unwrap_or(0)
.cmp(&parse_size(&a.config.vram_usage).unwrap_or(0))
})
.map(|(port, instance)| (*port, instance.clone()))
.collect::<Vec<_>>();
for (port, instance) in instances_by_size.iter() {
tracing::info!("Stopping instance on port {}", port);
let mut proc = instance.process.lock().await;
proc.kill().await.ok();
to_remove.push(*port);
state_guard.used_ram = state_guard
.used_ram
.saturating_sub(parse_size(&instance.config.ram_usage).unwrap_or(0));
state_guard.used_vram = state_guard
.used_vram
.saturating_sub(parse_size(&instance.config.vram_usage).unwrap_or(0));
if state_guard.used_ram + model_ram_usage <= state_guard.total_ram
&& state_guard.used_vram + model_vram_usage <= state_guard.total_vram
{
tracing::info!("Freed enough resources");
break;
}
}
for port in to_remove {
state_guard.instances.remove(&port);
}
} else {
tracing::info!("Sufficient resources available");
}
// Spawn a new instance.
let instance = Self::spawn(&model_config).await?;
state_guard.used_ram += model_ram_usage;
state_guard.used_vram += model_vram_usage;
state_guard
.instances
.insert(model_config.port, instance.clone());
sleep(Duration::from_millis(250)).await;
instance.running_and_port_ready().await?;
sleep(Duration::from_millis(250)).await;
Ok(instance)
}
/// Spawn a new llama-server process based on the model configuration.
pub async fn spawn(model_config: &ModelConfig) -> Result<LlamaInstance, AppError> {
let args = model_config
.args
.iter()
.flat_map(|(k, v)| {
if v == "true" {
vec![format!("--{}", k)]
} else {
vec![format!("--{}", k), v.clone()]
}
})
.collect::<Vec<_>>();
let internal_port = model_config
.internal_port
.expect("Internal port must be set");
let mut cmd = Command::new("llama-server");
cmd.kill_on_drop(true)
.envs(model_config.env.clone())
.args(&args)
.arg("--port")
.arg(format!("{}", internal_port))
.stdout(Stdio::null())
.stderr(Stdio::null());
// Silence output could be captured later via an API.
tracing::info!("Starting llama-server with command: {:?}", cmd);
let child = cmd.spawn().expect("Failed to start llama-server");
Ok(LlamaInstance {
config: model_config.clone(),
process: Arc::new(Mutex::new(child)),
})
}
pub async fn running_and_port_ready(&self) -> Result<(), anyhow::Error> {
async fn is_running(instance: &LlamaInstance) -> Result<(), AppError> {
let mut proc = instance.process.clone().lock_owned().await;
match proc.try_wait()? {
Some(exit_status) => {
tracing::error!("Llama instance exited: {:?}", exit_status);
Err(AppError::Unknown(anyhow!("Llama instance exited")))
}
None => Ok(()),
}
}
async fn wait_for_port(port: u16) -> Result<(), anyhow::Error> {
for _ in 0..10 {
if TcpStream::connect(("127.0.0.1", port)).await.is_ok() {
return Ok(());
}
sleep(Duration::from_millis(500)).await;
}
Err(anyhow!("Timeout waiting for port"))
}
is_running(self).await?;
wait_for_port(self.config.internal_port.expect("No port picked?")).await?;
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct InnerSharedState {
total_ram: u64,
total_vram: u64,
used_ram: u64,
used_vram: u64,
instances: HashMap<u16, LlamaInstance>,
}
type SharedStateArc = Arc<Mutex<InnerSharedState>>;
/// TODO migrate to dashmap + individual Arc<Mutex<T>> or Rwlocks
#[derive(Clone, Debug, derive_more::Deref)]
pub struct SharedState(SharedStateArc);
impl SharedState {
fn from_config(config: Config) -> Self {
// Parse hardware resources
let total_ram = parse_size(&config.hardware.ram).expect("Invalid RAM size in config");
let total_vram = parse_size(&config.hardware.vram).expect("Invalid VRAM size in config");
// Initialize shared state
let shared_state = InnerSharedState {
total_ram,
total_vram,
used_ram: 0,
used_vram: 0,
instances: HashMap::new(),
};
Self(Arc::new(Mutex::new(shared_state)))
}
}
/// Build an Axum app for a given model.
fn create_app(model_config: &ModelConfig, state: SharedState) -> Router {
Router::new()
.route(
"/",
any({
let state = state.clone();
let model_config = model_config.clone();
move |req| handle_request(req, model_config.clone(), state.clone())
}),
)
.route(
"/*path",
any({
let state = state.clone();
let model_config = model_config.clone();
move |req| handle_request(req, model_config.clone(), state.clone())
}),
)
.layer(
tower_http::trace::TraceLayer::new_for_http()
.make_span_with(tower_http::trace::DefaultMakeSpan::new().include_headers(true))
.on_request(tower_http::trace::DefaultOnRequest::new().level(Level::DEBUG))
.on_response(tower_http::trace::DefaultOnResponse::new().level(Level::TRACE))
.on_eos(tower_http::trace::DefaultOnEos::new().level(Level::DEBUG))
.on_failure(tower_http::trace::DefaultOnFailure::new().level(Level::ERROR)),
)
}
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
initialize_logger(); logging::initialize_logger();
let config = Config::default_from_pwd_yml(); let config = AppConfig::default_from_pwd_yml();
let shared_state = SharedState::from_config(config.clone()); start_server(config).await;
// For each model, set up an axum server listening on the specified port
let mut handles = Vec::new();
for model_config in config.models {
let state = shared_state.clone();
let model_config = model_config.clone();
let handle = tokio::spawn(async move {
let model_config = model_config.clone();
let app = create_app(&model_config, state);
let addr = SocketAddr::from(([0, 0, 0, 0], model_config.port));
tracing::info!(msg = "Listening", ?model_config);
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
axum::serve(listener, app.into_make_service())
.await
.unwrap();
});
handles.push(handle);
}
futures::future::join_all(handles).await;
}
use thiserror::Error;
#[derive(Error, Debug)]
pub enum AppError {
#[error("Axum Error")]
AxumError(#[from] axum::Error),
#[error("Axum Http Error")]
AxumHttpError(#[from] axum::http::Error),
#[error("Reqwest Error")]
ReqwestError(#[from] reqwest::Error),
#[error("Reqwest Middleware Error")]
ReqwestMiddlewareError(#[from] reqwest_middleware::Error),
#[error("Client Error")]
ClientError(#[from] hyper::Error),
#[error("Io Error")]
IoError(#[from] std::io::Error),
#[error("Unknown error")]
Unknown(#[from] anyhow::Error),
}
// Tell axum how to convert `AppError` into a response.
impl IntoResponse for AppError {
fn into_response(self) -> axum::response::Response {
tracing::error!("::AppError::into_response:: {:?}", self);
(
http::StatusCode::INTERNAL_SERVER_ERROR,
format!("Something went wrong: {:?}", self),
)
.into_response()
}
}
async fn proxy_request(
req: Request<Body>,
model_config: &ModelConfig,
) -> Result<Response<Body>, AppError> {
let retry_policy = ExponentialBackoff::builder()
.retry_bounds(
std::time::Duration::from_millis(500),
std::time::Duration::from_secs(8),
)
.jitter(reqwest_retry::Jitter::None)
.base(2)
.build_with_max_retries(8);
let client: ClientWithMiddleware = ClientBuilder::new(Client::new())
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
.build();
let internal_port = model_config
.internal_port
.expect("Internal port must be set");
let uri = format!(
"http://127.0.0.1:{}{}",
internal_port,
req.uri()
.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or("")
);
let mut request_builder = client.request(req.method().clone(), &uri);
// Forward all headers.
for (name, value) in req.headers() {
request_builder = request_builder.header(name, value);
}
let max_size = parse_size("1G").unwrap() as usize;
let body_bytes = axum::body::to_bytes(req.into_body(), max_size).await?;
let request = request_builder.body(body_bytes).build()?;
let response = client.execute(request).await?;
let mut builder = Response::builder().status(response.status());
for (name, value) in response.headers() {
builder = builder.header(name, value);
}
let byte_stream = response.bytes_stream();
let body = Body::from_stream(byte_stream);
Ok(builder.body(body)?)
}
async fn handle_request(
req: Request<Body>,
model_config: ModelConfig,
state: SharedState,
) -> impl IntoResponse {
let _instance = LlamaInstance::get_or_spawn_instance(&model_config, state).await?;
let response = proxy_request(req, &model_config).await?;
Ok::<axum::http::Response<Body>, AppError>(response)
}
fn parse_size(size_str: &str) -> Option<u64> {
let mut num = String::new();
let mut unit = String::new();
for c in size_str.chars() {
if c.is_digit(10) || c == '.' {
num.push(c);
} else {
unit.push(c);
}
}
let num: f64 = num.parse().ok()?;
let multiplier = match unit.to_lowercase().as_str() {
"g" | "gb" => 1024 * 1024 * 1024,
"m" | "mb" => 1024 * 1024,
"k" | "kb" => 1024,
_ => panic!("Invalid Size"),
};
let res = (num * multiplier as f64) as u64;
Some(res)
} }

View file

@ -0,0 +1,69 @@
use crate::{
config::ModelSpec, error::AppError, inference_process::InferenceProcess, state::AppState,
util::parse_size,
};
use axum::{
body::Body,
http::{Request, Response},
response::IntoResponse,
};
use reqwest::Client;
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
pub async fn proxy_request(
req: Request<Body>,
spec: &ModelSpec,
) -> Result<Response<Body>, AppError> {
let retry_policy = ExponentialBackoff::builder()
.retry_bounds(
std::time::Duration::from_millis(500),
std::time::Duration::from_secs(8),
)
.jitter(reqwest_retry::Jitter::None)
.base(2)
.build_with_max_retries(8);
let client: ClientWithMiddleware = ClientBuilder::new(Client::new())
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
.build();
let internal_port = spec.internal_port.expect("Internal port must be set");
let uri = format!(
"http://127.0.0.1:{}{}",
internal_port,
req.uri()
.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or("")
);
let mut request_builder = client.request(req.method().clone(), &uri);
for (name, value) in req.headers() {
request_builder = request_builder.header(name, value);
}
let max_size = parse_size("1G").unwrap() as usize;
let body_bytes = axum::body::to_bytes(req.into_body(), max_size).await?;
let request = request_builder.body(body_bytes).build()?;
let response = client.execute(request).await?;
let mut builder = Response::builder().status(response.status());
for (name, value) in response.headers() {
builder = builder.header(name, value);
}
let byte_stream = response.bytes_stream();
let body = Body::from_stream(byte_stream);
Ok(builder.body(body)?)
}
pub async fn handle_request(
req: Request<Body>,
spec: ModelSpec,
state: AppState,
) -> impl IntoResponse {
let _ = InferenceProcess::get_or_spawn(&spec, state).await?;
let response = proxy_request(req, &spec).await?;
Ok::<Response<Body>, AppError>(response)
}

View file

@ -0,0 +1,36 @@
use crate::{config::AppConfig, inference_process::InferenceProcess, util::parse_size};
use std::{collections::HashMap, sync::Arc};
use tokio::sync::Mutex;
#[derive(Clone, Debug)]
pub struct ResourceManager {
pub total_ram: u64,
pub total_vram: u64,
pub used_ram: u64,
pub used_vram: u64,
pub processes: HashMap<u16, InferenceProcess>,
}
pub type ResourceManagerHandle = Arc<Mutex<ResourceManager>>;
#[derive(Clone, Debug, derive_more::Deref)]
pub struct AppState(pub ResourceManagerHandle);
impl AppState {
pub fn from_config(config: AppConfig) -> Self {
let total_ram =
parse_size(&config.system_resources.ram).expect("Invalid RAM size in config");
let total_vram =
parse_size(&config.system_resources.vram).expect("Invalid VRAM size in config");
let resource_manager = ResourceManager {
total_ram,
total_vram,
used_ram: 0,
used_vram: 0,
processes: HashMap::new(),
};
Self(Arc::new(Mutex::new(resource_manager)))
}
}

View file

@ -0,0 +1,22 @@
pub fn parse_size(size_str: &str) -> Option<u64> {
let mut num = String::new();
let mut unit = String::new();
for c in size_str.chars() {
if c.is_digit(10) || c == '.' {
num.push(c);
} else {
unit.push(c);
}
}
let num: f64 = num.parse().ok()?;
let multiplier = match unit.to_lowercase().as_str() {
"g" | "gb" => 1024 * 1024 * 1024,
"m" | "mb" => 1024 * 1024,
"k" | "kb" => 1024,
_ => panic!("Invalid Size"),
};
Some((num * multiplier as f64) as u64)
}