Refactor llama_proxy_man
This commit is contained in:
parent
8c062ece28
commit
e3a3ec3826
3 changed files with 277 additions and 218 deletions
25
Cargo.lock
generated
25
Cargo.lock
generated
|
@ -1141,6 +1141,26 @@ dependencies = [
|
||||||
"syn 2.0.98",
|
"syn 2.0.98",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "derive_more"
|
||||||
|
version = "2.0.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "093242cf7570c207c83073cf82f79706fe7b8317e98620a47d5be7c3d8497678"
|
||||||
|
dependencies = [
|
||||||
|
"derive_more-impl",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "derive_more-impl"
|
||||||
|
version = "2.0.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "bda628edc44c4bb645fbe0f758797143e4e07926f7ebf4e9bdfbd3d2ce621df3"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 2.0.98",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "digest"
|
name = "digest"
|
||||||
version = "0.10.7"
|
version = "0.10.7"
|
||||||
|
@ -2950,7 +2970,7 @@ dependencies = [
|
||||||
"chrono",
|
"chrono",
|
||||||
"console_error_panic_hook",
|
"console_error_panic_hook",
|
||||||
"dashmap",
|
"dashmap",
|
||||||
"derive_more",
|
"derive_more 0.99.19",
|
||||||
"futures",
|
"futures",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
"gloo-net 0.5.0",
|
"gloo-net 0.5.0",
|
||||||
|
@ -3002,6 +3022,7 @@ version = "0.1.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"axum",
|
"axum",
|
||||||
|
"derive_more 2.0.1",
|
||||||
"futures",
|
"futures",
|
||||||
"hyper",
|
"hyper",
|
||||||
"itertools 0.13.0",
|
"itertools 0.13.0",
|
||||||
|
@ -4781,7 +4802,7 @@ checksum = "df320f1889ac4ba6bc0cdc9c9af7af4bd64bb927bccdf32d81140dc1f9be12fe"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bitflags 1.3.2",
|
"bitflags 1.3.2",
|
||||||
"cssparser",
|
"cssparser",
|
||||||
"derive_more",
|
"derive_more 0.99.19",
|
||||||
"fxhash",
|
"fxhash",
|
||||||
"log",
|
"log",
|
||||||
"matches",
|
"matches",
|
||||||
|
|
|
@ -29,3 +29,4 @@ reqwest-retry = "0.6.1"
|
||||||
reqwest-middleware = { version = "0.3.3", features = ["charset", "http2", "json", "multipart", "rustls-tls"] }
|
reqwest-middleware = { version = "0.3.3", features = ["charset", "http2", "json", "multipart", "rustls-tls"] }
|
||||||
itertools = "0.13.0"
|
itertools = "0.13.0"
|
||||||
openport = { version = "0.1.1", features = ["rand"] }
|
openport = { version = "0.1.1", features = ["rand"] }
|
||||||
|
derive_more = { version = "2.0.1", features = ["deref"] }
|
||||||
|
|
|
@ -11,6 +11,8 @@ use axum::{
|
||||||
use futures;
|
use futures;
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
|
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
|
||||||
|
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::{collections::HashMap, net::SocketAddr, process::Stdio, sync::Arc};
|
use std::{collections::HashMap, net::SocketAddr, process::Stdio, sync::Arc};
|
||||||
use tokio::{
|
use tokio::{
|
||||||
|
@ -19,10 +21,7 @@ use tokio::{
|
||||||
sync::Mutex,
|
sync::Mutex,
|
||||||
time::{sleep, Duration},
|
time::{sleep, Duration},
|
||||||
};
|
};
|
||||||
use tower_http::trace::{
|
use tower_http;
|
||||||
DefaultMakeSpan, DefaultOnEos, DefaultOnFailure, DefaultOnRequest, DefaultOnResponse,
|
|
||||||
TraceLayer,
|
|
||||||
};
|
|
||||||
use tracing::Level;
|
use tracing::Level;
|
||||||
|
|
||||||
use std::sync::Once;
|
use std::sync::Once;
|
||||||
|
@ -42,6 +41,8 @@ pub fn initialize_logger() {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO Add valiation to the config
|
||||||
|
// - e.g. check double taken ports etc
|
||||||
#[derive(Clone, Debug, Deserialize)]
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
struct Config {
|
struct Config {
|
||||||
hardware: Hardware,
|
hardware: Hardware,
|
||||||
|
@ -49,7 +50,18 @@ struct Config {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
|
fn default_from_pwd_yml() -> Self {
|
||||||
|
// Read and parse the YAML configuration
|
||||||
|
let config_str =
|
||||||
|
std::fs::read_to_string("config.yaml").expect("Failed to read config.yaml");
|
||||||
|
|
||||||
|
serde_yaml::from_str::<Self>(&config_str)
|
||||||
|
.expect("Failed to parse config.yaml")
|
||||||
|
.pick_open_ports()
|
||||||
|
}
|
||||||
|
|
||||||
// TODO split up into raw deser config and "parsed"/"processed" config which always has a port
|
// TODO split up into raw deser config and "parsed"/"processed" config which always has a port
|
||||||
|
// FIXME maybe just always initialize with a port ?
|
||||||
fn pick_open_ports(self) -> Self {
|
fn pick_open_ports(self) -> Self {
|
||||||
let mut config = self.clone();
|
let mut config = self.clone();
|
||||||
for model in &mut config.models {
|
for model in &mut config.models {
|
||||||
|
@ -89,8 +101,154 @@ struct LlamaInstance {
|
||||||
// busy: bool,
|
// busy: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl LlamaInstance {
|
||||||
|
/// Retrieve a running instance from state or spawn a new one.
|
||||||
|
pub async fn get_or_spawn_instance(
|
||||||
|
model_config: &ModelConfig,
|
||||||
|
state: SharedState,
|
||||||
|
) -> Result<LlamaInstance, AppError> {
|
||||||
|
let model_ram_usage =
|
||||||
|
parse_size(&model_config.ram_usage).expect("Invalid ram_usage in model config");
|
||||||
|
let model_vram_usage =
|
||||||
|
parse_size(&model_config.vram_usage).expect("Invalid vram_usage in model config");
|
||||||
|
|
||||||
|
{
|
||||||
|
// First, check if we already have an instance.
|
||||||
|
let state_guard = state.lock().await;
|
||||||
|
if let Some(instance) = state_guard.instances.get(&model_config.port) {
|
||||||
|
return Ok(instance.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut state_guard = state.lock().await;
|
||||||
|
tracing::info!(msg = "Shared state before spawn", ?state_guard);
|
||||||
|
|
||||||
|
// If not enough free resources, stop some running instances.
|
||||||
|
if (state_guard.used_ram + model_ram_usage > state_guard.total_ram)
|
||||||
|
|| (state_guard.used_vram + model_vram_usage > state_guard.total_vram)
|
||||||
|
{
|
||||||
|
let mut to_remove = Vec::new();
|
||||||
|
let instances_by_size = state_guard
|
||||||
|
.instances
|
||||||
|
.iter()
|
||||||
|
.sorted_by(|(_, a), (_, b)| {
|
||||||
|
parse_size(&b.config.vram_usage)
|
||||||
|
.unwrap_or(0)
|
||||||
|
.cmp(&parse_size(&a.config.vram_usage).unwrap_or(0))
|
||||||
|
})
|
||||||
|
.map(|(port, instance)| (*port, instance.clone()))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
for (port, instance) in instances_by_size.iter() {
|
||||||
|
tracing::info!("Stopping instance on port {}", port);
|
||||||
|
let mut proc = instance.process.lock().await;
|
||||||
|
proc.kill().await.ok();
|
||||||
|
to_remove.push(*port);
|
||||||
|
state_guard.used_ram = state_guard
|
||||||
|
.used_ram
|
||||||
|
.saturating_sub(parse_size(&instance.config.ram_usage).unwrap_or(0));
|
||||||
|
state_guard.used_vram = state_guard
|
||||||
|
.used_vram
|
||||||
|
.saturating_sub(parse_size(&instance.config.vram_usage).unwrap_or(0));
|
||||||
|
if state_guard.used_ram + model_ram_usage <= state_guard.total_ram
|
||||||
|
&& state_guard.used_vram + model_vram_usage <= state_guard.total_vram
|
||||||
|
{
|
||||||
|
tracing::info!("Freed enough resources");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for port in to_remove {
|
||||||
|
state_guard.instances.remove(&port);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tracing::info!("Sufficient resources available");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Spawn a new instance.
|
||||||
|
let instance = Self::spawn(&model_config).await?;
|
||||||
|
state_guard.used_ram += model_ram_usage;
|
||||||
|
state_guard.used_vram += model_vram_usage;
|
||||||
|
state_guard
|
||||||
|
.instances
|
||||||
|
.insert(model_config.port, instance.clone());
|
||||||
|
|
||||||
|
sleep(Duration::from_millis(250)).await;
|
||||||
|
|
||||||
|
instance.running_and_port_ready().await?;
|
||||||
|
|
||||||
|
sleep(Duration::from_millis(250)).await;
|
||||||
|
|
||||||
|
Ok(instance)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Spawn a new llama-server process based on the model configuration.
|
||||||
|
pub async fn spawn(model_config: &ModelConfig) -> Result<LlamaInstance, AppError> {
|
||||||
|
let args = model_config
|
||||||
|
.args
|
||||||
|
.iter()
|
||||||
|
.flat_map(|(k, v)| {
|
||||||
|
if v == "true" {
|
||||||
|
vec![format!("--{}", k)]
|
||||||
|
} else {
|
||||||
|
vec![format!("--{}", k), v.clone()]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let internal_port = model_config
|
||||||
|
.internal_port
|
||||||
|
.expect("Internal port must be set");
|
||||||
|
|
||||||
|
let mut cmd = Command::new("llama-server");
|
||||||
|
cmd.kill_on_drop(true)
|
||||||
|
.envs(model_config.env.clone())
|
||||||
|
.args(&args)
|
||||||
|
.arg("--port")
|
||||||
|
.arg(format!("{}", internal_port))
|
||||||
|
.stdout(Stdio::null())
|
||||||
|
.stderr(Stdio::null());
|
||||||
|
|
||||||
|
// Silence output – could be captured later via an API.
|
||||||
|
|
||||||
|
tracing::info!("Starting llama-server with command: {:?}", cmd);
|
||||||
|
let child = cmd.spawn().expect("Failed to start llama-server");
|
||||||
|
|
||||||
|
Ok(LlamaInstance {
|
||||||
|
config: model_config.clone(),
|
||||||
|
process: Arc::new(Mutex::new(child)),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn running_and_port_ready(&self) -> Result<(), anyhow::Error> {
|
||||||
|
async fn is_running(instance: &LlamaInstance) -> Result<(), AppError> {
|
||||||
|
let mut proc = instance.process.clone().lock_owned().await;
|
||||||
|
match proc.try_wait()? {
|
||||||
|
Some(exit_status) => {
|
||||||
|
tracing::error!("Llama instance exited: {:?}", exit_status);
|
||||||
|
Err(AppError::Unknown(anyhow!("Llama instance exited")))
|
||||||
|
}
|
||||||
|
None => Ok(()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn wait_for_port(port: u16) -> Result<(), anyhow::Error> {
|
||||||
|
for _ in 0..10 {
|
||||||
|
if TcpStream::connect(("127.0.0.1", port)).await.is_ok() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
sleep(Duration::from_millis(500)).await;
|
||||||
|
}
|
||||||
|
Err(anyhow!("Timeout waiting for port"))
|
||||||
|
}
|
||||||
|
|
||||||
|
is_running(self).await?;
|
||||||
|
wait_for_port(self.config.internal_port.expect("No port picked?")).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
struct SharedState {
|
pub struct InnerSharedState {
|
||||||
total_ram: u64,
|
total_ram: u64,
|
||||||
total_vram: u64,
|
total_vram: u64,
|
||||||
used_ram: u64,
|
used_ram: u64,
|
||||||
|
@ -98,31 +256,67 @@ struct SharedState {
|
||||||
instances: HashMap<u16, LlamaInstance>,
|
instances: HashMap<u16, LlamaInstance>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
type SharedStateArc = Arc<Mutex<InnerSharedState>>;
|
||||||
async fn main() {
|
|
||||||
// TODO add autostart of models based on config
|
|
||||||
// abstract starting logic out of handler for this to allow seperate calls to start
|
|
||||||
// maybe add to SharedState & LLamaInstance ?
|
|
||||||
|
|
||||||
initialize_logger();
|
/// TODO migrate to dashmap + individual Arc<Mutex<T>> or Rwlocks
|
||||||
// Read and parse the YAML configuration
|
#[derive(Clone, Debug, derive_more::Deref)]
|
||||||
let config_str = std::fs::read_to_string("config.yaml").expect("Failed to read config.yaml");
|
pub struct SharedState(SharedStateArc);
|
||||||
let config: Config = serde_yaml::from_str::<Config>(&config_str)
|
|
||||||
.expect("Failed to parse config.yaml")
|
|
||||||
.pick_open_ports();
|
|
||||||
|
|
||||||
|
impl SharedState {
|
||||||
|
fn from_config(config: Config) -> Self {
|
||||||
// Parse hardware resources
|
// Parse hardware resources
|
||||||
let total_ram = parse_size(&config.hardware.ram).expect("Invalid RAM size in config");
|
let total_ram = parse_size(&config.hardware.ram).expect("Invalid RAM size in config");
|
||||||
let total_vram = parse_size(&config.hardware.vram).expect("Invalid VRAM size in config");
|
let total_vram = parse_size(&config.hardware.vram).expect("Invalid VRAM size in config");
|
||||||
|
|
||||||
// Initialize shared state
|
// Initialize shared state
|
||||||
let shared_state = Arc::new(Mutex::new(SharedState {
|
let shared_state = InnerSharedState {
|
||||||
total_ram,
|
total_ram,
|
||||||
total_vram,
|
total_vram,
|
||||||
used_ram: 0,
|
used_ram: 0,
|
||||||
used_vram: 0,
|
used_vram: 0,
|
||||||
instances: HashMap::new(),
|
instances: HashMap::new(),
|
||||||
}));
|
};
|
||||||
|
|
||||||
|
Self(Arc::new(Mutex::new(shared_state)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build an Axum app for a given model.
|
||||||
|
fn create_app(model_config: &ModelConfig, state: SharedState) -> Router {
|
||||||
|
Router::new()
|
||||||
|
.route(
|
||||||
|
"/",
|
||||||
|
any({
|
||||||
|
let state = state.clone();
|
||||||
|
let model_config = model_config.clone();
|
||||||
|
move |req| handle_request(req, model_config.clone(), state.clone())
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.route(
|
||||||
|
"/*path",
|
||||||
|
any({
|
||||||
|
let state = state.clone();
|
||||||
|
let model_config = model_config.clone();
|
||||||
|
move |req| handle_request(req, model_config.clone(), state.clone())
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.layer(
|
||||||
|
tower_http::trace::TraceLayer::new_for_http()
|
||||||
|
.make_span_with(tower_http::trace::DefaultMakeSpan::new().include_headers(true))
|
||||||
|
.on_request(tower_http::trace::DefaultOnRequest::new().level(Level::DEBUG))
|
||||||
|
.on_response(tower_http::trace::DefaultOnResponse::new().level(Level::TRACE))
|
||||||
|
.on_eos(tower_http::trace::DefaultOnEos::new().level(Level::DEBUG))
|
||||||
|
.on_failure(tower_http::trace::DefaultOnFailure::new().level(Level::ERROR)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
|
initialize_logger();
|
||||||
|
|
||||||
|
let config = Config::default_from_pwd_yml();
|
||||||
|
|
||||||
|
let shared_state = SharedState::from_config(config.clone());
|
||||||
|
|
||||||
// For each model, set up an axum server listening on the specified port
|
// For each model, set up an axum server listening on the specified port
|
||||||
let mut handles = Vec::new();
|
let mut handles = Vec::new();
|
||||||
|
@ -134,31 +328,7 @@ async fn main() {
|
||||||
let handle = tokio::spawn(async move {
|
let handle = tokio::spawn(async move {
|
||||||
let model_config = model_config.clone();
|
let model_config = model_config.clone();
|
||||||
|
|
||||||
let app = Router::new()
|
let app = create_app(&model_config, state);
|
||||||
.route(
|
|
||||||
"/",
|
|
||||||
any({
|
|
||||||
let state = state.clone();
|
|
||||||
let model_config = model_config.clone();
|
|
||||||
move |req| handle_request(req, model_config, state)
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
.route(
|
|
||||||
"/*path",
|
|
||||||
any({
|
|
||||||
let state = state.clone();
|
|
||||||
let model_config = model_config.clone();
|
|
||||||
move |req| handle_request(req, model_config, state)
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
.layer(
|
|
||||||
TraceLayer::new_for_http()
|
|
||||||
.make_span_with(DefaultMakeSpan::new().include_headers(true))
|
|
||||||
.on_request(DefaultOnRequest::new().level(Level::DEBUG))
|
|
||||||
.on_response(DefaultOnResponse::new().level(Level::TRACE))
|
|
||||||
.on_eos(DefaultOnEos::new().level(Level::DEBUG))
|
|
||||||
.on_failure(DefaultOnFailure::new().level(Level::ERROR)),
|
|
||||||
);
|
|
||||||
|
|
||||||
let addr = SocketAddr::from(([0, 0, 0, 0], model_config.port));
|
let addr = SocketAddr::from(([0, 0, 0, 0], model_config.port));
|
||||||
|
|
||||||
|
@ -190,6 +360,8 @@ pub enum AppError {
|
||||||
ReqwestMiddlewareError(#[from] reqwest_middleware::Error),
|
ReqwestMiddlewareError(#[from] reqwest_middleware::Error),
|
||||||
#[error("Client Error")]
|
#[error("Client Error")]
|
||||||
ClientError(#[from] hyper::Error),
|
ClientError(#[from] hyper::Error),
|
||||||
|
#[error("Io Error")]
|
||||||
|
IoError(#[from] std::io::Error),
|
||||||
#[error("Unknown error")]
|
#[error("Unknown error")]
|
||||||
Unknown(#[from] anyhow::Error),
|
Unknown(#[from] anyhow::Error),
|
||||||
}
|
}
|
||||||
|
@ -206,202 +378,67 @@ impl IntoResponse for AppError {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_request(
|
async fn proxy_request(
|
||||||
req: Request<Body>,
|
req: Request<Body>,
|
||||||
model_config: ModelConfig,
|
model_config: &ModelConfig,
|
||||||
state: Arc<Mutex<SharedState>>,
|
) -> Result<Response<Body>, AppError> {
|
||||||
// ) -> Result<Response<Body>, anyhow::Error> {
|
let retry_policy = ExponentialBackoff::builder()
|
||||||
) -> impl IntoResponse {
|
.retry_bounds(
|
||||||
let model_ram_usage = parse_size(&model_config.ram_usage).expect("Invalid ram_usage");
|
std::time::Duration::from_millis(500),
|
||||||
let model_vram_usage = parse_size(&model_config.vram_usage).expect("Invalid vram_usage");
|
std::time::Duration::from_secs(8),
|
||||||
|
|
||||||
let instance = {
|
|
||||||
let mut state = state.lock().await;
|
|
||||||
|
|
||||||
// Check if instance is running
|
|
||||||
if let Some(instance) = state.instances.get_mut(&model_config.port) {
|
|
||||||
// instance.busy = true;
|
|
||||||
instance.to_owned()
|
|
||||||
} else {
|
|
||||||
// Check resources
|
|
||||||
tracing::info!(msg = "Current state", ?state);
|
|
||||||
if ((state.used_ram + model_ram_usage) > state.total_ram)
|
|
||||||
|| ((state.used_vram + model_vram_usage) > state.total_vram)
|
|
||||||
{
|
|
||||||
// Stop other instances
|
|
||||||
let mut to_remove = Vec::new();
|
|
||||||
// TODO Actual smart stopping logic
|
|
||||||
// - search for smallest single model to stop to get enough room
|
|
||||||
// - if not possible search for smallest number of models to stop with lowest
|
|
||||||
// amount of "overshot
|
|
||||||
let instances_by_size =
|
|
||||||
state
|
|
||||||
.instances
|
|
||||||
.clone()
|
|
||||||
.into_iter()
|
|
||||||
.sorted_by(|(_, el_a), (_, el_b)| {
|
|
||||||
Ord::cmp(
|
|
||||||
&parse_size(el_b.config.vram_usage.as_str()),
|
|
||||||
&parse_size(el_a.config.vram_usage.as_str()),
|
|
||||||
)
|
)
|
||||||
});
|
|
||||||
for (port, instance) in instances_by_size {
|
|
||||||
// if !instance.busy {
|
|
||||||
tracing::info!("Stopping instance on port {}", port);
|
|
||||||
let mut process = instance.process.lock().await;
|
|
||||||
process.kill().await.ok();
|
|
||||||
to_remove.push(port);
|
|
||||||
state.used_ram -= parse_size(&instance.config.ram_usage).unwrap_or(0);
|
|
||||||
state.used_vram -= parse_size(&instance.config.vram_usage).unwrap_or(0);
|
|
||||||
if state.used_ram + model_ram_usage <= state.total_ram
|
|
||||||
&& state.used_vram + model_vram_usage <= state.total_vram
|
|
||||||
{
|
|
||||||
tracing::info!("Should have enough ram now");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
// }
|
|
||||||
}
|
|
||||||
for port in to_remove {
|
|
||||||
tracing::info!("Removing instance on port {}", port);
|
|
||||||
state.instances.remove(&port);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
tracing::info!("Already enough res free");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start new instance
|
|
||||||
let args = model_config
|
|
||||||
.args
|
|
||||||
.iter()
|
|
||||||
.flat_map(|(k, v)| {
|
|
||||||
if v == "true" {
|
|
||||||
vec![format!("--{}", k)]
|
|
||||||
} else {
|
|
||||||
vec![format!("--{}", k), v.clone()]
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
let mut cmd = Command::new("llama-server");
|
|
||||||
cmd.kill_on_drop(true);
|
|
||||||
cmd.envs(model_config.env.clone());
|
|
||||||
cmd.args(&args);
|
|
||||||
// TODO use openport crate via pick_random_unused_port for determining these
|
|
||||||
cmd.arg("--port");
|
|
||||||
cmd.arg(format!(
|
|
||||||
"{}",
|
|
||||||
model_config
|
|
||||||
.internal_port
|
|
||||||
.expect("Unexpected empty port, should've been picked")
|
|
||||||
));
|
|
||||||
cmd.stdout(Stdio::null()).stderr(Stdio::null()); // TODO save output and allow retrieval via api
|
|
||||||
|
|
||||||
tracing::info!("Starting llama-server with {:?}", cmd);
|
|
||||||
let process = Arc::new(Mutex::new(
|
|
||||||
cmd.spawn().expect("Failed to start llama-server"),
|
|
||||||
));
|
|
||||||
|
|
||||||
state.used_ram += model_ram_usage;
|
|
||||||
state.used_vram += model_vram_usage;
|
|
||||||
|
|
||||||
let instance = LlamaInstance {
|
|
||||||
config: model_config.clone(),
|
|
||||||
process,
|
|
||||||
// busy: true,
|
|
||||||
};
|
|
||||||
sleep(Duration::from_millis(500)).await;
|
|
||||||
state.instances.insert(model_config.port, instance.clone());
|
|
||||||
|
|
||||||
instance
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Wait for the instance to be ready
|
|
||||||
is_llama_instance_running(&instance).await?;
|
|
||||||
wait_for_port(
|
|
||||||
model_config
|
|
||||||
.internal_port
|
|
||||||
.expect("Unexpected empty port, should've been picked"),
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
// Proxy the request
|
|
||||||
let retry_policy = reqwest_retry::policies::ExponentialBackoff::builder()
|
|
||||||
.retry_bounds(Duration::from_millis(500), Duration::from_secs(8))
|
|
||||||
.jitter(reqwest_retry::Jitter::None)
|
.jitter(reqwest_retry::Jitter::None)
|
||||||
.base(2)
|
.base(2)
|
||||||
.build_with_max_retries(8);
|
.build_with_max_retries(8);
|
||||||
|
|
||||||
let client = reqwest_middleware::ClientBuilder::new(Client::new())
|
let client: ClientWithMiddleware = ClientBuilder::new(Client::new())
|
||||||
.with(reqwest_retry::RetryTransientMiddleware::new_with_policy(
|
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
|
||||||
retry_policy,
|
|
||||||
))
|
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
let internal_port = model_config
|
||||||
|
.internal_port
|
||||||
|
.expect("Internal port must be set");
|
||||||
let uri = format!(
|
let uri = format!(
|
||||||
"http://127.0.0.1:{}{}",
|
"http://127.0.0.1:{}{}",
|
||||||
model_config
|
internal_port,
|
||||||
.internal_port
|
req.uri()
|
||||||
.expect("Unexpected empty port, should've been picked"),
|
.path_and_query()
|
||||||
req.uri().path_and_query().map(|x| x.as_str()).unwrap_or("")
|
.map(|pq| pq.as_str())
|
||||||
|
.unwrap_or("")
|
||||||
);
|
);
|
||||||
|
|
||||||
let mut request_builder = client.request(req.method().clone(), &uri);
|
let mut request_builder = client.request(req.method().clone(), &uri);
|
||||||
// let mut request_builder = reqwest::RequestBuilder::from_parts(client, req.method().clone())
|
// Forward all headers.
|
||||||
|
|
||||||
for (name, value) in req.headers() {
|
for (name, value) in req.headers() {
|
||||||
request_builder = request_builder.header(name, value);
|
request_builder = request_builder.header(name, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
let body_bytes = axum::body::to_bytes(
|
let max_size = parse_size("1G").unwrap() as usize;
|
||||||
req.into_body(),
|
let body_bytes = axum::body::to_bytes(req.into_body(), max_size).await?;
|
||||||
parse_size("1G").unwrap().try_into().unwrap(),
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
let request = request_builder.body(body_bytes).build()?;
|
let request = request_builder.body(body_bytes).build()?;
|
||||||
// tracing::info!("Proxying request to {}", uri);
|
|
||||||
let response = client.execute(request).await?;
|
let response = client.execute(request).await?;
|
||||||
|
|
||||||
// tracing::info!("Received response from {}", uri);
|
|
||||||
let mut builder = Response::builder().status(response.status());
|
let mut builder = Response::builder().status(response.status());
|
||||||
for (name, value) in response.headers() {
|
for (name, value) in response.headers() {
|
||||||
builder = builder.header(name, value);
|
builder = builder.header(name, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
// let bytes = response.bytes().await?;
|
|
||||||
|
|
||||||
let byte_stream = response.bytes_stream();
|
let byte_stream = response.bytes_stream();
|
||||||
|
let body = Body::from_stream(byte_stream);
|
||||||
|
Ok(builder.body(body)?)
|
||||||
|
}
|
||||||
|
|
||||||
let body = axum::body::Body::from_stream(byte_stream);
|
async fn handle_request(
|
||||||
|
req: Request<Body>,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
state: SharedState,
|
||||||
|
) -> impl IntoResponse {
|
||||||
|
let _instance = LlamaInstance::get_or_spawn_instance(&model_config, state).await?;
|
||||||
|
let response = proxy_request(req, &model_config).await?;
|
||||||
|
|
||||||
tracing::debug!("streaming response on port: {}", model_config.port);
|
|
||||||
let response = builder.body(body)?;
|
|
||||||
Ok::<axum::http::Response<Body>, AppError>(response)
|
Ok::<axum::http::Response<Body>, AppError>(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn wait_for_port(port: u16) -> Result<(), anyhow::Error> {
|
|
||||||
for _ in 0..10 {
|
|
||||||
if TcpStream::connect(("127.0.0.1", port)).await.is_ok() {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
sleep(Duration::from_millis(500)).await;
|
|
||||||
}
|
|
||||||
Err(anyhow!("Timeout waiting for port"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// reimplement wait_for_port using a LlamaInstance, also check if the process is still running
|
|
||||||
async fn is_llama_instance_running(instance: &LlamaInstance) -> Result<(), anyhow::Error> {
|
|
||||||
match instance.process.to_owned().lock_owned().await.try_wait()? {
|
|
||||||
Some(exit_status) => {
|
|
||||||
let msg = "Llama instance exited";
|
|
||||||
tracing::error!(msg, ?exit_status);
|
|
||||||
Err(anyhow!(msg))
|
|
||||||
}
|
|
||||||
|
|
||||||
None => Ok(()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn parse_size(size_str: &str) -> Option<u64> {
|
fn parse_size(size_str: &str) -> Option<u64> {
|
||||||
let mut num = String::new();
|
let mut num = String::new();
|
||||||
let mut unit = String::new();
|
let mut unit = String::new();
|
||||||
|
|
Loading…
Add table
Reference in a new issue