botserver/src/llm/mod.rs

162 lines
5.2 KiB
Rust
Raw Normal View History

use async_trait::async_trait;
use futures::StreamExt;
use log::{info, trace};
use serde_json::Value;
use tokio::sync::mpsc;
pub mod local;
#[async_trait]
pub trait LLMProvider: Send + Sync {
async fn generate(
&self,
prompt: &str,
config: &Value,
model: &str,
2025-11-22 12:26:16 -03:00
key: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
async fn generate_stream(
&self,
prompt: &str,
config: &Value,
tx: mpsc::Sender<String>,
model: &str,
2025-11-22 12:26:16 -03:00
key: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
async fn cancel_job(
&self,
session_id: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
}
2025-11-22 12:26:16 -03:00
#[derive(Debug)]
pub struct OpenAIClient {
2025-10-11 12:29:03 -03:00
client: reqwest::Client,
base_url: String,
}
#[async_trait]
impl LLMProvider for OpenAIClient {
async fn generate(
&self,
prompt: &str,
messages: &Value,
model: &str,
2025-11-22 12:26:16 -03:00
key: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let default_messages = serde_json::json!([{"role": "user", "content": prompt}]);
2025-11-22 12:26:16 -03:00
let response = self
.client
.post(&format!("{}/v1/chat/completions", self.base_url))
.header("Authorization", format!("Bearer {}", key))
2025-10-11 12:29:03 -03:00
.json(&serde_json::json!({
"model": model,
"messages": if messages.is_array() && !messages.as_array().unwrap().is_empty() {
messages
} else {
&default_messages
}
2025-10-11 12:29:03 -03:00
}))
.send()
.await?;
let result: Value = response.json().await?;
let raw_content = result["choices"][0]["message"]["content"]
2025-10-11 12:29:03 -03:00
.as_str()
.unwrap_or("");
let end_token = "final<|message|>";
let content = if let Some(pos) = raw_content.find(end_token) {
raw_content[(pos + end_token.len())..].to_string()
} else {
raw_content.to_string()
};
2025-10-11 12:29:03 -03:00
Ok(content)
}
async fn generate_stream(
&self,
prompt: &str,
messages: &Value,
tx: mpsc::Sender<String>,
model: &str,
2025-11-22 12:26:16 -03:00
key: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let default_messages = serde_json::json!([{"role": "user", "content": prompt}]);
2025-10-11 12:29:03 -03:00
let response = self
.client
.post(&format!("{}/v1/chat/completions", self.base_url))
.header("Authorization", format!("Bearer {}", key))
2025-10-11 12:29:03 -03:00
.json(&serde_json::json!({
"model": model,
"messages": if messages.is_array() && !messages.as_array().unwrap().is_empty() {
info!("Using provided messages: {:?}", messages);
messages
} else {
&default_messages
},
2025-10-11 12:29:03 -03:00
"stream": true
}))
.send()
.await?;
let status = response.status();
if status != reqwest::StatusCode::OK {
let error_text = response.text().await.unwrap_or_default();
trace!("LLM generate_stream error: {}", error_text);
return Err(format!("LLM request failed with status: {}", status).into());
}
2025-10-11 12:29:03 -03:00
let mut stream = response.bytes_stream();
let mut buffer = String::new();
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
let chunk_str = String::from_utf8_lossy(&chunk);
for line in chunk_str.lines() {
if line.starts_with("data: ") && !line.contains("[DONE]") {
if let Ok(data) = serde_json::from_str::<Value>(&line[6..]) {
if let Some(content) = data["choices"][0]["delta"]["content"].as_str() {
buffer.push_str(content);
let _ = tx.send(content.to_string()).await;
}
}
}
}
}
Ok(())
}
async fn cancel_job(
&self,
_session_id: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
Ok(())
}
}
impl OpenAIClient {
pub fn new(_api_key: String, base_url: Option<String>) -> Self {
Self {
client: reqwest::Client::new(),
2025-11-22 12:26:16 -03:00
base_url: base_url.unwrap(),
}
}
2025-11-22 12:26:16 -03:00
pub fn build_messages(
system_prompt: &str,
context_data: &str,
history: &[(String, String)],
) -> Value {
let mut messages = Vec::new();
if !system_prompt.is_empty() {
messages.push(serde_json::json!({
"role": "system",
"content": system_prompt
}));
}
if !context_data.is_empty() {
messages.push(serde_json::json!({
2025-11-22 12:26:16 -03:00
"role": "system",
"content": context_data
}));
}
for (role, content) in history {
messages.push(serde_json::json!({
"role": role,
"content": content
}));
}
serde_json::Value::Array(messages)
}
}