use std::{io, io::Write}; use futures::StreamExt; use reqwest::Client; use reqwest_eventsource::{Event, EventSource}; use serde::{Deserialize, Serialize}; #[derive(Debug, Serialize, Deserialize)] struct CompletionRequest { stream: bool, n_predict: i32, temperature: f32, stop: Vec, repeat_last_n: i32, repeat_penalty: f32, top_k: i32, top_p: f32, min_p: f32, tfs_z: f32, typical_p: f32, presence_penalty: f32, frequency_penalty: f32, mirostat: f32, mirostat_tau: f32, mirostat_eta: f32, grammar: String, n_probs: i32, image_data: Vec, ignore_eos: bool, cache_prompt: bool, api_key: String, slot_id: i32, prompt: String, } impl Default for CompletionRequest { fn default() -> Self { CompletionRequest { stream: true, n_predict: 2048, temperature: 0.7, stop: vec![ "<|im_end|>".to_string(), "".to_string(), "".to_string(), "Llama:".to_string(), "User:".to_string(), ], repeat_last_n: 64, repeat_penalty: 1.1, top_k: 40, top_p: 0.95, min_p: 0.05, tfs_z: 1.0, typical_p: 1.0, presence_penalty: 0.0, frequency_penalty: 0.0, mirostat: 0.0, mirostat_tau: 5.0, mirostat_eta: 0.1, grammar: "".to_string(), n_probs: 0, ignore_eos: false, image_data: vec![], cache_prompt: true, api_key: "".to_string(), slot_id: -1, prompt: concat!( "<|im_start|>system\n", "You are a helpful assistant.", "<|im_end|>\n", "<|im_start|>user\n", "What is quantum entanglement ?", "<|im_end|>\n", "<|im_start|>assistant\n" ) .to_string(), } } } #[derive(Debug, Serialize, Deserialize)] struct CompletionResponse { id_slot: i64, content: String, stop: bool, multimodal: Option, // Only in stop response at stream end // TODO Add all fields model: Option, // Path to model tokens_predicted: Option, tokens_evaluated: Option, tokens_cached: Option, timings: Option, } #[derive(Debug, Serialize, Deserialize)] struct CompletionTimings { prompt_n: i64, prompt_ms: f64, prompt_per_token_ms: f64, prompt_per_second: f64, predicted_n: i64, predicted_ms: f64, predicted_per_token_ms: f64, predicted_per_second: f64, } pub async fn do_completion_request() -> Result<(), Box> { let client = Client::new(); let request_body = CompletionRequest::default(); let request_builder = client .post("http://100.64.0.3:18080/completion") .header("Accept", "text/event-stream") .header("Content-Type", "application/json") .header("User-Agent", "llama_forge_rs") .json(&request_body); let mut es = EventSource::new(request_builder)?; while let Some(event) = es.next().await { match event { Ok(Event::Open) => tracing::debug!("Connection Open!"), Ok(Event::Message(event)) => match event.event.as_str() { "message" => { let data = event.data; let response: CompletionResponse = serde_json::from_str(&data).unwrap(); print!("{}", response.content); io::stdout().flush().unwrap(); } _ => { todo! {} } }, Err(err) => { tracing::debug!("Error: {}", err); es.close(); } } } Ok(()) }