Compare commits
No commits in common. "bebdd35c6e5d5cbc9de44a260a82112b3cbb524b" and "d39459f2a9272dcb3ec6a40ca7ff6cca2ad6d5ff" have entirely different histories.
bebdd35c6e
...
d39459f2a9
13 changed files with 427 additions and 489 deletions
25
Cargo.lock
generated
25
Cargo.lock
generated
|
@ -1141,26 +1141,6 @@ dependencies = [
|
|||
"syn 2.0.98",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_more"
|
||||
version = "2.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "093242cf7570c207c83073cf82f79706fe7b8317e98620a47d5be7c3d8497678"
|
||||
dependencies = [
|
||||
"derive_more-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_more-impl"
|
||||
version = "2.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bda628edc44c4bb645fbe0f758797143e4e07926f7ebf4e9bdfbd3d2ce621df3"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.98",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "digest"
|
||||
version = "0.10.7"
|
||||
|
@ -2970,7 +2950,7 @@ dependencies = [
|
|||
"chrono",
|
||||
"console_error_panic_hook",
|
||||
"dashmap",
|
||||
"derive_more 0.99.19",
|
||||
"derive_more",
|
||||
"futures",
|
||||
"futures-util",
|
||||
"gloo-net 0.5.0",
|
||||
|
@ -3022,7 +3002,6 @@ version = "0.1.1"
|
|||
dependencies = [
|
||||
"anyhow",
|
||||
"axum",
|
||||
"derive_more 2.0.1",
|
||||
"futures",
|
||||
"hyper",
|
||||
"itertools 0.13.0",
|
||||
|
@ -4802,7 +4781,7 @@ checksum = "df320f1889ac4ba6bc0cdc9c9af7af4bd64bb927bccdf32d81140dc1f9be12fe"
|
|||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"cssparser",
|
||||
"derive_more 0.99.19",
|
||||
"derive_more",
|
||||
"fxhash",
|
||||
"log",
|
||||
"matches",
|
||||
|
|
10
README.md
10
README.md
|
@ -1,15 +1,5 @@
|
|||
# Redvau.lt AI Monorepo
|
||||
|
||||
## Short Term Todo
|
||||
- [x] Prepare proxy man for embedding
|
||||
- [-] Improve markdown rendering in forge chat
|
||||
- [ ] Embed proxy man & add simple ui in forge
|
||||
- [ ] dumb embed process on startup
|
||||
- [ ] View current instances/models/virtual endpoints
|
||||
- [ ] Edit
|
||||
- [ ] Add new (from configurable model folder)
|
||||
|
||||
|
||||
## Current Repos
|
||||
|
||||
#### llama-forge-rs:
|
||||
|
|
|
@ -14,7 +14,6 @@ crate-type = ["cdylib", "rlib"]
|
|||
|
||||
[dependencies]
|
||||
wasm-bindgen = "=0.2.100"
|
||||
# TODO Update to 0.7
|
||||
leptos = { version = "0.6", features = [
|
||||
"serde",
|
||||
"nightly",
|
||||
|
@ -88,9 +87,10 @@ pulldown-cmark = { version = "0.12.2", features = ["serde"] }
|
|||
# qdrant-client = "1.11.2"
|
||||
# swiftide = "0.9.1"
|
||||
|
||||
# TODO Add desktop/gui feature (or maybe server/headless only?)
|
||||
# TODO Add desktop/gui feature
|
||||
[features]
|
||||
default = ["ssr"]
|
||||
# default = ["ssr"]
|
||||
default = ["hydrate"]
|
||||
hydrate = ["leptos/hydrate", "leptos_meta/hydrate", "leptos_router/hydrate"]
|
||||
ssr = [
|
||||
"dep:async-broadcast",
|
||||
|
|
|
@ -29,4 +29,3 @@ reqwest-retry = "0.6.1"
|
|||
reqwest-middleware = { version = "0.3.3", features = ["charset", "http2", "json", "multipart", "rustls-tls"] }
|
||||
itertools = "0.13.0"
|
||||
openport = { version = "0.1.1", features = ["rand"] }
|
||||
derive_more = { version = "2.0.1", features = ["deref"] }
|
||||
|
|
|
@ -1,48 +0,0 @@
|
|||
use serde::Deserialize;
|
||||
use std::{collections::HashMap, fs};
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct AppConfig {
|
||||
pub system_resources: SystemResources,
|
||||
pub model_specs: Vec<ModelSpec>,
|
||||
}
|
||||
|
||||
impl AppConfig {
|
||||
pub fn default_from_pwd_yml() -> Self {
|
||||
let config_str = fs::read_to_string("config.yaml").expect("Failed to read config.yaml");
|
||||
serde_yaml::from_str::<Self>(&config_str)
|
||||
.expect("Failed to parse config.yaml")
|
||||
.assign_internal_ports()
|
||||
}
|
||||
|
||||
// Ensure every model has an internal port
|
||||
pub fn assign_internal_ports(self) -> Self {
|
||||
let mut config = self.clone();
|
||||
for model in &mut config.model_specs {
|
||||
if model.internal_port.is_none() {
|
||||
model.internal_port = Some(
|
||||
openport::pick_random_unused_port()
|
||||
.expect(&format!("No open port found for {:?}", model)),
|
||||
);
|
||||
}
|
||||
}
|
||||
config
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct SystemResources {
|
||||
pub ram: String,
|
||||
pub vram: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct ModelSpec {
|
||||
pub name: String,
|
||||
pub port: u16,
|
||||
pub internal_port: Option<u16>,
|
||||
pub env: HashMap<String, String>,
|
||||
pub args: HashMap<String, String>,
|
||||
pub vram_usage: String,
|
||||
pub ram_usage: String,
|
||||
}
|
|
@ -1,36 +0,0 @@
|
|||
use axum::{http, response::IntoResponse};
|
||||
use hyper;
|
||||
use reqwest;
|
||||
use reqwest_middleware;
|
||||
use std::io;
|
||||
use thiserror::Error;
|
||||
use anyhow::Error as AnyError;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum AppError {
|
||||
#[error("Axum Error")]
|
||||
AxumError(#[from] axum::Error),
|
||||
#[error("Axum Http Error")]
|
||||
AxumHttpError(#[from] http::Error),
|
||||
#[error("Reqwest Error")]
|
||||
ReqwestError(#[from] reqwest::Error),
|
||||
#[error("Reqwest Middleware Error")]
|
||||
ReqwestMiddlewareError(#[from] reqwest_middleware::Error),
|
||||
#[error("Client Error")]
|
||||
ClientError(#[from] hyper::Error),
|
||||
#[error("Io Error")]
|
||||
IoError(#[from] io::Error),
|
||||
#[error("Unknown error")]
|
||||
Unknown(#[from] AnyError),
|
||||
}
|
||||
|
||||
impl IntoResponse for AppError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
tracing::error!("::AppError::into_response: {:?}", self);
|
||||
(
|
||||
http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Something went wrong: {:?}", self),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
|
@ -1,149 +0,0 @@
|
|||
use crate::{config::ModelSpec, error::AppError, state::AppState, util::parse_size};
|
||||
use anyhow::anyhow;
|
||||
use itertools::Itertools;
|
||||
use std::{process::Stdio, sync::Arc};
|
||||
use tokio::{
|
||||
net::TcpStream,
|
||||
process::{Child, Command},
|
||||
sync::Mutex,
|
||||
time::{sleep, Duration},
|
||||
};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct InferenceProcess {
|
||||
pub spec: ModelSpec,
|
||||
pub process: Arc<Mutex<Child>>,
|
||||
}
|
||||
|
||||
impl InferenceProcess {
|
||||
/// Retrieve a running process from state or spawn a new one.
|
||||
pub async fn get_or_spawn(
|
||||
spec: &ModelSpec,
|
||||
state: AppState,
|
||||
) -> Result<InferenceProcess, AppError> {
|
||||
let required_ram = parse_size(&spec.ram_usage).expect("Invalid ram_usage in model spec");
|
||||
let required_vram = parse_size(&spec.vram_usage).expect("Invalid vram_usage in model spec");
|
||||
|
||||
{
|
||||
let state_guard = state.lock().await;
|
||||
if let Some(proc) = state_guard.processes.get(&spec.port) {
|
||||
return Ok(proc.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let mut state_guard = state.lock().await;
|
||||
tracing::info!(msg = "App state before spawn", ?state_guard);
|
||||
|
||||
// If not enough resources, stop some running processes.
|
||||
if (state_guard.used_ram + required_ram > state_guard.total_ram)
|
||||
|| (state_guard.used_vram + required_vram > state_guard.total_vram)
|
||||
{
|
||||
let mut to_remove = Vec::new();
|
||||
let processes_by_usage = state_guard
|
||||
.processes
|
||||
.iter()
|
||||
.sorted_by(|(_, a), (_, b)| {
|
||||
parse_size(&b.spec.vram_usage)
|
||||
.unwrap_or(0)
|
||||
.cmp(&parse_size(&a.spec.vram_usage).unwrap_or(0))
|
||||
})
|
||||
.map(|(port, proc)| (*port, proc.clone()))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for (port, proc) in processes_by_usage.iter() {
|
||||
tracing::info!("Stopping process on port {}", port);
|
||||
let mut lock = proc.process.lock().await;
|
||||
lock.kill().await.ok();
|
||||
to_remove.push(*port);
|
||||
state_guard.used_ram = state_guard
|
||||
.used_ram
|
||||
.saturating_sub(parse_size(&proc.spec.ram_usage).unwrap_or(0));
|
||||
state_guard.used_vram = state_guard
|
||||
.used_vram
|
||||
.saturating_sub(parse_size(&proc.spec.vram_usage).unwrap_or(0));
|
||||
if state_guard.used_ram + required_ram <= state_guard.total_ram
|
||||
&& state_guard.used_vram + required_vram <= state_guard.total_vram
|
||||
{
|
||||
tracing::info!("Freed enough resources");
|
||||
break;
|
||||
}
|
||||
}
|
||||
for port in to_remove {
|
||||
state_guard.processes.remove(&port);
|
||||
}
|
||||
} else {
|
||||
tracing::info!("Sufficient resources available");
|
||||
}
|
||||
|
||||
let proc = Self::spawn(spec).await?;
|
||||
state_guard.used_ram += required_ram;
|
||||
state_guard.used_vram += required_vram;
|
||||
state_guard.processes.insert(spec.port, proc.clone());
|
||||
|
||||
sleep(Duration::from_millis(250)).await;
|
||||
proc.wait_until_ready().await?;
|
||||
sleep(Duration::from_millis(250)).await;
|
||||
Ok(proc)
|
||||
}
|
||||
|
||||
/// Spawn a new inference process.
|
||||
pub async fn spawn(spec: &ModelSpec) -> Result<InferenceProcess, AppError> {
|
||||
let args = spec
|
||||
.args
|
||||
.iter()
|
||||
.flat_map(|(k, v)| {
|
||||
if v == "true" {
|
||||
vec![format!("--{}", k)]
|
||||
} else {
|
||||
vec![format!("--{}", k), v.clone()]
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let internal_port = spec.internal_port.expect("Internal port must be set");
|
||||
|
||||
let mut cmd = Command::new("llama-server");
|
||||
cmd.kill_on_drop(true)
|
||||
.envs(spec.env.clone())
|
||||
.args(&args)
|
||||
.arg("--port")
|
||||
.arg(format!("{}", internal_port))
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null());
|
||||
|
||||
tracing::info!("Starting llama-server with command: {:?}", cmd);
|
||||
let child = cmd.spawn().expect("Failed to start llama-server");
|
||||
|
||||
Ok(InferenceProcess {
|
||||
spec: spec.clone(),
|
||||
process: Arc::new(Mutex::new(child)),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn wait_until_ready(&self) -> Result<(), anyhow::Error> {
|
||||
async fn check_running(proc: &InferenceProcess) -> Result<(), AppError> {
|
||||
let mut lock = proc.process.clone().lock_owned().await;
|
||||
match lock.try_wait()? {
|
||||
Some(exit_status) => {
|
||||
tracing::error!("Inference process exited: {:?}", exit_status);
|
||||
Err(AppError::Unknown(anyhow!("Inference process exited")))
|
||||
}
|
||||
None => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_for_port(port: u16) -> Result<(), anyhow::Error> {
|
||||
for _ in 0..10 {
|
||||
if TcpStream::connect(("127.0.0.1", port)).await.is_ok() {
|
||||
return Ok(());
|
||||
}
|
||||
sleep(Duration::from_millis(500)).await;
|
||||
}
|
||||
Err(anyhow!("Timeout waiting for port"))
|
||||
}
|
||||
|
||||
check_running(self).await?;
|
||||
wait_for_port(self.spec.internal_port.expect("Internal port must be set")).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
|
@ -1,70 +0,0 @@
|
|||
pub mod config;
|
||||
pub mod error;
|
||||
pub mod inference_process;
|
||||
pub mod logging;
|
||||
pub mod proxy;
|
||||
pub mod state;
|
||||
pub mod util;
|
||||
|
||||
use axum::{routing::any, Router};
|
||||
use config::{AppConfig, ModelSpec};
|
||||
use state::AppState;
|
||||
use std::net::SocketAddr;
|
||||
use tower_http::trace::{
|
||||
DefaultMakeSpan, DefaultOnEos, DefaultOnFailure, DefaultOnRequest, DefaultOnResponse,
|
||||
TraceLayer,
|
||||
};
|
||||
use tracing::Level;
|
||||
|
||||
/// Creates an Axum application to handle inference requests for a specific model.
|
||||
pub fn create_app(spec: &ModelSpec, state: AppState) -> Router {
|
||||
Router::new()
|
||||
.route(
|
||||
"/",
|
||||
any({
|
||||
let state = state.clone();
|
||||
let spec = spec.clone();
|
||||
move |req| proxy::handle_request(req, spec.clone(), state.clone())
|
||||
}),
|
||||
)
|
||||
.route(
|
||||
"/*path",
|
||||
any({
|
||||
let state = state.clone();
|
||||
let spec = spec.clone();
|
||||
move |req| proxy::handle_request(req, spec.clone(), state.clone())
|
||||
}),
|
||||
)
|
||||
.layer(
|
||||
TraceLayer::new_for_http()
|
||||
.make_span_with(DefaultMakeSpan::new().include_headers(true))
|
||||
.on_request(DefaultOnRequest::new().level(Level::DEBUG))
|
||||
.on_response(DefaultOnResponse::new().level(Level::TRACE))
|
||||
.on_eos(DefaultOnEos::new().level(Level::DEBUG))
|
||||
.on_failure(DefaultOnFailure::new().level(Level::ERROR)),
|
||||
)
|
||||
}
|
||||
|
||||
/// Starts an inference server for each model defined in the config.
|
||||
pub async fn start_server(config: AppConfig) {
|
||||
let state = AppState::from_config(config.clone());
|
||||
|
||||
let mut handles = Vec::new();
|
||||
for spec in config.model_specs {
|
||||
let state = state.clone();
|
||||
let spec = spec.clone();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let app = create_app(&spec, state);
|
||||
let addr = SocketAddr::from(([0, 0, 0, 0], spec.port));
|
||||
tracing::info!(msg = "Listening", ?spec);
|
||||
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
|
||||
axum::serve(listener, app.into_make_service())
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
futures::future::join_all(handles).await;
|
||||
}
|
|
@ -5,8 +5,6 @@ use std::{
|
|||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use std::sync::Once;
|
||||
|
||||
use axum::{body::Body, http::Request};
|
||||
use pin_project_lite::pin_project;
|
||||
use tower::{Layer, Service};
|
||||
|
@ -83,18 +81,3 @@ where
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
static INIT: Once = Once::new();
|
||||
|
||||
pub fn initialize_logger() {
|
||||
INIT.call_once(|| {
|
||||
let env_filter = tracing_subscriber::EnvFilter::builder()
|
||||
.with_default_directive(tracing::level_filters::LevelFilter::INFO.into())
|
||||
.from_env_lossy();
|
||||
|
||||
tracing_subscriber::fmt()
|
||||
.compact()
|
||||
.with_env_filter(env_filter)
|
||||
.init();
|
||||
});
|
||||
}
|
||||
|
|
|
@ -1,11 +1,428 @@
|
|||
use llama_proxy_man::{config::AppConfig, logging, start_server};
|
||||
use tokio;
|
||||
mod logging;
|
||||
|
||||
use anyhow::anyhow;
|
||||
use axum::{
|
||||
body::Body,
|
||||
http::{self, Request, Response},
|
||||
response::IntoResponse,
|
||||
routing::any,
|
||||
Router,
|
||||
};
|
||||
use futures;
|
||||
use itertools::Itertools;
|
||||
use reqwest::Client;
|
||||
use serde::Deserialize;
|
||||
use std::{collections::HashMap, net::SocketAddr, process::Stdio, sync::Arc};
|
||||
use tokio::{
|
||||
net::TcpStream,
|
||||
process::{Child, Command},
|
||||
sync::Mutex,
|
||||
time::{sleep, Duration},
|
||||
};
|
||||
use tower_http::trace::{
|
||||
DefaultMakeSpan, DefaultOnEos, DefaultOnFailure, DefaultOnRequest, DefaultOnResponse,
|
||||
TraceLayer,
|
||||
};
|
||||
use tracing::Level;
|
||||
|
||||
use std::sync::Once;
|
||||
|
||||
pub static INIT: Once = Once::new();
|
||||
|
||||
pub fn initialize_logger() {
|
||||
INIT.call_once(|| {
|
||||
let env_filter = tracing_subscriber::EnvFilter::builder()
|
||||
.with_default_directive(tracing::level_filters::LevelFilter::INFO.into())
|
||||
.from_env_lossy();
|
||||
|
||||
tracing_subscriber::fmt()
|
||||
.compact()
|
||||
.with_env_filter(env_filter)
|
||||
.init();
|
||||
});
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
struct Config {
|
||||
hardware: Hardware,
|
||||
models: Vec<ModelConfig>,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
// TODO split up into raw deser config and "parsed"/"processed" config which always has a port
|
||||
fn pick_open_ports(self) -> Self {
|
||||
let mut config = self.clone();
|
||||
for model in &mut config.models {
|
||||
if model.internal_port.is_none() {
|
||||
model.internal_port = Some(
|
||||
openport::pick_random_unused_port()
|
||||
.expect(format!("No open port found for {:?}", model).as_str()),
|
||||
);
|
||||
}
|
||||
}
|
||||
config
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
struct Hardware {
|
||||
ram: String,
|
||||
vram: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
struct ModelConfig {
|
||||
#[allow(dead_code)]
|
||||
name: String,
|
||||
port: u16,
|
||||
internal_port: Option<u16>,
|
||||
env: HashMap<String, String>,
|
||||
args: HashMap<String, String>,
|
||||
vram_usage: String,
|
||||
ram_usage: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct LlamaInstance {
|
||||
config: ModelConfig,
|
||||
process: Arc<Mutex<Child>>,
|
||||
// busy: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct SharedState {
|
||||
total_ram: u64,
|
||||
total_vram: u64,
|
||||
used_ram: u64,
|
||||
used_vram: u64,
|
||||
instances: HashMap<u16, LlamaInstance>,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
logging::initialize_logger();
|
||||
// TODO add autostart of models based on config
|
||||
// abstract starting logic out of handler for this to allow seperate calls to start
|
||||
// maybe add to SharedState & LLamaInstance ?
|
||||
|
||||
let config = AppConfig::default_from_pwd_yml();
|
||||
initialize_logger();
|
||||
// Read and parse the YAML configuration
|
||||
let config_str = std::fs::read_to_string("config.yaml").expect("Failed to read config.yaml");
|
||||
let config: Config = serde_yaml::from_str::<Config>(&config_str)
|
||||
.expect("Failed to parse config.yaml")
|
||||
.pick_open_ports();
|
||||
|
||||
start_server(config).await;
|
||||
// Parse hardware resources
|
||||
let total_ram = parse_size(&config.hardware.ram).expect("Invalid RAM size in config");
|
||||
let total_vram = parse_size(&config.hardware.vram).expect("Invalid VRAM size in config");
|
||||
|
||||
// Initialize shared state
|
||||
let shared_state = Arc::new(Mutex::new(SharedState {
|
||||
total_ram,
|
||||
total_vram,
|
||||
used_ram: 0,
|
||||
used_vram: 0,
|
||||
instances: HashMap::new(),
|
||||
}));
|
||||
|
||||
// For each model, set up an axum server listening on the specified port
|
||||
let mut handles = Vec::new();
|
||||
|
||||
for model_config in config.models {
|
||||
let state = shared_state.clone();
|
||||
let model_config = model_config.clone();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let model_config = model_config.clone();
|
||||
|
||||
let app = Router::new()
|
||||
.route(
|
||||
"/",
|
||||
any({
|
||||
let state = state.clone();
|
||||
let model_config = model_config.clone();
|
||||
move |req| handle_request(req, model_config, state)
|
||||
}),
|
||||
)
|
||||
.route(
|
||||
"/*path",
|
||||
any({
|
||||
let state = state.clone();
|
||||
let model_config = model_config.clone();
|
||||
move |req| handle_request(req, model_config, state)
|
||||
}),
|
||||
)
|
||||
.layer(
|
||||
TraceLayer::new_for_http()
|
||||
.make_span_with(DefaultMakeSpan::new().include_headers(true))
|
||||
.on_request(DefaultOnRequest::new().level(Level::DEBUG))
|
||||
.on_response(DefaultOnResponse::new().level(Level::TRACE))
|
||||
.on_eos(DefaultOnEos::new().level(Level::DEBUG))
|
||||
.on_failure(DefaultOnFailure::new().level(Level::ERROR)),
|
||||
);
|
||||
|
||||
let addr = SocketAddr::from(([0, 0, 0, 0], model_config.port));
|
||||
|
||||
tracing::info!(msg = "Listening", ?model_config);
|
||||
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
|
||||
|
||||
axum::serve(listener, app.into_make_service())
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
futures::future::join_all(handles).await;
|
||||
}
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum AppError {
|
||||
#[error("Axum Error")]
|
||||
AxumError(#[from] axum::Error),
|
||||
#[error("Axum Http Error")]
|
||||
AxumHttpError(#[from] axum::http::Error),
|
||||
#[error("Reqwest Error")]
|
||||
ReqwestError(#[from] reqwest::Error),
|
||||
#[error("Reqwest Middleware Error")]
|
||||
ReqwestMiddlewareError(#[from] reqwest_middleware::Error),
|
||||
#[error("Client Error")]
|
||||
ClientError(#[from] hyper::Error),
|
||||
#[error("Unknown error")]
|
||||
Unknown(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
// Tell axum how to convert `AppError` into a response.
|
||||
impl IntoResponse for AppError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
tracing::error!("::AppError::into_response:: {:?}", self);
|
||||
(
|
||||
http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Something went wrong: {:?}", self),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_request(
|
||||
req: Request<Body>,
|
||||
model_config: ModelConfig,
|
||||
state: Arc<Mutex<SharedState>>,
|
||||
// ) -> Result<Response<Body>, anyhow::Error> {
|
||||
) -> impl IntoResponse {
|
||||
let model_ram_usage = parse_size(&model_config.ram_usage).expect("Invalid ram_usage");
|
||||
let model_vram_usage = parse_size(&model_config.vram_usage).expect("Invalid vram_usage");
|
||||
|
||||
let instance = {
|
||||
let mut state = state.lock().await;
|
||||
|
||||
// Check if instance is running
|
||||
if let Some(instance) = state.instances.get_mut(&model_config.port) {
|
||||
// instance.busy = true;
|
||||
instance.to_owned()
|
||||
} else {
|
||||
// Check resources
|
||||
tracing::info!(msg = "Current state", ?state);
|
||||
if ((state.used_ram + model_ram_usage) > state.total_ram)
|
||||
|| ((state.used_vram + model_vram_usage) > state.total_vram)
|
||||
{
|
||||
// Stop other instances
|
||||
let mut to_remove = Vec::new();
|
||||
// TODO Actual smart stopping logic
|
||||
// - search for smallest single model to stop to get enough room
|
||||
// - if not possible search for smallest number of models to stop with lowest
|
||||
// amount of "overshot
|
||||
let instances_by_size =
|
||||
state
|
||||
.instances
|
||||
.clone()
|
||||
.into_iter()
|
||||
.sorted_by(|(_, el_a), (_, el_b)| {
|
||||
Ord::cmp(
|
||||
&parse_size(el_b.config.vram_usage.as_str()),
|
||||
&parse_size(el_a.config.vram_usage.as_str()),
|
||||
)
|
||||
});
|
||||
for (port, instance) in instances_by_size {
|
||||
// if !instance.busy {
|
||||
tracing::info!("Stopping instance on port {}", port);
|
||||
let mut process = instance.process.lock().await;
|
||||
process.kill().await.ok();
|
||||
to_remove.push(port);
|
||||
state.used_ram -= parse_size(&instance.config.ram_usage).unwrap_or(0);
|
||||
state.used_vram -= parse_size(&instance.config.vram_usage).unwrap_or(0);
|
||||
if state.used_ram + model_ram_usage <= state.total_ram
|
||||
&& state.used_vram + model_vram_usage <= state.total_vram
|
||||
{
|
||||
tracing::info!("Should have enough ram now");
|
||||
break;
|
||||
}
|
||||
// }
|
||||
}
|
||||
for port in to_remove {
|
||||
tracing::info!("Removing instance on port {}", port);
|
||||
state.instances.remove(&port);
|
||||
}
|
||||
} else {
|
||||
tracing::info!("Already enough res free");
|
||||
}
|
||||
|
||||
// Start new instance
|
||||
let args = model_config
|
||||
.args
|
||||
.iter()
|
||||
.flat_map(|(k, v)| {
|
||||
if v == "true" {
|
||||
vec![format!("--{}", k)]
|
||||
} else {
|
||||
vec![format!("--{}", k), v.clone()]
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut cmd = Command::new("llama-server");
|
||||
cmd.kill_on_drop(true);
|
||||
cmd.envs(model_config.env.clone());
|
||||
cmd.args(&args);
|
||||
// TODO use openport crate via pick_random_unused_port for determining these
|
||||
cmd.arg("--port");
|
||||
cmd.arg(format!(
|
||||
"{}",
|
||||
model_config
|
||||
.internal_port
|
||||
.expect("Unexpected empty port, should've been picked")
|
||||
));
|
||||
cmd.stdout(Stdio::null()).stderr(Stdio::null()); // TODO save output and allow retrieval via api
|
||||
|
||||
tracing::info!("Starting llama-server with {:?}", cmd);
|
||||
let process = Arc::new(Mutex::new(
|
||||
cmd.spawn().expect("Failed to start llama-server"),
|
||||
));
|
||||
|
||||
state.used_ram += model_ram_usage;
|
||||
state.used_vram += model_vram_usage;
|
||||
|
||||
let instance = LlamaInstance {
|
||||
config: model_config.clone(),
|
||||
process,
|
||||
// busy: true,
|
||||
};
|
||||
sleep(Duration::from_millis(500)).await;
|
||||
state.instances.insert(model_config.port, instance.clone());
|
||||
|
||||
instance
|
||||
}
|
||||
};
|
||||
|
||||
// Wait for the instance to be ready
|
||||
is_llama_instance_running(&instance).await?;
|
||||
wait_for_port(
|
||||
model_config
|
||||
.internal_port
|
||||
.expect("Unexpected empty port, should've been picked"),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Proxy the request
|
||||
let retry_policy = reqwest_retry::policies::ExponentialBackoff::builder()
|
||||
.retry_bounds(Duration::from_millis(500), Duration::from_secs(8))
|
||||
.jitter(reqwest_retry::Jitter::None)
|
||||
.base(2)
|
||||
.build_with_max_retries(8);
|
||||
|
||||
let client = reqwest_middleware::ClientBuilder::new(Client::new())
|
||||
.with(reqwest_retry::RetryTransientMiddleware::new_with_policy(
|
||||
retry_policy,
|
||||
))
|
||||
.build();
|
||||
|
||||
let uri = format!(
|
||||
"http://127.0.0.1:{}{}",
|
||||
model_config
|
||||
.internal_port
|
||||
.expect("Unexpected empty port, should've been picked"),
|
||||
req.uri().path_and_query().map(|x| x.as_str()).unwrap_or("")
|
||||
);
|
||||
|
||||
let mut request_builder = client.request(req.method().clone(), &uri);
|
||||
// let mut request_builder = reqwest::RequestBuilder::from_parts(client, req.method().clone())
|
||||
|
||||
for (name, value) in req.headers() {
|
||||
request_builder = request_builder.header(name, value);
|
||||
}
|
||||
|
||||
let body_bytes = axum::body::to_bytes(
|
||||
req.into_body(),
|
||||
parse_size("1G").unwrap().try_into().unwrap(),
|
||||
)
|
||||
.await?;
|
||||
let request = request_builder.body(body_bytes).build()?;
|
||||
// tracing::info!("Proxying request to {}", uri);
|
||||
let response = client.execute(request).await?;
|
||||
|
||||
// tracing::info!("Received response from {}", uri);
|
||||
let mut builder = Response::builder().status(response.status());
|
||||
for (name, value) in response.headers() {
|
||||
builder = builder.header(name, value);
|
||||
}
|
||||
|
||||
// let bytes = response.bytes().await?;
|
||||
|
||||
let byte_stream = response.bytes_stream();
|
||||
|
||||
let body = axum::body::Body::from_stream(byte_stream);
|
||||
|
||||
tracing::debug!("streaming response on port: {}", model_config.port);
|
||||
let response = builder.body(body)?;
|
||||
Ok::<axum::http::Response<Body>, AppError>(response)
|
||||
}
|
||||
|
||||
async fn wait_for_port(port: u16) -> Result<(), anyhow::Error> {
|
||||
for _ in 0..10 {
|
||||
if TcpStream::connect(("127.0.0.1", port)).await.is_ok() {
|
||||
return Ok(());
|
||||
}
|
||||
sleep(Duration::from_millis(500)).await;
|
||||
}
|
||||
Err(anyhow!("Timeout waiting for port"))
|
||||
}
|
||||
|
||||
// reimplement wait_for_port using a LlamaInstance, also check if the process is still running
|
||||
async fn is_llama_instance_running(instance: &LlamaInstance) -> Result<(), anyhow::Error> {
|
||||
match instance.process.to_owned().lock_owned().await.try_wait()? {
|
||||
Some(exit_status) => {
|
||||
let msg = "Llama instance exited";
|
||||
tracing::error!(msg, ?exit_status);
|
||||
Err(anyhow!(msg))
|
||||
}
|
||||
|
||||
None => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_size(size_str: &str) -> Option<u64> {
|
||||
let mut num = String::new();
|
||||
let mut unit = String::new();
|
||||
|
||||
for c in size_str.chars() {
|
||||
if c.is_digit(10) || c == '.' {
|
||||
num.push(c);
|
||||
} else {
|
||||
unit.push(c);
|
||||
}
|
||||
}
|
||||
|
||||
let num: f64 = num.parse().ok()?;
|
||||
|
||||
let multiplier = match unit.to_lowercase().as_str() {
|
||||
"g" | "gb" => 1024 * 1024 * 1024,
|
||||
"m" | "mb" => 1024 * 1024,
|
||||
"k" | "kb" => 1024,
|
||||
_ => panic!("Invalid Size"),
|
||||
};
|
||||
|
||||
let res = (num * multiplier as f64) as u64;
|
||||
Some(res)
|
||||
}
|
||||
|
|
|
@ -1,69 +0,0 @@
|
|||
use crate::{
|
||||
config::ModelSpec, error::AppError, inference_process::InferenceProcess, state::AppState,
|
||||
util::parse_size,
|
||||
};
|
||||
use axum::{
|
||||
body::Body,
|
||||
http::{Request, Response},
|
||||
response::IntoResponse,
|
||||
};
|
||||
use reqwest::Client;
|
||||
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
|
||||
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
||||
|
||||
pub async fn proxy_request(
|
||||
req: Request<Body>,
|
||||
spec: &ModelSpec,
|
||||
) -> Result<Response<Body>, AppError> {
|
||||
let retry_policy = ExponentialBackoff::builder()
|
||||
.retry_bounds(
|
||||
std::time::Duration::from_millis(500),
|
||||
std::time::Duration::from_secs(8),
|
||||
)
|
||||
.jitter(reqwest_retry::Jitter::None)
|
||||
.base(2)
|
||||
.build_with_max_retries(8);
|
||||
|
||||
let client: ClientWithMiddleware = ClientBuilder::new(Client::new())
|
||||
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
|
||||
.build();
|
||||
|
||||
let internal_port = spec.internal_port.expect("Internal port must be set");
|
||||
let uri = format!(
|
||||
"http://127.0.0.1:{}{}",
|
||||
internal_port,
|
||||
req.uri()
|
||||
.path_and_query()
|
||||
.map(|pq| pq.as_str())
|
||||
.unwrap_or("")
|
||||
);
|
||||
|
||||
let mut request_builder = client.request(req.method().clone(), &uri);
|
||||
for (name, value) in req.headers() {
|
||||
request_builder = request_builder.header(name, value);
|
||||
}
|
||||
|
||||
let max_size = parse_size("1G").unwrap() as usize;
|
||||
let body_bytes = axum::body::to_bytes(req.into_body(), max_size).await?;
|
||||
let request = request_builder.body(body_bytes).build()?;
|
||||
let response = client.execute(request).await?;
|
||||
|
||||
let mut builder = Response::builder().status(response.status());
|
||||
for (name, value) in response.headers() {
|
||||
builder = builder.header(name, value);
|
||||
}
|
||||
|
||||
let byte_stream = response.bytes_stream();
|
||||
let body = Body::from_stream(byte_stream);
|
||||
Ok(builder.body(body)?)
|
||||
}
|
||||
|
||||
pub async fn handle_request(
|
||||
req: Request<Body>,
|
||||
spec: ModelSpec,
|
||||
state: AppState,
|
||||
) -> impl IntoResponse {
|
||||
let _ = InferenceProcess::get_or_spawn(&spec, state).await?;
|
||||
let response = proxy_request(req, &spec).await?;
|
||||
Ok::<Response<Body>, AppError>(response)
|
||||
}
|
|
@ -1,36 +0,0 @@
|
|||
use crate::{config::AppConfig, inference_process::InferenceProcess, util::parse_size};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ResourceManager {
|
||||
pub total_ram: u64,
|
||||
pub total_vram: u64,
|
||||
pub used_ram: u64,
|
||||
pub used_vram: u64,
|
||||
pub processes: HashMap<u16, InferenceProcess>,
|
||||
}
|
||||
|
||||
pub type ResourceManagerHandle = Arc<Mutex<ResourceManager>>;
|
||||
|
||||
#[derive(Clone, Debug, derive_more::Deref)]
|
||||
pub struct AppState(pub ResourceManagerHandle);
|
||||
|
||||
impl AppState {
|
||||
pub fn from_config(config: AppConfig) -> Self {
|
||||
let total_ram =
|
||||
parse_size(&config.system_resources.ram).expect("Invalid RAM size in config");
|
||||
let total_vram =
|
||||
parse_size(&config.system_resources.vram).expect("Invalid VRAM size in config");
|
||||
|
||||
let resource_manager = ResourceManager {
|
||||
total_ram,
|
||||
total_vram,
|
||||
used_ram: 0,
|
||||
used_vram: 0,
|
||||
processes: HashMap::new(),
|
||||
};
|
||||
|
||||
Self(Arc::new(Mutex::new(resource_manager)))
|
||||
}
|
||||
}
|
|
@ -1,22 +0,0 @@
|
|||
pub fn parse_size(size_str: &str) -> Option<u64> {
|
||||
let mut num = String::new();
|
||||
let mut unit = String::new();
|
||||
|
||||
for c in size_str.chars() {
|
||||
if c.is_digit(10) || c == '.' {
|
||||
num.push(c);
|
||||
} else {
|
||||
unit.push(c);
|
||||
}
|
||||
}
|
||||
|
||||
let num: f64 = num.parse().ok()?;
|
||||
let multiplier = match unit.to_lowercase().as_str() {
|
||||
"g" | "gb" => 1024 * 1024 * 1024,
|
||||
"m" | "mb" => 1024 * 1024,
|
||||
"k" | "kb" => 1024,
|
||||
_ => panic!("Invalid Size"),
|
||||
};
|
||||
|
||||
Some((num * multiplier as f64) as u64)
|
||||
}
|
Loading…
Add table
Reference in a new issue