refactor(proxy_man): Org & modularize app
- Introduce new modules: `config.rs`, `error.rs`, `inference_process.rs`, `proxy.rs`, `state.rs`, `util.rs` - Update `logging.rs` to include static initialization of logging - Add lib.rs for including in other projects - Refactor `main.rs` to use new modules and improve code structure
This commit is contained in:
parent
e3a3ec3826
commit
ad0cd12877
9 changed files with 452 additions and 459 deletions
48
llama_proxy_man/src/config.rs
Normal file
48
llama_proxy_man/src/config.rs
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
use serde::Deserialize;
|
||||||
|
use std::{collections::HashMap, fs};
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
|
pub struct AppConfig {
|
||||||
|
pub system_resources: SystemResources,
|
||||||
|
pub model_specs: Vec<ModelSpec>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AppConfig {
|
||||||
|
pub fn default_from_pwd_yml() -> Self {
|
||||||
|
let config_str = fs::read_to_string("config.yaml").expect("Failed to read config.yaml");
|
||||||
|
serde_yaml::from_str::<Self>(&config_str)
|
||||||
|
.expect("Failed to parse config.yaml")
|
||||||
|
.assign_internal_ports()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure every model has an internal port
|
||||||
|
pub fn assign_internal_ports(self) -> Self {
|
||||||
|
let mut config = self.clone();
|
||||||
|
for model in &mut config.model_specs {
|
||||||
|
if model.internal_port.is_none() {
|
||||||
|
model.internal_port = Some(
|
||||||
|
openport::pick_random_unused_port()
|
||||||
|
.expect(&format!("No open port found for {:?}", model)),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
config
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
|
pub struct SystemResources {
|
||||||
|
pub ram: String,
|
||||||
|
pub vram: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
|
pub struct ModelSpec {
|
||||||
|
pub name: String,
|
||||||
|
pub port: u16,
|
||||||
|
pub internal_port: Option<u16>,
|
||||||
|
pub env: HashMap<String, String>,
|
||||||
|
pub args: HashMap<String, String>,
|
||||||
|
pub vram_usage: String,
|
||||||
|
pub ram_usage: String,
|
||||||
|
}
|
36
llama_proxy_man/src/error.rs
Normal file
36
llama_proxy_man/src/error.rs
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
use axum::{http, response::IntoResponse};
|
||||||
|
use hyper;
|
||||||
|
use reqwest;
|
||||||
|
use reqwest_middleware;
|
||||||
|
use std::io;
|
||||||
|
use thiserror::Error;
|
||||||
|
use anyhow::Error as AnyError;
|
||||||
|
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub enum AppError {
|
||||||
|
#[error("Axum Error")]
|
||||||
|
AxumError(#[from] axum::Error),
|
||||||
|
#[error("Axum Http Error")]
|
||||||
|
AxumHttpError(#[from] http::Error),
|
||||||
|
#[error("Reqwest Error")]
|
||||||
|
ReqwestError(#[from] reqwest::Error),
|
||||||
|
#[error("Reqwest Middleware Error")]
|
||||||
|
ReqwestMiddlewareError(#[from] reqwest_middleware::Error),
|
||||||
|
#[error("Client Error")]
|
||||||
|
ClientError(#[from] hyper::Error),
|
||||||
|
#[error("Io Error")]
|
||||||
|
IoError(#[from] io::Error),
|
||||||
|
#[error("Unknown error")]
|
||||||
|
Unknown(#[from] AnyError),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoResponse for AppError {
|
||||||
|
fn into_response(self) -> axum::response::Response {
|
||||||
|
tracing::error!("::AppError::into_response: {:?}", self);
|
||||||
|
(
|
||||||
|
http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Something went wrong: {:?}", self),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
}
|
149
llama_proxy_man/src/inference_process.rs
Normal file
149
llama_proxy_man/src/inference_process.rs
Normal file
|
@ -0,0 +1,149 @@
|
||||||
|
use crate::{config::ModelSpec, error::AppError, state::AppState, util::parse_size};
|
||||||
|
use anyhow::anyhow;
|
||||||
|
use itertools::Itertools;
|
||||||
|
use std::{process::Stdio, sync::Arc};
|
||||||
|
use tokio::{
|
||||||
|
net::TcpStream,
|
||||||
|
process::{Child, Command},
|
||||||
|
sync::Mutex,
|
||||||
|
time::{sleep, Duration},
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct InferenceProcess {
|
||||||
|
pub spec: ModelSpec,
|
||||||
|
pub process: Arc<Mutex<Child>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl InferenceProcess {
|
||||||
|
/// Retrieve a running process from state or spawn a new one.
|
||||||
|
pub async fn get_or_spawn(
|
||||||
|
spec: &ModelSpec,
|
||||||
|
state: AppState,
|
||||||
|
) -> Result<InferenceProcess, AppError> {
|
||||||
|
let required_ram = parse_size(&spec.ram_usage).expect("Invalid ram_usage in model spec");
|
||||||
|
let required_vram = parse_size(&spec.vram_usage).expect("Invalid vram_usage in model spec");
|
||||||
|
|
||||||
|
{
|
||||||
|
let state_guard = state.lock().await;
|
||||||
|
if let Some(proc) = state_guard.processes.get(&spec.port) {
|
||||||
|
return Ok(proc.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut state_guard = state.lock().await;
|
||||||
|
tracing::info!(msg = "App state before spawn", ?state_guard);
|
||||||
|
|
||||||
|
// If not enough resources, stop some running processes.
|
||||||
|
if (state_guard.used_ram + required_ram > state_guard.total_ram)
|
||||||
|
|| (state_guard.used_vram + required_vram > state_guard.total_vram)
|
||||||
|
{
|
||||||
|
let mut to_remove = Vec::new();
|
||||||
|
let processes_by_usage = state_guard
|
||||||
|
.processes
|
||||||
|
.iter()
|
||||||
|
.sorted_by(|(_, a), (_, b)| {
|
||||||
|
parse_size(&b.spec.vram_usage)
|
||||||
|
.unwrap_or(0)
|
||||||
|
.cmp(&parse_size(&a.spec.vram_usage).unwrap_or(0))
|
||||||
|
})
|
||||||
|
.map(|(port, proc)| (*port, proc.clone()))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
for (port, proc) in processes_by_usage.iter() {
|
||||||
|
tracing::info!("Stopping process on port {}", port);
|
||||||
|
let mut lock = proc.process.lock().await;
|
||||||
|
lock.kill().await.ok();
|
||||||
|
to_remove.push(*port);
|
||||||
|
state_guard.used_ram = state_guard
|
||||||
|
.used_ram
|
||||||
|
.saturating_sub(parse_size(&proc.spec.ram_usage).unwrap_or(0));
|
||||||
|
state_guard.used_vram = state_guard
|
||||||
|
.used_vram
|
||||||
|
.saturating_sub(parse_size(&proc.spec.vram_usage).unwrap_or(0));
|
||||||
|
if state_guard.used_ram + required_ram <= state_guard.total_ram
|
||||||
|
&& state_guard.used_vram + required_vram <= state_guard.total_vram
|
||||||
|
{
|
||||||
|
tracing::info!("Freed enough resources");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for port in to_remove {
|
||||||
|
state_guard.processes.remove(&port);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tracing::info!("Sufficient resources available");
|
||||||
|
}
|
||||||
|
|
||||||
|
let proc = Self::spawn(spec).await?;
|
||||||
|
state_guard.used_ram += required_ram;
|
||||||
|
state_guard.used_vram += required_vram;
|
||||||
|
state_guard.processes.insert(spec.port, proc.clone());
|
||||||
|
|
||||||
|
sleep(Duration::from_millis(250)).await;
|
||||||
|
proc.wait_until_ready().await?;
|
||||||
|
sleep(Duration::from_millis(250)).await;
|
||||||
|
Ok(proc)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Spawn a new inference process.
|
||||||
|
pub async fn spawn(spec: &ModelSpec) -> Result<InferenceProcess, AppError> {
|
||||||
|
let args = spec
|
||||||
|
.args
|
||||||
|
.iter()
|
||||||
|
.flat_map(|(k, v)| {
|
||||||
|
if v == "true" {
|
||||||
|
vec![format!("--{}", k)]
|
||||||
|
} else {
|
||||||
|
vec![format!("--{}", k), v.clone()]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let internal_port = spec.internal_port.expect("Internal port must be set");
|
||||||
|
|
||||||
|
let mut cmd = Command::new("llama-server");
|
||||||
|
cmd.kill_on_drop(true)
|
||||||
|
.envs(spec.env.clone())
|
||||||
|
.args(&args)
|
||||||
|
.arg("--port")
|
||||||
|
.arg(format!("{}", internal_port))
|
||||||
|
.stdout(Stdio::null())
|
||||||
|
.stderr(Stdio::null());
|
||||||
|
|
||||||
|
tracing::info!("Starting llama-server with command: {:?}", cmd);
|
||||||
|
let child = cmd.spawn().expect("Failed to start llama-server");
|
||||||
|
|
||||||
|
Ok(InferenceProcess {
|
||||||
|
spec: spec.clone(),
|
||||||
|
process: Arc::new(Mutex::new(child)),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn wait_until_ready(&self) -> Result<(), anyhow::Error> {
|
||||||
|
async fn check_running(proc: &InferenceProcess) -> Result<(), AppError> {
|
||||||
|
let mut lock = proc.process.clone().lock_owned().await;
|
||||||
|
match lock.try_wait()? {
|
||||||
|
Some(exit_status) => {
|
||||||
|
tracing::error!("Inference process exited: {:?}", exit_status);
|
||||||
|
Err(AppError::Unknown(anyhow!("Inference process exited")))
|
||||||
|
}
|
||||||
|
None => Ok(()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn wait_for_port(port: u16) -> Result<(), anyhow::Error> {
|
||||||
|
for _ in 0..10 {
|
||||||
|
if TcpStream::connect(("127.0.0.1", port)).await.is_ok() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
sleep(Duration::from_millis(500)).await;
|
||||||
|
}
|
||||||
|
Err(anyhow!("Timeout waiting for port"))
|
||||||
|
}
|
||||||
|
|
||||||
|
check_running(self).await?;
|
||||||
|
wait_for_port(self.spec.internal_port.expect("Internal port must be set")).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
70
llama_proxy_man/src/lib.rs
Normal file
70
llama_proxy_man/src/lib.rs
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
pub mod config;
|
||||||
|
pub mod error;
|
||||||
|
pub mod inference_process;
|
||||||
|
pub mod logging;
|
||||||
|
pub mod proxy;
|
||||||
|
pub mod state;
|
||||||
|
pub mod util;
|
||||||
|
|
||||||
|
use axum::{routing::any, Router};
|
||||||
|
use config::{AppConfig, ModelSpec};
|
||||||
|
use state::AppState;
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
use tower_http::trace::{
|
||||||
|
DefaultMakeSpan, DefaultOnEos, DefaultOnFailure, DefaultOnRequest, DefaultOnResponse,
|
||||||
|
TraceLayer,
|
||||||
|
};
|
||||||
|
use tracing::Level;
|
||||||
|
|
||||||
|
/// Creates an Axum application to handle inference requests for a specific model.
|
||||||
|
pub fn create_app(spec: &ModelSpec, state: AppState) -> Router {
|
||||||
|
Router::new()
|
||||||
|
.route(
|
||||||
|
"/",
|
||||||
|
any({
|
||||||
|
let state = state.clone();
|
||||||
|
let spec = spec.clone();
|
||||||
|
move |req| proxy::handle_request(req, spec.clone(), state.clone())
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.route(
|
||||||
|
"/*path",
|
||||||
|
any({
|
||||||
|
let state = state.clone();
|
||||||
|
let spec = spec.clone();
|
||||||
|
move |req| proxy::handle_request(req, spec.clone(), state.clone())
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.layer(
|
||||||
|
TraceLayer::new_for_http()
|
||||||
|
.make_span_with(DefaultMakeSpan::new().include_headers(true))
|
||||||
|
.on_request(DefaultOnRequest::new().level(Level::DEBUG))
|
||||||
|
.on_response(DefaultOnResponse::new().level(Level::TRACE))
|
||||||
|
.on_eos(DefaultOnEos::new().level(Level::DEBUG))
|
||||||
|
.on_failure(DefaultOnFailure::new().level(Level::ERROR)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Starts an inference server for each model defined in the config.
|
||||||
|
pub async fn start_server(config: AppConfig) {
|
||||||
|
let state = AppState::from_config(config.clone());
|
||||||
|
|
||||||
|
let mut handles = Vec::new();
|
||||||
|
for spec in config.model_specs {
|
||||||
|
let state = state.clone();
|
||||||
|
let spec = spec.clone();
|
||||||
|
|
||||||
|
let handle = tokio::spawn(async move {
|
||||||
|
let app = create_app(&spec, state);
|
||||||
|
let addr = SocketAddr::from(([0, 0, 0, 0], spec.port));
|
||||||
|
tracing::info!(msg = "Listening", ?spec);
|
||||||
|
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
|
||||||
|
axum::serve(listener, app.into_make_service())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
});
|
||||||
|
handles.push(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
futures::future::join_all(handles).await;
|
||||||
|
}
|
|
@ -5,6 +5,8 @@ use std::{
|
||||||
task::{Context, Poll},
|
task::{Context, Poll},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use std::sync::Once;
|
||||||
|
|
||||||
use axum::{body::Body, http::Request};
|
use axum::{body::Body, http::Request};
|
||||||
use pin_project_lite::pin_project;
|
use pin_project_lite::pin_project;
|
||||||
use tower::{Layer, Service};
|
use tower::{Layer, Service};
|
||||||
|
@ -81,3 +83,18 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static INIT: Once = Once::new();
|
||||||
|
|
||||||
|
pub fn initialize_logger() {
|
||||||
|
INIT.call_once(|| {
|
||||||
|
let env_filter = tracing_subscriber::EnvFilter::builder()
|
||||||
|
.with_default_directive(tracing::level_filters::LevelFilter::INFO.into())
|
||||||
|
.from_env_lossy();
|
||||||
|
|
||||||
|
tracing_subscriber::fmt()
|
||||||
|
.compact()
|
||||||
|
.with_env_filter(env_filter)
|
||||||
|
.init();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
|
@ -1,465 +1,11 @@
|
||||||
mod logging;
|
use llama_proxy_man::{config::AppConfig, logging, start_server};
|
||||||
|
use tokio;
|
||||||
use anyhow::anyhow;
|
|
||||||
use axum::{
|
|
||||||
body::Body,
|
|
||||||
http::{self, Request, Response},
|
|
||||||
response::IntoResponse,
|
|
||||||
routing::any,
|
|
||||||
Router,
|
|
||||||
};
|
|
||||||
use futures;
|
|
||||||
use itertools::Itertools;
|
|
||||||
use reqwest::Client;
|
|
||||||
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
|
|
||||||
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
|
||||||
use serde::Deserialize;
|
|
||||||
use std::{collections::HashMap, net::SocketAddr, process::Stdio, sync::Arc};
|
|
||||||
use tokio::{
|
|
||||||
net::TcpStream,
|
|
||||||
process::{Child, Command},
|
|
||||||
sync::Mutex,
|
|
||||||
time::{sleep, Duration},
|
|
||||||
};
|
|
||||||
use tower_http;
|
|
||||||
use tracing::Level;
|
|
||||||
|
|
||||||
use std::sync::Once;
|
|
||||||
|
|
||||||
pub static INIT: Once = Once::new();
|
|
||||||
|
|
||||||
pub fn initialize_logger() {
|
|
||||||
INIT.call_once(|| {
|
|
||||||
let env_filter = tracing_subscriber::EnvFilter::builder()
|
|
||||||
.with_default_directive(tracing::level_filters::LevelFilter::INFO.into())
|
|
||||||
.from_env_lossy();
|
|
||||||
|
|
||||||
tracing_subscriber::fmt()
|
|
||||||
.compact()
|
|
||||||
.with_env_filter(env_filter)
|
|
||||||
.init();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO Add valiation to the config
|
|
||||||
// - e.g. check double taken ports etc
|
|
||||||
#[derive(Clone, Debug, Deserialize)]
|
|
||||||
struct Config {
|
|
||||||
hardware: Hardware,
|
|
||||||
models: Vec<ModelConfig>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Config {
|
|
||||||
fn default_from_pwd_yml() -> Self {
|
|
||||||
// Read and parse the YAML configuration
|
|
||||||
let config_str =
|
|
||||||
std::fs::read_to_string("config.yaml").expect("Failed to read config.yaml");
|
|
||||||
|
|
||||||
serde_yaml::from_str::<Self>(&config_str)
|
|
||||||
.expect("Failed to parse config.yaml")
|
|
||||||
.pick_open_ports()
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO split up into raw deser config and "parsed"/"processed" config which always has a port
|
|
||||||
// FIXME maybe just always initialize with a port ?
|
|
||||||
fn pick_open_ports(self) -> Self {
|
|
||||||
let mut config = self.clone();
|
|
||||||
for model in &mut config.models {
|
|
||||||
if model.internal_port.is_none() {
|
|
||||||
model.internal_port = Some(
|
|
||||||
openport::pick_random_unused_port()
|
|
||||||
.expect(format!("No open port found for {:?}", model).as_str()),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
config
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize)]
|
|
||||||
struct Hardware {
|
|
||||||
ram: String,
|
|
||||||
vram: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize)]
|
|
||||||
struct ModelConfig {
|
|
||||||
#[allow(dead_code)]
|
|
||||||
name: String,
|
|
||||||
port: u16,
|
|
||||||
internal_port: Option<u16>,
|
|
||||||
env: HashMap<String, String>,
|
|
||||||
args: HashMap<String, String>,
|
|
||||||
vram_usage: String,
|
|
||||||
ram_usage: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
struct LlamaInstance {
|
|
||||||
config: ModelConfig,
|
|
||||||
process: Arc<Mutex<Child>>,
|
|
||||||
// busy: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LlamaInstance {
|
|
||||||
/// Retrieve a running instance from state or spawn a new one.
|
|
||||||
pub async fn get_or_spawn_instance(
|
|
||||||
model_config: &ModelConfig,
|
|
||||||
state: SharedState,
|
|
||||||
) -> Result<LlamaInstance, AppError> {
|
|
||||||
let model_ram_usage =
|
|
||||||
parse_size(&model_config.ram_usage).expect("Invalid ram_usage in model config");
|
|
||||||
let model_vram_usage =
|
|
||||||
parse_size(&model_config.vram_usage).expect("Invalid vram_usage in model config");
|
|
||||||
|
|
||||||
{
|
|
||||||
// First, check if we already have an instance.
|
|
||||||
let state_guard = state.lock().await;
|
|
||||||
if let Some(instance) = state_guard.instances.get(&model_config.port) {
|
|
||||||
return Ok(instance.clone());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut state_guard = state.lock().await;
|
|
||||||
tracing::info!(msg = "Shared state before spawn", ?state_guard);
|
|
||||||
|
|
||||||
// If not enough free resources, stop some running instances.
|
|
||||||
if (state_guard.used_ram + model_ram_usage > state_guard.total_ram)
|
|
||||||
|| (state_guard.used_vram + model_vram_usage > state_guard.total_vram)
|
|
||||||
{
|
|
||||||
let mut to_remove = Vec::new();
|
|
||||||
let instances_by_size = state_guard
|
|
||||||
.instances
|
|
||||||
.iter()
|
|
||||||
.sorted_by(|(_, a), (_, b)| {
|
|
||||||
parse_size(&b.config.vram_usage)
|
|
||||||
.unwrap_or(0)
|
|
||||||
.cmp(&parse_size(&a.config.vram_usage).unwrap_or(0))
|
|
||||||
})
|
|
||||||
.map(|(port, instance)| (*port, instance.clone()))
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
for (port, instance) in instances_by_size.iter() {
|
|
||||||
tracing::info!("Stopping instance on port {}", port);
|
|
||||||
let mut proc = instance.process.lock().await;
|
|
||||||
proc.kill().await.ok();
|
|
||||||
to_remove.push(*port);
|
|
||||||
state_guard.used_ram = state_guard
|
|
||||||
.used_ram
|
|
||||||
.saturating_sub(parse_size(&instance.config.ram_usage).unwrap_or(0));
|
|
||||||
state_guard.used_vram = state_guard
|
|
||||||
.used_vram
|
|
||||||
.saturating_sub(parse_size(&instance.config.vram_usage).unwrap_or(0));
|
|
||||||
if state_guard.used_ram + model_ram_usage <= state_guard.total_ram
|
|
||||||
&& state_guard.used_vram + model_vram_usage <= state_guard.total_vram
|
|
||||||
{
|
|
||||||
tracing::info!("Freed enough resources");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for port in to_remove {
|
|
||||||
state_guard.instances.remove(&port);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
tracing::info!("Sufficient resources available");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Spawn a new instance.
|
|
||||||
let instance = Self::spawn(&model_config).await?;
|
|
||||||
state_guard.used_ram += model_ram_usage;
|
|
||||||
state_guard.used_vram += model_vram_usage;
|
|
||||||
state_guard
|
|
||||||
.instances
|
|
||||||
.insert(model_config.port, instance.clone());
|
|
||||||
|
|
||||||
sleep(Duration::from_millis(250)).await;
|
|
||||||
|
|
||||||
instance.running_and_port_ready().await?;
|
|
||||||
|
|
||||||
sleep(Duration::from_millis(250)).await;
|
|
||||||
|
|
||||||
Ok(instance)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Spawn a new llama-server process based on the model configuration.
|
|
||||||
pub async fn spawn(model_config: &ModelConfig) -> Result<LlamaInstance, AppError> {
|
|
||||||
let args = model_config
|
|
||||||
.args
|
|
||||||
.iter()
|
|
||||||
.flat_map(|(k, v)| {
|
|
||||||
if v == "true" {
|
|
||||||
vec![format!("--{}", k)]
|
|
||||||
} else {
|
|
||||||
vec![format!("--{}", k), v.clone()]
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
let internal_port = model_config
|
|
||||||
.internal_port
|
|
||||||
.expect("Internal port must be set");
|
|
||||||
|
|
||||||
let mut cmd = Command::new("llama-server");
|
|
||||||
cmd.kill_on_drop(true)
|
|
||||||
.envs(model_config.env.clone())
|
|
||||||
.args(&args)
|
|
||||||
.arg("--port")
|
|
||||||
.arg(format!("{}", internal_port))
|
|
||||||
.stdout(Stdio::null())
|
|
||||||
.stderr(Stdio::null());
|
|
||||||
|
|
||||||
// Silence output – could be captured later via an API.
|
|
||||||
|
|
||||||
tracing::info!("Starting llama-server with command: {:?}", cmd);
|
|
||||||
let child = cmd.spawn().expect("Failed to start llama-server");
|
|
||||||
|
|
||||||
Ok(LlamaInstance {
|
|
||||||
config: model_config.clone(),
|
|
||||||
process: Arc::new(Mutex::new(child)),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn running_and_port_ready(&self) -> Result<(), anyhow::Error> {
|
|
||||||
async fn is_running(instance: &LlamaInstance) -> Result<(), AppError> {
|
|
||||||
let mut proc = instance.process.clone().lock_owned().await;
|
|
||||||
match proc.try_wait()? {
|
|
||||||
Some(exit_status) => {
|
|
||||||
tracing::error!("Llama instance exited: {:?}", exit_status);
|
|
||||||
Err(AppError::Unknown(anyhow!("Llama instance exited")))
|
|
||||||
}
|
|
||||||
None => Ok(()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn wait_for_port(port: u16) -> Result<(), anyhow::Error> {
|
|
||||||
for _ in 0..10 {
|
|
||||||
if TcpStream::connect(("127.0.0.1", port)).await.is_ok() {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
sleep(Duration::from_millis(500)).await;
|
|
||||||
}
|
|
||||||
Err(anyhow!("Timeout waiting for port"))
|
|
||||||
}
|
|
||||||
|
|
||||||
is_running(self).await?;
|
|
||||||
wait_for_port(self.config.internal_port.expect("No port picked?")).await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
pub struct InnerSharedState {
|
|
||||||
total_ram: u64,
|
|
||||||
total_vram: u64,
|
|
||||||
used_ram: u64,
|
|
||||||
used_vram: u64,
|
|
||||||
instances: HashMap<u16, LlamaInstance>,
|
|
||||||
}
|
|
||||||
|
|
||||||
type SharedStateArc = Arc<Mutex<InnerSharedState>>;
|
|
||||||
|
|
||||||
/// TODO migrate to dashmap + individual Arc<Mutex<T>> or Rwlocks
|
|
||||||
#[derive(Clone, Debug, derive_more::Deref)]
|
|
||||||
pub struct SharedState(SharedStateArc);
|
|
||||||
|
|
||||||
impl SharedState {
|
|
||||||
fn from_config(config: Config) -> Self {
|
|
||||||
// Parse hardware resources
|
|
||||||
let total_ram = parse_size(&config.hardware.ram).expect("Invalid RAM size in config");
|
|
||||||
let total_vram = parse_size(&config.hardware.vram).expect("Invalid VRAM size in config");
|
|
||||||
|
|
||||||
// Initialize shared state
|
|
||||||
let shared_state = InnerSharedState {
|
|
||||||
total_ram,
|
|
||||||
total_vram,
|
|
||||||
used_ram: 0,
|
|
||||||
used_vram: 0,
|
|
||||||
instances: HashMap::new(),
|
|
||||||
};
|
|
||||||
|
|
||||||
Self(Arc::new(Mutex::new(shared_state)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Build an Axum app for a given model.
|
|
||||||
fn create_app(model_config: &ModelConfig, state: SharedState) -> Router {
|
|
||||||
Router::new()
|
|
||||||
.route(
|
|
||||||
"/",
|
|
||||||
any({
|
|
||||||
let state = state.clone();
|
|
||||||
let model_config = model_config.clone();
|
|
||||||
move |req| handle_request(req, model_config.clone(), state.clone())
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
.route(
|
|
||||||
"/*path",
|
|
||||||
any({
|
|
||||||
let state = state.clone();
|
|
||||||
let model_config = model_config.clone();
|
|
||||||
move |req| handle_request(req, model_config.clone(), state.clone())
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
.layer(
|
|
||||||
tower_http::trace::TraceLayer::new_for_http()
|
|
||||||
.make_span_with(tower_http::trace::DefaultMakeSpan::new().include_headers(true))
|
|
||||||
.on_request(tower_http::trace::DefaultOnRequest::new().level(Level::DEBUG))
|
|
||||||
.on_response(tower_http::trace::DefaultOnResponse::new().level(Level::TRACE))
|
|
||||||
.on_eos(tower_http::trace::DefaultOnEos::new().level(Level::DEBUG))
|
|
||||||
.on_failure(tower_http::trace::DefaultOnFailure::new().level(Level::ERROR)),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() {
|
async fn main() {
|
||||||
initialize_logger();
|
logging::initialize_logger();
|
||||||
|
|
||||||
let config = Config::default_from_pwd_yml();
|
let config = AppConfig::default_from_pwd_yml();
|
||||||
|
|
||||||
let shared_state = SharedState::from_config(config.clone());
|
start_server(config).await;
|
||||||
|
|
||||||
// For each model, set up an axum server listening on the specified port
|
|
||||||
let mut handles = Vec::new();
|
|
||||||
|
|
||||||
for model_config in config.models {
|
|
||||||
let state = shared_state.clone();
|
|
||||||
let model_config = model_config.clone();
|
|
||||||
|
|
||||||
let handle = tokio::spawn(async move {
|
|
||||||
let model_config = model_config.clone();
|
|
||||||
|
|
||||||
let app = create_app(&model_config, state);
|
|
||||||
|
|
||||||
let addr = SocketAddr::from(([0, 0, 0, 0], model_config.port));
|
|
||||||
|
|
||||||
tracing::info!(msg = "Listening", ?model_config);
|
|
||||||
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
|
|
||||||
|
|
||||||
axum::serve(listener, app.into_make_service())
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
});
|
|
||||||
|
|
||||||
handles.push(handle);
|
|
||||||
}
|
|
||||||
|
|
||||||
futures::future::join_all(handles).await;
|
|
||||||
}
|
|
||||||
|
|
||||||
use thiserror::Error;
|
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
|
||||||
pub enum AppError {
|
|
||||||
#[error("Axum Error")]
|
|
||||||
AxumError(#[from] axum::Error),
|
|
||||||
#[error("Axum Http Error")]
|
|
||||||
AxumHttpError(#[from] axum::http::Error),
|
|
||||||
#[error("Reqwest Error")]
|
|
||||||
ReqwestError(#[from] reqwest::Error),
|
|
||||||
#[error("Reqwest Middleware Error")]
|
|
||||||
ReqwestMiddlewareError(#[from] reqwest_middleware::Error),
|
|
||||||
#[error("Client Error")]
|
|
||||||
ClientError(#[from] hyper::Error),
|
|
||||||
#[error("Io Error")]
|
|
||||||
IoError(#[from] std::io::Error),
|
|
||||||
#[error("Unknown error")]
|
|
||||||
Unknown(#[from] anyhow::Error),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tell axum how to convert `AppError` into a response.
|
|
||||||
impl IntoResponse for AppError {
|
|
||||||
fn into_response(self) -> axum::response::Response {
|
|
||||||
tracing::error!("::AppError::into_response:: {:?}", self);
|
|
||||||
(
|
|
||||||
http::StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
format!("Something went wrong: {:?}", self),
|
|
||||||
)
|
|
||||||
.into_response()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn proxy_request(
|
|
||||||
req: Request<Body>,
|
|
||||||
model_config: &ModelConfig,
|
|
||||||
) -> Result<Response<Body>, AppError> {
|
|
||||||
let retry_policy = ExponentialBackoff::builder()
|
|
||||||
.retry_bounds(
|
|
||||||
std::time::Duration::from_millis(500),
|
|
||||||
std::time::Duration::from_secs(8),
|
|
||||||
)
|
|
||||||
.jitter(reqwest_retry::Jitter::None)
|
|
||||||
.base(2)
|
|
||||||
.build_with_max_retries(8);
|
|
||||||
|
|
||||||
let client: ClientWithMiddleware = ClientBuilder::new(Client::new())
|
|
||||||
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
|
|
||||||
.build();
|
|
||||||
|
|
||||||
let internal_port = model_config
|
|
||||||
.internal_port
|
|
||||||
.expect("Internal port must be set");
|
|
||||||
let uri = format!(
|
|
||||||
"http://127.0.0.1:{}{}",
|
|
||||||
internal_port,
|
|
||||||
req.uri()
|
|
||||||
.path_and_query()
|
|
||||||
.map(|pq| pq.as_str())
|
|
||||||
.unwrap_or("")
|
|
||||||
);
|
|
||||||
|
|
||||||
let mut request_builder = client.request(req.method().clone(), &uri);
|
|
||||||
// Forward all headers.
|
|
||||||
for (name, value) in req.headers() {
|
|
||||||
request_builder = request_builder.header(name, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
let max_size = parse_size("1G").unwrap() as usize;
|
|
||||||
let body_bytes = axum::body::to_bytes(req.into_body(), max_size).await?;
|
|
||||||
let request = request_builder.body(body_bytes).build()?;
|
|
||||||
let response = client.execute(request).await?;
|
|
||||||
|
|
||||||
let mut builder = Response::builder().status(response.status());
|
|
||||||
for (name, value) in response.headers() {
|
|
||||||
builder = builder.header(name, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
let byte_stream = response.bytes_stream();
|
|
||||||
let body = Body::from_stream(byte_stream);
|
|
||||||
Ok(builder.body(body)?)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn handle_request(
|
|
||||||
req: Request<Body>,
|
|
||||||
model_config: ModelConfig,
|
|
||||||
state: SharedState,
|
|
||||||
) -> impl IntoResponse {
|
|
||||||
let _instance = LlamaInstance::get_or_spawn_instance(&model_config, state).await?;
|
|
||||||
let response = proxy_request(req, &model_config).await?;
|
|
||||||
|
|
||||||
Ok::<axum::http::Response<Body>, AppError>(response)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn parse_size(size_str: &str) -> Option<u64> {
|
|
||||||
let mut num = String::new();
|
|
||||||
let mut unit = String::new();
|
|
||||||
|
|
||||||
for c in size_str.chars() {
|
|
||||||
if c.is_digit(10) || c == '.' {
|
|
||||||
num.push(c);
|
|
||||||
} else {
|
|
||||||
unit.push(c);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let num: f64 = num.parse().ok()?;
|
|
||||||
|
|
||||||
let multiplier = match unit.to_lowercase().as_str() {
|
|
||||||
"g" | "gb" => 1024 * 1024 * 1024,
|
|
||||||
"m" | "mb" => 1024 * 1024,
|
|
||||||
"k" | "kb" => 1024,
|
|
||||||
_ => panic!("Invalid Size"),
|
|
||||||
};
|
|
||||||
|
|
||||||
let res = (num * multiplier as f64) as u64;
|
|
||||||
Some(res)
|
|
||||||
}
|
}
|
||||||
|
|
69
llama_proxy_man/src/proxy.rs
Normal file
69
llama_proxy_man/src/proxy.rs
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
use crate::{
|
||||||
|
config::ModelSpec, error::AppError, inference_process::InferenceProcess, state::AppState,
|
||||||
|
util::parse_size,
|
||||||
|
};
|
||||||
|
use axum::{
|
||||||
|
body::Body,
|
||||||
|
http::{Request, Response},
|
||||||
|
response::IntoResponse,
|
||||||
|
};
|
||||||
|
use reqwest::Client;
|
||||||
|
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
|
||||||
|
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
||||||
|
|
||||||
|
pub async fn proxy_request(
|
||||||
|
req: Request<Body>,
|
||||||
|
spec: &ModelSpec,
|
||||||
|
) -> Result<Response<Body>, AppError> {
|
||||||
|
let retry_policy = ExponentialBackoff::builder()
|
||||||
|
.retry_bounds(
|
||||||
|
std::time::Duration::from_millis(500),
|
||||||
|
std::time::Duration::from_secs(8),
|
||||||
|
)
|
||||||
|
.jitter(reqwest_retry::Jitter::None)
|
||||||
|
.base(2)
|
||||||
|
.build_with_max_retries(8);
|
||||||
|
|
||||||
|
let client: ClientWithMiddleware = ClientBuilder::new(Client::new())
|
||||||
|
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
let internal_port = spec.internal_port.expect("Internal port must be set");
|
||||||
|
let uri = format!(
|
||||||
|
"http://127.0.0.1:{}{}",
|
||||||
|
internal_port,
|
||||||
|
req.uri()
|
||||||
|
.path_and_query()
|
||||||
|
.map(|pq| pq.as_str())
|
||||||
|
.unwrap_or("")
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut request_builder = client.request(req.method().clone(), &uri);
|
||||||
|
for (name, value) in req.headers() {
|
||||||
|
request_builder = request_builder.header(name, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
let max_size = parse_size("1G").unwrap() as usize;
|
||||||
|
let body_bytes = axum::body::to_bytes(req.into_body(), max_size).await?;
|
||||||
|
let request = request_builder.body(body_bytes).build()?;
|
||||||
|
let response = client.execute(request).await?;
|
||||||
|
|
||||||
|
let mut builder = Response::builder().status(response.status());
|
||||||
|
for (name, value) in response.headers() {
|
||||||
|
builder = builder.header(name, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
let byte_stream = response.bytes_stream();
|
||||||
|
let body = Body::from_stream(byte_stream);
|
||||||
|
Ok(builder.body(body)?)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn handle_request(
|
||||||
|
req: Request<Body>,
|
||||||
|
spec: ModelSpec,
|
||||||
|
state: AppState,
|
||||||
|
) -> impl IntoResponse {
|
||||||
|
let _ = InferenceProcess::get_or_spawn(&spec, state).await?;
|
||||||
|
let response = proxy_request(req, &spec).await?;
|
||||||
|
Ok::<Response<Body>, AppError>(response)
|
||||||
|
}
|
36
llama_proxy_man/src/state.rs
Normal file
36
llama_proxy_man/src/state.rs
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
use crate::{config::AppConfig, inference_process::InferenceProcess, util::parse_size};
|
||||||
|
use std::{collections::HashMap, sync::Arc};
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct ResourceManager {
|
||||||
|
pub total_ram: u64,
|
||||||
|
pub total_vram: u64,
|
||||||
|
pub used_ram: u64,
|
||||||
|
pub used_vram: u64,
|
||||||
|
pub processes: HashMap<u16, InferenceProcess>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type ResourceManagerHandle = Arc<Mutex<ResourceManager>>;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, derive_more::Deref)]
|
||||||
|
pub struct AppState(pub ResourceManagerHandle);
|
||||||
|
|
||||||
|
impl AppState {
|
||||||
|
pub fn from_config(config: AppConfig) -> Self {
|
||||||
|
let total_ram =
|
||||||
|
parse_size(&config.system_resources.ram).expect("Invalid RAM size in config");
|
||||||
|
let total_vram =
|
||||||
|
parse_size(&config.system_resources.vram).expect("Invalid VRAM size in config");
|
||||||
|
|
||||||
|
let resource_manager = ResourceManager {
|
||||||
|
total_ram,
|
||||||
|
total_vram,
|
||||||
|
used_ram: 0,
|
||||||
|
used_vram: 0,
|
||||||
|
processes: HashMap::new(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Self(Arc::new(Mutex::new(resource_manager)))
|
||||||
|
}
|
||||||
|
}
|
22
llama_proxy_man/src/util.rs
Normal file
22
llama_proxy_man/src/util.rs
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
pub fn parse_size(size_str: &str) -> Option<u64> {
|
||||||
|
let mut num = String::new();
|
||||||
|
let mut unit = String::new();
|
||||||
|
|
||||||
|
for c in size_str.chars() {
|
||||||
|
if c.is_digit(10) || c == '.' {
|
||||||
|
num.push(c);
|
||||||
|
} else {
|
||||||
|
unit.push(c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let num: f64 = num.parse().ok()?;
|
||||||
|
let multiplier = match unit.to_lowercase().as_str() {
|
||||||
|
"g" | "gb" => 1024 * 1024 * 1024,
|
||||||
|
"m" | "mb" => 1024 * 1024,
|
||||||
|
"k" | "kb" => 1024,
|
||||||
|
_ => panic!("Invalid Size"),
|
||||||
|
};
|
||||||
|
|
||||||
|
Some((num * multiplier as f64) as u64)
|
||||||
|
}
|
Loading…
Add table
Reference in a new issue