145 lines
3.9 KiB
Rust
145 lines
3.9 KiB
Rust
use std::{io, io::Write};
|
|
|
|
use futures::StreamExt;
|
|
use reqwest::Client;
|
|
use reqwest_eventsource::{Event, EventSource};
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
struct CompletionRequest {
|
|
stream: bool,
|
|
n_predict: i32,
|
|
temperature: f32,
|
|
stop: Vec<String>,
|
|
repeat_last_n: i32,
|
|
repeat_penalty: f32,
|
|
top_k: i32,
|
|
top_p: f32,
|
|
min_p: f32,
|
|
tfs_z: f32,
|
|
typical_p: f32,
|
|
presence_penalty: f32,
|
|
frequency_penalty: f32,
|
|
mirostat: f32,
|
|
mirostat_tau: f32,
|
|
mirostat_eta: f32,
|
|
grammar: String,
|
|
n_probs: i32,
|
|
image_data: Vec<String>,
|
|
ignore_eos: bool,
|
|
cache_prompt: bool,
|
|
api_key: String,
|
|
slot_id: i32,
|
|
prompt: String,
|
|
}
|
|
|
|
impl Default for CompletionRequest {
|
|
fn default() -> Self {
|
|
CompletionRequest {
|
|
stream: true,
|
|
n_predict: 2048,
|
|
temperature: 0.7,
|
|
stop: vec![
|
|
"<|im_end|>".to_string(),
|
|
"<dummy00001>".to_string(),
|
|
"</s>".to_string(),
|
|
"Llama:".to_string(),
|
|
"User:".to_string(),
|
|
],
|
|
repeat_last_n: 64,
|
|
repeat_penalty: 1.1,
|
|
top_k: 40,
|
|
top_p: 0.95,
|
|
min_p: 0.05,
|
|
tfs_z: 1.0,
|
|
typical_p: 1.0,
|
|
presence_penalty: 0.0,
|
|
frequency_penalty: 0.0,
|
|
mirostat: 0.0,
|
|
mirostat_tau: 5.0,
|
|
mirostat_eta: 0.1,
|
|
grammar: "".to_string(),
|
|
n_probs: 0,
|
|
ignore_eos: false,
|
|
image_data: vec![],
|
|
cache_prompt: true,
|
|
api_key: "".to_string(),
|
|
slot_id: -1,
|
|
prompt: concat!(
|
|
"<|im_start|>system\n",
|
|
"You are a helpful assistant.",
|
|
"<|im_end|>\n",
|
|
"<|im_start|>user\n",
|
|
"What is quantum entanglement ?",
|
|
"<|im_end|>\n",
|
|
"<|im_start|>assistant\n"
|
|
)
|
|
.to_string(),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
struct CompletionResponse {
|
|
id_slot: i64,
|
|
content: String,
|
|
stop: bool,
|
|
multimodal: Option<bool>,
|
|
// Only in stop response at stream end
|
|
// TODO Add all fields
|
|
model: Option<String>, // Path to model
|
|
tokens_predicted: Option<i64>,
|
|
tokens_evaluated: Option<i64>,
|
|
tokens_cached: Option<i64>,
|
|
timings: Option<CompletionTimings>,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
struct CompletionTimings {
|
|
prompt_n: i64,
|
|
prompt_ms: f64,
|
|
prompt_per_token_ms: f64,
|
|
prompt_per_second: f64,
|
|
predicted_n: i64,
|
|
predicted_ms: f64,
|
|
predicted_per_token_ms: f64,
|
|
predicted_per_second: f64,
|
|
}
|
|
|
|
pub async fn do_completion_request() -> Result<(), Box<dyn std::error::Error>> {
|
|
let client = Client::new();
|
|
|
|
let request_body = CompletionRequest::default();
|
|
|
|
let request_builder = client
|
|
.post("http://100.64.0.3:18080/completion")
|
|
.header("Accept", "text/event-stream")
|
|
.header("Content-Type", "application/json")
|
|
.header("User-Agent", "llama_forge_rs")
|
|
.json(&request_body);
|
|
|
|
let mut es = EventSource::new(request_builder)?;
|
|
|
|
while let Some(event) = es.next().await {
|
|
match event {
|
|
Ok(Event::Open) => tracing::debug!("Connection Open!"),
|
|
Ok(Event::Message(event)) => match event.event.as_str() {
|
|
"message" => {
|
|
let data = event.data;
|
|
let response: CompletionResponse = serde_json::from_str(&data).unwrap();
|
|
print!("{}", response.content);
|
|
io::stdout().flush().unwrap();
|
|
}
|
|
_ => {
|
|
todo! {}
|
|
}
|
|
},
|
|
Err(err) => {
|
|
tracing::debug!("Error: {}", err);
|
|
es.close();
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|