redvault-ai/llama_proxy_man/src/config.rs
2025-02-20 02:13:21 +01:00

66 lines
1.8 KiB
Rust

use std::{collections::HashMap, fs};
use figment::{
providers::{Env, Format, Json, Toml, Yaml},
Figment,
};
use serde::Deserialize;
#[derive(Clone, Debug, Deserialize)]
pub struct AppConfig {
pub system_resources: SystemResources,
pub model_specs: Vec<ModelSpec>,
}
impl AppConfig {
pub fn default_figment() -> Self {
let config: Result<Self, _> = Figment::new()
.merge(Toml::file("config.toml"))
.merge(Yaml::file("config.yaml"))
.merge(Env::prefixed("LLAMA_FORGE_"))
.join(Json::file("Cargo.json"))
.extract();
tracing::info!(?config);
config.unwrap().assign_internal_ports()
}
pub fn default_from_pwd_yml() -> Self {
let config_str = fs::read_to_string("./config.yaml").expect("Failed to read config.yaml");
serde_yaml::from_str::<Self>(&config_str)
.expect("Failed to parse config.yaml")
.assign_internal_ports()
}
// Ensure every model has an internal port
pub fn assign_internal_ports(self) -> Self {
let mut config = self.clone();
for model in &mut config.model_specs {
if model.internal_port.is_none() {
model.internal_port = Some(
openport::pick_random_unused_port()
.unwrap_or_else(|| panic!("No open port found for {:?}", model)),
);
}
}
config
}
}
#[derive(Clone, Debug, Deserialize)]
pub struct SystemResources {
pub ram: String,
pub vram: String,
}
#[derive(Clone, Debug, Deserialize)]
pub struct ModelSpec {
pub name: String,
pub port: u16,
pub internal_port: Option<u16>,
pub env: HashMap<String, String>,
pub args: HashMap<String, String>,
pub vram_usage: String,
pub ram_usage: String,
}