redvault-ai/llama_proxy_man/src/main.rs

465 lines
14 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

mod logging;
use anyhow::anyhow;
use axum::{
body::Body,
http::{self, Request, Response},
response::IntoResponse,
routing::any,
Router,
};
use futures;
use itertools::Itertools;
use reqwest::Client;
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use serde::Deserialize;
use std::{collections::HashMap, net::SocketAddr, process::Stdio, sync::Arc};
use tokio::{
net::TcpStream,
process::{Child, Command},
sync::Mutex,
time::{sleep, Duration},
};
use tower_http;
use tracing::Level;
use std::sync::Once;
pub static INIT: Once = Once::new();
pub fn initialize_logger() {
INIT.call_once(|| {
let env_filter = tracing_subscriber::EnvFilter::builder()
.with_default_directive(tracing::level_filters::LevelFilter::INFO.into())
.from_env_lossy();
tracing_subscriber::fmt()
.compact()
.with_env_filter(env_filter)
.init();
});
}
// TODO Add valiation to the config
// - e.g. check double taken ports etc
#[derive(Clone, Debug, Deserialize)]
struct Config {
hardware: Hardware,
models: Vec<ModelConfig>,
}
impl Config {
fn default_from_pwd_yml() -> Self {
// Read and parse the YAML configuration
let config_str =
std::fs::read_to_string("config.yaml").expect("Failed to read config.yaml");
serde_yaml::from_str::<Self>(&config_str)
.expect("Failed to parse config.yaml")
.pick_open_ports()
}
// TODO split up into raw deser config and "parsed"/"processed" config which always has a port
// FIXME maybe just always initialize with a port ?
fn pick_open_ports(self) -> Self {
let mut config = self.clone();
for model in &mut config.models {
if model.internal_port.is_none() {
model.internal_port = Some(
openport::pick_random_unused_port()
.expect(format!("No open port found for {:?}", model).as_str()),
);
}
}
config
}
}
#[derive(Clone, Debug, Deserialize)]
struct Hardware {
ram: String,
vram: String,
}
#[derive(Clone, Debug, Deserialize)]
struct ModelConfig {
#[allow(dead_code)]
name: String,
port: u16,
internal_port: Option<u16>,
env: HashMap<String, String>,
args: HashMap<String, String>,
vram_usage: String,
ram_usage: String,
}
#[derive(Clone, Debug)]
struct LlamaInstance {
config: ModelConfig,
process: Arc<Mutex<Child>>,
// busy: bool,
}
impl LlamaInstance {
/// Retrieve a running instance from state or spawn a new one.
pub async fn get_or_spawn_instance(
model_config: &ModelConfig,
state: SharedState,
) -> Result<LlamaInstance, AppError> {
let model_ram_usage =
parse_size(&model_config.ram_usage).expect("Invalid ram_usage in model config");
let model_vram_usage =
parse_size(&model_config.vram_usage).expect("Invalid vram_usage in model config");
{
// First, check if we already have an instance.
let state_guard = state.lock().await;
if let Some(instance) = state_guard.instances.get(&model_config.port) {
return Ok(instance.clone());
}
}
let mut state_guard = state.lock().await;
tracing::info!(msg = "Shared state before spawn", ?state_guard);
// If not enough free resources, stop some running instances.
if (state_guard.used_ram + model_ram_usage > state_guard.total_ram)
|| (state_guard.used_vram + model_vram_usage > state_guard.total_vram)
{
let mut to_remove = Vec::new();
let instances_by_size = state_guard
.instances
.iter()
.sorted_by(|(_, a), (_, b)| {
parse_size(&b.config.vram_usage)
.unwrap_or(0)
.cmp(&parse_size(&a.config.vram_usage).unwrap_or(0))
})
.map(|(port, instance)| (*port, instance.clone()))
.collect::<Vec<_>>();
for (port, instance) in instances_by_size.iter() {
tracing::info!("Stopping instance on port {}", port);
let mut proc = instance.process.lock().await;
proc.kill().await.ok();
to_remove.push(*port);
state_guard.used_ram = state_guard
.used_ram
.saturating_sub(parse_size(&instance.config.ram_usage).unwrap_or(0));
state_guard.used_vram = state_guard
.used_vram
.saturating_sub(parse_size(&instance.config.vram_usage).unwrap_or(0));
if state_guard.used_ram + model_ram_usage <= state_guard.total_ram
&& state_guard.used_vram + model_vram_usage <= state_guard.total_vram
{
tracing::info!("Freed enough resources");
break;
}
}
for port in to_remove {
state_guard.instances.remove(&port);
}
} else {
tracing::info!("Sufficient resources available");
}
// Spawn a new instance.
let instance = Self::spawn(&model_config).await?;
state_guard.used_ram += model_ram_usage;
state_guard.used_vram += model_vram_usage;
state_guard
.instances
.insert(model_config.port, instance.clone());
sleep(Duration::from_millis(250)).await;
instance.running_and_port_ready().await?;
sleep(Duration::from_millis(250)).await;
Ok(instance)
}
/// Spawn a new llama-server process based on the model configuration.
pub async fn spawn(model_config: &ModelConfig) -> Result<LlamaInstance, AppError> {
let args = model_config
.args
.iter()
.flat_map(|(k, v)| {
if v == "true" {
vec![format!("--{}", k)]
} else {
vec![format!("--{}", k), v.clone()]
}
})
.collect::<Vec<_>>();
let internal_port = model_config
.internal_port
.expect("Internal port must be set");
let mut cmd = Command::new("llama-server");
cmd.kill_on_drop(true)
.envs(model_config.env.clone())
.args(&args)
.arg("--port")
.arg(format!("{}", internal_port))
.stdout(Stdio::null())
.stderr(Stdio::null());
// Silence output could be captured later via an API.
tracing::info!("Starting llama-server with command: {:?}", cmd);
let child = cmd.spawn().expect("Failed to start llama-server");
Ok(LlamaInstance {
config: model_config.clone(),
process: Arc::new(Mutex::new(child)),
})
}
pub async fn running_and_port_ready(&self) -> Result<(), anyhow::Error> {
async fn is_running(instance: &LlamaInstance) -> Result<(), AppError> {
let mut proc = instance.process.clone().lock_owned().await;
match proc.try_wait()? {
Some(exit_status) => {
tracing::error!("Llama instance exited: {:?}", exit_status);
Err(AppError::Unknown(anyhow!("Llama instance exited")))
}
None => Ok(()),
}
}
async fn wait_for_port(port: u16) -> Result<(), anyhow::Error> {
for _ in 0..10 {
if TcpStream::connect(("127.0.0.1", port)).await.is_ok() {
return Ok(());
}
sleep(Duration::from_millis(500)).await;
}
Err(anyhow!("Timeout waiting for port"))
}
is_running(self).await?;
wait_for_port(self.config.internal_port.expect("No port picked?")).await?;
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct InnerSharedState {
total_ram: u64,
total_vram: u64,
used_ram: u64,
used_vram: u64,
instances: HashMap<u16, LlamaInstance>,
}
type SharedStateArc = Arc<Mutex<InnerSharedState>>;
/// TODO migrate to dashmap + individual Arc<Mutex<T>> or Rwlocks
#[derive(Clone, Debug, derive_more::Deref)]
pub struct SharedState(SharedStateArc);
impl SharedState {
fn from_config(config: Config) -> Self {
// Parse hardware resources
let total_ram = parse_size(&config.hardware.ram).expect("Invalid RAM size in config");
let total_vram = parse_size(&config.hardware.vram).expect("Invalid VRAM size in config");
// Initialize shared state
let shared_state = InnerSharedState {
total_ram,
total_vram,
used_ram: 0,
used_vram: 0,
instances: HashMap::new(),
};
Self(Arc::new(Mutex::new(shared_state)))
}
}
/// Build an Axum app for a given model.
fn create_app(model_config: &ModelConfig, state: SharedState) -> Router {
Router::new()
.route(
"/",
any({
let state = state.clone();
let model_config = model_config.clone();
move |req| handle_request(req, model_config.clone(), state.clone())
}),
)
.route(
"/*path",
any({
let state = state.clone();
let model_config = model_config.clone();
move |req| handle_request(req, model_config.clone(), state.clone())
}),
)
.layer(
tower_http::trace::TraceLayer::new_for_http()
.make_span_with(tower_http::trace::DefaultMakeSpan::new().include_headers(true))
.on_request(tower_http::trace::DefaultOnRequest::new().level(Level::DEBUG))
.on_response(tower_http::trace::DefaultOnResponse::new().level(Level::TRACE))
.on_eos(tower_http::trace::DefaultOnEos::new().level(Level::DEBUG))
.on_failure(tower_http::trace::DefaultOnFailure::new().level(Level::ERROR)),
)
}
#[tokio::main]
async fn main() {
initialize_logger();
let config = Config::default_from_pwd_yml();
let shared_state = SharedState::from_config(config.clone());
// For each model, set up an axum server listening on the specified port
let mut handles = Vec::new();
for model_config in config.models {
let state = shared_state.clone();
let model_config = model_config.clone();
let handle = tokio::spawn(async move {
let model_config = model_config.clone();
let app = create_app(&model_config, state);
let addr = SocketAddr::from(([0, 0, 0, 0], model_config.port));
tracing::info!(msg = "Listening", ?model_config);
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
axum::serve(listener, app.into_make_service())
.await
.unwrap();
});
handles.push(handle);
}
futures::future::join_all(handles).await;
}
use thiserror::Error;
#[derive(Error, Debug)]
pub enum AppError {
#[error("Axum Error")]
AxumError(#[from] axum::Error),
#[error("Axum Http Error")]
AxumHttpError(#[from] axum::http::Error),
#[error("Reqwest Error")]
ReqwestError(#[from] reqwest::Error),
#[error("Reqwest Middleware Error")]
ReqwestMiddlewareError(#[from] reqwest_middleware::Error),
#[error("Client Error")]
ClientError(#[from] hyper::Error),
#[error("Io Error")]
IoError(#[from] std::io::Error),
#[error("Unknown error")]
Unknown(#[from] anyhow::Error),
}
// Tell axum how to convert `AppError` into a response.
impl IntoResponse for AppError {
fn into_response(self) -> axum::response::Response {
tracing::error!("::AppError::into_response:: {:?}", self);
(
http::StatusCode::INTERNAL_SERVER_ERROR,
format!("Something went wrong: {:?}", self),
)
.into_response()
}
}
async fn proxy_request(
req: Request<Body>,
model_config: &ModelConfig,
) -> Result<Response<Body>, AppError> {
let retry_policy = ExponentialBackoff::builder()
.retry_bounds(
std::time::Duration::from_millis(500),
std::time::Duration::from_secs(8),
)
.jitter(reqwest_retry::Jitter::None)
.base(2)
.build_with_max_retries(8);
let client: ClientWithMiddleware = ClientBuilder::new(Client::new())
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
.build();
let internal_port = model_config
.internal_port
.expect("Internal port must be set");
let uri = format!(
"http://127.0.0.1:{}{}",
internal_port,
req.uri()
.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or("")
);
let mut request_builder = client.request(req.method().clone(), &uri);
// Forward all headers.
for (name, value) in req.headers() {
request_builder = request_builder.header(name, value);
}
let max_size = parse_size("1G").unwrap() as usize;
let body_bytes = axum::body::to_bytes(req.into_body(), max_size).await?;
let request = request_builder.body(body_bytes).build()?;
let response = client.execute(request).await?;
let mut builder = Response::builder().status(response.status());
for (name, value) in response.headers() {
builder = builder.header(name, value);
}
let byte_stream = response.bytes_stream();
let body = Body::from_stream(byte_stream);
Ok(builder.body(body)?)
}
async fn handle_request(
req: Request<Body>,
model_config: ModelConfig,
state: SharedState,
) -> impl IntoResponse {
let _instance = LlamaInstance::get_or_spawn_instance(&model_config, state).await?;
let response = proxy_request(req, &model_config).await?;
Ok::<axum::http::Response<Body>, AppError>(response)
}
fn parse_size(size_str: &str) -> Option<u64> {
let mut num = String::new();
let mut unit = String::new();
for c in size_str.chars() {
if c.is_digit(10) || c == '.' {
num.push(c);
} else {
unit.push(c);
}
}
let num: f64 = num.parse().ok()?;
let multiplier = match unit.to_lowercase().as_str() {
"g" | "gb" => 1024 * 1024 * 1024,
"m" | "mb" => 1024 * 1024,
"k" | "kb" => 1024,
_ => panic!("Invalid Size"),
};
let res = (num * multiplier as f64) as u64;
Some(res)
}