redvault-ai/llama_forge_rs/src/server/backends/llama_completion.rs

145 lines
3.9 KiB
Rust

use std::{io, io::Write};
use futures::StreamExt;
use reqwest::Client;
use reqwest_eventsource::{Event, EventSource};
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
struct CompletionRequest {
stream: bool,
n_predict: i32,
temperature: f32,
stop: Vec<String>,
repeat_last_n: i32,
repeat_penalty: f32,
top_k: i32,
top_p: f32,
min_p: f32,
tfs_z: f32,
typical_p: f32,
presence_penalty: f32,
frequency_penalty: f32,
mirostat: f32,
mirostat_tau: f32,
mirostat_eta: f32,
grammar: String,
n_probs: i32,
image_data: Vec<String>,
ignore_eos: bool,
cache_prompt: bool,
api_key: String,
slot_id: i32,
prompt: String,
}
impl Default for CompletionRequest {
fn default() -> Self {
CompletionRequest {
stream: true,
n_predict: 2048,
temperature: 0.7,
stop: vec![
"<|im_end|>".to_string(),
"<dummy00001>".to_string(),
"</s>".to_string(),
"Llama:".to_string(),
"User:".to_string(),
],
repeat_last_n: 64,
repeat_penalty: 1.1,
top_k: 40,
top_p: 0.95,
min_p: 0.05,
tfs_z: 1.0,
typical_p: 1.0,
presence_penalty: 0.0,
frequency_penalty: 0.0,
mirostat: 0.0,
mirostat_tau: 5.0,
mirostat_eta: 0.1,
grammar: "".to_string(),
n_probs: 0,
ignore_eos: false,
image_data: vec![],
cache_prompt: true,
api_key: "".to_string(),
slot_id: -1,
prompt: concat!(
"<|im_start|>system\n",
"You are a helpful assistant.",
"<|im_end|>\n",
"<|im_start|>user\n",
"What is quantum entanglement ?",
"<|im_end|>\n",
"<|im_start|>assistant\n"
)
.to_string(),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
struct CompletionResponse {
id_slot: i64,
content: String,
stop: bool,
multimodal: Option<bool>,
// Only in stop response at stream end
// TODO Add all fields
model: Option<String>, // Path to model
tokens_predicted: Option<i64>,
tokens_evaluated: Option<i64>,
tokens_cached: Option<i64>,
timings: Option<CompletionTimings>,
}
#[derive(Debug, Serialize, Deserialize)]
struct CompletionTimings {
prompt_n: i64,
prompt_ms: f64,
prompt_per_token_ms: f64,
prompt_per_second: f64,
predicted_n: i64,
predicted_ms: f64,
predicted_per_token_ms: f64,
predicted_per_second: f64,
}
pub async fn do_completion_request() -> Result<(), Box<dyn std::error::Error>> {
let client = Client::new();
let request_body = CompletionRequest::default();
let request_builder = client
.post("http://100.64.0.3:18080/completion")
.header("Accept", "text/event-stream")
.header("Content-Type", "application/json")
.header("User-Agent", "llama_forge_rs")
.json(&request_body);
let mut es = EventSource::new(request_builder)?;
while let Some(event) = es.next().await {
match event {
Ok(Event::Open) => tracing::debug!("Connection Open!"),
Ok(Event::Message(event)) => match event.event.as_str() {
"message" => {
let data = event.data;
let response: CompletionResponse = serde_json::from_str(&data).unwrap();
print!("{}", response.content);
io::stdout().flush().unwrap();
}
_ => {
todo! {}
}
},
Err(err) => {
tracing::debug!("Error: {}", err);
es.close();
}
}
}
Ok(())
}