Refactor llama_proxy_man
This commit is contained in:
parent
8c062ece28
commit
e3a3ec3826
3 changed files with 277 additions and 218 deletions
25
Cargo.lock
generated
25
Cargo.lock
generated
|
@ -1141,6 +1141,26 @@ dependencies = [
|
|||
"syn 2.0.98",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_more"
|
||||
version = "2.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "093242cf7570c207c83073cf82f79706fe7b8317e98620a47d5be7c3d8497678"
|
||||
dependencies = [
|
||||
"derive_more-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_more-impl"
|
||||
version = "2.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bda628edc44c4bb645fbe0f758797143e4e07926f7ebf4e9bdfbd3d2ce621df3"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.98",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "digest"
|
||||
version = "0.10.7"
|
||||
|
@ -2950,7 +2970,7 @@ dependencies = [
|
|||
"chrono",
|
||||
"console_error_panic_hook",
|
||||
"dashmap",
|
||||
"derive_more",
|
||||
"derive_more 0.99.19",
|
||||
"futures",
|
||||
"futures-util",
|
||||
"gloo-net 0.5.0",
|
||||
|
@ -3002,6 +3022,7 @@ version = "0.1.1"
|
|||
dependencies = [
|
||||
"anyhow",
|
||||
"axum",
|
||||
"derive_more 2.0.1",
|
||||
"futures",
|
||||
"hyper",
|
||||
"itertools 0.13.0",
|
||||
|
@ -4781,7 +4802,7 @@ checksum = "df320f1889ac4ba6bc0cdc9c9af7af4bd64bb927bccdf32d81140dc1f9be12fe"
|
|||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"cssparser",
|
||||
"derive_more",
|
||||
"derive_more 0.99.19",
|
||||
"fxhash",
|
||||
"log",
|
||||
"matches",
|
||||
|
|
|
@ -29,3 +29,4 @@ reqwest-retry = "0.6.1"
|
|||
reqwest-middleware = { version = "0.3.3", features = ["charset", "http2", "json", "multipart", "rustls-tls"] }
|
||||
itertools = "0.13.0"
|
||||
openport = { version = "0.1.1", features = ["rand"] }
|
||||
derive_more = { version = "2.0.1", features = ["deref"] }
|
||||
|
|
|
@ -11,6 +11,8 @@ use axum::{
|
|||
use futures;
|
||||
use itertools::Itertools;
|
||||
use reqwest::Client;
|
||||
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
|
||||
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
||||
use serde::Deserialize;
|
||||
use std::{collections::HashMap, net::SocketAddr, process::Stdio, sync::Arc};
|
||||
use tokio::{
|
||||
|
@ -19,10 +21,7 @@ use tokio::{
|
|||
sync::Mutex,
|
||||
time::{sleep, Duration},
|
||||
};
|
||||
use tower_http::trace::{
|
||||
DefaultMakeSpan, DefaultOnEos, DefaultOnFailure, DefaultOnRequest, DefaultOnResponse,
|
||||
TraceLayer,
|
||||
};
|
||||
use tower_http;
|
||||
use tracing::Level;
|
||||
|
||||
use std::sync::Once;
|
||||
|
@ -42,6 +41,8 @@ pub fn initialize_logger() {
|
|||
});
|
||||
}
|
||||
|
||||
// TODO Add valiation to the config
|
||||
// - e.g. check double taken ports etc
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
struct Config {
|
||||
hardware: Hardware,
|
||||
|
@ -49,7 +50,18 @@ struct Config {
|
|||
}
|
||||
|
||||
impl Config {
|
||||
fn default_from_pwd_yml() -> Self {
|
||||
// Read and parse the YAML configuration
|
||||
let config_str =
|
||||
std::fs::read_to_string("config.yaml").expect("Failed to read config.yaml");
|
||||
|
||||
serde_yaml::from_str::<Self>(&config_str)
|
||||
.expect("Failed to parse config.yaml")
|
||||
.pick_open_ports()
|
||||
}
|
||||
|
||||
// TODO split up into raw deser config and "parsed"/"processed" config which always has a port
|
||||
// FIXME maybe just always initialize with a port ?
|
||||
fn pick_open_ports(self) -> Self {
|
||||
let mut config = self.clone();
|
||||
for model in &mut config.models {
|
||||
|
@ -89,8 +101,154 @@ struct LlamaInstance {
|
|||
// busy: bool,
|
||||
}
|
||||
|
||||
impl LlamaInstance {
|
||||
/// Retrieve a running instance from state or spawn a new one.
|
||||
pub async fn get_or_spawn_instance(
|
||||
model_config: &ModelConfig,
|
||||
state: SharedState,
|
||||
) -> Result<LlamaInstance, AppError> {
|
||||
let model_ram_usage =
|
||||
parse_size(&model_config.ram_usage).expect("Invalid ram_usage in model config");
|
||||
let model_vram_usage =
|
||||
parse_size(&model_config.vram_usage).expect("Invalid vram_usage in model config");
|
||||
|
||||
{
|
||||
// First, check if we already have an instance.
|
||||
let state_guard = state.lock().await;
|
||||
if let Some(instance) = state_guard.instances.get(&model_config.port) {
|
||||
return Ok(instance.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let mut state_guard = state.lock().await;
|
||||
tracing::info!(msg = "Shared state before spawn", ?state_guard);
|
||||
|
||||
// If not enough free resources, stop some running instances.
|
||||
if (state_guard.used_ram + model_ram_usage > state_guard.total_ram)
|
||||
|| (state_guard.used_vram + model_vram_usage > state_guard.total_vram)
|
||||
{
|
||||
let mut to_remove = Vec::new();
|
||||
let instances_by_size = state_guard
|
||||
.instances
|
||||
.iter()
|
||||
.sorted_by(|(_, a), (_, b)| {
|
||||
parse_size(&b.config.vram_usage)
|
||||
.unwrap_or(0)
|
||||
.cmp(&parse_size(&a.config.vram_usage).unwrap_or(0))
|
||||
})
|
||||
.map(|(port, instance)| (*port, instance.clone()))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for (port, instance) in instances_by_size.iter() {
|
||||
tracing::info!("Stopping instance on port {}", port);
|
||||
let mut proc = instance.process.lock().await;
|
||||
proc.kill().await.ok();
|
||||
to_remove.push(*port);
|
||||
state_guard.used_ram = state_guard
|
||||
.used_ram
|
||||
.saturating_sub(parse_size(&instance.config.ram_usage).unwrap_or(0));
|
||||
state_guard.used_vram = state_guard
|
||||
.used_vram
|
||||
.saturating_sub(parse_size(&instance.config.vram_usage).unwrap_or(0));
|
||||
if state_guard.used_ram + model_ram_usage <= state_guard.total_ram
|
||||
&& state_guard.used_vram + model_vram_usage <= state_guard.total_vram
|
||||
{
|
||||
tracing::info!("Freed enough resources");
|
||||
break;
|
||||
}
|
||||
}
|
||||
for port in to_remove {
|
||||
state_guard.instances.remove(&port);
|
||||
}
|
||||
} else {
|
||||
tracing::info!("Sufficient resources available");
|
||||
}
|
||||
|
||||
// Spawn a new instance.
|
||||
let instance = Self::spawn(&model_config).await?;
|
||||
state_guard.used_ram += model_ram_usage;
|
||||
state_guard.used_vram += model_vram_usage;
|
||||
state_guard
|
||||
.instances
|
||||
.insert(model_config.port, instance.clone());
|
||||
|
||||
sleep(Duration::from_millis(250)).await;
|
||||
|
||||
instance.running_and_port_ready().await?;
|
||||
|
||||
sleep(Duration::from_millis(250)).await;
|
||||
|
||||
Ok(instance)
|
||||
}
|
||||
|
||||
/// Spawn a new llama-server process based on the model configuration.
|
||||
pub async fn spawn(model_config: &ModelConfig) -> Result<LlamaInstance, AppError> {
|
||||
let args = model_config
|
||||
.args
|
||||
.iter()
|
||||
.flat_map(|(k, v)| {
|
||||
if v == "true" {
|
||||
vec![format!("--{}", k)]
|
||||
} else {
|
||||
vec![format!("--{}", k), v.clone()]
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let internal_port = model_config
|
||||
.internal_port
|
||||
.expect("Internal port must be set");
|
||||
|
||||
let mut cmd = Command::new("llama-server");
|
||||
cmd.kill_on_drop(true)
|
||||
.envs(model_config.env.clone())
|
||||
.args(&args)
|
||||
.arg("--port")
|
||||
.arg(format!("{}", internal_port))
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null());
|
||||
|
||||
// Silence output – could be captured later via an API.
|
||||
|
||||
tracing::info!("Starting llama-server with command: {:?}", cmd);
|
||||
let child = cmd.spawn().expect("Failed to start llama-server");
|
||||
|
||||
Ok(LlamaInstance {
|
||||
config: model_config.clone(),
|
||||
process: Arc::new(Mutex::new(child)),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn running_and_port_ready(&self) -> Result<(), anyhow::Error> {
|
||||
async fn is_running(instance: &LlamaInstance) -> Result<(), AppError> {
|
||||
let mut proc = instance.process.clone().lock_owned().await;
|
||||
match proc.try_wait()? {
|
||||
Some(exit_status) => {
|
||||
tracing::error!("Llama instance exited: {:?}", exit_status);
|
||||
Err(AppError::Unknown(anyhow!("Llama instance exited")))
|
||||
}
|
||||
None => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_for_port(port: u16) -> Result<(), anyhow::Error> {
|
||||
for _ in 0..10 {
|
||||
if TcpStream::connect(("127.0.0.1", port)).await.is_ok() {
|
||||
return Ok(());
|
||||
}
|
||||
sleep(Duration::from_millis(500)).await;
|
||||
}
|
||||
Err(anyhow!("Timeout waiting for port"))
|
||||
}
|
||||
|
||||
is_running(self).await?;
|
||||
wait_for_port(self.config.internal_port.expect("No port picked?")).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct SharedState {
|
||||
pub struct InnerSharedState {
|
||||
total_ram: u64,
|
||||
total_vram: u64,
|
||||
used_ram: u64,
|
||||
|
@ -98,31 +256,67 @@ struct SharedState {
|
|||
instances: HashMap<u16, LlamaInstance>,
|
||||
}
|
||||
|
||||
type SharedStateArc = Arc<Mutex<InnerSharedState>>;
|
||||
|
||||
/// TODO migrate to dashmap + individual Arc<Mutex<T>> or Rwlocks
|
||||
#[derive(Clone, Debug, derive_more::Deref)]
|
||||
pub struct SharedState(SharedStateArc);
|
||||
|
||||
impl SharedState {
|
||||
fn from_config(config: Config) -> Self {
|
||||
// Parse hardware resources
|
||||
let total_ram = parse_size(&config.hardware.ram).expect("Invalid RAM size in config");
|
||||
let total_vram = parse_size(&config.hardware.vram).expect("Invalid VRAM size in config");
|
||||
|
||||
// Initialize shared state
|
||||
let shared_state = InnerSharedState {
|
||||
total_ram,
|
||||
total_vram,
|
||||
used_ram: 0,
|
||||
used_vram: 0,
|
||||
instances: HashMap::new(),
|
||||
};
|
||||
|
||||
Self(Arc::new(Mutex::new(shared_state)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Build an Axum app for a given model.
|
||||
fn create_app(model_config: &ModelConfig, state: SharedState) -> Router {
|
||||
Router::new()
|
||||
.route(
|
||||
"/",
|
||||
any({
|
||||
let state = state.clone();
|
||||
let model_config = model_config.clone();
|
||||
move |req| handle_request(req, model_config.clone(), state.clone())
|
||||
}),
|
||||
)
|
||||
.route(
|
||||
"/*path",
|
||||
any({
|
||||
let state = state.clone();
|
||||
let model_config = model_config.clone();
|
||||
move |req| handle_request(req, model_config.clone(), state.clone())
|
||||
}),
|
||||
)
|
||||
.layer(
|
||||
tower_http::trace::TraceLayer::new_for_http()
|
||||
.make_span_with(tower_http::trace::DefaultMakeSpan::new().include_headers(true))
|
||||
.on_request(tower_http::trace::DefaultOnRequest::new().level(Level::DEBUG))
|
||||
.on_response(tower_http::trace::DefaultOnResponse::new().level(Level::TRACE))
|
||||
.on_eos(tower_http::trace::DefaultOnEos::new().level(Level::DEBUG))
|
||||
.on_failure(tower_http::trace::DefaultOnFailure::new().level(Level::ERROR)),
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
// TODO add autostart of models based on config
|
||||
// abstract starting logic out of handler for this to allow seperate calls to start
|
||||
// maybe add to SharedState & LLamaInstance ?
|
||||
|
||||
initialize_logger();
|
||||
// Read and parse the YAML configuration
|
||||
let config_str = std::fs::read_to_string("config.yaml").expect("Failed to read config.yaml");
|
||||
let config: Config = serde_yaml::from_str::<Config>(&config_str)
|
||||
.expect("Failed to parse config.yaml")
|
||||
.pick_open_ports();
|
||||
|
||||
// Parse hardware resources
|
||||
let total_ram = parse_size(&config.hardware.ram).expect("Invalid RAM size in config");
|
||||
let total_vram = parse_size(&config.hardware.vram).expect("Invalid VRAM size in config");
|
||||
let config = Config::default_from_pwd_yml();
|
||||
|
||||
// Initialize shared state
|
||||
let shared_state = Arc::new(Mutex::new(SharedState {
|
||||
total_ram,
|
||||
total_vram,
|
||||
used_ram: 0,
|
||||
used_vram: 0,
|
||||
instances: HashMap::new(),
|
||||
}));
|
||||
let shared_state = SharedState::from_config(config.clone());
|
||||
|
||||
// For each model, set up an axum server listening on the specified port
|
||||
let mut handles = Vec::new();
|
||||
|
@ -134,31 +328,7 @@ async fn main() {
|
|||
let handle = tokio::spawn(async move {
|
||||
let model_config = model_config.clone();
|
||||
|
||||
let app = Router::new()
|
||||
.route(
|
||||
"/",
|
||||
any({
|
||||
let state = state.clone();
|
||||
let model_config = model_config.clone();
|
||||
move |req| handle_request(req, model_config, state)
|
||||
}),
|
||||
)
|
||||
.route(
|
||||
"/*path",
|
||||
any({
|
||||
let state = state.clone();
|
||||
let model_config = model_config.clone();
|
||||
move |req| handle_request(req, model_config, state)
|
||||
}),
|
||||
)
|
||||
.layer(
|
||||
TraceLayer::new_for_http()
|
||||
.make_span_with(DefaultMakeSpan::new().include_headers(true))
|
||||
.on_request(DefaultOnRequest::new().level(Level::DEBUG))
|
||||
.on_response(DefaultOnResponse::new().level(Level::TRACE))
|
||||
.on_eos(DefaultOnEos::new().level(Level::DEBUG))
|
||||
.on_failure(DefaultOnFailure::new().level(Level::ERROR)),
|
||||
);
|
||||
let app = create_app(&model_config, state);
|
||||
|
||||
let addr = SocketAddr::from(([0, 0, 0, 0], model_config.port));
|
||||
|
||||
|
@ -190,6 +360,8 @@ pub enum AppError {
|
|||
ReqwestMiddlewareError(#[from] reqwest_middleware::Error),
|
||||
#[error("Client Error")]
|
||||
ClientError(#[from] hyper::Error),
|
||||
#[error("Io Error")]
|
||||
IoError(#[from] std::io::Error),
|
||||
#[error("Unknown error")]
|
||||
Unknown(#[from] anyhow::Error),
|
||||
}
|
||||
|
@ -206,202 +378,67 @@ impl IntoResponse for AppError {
|
|||
}
|
||||
}
|
||||
|
||||
async fn handle_request(
|
||||
async fn proxy_request(
|
||||
req: Request<Body>,
|
||||
model_config: ModelConfig,
|
||||
state: Arc<Mutex<SharedState>>,
|
||||
// ) -> Result<Response<Body>, anyhow::Error> {
|
||||
) -> impl IntoResponse {
|
||||
let model_ram_usage = parse_size(&model_config.ram_usage).expect("Invalid ram_usage");
|
||||
let model_vram_usage = parse_size(&model_config.vram_usage).expect("Invalid vram_usage");
|
||||
|
||||
let instance = {
|
||||
let mut state = state.lock().await;
|
||||
|
||||
// Check if instance is running
|
||||
if let Some(instance) = state.instances.get_mut(&model_config.port) {
|
||||
// instance.busy = true;
|
||||
instance.to_owned()
|
||||
} else {
|
||||
// Check resources
|
||||
tracing::info!(msg = "Current state", ?state);
|
||||
if ((state.used_ram + model_ram_usage) > state.total_ram)
|
||||
|| ((state.used_vram + model_vram_usage) > state.total_vram)
|
||||
{
|
||||
// Stop other instances
|
||||
let mut to_remove = Vec::new();
|
||||
// TODO Actual smart stopping logic
|
||||
// - search for smallest single model to stop to get enough room
|
||||
// - if not possible search for smallest number of models to stop with lowest
|
||||
// amount of "overshot
|
||||
let instances_by_size =
|
||||
state
|
||||
.instances
|
||||
.clone()
|
||||
.into_iter()
|
||||
.sorted_by(|(_, el_a), (_, el_b)| {
|
||||
Ord::cmp(
|
||||
&parse_size(el_b.config.vram_usage.as_str()),
|
||||
&parse_size(el_a.config.vram_usage.as_str()),
|
||||
)
|
||||
});
|
||||
for (port, instance) in instances_by_size {
|
||||
// if !instance.busy {
|
||||
tracing::info!("Stopping instance on port {}", port);
|
||||
let mut process = instance.process.lock().await;
|
||||
process.kill().await.ok();
|
||||
to_remove.push(port);
|
||||
state.used_ram -= parse_size(&instance.config.ram_usage).unwrap_or(0);
|
||||
state.used_vram -= parse_size(&instance.config.vram_usage).unwrap_or(0);
|
||||
if state.used_ram + model_ram_usage <= state.total_ram
|
||||
&& state.used_vram + model_vram_usage <= state.total_vram
|
||||
{
|
||||
tracing::info!("Should have enough ram now");
|
||||
break;
|
||||
}
|
||||
// }
|
||||
}
|
||||
for port in to_remove {
|
||||
tracing::info!("Removing instance on port {}", port);
|
||||
state.instances.remove(&port);
|
||||
}
|
||||
} else {
|
||||
tracing::info!("Already enough res free");
|
||||
}
|
||||
|
||||
// Start new instance
|
||||
let args = model_config
|
||||
.args
|
||||
.iter()
|
||||
.flat_map(|(k, v)| {
|
||||
if v == "true" {
|
||||
vec![format!("--{}", k)]
|
||||
} else {
|
||||
vec![format!("--{}", k), v.clone()]
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut cmd = Command::new("llama-server");
|
||||
cmd.kill_on_drop(true);
|
||||
cmd.envs(model_config.env.clone());
|
||||
cmd.args(&args);
|
||||
// TODO use openport crate via pick_random_unused_port for determining these
|
||||
cmd.arg("--port");
|
||||
cmd.arg(format!(
|
||||
"{}",
|
||||
model_config
|
||||
.internal_port
|
||||
.expect("Unexpected empty port, should've been picked")
|
||||
));
|
||||
cmd.stdout(Stdio::null()).stderr(Stdio::null()); // TODO save output and allow retrieval via api
|
||||
|
||||
tracing::info!("Starting llama-server with {:?}", cmd);
|
||||
let process = Arc::new(Mutex::new(
|
||||
cmd.spawn().expect("Failed to start llama-server"),
|
||||
));
|
||||
|
||||
state.used_ram += model_ram_usage;
|
||||
state.used_vram += model_vram_usage;
|
||||
|
||||
let instance = LlamaInstance {
|
||||
config: model_config.clone(),
|
||||
process,
|
||||
// busy: true,
|
||||
};
|
||||
sleep(Duration::from_millis(500)).await;
|
||||
state.instances.insert(model_config.port, instance.clone());
|
||||
|
||||
instance
|
||||
}
|
||||
};
|
||||
|
||||
// Wait for the instance to be ready
|
||||
is_llama_instance_running(&instance).await?;
|
||||
wait_for_port(
|
||||
model_config
|
||||
.internal_port
|
||||
.expect("Unexpected empty port, should've been picked"),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Proxy the request
|
||||
let retry_policy = reqwest_retry::policies::ExponentialBackoff::builder()
|
||||
.retry_bounds(Duration::from_millis(500), Duration::from_secs(8))
|
||||
model_config: &ModelConfig,
|
||||
) -> Result<Response<Body>, AppError> {
|
||||
let retry_policy = ExponentialBackoff::builder()
|
||||
.retry_bounds(
|
||||
std::time::Duration::from_millis(500),
|
||||
std::time::Duration::from_secs(8),
|
||||
)
|
||||
.jitter(reqwest_retry::Jitter::None)
|
||||
.base(2)
|
||||
.build_with_max_retries(8);
|
||||
|
||||
let client = reqwest_middleware::ClientBuilder::new(Client::new())
|
||||
.with(reqwest_retry::RetryTransientMiddleware::new_with_policy(
|
||||
retry_policy,
|
||||
))
|
||||
let client: ClientWithMiddleware = ClientBuilder::new(Client::new())
|
||||
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
|
||||
.build();
|
||||
|
||||
let internal_port = model_config
|
||||
.internal_port
|
||||
.expect("Internal port must be set");
|
||||
let uri = format!(
|
||||
"http://127.0.0.1:{}{}",
|
||||
model_config
|
||||
.internal_port
|
||||
.expect("Unexpected empty port, should've been picked"),
|
||||
req.uri().path_and_query().map(|x| x.as_str()).unwrap_or("")
|
||||
internal_port,
|
||||
req.uri()
|
||||
.path_and_query()
|
||||
.map(|pq| pq.as_str())
|
||||
.unwrap_or("")
|
||||
);
|
||||
|
||||
let mut request_builder = client.request(req.method().clone(), &uri);
|
||||
// let mut request_builder = reqwest::RequestBuilder::from_parts(client, req.method().clone())
|
||||
|
||||
// Forward all headers.
|
||||
for (name, value) in req.headers() {
|
||||
request_builder = request_builder.header(name, value);
|
||||
}
|
||||
|
||||
let body_bytes = axum::body::to_bytes(
|
||||
req.into_body(),
|
||||
parse_size("1G").unwrap().try_into().unwrap(),
|
||||
)
|
||||
.await?;
|
||||
let max_size = parse_size("1G").unwrap() as usize;
|
||||
let body_bytes = axum::body::to_bytes(req.into_body(), max_size).await?;
|
||||
let request = request_builder.body(body_bytes).build()?;
|
||||
// tracing::info!("Proxying request to {}", uri);
|
||||
let response = client.execute(request).await?;
|
||||
|
||||
// tracing::info!("Received response from {}", uri);
|
||||
let mut builder = Response::builder().status(response.status());
|
||||
for (name, value) in response.headers() {
|
||||
builder = builder.header(name, value);
|
||||
}
|
||||
|
||||
// let bytes = response.bytes().await?;
|
||||
|
||||
let byte_stream = response.bytes_stream();
|
||||
let body = Body::from_stream(byte_stream);
|
||||
Ok(builder.body(body)?)
|
||||
}
|
||||
|
||||
let body = axum::body::Body::from_stream(byte_stream);
|
||||
async fn handle_request(
|
||||
req: Request<Body>,
|
||||
model_config: ModelConfig,
|
||||
state: SharedState,
|
||||
) -> impl IntoResponse {
|
||||
let _instance = LlamaInstance::get_or_spawn_instance(&model_config, state).await?;
|
||||
let response = proxy_request(req, &model_config).await?;
|
||||
|
||||
tracing::debug!("streaming response on port: {}", model_config.port);
|
||||
let response = builder.body(body)?;
|
||||
Ok::<axum::http::Response<Body>, AppError>(response)
|
||||
}
|
||||
|
||||
async fn wait_for_port(port: u16) -> Result<(), anyhow::Error> {
|
||||
for _ in 0..10 {
|
||||
if TcpStream::connect(("127.0.0.1", port)).await.is_ok() {
|
||||
return Ok(());
|
||||
}
|
||||
sleep(Duration::from_millis(500)).await;
|
||||
}
|
||||
Err(anyhow!("Timeout waiting for port"))
|
||||
}
|
||||
|
||||
// reimplement wait_for_port using a LlamaInstance, also check if the process is still running
|
||||
async fn is_llama_instance_running(instance: &LlamaInstance) -> Result<(), anyhow::Error> {
|
||||
match instance.process.to_owned().lock_owned().await.try_wait()? {
|
||||
Some(exit_status) => {
|
||||
let msg = "Llama instance exited";
|
||||
tracing::error!(msg, ?exit_status);
|
||||
Err(anyhow!(msg))
|
||||
}
|
||||
|
||||
None => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_size(size_str: &str) -> Option<u64> {
|
||||
let mut num = String::new();
|
||||
let mut unit = String::new();
|
||||
|
|
Loading…
Add table
Reference in a new issue