Compare commits
4 commits
d39459f2a9
...
bebdd35c6e
Author | SHA1 | Date | |
---|---|---|---|
bebdd35c6e | |||
ad0cd12877 | |||
e3a3ec3826 | |||
8c062ece28 |
13 changed files with 489 additions and 427 deletions
25
Cargo.lock
generated
25
Cargo.lock
generated
|
@ -1141,6 +1141,26 @@ dependencies = [
|
|||
"syn 2.0.98",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_more"
|
||||
version = "2.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "093242cf7570c207c83073cf82f79706fe7b8317e98620a47d5be7c3d8497678"
|
||||
dependencies = [
|
||||
"derive_more-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_more-impl"
|
||||
version = "2.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bda628edc44c4bb645fbe0f758797143e4e07926f7ebf4e9bdfbd3d2ce621df3"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.98",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "digest"
|
||||
version = "0.10.7"
|
||||
|
@ -2950,7 +2970,7 @@ dependencies = [
|
|||
"chrono",
|
||||
"console_error_panic_hook",
|
||||
"dashmap",
|
||||
"derive_more",
|
||||
"derive_more 0.99.19",
|
||||
"futures",
|
||||
"futures-util",
|
||||
"gloo-net 0.5.0",
|
||||
|
@ -3002,6 +3022,7 @@ version = "0.1.1"
|
|||
dependencies = [
|
||||
"anyhow",
|
||||
"axum",
|
||||
"derive_more 2.0.1",
|
||||
"futures",
|
||||
"hyper",
|
||||
"itertools 0.13.0",
|
||||
|
@ -4781,7 +4802,7 @@ checksum = "df320f1889ac4ba6bc0cdc9c9af7af4bd64bb927bccdf32d81140dc1f9be12fe"
|
|||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"cssparser",
|
||||
"derive_more",
|
||||
"derive_more 0.99.19",
|
||||
"fxhash",
|
||||
"log",
|
||||
"matches",
|
||||
|
|
10
README.md
10
README.md
|
@ -1,5 +1,15 @@
|
|||
# Redvau.lt AI Monorepo
|
||||
|
||||
## Short Term Todo
|
||||
- [x] Prepare proxy man for embedding
|
||||
- [-] Improve markdown rendering in forge chat
|
||||
- [ ] Embed proxy man & add simple ui in forge
|
||||
- [ ] dumb embed process on startup
|
||||
- [ ] View current instances/models/virtual endpoints
|
||||
- [ ] Edit
|
||||
- [ ] Add new (from configurable model folder)
|
||||
|
||||
|
||||
## Current Repos
|
||||
|
||||
#### llama-forge-rs:
|
||||
|
|
|
@ -14,6 +14,7 @@ crate-type = ["cdylib", "rlib"]
|
|||
|
||||
[dependencies]
|
||||
wasm-bindgen = "=0.2.100"
|
||||
# TODO Update to 0.7
|
||||
leptos = { version = "0.6", features = [
|
||||
"serde",
|
||||
"nightly",
|
||||
|
@ -87,10 +88,9 @@ pulldown-cmark = { version = "0.12.2", features = ["serde"] }
|
|||
# qdrant-client = "1.11.2"
|
||||
# swiftide = "0.9.1"
|
||||
|
||||
# TODO Add desktop/gui feature
|
||||
# TODO Add desktop/gui feature (or maybe server/headless only?)
|
||||
[features]
|
||||
# default = ["ssr"]
|
||||
default = ["hydrate"]
|
||||
default = ["ssr"]
|
||||
hydrate = ["leptos/hydrate", "leptos_meta/hydrate", "leptos_router/hydrate"]
|
||||
ssr = [
|
||||
"dep:async-broadcast",
|
||||
|
|
|
@ -29,3 +29,4 @@ reqwest-retry = "0.6.1"
|
|||
reqwest-middleware = { version = "0.3.3", features = ["charset", "http2", "json", "multipart", "rustls-tls"] }
|
||||
itertools = "0.13.0"
|
||||
openport = { version = "0.1.1", features = ["rand"] }
|
||||
derive_more = { version = "2.0.1", features = ["deref"] }
|
||||
|
|
48
llama_proxy_man/src/config.rs
Normal file
48
llama_proxy_man/src/config.rs
Normal file
|
@ -0,0 +1,48 @@
|
|||
use serde::Deserialize;
|
||||
use std::{collections::HashMap, fs};
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct AppConfig {
|
||||
pub system_resources: SystemResources,
|
||||
pub model_specs: Vec<ModelSpec>,
|
||||
}
|
||||
|
||||
impl AppConfig {
|
||||
pub fn default_from_pwd_yml() -> Self {
|
||||
let config_str = fs::read_to_string("config.yaml").expect("Failed to read config.yaml");
|
||||
serde_yaml::from_str::<Self>(&config_str)
|
||||
.expect("Failed to parse config.yaml")
|
||||
.assign_internal_ports()
|
||||
}
|
||||
|
||||
// Ensure every model has an internal port
|
||||
pub fn assign_internal_ports(self) -> Self {
|
||||
let mut config = self.clone();
|
||||
for model in &mut config.model_specs {
|
||||
if model.internal_port.is_none() {
|
||||
model.internal_port = Some(
|
||||
openport::pick_random_unused_port()
|
||||
.expect(&format!("No open port found for {:?}", model)),
|
||||
);
|
||||
}
|
||||
}
|
||||
config
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct SystemResources {
|
||||
pub ram: String,
|
||||
pub vram: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct ModelSpec {
|
||||
pub name: String,
|
||||
pub port: u16,
|
||||
pub internal_port: Option<u16>,
|
||||
pub env: HashMap<String, String>,
|
||||
pub args: HashMap<String, String>,
|
||||
pub vram_usage: String,
|
||||
pub ram_usage: String,
|
||||
}
|
36
llama_proxy_man/src/error.rs
Normal file
36
llama_proxy_man/src/error.rs
Normal file
|
@ -0,0 +1,36 @@
|
|||
use axum::{http, response::IntoResponse};
|
||||
use hyper;
|
||||
use reqwest;
|
||||
use reqwest_middleware;
|
||||
use std::io;
|
||||
use thiserror::Error;
|
||||
use anyhow::Error as AnyError;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum AppError {
|
||||
#[error("Axum Error")]
|
||||
AxumError(#[from] axum::Error),
|
||||
#[error("Axum Http Error")]
|
||||
AxumHttpError(#[from] http::Error),
|
||||
#[error("Reqwest Error")]
|
||||
ReqwestError(#[from] reqwest::Error),
|
||||
#[error("Reqwest Middleware Error")]
|
||||
ReqwestMiddlewareError(#[from] reqwest_middleware::Error),
|
||||
#[error("Client Error")]
|
||||
ClientError(#[from] hyper::Error),
|
||||
#[error("Io Error")]
|
||||
IoError(#[from] io::Error),
|
||||
#[error("Unknown error")]
|
||||
Unknown(#[from] AnyError),
|
||||
}
|
||||
|
||||
impl IntoResponse for AppError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
tracing::error!("::AppError::into_response: {:?}", self);
|
||||
(
|
||||
http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Something went wrong: {:?}", self),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
149
llama_proxy_man/src/inference_process.rs
Normal file
149
llama_proxy_man/src/inference_process.rs
Normal file
|
@ -0,0 +1,149 @@
|
|||
use crate::{config::ModelSpec, error::AppError, state::AppState, util::parse_size};
|
||||
use anyhow::anyhow;
|
||||
use itertools::Itertools;
|
||||
use std::{process::Stdio, sync::Arc};
|
||||
use tokio::{
|
||||
net::TcpStream,
|
||||
process::{Child, Command},
|
||||
sync::Mutex,
|
||||
time::{sleep, Duration},
|
||||
};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct InferenceProcess {
|
||||
pub spec: ModelSpec,
|
||||
pub process: Arc<Mutex<Child>>,
|
||||
}
|
||||
|
||||
impl InferenceProcess {
|
||||
/// Retrieve a running process from state or spawn a new one.
|
||||
pub async fn get_or_spawn(
|
||||
spec: &ModelSpec,
|
||||
state: AppState,
|
||||
) -> Result<InferenceProcess, AppError> {
|
||||
let required_ram = parse_size(&spec.ram_usage).expect("Invalid ram_usage in model spec");
|
||||
let required_vram = parse_size(&spec.vram_usage).expect("Invalid vram_usage in model spec");
|
||||
|
||||
{
|
||||
let state_guard = state.lock().await;
|
||||
if let Some(proc) = state_guard.processes.get(&spec.port) {
|
||||
return Ok(proc.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let mut state_guard = state.lock().await;
|
||||
tracing::info!(msg = "App state before spawn", ?state_guard);
|
||||
|
||||
// If not enough resources, stop some running processes.
|
||||
if (state_guard.used_ram + required_ram > state_guard.total_ram)
|
||||
|| (state_guard.used_vram + required_vram > state_guard.total_vram)
|
||||
{
|
||||
let mut to_remove = Vec::new();
|
||||
let processes_by_usage = state_guard
|
||||
.processes
|
||||
.iter()
|
||||
.sorted_by(|(_, a), (_, b)| {
|
||||
parse_size(&b.spec.vram_usage)
|
||||
.unwrap_or(0)
|
||||
.cmp(&parse_size(&a.spec.vram_usage).unwrap_or(0))
|
||||
})
|
||||
.map(|(port, proc)| (*port, proc.clone()))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for (port, proc) in processes_by_usage.iter() {
|
||||
tracing::info!("Stopping process on port {}", port);
|
||||
let mut lock = proc.process.lock().await;
|
||||
lock.kill().await.ok();
|
||||
to_remove.push(*port);
|
||||
state_guard.used_ram = state_guard
|
||||
.used_ram
|
||||
.saturating_sub(parse_size(&proc.spec.ram_usage).unwrap_or(0));
|
||||
state_guard.used_vram = state_guard
|
||||
.used_vram
|
||||
.saturating_sub(parse_size(&proc.spec.vram_usage).unwrap_or(0));
|
||||
if state_guard.used_ram + required_ram <= state_guard.total_ram
|
||||
&& state_guard.used_vram + required_vram <= state_guard.total_vram
|
||||
{
|
||||
tracing::info!("Freed enough resources");
|
||||
break;
|
||||
}
|
||||
}
|
||||
for port in to_remove {
|
||||
state_guard.processes.remove(&port);
|
||||
}
|
||||
} else {
|
||||
tracing::info!("Sufficient resources available");
|
||||
}
|
||||
|
||||
let proc = Self::spawn(spec).await?;
|
||||
state_guard.used_ram += required_ram;
|
||||
state_guard.used_vram += required_vram;
|
||||
state_guard.processes.insert(spec.port, proc.clone());
|
||||
|
||||
sleep(Duration::from_millis(250)).await;
|
||||
proc.wait_until_ready().await?;
|
||||
sleep(Duration::from_millis(250)).await;
|
||||
Ok(proc)
|
||||
}
|
||||
|
||||
/// Spawn a new inference process.
|
||||
pub async fn spawn(spec: &ModelSpec) -> Result<InferenceProcess, AppError> {
|
||||
let args = spec
|
||||
.args
|
||||
.iter()
|
||||
.flat_map(|(k, v)| {
|
||||
if v == "true" {
|
||||
vec![format!("--{}", k)]
|
||||
} else {
|
||||
vec![format!("--{}", k), v.clone()]
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let internal_port = spec.internal_port.expect("Internal port must be set");
|
||||
|
||||
let mut cmd = Command::new("llama-server");
|
||||
cmd.kill_on_drop(true)
|
||||
.envs(spec.env.clone())
|
||||
.args(&args)
|
||||
.arg("--port")
|
||||
.arg(format!("{}", internal_port))
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null());
|
||||
|
||||
tracing::info!("Starting llama-server with command: {:?}", cmd);
|
||||
let child = cmd.spawn().expect("Failed to start llama-server");
|
||||
|
||||
Ok(InferenceProcess {
|
||||
spec: spec.clone(),
|
||||
process: Arc::new(Mutex::new(child)),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn wait_until_ready(&self) -> Result<(), anyhow::Error> {
|
||||
async fn check_running(proc: &InferenceProcess) -> Result<(), AppError> {
|
||||
let mut lock = proc.process.clone().lock_owned().await;
|
||||
match lock.try_wait()? {
|
||||
Some(exit_status) => {
|
||||
tracing::error!("Inference process exited: {:?}", exit_status);
|
||||
Err(AppError::Unknown(anyhow!("Inference process exited")))
|
||||
}
|
||||
None => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_for_port(port: u16) -> Result<(), anyhow::Error> {
|
||||
for _ in 0..10 {
|
||||
if TcpStream::connect(("127.0.0.1", port)).await.is_ok() {
|
||||
return Ok(());
|
||||
}
|
||||
sleep(Duration::from_millis(500)).await;
|
||||
}
|
||||
Err(anyhow!("Timeout waiting for port"))
|
||||
}
|
||||
|
||||
check_running(self).await?;
|
||||
wait_for_port(self.spec.internal_port.expect("Internal port must be set")).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
70
llama_proxy_man/src/lib.rs
Normal file
70
llama_proxy_man/src/lib.rs
Normal file
|
@ -0,0 +1,70 @@
|
|||
pub mod config;
|
||||
pub mod error;
|
||||
pub mod inference_process;
|
||||
pub mod logging;
|
||||
pub mod proxy;
|
||||
pub mod state;
|
||||
pub mod util;
|
||||
|
||||
use axum::{routing::any, Router};
|
||||
use config::{AppConfig, ModelSpec};
|
||||
use state::AppState;
|
||||
use std::net::SocketAddr;
|
||||
use tower_http::trace::{
|
||||
DefaultMakeSpan, DefaultOnEos, DefaultOnFailure, DefaultOnRequest, DefaultOnResponse,
|
||||
TraceLayer,
|
||||
};
|
||||
use tracing::Level;
|
||||
|
||||
/// Creates an Axum application to handle inference requests for a specific model.
|
||||
pub fn create_app(spec: &ModelSpec, state: AppState) -> Router {
|
||||
Router::new()
|
||||
.route(
|
||||
"/",
|
||||
any({
|
||||
let state = state.clone();
|
||||
let spec = spec.clone();
|
||||
move |req| proxy::handle_request(req, spec.clone(), state.clone())
|
||||
}),
|
||||
)
|
||||
.route(
|
||||
"/*path",
|
||||
any({
|
||||
let state = state.clone();
|
||||
let spec = spec.clone();
|
||||
move |req| proxy::handle_request(req, spec.clone(), state.clone())
|
||||
}),
|
||||
)
|
||||
.layer(
|
||||
TraceLayer::new_for_http()
|
||||
.make_span_with(DefaultMakeSpan::new().include_headers(true))
|
||||
.on_request(DefaultOnRequest::new().level(Level::DEBUG))
|
||||
.on_response(DefaultOnResponse::new().level(Level::TRACE))
|
||||
.on_eos(DefaultOnEos::new().level(Level::DEBUG))
|
||||
.on_failure(DefaultOnFailure::new().level(Level::ERROR)),
|
||||
)
|
||||
}
|
||||
|
||||
/// Starts an inference server for each model defined in the config.
|
||||
pub async fn start_server(config: AppConfig) {
|
||||
let state = AppState::from_config(config.clone());
|
||||
|
||||
let mut handles = Vec::new();
|
||||
for spec in config.model_specs {
|
||||
let state = state.clone();
|
||||
let spec = spec.clone();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let app = create_app(&spec, state);
|
||||
let addr = SocketAddr::from(([0, 0, 0, 0], spec.port));
|
||||
tracing::info!(msg = "Listening", ?spec);
|
||||
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
|
||||
axum::serve(listener, app.into_make_service())
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
futures::future::join_all(handles).await;
|
||||
}
|
|
@ -5,6 +5,8 @@ use std::{
|
|||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use std::sync::Once;
|
||||
|
||||
use axum::{body::Body, http::Request};
|
||||
use pin_project_lite::pin_project;
|
||||
use tower::{Layer, Service};
|
||||
|
@ -81,3 +83,18 @@ where
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
static INIT: Once = Once::new();
|
||||
|
||||
pub fn initialize_logger() {
|
||||
INIT.call_once(|| {
|
||||
let env_filter = tracing_subscriber::EnvFilter::builder()
|
||||
.with_default_directive(tracing::level_filters::LevelFilter::INFO.into())
|
||||
.from_env_lossy();
|
||||
|
||||
tracing_subscriber::fmt()
|
||||
.compact()
|
||||
.with_env_filter(env_filter)
|
||||
.init();
|
||||
});
|
||||
}
|
||||
|
|
|
@ -1,428 +1,11 @@
|
|||
mod logging;
|
||||
|
||||
use anyhow::anyhow;
|
||||
use axum::{
|
||||
body::Body,
|
||||
http::{self, Request, Response},
|
||||
response::IntoResponse,
|
||||
routing::any,
|
||||
Router,
|
||||
};
|
||||
use futures;
|
||||
use itertools::Itertools;
|
||||
use reqwest::Client;
|
||||
use serde::Deserialize;
|
||||
use std::{collections::HashMap, net::SocketAddr, process::Stdio, sync::Arc};
|
||||
use tokio::{
|
||||
net::TcpStream,
|
||||
process::{Child, Command},
|
||||
sync::Mutex,
|
||||
time::{sleep, Duration},
|
||||
};
|
||||
use tower_http::trace::{
|
||||
DefaultMakeSpan, DefaultOnEos, DefaultOnFailure, DefaultOnRequest, DefaultOnResponse,
|
||||
TraceLayer,
|
||||
};
|
||||
use tracing::Level;
|
||||
|
||||
use std::sync::Once;
|
||||
|
||||
pub static INIT: Once = Once::new();
|
||||
|
||||
pub fn initialize_logger() {
|
||||
INIT.call_once(|| {
|
||||
let env_filter = tracing_subscriber::EnvFilter::builder()
|
||||
.with_default_directive(tracing::level_filters::LevelFilter::INFO.into())
|
||||
.from_env_lossy();
|
||||
|
||||
tracing_subscriber::fmt()
|
||||
.compact()
|
||||
.with_env_filter(env_filter)
|
||||
.init();
|
||||
});
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
struct Config {
|
||||
hardware: Hardware,
|
||||
models: Vec<ModelConfig>,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
// TODO split up into raw deser config and "parsed"/"processed" config which always has a port
|
||||
fn pick_open_ports(self) -> Self {
|
||||
let mut config = self.clone();
|
||||
for model in &mut config.models {
|
||||
if model.internal_port.is_none() {
|
||||
model.internal_port = Some(
|
||||
openport::pick_random_unused_port()
|
||||
.expect(format!("No open port found for {:?}", model).as_str()),
|
||||
);
|
||||
}
|
||||
}
|
||||
config
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
struct Hardware {
|
||||
ram: String,
|
||||
vram: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
struct ModelConfig {
|
||||
#[allow(dead_code)]
|
||||
name: String,
|
||||
port: u16,
|
||||
internal_port: Option<u16>,
|
||||
env: HashMap<String, String>,
|
||||
args: HashMap<String, String>,
|
||||
vram_usage: String,
|
||||
ram_usage: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct LlamaInstance {
|
||||
config: ModelConfig,
|
||||
process: Arc<Mutex<Child>>,
|
||||
// busy: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct SharedState {
|
||||
total_ram: u64,
|
||||
total_vram: u64,
|
||||
used_ram: u64,
|
||||
used_vram: u64,
|
||||
instances: HashMap<u16, LlamaInstance>,
|
||||
}
|
||||
use llama_proxy_man::{config::AppConfig, logging, start_server};
|
||||
use tokio;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
// TODO add autostart of models based on config
|
||||
// abstract starting logic out of handler for this to allow seperate calls to start
|
||||
// maybe add to SharedState & LLamaInstance ?
|
||||
logging::initialize_logger();
|
||||
|
||||
initialize_logger();
|
||||
// Read and parse the YAML configuration
|
||||
let config_str = std::fs::read_to_string("config.yaml").expect("Failed to read config.yaml");
|
||||
let config: Config = serde_yaml::from_str::<Config>(&config_str)
|
||||
.expect("Failed to parse config.yaml")
|
||||
.pick_open_ports();
|
||||
let config = AppConfig::default_from_pwd_yml();
|
||||
|
||||
// Parse hardware resources
|
||||
let total_ram = parse_size(&config.hardware.ram).expect("Invalid RAM size in config");
|
||||
let total_vram = parse_size(&config.hardware.vram).expect("Invalid VRAM size in config");
|
||||
|
||||
// Initialize shared state
|
||||
let shared_state = Arc::new(Mutex::new(SharedState {
|
||||
total_ram,
|
||||
total_vram,
|
||||
used_ram: 0,
|
||||
used_vram: 0,
|
||||
instances: HashMap::new(),
|
||||
}));
|
||||
|
||||
// For each model, set up an axum server listening on the specified port
|
||||
let mut handles = Vec::new();
|
||||
|
||||
for model_config in config.models {
|
||||
let state = shared_state.clone();
|
||||
let model_config = model_config.clone();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let model_config = model_config.clone();
|
||||
|
||||
let app = Router::new()
|
||||
.route(
|
||||
"/",
|
||||
any({
|
||||
let state = state.clone();
|
||||
let model_config = model_config.clone();
|
||||
move |req| handle_request(req, model_config, state)
|
||||
}),
|
||||
)
|
||||
.route(
|
||||
"/*path",
|
||||
any({
|
||||
let state = state.clone();
|
||||
let model_config = model_config.clone();
|
||||
move |req| handle_request(req, model_config, state)
|
||||
}),
|
||||
)
|
||||
.layer(
|
||||
TraceLayer::new_for_http()
|
||||
.make_span_with(DefaultMakeSpan::new().include_headers(true))
|
||||
.on_request(DefaultOnRequest::new().level(Level::DEBUG))
|
||||
.on_response(DefaultOnResponse::new().level(Level::TRACE))
|
||||
.on_eos(DefaultOnEos::new().level(Level::DEBUG))
|
||||
.on_failure(DefaultOnFailure::new().level(Level::ERROR)),
|
||||
);
|
||||
|
||||
let addr = SocketAddr::from(([0, 0, 0, 0], model_config.port));
|
||||
|
||||
tracing::info!(msg = "Listening", ?model_config);
|
||||
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
|
||||
|
||||
axum::serve(listener, app.into_make_service())
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
futures::future::join_all(handles).await;
|
||||
}
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum AppError {
|
||||
#[error("Axum Error")]
|
||||
AxumError(#[from] axum::Error),
|
||||
#[error("Axum Http Error")]
|
||||
AxumHttpError(#[from] axum::http::Error),
|
||||
#[error("Reqwest Error")]
|
||||
ReqwestError(#[from] reqwest::Error),
|
||||
#[error("Reqwest Middleware Error")]
|
||||
ReqwestMiddlewareError(#[from] reqwest_middleware::Error),
|
||||
#[error("Client Error")]
|
||||
ClientError(#[from] hyper::Error),
|
||||
#[error("Unknown error")]
|
||||
Unknown(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
// Tell axum how to convert `AppError` into a response.
|
||||
impl IntoResponse for AppError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
tracing::error!("::AppError::into_response:: {:?}", self);
|
||||
(
|
||||
http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Something went wrong: {:?}", self),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_request(
|
||||
req: Request<Body>,
|
||||
model_config: ModelConfig,
|
||||
state: Arc<Mutex<SharedState>>,
|
||||
// ) -> Result<Response<Body>, anyhow::Error> {
|
||||
) -> impl IntoResponse {
|
||||
let model_ram_usage = parse_size(&model_config.ram_usage).expect("Invalid ram_usage");
|
||||
let model_vram_usage = parse_size(&model_config.vram_usage).expect("Invalid vram_usage");
|
||||
|
||||
let instance = {
|
||||
let mut state = state.lock().await;
|
||||
|
||||
// Check if instance is running
|
||||
if let Some(instance) = state.instances.get_mut(&model_config.port) {
|
||||
// instance.busy = true;
|
||||
instance.to_owned()
|
||||
} else {
|
||||
// Check resources
|
||||
tracing::info!(msg = "Current state", ?state);
|
||||
if ((state.used_ram + model_ram_usage) > state.total_ram)
|
||||
|| ((state.used_vram + model_vram_usage) > state.total_vram)
|
||||
{
|
||||
// Stop other instances
|
||||
let mut to_remove = Vec::new();
|
||||
// TODO Actual smart stopping logic
|
||||
// - search for smallest single model to stop to get enough room
|
||||
// - if not possible search for smallest number of models to stop with lowest
|
||||
// amount of "overshot
|
||||
let instances_by_size =
|
||||
state
|
||||
.instances
|
||||
.clone()
|
||||
.into_iter()
|
||||
.sorted_by(|(_, el_a), (_, el_b)| {
|
||||
Ord::cmp(
|
||||
&parse_size(el_b.config.vram_usage.as_str()),
|
||||
&parse_size(el_a.config.vram_usage.as_str()),
|
||||
)
|
||||
});
|
||||
for (port, instance) in instances_by_size {
|
||||
// if !instance.busy {
|
||||
tracing::info!("Stopping instance on port {}", port);
|
||||
let mut process = instance.process.lock().await;
|
||||
process.kill().await.ok();
|
||||
to_remove.push(port);
|
||||
state.used_ram -= parse_size(&instance.config.ram_usage).unwrap_or(0);
|
||||
state.used_vram -= parse_size(&instance.config.vram_usage).unwrap_or(0);
|
||||
if state.used_ram + model_ram_usage <= state.total_ram
|
||||
&& state.used_vram + model_vram_usage <= state.total_vram
|
||||
{
|
||||
tracing::info!("Should have enough ram now");
|
||||
break;
|
||||
}
|
||||
// }
|
||||
}
|
||||
for port in to_remove {
|
||||
tracing::info!("Removing instance on port {}", port);
|
||||
state.instances.remove(&port);
|
||||
}
|
||||
} else {
|
||||
tracing::info!("Already enough res free");
|
||||
}
|
||||
|
||||
// Start new instance
|
||||
let args = model_config
|
||||
.args
|
||||
.iter()
|
||||
.flat_map(|(k, v)| {
|
||||
if v == "true" {
|
||||
vec![format!("--{}", k)]
|
||||
} else {
|
||||
vec![format!("--{}", k), v.clone()]
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut cmd = Command::new("llama-server");
|
||||
cmd.kill_on_drop(true);
|
||||
cmd.envs(model_config.env.clone());
|
||||
cmd.args(&args);
|
||||
// TODO use openport crate via pick_random_unused_port for determining these
|
||||
cmd.arg("--port");
|
||||
cmd.arg(format!(
|
||||
"{}",
|
||||
model_config
|
||||
.internal_port
|
||||
.expect("Unexpected empty port, should've been picked")
|
||||
));
|
||||
cmd.stdout(Stdio::null()).stderr(Stdio::null()); // TODO save output and allow retrieval via api
|
||||
|
||||
tracing::info!("Starting llama-server with {:?}", cmd);
|
||||
let process = Arc::new(Mutex::new(
|
||||
cmd.spawn().expect("Failed to start llama-server"),
|
||||
));
|
||||
|
||||
state.used_ram += model_ram_usage;
|
||||
state.used_vram += model_vram_usage;
|
||||
|
||||
let instance = LlamaInstance {
|
||||
config: model_config.clone(),
|
||||
process,
|
||||
// busy: true,
|
||||
};
|
||||
sleep(Duration::from_millis(500)).await;
|
||||
state.instances.insert(model_config.port, instance.clone());
|
||||
|
||||
instance
|
||||
}
|
||||
};
|
||||
|
||||
// Wait for the instance to be ready
|
||||
is_llama_instance_running(&instance).await?;
|
||||
wait_for_port(
|
||||
model_config
|
||||
.internal_port
|
||||
.expect("Unexpected empty port, should've been picked"),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Proxy the request
|
||||
let retry_policy = reqwest_retry::policies::ExponentialBackoff::builder()
|
||||
.retry_bounds(Duration::from_millis(500), Duration::from_secs(8))
|
||||
.jitter(reqwest_retry::Jitter::None)
|
||||
.base(2)
|
||||
.build_with_max_retries(8);
|
||||
|
||||
let client = reqwest_middleware::ClientBuilder::new(Client::new())
|
||||
.with(reqwest_retry::RetryTransientMiddleware::new_with_policy(
|
||||
retry_policy,
|
||||
))
|
||||
.build();
|
||||
|
||||
let uri = format!(
|
||||
"http://127.0.0.1:{}{}",
|
||||
model_config
|
||||
.internal_port
|
||||
.expect("Unexpected empty port, should've been picked"),
|
||||
req.uri().path_and_query().map(|x| x.as_str()).unwrap_or("")
|
||||
);
|
||||
|
||||
let mut request_builder = client.request(req.method().clone(), &uri);
|
||||
// let mut request_builder = reqwest::RequestBuilder::from_parts(client, req.method().clone())
|
||||
|
||||
for (name, value) in req.headers() {
|
||||
request_builder = request_builder.header(name, value);
|
||||
}
|
||||
|
||||
let body_bytes = axum::body::to_bytes(
|
||||
req.into_body(),
|
||||
parse_size("1G").unwrap().try_into().unwrap(),
|
||||
)
|
||||
.await?;
|
||||
let request = request_builder.body(body_bytes).build()?;
|
||||
// tracing::info!("Proxying request to {}", uri);
|
||||
let response = client.execute(request).await?;
|
||||
|
||||
// tracing::info!("Received response from {}", uri);
|
||||
let mut builder = Response::builder().status(response.status());
|
||||
for (name, value) in response.headers() {
|
||||
builder = builder.header(name, value);
|
||||
}
|
||||
|
||||
// let bytes = response.bytes().await?;
|
||||
|
||||
let byte_stream = response.bytes_stream();
|
||||
|
||||
let body = axum::body::Body::from_stream(byte_stream);
|
||||
|
||||
tracing::debug!("streaming response on port: {}", model_config.port);
|
||||
let response = builder.body(body)?;
|
||||
Ok::<axum::http::Response<Body>, AppError>(response)
|
||||
}
|
||||
|
||||
async fn wait_for_port(port: u16) -> Result<(), anyhow::Error> {
|
||||
for _ in 0..10 {
|
||||
if TcpStream::connect(("127.0.0.1", port)).await.is_ok() {
|
||||
return Ok(());
|
||||
}
|
||||
sleep(Duration::from_millis(500)).await;
|
||||
}
|
||||
Err(anyhow!("Timeout waiting for port"))
|
||||
}
|
||||
|
||||
// reimplement wait_for_port using a LlamaInstance, also check if the process is still running
|
||||
async fn is_llama_instance_running(instance: &LlamaInstance) -> Result<(), anyhow::Error> {
|
||||
match instance.process.to_owned().lock_owned().await.try_wait()? {
|
||||
Some(exit_status) => {
|
||||
let msg = "Llama instance exited";
|
||||
tracing::error!(msg, ?exit_status);
|
||||
Err(anyhow!(msg))
|
||||
}
|
||||
|
||||
None => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_size(size_str: &str) -> Option<u64> {
|
||||
let mut num = String::new();
|
||||
let mut unit = String::new();
|
||||
|
||||
for c in size_str.chars() {
|
||||
if c.is_digit(10) || c == '.' {
|
||||
num.push(c);
|
||||
} else {
|
||||
unit.push(c);
|
||||
}
|
||||
}
|
||||
|
||||
let num: f64 = num.parse().ok()?;
|
||||
|
||||
let multiplier = match unit.to_lowercase().as_str() {
|
||||
"g" | "gb" => 1024 * 1024 * 1024,
|
||||
"m" | "mb" => 1024 * 1024,
|
||||
"k" | "kb" => 1024,
|
||||
_ => panic!("Invalid Size"),
|
||||
};
|
||||
|
||||
let res = (num * multiplier as f64) as u64;
|
||||
Some(res)
|
||||
start_server(config).await;
|
||||
}
|
||||
|
|
69
llama_proxy_man/src/proxy.rs
Normal file
69
llama_proxy_man/src/proxy.rs
Normal file
|
@ -0,0 +1,69 @@
|
|||
use crate::{
|
||||
config::ModelSpec, error::AppError, inference_process::InferenceProcess, state::AppState,
|
||||
util::parse_size,
|
||||
};
|
||||
use axum::{
|
||||
body::Body,
|
||||
http::{Request, Response},
|
||||
response::IntoResponse,
|
||||
};
|
||||
use reqwest::Client;
|
||||
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
|
||||
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
||||
|
||||
pub async fn proxy_request(
|
||||
req: Request<Body>,
|
||||
spec: &ModelSpec,
|
||||
) -> Result<Response<Body>, AppError> {
|
||||
let retry_policy = ExponentialBackoff::builder()
|
||||
.retry_bounds(
|
||||
std::time::Duration::from_millis(500),
|
||||
std::time::Duration::from_secs(8),
|
||||
)
|
||||
.jitter(reqwest_retry::Jitter::None)
|
||||
.base(2)
|
||||
.build_with_max_retries(8);
|
||||
|
||||
let client: ClientWithMiddleware = ClientBuilder::new(Client::new())
|
||||
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
|
||||
.build();
|
||||
|
||||
let internal_port = spec.internal_port.expect("Internal port must be set");
|
||||
let uri = format!(
|
||||
"http://127.0.0.1:{}{}",
|
||||
internal_port,
|
||||
req.uri()
|
||||
.path_and_query()
|
||||
.map(|pq| pq.as_str())
|
||||
.unwrap_or("")
|
||||
);
|
||||
|
||||
let mut request_builder = client.request(req.method().clone(), &uri);
|
||||
for (name, value) in req.headers() {
|
||||
request_builder = request_builder.header(name, value);
|
||||
}
|
||||
|
||||
let max_size = parse_size("1G").unwrap() as usize;
|
||||
let body_bytes = axum::body::to_bytes(req.into_body(), max_size).await?;
|
||||
let request = request_builder.body(body_bytes).build()?;
|
||||
let response = client.execute(request).await?;
|
||||
|
||||
let mut builder = Response::builder().status(response.status());
|
||||
for (name, value) in response.headers() {
|
||||
builder = builder.header(name, value);
|
||||
}
|
||||
|
||||
let byte_stream = response.bytes_stream();
|
||||
let body = Body::from_stream(byte_stream);
|
||||
Ok(builder.body(body)?)
|
||||
}
|
||||
|
||||
pub async fn handle_request(
|
||||
req: Request<Body>,
|
||||
spec: ModelSpec,
|
||||
state: AppState,
|
||||
) -> impl IntoResponse {
|
||||
let _ = InferenceProcess::get_or_spawn(&spec, state).await?;
|
||||
let response = proxy_request(req, &spec).await?;
|
||||
Ok::<Response<Body>, AppError>(response)
|
||||
}
|
36
llama_proxy_man/src/state.rs
Normal file
36
llama_proxy_man/src/state.rs
Normal file
|
@ -0,0 +1,36 @@
|
|||
use crate::{config::AppConfig, inference_process::InferenceProcess, util::parse_size};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ResourceManager {
|
||||
pub total_ram: u64,
|
||||
pub total_vram: u64,
|
||||
pub used_ram: u64,
|
||||
pub used_vram: u64,
|
||||
pub processes: HashMap<u16, InferenceProcess>,
|
||||
}
|
||||
|
||||
pub type ResourceManagerHandle = Arc<Mutex<ResourceManager>>;
|
||||
|
||||
#[derive(Clone, Debug, derive_more::Deref)]
|
||||
pub struct AppState(pub ResourceManagerHandle);
|
||||
|
||||
impl AppState {
|
||||
pub fn from_config(config: AppConfig) -> Self {
|
||||
let total_ram =
|
||||
parse_size(&config.system_resources.ram).expect("Invalid RAM size in config");
|
||||
let total_vram =
|
||||
parse_size(&config.system_resources.vram).expect("Invalid VRAM size in config");
|
||||
|
||||
let resource_manager = ResourceManager {
|
||||
total_ram,
|
||||
total_vram,
|
||||
used_ram: 0,
|
||||
used_vram: 0,
|
||||
processes: HashMap::new(),
|
||||
};
|
||||
|
||||
Self(Arc::new(Mutex::new(resource_manager)))
|
||||
}
|
||||
}
|
22
llama_proxy_man/src/util.rs
Normal file
22
llama_proxy_man/src/util.rs
Normal file
|
@ -0,0 +1,22 @@
|
|||
pub fn parse_size(size_str: &str) -> Option<u64> {
|
||||
let mut num = String::new();
|
||||
let mut unit = String::new();
|
||||
|
||||
for c in size_str.chars() {
|
||||
if c.is_digit(10) || c == '.' {
|
||||
num.push(c);
|
||||
} else {
|
||||
unit.push(c);
|
||||
}
|
||||
}
|
||||
|
||||
let num: f64 = num.parse().ok()?;
|
||||
let multiplier = match unit.to_lowercase().as_str() {
|
||||
"g" | "gb" => 1024 * 1024 * 1024,
|
||||
"m" | "mb" => 1024 * 1024,
|
||||
"k" | "kb" => 1024,
|
||||
_ => panic!("Invalid Size"),
|
||||
};
|
||||
|
||||
Some((num * multiplier as f64) as u64)
|
||||
}
|
Loading…
Add table
Reference in a new issue