465 lines
14 KiB
Rust
465 lines
14 KiB
Rust
mod logging;
|
||
|
||
use anyhow::anyhow;
|
||
use axum::{
|
||
body::Body,
|
||
http::{self, Request, Response},
|
||
response::IntoResponse,
|
||
routing::any,
|
||
Router,
|
||
};
|
||
use futures;
|
||
use itertools::Itertools;
|
||
use reqwest::Client;
|
||
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
|
||
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
||
use serde::Deserialize;
|
||
use std::{collections::HashMap, net::SocketAddr, process::Stdio, sync::Arc};
|
||
use tokio::{
|
||
net::TcpStream,
|
||
process::{Child, Command},
|
||
sync::Mutex,
|
||
time::{sleep, Duration},
|
||
};
|
||
use tower_http;
|
||
use tracing::Level;
|
||
|
||
use std::sync::Once;
|
||
|
||
pub static INIT: Once = Once::new();
|
||
|
||
pub fn initialize_logger() {
|
||
INIT.call_once(|| {
|
||
let env_filter = tracing_subscriber::EnvFilter::builder()
|
||
.with_default_directive(tracing::level_filters::LevelFilter::INFO.into())
|
||
.from_env_lossy();
|
||
|
||
tracing_subscriber::fmt()
|
||
.compact()
|
||
.with_env_filter(env_filter)
|
||
.init();
|
||
});
|
||
}
|
||
|
||
// TODO Add valiation to the config
|
||
// - e.g. check double taken ports etc
|
||
#[derive(Clone, Debug, Deserialize)]
|
||
struct Config {
|
||
hardware: Hardware,
|
||
models: Vec<ModelConfig>,
|
||
}
|
||
|
||
impl Config {
|
||
fn default_from_pwd_yml() -> Self {
|
||
// Read and parse the YAML configuration
|
||
let config_str =
|
||
std::fs::read_to_string("config.yaml").expect("Failed to read config.yaml");
|
||
|
||
serde_yaml::from_str::<Self>(&config_str)
|
||
.expect("Failed to parse config.yaml")
|
||
.pick_open_ports()
|
||
}
|
||
|
||
// TODO split up into raw deser config and "parsed"/"processed" config which always has a port
|
||
// FIXME maybe just always initialize with a port ?
|
||
fn pick_open_ports(self) -> Self {
|
||
let mut config = self.clone();
|
||
for model in &mut config.models {
|
||
if model.internal_port.is_none() {
|
||
model.internal_port = Some(
|
||
openport::pick_random_unused_port()
|
||
.expect(format!("No open port found for {:?}", model).as_str()),
|
||
);
|
||
}
|
||
}
|
||
config
|
||
}
|
||
}
|
||
|
||
#[derive(Clone, Debug, Deserialize)]
|
||
struct Hardware {
|
||
ram: String,
|
||
vram: String,
|
||
}
|
||
|
||
#[derive(Clone, Debug, Deserialize)]
|
||
struct ModelConfig {
|
||
#[allow(dead_code)]
|
||
name: String,
|
||
port: u16,
|
||
internal_port: Option<u16>,
|
||
env: HashMap<String, String>,
|
||
args: HashMap<String, String>,
|
||
vram_usage: String,
|
||
ram_usage: String,
|
||
}
|
||
|
||
#[derive(Clone, Debug)]
|
||
struct LlamaInstance {
|
||
config: ModelConfig,
|
||
process: Arc<Mutex<Child>>,
|
||
// busy: bool,
|
||
}
|
||
|
||
impl LlamaInstance {
|
||
/// Retrieve a running instance from state or spawn a new one.
|
||
pub async fn get_or_spawn_instance(
|
||
model_config: &ModelConfig,
|
||
state: SharedState,
|
||
) -> Result<LlamaInstance, AppError> {
|
||
let model_ram_usage =
|
||
parse_size(&model_config.ram_usage).expect("Invalid ram_usage in model config");
|
||
let model_vram_usage =
|
||
parse_size(&model_config.vram_usage).expect("Invalid vram_usage in model config");
|
||
|
||
{
|
||
// First, check if we already have an instance.
|
||
let state_guard = state.lock().await;
|
||
if let Some(instance) = state_guard.instances.get(&model_config.port) {
|
||
return Ok(instance.clone());
|
||
}
|
||
}
|
||
|
||
let mut state_guard = state.lock().await;
|
||
tracing::info!(msg = "Shared state before spawn", ?state_guard);
|
||
|
||
// If not enough free resources, stop some running instances.
|
||
if (state_guard.used_ram + model_ram_usage > state_guard.total_ram)
|
||
|| (state_guard.used_vram + model_vram_usage > state_guard.total_vram)
|
||
{
|
||
let mut to_remove = Vec::new();
|
||
let instances_by_size = state_guard
|
||
.instances
|
||
.iter()
|
||
.sorted_by(|(_, a), (_, b)| {
|
||
parse_size(&b.config.vram_usage)
|
||
.unwrap_or(0)
|
||
.cmp(&parse_size(&a.config.vram_usage).unwrap_or(0))
|
||
})
|
||
.map(|(port, instance)| (*port, instance.clone()))
|
||
.collect::<Vec<_>>();
|
||
|
||
for (port, instance) in instances_by_size.iter() {
|
||
tracing::info!("Stopping instance on port {}", port);
|
||
let mut proc = instance.process.lock().await;
|
||
proc.kill().await.ok();
|
||
to_remove.push(*port);
|
||
state_guard.used_ram = state_guard
|
||
.used_ram
|
||
.saturating_sub(parse_size(&instance.config.ram_usage).unwrap_or(0));
|
||
state_guard.used_vram = state_guard
|
||
.used_vram
|
||
.saturating_sub(parse_size(&instance.config.vram_usage).unwrap_or(0));
|
||
if state_guard.used_ram + model_ram_usage <= state_guard.total_ram
|
||
&& state_guard.used_vram + model_vram_usage <= state_guard.total_vram
|
||
{
|
||
tracing::info!("Freed enough resources");
|
||
break;
|
||
}
|
||
}
|
||
for port in to_remove {
|
||
state_guard.instances.remove(&port);
|
||
}
|
||
} else {
|
||
tracing::info!("Sufficient resources available");
|
||
}
|
||
|
||
// Spawn a new instance.
|
||
let instance = Self::spawn(&model_config).await?;
|
||
state_guard.used_ram += model_ram_usage;
|
||
state_guard.used_vram += model_vram_usage;
|
||
state_guard
|
||
.instances
|
||
.insert(model_config.port, instance.clone());
|
||
|
||
sleep(Duration::from_millis(250)).await;
|
||
|
||
instance.running_and_port_ready().await?;
|
||
|
||
sleep(Duration::from_millis(250)).await;
|
||
|
||
Ok(instance)
|
||
}
|
||
|
||
/// Spawn a new llama-server process based on the model configuration.
|
||
pub async fn spawn(model_config: &ModelConfig) -> Result<LlamaInstance, AppError> {
|
||
let args = model_config
|
||
.args
|
||
.iter()
|
||
.flat_map(|(k, v)| {
|
||
if v == "true" {
|
||
vec![format!("--{}", k)]
|
||
} else {
|
||
vec![format!("--{}", k), v.clone()]
|
||
}
|
||
})
|
||
.collect::<Vec<_>>();
|
||
|
||
let internal_port = model_config
|
||
.internal_port
|
||
.expect("Internal port must be set");
|
||
|
||
let mut cmd = Command::new("llama-server");
|
||
cmd.kill_on_drop(true)
|
||
.envs(model_config.env.clone())
|
||
.args(&args)
|
||
.arg("--port")
|
||
.arg(format!("{}", internal_port))
|
||
.stdout(Stdio::null())
|
||
.stderr(Stdio::null());
|
||
|
||
// Silence output – could be captured later via an API.
|
||
|
||
tracing::info!("Starting llama-server with command: {:?}", cmd);
|
||
let child = cmd.spawn().expect("Failed to start llama-server");
|
||
|
||
Ok(LlamaInstance {
|
||
config: model_config.clone(),
|
||
process: Arc::new(Mutex::new(child)),
|
||
})
|
||
}
|
||
|
||
pub async fn running_and_port_ready(&self) -> Result<(), anyhow::Error> {
|
||
async fn is_running(instance: &LlamaInstance) -> Result<(), AppError> {
|
||
let mut proc = instance.process.clone().lock_owned().await;
|
||
match proc.try_wait()? {
|
||
Some(exit_status) => {
|
||
tracing::error!("Llama instance exited: {:?}", exit_status);
|
||
Err(AppError::Unknown(anyhow!("Llama instance exited")))
|
||
}
|
||
None => Ok(()),
|
||
}
|
||
}
|
||
|
||
async fn wait_for_port(port: u16) -> Result<(), anyhow::Error> {
|
||
for _ in 0..10 {
|
||
if TcpStream::connect(("127.0.0.1", port)).await.is_ok() {
|
||
return Ok(());
|
||
}
|
||
sleep(Duration::from_millis(500)).await;
|
||
}
|
||
Err(anyhow!("Timeout waiting for port"))
|
||
}
|
||
|
||
is_running(self).await?;
|
||
wait_for_port(self.config.internal_port.expect("No port picked?")).await?;
|
||
Ok(())
|
||
}
|
||
}
|
||
|
||
#[derive(Clone, Debug)]
|
||
pub struct InnerSharedState {
|
||
total_ram: u64,
|
||
total_vram: u64,
|
||
used_ram: u64,
|
||
used_vram: u64,
|
||
instances: HashMap<u16, LlamaInstance>,
|
||
}
|
||
|
||
type SharedStateArc = Arc<Mutex<InnerSharedState>>;
|
||
|
||
/// TODO migrate to dashmap + individual Arc<Mutex<T>> or Rwlocks
|
||
#[derive(Clone, Debug, derive_more::Deref)]
|
||
pub struct SharedState(SharedStateArc);
|
||
|
||
impl SharedState {
|
||
fn from_config(config: Config) -> Self {
|
||
// Parse hardware resources
|
||
let total_ram = parse_size(&config.hardware.ram).expect("Invalid RAM size in config");
|
||
let total_vram = parse_size(&config.hardware.vram).expect("Invalid VRAM size in config");
|
||
|
||
// Initialize shared state
|
||
let shared_state = InnerSharedState {
|
||
total_ram,
|
||
total_vram,
|
||
used_ram: 0,
|
||
used_vram: 0,
|
||
instances: HashMap::new(),
|
||
};
|
||
|
||
Self(Arc::new(Mutex::new(shared_state)))
|
||
}
|
||
}
|
||
|
||
/// Build an Axum app for a given model.
|
||
fn create_app(model_config: &ModelConfig, state: SharedState) -> Router {
|
||
Router::new()
|
||
.route(
|
||
"/",
|
||
any({
|
||
let state = state.clone();
|
||
let model_config = model_config.clone();
|
||
move |req| handle_request(req, model_config.clone(), state.clone())
|
||
}),
|
||
)
|
||
.route(
|
||
"/*path",
|
||
any({
|
||
let state = state.clone();
|
||
let model_config = model_config.clone();
|
||
move |req| handle_request(req, model_config.clone(), state.clone())
|
||
}),
|
||
)
|
||
.layer(
|
||
tower_http::trace::TraceLayer::new_for_http()
|
||
.make_span_with(tower_http::trace::DefaultMakeSpan::new().include_headers(true))
|
||
.on_request(tower_http::trace::DefaultOnRequest::new().level(Level::DEBUG))
|
||
.on_response(tower_http::trace::DefaultOnResponse::new().level(Level::TRACE))
|
||
.on_eos(tower_http::trace::DefaultOnEos::new().level(Level::DEBUG))
|
||
.on_failure(tower_http::trace::DefaultOnFailure::new().level(Level::ERROR)),
|
||
)
|
||
}
|
||
|
||
#[tokio::main]
|
||
async fn main() {
|
||
initialize_logger();
|
||
|
||
let config = Config::default_from_pwd_yml();
|
||
|
||
let shared_state = SharedState::from_config(config.clone());
|
||
|
||
// For each model, set up an axum server listening on the specified port
|
||
let mut handles = Vec::new();
|
||
|
||
for model_config in config.models {
|
||
let state = shared_state.clone();
|
||
let model_config = model_config.clone();
|
||
|
||
let handle = tokio::spawn(async move {
|
||
let model_config = model_config.clone();
|
||
|
||
let app = create_app(&model_config, state);
|
||
|
||
let addr = SocketAddr::from(([0, 0, 0, 0], model_config.port));
|
||
|
||
tracing::info!(msg = "Listening", ?model_config);
|
||
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
|
||
|
||
axum::serve(listener, app.into_make_service())
|
||
.await
|
||
.unwrap();
|
||
});
|
||
|
||
handles.push(handle);
|
||
}
|
||
|
||
futures::future::join_all(handles).await;
|
||
}
|
||
|
||
use thiserror::Error;
|
||
|
||
#[derive(Error, Debug)]
|
||
pub enum AppError {
|
||
#[error("Axum Error")]
|
||
AxumError(#[from] axum::Error),
|
||
#[error("Axum Http Error")]
|
||
AxumHttpError(#[from] axum::http::Error),
|
||
#[error("Reqwest Error")]
|
||
ReqwestError(#[from] reqwest::Error),
|
||
#[error("Reqwest Middleware Error")]
|
||
ReqwestMiddlewareError(#[from] reqwest_middleware::Error),
|
||
#[error("Client Error")]
|
||
ClientError(#[from] hyper::Error),
|
||
#[error("Io Error")]
|
||
IoError(#[from] std::io::Error),
|
||
#[error("Unknown error")]
|
||
Unknown(#[from] anyhow::Error),
|
||
}
|
||
|
||
// Tell axum how to convert `AppError` into a response.
|
||
impl IntoResponse for AppError {
|
||
fn into_response(self) -> axum::response::Response {
|
||
tracing::error!("::AppError::into_response:: {:?}", self);
|
||
(
|
||
http::StatusCode::INTERNAL_SERVER_ERROR,
|
||
format!("Something went wrong: {:?}", self),
|
||
)
|
||
.into_response()
|
||
}
|
||
}
|
||
|
||
async fn proxy_request(
|
||
req: Request<Body>,
|
||
model_config: &ModelConfig,
|
||
) -> Result<Response<Body>, AppError> {
|
||
let retry_policy = ExponentialBackoff::builder()
|
||
.retry_bounds(
|
||
std::time::Duration::from_millis(500),
|
||
std::time::Duration::from_secs(8),
|
||
)
|
||
.jitter(reqwest_retry::Jitter::None)
|
||
.base(2)
|
||
.build_with_max_retries(8);
|
||
|
||
let client: ClientWithMiddleware = ClientBuilder::new(Client::new())
|
||
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
|
||
.build();
|
||
|
||
let internal_port = model_config
|
||
.internal_port
|
||
.expect("Internal port must be set");
|
||
let uri = format!(
|
||
"http://127.0.0.1:{}{}",
|
||
internal_port,
|
||
req.uri()
|
||
.path_and_query()
|
||
.map(|pq| pq.as_str())
|
||
.unwrap_or("")
|
||
);
|
||
|
||
let mut request_builder = client.request(req.method().clone(), &uri);
|
||
// Forward all headers.
|
||
for (name, value) in req.headers() {
|
||
request_builder = request_builder.header(name, value);
|
||
}
|
||
|
||
let max_size = parse_size("1G").unwrap() as usize;
|
||
let body_bytes = axum::body::to_bytes(req.into_body(), max_size).await?;
|
||
let request = request_builder.body(body_bytes).build()?;
|
||
let response = client.execute(request).await?;
|
||
|
||
let mut builder = Response::builder().status(response.status());
|
||
for (name, value) in response.headers() {
|
||
builder = builder.header(name, value);
|
||
}
|
||
|
||
let byte_stream = response.bytes_stream();
|
||
let body = Body::from_stream(byte_stream);
|
||
Ok(builder.body(body)?)
|
||
}
|
||
|
||
async fn handle_request(
|
||
req: Request<Body>,
|
||
model_config: ModelConfig,
|
||
state: SharedState,
|
||
) -> impl IntoResponse {
|
||
let _instance = LlamaInstance::get_or_spawn_instance(&model_config, state).await?;
|
||
let response = proxy_request(req, &model_config).await?;
|
||
|
||
Ok::<axum::http::Response<Body>, AppError>(response)
|
||
}
|
||
|
||
fn parse_size(size_str: &str) -> Option<u64> {
|
||
let mut num = String::new();
|
||
let mut unit = String::new();
|
||
|
||
for c in size_str.chars() {
|
||
if c.is_digit(10) || c == '.' {
|
||
num.push(c);
|
||
} else {
|
||
unit.push(c);
|
||
}
|
||
}
|
||
|
||
let num: f64 = num.parse().ok()?;
|
||
|
||
let multiplier = match unit.to_lowercase().as_str() {
|
||
"g" | "gb" => 1024 * 1024 * 1024,
|
||
"m" | "mb" => 1024 * 1024,
|
||
"k" | "kb" => 1024,
|
||
_ => panic!("Invalid Size"),
|
||
};
|
||
|
||
let res = (num * multiplier as f64) as u64;
|
||
Some(res)
|
||
}
|