Refactor llama_proxy_man

This commit is contained in:
Tristan D. 2025-02-10 23:22:31 +01:00
parent 8c062ece28
commit e3a3ec3826
Signed by: tristan
SSH key fingerprint: SHA256:3RU4RLOoM8oAjFU19f1W6t8uouZbA7GWkaSW6rjp1k8
3 changed files with 277 additions and 218 deletions

25
Cargo.lock generated
View file

@ -1141,6 +1141,26 @@ dependencies = [
"syn 2.0.98", "syn 2.0.98",
] ]
[[package]]
name = "derive_more"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "093242cf7570c207c83073cf82f79706fe7b8317e98620a47d5be7c3d8497678"
dependencies = [
"derive_more-impl",
]
[[package]]
name = "derive_more-impl"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bda628edc44c4bb645fbe0f758797143e4e07926f7ebf4e9bdfbd3d2ce621df3"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.98",
]
[[package]] [[package]]
name = "digest" name = "digest"
version = "0.10.7" version = "0.10.7"
@ -2950,7 +2970,7 @@ dependencies = [
"chrono", "chrono",
"console_error_panic_hook", "console_error_panic_hook",
"dashmap", "dashmap",
"derive_more", "derive_more 0.99.19",
"futures", "futures",
"futures-util", "futures-util",
"gloo-net 0.5.0", "gloo-net 0.5.0",
@ -3002,6 +3022,7 @@ version = "0.1.1"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"axum", "axum",
"derive_more 2.0.1",
"futures", "futures",
"hyper", "hyper",
"itertools 0.13.0", "itertools 0.13.0",
@ -4781,7 +4802,7 @@ checksum = "df320f1889ac4ba6bc0cdc9c9af7af4bd64bb927bccdf32d81140dc1f9be12fe"
dependencies = [ dependencies = [
"bitflags 1.3.2", "bitflags 1.3.2",
"cssparser", "cssparser",
"derive_more", "derive_more 0.99.19",
"fxhash", "fxhash",
"log", "log",
"matches", "matches",

View file

@ -29,3 +29,4 @@ reqwest-retry = "0.6.1"
reqwest-middleware = { version = "0.3.3", features = ["charset", "http2", "json", "multipart", "rustls-tls"] } reqwest-middleware = { version = "0.3.3", features = ["charset", "http2", "json", "multipart", "rustls-tls"] }
itertools = "0.13.0" itertools = "0.13.0"
openport = { version = "0.1.1", features = ["rand"] } openport = { version = "0.1.1", features = ["rand"] }
derive_more = { version = "2.0.1", features = ["deref"] }

View file

@ -11,6 +11,8 @@ use axum::{
use futures; use futures;
use itertools::Itertools; use itertools::Itertools;
use reqwest::Client; use reqwest::Client;
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use serde::Deserialize; use serde::Deserialize;
use std::{collections::HashMap, net::SocketAddr, process::Stdio, sync::Arc}; use std::{collections::HashMap, net::SocketAddr, process::Stdio, sync::Arc};
use tokio::{ use tokio::{
@ -19,10 +21,7 @@ use tokio::{
sync::Mutex, sync::Mutex,
time::{sleep, Duration}, time::{sleep, Duration},
}; };
use tower_http::trace::{ use tower_http;
DefaultMakeSpan, DefaultOnEos, DefaultOnFailure, DefaultOnRequest, DefaultOnResponse,
TraceLayer,
};
use tracing::Level; use tracing::Level;
use std::sync::Once; use std::sync::Once;
@ -42,6 +41,8 @@ pub fn initialize_logger() {
}); });
} }
// TODO Add valiation to the config
// - e.g. check double taken ports etc
#[derive(Clone, Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
struct Config { struct Config {
hardware: Hardware, hardware: Hardware,
@ -49,7 +50,18 @@ struct Config {
} }
impl Config { impl Config {
fn default_from_pwd_yml() -> Self {
// Read and parse the YAML configuration
let config_str =
std::fs::read_to_string("config.yaml").expect("Failed to read config.yaml");
serde_yaml::from_str::<Self>(&config_str)
.expect("Failed to parse config.yaml")
.pick_open_ports()
}
// TODO split up into raw deser config and "parsed"/"processed" config which always has a port // TODO split up into raw deser config and "parsed"/"processed" config which always has a port
// FIXME maybe just always initialize with a port ?
fn pick_open_ports(self) -> Self { fn pick_open_ports(self) -> Self {
let mut config = self.clone(); let mut config = self.clone();
for model in &mut config.models { for model in &mut config.models {
@ -89,8 +101,154 @@ struct LlamaInstance {
// busy: bool, // busy: bool,
} }
impl LlamaInstance {
/// Retrieve a running instance from state or spawn a new one.
pub async fn get_or_spawn_instance(
model_config: &ModelConfig,
state: SharedState,
) -> Result<LlamaInstance, AppError> {
let model_ram_usage =
parse_size(&model_config.ram_usage).expect("Invalid ram_usage in model config");
let model_vram_usage =
parse_size(&model_config.vram_usage).expect("Invalid vram_usage in model config");
{
// First, check if we already have an instance.
let state_guard = state.lock().await;
if let Some(instance) = state_guard.instances.get(&model_config.port) {
return Ok(instance.clone());
}
}
let mut state_guard = state.lock().await;
tracing::info!(msg = "Shared state before spawn", ?state_guard);
// If not enough free resources, stop some running instances.
if (state_guard.used_ram + model_ram_usage > state_guard.total_ram)
|| (state_guard.used_vram + model_vram_usage > state_guard.total_vram)
{
let mut to_remove = Vec::new();
let instances_by_size = state_guard
.instances
.iter()
.sorted_by(|(_, a), (_, b)| {
parse_size(&b.config.vram_usage)
.unwrap_or(0)
.cmp(&parse_size(&a.config.vram_usage).unwrap_or(0))
})
.map(|(port, instance)| (*port, instance.clone()))
.collect::<Vec<_>>();
for (port, instance) in instances_by_size.iter() {
tracing::info!("Stopping instance on port {}", port);
let mut proc = instance.process.lock().await;
proc.kill().await.ok();
to_remove.push(*port);
state_guard.used_ram = state_guard
.used_ram
.saturating_sub(parse_size(&instance.config.ram_usage).unwrap_or(0));
state_guard.used_vram = state_guard
.used_vram
.saturating_sub(parse_size(&instance.config.vram_usage).unwrap_or(0));
if state_guard.used_ram + model_ram_usage <= state_guard.total_ram
&& state_guard.used_vram + model_vram_usage <= state_guard.total_vram
{
tracing::info!("Freed enough resources");
break;
}
}
for port in to_remove {
state_guard.instances.remove(&port);
}
} else {
tracing::info!("Sufficient resources available");
}
// Spawn a new instance.
let instance = Self::spawn(&model_config).await?;
state_guard.used_ram += model_ram_usage;
state_guard.used_vram += model_vram_usage;
state_guard
.instances
.insert(model_config.port, instance.clone());
sleep(Duration::from_millis(250)).await;
instance.running_and_port_ready().await?;
sleep(Duration::from_millis(250)).await;
Ok(instance)
}
/// Spawn a new llama-server process based on the model configuration.
pub async fn spawn(model_config: &ModelConfig) -> Result<LlamaInstance, AppError> {
let args = model_config
.args
.iter()
.flat_map(|(k, v)| {
if v == "true" {
vec![format!("--{}", k)]
} else {
vec![format!("--{}", k), v.clone()]
}
})
.collect::<Vec<_>>();
let internal_port = model_config
.internal_port
.expect("Internal port must be set");
let mut cmd = Command::new("llama-server");
cmd.kill_on_drop(true)
.envs(model_config.env.clone())
.args(&args)
.arg("--port")
.arg(format!("{}", internal_port))
.stdout(Stdio::null())
.stderr(Stdio::null());
// Silence output could be captured later via an API.
tracing::info!("Starting llama-server with command: {:?}", cmd);
let child = cmd.spawn().expect("Failed to start llama-server");
Ok(LlamaInstance {
config: model_config.clone(),
process: Arc::new(Mutex::new(child)),
})
}
pub async fn running_and_port_ready(&self) -> Result<(), anyhow::Error> {
async fn is_running(instance: &LlamaInstance) -> Result<(), AppError> {
let mut proc = instance.process.clone().lock_owned().await;
match proc.try_wait()? {
Some(exit_status) => {
tracing::error!("Llama instance exited: {:?}", exit_status);
Err(AppError::Unknown(anyhow!("Llama instance exited")))
}
None => Ok(()),
}
}
async fn wait_for_port(port: u16) -> Result<(), anyhow::Error> {
for _ in 0..10 {
if TcpStream::connect(("127.0.0.1", port)).await.is_ok() {
return Ok(());
}
sleep(Duration::from_millis(500)).await;
}
Err(anyhow!("Timeout waiting for port"))
}
is_running(self).await?;
wait_for_port(self.config.internal_port.expect("No port picked?")).await?;
Ok(())
}
}
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
struct SharedState { pub struct InnerSharedState {
total_ram: u64, total_ram: u64,
total_vram: u64, total_vram: u64,
used_ram: u64, used_ram: u64,
@ -98,31 +256,67 @@ struct SharedState {
instances: HashMap<u16, LlamaInstance>, instances: HashMap<u16, LlamaInstance>,
} }
type SharedStateArc = Arc<Mutex<InnerSharedState>>;
/// TODO migrate to dashmap + individual Arc<Mutex<T>> or Rwlocks
#[derive(Clone, Debug, derive_more::Deref)]
pub struct SharedState(SharedStateArc);
impl SharedState {
fn from_config(config: Config) -> Self {
// Parse hardware resources
let total_ram = parse_size(&config.hardware.ram).expect("Invalid RAM size in config");
let total_vram = parse_size(&config.hardware.vram).expect("Invalid VRAM size in config");
// Initialize shared state
let shared_state = InnerSharedState {
total_ram,
total_vram,
used_ram: 0,
used_vram: 0,
instances: HashMap::new(),
};
Self(Arc::new(Mutex::new(shared_state)))
}
}
/// Build an Axum app for a given model.
fn create_app(model_config: &ModelConfig, state: SharedState) -> Router {
Router::new()
.route(
"/",
any({
let state = state.clone();
let model_config = model_config.clone();
move |req| handle_request(req, model_config.clone(), state.clone())
}),
)
.route(
"/*path",
any({
let state = state.clone();
let model_config = model_config.clone();
move |req| handle_request(req, model_config.clone(), state.clone())
}),
)
.layer(
tower_http::trace::TraceLayer::new_for_http()
.make_span_with(tower_http::trace::DefaultMakeSpan::new().include_headers(true))
.on_request(tower_http::trace::DefaultOnRequest::new().level(Level::DEBUG))
.on_response(tower_http::trace::DefaultOnResponse::new().level(Level::TRACE))
.on_eos(tower_http::trace::DefaultOnEos::new().level(Level::DEBUG))
.on_failure(tower_http::trace::DefaultOnFailure::new().level(Level::ERROR)),
)
}
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
// TODO add autostart of models based on config
// abstract starting logic out of handler for this to allow seperate calls to start
// maybe add to SharedState & LLamaInstance ?
initialize_logger(); initialize_logger();
// Read and parse the YAML configuration
let config_str = std::fs::read_to_string("config.yaml").expect("Failed to read config.yaml");
let config: Config = serde_yaml::from_str::<Config>(&config_str)
.expect("Failed to parse config.yaml")
.pick_open_ports();
// Parse hardware resources let config = Config::default_from_pwd_yml();
let total_ram = parse_size(&config.hardware.ram).expect("Invalid RAM size in config");
let total_vram = parse_size(&config.hardware.vram).expect("Invalid VRAM size in config");
// Initialize shared state let shared_state = SharedState::from_config(config.clone());
let shared_state = Arc::new(Mutex::new(SharedState {
total_ram,
total_vram,
used_ram: 0,
used_vram: 0,
instances: HashMap::new(),
}));
// For each model, set up an axum server listening on the specified port // For each model, set up an axum server listening on the specified port
let mut handles = Vec::new(); let mut handles = Vec::new();
@ -134,31 +328,7 @@ async fn main() {
let handle = tokio::spawn(async move { let handle = tokio::spawn(async move {
let model_config = model_config.clone(); let model_config = model_config.clone();
let app = Router::new() let app = create_app(&model_config, state);
.route(
"/",
any({
let state = state.clone();
let model_config = model_config.clone();
move |req| handle_request(req, model_config, state)
}),
)
.route(
"/*path",
any({
let state = state.clone();
let model_config = model_config.clone();
move |req| handle_request(req, model_config, state)
}),
)
.layer(
TraceLayer::new_for_http()
.make_span_with(DefaultMakeSpan::new().include_headers(true))
.on_request(DefaultOnRequest::new().level(Level::DEBUG))
.on_response(DefaultOnResponse::new().level(Level::TRACE))
.on_eos(DefaultOnEos::new().level(Level::DEBUG))
.on_failure(DefaultOnFailure::new().level(Level::ERROR)),
);
let addr = SocketAddr::from(([0, 0, 0, 0], model_config.port)); let addr = SocketAddr::from(([0, 0, 0, 0], model_config.port));
@ -190,6 +360,8 @@ pub enum AppError {
ReqwestMiddlewareError(#[from] reqwest_middleware::Error), ReqwestMiddlewareError(#[from] reqwest_middleware::Error),
#[error("Client Error")] #[error("Client Error")]
ClientError(#[from] hyper::Error), ClientError(#[from] hyper::Error),
#[error("Io Error")]
IoError(#[from] std::io::Error),
#[error("Unknown error")] #[error("Unknown error")]
Unknown(#[from] anyhow::Error), Unknown(#[from] anyhow::Error),
} }
@ -206,202 +378,67 @@ impl IntoResponse for AppError {
} }
} }
async fn handle_request( async fn proxy_request(
req: Request<Body>, req: Request<Body>,
model_config: ModelConfig, model_config: &ModelConfig,
state: Arc<Mutex<SharedState>>, ) -> Result<Response<Body>, AppError> {
// ) -> Result<Response<Body>, anyhow::Error> { let retry_policy = ExponentialBackoff::builder()
) -> impl IntoResponse { .retry_bounds(
let model_ram_usage = parse_size(&model_config.ram_usage).expect("Invalid ram_usage"); std::time::Duration::from_millis(500),
let model_vram_usage = parse_size(&model_config.vram_usage).expect("Invalid vram_usage"); std::time::Duration::from_secs(8),
)
let instance = {
let mut state = state.lock().await;
// Check if instance is running
if let Some(instance) = state.instances.get_mut(&model_config.port) {
// instance.busy = true;
instance.to_owned()
} else {
// Check resources
tracing::info!(msg = "Current state", ?state);
if ((state.used_ram + model_ram_usage) > state.total_ram)
|| ((state.used_vram + model_vram_usage) > state.total_vram)
{
// Stop other instances
let mut to_remove = Vec::new();
// TODO Actual smart stopping logic
// - search for smallest single model to stop to get enough room
// - if not possible search for smallest number of models to stop with lowest
// amount of "overshot
let instances_by_size =
state
.instances
.clone()
.into_iter()
.sorted_by(|(_, el_a), (_, el_b)| {
Ord::cmp(
&parse_size(el_b.config.vram_usage.as_str()),
&parse_size(el_a.config.vram_usage.as_str()),
)
});
for (port, instance) in instances_by_size {
// if !instance.busy {
tracing::info!("Stopping instance on port {}", port);
let mut process = instance.process.lock().await;
process.kill().await.ok();
to_remove.push(port);
state.used_ram -= parse_size(&instance.config.ram_usage).unwrap_or(0);
state.used_vram -= parse_size(&instance.config.vram_usage).unwrap_or(0);
if state.used_ram + model_ram_usage <= state.total_ram
&& state.used_vram + model_vram_usage <= state.total_vram
{
tracing::info!("Should have enough ram now");
break;
}
// }
}
for port in to_remove {
tracing::info!("Removing instance on port {}", port);
state.instances.remove(&port);
}
} else {
tracing::info!("Already enough res free");
}
// Start new instance
let args = model_config
.args
.iter()
.flat_map(|(k, v)| {
if v == "true" {
vec![format!("--{}", k)]
} else {
vec![format!("--{}", k), v.clone()]
}
})
.collect::<Vec<_>>();
let mut cmd = Command::new("llama-server");
cmd.kill_on_drop(true);
cmd.envs(model_config.env.clone());
cmd.args(&args);
// TODO use openport crate via pick_random_unused_port for determining these
cmd.arg("--port");
cmd.arg(format!(
"{}",
model_config
.internal_port
.expect("Unexpected empty port, should've been picked")
));
cmd.stdout(Stdio::null()).stderr(Stdio::null()); // TODO save output and allow retrieval via api
tracing::info!("Starting llama-server with {:?}", cmd);
let process = Arc::new(Mutex::new(
cmd.spawn().expect("Failed to start llama-server"),
));
state.used_ram += model_ram_usage;
state.used_vram += model_vram_usage;
let instance = LlamaInstance {
config: model_config.clone(),
process,
// busy: true,
};
sleep(Duration::from_millis(500)).await;
state.instances.insert(model_config.port, instance.clone());
instance
}
};
// Wait for the instance to be ready
is_llama_instance_running(&instance).await?;
wait_for_port(
model_config
.internal_port
.expect("Unexpected empty port, should've been picked"),
)
.await?;
// Proxy the request
let retry_policy = reqwest_retry::policies::ExponentialBackoff::builder()
.retry_bounds(Duration::from_millis(500), Duration::from_secs(8))
.jitter(reqwest_retry::Jitter::None) .jitter(reqwest_retry::Jitter::None)
.base(2) .base(2)
.build_with_max_retries(8); .build_with_max_retries(8);
let client = reqwest_middleware::ClientBuilder::new(Client::new()) let client: ClientWithMiddleware = ClientBuilder::new(Client::new())
.with(reqwest_retry::RetryTransientMiddleware::new_with_policy( .with(RetryTransientMiddleware::new_with_policy(retry_policy))
retry_policy,
))
.build(); .build();
let internal_port = model_config
.internal_port
.expect("Internal port must be set");
let uri = format!( let uri = format!(
"http://127.0.0.1:{}{}", "http://127.0.0.1:{}{}",
model_config internal_port,
.internal_port req.uri()
.expect("Unexpected empty port, should've been picked"), .path_and_query()
req.uri().path_and_query().map(|x| x.as_str()).unwrap_or("") .map(|pq| pq.as_str())
.unwrap_or("")
); );
let mut request_builder = client.request(req.method().clone(), &uri); let mut request_builder = client.request(req.method().clone(), &uri);
// let mut request_builder = reqwest::RequestBuilder::from_parts(client, req.method().clone()) // Forward all headers.
for (name, value) in req.headers() { for (name, value) in req.headers() {
request_builder = request_builder.header(name, value); request_builder = request_builder.header(name, value);
} }
let body_bytes = axum::body::to_bytes( let max_size = parse_size("1G").unwrap() as usize;
req.into_body(), let body_bytes = axum::body::to_bytes(req.into_body(), max_size).await?;
parse_size("1G").unwrap().try_into().unwrap(),
)
.await?;
let request = request_builder.body(body_bytes).build()?; let request = request_builder.body(body_bytes).build()?;
// tracing::info!("Proxying request to {}", uri);
let response = client.execute(request).await?; let response = client.execute(request).await?;
// tracing::info!("Received response from {}", uri);
let mut builder = Response::builder().status(response.status()); let mut builder = Response::builder().status(response.status());
for (name, value) in response.headers() { for (name, value) in response.headers() {
builder = builder.header(name, value); builder = builder.header(name, value);
} }
// let bytes = response.bytes().await?;
let byte_stream = response.bytes_stream(); let byte_stream = response.bytes_stream();
let body = Body::from_stream(byte_stream);
Ok(builder.body(body)?)
}
let body = axum::body::Body::from_stream(byte_stream); async fn handle_request(
req: Request<Body>,
model_config: ModelConfig,
state: SharedState,
) -> impl IntoResponse {
let _instance = LlamaInstance::get_or_spawn_instance(&model_config, state).await?;
let response = proxy_request(req, &model_config).await?;
tracing::debug!("streaming response on port: {}", model_config.port);
let response = builder.body(body)?;
Ok::<axum::http::Response<Body>, AppError>(response) Ok::<axum::http::Response<Body>, AppError>(response)
} }
async fn wait_for_port(port: u16) -> Result<(), anyhow::Error> {
for _ in 0..10 {
if TcpStream::connect(("127.0.0.1", port)).await.is_ok() {
return Ok(());
}
sleep(Duration::from_millis(500)).await;
}
Err(anyhow!("Timeout waiting for port"))
}
// reimplement wait_for_port using a LlamaInstance, also check if the process is still running
async fn is_llama_instance_running(instance: &LlamaInstance) -> Result<(), anyhow::Error> {
match instance.process.to_owned().lock_owned().await.try_wait()? {
Some(exit_status) => {
let msg = "Llama instance exited";
tracing::error!(msg, ?exit_status);
Err(anyhow!(msg))
}
None => Ok(()),
}
}
fn parse_size(size_str: &str) -> Option<u64> { fn parse_size(size_str: &str) -> Option<u64> {
let mut num = String::new(); let mut num = String::new();
let mut unit = String::new(); let mut unit = String::new();