use async_trait::async_trait; use futures::StreamExt; use log::{info, trace}; use serde_json::Value; use tokio::sync::mpsc; pub mod cache; pub mod episodic_memory; pub mod llm_models; pub mod local; pub mod observability; pub use llm_models::get_handler; #[async_trait] pub trait LLMProvider: Send + Sync { async fn generate( &self, prompt: &str, config: &Value, model: &str, key: &str, ) -> Result>; async fn generate_stream( &self, prompt: &str, config: &Value, tx: mpsc::Sender, model: &str, key: &str, ) -> Result<(), Box>; async fn cancel_job( &self, session_id: &str, ) -> Result<(), Box>; } #[derive(Debug)] pub struct OpenAIClient { client: reqwest::Client, base_url: String, } impl OpenAIClient { pub fn new(_api_key: String, base_url: Option) -> Self { Self { client: reqwest::Client::new(), base_url: base_url.unwrap_or_else(|| "https://api.openai.com".to_string()), } } pub fn build_messages( system_prompt: &str, context_data: &str, history: &[(String, String)], ) -> Value { let mut messages = Vec::new(); if !system_prompt.is_empty() { messages.push(serde_json::json!({ "role": "system", "content": system_prompt })); } if !context_data.is_empty() { messages.push(serde_json::json!({ "role": "system", "content": context_data })); } for (role, content) in history { messages.push(serde_json::json!({ "role": role, "content": content })); } serde_json::Value::Array(messages) } } #[async_trait] impl LLMProvider for OpenAIClient { async fn generate( &self, prompt: &str, messages: &Value, model: &str, key: &str, ) -> Result> { let default_messages = serde_json::json!([{"role": "user", "content": prompt}]); let response = self .client .post(&format!("{}/v1/chat/completions", self.base_url)) .header("Authorization", format!("Bearer {}", key)) .json(&serde_json::json!({ "model": model, "messages": if messages.is_array() && !messages.as_array().unwrap().is_empty() { messages } else { &default_messages } })) .send() .await?; let result: Value = response.json().await?; let raw_content = result["choices"][0]["message"]["content"] .as_str() .unwrap_or(""); let handler = get_handler(model); let content = handler.process_content(raw_content); Ok(content) } async fn generate_stream( &self, prompt: &str, messages: &Value, tx: mpsc::Sender, model: &str, key: &str, ) -> Result<(), Box> { let default_messages = serde_json::json!([{"role": "user", "content": prompt}]); let response = self .client .post(&format!("{}/v1/chat/completions", self.base_url)) .header("Authorization", format!("Bearer {}", key)) .json(&serde_json::json!({ "model": model, "messages": if messages.is_array() && !messages.as_array().unwrap().is_empty() { info!("Using provided messages: {:?}", messages); messages } else { &default_messages }, "stream": true })) .send() .await?; let status = response.status(); if status != reqwest::StatusCode::OK { let error_text = response.text().await.unwrap_or_default(); trace!("LLM generate_stream error: {}", error_text); return Err(format!("LLM request failed with status: {}", status).into()); } let handler = get_handler(model); let mut stream = response.bytes_stream(); let mut buffer = String::new(); while let Some(chunk) = stream.next().await { let chunk = chunk?; let chunk_str = String::from_utf8_lossy(&chunk); for line in chunk_str.lines() { if line.starts_with("data: ") && !line.contains("[DONE]") { if let Ok(data) = serde_json::from_str::(&line[6..]) { if let Some(content) = data["choices"][0]["delta"]["content"].as_str() { buffer.push_str(content); let processed = handler.process_content(content); if !processed.is_empty() { let _ = tx.send(processed).await; } } } } } } Ok(()) } async fn cancel_job( &self, _session_id: &str, ) -> Result<(), Box> { Ok(()) } } pub fn start_llm_services(state: &std::sync::Arc) { episodic_memory::start_episodic_memory_scheduler(std::sync::Arc::clone(state)); info!("LLM services started (episodic memory scheduler)"); }