Refactor llama_proxy_man

This commit is contained in:
Tristan D. 2025-02-10 23:22:31 +01:00
parent 8c062ece28
commit e3a3ec3826
Signed by: tristan
SSH key fingerprint: SHA256:3RU4RLOoM8oAjFU19f1W6t8uouZbA7GWkaSW6rjp1k8
3 changed files with 277 additions and 218 deletions

25
Cargo.lock generated
View file

@ -1141,6 +1141,26 @@ dependencies = [
"syn 2.0.98",
]
[[package]]
name = "derive_more"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "093242cf7570c207c83073cf82f79706fe7b8317e98620a47d5be7c3d8497678"
dependencies = [
"derive_more-impl",
]
[[package]]
name = "derive_more-impl"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bda628edc44c4bb645fbe0f758797143e4e07926f7ebf4e9bdfbd3d2ce621df3"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.98",
]
[[package]]
name = "digest"
version = "0.10.7"
@ -2950,7 +2970,7 @@ dependencies = [
"chrono",
"console_error_panic_hook",
"dashmap",
"derive_more",
"derive_more 0.99.19",
"futures",
"futures-util",
"gloo-net 0.5.0",
@ -3002,6 +3022,7 @@ version = "0.1.1"
dependencies = [
"anyhow",
"axum",
"derive_more 2.0.1",
"futures",
"hyper",
"itertools 0.13.0",
@ -4781,7 +4802,7 @@ checksum = "df320f1889ac4ba6bc0cdc9c9af7af4bd64bb927bccdf32d81140dc1f9be12fe"
dependencies = [
"bitflags 1.3.2",
"cssparser",
"derive_more",
"derive_more 0.99.19",
"fxhash",
"log",
"matches",

View file

@ -29,3 +29,4 @@ reqwest-retry = "0.6.1"
reqwest-middleware = { version = "0.3.3", features = ["charset", "http2", "json", "multipart", "rustls-tls"] }
itertools = "0.13.0"
openport = { version = "0.1.1", features = ["rand"] }
derive_more = { version = "2.0.1", features = ["deref"] }

View file

@ -11,6 +11,8 @@ use axum::{
use futures;
use itertools::Itertools;
use reqwest::Client;
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use serde::Deserialize;
use std::{collections::HashMap, net::SocketAddr, process::Stdio, sync::Arc};
use tokio::{
@ -19,10 +21,7 @@ use tokio::{
sync::Mutex,
time::{sleep, Duration},
};
use tower_http::trace::{
DefaultMakeSpan, DefaultOnEos, DefaultOnFailure, DefaultOnRequest, DefaultOnResponse,
TraceLayer,
};
use tower_http;
use tracing::Level;
use std::sync::Once;
@ -42,6 +41,8 @@ pub fn initialize_logger() {
});
}
// TODO Add valiation to the config
// - e.g. check double taken ports etc
#[derive(Clone, Debug, Deserialize)]
struct Config {
hardware: Hardware,
@ -49,7 +50,18 @@ struct Config {
}
impl Config {
fn default_from_pwd_yml() -> Self {
// Read and parse the YAML configuration
let config_str =
std::fs::read_to_string("config.yaml").expect("Failed to read config.yaml");
serde_yaml::from_str::<Self>(&config_str)
.expect("Failed to parse config.yaml")
.pick_open_ports()
}
// TODO split up into raw deser config and "parsed"/"processed" config which always has a port
// FIXME maybe just always initialize with a port ?
fn pick_open_ports(self) -> Self {
let mut config = self.clone();
for model in &mut config.models {
@ -89,8 +101,154 @@ struct LlamaInstance {
// busy: bool,
}
impl LlamaInstance {
/// Retrieve a running instance from state or spawn a new one.
pub async fn get_or_spawn_instance(
model_config: &ModelConfig,
state: SharedState,
) -> Result<LlamaInstance, AppError> {
let model_ram_usage =
parse_size(&model_config.ram_usage).expect("Invalid ram_usage in model config");
let model_vram_usage =
parse_size(&model_config.vram_usage).expect("Invalid vram_usage in model config");
{
// First, check if we already have an instance.
let state_guard = state.lock().await;
if let Some(instance) = state_guard.instances.get(&model_config.port) {
return Ok(instance.clone());
}
}
let mut state_guard = state.lock().await;
tracing::info!(msg = "Shared state before spawn", ?state_guard);
// If not enough free resources, stop some running instances.
if (state_guard.used_ram + model_ram_usage > state_guard.total_ram)
|| (state_guard.used_vram + model_vram_usage > state_guard.total_vram)
{
let mut to_remove = Vec::new();
let instances_by_size = state_guard
.instances
.iter()
.sorted_by(|(_, a), (_, b)| {
parse_size(&b.config.vram_usage)
.unwrap_or(0)
.cmp(&parse_size(&a.config.vram_usage).unwrap_or(0))
})
.map(|(port, instance)| (*port, instance.clone()))
.collect::<Vec<_>>();
for (port, instance) in instances_by_size.iter() {
tracing::info!("Stopping instance on port {}", port);
let mut proc = instance.process.lock().await;
proc.kill().await.ok();
to_remove.push(*port);
state_guard.used_ram = state_guard
.used_ram
.saturating_sub(parse_size(&instance.config.ram_usage).unwrap_or(0));
state_guard.used_vram = state_guard
.used_vram
.saturating_sub(parse_size(&instance.config.vram_usage).unwrap_or(0));
if state_guard.used_ram + model_ram_usage <= state_guard.total_ram
&& state_guard.used_vram + model_vram_usage <= state_guard.total_vram
{
tracing::info!("Freed enough resources");
break;
}
}
for port in to_remove {
state_guard.instances.remove(&port);
}
} else {
tracing::info!("Sufficient resources available");
}
// Spawn a new instance.
let instance = Self::spawn(&model_config).await?;
state_guard.used_ram += model_ram_usage;
state_guard.used_vram += model_vram_usage;
state_guard
.instances
.insert(model_config.port, instance.clone());
sleep(Duration::from_millis(250)).await;
instance.running_and_port_ready().await?;
sleep(Duration::from_millis(250)).await;
Ok(instance)
}
/// Spawn a new llama-server process based on the model configuration.
pub async fn spawn(model_config: &ModelConfig) -> Result<LlamaInstance, AppError> {
let args = model_config
.args
.iter()
.flat_map(|(k, v)| {
if v == "true" {
vec![format!("--{}", k)]
} else {
vec![format!("--{}", k), v.clone()]
}
})
.collect::<Vec<_>>();
let internal_port = model_config
.internal_port
.expect("Internal port must be set");
let mut cmd = Command::new("llama-server");
cmd.kill_on_drop(true)
.envs(model_config.env.clone())
.args(&args)
.arg("--port")
.arg(format!("{}", internal_port))
.stdout(Stdio::null())
.stderr(Stdio::null());
// Silence output could be captured later via an API.
tracing::info!("Starting llama-server with command: {:?}", cmd);
let child = cmd.spawn().expect("Failed to start llama-server");
Ok(LlamaInstance {
config: model_config.clone(),
process: Arc::new(Mutex::new(child)),
})
}
pub async fn running_and_port_ready(&self) -> Result<(), anyhow::Error> {
async fn is_running(instance: &LlamaInstance) -> Result<(), AppError> {
let mut proc = instance.process.clone().lock_owned().await;
match proc.try_wait()? {
Some(exit_status) => {
tracing::error!("Llama instance exited: {:?}", exit_status);
Err(AppError::Unknown(anyhow!("Llama instance exited")))
}
None => Ok(()),
}
}
async fn wait_for_port(port: u16) -> Result<(), anyhow::Error> {
for _ in 0..10 {
if TcpStream::connect(("127.0.0.1", port)).await.is_ok() {
return Ok(());
}
sleep(Duration::from_millis(500)).await;
}
Err(anyhow!("Timeout waiting for port"))
}
is_running(self).await?;
wait_for_port(self.config.internal_port.expect("No port picked?")).await?;
Ok(())
}
}
#[derive(Clone, Debug)]
struct SharedState {
pub struct InnerSharedState {
total_ram: u64,
total_vram: u64,
used_ram: u64,
@ -98,31 +256,67 @@ struct SharedState {
instances: HashMap<u16, LlamaInstance>,
}
#[tokio::main]
async fn main() {
// TODO add autostart of models based on config
// abstract starting logic out of handler for this to allow seperate calls to start
// maybe add to SharedState & LLamaInstance ?
type SharedStateArc = Arc<Mutex<InnerSharedState>>;
initialize_logger();
// Read and parse the YAML configuration
let config_str = std::fs::read_to_string("config.yaml").expect("Failed to read config.yaml");
let config: Config = serde_yaml::from_str::<Config>(&config_str)
.expect("Failed to parse config.yaml")
.pick_open_ports();
/// TODO migrate to dashmap + individual Arc<Mutex<T>> or Rwlocks
#[derive(Clone, Debug, derive_more::Deref)]
pub struct SharedState(SharedStateArc);
impl SharedState {
fn from_config(config: Config) -> Self {
// Parse hardware resources
let total_ram = parse_size(&config.hardware.ram).expect("Invalid RAM size in config");
let total_vram = parse_size(&config.hardware.vram).expect("Invalid VRAM size in config");
// Initialize shared state
let shared_state = Arc::new(Mutex::new(SharedState {
let shared_state = InnerSharedState {
total_ram,
total_vram,
used_ram: 0,
used_vram: 0,
instances: HashMap::new(),
}));
};
Self(Arc::new(Mutex::new(shared_state)))
}
}
/// Build an Axum app for a given model.
fn create_app(model_config: &ModelConfig, state: SharedState) -> Router {
Router::new()
.route(
"/",
any({
let state = state.clone();
let model_config = model_config.clone();
move |req| handle_request(req, model_config.clone(), state.clone())
}),
)
.route(
"/*path",
any({
let state = state.clone();
let model_config = model_config.clone();
move |req| handle_request(req, model_config.clone(), state.clone())
}),
)
.layer(
tower_http::trace::TraceLayer::new_for_http()
.make_span_with(tower_http::trace::DefaultMakeSpan::new().include_headers(true))
.on_request(tower_http::trace::DefaultOnRequest::new().level(Level::DEBUG))
.on_response(tower_http::trace::DefaultOnResponse::new().level(Level::TRACE))
.on_eos(tower_http::trace::DefaultOnEos::new().level(Level::DEBUG))
.on_failure(tower_http::trace::DefaultOnFailure::new().level(Level::ERROR)),
)
}
#[tokio::main]
async fn main() {
initialize_logger();
let config = Config::default_from_pwd_yml();
let shared_state = SharedState::from_config(config.clone());
// For each model, set up an axum server listening on the specified port
let mut handles = Vec::new();
@ -134,31 +328,7 @@ async fn main() {
let handle = tokio::spawn(async move {
let model_config = model_config.clone();
let app = Router::new()
.route(
"/",
any({
let state = state.clone();
let model_config = model_config.clone();
move |req| handle_request(req, model_config, state)
}),
)
.route(
"/*path",
any({
let state = state.clone();
let model_config = model_config.clone();
move |req| handle_request(req, model_config, state)
}),
)
.layer(
TraceLayer::new_for_http()
.make_span_with(DefaultMakeSpan::new().include_headers(true))
.on_request(DefaultOnRequest::new().level(Level::DEBUG))
.on_response(DefaultOnResponse::new().level(Level::TRACE))
.on_eos(DefaultOnEos::new().level(Level::DEBUG))
.on_failure(DefaultOnFailure::new().level(Level::ERROR)),
);
let app = create_app(&model_config, state);
let addr = SocketAddr::from(([0, 0, 0, 0], model_config.port));
@ -190,6 +360,8 @@ pub enum AppError {
ReqwestMiddlewareError(#[from] reqwest_middleware::Error),
#[error("Client Error")]
ClientError(#[from] hyper::Error),
#[error("Io Error")]
IoError(#[from] std::io::Error),
#[error("Unknown error")]
Unknown(#[from] anyhow::Error),
}
@ -206,202 +378,67 @@ impl IntoResponse for AppError {
}
}
async fn handle_request(
async fn proxy_request(
req: Request<Body>,
model_config: ModelConfig,
state: Arc<Mutex<SharedState>>,
// ) -> Result<Response<Body>, anyhow::Error> {
) -> impl IntoResponse {
let model_ram_usage = parse_size(&model_config.ram_usage).expect("Invalid ram_usage");
let model_vram_usage = parse_size(&model_config.vram_usage).expect("Invalid vram_usage");
let instance = {
let mut state = state.lock().await;
// Check if instance is running
if let Some(instance) = state.instances.get_mut(&model_config.port) {
// instance.busy = true;
instance.to_owned()
} else {
// Check resources
tracing::info!(msg = "Current state", ?state);
if ((state.used_ram + model_ram_usage) > state.total_ram)
|| ((state.used_vram + model_vram_usage) > state.total_vram)
{
// Stop other instances
let mut to_remove = Vec::new();
// TODO Actual smart stopping logic
// - search for smallest single model to stop to get enough room
// - if not possible search for smallest number of models to stop with lowest
// amount of "overshot
let instances_by_size =
state
.instances
.clone()
.into_iter()
.sorted_by(|(_, el_a), (_, el_b)| {
Ord::cmp(
&parse_size(el_b.config.vram_usage.as_str()),
&parse_size(el_a.config.vram_usage.as_str()),
model_config: &ModelConfig,
) -> Result<Response<Body>, AppError> {
let retry_policy = ExponentialBackoff::builder()
.retry_bounds(
std::time::Duration::from_millis(500),
std::time::Duration::from_secs(8),
)
});
for (port, instance) in instances_by_size {
// if !instance.busy {
tracing::info!("Stopping instance on port {}", port);
let mut process = instance.process.lock().await;
process.kill().await.ok();
to_remove.push(port);
state.used_ram -= parse_size(&instance.config.ram_usage).unwrap_or(0);
state.used_vram -= parse_size(&instance.config.vram_usage).unwrap_or(0);
if state.used_ram + model_ram_usage <= state.total_ram
&& state.used_vram + model_vram_usage <= state.total_vram
{
tracing::info!("Should have enough ram now");
break;
}
// }
}
for port in to_remove {
tracing::info!("Removing instance on port {}", port);
state.instances.remove(&port);
}
} else {
tracing::info!("Already enough res free");
}
// Start new instance
let args = model_config
.args
.iter()
.flat_map(|(k, v)| {
if v == "true" {
vec![format!("--{}", k)]
} else {
vec![format!("--{}", k), v.clone()]
}
})
.collect::<Vec<_>>();
let mut cmd = Command::new("llama-server");
cmd.kill_on_drop(true);
cmd.envs(model_config.env.clone());
cmd.args(&args);
// TODO use openport crate via pick_random_unused_port for determining these
cmd.arg("--port");
cmd.arg(format!(
"{}",
model_config
.internal_port
.expect("Unexpected empty port, should've been picked")
));
cmd.stdout(Stdio::null()).stderr(Stdio::null()); // TODO save output and allow retrieval via api
tracing::info!("Starting llama-server with {:?}", cmd);
let process = Arc::new(Mutex::new(
cmd.spawn().expect("Failed to start llama-server"),
));
state.used_ram += model_ram_usage;
state.used_vram += model_vram_usage;
let instance = LlamaInstance {
config: model_config.clone(),
process,
// busy: true,
};
sleep(Duration::from_millis(500)).await;
state.instances.insert(model_config.port, instance.clone());
instance
}
};
// Wait for the instance to be ready
is_llama_instance_running(&instance).await?;
wait_for_port(
model_config
.internal_port
.expect("Unexpected empty port, should've been picked"),
)
.await?;
// Proxy the request
let retry_policy = reqwest_retry::policies::ExponentialBackoff::builder()
.retry_bounds(Duration::from_millis(500), Duration::from_secs(8))
.jitter(reqwest_retry::Jitter::None)
.base(2)
.build_with_max_retries(8);
let client = reqwest_middleware::ClientBuilder::new(Client::new())
.with(reqwest_retry::RetryTransientMiddleware::new_with_policy(
retry_policy,
))
let client: ClientWithMiddleware = ClientBuilder::new(Client::new())
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
.build();
let internal_port = model_config
.internal_port
.expect("Internal port must be set");
let uri = format!(
"http://127.0.0.1:{}{}",
model_config
.internal_port
.expect("Unexpected empty port, should've been picked"),
req.uri().path_and_query().map(|x| x.as_str()).unwrap_or("")
internal_port,
req.uri()
.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or("")
);
let mut request_builder = client.request(req.method().clone(), &uri);
// let mut request_builder = reqwest::RequestBuilder::from_parts(client, req.method().clone())
// Forward all headers.
for (name, value) in req.headers() {
request_builder = request_builder.header(name, value);
}
let body_bytes = axum::body::to_bytes(
req.into_body(),
parse_size("1G").unwrap().try_into().unwrap(),
)
.await?;
let max_size = parse_size("1G").unwrap() as usize;
let body_bytes = axum::body::to_bytes(req.into_body(), max_size).await?;
let request = request_builder.body(body_bytes).build()?;
// tracing::info!("Proxying request to {}", uri);
let response = client.execute(request).await?;
// tracing::info!("Received response from {}", uri);
let mut builder = Response::builder().status(response.status());
for (name, value) in response.headers() {
builder = builder.header(name, value);
}
// let bytes = response.bytes().await?;
let byte_stream = response.bytes_stream();
let body = Body::from_stream(byte_stream);
Ok(builder.body(body)?)
}
let body = axum::body::Body::from_stream(byte_stream);
async fn handle_request(
req: Request<Body>,
model_config: ModelConfig,
state: SharedState,
) -> impl IntoResponse {
let _instance = LlamaInstance::get_or_spawn_instance(&model_config, state).await?;
let response = proxy_request(req, &model_config).await?;
tracing::debug!("streaming response on port: {}", model_config.port);
let response = builder.body(body)?;
Ok::<axum::http::Response<Body>, AppError>(response)
}
async fn wait_for_port(port: u16) -> Result<(), anyhow::Error> {
for _ in 0..10 {
if TcpStream::connect(("127.0.0.1", port)).await.is_ok() {
return Ok(());
}
sleep(Duration::from_millis(500)).await;
}
Err(anyhow!("Timeout waiting for port"))
}
// reimplement wait_for_port using a LlamaInstance, also check if the process is still running
async fn is_llama_instance_running(instance: &LlamaInstance) -> Result<(), anyhow::Error> {
match instance.process.to_owned().lock_owned().await.try_wait()? {
Some(exit_status) => {
let msg = "Llama instance exited";
tracing::error!(msg, ?exit_status);
Err(anyhow!(msg))
}
None => Ok(()),
}
}
fn parse_size(size_str: &str) -> Option<u64> {
let mut num = String::new();
let mut unit = String::new();