2025-10-07 07:16:03 -03:00
|
|
|
use async_trait::async_trait;
|
|
|
|
|
use futures::StreamExt;
|
2025-11-11 22:31:19 -03:00
|
|
|
use log::info;
|
2025-10-07 07:16:03 -03:00
|
|
|
use serde_json::Value;
|
|
|
|
|
use tokio::sync::mpsc;
|
2025-11-02 18:36:21 -03:00
|
|
|
pub mod local;
|
2025-10-07 07:16:03 -03:00
|
|
|
#[async_trait]
|
|
|
|
|
pub trait LLMProvider: Send + Sync {
|
|
|
|
|
async fn generate(
|
|
|
|
|
&self,
|
|
|
|
|
prompt: &str,
|
|
|
|
|
config: &Value,
|
|
|
|
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
|
|
|
|
|
async fn generate_stream(
|
|
|
|
|
&self,
|
|
|
|
|
prompt: &str,
|
|
|
|
|
config: &Value,
|
|
|
|
|
tx: mpsc::Sender<String>,
|
|
|
|
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
|
2025-11-06 17:07:12 -03:00
|
|
|
async fn summarize(
|
|
|
|
|
&self,
|
|
|
|
|
text: &str,
|
|
|
|
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
|
|
|
|
let prompt = format!("Summarize the following conversation while preserving key details:\n\n{}", text);
|
|
|
|
|
self.generate(&prompt, &serde_json::json!({"max_tokens": 500}))
|
|
|
|
|
.await
|
|
|
|
|
}
|
2025-11-02 15:13:47 -03:00
|
|
|
async fn cancel_job(
|
|
|
|
|
&self,
|
|
|
|
|
session_id: &str,
|
|
|
|
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
|
2025-10-07 07:16:03 -03:00
|
|
|
}
|
|
|
|
|
pub struct OpenAIClient {
|
2025-10-11 12:29:03 -03:00
|
|
|
client: reqwest::Client,
|
|
|
|
|
api_key: String,
|
|
|
|
|
base_url: String,
|
2025-10-07 07:16:03 -03:00
|
|
|
}
|
|
|
|
|
#[async_trait]
|
|
|
|
|
impl LLMProvider for OpenAIClient {
|
|
|
|
|
async fn generate(
|
|
|
|
|
&self,
|
|
|
|
|
prompt: &str,
|
2025-11-11 22:31:19 -03:00
|
|
|
messages: &Value,
|
2025-10-07 07:16:03 -03:00
|
|
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
2025-11-11 22:31:19 -03:00
|
|
|
let default_messages = serde_json::json!([{"role": "user", "content": prompt}]);
|
2025-10-11 12:29:03 -03:00
|
|
|
let response = self
|
2025-10-07 07:16:03 -03:00
|
|
|
.client
|
2025-11-11 18:32:52 -03:00
|
|
|
.post(&format!("{}/v1/chat/completions/", self.base_url))
|
2025-10-11 12:29:03 -03:00
|
|
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
|
|
|
.json(&serde_json::json!({
|
|
|
|
|
"model": "gpt-3.5-turbo",
|
2025-11-11 22:31:19 -03:00
|
|
|
"messages": if messages.is_array() && !messages.as_array().unwrap().is_empty() {
|
|
|
|
|
messages
|
|
|
|
|
} else {
|
|
|
|
|
&default_messages
|
|
|
|
|
}
|
2025-10-11 12:29:03 -03:00
|
|
|
}))
|
|
|
|
|
.send()
|
|
|
|
|
.await?;
|
|
|
|
|
let result: Value = response.json().await?;
|
2025-10-15 22:24:04 -03:00
|
|
|
let raw_content = result["choices"][0]["message"]["content"]
|
2025-10-11 12:29:03 -03:00
|
|
|
.as_str()
|
2025-10-15 22:24:04 -03:00
|
|
|
.unwrap_or("");
|
|
|
|
|
let end_token = "final<|message|>";
|
|
|
|
|
let content = if let Some(pos) = raw_content.find(end_token) {
|
|
|
|
|
raw_content[(pos + end_token.len())..].to_string()
|
|
|
|
|
} else {
|
|
|
|
|
raw_content.to_string()
|
|
|
|
|
};
|
2025-10-11 12:29:03 -03:00
|
|
|
Ok(content)
|
2025-10-07 07:16:03 -03:00
|
|
|
}
|
|
|
|
|
async fn generate_stream(
|
|
|
|
|
&self,
|
|
|
|
|
prompt: &str,
|
2025-11-11 22:31:19 -03:00
|
|
|
messages: &Value,
|
2025-10-07 07:16:03 -03:00
|
|
|
tx: mpsc::Sender<String>,
|
|
|
|
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
2025-11-11 22:31:19 -03:00
|
|
|
let default_messages = serde_json::json!([{"role": "user", "content": prompt}]);
|
2025-10-11 12:29:03 -03:00
|
|
|
let response = self
|
2025-10-07 07:16:03 -03:00
|
|
|
.client
|
2025-11-04 23:11:33 -03:00
|
|
|
.post(&format!("{}/v1/chat/completions", self.base_url))
|
2025-10-11 12:29:03 -03:00
|
|
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
|
|
|
.json(&serde_json::json!({
|
|
|
|
|
"model": "gpt-3.5-turbo",
|
2025-11-11 22:31:19 -03:00
|
|
|
"messages": if messages.is_array() && !messages.as_array().unwrap().is_empty() {
|
|
|
|
|
info!("Using provided messages: {:?}", messages);
|
|
|
|
|
messages
|
|
|
|
|
} else {
|
|
|
|
|
&default_messages
|
|
|
|
|
},
|
2025-10-11 12:29:03 -03:00
|
|
|
"stream": true
|
|
|
|
|
}))
|
|
|
|
|
.send()
|
|
|
|
|
.await?;
|
|
|
|
|
let mut stream = response.bytes_stream();
|
|
|
|
|
let mut buffer = String::new();
|
|
|
|
|
while let Some(chunk) = stream.next().await {
|
|
|
|
|
let chunk = chunk?;
|
|
|
|
|
let chunk_str = String::from_utf8_lossy(&chunk);
|
|
|
|
|
for line in chunk_str.lines() {
|
|
|
|
|
if line.starts_with("data: ") && !line.contains("[DONE]") {
|
|
|
|
|
if let Ok(data) = serde_json::from_str::<Value>(&line[6..]) {
|
|
|
|
|
if let Some(content) = data["choices"][0]["delta"]["content"].as_str() {
|
|
|
|
|
buffer.push_str(content);
|
|
|
|
|
let _ = tx.send(content.to_string()).await;
|
|
|
|
|
}
|
2025-10-07 07:16:03 -03:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
2025-11-02 15:13:47 -03:00
|
|
|
async fn cancel_job(
|
|
|
|
|
&self,
|
|
|
|
|
_session_id: &str,
|
|
|
|
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
2025-10-07 07:16:03 -03:00
|
|
|
}
|
2025-11-11 21:13:12 -03:00
|
|
|
|
|
|
|
|
impl OpenAIClient {
|
|
|
|
|
pub fn new(api_key: String, base_url: Option<String>) -> Self {
|
|
|
|
|
Self {
|
|
|
|
|
client: reqwest::Client::new(),
|
|
|
|
|
api_key,
|
|
|
|
|
base_url: base_url.unwrap()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2025-11-11 22:31:19 -03:00
|
|
|
pub fn build_messages(system_prompt: &str, context_data: &str, history: &[(String, String)]) -> Value {
|
2025-11-11 21:13:12 -03:00
|
|
|
let mut messages = Vec::new();
|
2025-11-11 22:31:19 -03:00
|
|
|
if !system_prompt.is_empty() {
|
|
|
|
|
messages.push(serde_json::json!({
|
|
|
|
|
"role": "system",
|
|
|
|
|
"content": system_prompt
|
|
|
|
|
}));
|
2025-11-11 21:13:12 -03:00
|
|
|
}
|
2025-11-11 22:31:19 -03:00
|
|
|
if !context_data.is_empty() {
|
|
|
|
|
messages.push(serde_json::json!({
|
|
|
|
|
"role": "system",
|
|
|
|
|
"content": context_data
|
|
|
|
|
}));
|
|
|
|
|
}
|
|
|
|
|
for (role, content) in history {
|
|
|
|
|
messages.push(serde_json::json!({
|
|
|
|
|
"role": role,
|
|
|
|
|
"content": content
|
|
|
|
|
}));
|
2025-11-11 21:13:12 -03:00
|
|
|
}
|
2025-11-11 22:31:19 -03:00
|
|
|
serde_json::Value::Array(messages)
|
2025-11-11 21:13:12 -03:00
|
|
|
}
|
|
|
|
|
}
|