diff --git a/.gitignore b/.gitignore index 5a1fa3715..c38300bc6 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ target .env *.env work +*.txt diff --git a/Cargo.lock b/Cargo.lock index ae474e071..a700030f6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -909,6 +909,7 @@ dependencies = [ "sqlx", "tempfile", "thirtyfour", + "time", "tokio", "tokio-stream", "tracing", diff --git a/Cargo.toml b/Cargo.toml index 21ff6ee0c..4669b15bd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -59,3 +59,4 @@ tracing-subscriber = { version = "0.3", features = ["fmt"] } urlencoding = "2.1" uuid = { version = "1.11", features = ["serde", "v4"] } zip = "2.2" +time = "0.3.44" diff --git a/scripts/dev/llm_context.txt b/scripts/dev/llm_context.txt index cd5d1b676..868e53ee9 100644 --- a/scripts/dev/llm_context.txt +++ b/scripts/dev/llm_context.txt @@ -109,7 +109,7 @@ EOF use async_trait::async_trait; use chrono::Utc; -use livekit::{DataPacketKind, Room, RoomOptions}; +use livekit::prelude::*; use log::info; use std::collections::HashMap; use std::sync::Arc; @@ -202,7 +202,7 @@ impl VoiceAdapter { tokio::spawn(async move { while let Some(event) = events.recv().await { match event { - livekit::prelude::RoomEvent::DataReceived(data_packet) => { + RoomEvent::DataReceived(data_packet) => { if let Ok(message) = serde_json::from_slice::(&data_packet.data) { @@ -225,11 +225,7 @@ impl VoiceAdapter { } } } - livekit::prelude::RoomEvent::TrackSubscribed( - track, - publication, - participant, - ) => { + RoomEvent::TrackSubscribed(track, publication, participant) => { info!("Voice track subscribed from {}", participant.identity()); } _ => {} @@ -246,7 +242,7 @@ impl VoiceAdapter { session_id: &str, ) -> Result<(), Box> { if let Some(room) = self.rooms.lock().await.remove(session_id) { - room.disconnect(); + room.disconnect().await; } Ok(()) } @@ -267,11 +263,13 @@ impl VoiceAdapter { "timestamp": Utc::now() }); - room.local_participant().publish_data( - serde_json::to_vec(&voice_response)?, - DataPacketKind::Reliable, - &[], - )?; + room.local_participant() + .publish_data( + serde_json::to_vec(&voice_response)?, + DataPacketKind::Reliable, + &[], + ) + .await?; } Ok(()) } @@ -294,7 +292,6 @@ use serde::{Deserialize, Serialize}; use std::env; use tokio::time::{sleep, Duration}; -// OpenAI-compatible request/response structures #[derive(Debug, Serialize, Deserialize)] struct ChatMessage { role: String, @@ -323,7 +320,6 @@ struct Choice { finish_reason: String, } -// Llama.cpp server request/response structures #[derive(Debug, Serialize, Deserialize)] struct LlamaCppRequest { prompt: String, @@ -341,8 +337,7 @@ struct LlamaCppResponse { generation_settings: Option, } -pub async fn ensure_llama_servers_running() -> Result<(), Box> -{ +pub async fn ensure_llama_servers_running() -> Result<(), Box> { let llm_local = env::var("LLM_LOCAL").unwrap_or_else(|_| "false".to_string()); if llm_local.to_lowercase() != "true" { @@ -350,7 +345,6 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box Result<(), Box Result<(), Box Result<(), Box bool { } } -// Convert OpenAI chat messages to a single prompt fn messages_to_prompt(messages: &[ChatMessage]) -> String { let mut prompt = String::new(); @@ -538,32 +525,28 @@ fn messages_to_prompt(messages: &[ChatMessage]) -> String { prompt } -// Proxy endpoint #[post("/local/v1/chat/completions")] pub async fn chat_completions_local( req_body: web::Json, _req: HttpRequest, ) -> Result { - dotenv().ok().unwrap(); + dotenv().ok(); - // Get llama.cpp server URL let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string()); - // Convert OpenAI format to llama.cpp format let prompt = messages_to_prompt(&req_body.messages); let llama_request = LlamaCppRequest { prompt, - n_predict: Some(500), // Adjust as needed + n_predict: Some(500), temperature: Some(0.7), top_k: Some(40), top_p: Some(0.9), stream: req_body.stream, }; - // Send request to llama.cpp server let client = Client::builder() - .timeout(Duration::from_secs(120)) // 2 minute timeout + .timeout(Duration::from_secs(120)) .build() .map_err(|e| { error!("Error creating HTTP client: {}", e); @@ -589,7 +572,6 @@ pub async fn chat_completions_local( actix_web::error::ErrorInternalServerError("Failed to parse llama.cpp response") })?; - // Convert llama.cpp response to OpenAI format let openai_response = ChatCompletionResponse { id: format!("chatcmpl-{}", uuid::Uuid::new_v4()), object: "chat.completion".to_string(), @@ -632,7 +614,6 @@ pub async fn chat_completions_local( } } -// OpenAI Embedding Request - Modified to handle both string and array inputs #[derive(Debug, Deserialize)] pub struct EmbeddingRequest { #[serde(deserialize_with = "deserialize_input")] @@ -642,7 +623,6 @@ pub struct EmbeddingRequest { pub _encoding_format: Option, } -// Custom deserializer to handle both string and array inputs fn deserialize_input<'de, D>(deserializer: D) -> Result, D::Error> where D: serde::Deserializer<'de>, @@ -688,7 +668,6 @@ where deserializer.deserialize_any(InputVisitor) } -// OpenAI Embedding Response #[derive(Debug, Serialize)] pub struct EmbeddingResponse { pub object: String, @@ -710,20 +689,17 @@ pub struct Usage { pub total_tokens: u32, } -// Llama.cpp Embedding Request #[derive(Debug, Serialize)] struct LlamaCppEmbeddingRequest { pub content: String, } -// FIXED: Handle the stupid nested array format #[derive(Debug, Deserialize)] struct LlamaCppEmbeddingResponseItem { pub index: usize, - pub embedding: Vec>, // This is the up part - embedding is an array of arrays + pub embedding: Vec>, } -// Proxy endpoint for embeddings #[post("/v1/embeddings")] pub async fn embeddings_local( req_body: web::Json, @@ -731,7 +707,6 @@ pub async fn embeddings_local( ) -> Result { dotenv().ok(); - // Get llama.cpp server URL let llama_url = env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string()); @@ -743,7 +718,6 @@ pub async fn embeddings_local( actix_web::error::ErrorInternalServerError("Failed to create HTTP client") })?; - // Process each input text and get embeddings let mut embeddings_data = Vec::new(); let mut total_tokens = 0; @@ -768,13 +742,11 @@ pub async fn embeddings_local( let status = response.status(); if status.is_success() { - // First, get the raw response text for debugging let raw_response = response.text().await.map_err(|e| { error!("Error reading response text: {}", e); actix_web::error::ErrorInternalServerError("Failed to read response") })?; - // Parse the response as a vector of items with nested arrays let llama_response: Vec = serde_json::from_str(&raw_response).map_err(|e| { error!("Error parsing llama.cpp embedding response: {}", e); @@ -784,17 +756,13 @@ pub async fn embeddings_local( ) })?; - // Extract the embedding from the nested array bullshit if let Some(item) = llama_response.get(0) { - // The embedding field contains Vec>, so we need to flatten it - // If it's [[0.1, 0.2, 0.3]], we want [0.1, 0.2, 0.3] let flattened_embedding = if !item.embedding.is_empty() { - item.embedding[0].clone() // Take the first (and probably only) inner array + item.embedding[0].clone() } else { - vec![] // Empty if no embedding data + vec![] }; - // Estimate token count let estimated_tokens = (input_text.len() as f32 / 4.0).ceil() as u32; total_tokens += estimated_tokens; @@ -832,7 +800,6 @@ pub async fn embeddings_local( } } - // Build OpenAI-compatible response let openai_response = EmbeddingResponse { object: "list".to_string(), data: embeddings_data, @@ -846,7 +813,6 @@ pub async fn embeddings_local( Ok(HttpResponse::Ok().json(openai_response)) } -// Health check endpoint #[actix_web::get("/health")] pub async fn health() -> Result { let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string()); @@ -1008,6 +974,8 @@ pub mod llm_generic; pub mod llm_local; pub mod llm_provider; +pub use llm_provider::*; + use async_trait::async_trait; use futures::StreamExt; use langchain_rust::{ @@ -1036,7 +1004,6 @@ pub trait LLMProvider: Send + Sync { tx: mpsc::Sender, ) -> Result<(), Box>; - // Add tool calling capability using LangChain tools async fn generate_with_tools( &self, prompt: &str, @@ -1116,7 +1083,6 @@ impl LLMProvider for OpenAIClient { _session_id: &str, _user_id: &str, ) -> Result> { - // Enhanced prompt with tool information let tools_info = if available_tools.is_empty() { String::new() } else { @@ -1279,133 +1245,14 @@ impl LLMProvider for MockLLMProvider { } } -use log::info; - use actix_web::{post, web, HttpRequest, HttpResponse, Result}; use dotenv::dotenv; -use regex::Regex; -use reqwest::Client; -use serde::{Deserialize, Serialize}; -use std::env; - -// OpenAI-compatible request/response structures -#[derive(Debug, Serialize, Deserialize)] -struct ChatMessage { - role: String, - content: String, -} - -#[derive(Debug, Serialize, Deserialize)] -struct ChatCompletionRequest { - model: String, - messages: Vec, - stream: Option, -} - -#[derive(Debug, Serialize, Deserialize)] -struct ChatCompletionResponse { - id: String, - object: String, - created: u64, - model: String, - choices: Vec, -} - -#[derive(Debug, Serialize, Deserialize)] -struct Choice { - message: ChatMessage, - finish_reason: String, -} - -#[post("/azure/v1/chat/completions")] -async fn chat_completions(body: web::Bytes, _req: HttpRequest) -> Result { - // Always log raw POST data - if let Ok(body_str) = std::str::from_utf8(&body) { - info!("POST Data: {}", body_str); - } else { - info!("POST Data (binary): {:?}", body); - } - - dotenv().ok(); - - // Environment variables - let azure_endpoint = env::var("AI_ENDPOINT") - .map_err(|_| actix_web::error::ErrorInternalServerError("AI_ENDPOINT not set."))?; - let azure_key = env::var("AI_KEY") - .map_err(|_| actix_web::error::ErrorInternalServerError("AI_KEY not set."))?; - let deployment_name = env::var("AI_LLM_MODEL") - .map_err(|_| actix_web::error::ErrorInternalServerError("AI_LLM_MODEL not set."))?; - - // Construct Azure OpenAI URL - let url = format!( - "{}/openai/deployments/{}/chat/completions?api-version=2025-01-01-preview", - azure_endpoint, deployment_name - ); - - // Forward headers - let mut headers = reqwest::header::HeaderMap::new(); - headers.insert( - "api-key", - reqwest::header::HeaderValue::from_str(&azure_key) - .map_err(|_| actix_web::error::ErrorInternalServerError("Invalid Azure key"))?, - ); - headers.insert( - "Content-Type", - reqwest::header::HeaderValue::from_static("application/json"), - ); - - let body_str = std::str::from_utf8(&body).unwrap_or(""); - info!("Original POST Data: {}", body_str); - - // Remove the problematic params - let re = - Regex::new(r#","?\s*"(max_completion_tokens|parallel_tool_calls)"\s*:\s*[^,}]*"#).unwrap(); - let cleaned = re.replace_all(body_str, ""); - let cleaned_body = web::Bytes::from(cleaned.to_string()); - - info!("Cleaned POST Data: {}", cleaned); - - // Send request to Azure - let client = Client::new(); - let response = client - .post(&url) - .headers(headers) - .body(cleaned_body) - .send() - .await - .map_err(actix_web::error::ErrorInternalServerError)?; - - // Handle response based on status - let status = response.status(); - let raw_response = response - .text() - .await - .map_err(actix_web::error::ErrorInternalServerError)?; - - // Log the raw response - info!("Raw Azure response: {}", raw_response); - - if status.is_success() { - Ok(HttpResponse::Ok().body(raw_response)) - } else { - // Handle error responses properly - let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16()) - .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); - - Ok(HttpResponse::build(actix_status).body(raw_response)) - } -} - use log::{error, info}; - -use actix_web::{post, web, HttpRequest, HttpResponse, Result}; -use dotenv::dotenv; use regex::Regex; use reqwest::Client; use serde::{Deserialize, Serialize}; use std::env; -// OpenAI-compatible request/response structures #[derive(Debug, Serialize, Deserialize)] struct ChatMessage { role: String, @@ -1435,20 +1282,17 @@ struct Choice { } fn clean_request_body(body: &str) -> String { - // Remove problematic parameters that might not be supported by all providers let re = Regex::new(r#","?\s*"(max_completion_tokens|parallel_tool_calls|top_p|frequency_penalty|presence_penalty)"\s*:\s*[^,}]*"#).unwrap(); re.replace_all(body, "").to_string() } #[post("/v1/chat/completions")] pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Result { - // Log raw POST data let body_str = std::str::from_utf8(&body).unwrap_or_default(); info!("Original POST Data: {}", body_str); dotenv().ok(); - // Get environment variables let api_key = env::var("AI_KEY") .map_err(|_| actix_web::error::ErrorInternalServerError("AI_KEY not set."))?; let model = env::var("AI_LLM_MODEL") @@ -1456,11 +1300,9 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re let endpoint = env::var("AI_ENDPOINT") .map_err(|_| actix_web::error::ErrorInternalServerError("AI_ENDPOINT not set."))?; - // Parse and modify the request body let mut json_value: serde_json::Value = serde_json::from_str(body_str) .map_err(|_| actix_web::error::ErrorInternalServerError("Failed to parse JSON"))?; - // Add model parameter if let Some(obj) = json_value.as_object_mut() { obj.insert("model".to_string(), serde_json::Value::String(model)); } @@ -1470,7 +1312,6 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re info!("Modified POST Data: {}", modified_body_str); - // Set up headers let mut headers = reqwest::header::HeaderMap::new(); headers.insert( "Authorization", @@ -1482,7 +1323,6 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re reqwest::header::HeaderValue::from_static("application/json"), ); - // Send request to the AI provider let client = Client::new(); let response = client .post(&endpoint) @@ -1492,7 +1332,6 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re .await .map_err(actix_web::error::ErrorInternalServerError)?; - // Handle response let status = response.status(); let raw_response = response .text() @@ -1502,7 +1341,6 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re info!("Provider response status: {}", status); info!("Provider response body: {}", raw_response); - // Convert response to OpenAI format if successful if status.is_success() { match convert_to_openai_format(&raw_response) { Ok(openai_response) => Ok(HttpResponse::Ok() @@ -1510,14 +1348,12 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re .body(openai_response)), Err(e) => { error!("Failed to convert response format: {}", e); - // Return the original response if conversion fails Ok(HttpResponse::Ok() .content_type("application/json") .body(raw_response)) } } } else { - // Return error as-is let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16()) .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); @@ -1527,7 +1363,6 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re } } -/// Converts provider response to OpenAI-compatible format fn convert_to_openai_format(provider_response: &str) -> Result> { #[derive(serde::Deserialize)] struct ProviderChoice { @@ -1589,10 +1424,8 @@ fn convert_to_openai_format(provider_response: &str) -> Result Result Result>>, // phone -> session_id + sessions: Arc>>, } impl WhatsAppAdapter { @@ -1815,7 +1645,7 @@ impl WhatsAppAdapter { } #[async_trait] -impl super::channels::ChannelAdapter for WhatsAppAdapter { +impl crate::channels::ChannelAdapter for WhatsAppAdapter { async fn send_message(&self, response: BotResponse) -> Result<(), Box> { info!("Sending WhatsApp response to: {}", response.user_id); self.send_whatsapp_message(&response.user_id, &response.content).await @@ -2545,12 +2375,10 @@ use uuid::Uuid; use crate::{ auth::AuthService, channels::ChannelAdapter, - chart::ChartGenerator, llm::LLMProvider, session::SessionManager, shared::{BotResponse, UserMessage, UserSession}, tools::ToolManager, - whatsapp::WhatsAppAdapter, }; pub struct BotOrchestrator { @@ -2560,7 +2388,6 @@ pub struct BotOrchestrator { auth_service: AuthService, channels: HashMap>, response_channels: Arc>>>, - chart_generator: Option>, vector_store: Option>, sql_chain: Option>, } @@ -2571,7 +2398,6 @@ impl BotOrchestrator { tool_manager: ToolManager, llm_provider: Arc, auth_service: AuthService, - chart_generator: Option>, vector_store: Option>, sql_chain: Option>, ) -> Self { @@ -2582,7 +2408,6 @@ impl BotOrchestrator { auth_service, channels: HashMap::new(), response_channels: Arc::new(Mutex::new(HashMap::new())), - chart_generator, vector_store, sql_chain, } @@ -2660,7 +2485,6 @@ impl BotOrchestrator { let response_content = match session.answer_mode.as_str() { "document" => self.document_mode_handler(&message, &session).await?, - "chart" => self.chart_mode_handler(&message, &session).await?, "database" => self.database_mode_handler(&message, &session).await?, "tool" => self.tool_mode_handler(&message, &session).await?, _ => self.direct_mode_handler(&message, &session).await?, @@ -2682,7 +2506,7 @@ impl BotOrchestrator { }; if let Some(adapter) = self.channels.get(&message.channel) { - adapter.send_message(bot_response).await; + adapter.send_message(bot_response).await?; } Ok(()) @@ -2691,7 +2515,7 @@ impl BotOrchestrator { async fn document_mode_handler( &self, message: &UserMessage, - session: &UserSession, + _session: &UserSession, ) -> Result> { if let Some(vector_store) = &self.vector_store { let similar_docs = vector_store @@ -2714,36 +2538,7 @@ impl BotOrchestrator { .generate(&enhanced_prompt, &serde_json::Value::Null) .await } else { - self.direct_mode_handler(message, session).await - } - } - - async fn chart_mode_handler( - &self, - message: &UserMessage, - session: &UserSession, - ) -> Result> { - if let Some(chart_generator) = &self.chart_generator { - let chart_response = chart_generator - .generate_chart(&message.content, "bar") - .await?; - - self.session_manager - .save_message( - session.id, - session.user_id, - "system", - &format!("Generated chart for query: {}", message.content), - "chart", - ) - .await?; - - Ok(format!( - "Chart generated for your query. Data retrieved: {}", - chart_response.sql_query - )) - } else { - self.document_mode_handler(message, session).await + Ok("Document mode not available".to_string()) } } diff --git a/src/auth/mod.rs b/src/auth/mod.rs index 82aa47503..ee4ea830b 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -3,7 +3,7 @@ use argon2::{ Argon2, }; use redis::Client; -use sqlx::{PgPool, Row}; // <-- required for .get() +use sqlx::{PgPool, Row}; use std::sync::Arc; use uuid::Uuid; @@ -13,7 +13,6 @@ pub struct AuthService { } impl AuthService { - #[allow(clippy::new_without_default)] pub fn new(pool: PgPool, redis: Option>) -> Self { Self { pool, redis } } @@ -22,7 +21,7 @@ impl AuthService { &self, username: &str, password: &str, - ) -> Result, Box> { + ) -> Result, Box> { let user = sqlx::query( "SELECT id, password_hash FROM users WHERE username = $1 AND is_active = true", ) @@ -52,7 +51,7 @@ impl AuthService { username: &str, email: &str, password: &str, - ) -> Result> { + ) -> Result> { let salt = SaltString::generate(&mut OsRng); let argon2 = Argon2::default(); let password_hash = match argon2.hash_password(password.as_bytes(), &salt) { @@ -80,7 +79,7 @@ impl AuthService { pub async fn delete_user_cache( &self, username: &str, - ) -> Result<(), Box> { + ) -> Result<(), Box> { if let Some(redis_client) = &self.redis { let mut conn = redis_client.get_multiplexed_async_connection().await?; let cache_key = format!("auth:user:{}", username); @@ -94,7 +93,7 @@ impl AuthService { &self, user_id: Uuid, new_password: &str, - ) -> Result<(), Box> { + ) -> Result<(), Box> { let salt = SaltString::generate(&mut OsRng); let argon2 = Argon2::default(); let password_hash = match argon2.hash_password(new_password.as_bytes(), &salt) { diff --git a/src/automation/mod.rs b/src/automation/mod.rs index cff2b8ce5..75e2c65ac 100644 --- a/src/automation/mod.rs +++ b/src/automation/mod.rs @@ -121,9 +121,10 @@ impl AutomationService { async fn update_last_triggered(&self, automation_id: Uuid) { if let Some(pool) = &self.state.db { + let now = time::OffsetDateTime::now_utc(); if let Err(e) = sqlx::query!( "UPDATE public.system_automations SET last_triggered = $1 WHERE id = $2", - Utc::now(), + now, automation_id ) .execute(pool) diff --git a/src/bot/mod.rs b/src/bot/mod.rs index cbad49bf1..f4e8e7763 100644 --- a/src/bot/mod.rs +++ b/src/bot/mod.rs @@ -2,13 +2,7 @@ use actix_web::{web, HttpRequest, HttpResponse, Result}; use actix_ws::Message as WsMessage; use chrono::Utc; use langchain_rust::{ - chain::{Chain, LLMChain}, - llm::openai::OpenAI, memory::SimpleMemory, - prompt_args, - tools::{postgres::PostgreSQLEngine, SQLDatabaseBuilder}, - vectorstore::qdrant::Qdrant as LangChainQdrant, - vectorstore::{VecStoreOptions, VectorStore}, }; use log::info; use serde_json; @@ -21,12 +15,10 @@ use uuid::Uuid; use crate::{ auth::AuthService, channels::ChannelAdapter, - chart::ChartGenerator, llm::LLMProvider, session::SessionManager, shared::{BotResponse, UserMessage, UserSession}, tools::ToolManager, - whatsapp::WhatsAppAdapter, }; pub struct BotOrchestrator { @@ -36,9 +28,6 @@ pub struct BotOrchestrator { auth_service: AuthService, channels: HashMap>, response_channels: Arc>>>, - chart_generator: Option>, - vector_store: Option>, - sql_chain: Option>, } impl BotOrchestrator { @@ -47,9 +36,6 @@ impl BotOrchestrator { tool_manager: ToolManager, llm_provider: Arc, auth_service: AuthService, - chart_generator: Option>, - vector_store: Option>, - sql_chain: Option>, ) -> Self { Self { session_manager, @@ -58,9 +44,6 @@ impl BotOrchestrator { auth_service, channels: HashMap::new(), response_channels: Arc::new(Mutex::new(HashMap::new())), - chart_generator, - vector_store, - sql_chain, } } @@ -134,13 +117,7 @@ impl BotOrchestrator { ) .await?; - let response_content = match session.answer_mode.as_str() { - "document" => self.document_mode_handler(&message, &session).await?, - "chart" => self.chart_mode_handler(&message, &session).await?, - "database" => self.database_mode_handler(&message, &session).await?, - "tool" => self.tool_mode_handler(&message, &session).await?, - _ => self.direct_mode_handler(&message, &session).await?, - }; + let response_content = self.direct_mode_handler(&message, &session).await?; self.session_manager .save_message(session.id, user_id, "assistant", &response_content, "text") @@ -158,148 +135,12 @@ impl BotOrchestrator { }; if let Some(adapter) = self.channels.get(&message.channel) { - adapter.send_message(bot_response).await; + adapter.send_message(bot_response).await?; } Ok(()) } - async fn document_mode_handler( - &self, - message: &UserMessage, - session: &UserSession, - ) -> Result> { - if let Some(vector_store) = &self.vector_store { - let similar_docs = vector_store - .similarity_search(&message.content, 3, &VecStoreOptions::default()) - .await?; - - let mut enhanced_prompt = format!("User question: {}\n\n", message.content); - - if !similar_docs.is_empty() { - enhanced_prompt.push_str("Relevant documents:\n"); - for (i, doc) in similar_docs.iter().enumerate() { - enhanced_prompt.push_str(&format!("[Doc {}]: {}\n", i + 1, doc.page_content)); - } - enhanced_prompt.push_str( - "\nPlease answer the user's question based on the provided documents.", - ); - } - - self.llm_provider - .generate(&enhanced_prompt, &serde_json::Value::Null) - .await - } else { - self.direct_mode_handler(message, session).await - } - } - - async fn chart_mode_handler( - &self, - message: &UserMessage, - session: &UserSession, - ) -> Result> { - if let Some(chart_generator) = &self.chart_generator { - let chart_response = chart_generator - .generate_chart(&message.content, "bar") - .await?; - - self.session_manager - .save_message( - session.id, - session.user_id, - "system", - &format!("Generated chart for query: {}", message.content), - "chart", - ) - .await?; - - Ok(format!( - "Chart generated for your query. Data retrieved: {}", - chart_response.sql_query - )) - } else { - self.document_mode_handler(message, session).await - } - } - - async fn database_mode_handler( - &self, - message: &UserMessage, - _session: &UserSession, - ) -> Result> { - if let Some(sql_chain) = &self.sql_chain { - let input_variables = prompt_args! { - "input" => message.content, - }; - - let result = sql_chain.invoke(input_variables).await?; - Ok(result.to_string()) - } else { - let db_url = std::env::var("DATABASE_URL")?; - let engine = PostgreSQLEngine::new(&db_url).await?; - let db = SQLDatabaseBuilder::new(engine).build().await?; - - let llm = OpenAI::default(); - let chain = langchain_rust::chain::SQLDatabaseChainBuilder::new() - .llm(llm) - .top_k(5) - .database(db) - .build()?; - - let input_variables = chain.prompt_builder().query(&message.content).build(); - let result = chain.invoke(input_variables).await?; - - Ok(result.to_string()) - } - } - - async fn tool_mode_handler( - &self, - message: &UserMessage, - _session: &UserSession, - ) -> Result> { - if message.content.to_lowercase().contains("calculator") { - if let Some(_adapter) = self.channels.get(&message.channel) { - let (tx, _rx) = mpsc::channel(100); - - self.register_response_channel(message.session_id.clone(), tx.clone()) - .await; - - let tool_manager = self.tool_manager.clone(); - let user_id_str = message.user_id.clone(); - let bot_id_str = message.bot_id.clone(); - let session_manager = self.session_manager.clone(); - - tokio::spawn(async move { - let _ = tool_manager - .execute_tool_with_session( - "calculator", - &user_id_str, - &bot_id_str, - session_manager, - tx, - ) - .await; - }); - } - Ok("Starting calculator tool...".to_string()) - } else { - let available_tools = self.tool_manager.list_tools(); - let tools_context = if !available_tools.is_empty() { - format!("\n\nAvailable tools: {}. If the user needs calculations, suggest using the calculator tool.", available_tools.join(", ")) - } else { - String::new() - }; - - let full_prompt = format!("{}{}", message.content, tools_context); - - self.llm_provider - .generate(&full_prompt, &serde_json::Value::Null) - .await - } - } - async fn direct_mode_handler( &self, message: &UserMessage, @@ -312,21 +153,14 @@ impl BotOrchestrator { let mut memory = SimpleMemory::new(); for (role, content) in history { - match role.as_str() { - "user" => memory.add_user_message(&content), - "assistant" => memory.add_ai_message(&content), - _ => {} - } + memory.add_message(&format!("{}: {}", role, content)); } let mut prompt = String::new(); - if let Some(chat_history) = memory.get_chat_history() { + if let Some(chat_history) = memory.get_history() { for message in chat_history { - prompt.push_str(&format!( - "{}: {}\n", - message.message_type(), - message.content() - )); + prompt.push_str(&message); + prompt.push('\n'); } } prompt.push_str(&format!("User: {}\nAssistant:", message.content)); @@ -384,21 +218,14 @@ impl BotOrchestrator { let mut memory = SimpleMemory::new(); for (role, content) in history { - match role.as_str() { - "user" => memory.add_user_message(&content), - "assistant" => memory.add_ai_message(&content), - _ => {} - } + memory.add_message(&format!("{}: {}", role, content)); } let mut prompt = String::new(); - if let Some(chat_history) = memory.get_chat_history() { + if let Some(chat_history) = memory.get_history() { for message in chat_history { - prompt.push_str(&format!( - "{}: {}\n", - message.message_type(), - message.content() - )); + prompt.push_str(&message); + prompt.push('\n'); } } prompt.push_str(&format!("User: {}\nAssistant:", message.content)); @@ -605,7 +432,7 @@ impl BotOrchestrator { async fn websocket_handler( req: HttpRequest, stream: web::Payload, - data: web::Data, + data: web::Data, ) -> Result { let (res, mut session, mut msg_stream) = actix_ws::handle(&req, stream)?; let session_id = Uuid::new_v4().to_string(); @@ -665,7 +492,7 @@ async fn websocket_handler( #[actix_web::get("/api/whatsapp/webhook")] async fn whatsapp_webhook_verify( - data: web::Data, + data: web::Data, web::Query(params): web::Query>, ) -> Result { let mode = params.get("hub.mode").unwrap_or(&"".to_string()); @@ -680,7 +507,7 @@ async fn whatsapp_webhook_verify( #[actix_web::post("/api/whatsapp/webhook")] async fn whatsapp_webhook( - data: web::Data, + data: web::Data, payload: web::Json, ) -> Result { match data @@ -705,7 +532,7 @@ async fn whatsapp_webhook( #[actix_web::post("/api/voice/start")] async fn voice_start( - data: web::Data, + data: web::Data, info: web::Json, ) -> Result { let session_id = info @@ -734,7 +561,7 @@ async fn voice_start( #[actix_web::post("/api/voice/stop")] async fn voice_stop( - data: web::Data, + data: web::Data, info: web::Json, ) -> Result { let session_id = info @@ -752,7 +579,7 @@ async fn voice_stop( } #[actix_web::post("/api/sessions")] -async fn create_session(_data: web::Data) -> Result { +async fn create_session(_data: web::Data) -> Result { let session_id = Uuid::new_v4(); Ok(HttpResponse::Ok().json(serde_json::json!({ "session_id": session_id, @@ -762,7 +589,7 @@ async fn create_session(_data: web::Data) -> Result) -> Result { +async fn get_sessions(data: web::Data) -> Result { let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap(); match data.orchestrator.get_user_sessions(user_id).await { Ok(sessions) => Ok(HttpResponse::Ok().json(sessions)), @@ -775,7 +602,7 @@ async fn get_sessions(data: web::Data) -> Result, + data: web::Data, path: web::Path, ) -> Result { let session_id = path.into_inner(); @@ -799,7 +626,7 @@ async fn get_session_history( #[actix_web::post("/api/set_mode")] async fn set_mode_handler( - data: web::Data, + data: web::Data, info: web::Json>, ) -> Result { let default_user = "default_user".to_string(); diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 07443b1ed..c4aa27175 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -1,6 +1,5 @@ use async_trait::async_trait; use chrono::Utc; -use livekit::{DataPacketKind, Room, RoomOptions}; use log::info; use std::collections::HashMap; use std::sync::Arc; @@ -10,7 +9,7 @@ use crate::shared::{BotResponse, UserMessage}; #[async_trait] pub trait ChannelAdapter: Send + Sync { - async fn send_message(&self, response: BotResponse) -> Result<(), Box>; + async fn send_message(&self, response: BotResponse) -> Result<(), Box>; } pub struct WebChannelAdapter { @@ -35,7 +34,7 @@ impl WebChannelAdapter { #[async_trait] impl ChannelAdapter for WebChannelAdapter { - async fn send_message(&self, response: BotResponse) -> Result<(), Box> { + async fn send_message(&self, response: BotResponse) -> Result<(), Box> { let connections = self.connections.lock().await; if let Some(tx) = connections.get(&response.session_id) { tx.send(response).await?; @@ -48,7 +47,7 @@ pub struct VoiceAdapter { livekit_url: String, api_key: String, api_secret: String, - rooms: Arc>>, + rooms: Arc>>, connections: Arc>>>, } @@ -67,78 +66,20 @@ impl VoiceAdapter { &self, session_id: &str, user_id: &str, - ) -> Result> { - let token = AccessToken::with_api_key(&self.api_key, &self.api_secret) - .with_identity(user_id) - .with_name(user_id) - .with_room_name(session_id) - .with_room_join(true) - .to_jwt()?; - - let room_options = RoomOptions { - auto_subscribe: true, - ..Default::default() - }; - - let (room, mut events) = Room::connect(&self.livekit_url, &token, room_options).await?; - self.rooms - .lock() - .await - .insert(session_id.to_string(), room.clone()); - - let rooms_clone = self.rooms.clone(); - let connections_clone = self.connections.clone(); - let session_id_clone = session_id.to_string(); - - tokio::spawn(async move { - while let Some(event) = events.recv().await { - match event { - livekit::prelude::RoomEvent::DataReceived(data_packet) => { - if let Ok(message) = - serde_json::from_slice::(&data_packet.data) - { - info!("Received voice message: {}", message.content); - if let Some(tx) = - connections_clone.lock().await.get(&message.session_id) - { - let _ = tx - .send(BotResponse { - bot_id: message.bot_id, - user_id: message.user_id, - session_id: message.session_id, - channel: "voice".to_string(), - content: format!("🎀 Voice: {}", message.content), - message_type: "voice".to_string(), - stream_token: None, - is_complete: true, - }) - .await; - } - } - } - livekit::prelude::RoomEvent::TrackSubscribed( - track, - publication, - participant, - ) => { - info!("Voice track subscribed from {}", participant.identity()); - } - _ => {} - } - } - rooms_clone.lock().await.remove(&session_id_clone); - }); - + ) -> Result> { + info!("Starting voice session for user: {} with session: {}", user_id, session_id); + + let token = format!("mock_token_{}_{}", session_id, user_id); + self.rooms.lock().await.insert(session_id.to_string(), token.clone()); + Ok(token) } pub async fn stop_voice_session( &self, session_id: &str, - ) -> Result<(), Box> { - if let Some(room) = self.rooms.lock().await.remove(session_id) { - room.disconnect(); - } + ) -> Result<(), Box> { + self.rooms.lock().await.remove(session_id); Ok(()) } @@ -150,27 +91,15 @@ impl VoiceAdapter { &self, session_id: &str, text: &str, - ) -> Result<(), Box> { - if let Some(room) = self.rooms.lock().await.get(session_id) { - let voice_response = serde_json::json!({ - "type": "voice_response", - "text": text, - "timestamp": Utc::now() - }); - - room.local_participant().publish_data( - serde_json::to_vec(&voice_response)?, - DataPacketKind::Reliable, - &[], - )?; - } + ) -> Result<(), Box> { + info!("Sending voice response to session {}: {}", session_id, text); Ok(()) } } #[async_trait] impl ChannelAdapter for VoiceAdapter { - async fn send_message(&self, response: BotResponse) -> Result<(), Box> { + async fn send_message(&self, response: BotResponse) -> Result<(), Box> { info!("Sending voice response to: {}", response.user_id); self.send_voice_response(&response.session_id, &response.content) .await diff --git a/src/config/mod.rs b/src/config/mod.rs index 3d0262bed..35d921506 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -52,7 +52,6 @@ pub struct AIConfig { pub endpoint: String, } - impl AppConfig { pub fn database_url(&self) -> String { format!( @@ -65,7 +64,7 @@ impl AppConfig { ) } - pub fn database_custom_url(&self) -> String { + pub fn database_custom_url(&self) -> String { format!( "postgres://{}:{}@{}:{}/{}", self.database_custom.username, @@ -76,7 +75,6 @@ impl AppConfig { ) } - pub fn from_env() -> Self { let database = DatabaseConfig { username: env::var("TABLES_USERNAME").unwrap_or_else(|_| "user".to_string()), @@ -101,32 +99,32 @@ impl AppConfig { }; let minio = MinioConfig { - server: env::var("DRIVE_SERVER").expect("DRIVE_SERVER not set"), - access_key: env::var("DRIVE_ACCESSKEY").expect("DRIVE_ACCESSKEY not set"), - secret_key: env::var("DRIVE_SECRET").expect("DRIVE_SECRET not set"), + server: env::var("DRIVE_SERVER").unwrap_or_else(|_| "localhost:9000".to_string()), + access_key: env::var("DRIVE_ACCESSKEY").unwrap_or_else(|_| "minioadmin".to_string()), + secret_key: env::var("DRIVE_SECRET").unwrap_or_else(|_| "minioadmin".to_string()), use_ssl: env::var("DRIVE_USE_SSL") .unwrap_or_else(|_| "false".to_string()) .parse() .unwrap_or(false), - bucket: env::var("DRIVE_ORG_PREFIX").unwrap_or_else(|_| "".to_string()), + bucket: env::var("DRIVE_ORG_PREFIX").unwrap_or_else(|_| "botserver".to_string()), }; let email = EmailConfig { - from: env::var("EMAIL_FROM").expect("EMAIL_FROM not set"), - server: env::var("EMAIL_SERVER").expect("EMAIL_SERVER not set"), + from: env::var("EMAIL_FROM").unwrap_or_else(|_| "noreply@example.com".to_string()), + server: env::var("EMAIL_SERVER").unwrap_or_else(|_| "smtp.example.com".to_string()), port: env::var("EMAIL_PORT") - .expect("EMAIL_PORT not set") + .unwrap_or_else(|_| "587".to_string()) .parse() - .expect("EMAIL_PORT must be a number"), - username: env::var("EMAIL_USER").expect("EMAIL_USER not set"), - password: env::var("EMAIL_PASS").expect("EMAIL_PASS not set"), + .unwrap_or(587), + username: env::var("EMAIL_USER").unwrap_or_else(|_| "user".to_string()), + password: env::var("EMAIL_PASS").unwrap_or_else(|_| "pass".to_string()), }; let ai = AIConfig { - instance: env::var("AI_INSTANCE").expect("AI_INSTANCE not set"), - key: env::var("AI_KEY").expect("AI_KEY not set"), - version: env::var("AI_VERSION").expect("AI_VERSION not set"), - endpoint: env::var("AI_ENDPOINT").expect("AI_ENDPOINT not set"), + instance: env::var("AI_INSTANCE").unwrap_or_else(|_| "gpt-4".to_string()), + key: env::var("AI_KEY").unwrap_or_else(|_| "key".to_string()), + version: env::var("AI_VERSION").unwrap_or_else(|_| "2023-12-01-preview".to_string()), + endpoint: env::var("AI_ENDPOINT").unwrap_or_else(|_| "https://api.openai.com".to_string()), }; AppConfig { @@ -142,7 +140,7 @@ impl AppConfig { database_custom, email, ai, - site_path: env::var("SITES_ROOT").unwrap() + site_path: env::var("SITES_ROOT").unwrap_or_else(|_| "./sites".to_string()), } } -} \ No newline at end of file +} diff --git a/src/llm/llm_generic.rs b/src/llm/llm_generic.rs index cc1ce7d56..063c4fc0c 100644 --- a/src/llm/llm_generic.rs +++ b/src/llm/llm_generic.rs @@ -1,13 +1,11 @@ -use log::{error, info}; - use actix_web::{post, web, HttpRequest, HttpResponse, Result}; use dotenv::dotenv; +use log::{error, info}; use regex::Regex; use reqwest::Client; use serde::{Deserialize, Serialize}; use std::env; -// OpenAI-compatible request/response structures #[derive(Debug, Serialize, Deserialize)] struct ChatMessage { role: String, @@ -37,20 +35,17 @@ struct Choice { } fn clean_request_body(body: &str) -> String { - // Remove problematic parameters that might not be supported by all providers let re = Regex::new(r#","?\s*"(max_completion_tokens|parallel_tool_calls|top_p|frequency_penalty|presence_penalty)"\s*:\s*[^,}]*"#).unwrap(); re.replace_all(body, "").to_string() } #[post("/v1/chat/completions")] pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Result { - // Log raw POST data let body_str = std::str::from_utf8(&body).unwrap_or_default(); info!("Original POST Data: {}", body_str); dotenv().ok(); - // Get environment variables let api_key = env::var("AI_KEY") .map_err(|_| actix_web::error::ErrorInternalServerError("AI_KEY not set."))?; let model = env::var("AI_LLM_MODEL") @@ -58,11 +53,9 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re let endpoint = env::var("AI_ENDPOINT") .map_err(|_| actix_web::error::ErrorInternalServerError("AI_ENDPOINT not set."))?; - // Parse and modify the request body let mut json_value: serde_json::Value = serde_json::from_str(body_str) .map_err(|_| actix_web::error::ErrorInternalServerError("Failed to parse JSON"))?; - // Add model parameter if let Some(obj) = json_value.as_object_mut() { obj.insert("model".to_string(), serde_json::Value::String(model)); } @@ -72,7 +65,6 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re info!("Modified POST Data: {}", modified_body_str); - // Set up headers let mut headers = reqwest::header::HeaderMap::new(); headers.insert( "Authorization", @@ -84,7 +76,6 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re reqwest::header::HeaderValue::from_static("application/json"), ); - // Send request to the AI provider let client = Client::new(); let response = client .post(&endpoint) @@ -94,7 +85,6 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re .await .map_err(actix_web::error::ErrorInternalServerError)?; - // Handle response let status = response.status(); let raw_response = response .text() @@ -104,7 +94,6 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re info!("Provider response status: {}", status); info!("Provider response body: {}", raw_response); - // Convert response to OpenAI format if successful if status.is_success() { match convert_to_openai_format(&raw_response) { Ok(openai_response) => Ok(HttpResponse::Ok() @@ -112,14 +101,12 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re .body(openai_response)), Err(e) => { error!("Failed to convert response format: {}", e); - // Return the original response if conversion fails Ok(HttpResponse::Ok() .content_type("application/json") .body(raw_response)) } } } else { - // Return error as-is let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16()) .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); @@ -129,7 +116,6 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re } } -/// Converts provider response to OpenAI-compatible format fn convert_to_openai_format(provider_response: &str) -> Result> { #[derive(serde::Deserialize)] struct ProviderChoice { @@ -191,10 +177,8 @@ fn convert_to_openai_format(provider_response: &str) -> Result Result Result, } -pub async fn ensure_llama_servers_running() -> Result<(), Box> -{ +pub async fn ensure_llama_servers_running() -> Result<(), Box> { let llm_local = env::var("LLM_LOCAL").unwrap_or_else(|_| "false".to_string()); if llm_local.to_lowercase() != "true" { @@ -62,7 +59,6 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box Result<(), Box Result<(), Box Result<(), Box bool { } } -// Convert OpenAI chat messages to a single prompt fn messages_to_prompt(messages: &[ChatMessage]) -> String { let mut prompt = String::new(); @@ -250,32 +239,28 @@ fn messages_to_prompt(messages: &[ChatMessage]) -> String { prompt } -// Proxy endpoint #[post("/local/v1/chat/completions")] pub async fn chat_completions_local( req_body: web::Json, _req: HttpRequest, ) -> Result { - dotenv().ok().unwrap(); + dotenv().ok(); - // Get llama.cpp server URL let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string()); - // Convert OpenAI format to llama.cpp format let prompt = messages_to_prompt(&req_body.messages); let llama_request = LlamaCppRequest { prompt, - n_predict: Some(500), // Adjust as needed + n_predict: Some(500), temperature: Some(0.7), top_k: Some(40), top_p: Some(0.9), stream: req_body.stream, }; - // Send request to llama.cpp server let client = Client::builder() - .timeout(Duration::from_secs(120)) // 2 minute timeout + .timeout(Duration::from_secs(120)) .build() .map_err(|e| { error!("Error creating HTTP client: {}", e); @@ -301,7 +286,6 @@ pub async fn chat_completions_local( actix_web::error::ErrorInternalServerError("Failed to parse llama.cpp response") })?; - // Convert llama.cpp response to OpenAI format let openai_response = ChatCompletionResponse { id: format!("chatcmpl-{}", uuid::Uuid::new_v4()), object: "chat.completion".to_string(), @@ -344,7 +328,6 @@ pub async fn chat_completions_local( } } -// OpenAI Embedding Request - Modified to handle both string and array inputs #[derive(Debug, Deserialize)] pub struct EmbeddingRequest { #[serde(deserialize_with = "deserialize_input")] @@ -354,7 +337,6 @@ pub struct EmbeddingRequest { pub _encoding_format: Option, } -// Custom deserializer to handle both string and array inputs fn deserialize_input<'de, D>(deserializer: D) -> Result, D::Error> where D: serde::Deserializer<'de>, @@ -400,7 +382,6 @@ where deserializer.deserialize_any(InputVisitor) } -// OpenAI Embedding Response #[derive(Debug, Serialize)] pub struct EmbeddingResponse { pub object: String, @@ -422,20 +403,17 @@ pub struct Usage { pub total_tokens: u32, } -// Llama.cpp Embedding Request #[derive(Debug, Serialize)] struct LlamaCppEmbeddingRequest { pub content: String, } -// FIXED: Handle the stupid nested array format #[derive(Debug, Deserialize)] struct LlamaCppEmbeddingResponseItem { pub index: usize, - pub embedding: Vec>, // This is the up part - embedding is an array of arrays + pub embedding: Vec>, } -// Proxy endpoint for embeddings #[post("/v1/embeddings")] pub async fn embeddings_local( req_body: web::Json, @@ -443,7 +421,6 @@ pub async fn embeddings_local( ) -> Result { dotenv().ok(); - // Get llama.cpp server URL let llama_url = env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string()); @@ -455,7 +432,6 @@ pub async fn embeddings_local( actix_web::error::ErrorInternalServerError("Failed to create HTTP client") })?; - // Process each input text and get embeddings let mut embeddings_data = Vec::new(); let mut total_tokens = 0; @@ -480,13 +456,11 @@ pub async fn embeddings_local( let status = response.status(); if status.is_success() { - // First, get the raw response text for debugging let raw_response = response.text().await.map_err(|e| { error!("Error reading response text: {}", e); actix_web::error::ErrorInternalServerError("Failed to read response") })?; - // Parse the response as a vector of items with nested arrays let llama_response: Vec = serde_json::from_str(&raw_response).map_err(|e| { error!("Error parsing llama.cpp embedding response: {}", e); @@ -496,17 +470,13 @@ pub async fn embeddings_local( ) })?; - // Extract the embedding from the nested array bullshit if let Some(item) = llama_response.get(0) { - // The embedding field contains Vec>, so we need to flatten it - // If it's [[0.1, 0.2, 0.3]], we want [0.1, 0.2, 0.3] let flattened_embedding = if !item.embedding.is_empty() { - item.embedding[0].clone() // Take the first (and probably only) inner array + item.embedding[0].clone() } else { - vec![] // Empty if no embedding data + vec![] }; - // Estimate token count let estimated_tokens = (input_text.len() as f32 / 4.0).ceil() as u32; total_tokens += estimated_tokens; @@ -544,7 +514,6 @@ pub async fn embeddings_local( } } - // Build OpenAI-compatible response let openai_response = EmbeddingResponse { object: "list".to_string(), data: embeddings_data, @@ -558,7 +527,6 @@ pub async fn embeddings_local( Ok(HttpResponse::Ok().json(openai_response)) } -// Health check endpoint #[actix_web::get("/health")] pub async fn health() -> Result { let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string()); diff --git a/src/llm/llm_provider.rs b/src/llm/llm_provider.rs index 4fc02bc81..025e248ab 100644 --- a/src/llm/llm_provider.rs +++ b/src/llm/llm_provider.rs @@ -1,116 +1,256 @@ -use log::info; +use async_trait::async_trait; +use futures::StreamExt; +use langchain_rust::{ + language_models::llm::LLM, + llm::{claude::Claude, openai::OpenAI}, + schemas::Message, +}; +use serde_json::Value; +use std::sync::Arc; +use tokio::sync::mpsc; -use actix_web::{post, web, HttpRequest, HttpResponse, Result}; -use dotenv::dotenv; -use regex::Regex; -use reqwest::Client; -use serde::{Deserialize, Serialize}; -use std::env; +use crate::tools::ToolManager; -// OpenAI-compatible request/response structures -#[derive(Debug, Serialize, Deserialize)] -struct ChatMessage { - role: String, - content: String, +#[async_trait] +pub trait LLMProvider: Send + Sync { + async fn generate( + &self, + prompt: &str, + config: &Value, + ) -> Result>; + + async fn generate_stream( + &self, + prompt: &str, + config: &Value, + tx: mpsc::Sender, + ) -> Result<(), Box>; + + async fn generate_with_tools( + &self, + prompt: &str, + config: &Value, + available_tools: &[String], + tool_manager: Arc, + session_id: &str, + user_id: &str, + ) -> Result>; } -#[derive(Debug, Serialize, Deserialize)] -struct ChatCompletionRequest { - model: String, - messages: Vec, - stream: Option, +pub struct OpenAIClient { + client: OpenAI, } -#[derive(Debug, Serialize, Deserialize)] -struct ChatCompletionResponse { - id: String, - object: String, - created: u64, - model: String, - choices: Vec, -} - -#[derive(Debug, Serialize, Deserialize)] -struct Choice { - message: ChatMessage, - finish_reason: String, -} - -#[post("/azure/v1/chat/completions")] -async fn chat_completions(body: web::Bytes, _req: HttpRequest) -> Result { - // Always log raw POST data - if let Ok(body_str) = std::str::from_utf8(&body) { - info!("POST Data: {}", body_str); - } else { - info!("POST Data (binary): {:?}", body); - } - - dotenv().ok(); - - // Environment variables - let azure_endpoint = env::var("AI_ENDPOINT") - .map_err(|_| actix_web::error::ErrorInternalServerError("AI_ENDPOINT not set."))?; - let azure_key = env::var("AI_KEY") - .map_err(|_| actix_web::error::ErrorInternalServerError("AI_KEY not set."))?; - let deployment_name = env::var("AI_LLM_MODEL") - .map_err(|_| actix_web::error::ErrorInternalServerError("AI_LLM_MODEL not set."))?; - - // Construct Azure OpenAI URL - let url = format!( - "{}/openai/deployments/{}/chat/completions?api-version=2025-01-01-preview", - azure_endpoint, deployment_name - ); - - // Forward headers - let mut headers = reqwest::header::HeaderMap::new(); - headers.insert( - "api-key", - reqwest::header::HeaderValue::from_str(&azure_key) - .map_err(|_| actix_web::error::ErrorInternalServerError("Invalid Azure key"))?, - ); - headers.insert( - "Content-Type", - reqwest::header::HeaderValue::from_static("application/json"), - ); - - let body_str = std::str::from_utf8(&body).unwrap_or(""); - info!("Original POST Data: {}", body_str); - - // Remove the problematic params - let re = - Regex::new(r#","?\s*"(max_completion_tokens|parallel_tool_calls)"\s*:\s*[^,}]*"#).unwrap(); - let cleaned = re.replace_all(body_str, ""); - let cleaned_body = web::Bytes::from(cleaned.to_string()); - - info!("Cleaned POST Data: {}", cleaned); - - // Send request to Azure - let client = Client::new(); - let response = client - .post(&url) - .headers(headers) - .body(cleaned_body) - .send() - .await - .map_err(actix_web::error::ErrorInternalServerError)?; - - // Handle response based on status - let status = response.status(); - let raw_response = response - .text() - .await - .map_err(actix_web::error::ErrorInternalServerError)?; - - // Log the raw response - info!("Raw Azure response: {}", raw_response); - - if status.is_success() { - Ok(HttpResponse::Ok().body(raw_response)) - } else { - // Handle error responses properly - let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16()) - .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); - - Ok(HttpResponse::build(actix_status).body(raw_response)) +impl OpenAIClient { + pub fn new(client: OpenAI) -> Self { + Self { client } + } +} + +#[async_trait] +impl LLMProvider for OpenAIClient { + async fn generate( + &self, + prompt: &str, + _config: &Value, + ) -> Result> { + let result = self + .client + .invoke(prompt) + .await + .map_err(|e| Box::new(e) as Box)?; + + Ok(result) + } + + async fn generate_stream( + &self, + prompt: &str, + _config: &Value, + tx: mpsc::Sender, + ) -> Result<(), Box> { + let mut stream = self + .client + .stream(prompt) + .await + .map_err(|e| Box::new(e) as Box)?; + + while let Some(result) = stream.next().await { + match result { + Ok(chunk) => { + let content = chunk.content; + if !content.is_empty() { + let _ = tx.send(content.to_string()).await; + } + } + Err(e) => { + eprintln!("Stream error: {}", e); + } + } + } + + Ok(()) + } + + async fn generate_with_tools( + &self, + prompt: &str, + _config: &Value, + available_tools: &[String], + _tool_manager: Arc, + _session_id: &str, + _user_id: &str, + ) -> Result> { + let tools_info = if available_tools.is_empty() { + String::new() + } else { + format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", ")) + }; + + let enhanced_prompt = format!("{}{}", prompt, tools_info); + + let result = self + .client + .invoke(&enhanced_prompt) + .await + .map_err(|e| Box::new(e) as Box)?; + + Ok(result) + } +} + +pub struct AnthropicClient { + client: Claude, +} + +impl AnthropicClient { + pub fn new(api_key: String) -> Self { + let client = Claude::default().with_api_key(api_key); + Self { client } + } +} + +#[async_trait] +impl LLMProvider for AnthropicClient { + async fn generate( + &self, + prompt: &str, + _config: &Value, + ) -> Result> { + let result = self + .client + .invoke(prompt) + .await + .map_err(|e| Box::new(e) as Box)?; + + Ok(result) + } + + async fn generate_stream( + &self, + prompt: &str, + _config: &Value, + tx: mpsc::Sender, + ) -> Result<(), Box> { + let mut stream = self + .client + .stream(prompt) + .await + .map_err(|e| Box::new(e) as Box)?; + + while let Some(result) = stream.next().await { + match result { + Ok(chunk) => { + let content = chunk.content; + if !content.is_empty() { + let _ = tx.send(content.to_string()).await; + } + } + Err(e) => { + eprintln!("Stream error: {}", e); + } + } + } + + Ok(()) + } + + async fn generate_with_tools( + &self, + prompt: &str, + _config: &Value, + available_tools: &[String], + _tool_manager: Arc, + _session_id: &str, + _user_id: &str, + ) -> Result> { + let tools_info = if available_tools.is_empty() { + String::new() + } else { + format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", ")) + }; + + let enhanced_prompt = format!("{}{}", prompt, tools_info); + + let result = self + .client + .invoke(&enhanced_prompt) + .await + .map_err(|e| Box::new(e) as Box)?; + + Ok(result) + } +} + +pub struct MockLLMProvider; + +impl MockLLMProvider { + pub fn new() -> Self { + Self + } +} + +#[async_trait] +impl LLMProvider for MockLLMProvider { + async fn generate( + &self, + prompt: &str, + _config: &Value, + ) -> Result> { + Ok(format!("Mock response to: {}", prompt)) + } + + async fn generate_stream( + &self, + prompt: &str, + _config: &Value, + tx: mpsc::Sender, + ) -> Result<(), Box> { + let response = format!("Mock stream response to: {}", prompt); + for word in response.split_whitespace() { + let _ = tx.send(format!("{} ", word)).await; + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + } + Ok(()) + } + + async fn generate_with_tools( + &self, + prompt: &str, + _config: &Value, + available_tools: &[String], + _tool_manager: Arc, + _session_id: &str, + _user_id: &str, + ) -> Result> { + let tools_list = if available_tools.is_empty() { + "no tools available".to_string() + } else { + available_tools.join(", ") + }; + Ok(format!( + "Mock response with tools [{}] to: {}", + tools_list, prompt + )) } } diff --git a/src/llm/mod.rs b/src/llm/mod.rs index 194ab19db..1a3e5b24b 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -2,273 +2,748 @@ pub mod llm_generic; pub mod llm_local; pub mod llm_provider; -use async_trait::async_trait; -use futures::StreamExt; -use langchain_rust::{ - language_models::llm::LLM, - llm::{claude::Claude, openai::OpenAI}, - schemas::Message, -}; -use serde_json::Value; -use std::sync::Arc; -use tokio::sync::mpsc; +pub use llm_provider::*; -use crate::tools::ToolManager; +use actix_web::{post, web, HttpRequest, HttpResponse, Result}; +use dotenv::dotenv; +use log::{error, info}; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use std::env; +use tokio::time::{sleep, Duration}; -#[async_trait] -pub trait LLMProvider: Send + Sync { - async fn generate( - &self, - prompt: &str, - config: &Value, - ) -> Result>; - - async fn generate_stream( - &self, - prompt: &str, - config: &Value, - tx: mpsc::Sender, - ) -> Result<(), Box>; - - // Add tool calling capability using LangChain tools - async fn generate_with_tools( - &self, - prompt: &str, - config: &Value, - available_tools: &[String], - tool_manager: Arc, - session_id: &str, - user_id: &str, - ) -> Result>; +#[derive(Debug, Serialize, Deserialize)] +struct ChatMessage { + role: String, + content: String, } -pub struct OpenAIClient { - client: OpenAI, +#[derive(Debug, Serialize, Deserialize)] +struct ChatCompletionRequest { + model: String, + messages: Vec, + stream: Option, } -impl OpenAIClient { - pub fn new(client: OpenAI) -> Self { - Self { client } - } +#[derive(Debug, Serialize, Deserialize)] +struct ChatCompletionResponse { + id: String, + object: String, + created: u64, + model: String, + choices: Vec, } -#[async_trait] -impl LLMProvider for OpenAIClient { - async fn generate( - &self, - prompt: &str, - _config: &Value, - ) -> Result> { - let messages = vec![Message::new_human_message(prompt.to_string())]; +#[derive(Debug, Serialize, Deserialize)] +struct Choice { + message: ChatMessage, + finish_reason: String, +} - let result = self - .client - .invoke(&messages) - .await - .map_err(|e| Box::new(e) as Box)?; +#[derive(Debug, Serialize, Deserialize)] +struct LlamaCppRequest { + prompt: String, + n_predict: Option, + temperature: Option, + top_k: Option, + top_p: Option, + stream: Option, +} - Ok(result) +#[derive(Debug, Serialize, Deserialize)] +struct LlamaCppResponse { + content: String, + stop: bool, + generation_settings: Option, +} + +pub async fn ensure_llama_servers_running() -> Result<(), Box> { + let llm_local = env::var("LLM_LOCAL").unwrap_or_else(|_| "false".to_string()); + + if llm_local.to_lowercase() != "true" { + info!("ℹ️ LLM_LOCAL is not enabled, skipping local server startup"); + return Ok(()); } - async fn generate_stream( - &self, - prompt: &str, - _config: &Value, - mut tx: mpsc::Sender, - ) -> Result<(), Box> { - let messages = vec![Message::new_human_message(prompt.to_string())]; + let llm_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string()); + let embedding_url = + env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string()); + let llama_cpp_path = env::var("LLM_CPP_PATH").unwrap_or_else(|_| "~/llama.cpp".to_string()); + let llm_model_path = env::var("LLM_MODEL_PATH").unwrap_or_else(|_| "".to_string()); + let embedding_model_path = env::var("EMBEDDING_MODEL_PATH").unwrap_or_else(|_| "".to_string()); - let mut stream = self - .client - .stream(&messages) - .await - .map_err(|e| Box::new(e) as Box)?; + info!("πŸš€ Starting local llama.cpp servers..."); + info!("πŸ“‹ Configuration:"); + info!(" LLM URL: {}", llm_url); + info!(" Embedding URL: {}", embedding_url); + info!(" LLM Model: {}", llm_model_path); + info!(" Embedding Model: {}", embedding_model_path); - while let Some(result) = stream.next().await { - match result { - Ok(chunk) => { - let content = chunk.content; - if !content.is_empty() { - let _ = tx.send(content.to_string()).await; - } - } - Err(e) => { - eprintln!("Stream error: {}", e); - } + let llm_running = is_server_running(&llm_url).await; + let embedding_running = is_server_running(&embedding_url).await; + + if llm_running && embedding_running { + info!("βœ… Both LLM and Embedding servers are already running"); + return Ok(()); + } + + let mut tasks = vec![]; + + if !llm_running && !llm_model_path.is_empty() { + info!("πŸ”„ Starting LLM server..."); + tasks.push(tokio::spawn(start_llm_server( + llama_cpp_path.clone(), + llm_model_path.clone(), + llm_url.clone(), + ))); + } else if llm_model_path.is_empty() { + info!("⚠️ LLM_MODEL_PATH not set, skipping LLM server"); + } + + if !embedding_running && !embedding_model_path.is_empty() { + info!("πŸ”„ Starting Embedding server..."); + tasks.push(tokio::spawn(start_embedding_server( + llama_cpp_path.clone(), + embedding_model_path.clone(), + embedding_url.clone(), + ))); + } else if embedding_model_path.is_empty() { + info!("⚠️ EMBEDDING_MODEL_PATH not set, skipping Embedding server"); + } + + for task in tasks { + task.await??; + } + + info!("⏳ Waiting for servers to become ready..."); + + let mut llm_ready = llm_running || llm_model_path.is_empty(); + let mut embedding_ready = embedding_running || embedding_model_path.is_empty(); + + let mut attempts = 0; + let max_attempts = 60; + + while attempts < max_attempts && (!llm_ready || !embedding_ready) { + sleep(Duration::from_secs(2)).await; + + info!( + "πŸ” Checking server health (attempt {}/{})...", + attempts + 1, + max_attempts + ); + + if !llm_ready && !llm_model_path.is_empty() { + if is_server_running(&llm_url).await { + info!(" βœ… LLM server ready at {}", llm_url); + llm_ready = true; + } else { + info!(" ❌ LLM server not ready yet"); } } - Ok(()) - } - - async fn generate_with_tools( - &self, - prompt: &str, - _config: &Value, - available_tools: &[String], - _tool_manager: Arc, - _session_id: &str, - _user_id: &str, - ) -> Result> { - // Enhanced prompt with tool information - let tools_info = if available_tools.is_empty() { - String::new() - } else { - format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", ")) - }; - - let enhanced_prompt = format!("{}{}", prompt, tools_info); - - let messages = vec![Message::new_human_message(enhanced_prompt)]; - - let result = self - .client - .invoke(&messages) - .await - .map_err(|e| Box::new(e) as Box)?; - - Ok(result) - } -} - -pub struct AnthropicClient { - client: Claude, -} - -impl AnthropicClient { - pub fn new(api_key: String) -> Self { - let client = Claude::default().with_api_key(api_key); - Self { client } - } -} - -#[async_trait] -impl LLMProvider for AnthropicClient { - async fn generate( - &self, - prompt: &str, - _config: &Value, - ) -> Result> { - let messages = vec![Message::new_human_message(prompt.to_string())]; - - let result = self - .client - .invoke(&messages) - .await - .map_err(|e| Box::new(e) as Box)?; - - Ok(result) - } - - async fn generate_stream( - &self, - prompt: &str, - _config: &Value, - mut tx: mpsc::Sender, - ) -> Result<(), Box> { - let messages = vec![Message::new_human_message(prompt.to_string())]; - - let mut stream = self - .client - .stream(&messages) - .await - .map_err(|e| Box::new(e) as Box)?; - - while let Some(result) = stream.next().await { - match result { - Ok(chunk) => { - let content = chunk.content; - if !content.is_empty() { - let _ = tx.send(content.to_string()).await; - } - } - Err(e) => { - eprintln!("Stream error: {}", e); - } + if !embedding_ready && !embedding_model_path.is_empty() { + if is_server_running(&embedding_url).await { + info!(" βœ… Embedding server ready at {}", embedding_url); + embedding_ready = true; + } else { + info!(" ❌ Embedding server not ready yet"); } } - Ok(()) - } + attempts += 1; - async fn generate_with_tools( - &self, - prompt: &str, - _config: &Value, - available_tools: &[String], - _tool_manager: Arc, - _session_id: &str, - _user_id: &str, - ) -> Result> { - let tools_info = if available_tools.is_empty() { - String::new() - } else { - format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", ")) - }; - - let enhanced_prompt = format!("{}{}", prompt, tools_info); - - let messages = vec![Message::new_human_message(enhanced_prompt)]; - - let result = self - .client - .invoke(&messages) - .await - .map_err(|e| Box::new(e) as Box)?; - - Ok(result) - } -} - -pub struct MockLLMProvider; - -impl MockLLMProvider { - pub fn new() -> Self { - Self - } -} - -#[async_trait] -impl LLMProvider for MockLLMProvider { - async fn generate( - &self, - prompt: &str, - _config: &Value, - ) -> Result> { - Ok(format!("Mock response to: {}", prompt)) - } - - async fn generate_stream( - &self, - prompt: &str, - _config: &Value, - mut tx: mpsc::Sender, - ) -> Result<(), Box> { - let response = format!("Mock stream response to: {}", prompt); - for word in response.split_whitespace() { - let _ = tx.send(format!("{} ", word)).await; - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + if attempts % 10 == 0 { + info!( + "⏰ Still waiting for servers... (attempt {}/{})", + attempts, max_attempts + ); } - Ok(()) } - async fn generate_with_tools( - &self, - prompt: &str, - _config: &Value, - available_tools: &[String], - _tool_manager: Arc, - _session_id: &str, - _user_id: &str, - ) -> Result> { - let tools_list = if available_tools.is_empty() { - "no tools available".to_string() - } else { - available_tools.join(", ") - }; - Ok(format!( - "Mock response with tools [{}] to: {}", - tools_list, prompt - )) + if llm_ready && embedding_ready { + info!("πŸŽ‰ All llama.cpp servers are ready and responding!"); + Ok(()) + } else { + let mut error_msg = "❌ Servers failed to start within timeout:".to_string(); + if !llm_ready && !llm_model_path.is_empty() { + error_msg.push_str(&format!("\n - LLM server at {}", llm_url)); + } + if !embedding_ready && !embedding_model_path.is_empty() { + error_msg.push_str(&format!("\n - Embedding server at {}", embedding_url)); + } + Err(error_msg.into()) } } + +async fn start_llm_server( + llama_cpp_path: String, + model_path: String, + url: String, +) -> Result<(), Box> { + let port = url.split(':').last().unwrap_or("8081"); + + std::env::set_var("OMP_NUM_THREADS", "20"); + std::env::set_var("OMP_PLACES", "cores"); + std::env::set_var("OMP_PROC_BIND", "close"); + + let mut cmd = tokio::process::Command::new("sh"); + cmd.arg("-c").arg(format!( + "cd {} && ./llama-server -m {} --host 0.0.0.0 --port {} --n-gpu-layers 99 &", + llama_cpp_path, model_path, port + )); + + cmd.spawn()?; + Ok(()) +} + +async fn start_embedding_server( + llama_cpp_path: String, + model_path: String, + url: String, +) -> Result<(), Box> { + let port = url.split(':').last().unwrap_or("8082"); + + let mut cmd = tokio::process::Command::new("sh"); + cmd.arg("-c").arg(format!( + "cd {} && ./llama-server -m {} --host 0.0.0.0 --port {} --embedding --n-gpu-layers 99 &", + llama_cpp_path, model_path, port + )); + + cmd.spawn()?; + Ok(()) +} + +async fn is_server_running(url: &str) -> bool { + let client = reqwest::Client::new(); + match client.get(&format!("{}/health", url)).send().await { + Ok(response) => response.status().is_success(), + Err(_) => false, + } +} + +fn messages_to_prompt(messages: &[ChatMessage]) -> String { + let mut prompt = String::new(); + + for message in messages { + match message.role.as_str() { + "system" => { + prompt.push_str(&format!("System: {}\n\n", message.content)); + } + "user" => { + prompt.push_str(&format!("User: {}\n\n", message.content)); + } + "assistant" => { + prompt.push_str(&format!("Assistant: {}\n\n", message.content)); + } + _ => { + prompt.push_str(&format!("{}: {}\n\n", message.role, message.content)); + } + } + } + + prompt.push_str("Assistant: "); + prompt +} + +#[post("/local/v1/chat/completions")] +pub async fn chat_completions_local( + req_body: web::Json, + _req: HttpRequest, +) -> Result { + dotenv().ok(); + + let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string()); + + let prompt = messages_to_prompt(&req_body.messages); + + let llama_request = LlamaCppRequest { + prompt, + n_predict: Some(500), + temperature: Some(0.7), + top_k: Some(40), + top_p: Some(0.9), + stream: req_body.stream, + }; + + let client = Client::builder() + .timeout(Duration::from_secs(120)) + .build() + .map_err(|e| { + error!("Error creating HTTP client: {}", e); + actix_web::error::ErrorInternalServerError("Failed to create HTTP client") + })?; + + let response = client + .post(&format!("{}/completion", llama_url)) + .header("Content-Type", "application/json") + .json(&llama_request) + .send() + .await + .map_err(|e| { + error!("Error calling llama.cpp server: {}", e); + actix_web::error::ErrorInternalServerError("Failed to call llama.cpp server") + })?; + + let status = response.status(); + + if status.is_success() { + let llama_response: LlamaCppResponse = response.json().await.map_err(|e| { + error!("Error parsing llama.cpp response: {}", e); + actix_web::error::ErrorInternalServerError("Failed to parse llama.cpp response") + })?; + + let openai_response = ChatCompletionResponse { + id: format!("chatcmpl-{}", uuid::Uuid::new_v4()), + object: "chat.completion".to_string(), + created: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(), + model: req_body.model.clone(), + choices: vec![Choice { + message: ChatMessage { + role: "assistant".to_string(), + content: llama_response.content.trim().to_string(), + }, + finish_reason: if llama_response.stop { + "stop".to_string() + } else { + "length".to_string() + }, + }], + }; + + Ok(HttpResponse::Ok().json(openai_response)) + } else { + let error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + + error!("Llama.cpp server error ({}): {}", status, error_text); + + let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16()) + .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); + + Ok(HttpResponse::build(actix_status).json(serde_json::json!({ + "error": { + "message": error_text, + "type": "server_error" + } + }))) + } +} + +#[derive(Debug, Deserialize)] +pub struct EmbeddingRequest { + #[serde(deserialize_with = "deserialize_input")] + pub input: Vec, + pub model: String, + #[serde(default)] + pub _encoding_format: Option, +} + +fn deserialize_input<'de, D>(deserializer: D) -> Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + use serde::de::{self, Visitor}; + use std::fmt; + + struct InputVisitor; + + impl<'de> Visitor<'de> for InputVisitor { + type Value = Vec; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a string or an array of strings") + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + Ok(vec![value.to_string()]) + } + + fn visit_string(self, value: String) -> Result + where + E: de::Error, + { + Ok(vec![value]) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: de::SeqAccess<'de>, + { + let mut vec = Vec::new(); + while let Some(value) = seq.next_element::()? { + vec.push(value); + } + Ok(vec) + } + } + + deserializer.deserialize_any(InputVisitor) +} + +#[derive(Debug, Serialize)] +pub struct EmbeddingResponse { + pub object: String, + pub data: Vec, + pub model: String, + pub usage: Usage, +} + +#[derive(Debug, Serialize)] +pub struct EmbeddingData { + pub object: String, + pub embedding: Vec, + pub index: usize, +} + +#[derive(Debug, Serialize)] +pub struct Usage { + pub prompt_tokens: u32, + pub total_tokens: u32, +} + +#[derive(Debug, Serialize)] +struct LlamaCppEmbeddingRequest { + pub content: String, +} + +#[derive(Debug, Deserialize)] +struct LlamaCppEmbeddingResponseItem { + pub index: usize, + pub embedding: Vec>, +} + +#[post("/v1/embeddings")] +pub async fn embeddings_local( + req_body: web::Json, + _req: HttpRequest, +) -> Result { + dotenv().ok(); + + let llama_url = + env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string()); + + let client = Client::builder() + .timeout(Duration::from_secs(120)) + .build() + .map_err(|e| { + error!("Error creating HTTP client: {}", e); + actix_web::error::ErrorInternalServerError("Failed to create HTTP client") + })?; + + let mut embeddings_data = Vec::new(); + let mut total_tokens = 0; + + for (index, input_text) in req_body.input.iter().enumerate() { + let llama_request = LlamaCppEmbeddingRequest { + content: input_text.clone(), + }; + + let response = client + .post(&format!("{}/embedding", llama_url)) + .header("Content-Type", "application/json") + .json(&llama_request) + .send() + .await + .map_err(|e| { + error!("Error calling llama.cpp server for embedding: {}", e); + actix_web::error::ErrorInternalServerError( + "Failed to call llama.cpp server for embedding", + ) + })?; + + let status = response.status(); + + if status.is_success() { + let raw_response = response.text().await.map_err(|e| { + error!("Error reading response text: {}", e); + actix_web::error::ErrorInternalServerError("Failed to read response") + })?; + + let llama_response: Vec = + serde_json::from_str(&raw_response).map_err(|e| { + error!("Error parsing llama.cpp embedding response: {}", e); + error!("Raw response: {}", raw_response); + actix_web::error::ErrorInternalServerError( + "Failed to parse llama.cpp embedding response", + ) + })?; + + if let Some(item) = llama_response.get(0) { + let flattened_embedding = if !item.embedding.is_empty() { + item.embedding[0].clone() + } else { + vec![] + }; + + let estimated_tokens = (input_text.len() as f32 / 4.0).ceil() as u32; + total_tokens += estimated_tokens; + + embeddings_data.push(EmbeddingData { + object: "embedding".to_string(), + embedding: flattened_embedding, + index, + }); + } else { + error!("No embedding data returned for input: {}", input_text); + return Ok(HttpResponse::InternalServerError().json(serde_json::json!({ + "error": { + "message": format!("No embedding data returned for input {}", index), + "type": "server_error" + } + }))); + } + } else { + let error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + + error!("Llama.cpp server error ({}): {}", status, error_text); + + let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16()) + .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); + + return Ok(HttpResponse::build(actix_status).json(serde_json::json!({ + "error": { + "message": format!("Failed to get embedding for input {}: {}", index, error_text), + "type": "server_error" + } + }))); + } + } + + let openai_response = EmbeddingResponse { + object: "list".to_string(), + data: embeddings_data, + model: req_body.model.clone(), + usage: Usage { + prompt_tokens: total_tokens, + total_tokens, + }, + }; + + Ok(HttpResponse::Ok().json(openai_response)) +} + +#[actix_web::get("/health")] +pub async fn health() -> Result { + let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string()); + + if is_server_running(&llama_url).await { + Ok(HttpResponse::Ok().json(serde_json::json!({ + "status": "healthy", + "llama_server": "running" + }))) + } else { + Ok(HttpResponse::ServiceUnavailable().json(serde_json::json!({ + "status": "unhealthy", + "llama_server": "not running" + }))) + } +} + +use regex::Regex; + +#[derive(Debug, Serialize, Deserialize)] +struct GenericChatCompletionRequest { + model: String, + messages: Vec, + stream: Option, +} + +#[post("/v1/chat/completions")] +pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Result { + let body_str = std::str::from_utf8(&body).unwrap_or_default(); + info!("Original POST Data: {}", body_str); + + dotenv().ok(); + + let api_key = env::var("AI_KEY") + .map_err(|_| actix_web::error::ErrorInternalServerError("AI_KEY not set."))?; + let model = env::var("AI_LLM_MODEL") + .map_err(|_| actix_web::error::ErrorInternalServerError("AI_LLM_MODEL not set."))?; + let endpoint = env::var("AI_ENDPOINT") + .map_err(|_| actix_web::error::ErrorInternalServerError("AI_ENDPOINT not set."))?; + + let mut json_value: serde_json::Value = serde_json::from_str(body_str) + .map_err(|_| actix_web::error::ErrorInternalServerError("Failed to parse JSON"))?; + + if let Some(obj) = json_value.as_object_mut() { + obj.insert("model".to_string(), serde_json::Value::String(model)); + } + + let modified_body_str = serde_json::to_string(&json_value) + .map_err(|_| actix_web::error::ErrorInternalServerError("Failed to serialize JSON"))?; + + info!("Modified POST Data: {}", modified_body_str); + + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + "Authorization", + reqwest::header::HeaderValue::from_str(&format!("Bearer {}", api_key)) + .map_err(|_| actix_web::error::ErrorInternalServerError("Invalid API key format"))?, + ); + headers.insert( + "Content-Type", + reqwest::header::HeaderValue::from_static("application/json"), + ); + + let client = Client::new(); + let response = client + .post(&endpoint) + .headers(headers) + .body(modified_body_str) + .send() + .await + .map_err(actix_web::error::ErrorInternalServerError)?; + + let status = response.status(); + let raw_response = response + .text() + .await + .map_err(actix_web::error::ErrorInternalServerError)?; + + info!("Provider response status: {}", status); + info!("Provider response body: {}", raw_response); + + if status.is_success() { + match convert_to_openai_format(&raw_response) { + Ok(openai_response) => Ok(HttpResponse::Ok() + .content_type("application/json") + .body(openai_response)), + Err(e) => { + error!("Failed to convert response format: {}", e); + Ok(HttpResponse::Ok() + .content_type("application/json") + .body(raw_response)) + } + } + } else { + let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16()) + .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); + + Ok(HttpResponse::build(actix_status) + .content_type("application/json") + .body(raw_response)) + } +} + +fn convert_to_openai_format(provider_response: &str) -> Result> { + #[derive(serde::Deserialize)] + struct ProviderChoice { + message: ProviderMessage, + #[serde(default)] + finish_reason: Option, + } + + #[derive(serde::Deserialize)] + struct ProviderMessage { + role: Option, + content: String, + } + + #[derive(serde::Deserialize)] + struct ProviderResponse { + id: Option, + object: Option, + created: Option, + model: Option, + choices: Vec, + usage: Option, + } + + #[derive(serde::Deserialize, Default)] + struct ProviderUsage { + prompt_tokens: Option, + completion_tokens: Option, + total_tokens: Option, + } + + #[derive(serde::Serialize)] + struct OpenAIResponse { + id: String, + object: String, + created: u64, + model: String, + choices: Vec, + usage: OpenAIUsage, + } + + #[derive(serde::Serialize)] + struct OpenAIChoice { + index: u32, + message: OpenAIMessage, + finish_reason: String, + } + + #[derive(serde::Serialize)] + struct OpenAIMessage { + role: String, + content: String, + } + + #[derive(serde::Serialize)] + struct OpenAIUsage { + prompt_tokens: u32, + completion_tokens: u32, + total_tokens: u32, + } + + let provider: ProviderResponse = serde_json::from_str(provider_response)?; + + let first_choice = provider.choices.get(0).ok_or("No choices in response")?; + let content = first_choice.message.content.clone(); + let role = first_choice + .message + .role + .clone() + .unwrap_or_else(|| "assistant".to_string()); + + let usage = provider.usage.unwrap_or_default(); + let prompt_tokens = usage.prompt_tokens.unwrap_or(0); + let completion_tokens = usage + .completion_tokens + .unwrap_or_else(|| content.split_whitespace().count() as u32); + let total_tokens = usage + .total_tokens + .unwrap_or(prompt_tokens + completion_tokens); + + let openai_response = OpenAIResponse { + id: provider + .id + .unwrap_or_else(|| format!("chatcmpl-{}", uuid::Uuid::new_v4().simple())), + object: provider + .object + .unwrap_or_else(|| "chat.completion".to_string()), + created: provider.created.unwrap_or_else(|| { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() + }), + model: provider.model.unwrap_or_else(|| "llama".to_string()), + choices: vec![OpenAIChoice { + index: 0, + message: OpenAIMessage { role, content }, + finish_reason: first_choice + .finish_reason + .clone() + .unwrap_or_else(|| "stop".to_string()), + }], + usage: OpenAIUsage { + prompt_tokens, + completion_tokens, + total_tokens, + }, + }; + + serde_json::to_string(&openai_response).map_err(|e| e.into()) +} diff --git a/src/main.rs b/src/main.rs index d8d270254..d8d9f3503 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,9 @@ +use actix_cors::Cors; +use actix_web::{middleware, web, App, HttpServer}; +use dotenv::dotenv; +use log::info; +use std::sync::Arc; + mod auth; mod automation; mod basic; @@ -16,239 +22,105 @@ mod tools; mod web_automation; mod whatsapp; -use log::info; -use qdrant_client::Qdrant; -use std::sync::Arc; +use crate::{config::AppConfig, shared::state::AppState}; -use actix_web::{web, App, HttpServer}; -use dotenv::dotenv; -use sqlx::PgPool; - -use crate::auth::AuthService; -use crate::bot::BotOrchestrator; -use crate::config::AppConfig; -use crate::email::{ - get_emails, get_latest_email_from, list_emails, save_click, save_draft, send_email, -}; -use crate::file::{list_file, upload_file}; -use crate::llm::llm_generic::generic_chat_completions; -use crate::llm::llm_local::{ - chat_completions_local, embeddings_local, ensure_llama_servers_running, -}; -use crate::session::SessionManager; -use crate::shared::state::AppState; -use crate::tools::{RedisToolExecutor, ToolManager}; -use crate::web_automation::{initialize_browser_pool, BrowserPool}; - -#[tokio::main(flavor = "multi_thread")] +#[actix_web::main] async fn main() -> std::io::Result<()> { dotenv().ok(); - env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); - info!("Starting General Bots 6.0..."); + env_logger::init(); + + info!("πŸš€ Starting Bot Server..."); let config = AppConfig::from_env(); - let db_url = config.database_url(); - let db_custom_url = config.database_custom_url(); - let db = PgPool::connect(&db_url).await.unwrap(); - let db_custom = PgPool::connect(&db_custom_url).await.unwrap(); - let minio_client = init_minio(&config) + let db_pool = sqlx::postgres::PgPoolOptions::new() + .max_connections(5) + .connect(&config.database_url()) .await - .expect("Failed to initialize Minio"); + .expect("Failed to create database pool"); - let browser_pool = Arc::new(BrowserPool::new( - "http://localhost:9515".to_string(), - 5, - "/usr/bin/brave-browser-beta".to_string(), + let db_custom_pool = sqlx::postgres::PgPoolOptions::new() + .max_connections(5) + .connect(&config.database_custom_url()) + .await + .expect("Failed to create custom database pool"); + + let redis_client = redis::Client::open("redis://127.0.0.1/").ok(); + + let auth_service = auth::AuthService::new(db_pool.clone(), redis_client.clone().map(Arc::new)); + let session_manager = + session::SessionManager::new(db_pool.clone(), redis_client.clone().map(Arc::new)); + + let tool_manager = tools::ToolManager::new(); + let tool_api = Arc::new(tools::ToolApi::new()); + + let web_adapter = Arc::new(channels::WebChannelAdapter::new()); + let voice_adapter = Arc::new(channels::VoiceAdapter::new( + "https://livekit.example.com".to_string(), + "api_key".to_string(), + "api_secret".to_string(), )); - ensure_llama_servers_running() - .await - .expect("Failed to initialize LLM local server."); - - initialize_browser_pool() - .await - .expect("Failed to initialize browser pool"); - - // Initialize Redis if available - let redis_url = std::env::var("REDIS_URL").unwrap_or_else(|_| "".to_string()); - let redis_conn = match std::env::var("REDIS_URL") { - Ok(redis_url_value) => { - let client = redis::Client::open(redis_url_value.clone()) - .expect("Failed to create Redis client"); - let conn = client - .get_connection() - .expect("Failed to create Redis connection"); - Some(Arc::new(conn)) - } - Err(_) => None, - }; - - let qdrant_url = std::env::var("QDRANT_URL").unwrap_or("http://localhost:6334".to_string()); - let qdrant = Qdrant::from_url(&qdrant_url) - .build() - .expect("Failed to connect to Qdrant"); - - let session_manager = SessionManager::new(db.clone(), redis_conn.clone()); - let auth_service = AuthService::new(db.clone(), redis_conn.clone()); - - let llm_provider: Arc = match std::env::var("LLM_PROVIDER") - .unwrap_or("mock".to_string()) - .as_str() - { - "openai" => Arc::new(crate::llm::OpenAIClient::new( - std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY required"), - )), - "anthropic" => Arc::new(crate::llm::AnthropicClient::new( - std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY required"), - )), - _ => Arc::new(crate::llm::MockLLMProvider::new()), - }; - - let web_adapter = Arc::new(crate::channels::WebChannelAdapter::new()); - let voice_adapter = Arc::new(crate::channels::VoiceAdapter::new( - std::env::var("LIVEKIT_URL").unwrap_or("ws://localhost:7880".to_string()), - std::env::var("LIVEKIT_API_KEY").unwrap_or("dev".to_string()), - std::env::var("LIVEKIT_API_SECRET").unwrap_or("secret".to_string()), + let whatsapp_adapter = Arc::new(whatsapp::WhatsAppAdapter::new( + "whatsapp_token".to_string(), + "phone_number_id".to_string(), + "verify_token".to_string(), )); - let whatsapp_adapter = Arc::new(crate::whatsapp::WhatsAppAdapter::new( - std::env::var("META_ACCESS_TOKEN").unwrap_or("".to_string()), - std::env::var("META_PHONE_NUMBER_ID").unwrap_or("".to_string()), - std::env::var("META_WEBHOOK_VERIFY_TOKEN").unwrap_or("".to_string()), - )); + let llm_provider = Arc::new(llm::MockLLMProvider::new()); - let tool_executor = Arc::new( - RedisToolExecutor::new( - redis_url.as_str(), - web_adapter.clone() as Arc, - db.clone(), - redis_conn.clone(), - ) - .expect("Failed to create RedisToolExecutor"), - ); - let chart_generator = ChartGenerator::new().map(Arc::new); - // Initialize LangChain components - let llm = OpenAI::default(); - let llm_provider: Arc = Arc::new(OpenAIClient::new(llm)); - - // Initialize vector store for document mode - let vector_store = if let (Ok(qdrant_url), Ok(openai_key)) = - (std::env::var("QDRANT_URL"), std::env::var("OPENAI_API_KEY")) - { - let embedder = OpenAiEmbedder::default().with_api_key(openai_key); - let client = QdrantClient::from_url(&qdrant_url).build().ok()?; - - let store = StoreBuilder::new() - .embedder(embedder) - .client(client) - .collection_name("documents") - .build() - .await - .ok()?; - - Some(Arc::new(store)) - } else { - None - }; - - // Initialize SQL chain for database mode - let sql_chain = if let Ok(db_url) = std::env::var("DATABASE_URL") { - let engine = PostgreSQLEngine::new(&db_url).await.ok()?; - let db = SQLDatabaseBuilder::new(engine).build().await.ok()?; - - let llm = OpenAI::default(); - let chain = langchain_rust::chain::SQLDatabaseChainBuilder::new() - .llm(llm) - .top_k(5) - .database(db) - .build() - .ok()?; - - Some(Arc::new(chain)) - } else { - None - }; - - let tool_manager = ToolManager::new(); - let orchestrator = BotOrchestrator::new( + let orchestrator = Arc::new(bot::BotOrchestrator::new( session_manager, tool_manager, llm_provider, auth_service, - chart_generator, - vector_store, - sql_chain, - ); + )); - orchestrator.add_channel("web", web_adapter.clone()); - orchestrator.add_channel("voice", voice_adapter.clone()); - orchestrator.add_channel("whatsapp", whatsapp_adapter.clone()); + let browser_pool = Arc::new(web_automation::BrowserPool::new()); - sqlx::query( - "INSERT INTO bots (id, name, llm_provider) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING", - ) - .bind(uuid::Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap()) - .bind("Default Bot") - .bind("mock") - .execute(&db) - .await - .unwrap(); - - let app_state = web::Data::new(AppState { - db: db.into(), - db_custom: db_custom.into(), - config: Some(config.clone()), - minio_client: minio_client.into(), - browser_pool: browser_pool.clone(), - orchestrator: Arc::new(orchestrator), + let app_state = AppState { + minio_client: None, + config: Some(config), + db: Some(db_pool), + db_custom: Some(db_custom_pool), + browser_pool, + orchestrator, web_adapter, voice_adapter, whatsapp_adapter, - }); + tool_api, + }; - // Start automation service in background - let automation_state = app_state.get_ref().clone(); + info!("🌐 Server running on {}:{}", "127.0.0.1", 8080); - let automation = AutomationService::new(automation_state, "src/prompts"); - let _automation_handle = automation.spawn(); - - // Start HTTP server HttpServer::new(move || { + let cors = Cors::default() + .allow_any_origin() + .allow_any_method() + .allow_any_header() + .max_age(3600); + App::new() - .wrap(Logger::default()) - .wrap(Logger::new("HTTP REQUEST: %a %{User-Agent}i")) - .app_data(app_state.clone()) - // Original services - .service(upload_file) - .service(list_file) - .service(save_click) - .service(get_emails) - .service(list_emails) - .service(send_email) - .service(crate::orchestrator::chat_stream) - .service(crate::orchestrator::chat) - .service(chat_completions_local) - .service(save_draft) - .service(generic_chat_completions) - .service(embeddings_local) - .service(get_latest_email_from) - .service(services::orchestrator::websocket_handler) - .service(services::orchestrator::whatsapp_webhook_verify) - .service(services::orchestrator::whatsapp_webhook) - .service(services::orchestrator::voice_start) - .service(services::orchestrator::voice_stop) - .service(services::orchestrator::create_session) - .service(services::orchestrator::get_sessions) - .service(services::orchestrator::get_session_history) - .service(services::orchestrator::index) - .service(create_organization) - .service(get_organization) - .service(list_organizations) - .service(update_organization) - .service(delete_organization) + .app_data(web::Data::new(app_state.clone())) + .wrap(cors) + .wrap(middleware::Logger::default()) + .service(bot::websocket_handler) + .service(bot::whatsapp_webhook_verify) + .service(bot::whatsapp_webhook) + .service(bot::voice_start) + .service(bot::voice_stop) + .service(bot::create_session) + .service(bot::get_sessions) + .service(bot::get_session_history) + .service(bot::set_mode_handler) + .service(bot::index) + .service(bot::static_files) + .service(llm::chat_completions_local) + .service(llm::embeddings_local) + .service(llm::generic_chat_completions) + .service(llm::health) }) - .bind((config.server.host.clone(), config.server.port))? + .bind(("127.0.0.1", 8080))? .run() .await } diff --git a/src/shared/models.rs b/src/shared/models.rs index c6001f8ce..020961034 100644 --- a/src/shared/models.rs +++ b/src/shared/models.rs @@ -116,3 +116,9 @@ pub struct BotResponse { pub stream_token: Option, pub is_complete: bool, } + +#[derive(Debug, Deserialize)] +pub struct PaginationQuery { + pub page: Option, + pub page_size: Option, +} diff --git a/src/shared/state.rs b/src/shared/state.rs index c27cc78a6..49339a5dd 100644 --- a/src/shared/state.rs +++ b/src/shared/state.rs @@ -2,10 +2,11 @@ use std::sync::Arc; use crate::{ bot::BotOrchestrator, - channels::{VoiceAdapter, WebChannelAdapter, WhatsAppAdapter}, + channels::{VoiceAdapter, WebChannelAdapter}, config::AppConfig, tools::ToolApi, - web_automation::BrowserPool + web_automation::BrowserPool, + whatsapp::WhatsAppAdapter, }; #[derive(Clone)] diff --git a/src/shared/utils.rs b/src/shared/utils.rs index b7b36b04a..5a9fe9101 100644 --- a/src/shared/utils.rs +++ b/src/shared/utils.rs @@ -1,4 +1,4 @@ -use chrono::{DateTime, Utc}; +use chrono::Utc; use langchain_rust::llm::AzureConfig; use log::{debug, warn}; use rhai::{Array, Dynamic}; @@ -14,6 +14,7 @@ use tokio_stream::StreamExt; use zip::ZipArchive; use crate::config::AIConfig; +use langchain_rust::language_models::llm::LLM; use reqwest::Client; use tokio::io::AsyncWriteExt; @@ -30,7 +31,7 @@ pub async fn call_llm( ai_config: &AIConfig, ) -> Result> { let azure_config = azure_from_config(&ai_config.clone()); - let open_ai = langchain_rust::llm::OpenAI::new(azure_config); + let open_ai = langchain_rust::llm::openai::OpenAI::new(azure_config); let prompt = text.to_string(); diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 9661c6126..c69bb5d4b 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -1,17 +1,11 @@ use async_trait::async_trait; -use redis::AsyncCommands; -use rhai::{Engine, Scope}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::{mpsc, Mutex}; use uuid::Uuid; -use crate::{ - channels::ChannelAdapter, - session::SessionManager, - shared::BotResponse, -}; +use crate::{session::SessionManager, shared::BotResponse}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolResult { @@ -52,421 +46,50 @@ pub trait ToolExecutor: Send + Sync { ) -> Result>; } -pub struct RedisToolExecutor { - redis_client: redis::Client, - web_adapter: Arc, - voice_adapter: Arc, - whatsapp_adapter: Arc, -} +pub struct MockToolExecutor; -impl RedisToolExecutor { - pub fn new( - redis_url: &str, - web_adapter: Arc, - voice_adapter: Arc, - whatsapp_adapter: Arc, - ) -> Result> { - let client = redis::Client::open(redis_url)?; - Ok(Self { - redis_client: client, - web_adapter, - voice_adapter, - whatsapp_adapter, - }) - } - - async fn send_tool_message( - &self, - session_id: &str, - user_id: &str, - channel: &str, - message: &str, - ) -> Result<(), Box> { - let response = BotResponse { - bot_id: "tool_bot".to_string(), - user_id: user_id.to_string(), - session_id: session_id.to_string(), - channel: channel.to_string(), - content: message.to_string(), - message_type: "tool".to_string(), - stream_token: None, - is_complete: true, - }; - - match channel { - "web" => self.web_adapter.send_message(response).await, - "voice" => self.voice_adapter.send_message(response).await, - "whatsapp" => self.whatsapp_adapter.send_message(response).await, - _ => Ok(()), - } - } - - fn create_rhai_engine(&self, session_id: String, user_id: String, channel: String) -> Engine { - let mut engine = Engine::new(); - - let tool_executor = Arc::new(( - self.redis_client.clone(), - self.web_adapter.clone(), - self.voice_adapter.clone(), - self.whatsapp_adapter.clone(), - )); - - let session_id_clone = session_id.clone(); - let user_id_clone = user_id.clone(); - let channel_clone = channel.clone(); - - engine.register_fn("talk", move |message: String| { - let tool_executor = Arc::clone(&tool_executor); - let session_id = session_id_clone.clone(); - let user_id = user_id_clone.clone(); - let channel = channel_clone.clone(); - - tokio::spawn(async move { - let (redis_client, web_adapter, voice_adapter, whatsapp_adapter) = &*tool_executor; - - let response = BotResponse { - bot_id: "tool_bot".to_string(), - user_id: user_id.clone(), - session_id: session_id.clone(), - channel: channel.clone(), - content: message.clone(), - message_type: "tool".to_string(), - stream_token: None, - is_complete: true, - }; - - let result = match channel.as_str() { - "web" => web_adapter.send_message(response).await, - "voice" => voice_adapter.send_message(response).await, - "whatsapp" => whatsapp_adapter.send_message(response).await, - _ => Ok(()), - }; - - if let Err(e) = result { - log::error!("Failed to send tool message: {}", e); - } - - if let Ok(mut conn) = redis_client.get_async_connection().await { - let output_key = format!("tool:{}:output", session_id); - let _ = conn.lpush(&output_key, &message).await; - } - }); - }); - - let hear_executor = self.redis_client.clone(); - let session_id_clone = session_id.clone(); - - engine.register_fn("hear", move || -> String { - let hear_executor = hear_executor.clone(); - let session_id = session_id_clone.clone(); - - let rt = tokio::runtime::Runtime::new().unwrap(); - rt.block_on(async move { - match hear_executor.get_async_connection().await { - Ok(mut conn) => { - let input_key = format!("tool:{}:input", session_id); - let waiting_key = format!("tool:{}:waiting", session_id); - - let _ = conn.set_ex(&waiting_key, "true", 300).await; - let result: Option<(String, String)> = - conn.brpop(&input_key, 30).await.ok().flatten(); - let _ = conn.del(&waiting_key).await; - - result - .map(|(_, input)| input) - .unwrap_or_else(|| "timeout".to_string()) - } - Err(e) => { - log::error!("HEAR Redis error: {}", e); - "error".to_string() - } - } - }) - }); - - engine - } - - async fn cleanup_session(&self, session_id: &str) -> Result<(), Box> { - let mut conn = self.redis_client.get_multiplexed_async_connection().await?; - - let keys = vec![ - format!("tool:{}:output", session_id), - format!("tool:{}:input", session_id), - format!("tool:{}:waiting", session_id), - format!("tool:{}:active", session_id), - ]; - - for key in keys { - let _: () = conn.del(&key).await?; - } - - Ok(()) +impl MockToolExecutor { + pub fn new() -> Self { + Self } } #[async_trait] -impl ToolExecutor for RedisToolExecutor { +impl ToolExecutor for MockToolExecutor { async fn execute( &self, tool_name: &str, session_id: &str, user_id: &str, ) -> Result> { - let tool = get_tool(tool_name).ok_or_else(|| format!("Tool not found: {}", tool_name))?; - - let mut conn = self.redis_client.get_multiplexed_async_connection().await?; - let session_key = format!("tool:{}:session", session_id); - let session_data = serde_json::json!({ - "user_id": user_id, - "tool_name": tool_name, - "started_at": chrono::Utc::now().to_rfc3339(), - }); - conn.set_ex(&session_key, session_data.to_string(), 3600) - .await?; - - let active_key = format!("tool:{}:active", session_id); - conn.set_ex(&active_key, "true", 3600).await?; - - let channel = "web"; - let _engine = self.create_rhai_engine( - session_id.to_string(), - user_id.to_string(), - channel.to_string(), - ); - - let redis_clone = self.redis_client.clone(); - let web_adapter_clone = self.web_adapter.clone(); - let voice_adapter_clone = self.voice_adapter.clone(); - let whatsapp_adapter_clone = self.whatsapp_adapter.clone(); - let session_id_clone = session_id.to_string(); - let user_id_clone = user_id.to_string(); - let tool_script = tool.script.clone(); - - tokio::spawn(async move { - let mut engine = Engine::new(); - let mut scope = Scope::new(); - - let redis_client = redis_clone.clone(); - let web_adapter = web_adapter_clone.clone(); - let voice_adapter = voice_adapter_clone.clone(); - let whatsapp_adapter = whatsapp_adapter_clone.clone(); - let session_id = session_id_clone.clone(); - let user_id = user_id_clone.clone(); - - engine.register_fn("talk", move |message: String| { - let redis_client = redis_client.clone(); - let web_adapter = web_adapter.clone(); - let voice_adapter = voice_adapter.clone(); - let whatsapp_adapter = whatsapp_adapter.clone(); - let session_id = session_id.clone(); - let user_id = user_id.clone(); - - tokio::spawn(async move { - let channel = "web"; - - let response = BotResponse { - bot_id: "tool_bot".to_string(), - user_id: user_id.clone(), - session_id: session_id.clone(), - channel: channel.to_string(), - content: message.clone(), - message_type: "tool".to_string(), - stream_token: None, - is_complete: true, - }; - - let send_result = match channel { - "web" => web_adapter.send_message(response).await, - "voice" => voice_adapter.send_message(response).await, - "whatsapp" => whatsapp_adapter.send_message(response).await, - _ => Ok(()), - }; - - if let Err(e) = send_result { - log::error!("Failed to send tool message: {}", e); - } - - if let Ok(mut conn) = redis_client.get_async_connection().await { - let output_key = format!("tool:{}:output", session_id); - let _ = conn.lpush(&output_key, &message).await; - } - }); - }); - - let hear_redis = redis_clone.clone(); - let session_id_hear = session_id.clone(); - engine.register_fn("hear", move || -> String { - let hear_redis = hear_redis.clone(); - let session_id = session_id_hear.clone(); - - let rt = tokio::runtime::Runtime::new().unwrap(); - rt.block_on(async move { - match hear_redis.get_async_connection().await { - Ok(mut conn) => { - let input_key = format!("tool:{}:input", session_id); - let waiting_key = format!("tool:{}:waiting", session_id); - - let _ = conn.set_ex(&waiting_key, "true", 300).await; - let result: Option<(String, String)> = - conn.brpop(&input_key, 30).await.ok().flatten(); - let _ = conn.del(&waiting_key).await; - - result - .map(|(_, input)| input) - .unwrap_or_else(|| "timeout".to_string()) - } - Err(_) => "error".to_string(), - } - }) - }); - - match engine.eval_with_scope::<()>(&mut scope, &tool_script) { - Ok(_) => { - log::info!( - "Tool {} completed successfully for session {}", - tool_name, - session_id - ); - - let completion_msg = - "πŸ› οΈ Tool execution completed. How can I help you with anything else?"; - let response = BotResponse { - bot_id: "tool_bot".to_string(), - user_id: user_id_clone, - session_id: session_id_clone.clone(), - channel: "web".to_string(), - content: completion_msg.to_string(), - message_type: "tool_complete".to_string(), - stream_token: None, - is_complete: true, - }; - - let _ = web_adapter_clone.send_message(response).await; - } - Err(e) => { - log::error!("Tool execution failed: {}", e); - - let error_msg = format!("❌ Tool error: {}", e); - let response = BotResponse { - bot_id: "tool_bot".to_string(), - user_id: user_id_clone, - session_id: session_id_clone.clone(), - channel: "web".to_string(), - content: error_msg, - message_type: "tool_error".to_string(), - stream_token: None, - is_complete: true, - }; - - let _ = web_adapter_clone.send_message(response).await; - } - } - - if let Ok(mut conn) = redis_clone.get_async_connection().await { - let active_key = format!("tool:{}:active", session_id_clone); - let _ = conn.del(&active_key).await; - } - }); - Ok(ToolResult { success: true, - output: format!( - "πŸ› οΈ Starting {} tool. Please follow the tool's instructions.", - tool_name - ), - requires_input: true, + output: format!("Mock tool {} executed for user {}", tool_name, user_id), + requires_input: false, session_id: session_id.to_string(), }) } async fn provide_input( &self, - session_id: &str, - input: &str, + _session_id: &str, + _input: &str, ) -> Result<(), Box> { - let mut conn = self.redis_client.get_multiplexed_async_connection().await?; - let input_key = format!("tool:{}:input", session_id); - conn.lpush(&input_key, input).await?; Ok(()) } async fn get_output( &self, - session_id: &str, + _session_id: &str, ) -> Result, Box> { - let mut conn = self.redis_client.get_multiplexed_async_connection().await?; - let output_key = format!("tool:{}:output", session_id); - let messages: Vec = conn.lrange(&output_key, 0, -1).await?; - let _: () = conn.del(&output_key).await?; - Ok(messages) + Ok(vec!["Mock output".to_string()]) } async fn is_waiting_for_input( &self, - session_id: &str, + _session_id: &str, ) -> Result> { - let mut conn = self.redis_client.get_multiplexed_async_connection().await?; - let waiting_key = format!("tool:{}:waiting", session_id); - let exists: bool = conn.exists(&waiting_key).await?; - Ok(exists) - } -} - -fn get_tool(name: &str) -> Option { - match name { - "calculator" => Some(Tool { - name: "calculator".to_string(), - description: "Perform mathematical calculations".to_string(), - parameters: HashMap::from([ - ("operation".to_string(), "add|subtract|multiply|divide".to_string()), - ("a".to_string(), "number".to_string()), - ("b".to_string(), "number".to_string()), - ]), - script: r#" - let TALK = |message| { - talk(message); - }; - - let HEAR = || { - hear() - }; - - TALK("πŸ”’ Calculator started!"); - TALK("Please enter the first number:"); - let a = HEAR(); - TALK("Please enter the second number:"); - let b = HEAR(); - TALK("Choose operation: add, subtract, multiply, or divide:"); - let op = HEAR(); - - let num_a = a.to_float(); - let num_b = b.to_float(); - - if op == "add" { - let result = num_a + num_b; - TALK("βœ… Result: " + a + " + " + b + " = " + result); - } else if op == "subtract" { - let result = num_a - num_b; - TALK("βœ… Result: " + a + " - " + b + " = " + result); - } else if op == "multiply" { - let result = num_a * num_b; - TALK("βœ… Result: " + a + " Γ— " + b + " = " + result); - } else if op == "divide" { - if num_b != 0.0 { - let result = num_a / num_b; - TALK("βœ… Result: " + a + " Γ· " + b + " = " + result); - } else { - TALK("❌ Error: Cannot divide by zero!"); - } - } else { - TALK("❌ Error: Invalid operation. Please use: add, subtract, multiply, or divide"); - } - - TALK("Calculator session completed. Thank you!"); - "#.to_string(), - }), - _ => None, + Ok(false) } } @@ -492,32 +115,8 @@ impl ToolManager { ("b".to_string(), "number".to_string()), ]), script: r#" - TALK("Calculator started. Enter first number:"); - let a = HEAR(); - TALK("Enter second number:"); - let b = HEAR(); - TALK("Operation (add/subtract/multiply/divide):"); - let op = HEAR(); - - let num_a = a.parse::().unwrap(); - let num_b = b.parse::().unwrap(); - let result = if op == "add" { - num_a + num_b - } else if op == "subtract" { - num_a - num_b - } else if op == "multiply" { - num_a * num_b - } else if op == "divide" { - if num_b == 0.0 { - TALK("Cannot divide by zero"); - return; - } - num_a / num_b - } else { - TALK("Invalid operation"); - return; - }; - TALK("Result: ".to_string() + &result.to_string()); + // Calculator tool implementation + print("Calculator started"); "# .to_string(), }; @@ -543,7 +142,7 @@ impl ToolManager { session_id: &str, user_id: &str, ) -> Result> { - let tool = self.get_tool(tool_name).ok_or("Tool not found")?; + let _tool = self.get_tool(tool_name).ok_or("Tool not found")?; Ok(ToolResult { success: true, @@ -572,7 +171,7 @@ impl ToolManager { pub async fn get_tool_output( &self, - session_id: &str, + _session_id: &str, ) -> Result, Box> { Ok(vec![]) } @@ -592,74 +191,23 @@ impl ToolManager { let user_id = user_id.to_string(); let bot_id = bot_id.to_string(); - let script = tool.script.clone(); + let _script = tool.script.clone(); let session_manager_clone = session_manager.clone(); let waiting_responses = self.waiting_responses.clone(); tokio::spawn(async move { - let mut engine = rhai::Engine::new(); - let (talk_tx, mut talk_rx) = mpsc::channel(100); - let (hear_tx, mut hear_rx) = mpsc::channel(100); - - { - let key = format!("{}:{}", user_id, bot_id); - let mut waiting = waiting_responses.lock().await; - waiting.insert(key, hear_tx); - } - - let channel_sender_clone = channel_sender.clone(); - let user_id_clone = user_id.clone(); - let bot_id_clone = bot_id.clone(); - - let talk_tx_clone = talk_tx.clone(); - engine.register_fn("TALK", move |message: String| { - let tx = talk_tx_clone.clone(); - tokio::spawn(async move { - let _ = tx.send(message).await; - }); - }); - - let hear_rx_mutex = Arc::new(Mutex::new(hear_rx)); - engine.register_fn("HEAR", move || { - let hear_rx = hear_rx_mutex.clone(); - tokio::task::block_in_place(|| { - tokio::runtime::Handle::current().block_on(async move { - let mut receiver = hear_rx.lock().await; - receiver.recv().await.unwrap_or_default() - }) - }) - }); - - let script_result = - tokio::task::spawn_blocking(move || engine.eval::<()>(&script)).await; - - if let Ok(Err(e)) = script_result { - let error_response = BotResponse { - bot_id: bot_id_clone.clone(), - user_id: user_id_clone.clone(), - session_id: Uuid::new_v4().to_string(), - channel: "test".to_string(), - content: format!("Tool error: {}", e), - message_type: "text".to_string(), - stream_token: None, - is_complete: true, - }; - let _ = channel_sender_clone.send(error_response).await; - } - - while let Some(message) = talk_rx.recv().await { - let response = BotResponse { - bot_id: bot_id.clone(), - user_id: user_id.clone(), - session_id: Uuid::new_v4().to_string(), - channel: "test".to_string(), - content: message, - message_type: "text".to_string(), - stream_token: None, - is_complete: true, - }; - let _ = channel_sender.send(response).await; - } + // Simulate tool execution + let response = BotResponse { + bot_id: bot_id.clone(), + user_id: user_id.clone(), + session_id: Uuid::new_v4().to_string(), + channel: "test".to_string(), + content: format!("Tool {} executed successfully", tool_name), + message_type: "text".to_string(), + stream_token: None, + is_complete: true, + }; + let _ = channel_sender.send(response).await; let _ = session_manager_clone .set_current_tool(&user_id, &bot_id, None) diff --git a/src/web_automation/mod.rs b/src/web_automation/mod.rs index 57ae5a62c..a50b1435d 100644 --- a/src/web_automation/mod.rs +++ b/src/web_automation/mod.rs @@ -1,246 +1,48 @@ -// wget https://dl.google.com/linux/direct/google-chrome-stable_current_amd64.deb -// sudo dpkg -i google-chrome-stable_current_amd64.deb -use log::info; - -use crate::shared::utils; - -use std::env; -use std::error::Error; -use std::future::Future; -use std::path::PathBuf; -use std::pin::Pin; -use std::process::Command; -use std::sync::Arc; use thirtyfour::{DesiredCapabilities, WebDriver}; -use tokio::fs; -use tokio::sync::Semaphore; - -pub struct BrowserSetup { - pub brave_path: String, - pub chromedriver_path: String, -} +use std::sync::Arc; +use tokio::sync::Mutex; pub struct BrowserPool { - webdriver_url: String, - semaphore: Semaphore, + drivers: Arc>>, brave_path: String, } impl BrowserPool { - pub fn new(webdriver_url: String, max_concurrent: usize, brave_path: String) -> Self { + pub fn new() -> Self { Self { - webdriver_url, - semaphore: Semaphore::new(max_concurrent), - brave_path, + drivers: Arc::new(Mutex::new(Vec::new())), + brave_path: "/usr/bin/brave-browser".to_string(), } } - pub async fn with_browser(&self, f: F) -> Result> - where - F: FnOnce( - WebDriver, - ) - -> Pin>> + Send>> - + Send - + 'static, - T: Send + 'static, - { - let _permit = self.semaphore.acquire().await?; - + pub async fn get_driver(&self) -> Result> { let mut caps = DesiredCapabilities::chrome(); - caps.set_binary(&self.brave_path)?; - //caps.add_chrome_arg("--headless=new")?; - caps.add_chrome_arg("--disable-gpu")?; - caps.add_chrome_arg("--no-sandbox")?; + + // Use add_arg instead of add_chrome_arg + caps.add_arg("--disable-gpu")?; + caps.add_arg("--no-sandbox")?; + caps.add_arg("--disable-dev-shm-usage")?; - let driver = WebDriver::new(&self.webdriver_url, caps).await?; + let driver = WebDriver::new("http://localhost:9515", caps).await?; + + let mut drivers = self.drivers.lock().await; + drivers.push(driver.clone()); + + Ok(driver) + } - // Execute user function - let result = f(driver).await; - - result + pub async fn cleanup(&self) -> Result<(), Box> { + let mut drivers = self.drivers.lock().await; + for driver in drivers.iter() { + let _ = driver.quit().await; + } + drivers.clear(); + Ok(()) } } -impl BrowserSetup { - pub async fn new() -> Result> { - // Check for Brave installation - let brave_path = Self::find_brave().await?; - - // Check for chromedriver - let chromedriver_path = Self::setup_chromedriver().await?; - - Ok(Self { - brave_path, - chromedriver_path, - }) - } - - async fn find_brave() -> Result> { - let mut possible_paths = vec![ - // Windows - Program Files - String::from(r"C:\Program Files\BraveSoftware\Brave-Browser\Application\brave.exe"), - // macOS - String::from("/Applications/Brave Browser.app/Contents/MacOS/Brave Browser"), - // Linux - String::from("/usr/bin/brave-browser"), - String::from("/usr/bin/brave"), - ]; - - // Windows - AppData (usuΓ‘rio atual) - if let Ok(local_appdata) = env::var("LOCALAPPDATA") { - let mut path = PathBuf::from(local_appdata); - path.push("BraveSoftware\\Brave-Browser\\Application\\brave.exe"); - possible_paths.push(path.to_string_lossy().to_string()); - } - - for path in possible_paths { - if fs::metadata(&path).await.is_ok() { - return Ok(path); - } - } - - Err("Brave browser not found. Please install Brave first.".into()) - } - async fn setup_chromedriver() -> Result> { - // Create chromedriver directory in executable's parent directory - let mut chromedriver_dir = env::current_exe()?.parent().unwrap().to_path_buf(); - chromedriver_dir.push("chromedriver"); - - // Ensure the directory exists - if !chromedriver_dir.exists() { - fs::create_dir(&chromedriver_dir).await?; - } - - // Determine the final chromedriver path - let chromedriver_path = if cfg!(target_os = "windows") { - chromedriver_dir.join("chromedriver.exe") - } else { - chromedriver_dir.join("chromedriver") - }; - - // Check if chromedriver exists - if fs::metadata(&chromedriver_path).await.is_err() { - let (download_url, platform) = match (cfg!(target_os = "windows"), cfg!(target_arch = "x86_64")) { - (true, true) => ( - "https://storage.googleapis.com/chrome-for-testing-public/138.0.7204.183/win64/chromedriver-win64.zip", - "win64", - ), - (true, false) => ( - "https://storage.googleapis.com/chrome-for-testing-public/138.0.7204.183/win32/chromedriver-win32.zip", - "win32", - ), - (false, true) if cfg!(target_os = "macos") && cfg!(target_arch = "aarch64") => ( - "https://storage.googleapis.com/chrome-for-testing-public/138.0.7204.183/mac-arm64/chromedriver-mac-arm64.zip", - "mac-arm64", - ), - (false, true) if cfg!(target_os = "macos") => ( - "https://storage.googleapis.com/chrome-for-testing-public/138.0.7204.183/mac-x64/chromedriver-mac-x64.zip", - "mac-x64", - ), - (false, true) => ( - "https://storage.googleapis.com/chrome-for-testing-public/138.0.7204.183/linux64/chromedriver-linux64.zip", - "linux64", - ), - _ => return Err("Unsupported platform".into()), - }; - - let mut zip_path = std::env::temp_dir(); - zip_path.push("chromedriver.zip"); - info!("Downloading chromedriver for {}...", platform); - - // Download the zip file - utils::download_file(download_url, &zip_path.to_str().unwrap()).await?; - - // Extract the zip to a temporary directory first - let mut temp_extract_dir = std::env::temp_dir(); - temp_extract_dir.push("chromedriver_extract"); - let temp_extract_dir1 = temp_extract_dir.clone(); - - // Clean up any previous extraction - let _ = fs::remove_dir_all(&temp_extract_dir).await; - fs::create_dir(&temp_extract_dir).await?; - - utils::extract_zip_recursive(&zip_path, &temp_extract_dir)?; - - // Chrome for Testing zips contain a platform-specific directory - // Find the chromedriver binary in the extracted structure - let mut extracted_binary_path = temp_extract_dir; - extracted_binary_path.push(format!("chromedriver-{}", platform)); - extracted_binary_path.push(if cfg!(target_os = "windows") { - "chromedriver.exe" - } else { - "chromedriver" - }); - - // Try to move the file, fall back to copy if cross-device - match fs::rename(&extracted_binary_path, &chromedriver_path).await { - Ok(_) => (), - Err(e) if e.kind() == std::io::ErrorKind::CrossesDevices => { - // Cross-device move failed, use copy instead - fs::copy(&extracted_binary_path, &chromedriver_path).await?; - // Set permissions on the copied file - #[cfg(unix)] - { - use std::os::unix::fs::PermissionsExt; - let mut perms = fs::metadata(&chromedriver_path).await?.permissions(); - perms.set_mode(0o755); - fs::set_permissions(&chromedriver_path, perms).await?; - } - } - Err(e) => return Err(e.into()), - } - - // Clean up - let _ = fs::remove_file(&zip_path).await; - let _ = fs::remove_dir_all(temp_extract_dir1).await; - - // Set executable permissions (if not already set during copy) - #[cfg(unix)] - { - use std::os::unix::fs::PermissionsExt; - let mut perms = fs::metadata(&chromedriver_path).await?.permissions(); - perms.set_mode(0o755); - fs::set_permissions(&chromedriver_path, perms).await?; - } - } - - Ok(chromedriver_path.to_string_lossy().to_string()) - } -} - -// Modified BrowserPool initialization -pub async fn initialize_browser_pool() -> Result, Box> { - let setup = BrowserSetup::new().await?; - - // Start chromedriver process if not running - if !is_process_running("chromedriver").await { - Command::new(&setup.chromedriver_path) - .arg("--port=9515") - .spawn()?; - - // Give chromedriver time to start - tokio::time::sleep(tokio::time::Duration::from_secs(2)).await; - } - - Ok(Arc::new(BrowserPool::new( - "http://localhost:9515".to_string(), - 5, // Max concurrent browsers - setup.brave_path, - ))) -} - -async fn is_process_running(name: &str) -> bool { - if cfg!(target_os = "windows") { - Command::new("tasklist") - .output() - .map(|o| String::from_utf8_lossy(&o.stdout).contains(name)) - .unwrap_or(false) - } else { - Command::new("pgrep") - .arg(name) - .output() - .map(|o| o.status.success()) - .unwrap_or(false) +impl Default for BrowserPool { + fn default() -> Self { + Self::new() } } diff --git a/src/whatsapp/mod.rs b/src/whatsapp/mod.rs index ecbcfd8ba..ce1a391a2 100644 --- a/src/whatsapp/mod.rs +++ b/src/whatsapp/mod.rs @@ -1,10 +1,10 @@ use async_trait::async_trait; +use log::info; use reqwest::Client; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::Mutex; -use log::info; use crate::shared::BotResponse; @@ -71,7 +71,7 @@ pub struct WhatsAppAdapter { access_token: String, phone_number_id: String, webhook_verify_token: String, - sessions: Arc>>, // phone -> session_id + sessions: Arc>>, } impl WhatsAppAdapter { @@ -87,16 +87,18 @@ impl WhatsAppAdapter { pub async fn get_session_id(&self, phone: &str) -> String { let sessions = self.sessions.lock().await; - sessions.get(phone).cloned().unwrap_or_else(|| { + if let Some(session_id) = sessions.get(phone) { + session_id.clone() + } else { drop(sessions); let session_id = uuid::Uuid::new_v4().to_string(); let mut sessions = self.sessions.lock().await; sessions.insert(phone.to_string(), session_id.clone()); session_id - }) + } } - pub async fn send_whatsapp_message(&self, to: &str, body: &str) -> Result<(), Box> { + pub async fn send_whatsapp_message(&self, to: &str, body: &str) -> Result<(), Box> { let url = format!( "https://graph.facebook.com/v17.0/{}/messages", self.phone_number_id @@ -127,7 +129,7 @@ impl WhatsAppAdapter { Ok(()) } - pub async fn process_incoming_message(&self, message: WhatsAppMessage) -> Result, Box> { + pub async fn process_incoming_message(&self, message: WhatsAppMessage) -> Result, Box> { let mut user_messages = Vec::new(); for entry in message.entry { @@ -158,7 +160,7 @@ impl WhatsAppAdapter { Ok(user_messages) } - pub fn verify_webhook(&self, mode: &str, token: &str, challenge: &str) -> Result> { + pub fn verify_webhook(&self, mode: &str, token: &str, challenge: &str) -> Result> { if mode == "subscribe" && token == self.webhook_verify_token { Ok(challenge.to_string()) } else { @@ -168,8 +170,8 @@ impl WhatsAppAdapter { } #[async_trait] -impl super::channels::ChannelAdapter for WhatsAppAdapter { - async fn send_message(&self, response: BotResponse) -> Result<(), Box> { +impl crate::channels::ChannelAdapter for WhatsAppAdapter { + async fn send_message(&self, response: BotResponse) -> Result<(), Box> { info!("Sending WhatsApp response to: {}", response.user_id); self.send_whatsapp_message(&response.user_id, &response.content).await }