diff --git a/prompts/dev/fix.md b/prompts/dev/fix.md new file mode 100644 index 0000000..b3d8a1c --- /dev/null +++ b/prompts/dev/fix.md @@ -0,0 +1,36 @@ +You are fixing Rust code in a Cargo project. The user is providing problematic code that needs to be corrected. + +## Your Task +Fix ALL compiler errors and logical issues while maintaining the original intent. Return the COMPLETE corrected files as a SINGLE .sh script that can be executed from project root. +Use Cargo.toml as reference, do not change it. +Only return input files, all other files already exists. +If something, need to be added to a external file, inform it separated. + +## Critical Requirements +1. **Return as SINGLE .sh script** - Output must be a complete shell script using `cat > file << 'EOF'` pattern +2. **Include ALL files** - Every corrected file must be included in the script +3. **Respect Cargo.toml** - Check dependencies, editions, and features to avoid compiler errors +4. **Type safety** - Ensure all types match and trait bounds are satisfied +5. **Ownership rules** - Fix borrowing, ownership, and lifetime issues + +## Output Format Requirements +You MUST return exactly this example format: + +```sh +#!/bin/bash + +# Restore fixed Rust project + +cat > src/.rs << 'EOF' +use std::io; + +// test + +cat > src/.rs << 'EOF' +// Fixed library code +pub fn add(a: i32, b: i32) -> i32 { + a + b +} +EOF + +---- diff --git a/prompts/development/general.md b/prompts/dev/general.md similarity index 89% rename from prompts/development/general.md rename to prompts/dev/general.md index df5aefd..d847f5f 100644 --- a/prompts/development/general.md +++ b/prompts/dev/general.md @@ -1,3 +1,4 @@ +* Preffer imports than using :: to call methods, * Output a single `.sh` script using `cat` so it can be restored directly. * No placeholders, only real, production-ready code. * No comments, no explanations, no extra text. diff --git a/src/shared/PROMPT-model.md b/prompts/dev/model.md similarity index 100% rename from src/shared/PROMPT-model.md rename to prompts/dev/model.md diff --git a/prompts/development/service.md b/prompts/dev/service.md similarity index 100% rename from prompts/development/service.md rename to prompts/dev/service.md diff --git a/scripts/dev/llm_context.txt b/scripts/dev/llm_context.txt new file mode 100644 index 0000000..d104329 --- /dev/null +++ b/scripts/dev/llm_context.txt @@ -0,0 +1,2625 @@ +Consolidated LLM Context +* Preffer imports than using :: to call methods, +* Output a single `.sh` script using `cat` so it can be restored directly. +* No placeholders, only real, production-ready code. +* No comments, no explanations, no extra text. +* Follow KISS principles. +* Provide a complete, professional, working solution. +* If the script is too long, split into multiple parts, but always return the **entire code**. +* Output must be **only the code**, nothing else. + +[package] +name = "gbserver" +version = "0.1.0" +edition = "2021" +authors = ["Rodrigo Rodriguez "] +description = "General Bots Server" +license = "AGPL-3.0" +repository = "https://alm.pragmatismo.com.br/generalbots/gbserver" + +[features] +default = ["postgres", "qdrant"] +local_llm = [] +postgres = ["sqlx/postgres"] +qdrant = ["langchain-rust/qdrant"] + +[dependencies] +actix-cors = "0.7" +actix-multipart = "0.7" +actix-web = "4.9" +actix-ws = "0.3" +anyhow = "1.0" +async-stream = "0.3" +async-trait = "0.1" +aes-gcm = "0.10" +argon2 = "0.5" +base64 = "0.22" +bytes = "1.8" +chrono = { version = "0.4", features = ["serde"] } +dotenv = "0.15" +downloader = "0.2" +env_logger = "0.11" +futures = "0.3" +futures-util = "0.3" +imap = "2.4" +langchain-rust = { version = "4.6", features = ["qdrant", "postgres"] } +lettre = { version = "0.11", features = ["smtp-transport", "builder", "tokio1", "tokio1-native-tls"] } +livekit = "0.7" +log = "0.4" +mailparse = "0.15" +minio = { git = "https://github.com/minio/minio-rs", branch = "master" } +native-tls = "0.2" +num-format = "0.4" +qdrant-client = "1.12" +rhai = "1.22" +redis = { version = "0.27", features = ["tokio-comp"] } +regex = "1.11" +reqwest = { version = "0.12", features = ["json", "stream"] } +scraper = "0.20" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +smartstring = "1.0" +sqlx = { version = "0.8", features = ["time", "uuid", "runtime-tokio-rustls", "postgres", "chrono"] } +tempfile = "3" +thirtyfour = "0.34" +tokio = { version = "1.41", features = ["full"] } +tokio-stream = "0.1" +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["fmt"] } +urlencoding = "2.1" +uuid = { version = "1.11", features = ["serde", "v4"] } +zip = "2.2" + +You are fixing Rust code in a Cargo project. The user is providing problematic code that needs to be corrected. + +## Your Task +Fix ALL compiler errors and logical issues while maintaining the original intent. Return the COMPLETE corrected files as a SINGLE .sh script that can be executed from project root. +Use Cargo.toml as reference, do not change it. +Only return input files, all other files already exists. +If something, need to be added to a external file, inform it separated. + +## Critical Requirements +1. **Return as SINGLE .sh script** - Output must be a complete shell script using `cat > file << 'EOF'` pattern +2. **Include ALL files** - Every corrected file must be included in the script +3. **Respect Cargo.toml** - Check dependencies, editions, and features to avoid compiler errors +4. **Type safety** - Ensure all types match and trait bounds are satisfied +5. **Ownership rules** - Fix borrowing, ownership, and lifetime issues + +## Output Format Requirements +You MUST return exactly this example format: + +```sh +#!/bin/bash + +# Restore fixed Rust project + +cat > src/.rs << 'EOF' +use std::io; + +// test + +cat > src/.rs << 'EOF' +// Fixed library code +pub fn add(a: i32, b: i32) -> i32 { + a + b +} +EOF + +---- + +use std::sync::Arc; + +use minio::s3::Client; + +use crate::{config::AppConfig, web_automator::BrowserPool}; + +#[derive(Clone)] +pub struct AppState { + pub minio_client: Option, + pub config: Option, + pub db: Option, + pub db_custom: Option, + pub browser_pool: Arc, + pub orchestrator: Arc, + pub web_adapter: Arc, + pub voice_adapter: Arc, + pub whatsapp_adapter: Arc, + tool_api: Arc, // Add this +} +pub struct _BotState { + pub language: String, + pub work_folder: String, +} + +use crate::config::AIConfig; +use langchain_rust::llm::OpenAI; +use langchain_rust::{language_models::llm::LLM, llm::AzureConfig}; +use log::error; +use log::{debug, warn}; +use rhai::{Array, Dynamic}; +use serde_json::{json, Value}; +use smartstring::SmartString; +use sqlx::Column; // Required for .name() method +use sqlx::TypeInfo; // Required for .type_info() method +use sqlx::{postgres::PgRow, Row}; +use sqlx::{Decode, Type}; +use std::error::Error; +use std::fs::File; +use std::io::BufReader; +use std::path::Path; +use tokio::fs::File as TokioFile; +use tokio_stream::StreamExt; +use zip::ZipArchive; + +use reqwest::Client; +use tokio::io::AsyncWriteExt; + +pub fn azure_from_config(config: &AIConfig) -> AzureConfig { + AzureConfig::new() + .with_api_base(&config.endpoint) + .with_api_key(&config.key) + .with_api_version(&config.version) + .with_deployment_id(&config.instance) +} + +pub async fn call_llm( + text: &str, + ai_config: &AIConfig, +) -> Result> { + let azure_config = azure_from_config(&ai_config.clone()); + let open_ai = OpenAI::new(azure_config); + + // Directly use the input text as prompt + let prompt = text.to_string(); + + // Call LLM and return the raw text response + match open_ai.invoke(&prompt).await { + Ok(response_text) => Ok(response_text), + Err(err) => { + error!("Error invoking LLM API: {}", err); + Err(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + "Failed to invoke LLM API", + ))) + } + } +} + +pub fn extract_zip_recursive( + zip_path: &Path, + destination_path: &Path, +) -> Result<(), Box> { + let file = File::open(zip_path)?; + let buf_reader = BufReader::new(file); + let mut archive = ZipArchive::new(buf_reader)?; + + for i in 0..archive.len() { + let mut file = archive.by_index(i)?; + let outpath = destination_path.join(file.mangled_name()); + + if file.is_dir() { + std::fs::create_dir_all(&outpath)?; + } else { + if let Some(parent) = outpath.parent() { + if !parent.exists() { + std::fs::create_dir_all(&parent)?; + } + } + let mut outfile = File::create(&outpath)?; + std::io::copy(&mut file, &mut outfile)?; + } + } + + Ok(()) +} +pub fn row_to_json(row: PgRow) -> Result> { + let mut result = serde_json::Map::new(); + let columns = row.columns(); + debug!("Converting row with {} columns", columns.len()); + + for (i, column) in columns.iter().enumerate() { + let column_name = column.name(); + let type_name = column.type_info().name(); + + let value = match type_name { + "INT4" | "int4" => handle_nullable_type::(&row, i, column_name), + "INT8" | "int8" => handle_nullable_type::(&row, i, column_name), + "FLOAT4" | "float4" => handle_nullable_type::(&row, i, column_name), + "FLOAT8" | "float8" => handle_nullable_type::(&row, i, column_name), + "TEXT" | "VARCHAR" | "text" | "varchar" => { + handle_nullable_type::(&row, i, column_name) + } + "BOOL" | "bool" => handle_nullable_type::(&row, i, column_name), + "JSON" | "JSONB" | "json" | "jsonb" => handle_json(&row, i, column_name), + _ => { + warn!("Unknown type {} for column {}", type_name, column_name); + handle_nullable_type::(&row, i, column_name) + } + }; + + result.insert(column_name.to_string(), value); + } + + Ok(Value::Object(result)) +} + +fn handle_nullable_type<'r, T>(row: &'r PgRow, idx: usize, col_name: &str) -> Value +where + T: Type + Decode<'r, sqlx::Postgres> + serde::Serialize + std::fmt::Debug, +{ + match row.try_get::, _>(idx) { + Ok(Some(val)) => { + debug!("Successfully read column {} as {:?}", col_name, val); + json!(val) + } + Ok(None) => { + debug!("Column {} is NULL", col_name); + Value::Null + } + Err(e) => { + warn!("Failed to read column {}: {}", col_name, e); + Value::Null + } + } +} + +fn handle_json(row: &PgRow, idx: usize, col_name: &str) -> Value { + // First try to get as Option + match row.try_get::, _>(idx) { + Ok(Some(val)) => { + debug!("Successfully read JSON column {} as Value", col_name); + return val; + } + Ok(None) => return Value::Null, + Err(_) => (), // Fall through to other attempts + } + + // Try as Option that might contain JSON + match row.try_get::, _>(idx) { + Ok(Some(s)) => match serde_json::from_str(&s) { + Ok(val) => val, + Err(_) => { + debug!("Column {} contains string that's not JSON", col_name); + json!(s) + } + }, + Ok(None) => Value::Null, + Err(e) => { + warn!("Failed to read JSON column {}: {}", col_name, e); + Value::Null + } + } +} +pub fn json_value_to_dynamic(value: &Value) -> Dynamic { + match value { + Value::Null => Dynamic::UNIT, + Value::Bool(b) => Dynamic::from(*b), + Value::Number(n) => { + if let Some(i) = n.as_i64() { + Dynamic::from(i) + } else if let Some(f) = n.as_f64() { + Dynamic::from(f) + } else { + Dynamic::UNIT + } + } + Value::String(s) => Dynamic::from(s.clone()), + Value::Array(arr) => Dynamic::from( + arr.iter() + .map(json_value_to_dynamic) + .collect::(), + ), + Value::Object(obj) => Dynamic::from( + obj.iter() + .map(|(k, v)| (SmartString::from(k), json_value_to_dynamic(v))) + .collect::(), + ), + } +} + +/// Converts any value to an array - single values become single-element arrays +pub fn to_array(value: Dynamic) -> Array { + if value.is_array() { + // Already an array - return as-is + value.cast::() + } else if value.is_unit() || value.is::<()>() { + // Handle empty/unit case + Array::new() + } else { + // Convert single value to single-element array + Array::from([value]) + } +} + +pub async fn download_file(url: &str, output_path: &str) -> Result<(), Box> { + let client = Client::new(); + let response = client.get(url).send().await?; + + if response.status().is_success() { + let mut file = TokioFile::create(output_path).await?; + + let mut stream = response.bytes_stream(); + + while let Some(chunk) = stream.next().await { + file.write_all(&chunk?).await?; + } + debug!("File downloaded successfully to {}", output_path); + } else { + return Err("Failed to download file".into()); + } + + Ok(()) +} + +// Helper function to parse the filter string into SQL WHERE clause and parameters +pub fn parse_filter(filter_str: &str) -> Result<(String, Vec), Box> { + let parts: Vec<&str> = filter_str.split('=').collect(); + if parts.len() != 2 { + return Err("Invalid filter format. Expected 'KEY=VALUE'".into()); + } + + let column = parts[0].trim(); + let value = parts[1].trim(); + + // Validate column name to prevent SQL injection + if !column + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '_') + { + return Err("Invalid column name in filter".into()); + } + + // Return the parameterized query part and the value separately + Ok((format!("{} = $1", column), vec![value.to_string()])) +} + +// Parse filter without adding quotes +pub fn parse_filter_with_offset( + filter_str: &str, + offset: usize, +) -> Result<(String, Vec), Box> { + let mut clauses = Vec::new(); + let mut params = Vec::new(); + + for (i, condition) in filter_str.split('&').enumerate() { + let parts: Vec<&str> = condition.split('=').collect(); + if parts.len() != 2 { + return Err("Invalid filter format".into()); + } + + let column = parts[0].trim(); + let value = parts[1].trim(); + + if !column + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '_') + { + return Err("Invalid column name".into()); + } + + clauses.push(format!("{} = ${}", column, i + 1 + offset)); + params.push(value.to_string()); // Store raw value without quotes + } + + Ok((clauses.join(" AND "), params)) +} + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, Serialize, Deserialize, FromRow)] +pub struct organization { + pub org_id: Uuid, + pub name: String, + pub slug: String, + pub created_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize, FromRow)] +pub struct Bot { + pub bot_id: Uuid, + pub name: String, + pub status: BotStatus, + pub config: serde_json::Value, + pub created_at: chrono::DateTime, + pub updated_at: chrono::DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::Type)] +#[serde(rename_all = "snake_case")] +#[sqlx(type_name = "bot_status", rename_all = "snake_case")] +pub enum BotStatus { + Active, + Inactive, + Maintenance, +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum TriggerKind { + Scheduled = 0, + TableUpdate = 1, + TableInsert = 2, + TableDelete = 3, +} + +impl TriggerKind { + pub fn from_i32(value: i32) -> Option { + match value { + 0 => Some(Self::Scheduled), + 1 => Some(Self::TableUpdate), + 2 => Some(Self::TableInsert), + 3 => Some(Self::TableDelete), + _ => None, + } + } +} + +#[derive(Debug, FromRow, Serialize, Deserialize)] +pub struct Automation { + pub id: Uuid, + pub kind: i32, // Using number for trigger type + pub target: Option, + pub schedule: Option, + pub param: String, + pub is_active: bool, + pub last_triggered: Option>, +} + +pub mod models; +pub mod state; +pub mod utils; + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, Serialize, Deserialize, FromRow)] +pub struct UserSession { + pub id: Uuid, + pub user_id: Uuid, + pub bot_id: Uuid, + pub title: String, + pub context_data: serde_json::Value, + pub answer_mode: String, + pub current_tool: Option, + pub created_at: DateTime, + pub updated_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EmbeddingRequest { + pub text: String, + pub model: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EmbeddingResponse { + pub embedding: Vec, + pub model: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SearchResult { + pub text: String, + pub similarity: f32, + pub metadata: serde_json::Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UserMessage { + pub bot_id: String, + pub user_id: String, + pub session_id: String, + pub channel: String, + pub content: String, + pub message_type: String, + pub media_url: Option, + pub timestamp: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BotResponse { + pub bot_id: String, + pub user_id: String, + pub session_id: String, + pub channel: String, + pub content: String, + pub message_type: String, + pub stream_token: Option, + pub is_complete: bool, +} + +use crate::{shared::state::AppState, whatsapp}; +use actix_web::{web, HttpRequest, HttpResponse, Result}; +use actix_ws::Message as WsMessage; +use chrono::Utc; +use langchain_rust::{ + chain::{Chain, LLMChain}, + llm::openai::OpenAI, + memory::SimpleMemory, + prompt_args, + tools::{postgres::PostgreSQLEngine, SQLDatabaseBuilder}, + vectorstore::qdrant::Qdrant as LangChainQdrant, + vectorstore::{VecStoreOptions, VectorStore}, +}; +use log::info; +use serde_json; +use std::collections::HashMap; +use std::fs; +use std::sync::Arc; +use tokio::sync::{mpsc, Mutex}; +use uuid::Uuid; + +use crate::channels::{ChannelAdapter, VoiceAdapter, WebChannelAdapter}; +use crate::chart::ChartGenerator; +use crate::llm::LLMProvider; +use crate::session::SessionManager; +use crate::tools::ToolManager; +use crate::whatsapp::WhatsAppAdapter; +use crate::{ + auth::AuthService, + shared::{BotResponse, UserMessage, UserSession}, +}; +pub struct BotOrchestrator { + session_manager: SessionManager, + tool_manager: ToolManager, + llm_provider: Arc, + auth_service: AuthService, + channels: HashMap>, + response_channels: Arc>>>, + chart_generator: Option>, + vector_store: Option>, + sql_chain: Option>, +} + +impl BotOrchestrator { + pub fn new( + session_manager: SessionManager, + tool_manager: ToolManager, + llm_provider: Arc, + auth_service: AuthService, + chart_generator: Option>, + vector_store: Option>, + sql_chain: Option>, + ) -> Self { + Self { + session_manager, + tool_manager, + llm_provider, + auth_service, + channels: HashMap::new(), + response_channels: Arc::new(Mutex::new(HashMap::new())), + chart_generator, + vector_store, + sql_chain, + } + } + + pub fn add_channel(&mut self, channel_type: &str, adapter: Arc) { + self.channels.insert(channel_type.to_string(), adapter); + } + + pub async fn register_response_channel( + &self, + session_id: String, + sender: mpsc::Sender, + ) { + self.response_channels + .lock() + .await + .insert(session_id, sender); + } + + pub async fn set_user_answer_mode( + &self, + user_id: &str, + bot_id: &str, + mode: &str, + ) -> Result<(), Box> { + self.session_manager + .update_answer_mode(user_id, bot_id, mode) + .await?; + Ok(()) + } + + pub async fn process_message( + &self, + message: UserMessage, + ) -> Result<(), Box> { + info!( + "Processing message from channel: {}, user: {}", + message.channel, message.user_id + ); + + let user_id = Uuid::parse_str(&message.user_id).unwrap_or_else(|_| Uuid::new_v4()); + let bot_id = Uuid::parse_str(&message.bot_id) + .unwrap_or_else(|_| Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap()); + + let session = match self + .session_manager + .get_user_session(user_id, bot_id) + .await? + { + Some(session) => session, + None => { + self.session_manager + .create_session(user_id, bot_id, "New Conversation") + .await? + } + }; + + if session.answer_mode == "tool" && session.current_tool.is_some() { + self.tool_manager + .provide_user_response(&message.user_id, &message.bot_id, message.content.clone()) + .await?; + return Ok(()); + } + + self.session_manager + .save_message( + session.id, + user_id, + "user", + &message.content, + &message.message_type, + ) + .await?; + + let response_content = match session.answer_mode.as_str() { + "document" => self.document_mode_handler(&message, &session).await?, + "chart" => self.chart_mode_handler(&message, &session).await?, + "database" => self.database_mode_handler(&message, &session).await?, + "tool" => self.tool_mode_handler(&message, &session).await?, + _ => self.direct_mode_handler(&message, &session).await?, + }; + + self.session_manager + .save_message(session.id, user_id, "assistant", &response_content, "text") + .await?; + + let bot_response = BotResponse { + bot_id: message.bot_id, + user_id: message.user_id, + session_id: message.session_id, + channel: message.channel, + content: response_content, + message_type: "text".to_string(), + stream_token: None, + is_complete: true, + }; + + if let Some(adapter) = self.channels.get(&message.channel) { + adapter.send_message(bot_response).await?; + } + + Ok(()) + } + + async fn document_mode_handler( + &self, + message: &UserMessage, + session: &UserSession, + ) -> Result> { + if let Some(vector_store) = &self.vector_store { + let similar_docs = vector_store + .similarity_search(&message.content, 3, &VecStoreOptions::default()) + .await?; + + let mut enhanced_prompt = format!("User question: {}\n\n", message.content); + + if !similar_docs.is_empty() { + enhanced_prompt.push_str("Relevant documents:\n"); + for (i, doc) in similar_docs.iter().enumerate() { + enhanced_prompt.push_str(&format!("[Doc {}]: {}\n", i + 1, doc.page_content)); + } + enhanced_prompt.push_str( + "\nPlease answer the user's question based on the provided documents.", + ); + } + + self.llm_provider + .generate(&enhanced_prompt, &serde_json::Value::Null) + .await + } else { + self.direct_mode_handler(message, session).await + } + } + + async fn chart_mode_handler( + &self, + message: &UserMessage, + session: &UserSession, + ) -> Result> { + if let Some(chart_generator) = &self.chart_generator { + let chart_response = chart_generator + .generate_chart(&message.content, "bar") + .await?; + + self.session_manager + .save_message( + session.id, + session.user_id, + "system", + &format!("Generated chart for query: {}", message.content), + "chart", + ) + .await?; + + Ok(format!( + "Chart generated for your query. Data retrieved: {}", + chart_response.sql_query + )) + } else { + self.document_mode_handler(message, session).await + } + } + + async fn database_mode_handler( + &self, + message: &UserMessage, + _session: &UserSession, + ) -> Result> { + if let Some(sql_chain) = &self.sql_chain { + let input_variables = prompt_args! { + "input" => message.content, + }; + + let result = sql_chain.invoke(input_variables).await?; + Ok(result.to_string()) + } else { + let db_url = std::env::var("DATABASE_URL")?; + let engine = PostgreSQLEngine::new(&db_url).await?; + let db = SQLDatabaseBuilder::new(engine).build().await?; + + let llm = OpenAI::default(); + let chain = langchain_rust::chain::SQLDatabaseChainBuilder::new() + .llm(llm) + .top_k(5) + .database(db) + .build()?; + + let input_variables = chain.prompt_builder().query(&message.content).build(); + let result = chain.invoke(input_variables).await?; + + Ok(result.to_string()) + } + } + + async fn tool_mode_handler( + &self, + message: &UserMessage, + _session: &UserSession, + ) -> Result> { + if message.content.to_lowercase().contains("calculator") { + if let Some(_adapter) = self.channels.get(&message.channel) { + let (tx, _rx) = mpsc::channel(100); + + self.register_response_channel(message.session_id.clone(), tx.clone()) + .await; + + let tool_manager = self.tool_manager.clone(); + let user_id_str = message.user_id.clone(); + let bot_id_str = message.bot_id.clone(); + let session_manager = self.session_manager.clone(); + + tokio::spawn(async move { + let _ = tool_manager + .execute_tool("calculator", &user_id_str, &bot_id_str, session_manager, tx) + .await; + }); + } + Ok("Starting calculator tool...".to_string()) + } else { + let available_tools = self.tool_manager.list_tools(); + let tools_context = if !available_tools.is_empty() { + format!("\n\nAvailable tools: {}. If the user needs calculations, suggest using the calculator tool.", available_tools.join(", ")) + } else { + String::new() + }; + + let full_prompt = format!("{}{}", message.content, tools_context); + + self.llm_provider + .generate(&full_prompt, &serde_json::Value::Null) + .await + } + } + + async fn direct_mode_handler( + &self, + message: &UserMessage, + session: &UserSession, + ) -> Result> { + let history = self + .session_manager + .get_conversation_history(session.id, session.user_id) + .await?; + + let mut memory = SimpleMemory::new(); + for (role, content) in history { + match role.as_str() { + "user" => memory.add_user_message(&content), + "assistant" => memory.add_ai_message(&content), + _ => {} + } + } + + let mut prompt = String::new(); + if let Some(chat_history) = memory.get_chat_history() { + for message in chat_history { + prompt.push_str(&format!( + "{}: {}\n", + message.message_type(), + message.content() + )); + } + } + prompt.push_str(&format!("User: {}\nAssistant:", message.content)); + + self.llm_provider + .generate(&prompt, &serde_json::Value::Null) + .await + } + + pub async fn stream_response( + &self, + message: UserMessage, + mut response_tx: mpsc::Sender, + ) -> Result<(), Box> { + info!("Streaming response for user: {}", message.user_id); + + let user_id = Uuid::parse_str(&message.user_id).unwrap_or_else(|_| Uuid::new_v4()); + let bot_id = Uuid::parse_str(&message.bot_id) + .unwrap_or_else(|_| Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap()); + + let session = match self + .session_manager + .get_user_session(user_id, bot_id) + .await? + { + Some(session) => session, + None => { + self.session_manager + .create_session(user_id, bot_id, "New Conversation") + .await? + } + }; + + if session.answer_mode == "tool" && session.current_tool.is_some() { + self.tool_manager + .provide_user_response(&message.user_id, &message.bot_id, message.content.clone()) + .await?; + return Ok(()); + } + + self.session_manager + .save_message( + session.id, + user_id, + "user", + &message.content, + &message.message_type, + ) + .await?; + + let history = self + .session_manager + .get_conversation_history(session.id, user_id) + .await?; + + let mut memory = SimpleMemory::new(); + for (role, content) in history { + match role.as_str() { + "user" => memory.add_user_message(&content), + "assistant" => memory.add_ai_message(&content), + _ => {} + } + } + + let mut prompt = String::new(); + if let Some(chat_history) = memory.get_chat_history() { + for message in chat_history { + prompt.push_str(&format!( + "{}: {}\n", + message.message_type(), + message.content() + )); + } + } + prompt.push_str(&format!("User: {}\nAssistant:", message.content)); + + let (stream_tx, mut stream_rx) = mpsc::channel(100); + let llm_provider = self.llm_provider.clone(); + let prompt_clone = prompt.clone(); + + tokio::spawn(async move { + let _ = llm_provider + .generate_stream(&prompt_clone, &serde_json::Value::Null, stream_tx) + .await; + }); + + let mut full_response = String::new(); + while let Some(chunk) = stream_rx.recv().await { + full_response.push_str(&chunk); + + let bot_response = BotResponse { + bot_id: message.bot_id.clone(), + user_id: message.user_id.clone(), + session_id: message.session_id.clone(), + channel: message.channel.clone(), + content: chunk, + message_type: "text".to_string(), + stream_token: None, + is_complete: false, + }; + + if response_tx.send(bot_response).await.is_err() { + break; + } + } + + self.session_manager + .save_message(session.id, user_id, "assistant", &full_response, "text") + .await?; + + let final_response = BotResponse { + bot_id: message.bot_id, + user_id: message.user_id, + session_id: message.session_id, + channel: message.channel, + content: "".to_string(), + message_type: "text".to_string(), + stream_token: None, + is_complete: true, + }; + + response_tx.send(final_response).await?; + Ok(()) + } + + pub async fn get_user_sessions( + &self, + user_id: Uuid, + ) -> Result, Box> { + self.session_manager.get_user_sessions(user_id).await + } + + pub async fn get_conversation_history( + &self, + session_id: Uuid, + user_id: Uuid, + ) -> Result, Box> { + self.session_manager + .get_conversation_history(session_id, user_id) + .await + } + + pub async fn process_message_with_tools( + &self, + message: UserMessage, + ) -> Result<(), Box> { + info!( + "Processing message with tools from user: {}", + message.user_id + ); + + let user_id = Uuid::parse_str(&message.user_id).unwrap_or_else(|_| Uuid::new_v4()); + let bot_id = Uuid::parse_str(&message.bot_id) + .unwrap_or_else(|_| Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap()); + + let session = match self + .session_manager + .get_user_session(user_id, bot_id) + .await? + { + Some(session) => session, + None => { + self.session_manager + .create_session(user_id, bot_id, "New Conversation") + .await? + } + }; + + self.session_manager + .save_message( + session.id, + user_id, + "user", + &message.content, + &message.message_type, + ) + .await?; + + let is_tool_waiting = self + .tool_manager + .is_tool_waiting(&message.session_id) + .await + .unwrap_or(false); + + if is_tool_waiting { + self.tool_manager + .provide_input(&message.session_id, &message.content) + .await?; + + if let Ok(tool_output) = self.tool_manager.get_tool_output(&message.session_id).await { + for output in tool_output { + let bot_response = BotResponse { + bot_id: message.bot_id.clone(), + user_id: message.user_id.clone(), + session_id: message.session_id.clone(), + channel: message.channel.clone(), + content: output, + message_type: "text".to_string(), + stream_token: None, + is_complete: true, + }; + + if let Some(adapter) = self.channels.get(&message.channel) { + adapter.send_message(bot_response).await?; + } + } + } + return Ok(()); + } + + let available_tools = self.tool_manager.list_tools(); + let tools_context = if !available_tools.is_empty() { + format!("\n\nAvailable tools: {}. If the user needs calculations, suggest using the calculator tool.", available_tools.join(", ")) + } else { + String::new() + }; + + let full_prompt = format!("{}{}", message.content, tools_context); + + let response = if message.content.to_lowercase().contains("calculator") + || message.content.to_lowercase().contains("calculate") + || message.content.to_lowercase().contains("math") + { + match self + .tool_manager + .execute_tool("calculator", &message.session_id, &message.user_id) + .await + { + Ok(tool_result) => { + self.session_manager + .save_message( + session.id, + user_id, + "assistant", + &tool_result.output, + "tool_start", + ) + .await?; + + tool_result.output + } + Err(e) => { + format!("I encountered an error starting the calculator: {}", e) + } + } + } else { + self.llm_provider + .generate(&full_prompt, &serde_json::Value::Null) + .await? + }; + + self.session_manager + .save_message(session.id, user_id, "assistant", &response, "text") + .await?; + + let bot_response = BotResponse { + bot_id: message.bot_id, + user_id: message.user_id, + session_id: message.session_id, + channel: message.channel, + content: response, + message_type: "text".to_string(), + stream_token: None, + is_complete: true, + }; + + if let Some(adapter) = self.channels.get(&message.channel) { + adapter.send_message(bot_response).await?; + } + + Ok(()) + } + + async fn tool_mode_handler( + &self, + message: &UserMessage, + _session: &UserSession, + ) -> Result> { + if message.content.to_lowercase().contains("calculator") { + if let Some(_adapter) = self.channels.get(&message.channel) { + let (tx, _rx) = mpsc::channel(100); + + self.register_response_channel(message.session_id.clone(), tx.clone()) + .await; + + let tool_manager = self.tool_manager.clone(); + let user_id_str = message.user_id.clone(); + let bot_id_str = message.bot_id.clone(); + let session_manager = self.session_manager.clone(); + + tokio::spawn(async move { + let _ = tool_manager + .execute_tool_with_session( + "calculator", + &user_id_str, + &bot_id_str, + session_manager, + tx, + ) + .await; + }); + } + Ok("Starting calculator tool...".to_string()) + } else { + let available_tools = self.tool_manager.list_tools(); + let tools_context = if !available_tools.is_empty() { + format!("\n\nAvailable tools: {}. If the user needs calculations, suggest using the calculator tool.", available_tools.join(", ")) + } else { + String::new() + }; + + let full_prompt = format!("{}{}", message.content, tools_context); + + self.llm_provider + .generate(&full_prompt, &serde_json::Value::Null) + .await + } + } + + // Fix the process_message_with_tools method + pub async fn process_message_with_tools( + &self, + message: UserMessage, + ) -> Result<(), Box> { + info!( + "Processing message with tools from user: {}", + message.user_id + ); + + let user_id = Uuid::parse_str(&message.user_id).unwrap_or_else(|_| Uuid::new_v4()); + let bot_id = Uuid::parse_str(&message.bot_id) + .unwrap_or_else(|_| Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap()); + + let session = match self + .session_manager + .get_user_session(user_id, bot_id) + .await? + { + Some(session) => session, + None => { + self.session_manager + .create_session(user_id, bot_id, "New Conversation") + .await? + } + }; + + self.session_manager + .save_message( + session.id, + user_id, + "user", + &message.content, + &message.message_type, + ) + .await?; + + let is_tool_waiting = self + .tool_manager + .is_tool_waiting(&message.session_id) + .await + .unwrap_or(false); + + if is_tool_waiting { + self.tool_manager + .provide_input(&message.session_id, &message.content) + .await?; + + if let Ok(tool_output) = self.tool_manager.get_tool_output(&message.session_id).await { + for output in tool_output { + let bot_response = BotResponse { + bot_id: message.bot_id.clone(), + user_id: message.user_id.clone(), + session_id: message.session_id.clone(), + channel: message.channel.clone(), + content: output, + message_type: "text".to_string(), + stream_token: None, + is_complete: true, + }; + + if let Some(adapter) = self.channels.get(&message.channel) { + adapter.send_message(bot_response).await?; + } + } + } + return Ok(()); + } + + let response = if message.content.to_lowercase().contains("calculator") + || message.content.to_lowercase().contains("calculate") + || message.content.to_lowercase().contains("math") + { + match self + .tool_manager + .execute_tool("calculator", &message.session_id, &message.user_id) + .await + { + Ok(tool_result) => { + self.session_manager + .save_message( + session.id, + user_id, + "assistant", + &tool_result.output, + "tool_start", + ) + .await?; + + tool_result.output + } + Err(e) => { + format!("I encountered an error starting the calculator: {}", e) + } + } + } else { + let available_tools = self.tool_manager.list_tools(); + let tools_context = if !available_tools.is_empty() { + format!("\n\nAvailable tools: {}. If the user needs calculations, suggest using the calculator tool.", available_tools.join(", ")) + } else { + String::new() + }; + + let full_prompt = format!("{}{}", message.content, tools_context); + + self.llm_provider + .generate(&full_prompt, &serde_json::Value::Null) + .await? + }; + + self.session_manager + .save_message(session.id, user_id, "assistant", &response, "text") + .await?; + + let bot_response = BotResponse { + bot_id: message.bot_id, + user_id: message.user_id, + session_id: message.session_id, + channel: message.channel, + content: response, + message_type: "text".to_string(), + stream_token: None, + is_complete: true, + }; + + if let Some(adapter) = self.channels.get(&message.channel) { + adapter.send_message(bot_response).await?; + } + + Ok(()) + } +} + +#[actix_web::get("/ws")] +async fn websocket_handler( + req: HttpRequest, + stream: web::Payload, + data: web::Data, +) -> Result { + let (res, mut session, mut msg_stream) = actix_ws::handle(&req, stream)?; + let session_id = Uuid::new_v4().to_string(); + let (tx, mut rx) = mpsc::channel::(100); + + data.orchestrator + .register_response_channel(session_id.clone(), tx.clone()) + .await; + data.web_adapter + .add_connection(session_id.clone(), tx.clone()) + .await; + data.voice_adapter + .add_connection(session_id.clone(), tx.clone()) + .await; + + let orchestrator = data.orchestrator.clone(); + let web_adapter = data.web_adapter.clone(); + + actix_web::rt::spawn(async move { + while let Some(msg) = rx.recv().await { + if let Ok(json) = serde_json::to_string(&msg) { + let _ = session.text(json).await; + } + } + }); + + actix_web::rt::spawn(async move { + while let Some(Ok(msg)) = msg_stream.recv().await { + match msg { + WsMessage::Text(text) => { + let user_message = UserMessage { + bot_id: "default_bot".to_string(), + user_id: "default_user".to_string(), + session_id: session_id.clone(), + channel: "web".to_string(), + content: text.to_string(), + message_type: "text".to_string(), + media_url: None, + timestamp: Utc::now(), + }; + + if let Err(e) = orchestrator.stream_response(user_message, tx.clone()).await { + info!("Error processing message: {}", e); + } + } + WsMessage::Close(_) => { + web_adapter.remove_connection(&session_id).await; + break; + } + _ => {} + } + } + }); + + Ok(res) +} + +#[actix_web::get("/api/whatsapp/webhook")] +async fn whatsapp_webhook_verify( + data: web::Data, + web::Query(params): web::Query>, +) -> Result { + let mode = params.get("hub.mode").unwrap_or(&"".to_string()); + let token = params.get("hub.verify_token").unwrap_or(&"".to_string()); + let challenge = params.get("hub.challenge").unwrap_or(&"".to_string()); + + match data.whatsapp_adapter.verify_webhook(mode, token, challenge) { + Ok(challenge_response) => Ok(HttpResponse::Ok().body(challenge_response)), + Err(_) => Ok(HttpResponse::Forbidden().body("Verification failed")), + } +} + +#[actix_web::post("/api/whatsapp/webhook")] +async fn whatsapp_webhook( + data: web::Data, + payload: web::Json, +) -> Result { + match data + .whatsapp_adapter + .process_incoming_message(payload.into_inner()) + .await + { + Ok(user_messages) => { + for user_message in user_messages { + if let Err(e) = data.orchestrator.process_message(user_message).await { + log::error!("Error processing WhatsApp message: {}", e); + } + } + Ok(HttpResponse::Ok().body("")) + } + Err(e) => { + log::error!("Error processing WhatsApp webhook: {}", e); + Ok(HttpResponse::BadRequest().body("Invalid message")) + } + } +} + +#[actix_web::post("/api/voice/start")] +async fn voice_start( + data: web::Data, + info: web::Json, +) -> Result { + let session_id = info + .get("session_id") + .and_then(|s| s.as_str()) + .unwrap_or(""); + let user_id = info + .get("user_id") + .and_then(|u| u.as_str()) + .unwrap_or("user"); + + match data + .voice_adapter + .start_voice_session(session_id, user_id) + .await + { + Ok(token) => { + Ok(HttpResponse::Ok().json(serde_json::json!({"token": token, "status": "started"}))) + } + Err(e) => { + Ok(HttpResponse::InternalServerError() + .json(serde_json::json!({"error": e.to_string()}))) + } + } +} + +#[actix_web::post("/api/voice/stop")] +async fn voice_stop( + data: web::Data, + info: web::Json, +) -> Result { + let session_id = info + .get("session_id") + .and_then(|s| s.as_str()) + .unwrap_or(""); + + match data.voice_adapter.stop_voice_session(session_id).await { + Ok(()) => Ok(HttpResponse::Ok().json(serde_json::json!({"status": "stopped"}))), + Err(e) => { + Ok(HttpResponse::InternalServerError() + .json(serde_json::json!({"error": e.to_string()}))) + } + } +} + +#[actix_web::post("/api/sessions")] +async fn create_session(_data: web::Data) -> Result { + let session_id = Uuid::new_v4(); + Ok(HttpResponse::Ok().json(serde_json::json!({ + "session_id": session_id, + "title": "New Conversation", + "created_at": Utc::now() + }))) +} + +#[actix_web::get("/api/sessions")] +async fn get_sessions(data: web::Data) -> Result { + let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap(); + match data.orchestrator.get_user_sessions(user_id).await { + Ok(sessions) => Ok(HttpResponse::Ok().json(sessions)), + Err(e) => { + Ok(HttpResponse::InternalServerError() + .json(serde_json::json!({"error": e.to_string()}))) + } + } +} + +#[actix_web::get("/api/sessions/{session_id}")] +async fn get_session_history( + data: web::Data, + path: web::Path, +) -> Result { + let session_id = path.into_inner(); + let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap(); + + match Uuid::parse_str(&session_id) { + Ok(session_uuid) => match data + .orchestrator + .get_conversation_history(session_uuid, user_id) + .await + { + Ok(history) => Ok(HttpResponse::Ok().json(history)), + Err(e) => Ok(HttpResponse::InternalServerError() + .json(serde_json::json!({"error": e.to_string()}))), + }, + Err(_) => { + Ok(HttpResponse::BadRequest().json(serde_json::json!({"error": "Invalid session ID"}))) + } + } +} + +#[actix_web::post("/api/set_mode")] +async fn set_mode_handler( + data: web::Data, + info: web::Json>, +) -> Result { + let default_user = "default_user".to_string(); + let default_bot = "default_bot".to_string(); + let default_mode = "direct".to_string(); + + let user_id = info.get("user_id").unwrap_or(&default_user); + let bot_id = info.get("bot_id").unwrap_or(&default_bot); + let mode = info.get("mode").unwrap_or(&default_mode); + + if let Err(e) = data + .orchestrator + .set_user_answer_mode(user_id, bot_id, mode) + .await + { + return Ok( + HttpResponse::InternalServerError().json(serde_json::json!({"error": e.to_string()})) + ); + } + + Ok(HttpResponse::Ok().json(serde_json::json!({"status": "mode_updated"}))) +} + +#[actix_web::get("/")] +async fn index() -> Result { + let html = fs::read_to_string("templates/index.html") + .unwrap_or_else(|_| include_str!("../../static/index.html").to_string()); + Ok(HttpResponse::Ok().content_type("text/html").body(html)) +} + +#[actix_web::get("/static/{filename:.*}")] +async fn static_files(req: HttpRequest) -> Result { + let filename = req.match_info().query("filename"); + let path = format!("static/{}", filename); + + match fs::read(&path) { + Ok(content) => { + let content_type = match filename { + f if f.ends_with(".js") => "application/javascript", + f if f.ends_with(".css") => "text/css", + f if f.ends_with(".png") => "image/png", + f if f.ends_with(".jpg") | f.ends_with(".jpeg") => "image/jpeg", + _ => "text/plain", + }; + + Ok(HttpResponse::Ok().content_type(content_type).body(content)) + } + Err(_) => Ok(HttpResponse::NotFound().body("File not found")), + } +} + +use crate::shared::UserSession; +use redis::{AsyncCommands, Client}; +use serde_json; +use sqlx::{PgPool, Row}; +use std::sync::Arc; +use uuid::Uuid; + +pub struct SessionManager { + pub pool: PgPool, + pub redis: Option>, +} + +impl SessionManager { + pub fn new(pool: PgPool, redis: Option>) -> Self { + Self { pool, redis } + } + + pub async fn get_user_session( + &self, + user_id: Uuid, + bot_id: Uuid, + ) -> Result, Box> { + if let Some(redis_client) = &self.redis { + let mut conn = redis_client.get_multiplexed_async_connection().await?; + let cache_key = format!("session:{}:{}", user_id, bot_id); + let session_json: Option = conn.get(&cache_key).await?; + if let Some(json) = session_json { + if let Ok(session) = serde_json::from_str::(&json) { + return Ok(Some(session)); + } + } + } + + let session = sqlx::query_as::<_, UserSession>( + "SELECT * FROM user_sessions WHERE user_id = $1 AND bot_id = $2 ORDER BY updated_at DESC LIMIT 1", + ) + .bind(user_id) + .bind(bot_id) + .fetch_optional(&self.pool) + .await?; + + if let Some(ref session) = session { + if let Some(redis_client) = &self.redis { + let mut conn = redis_client.get_multiplexed_async_connection().await?; + let cache_key = format!("session:{}:{}", user_id, bot_id); + let session_json = serde_json::to_string(session)?; + let _: () = conn.set_ex(cache_key, session_json, 1800).await?; + } + } + + Ok(session) + } + + pub async fn create_session( + &self, + user_id: Uuid, + bot_id: Uuid, + title: &str, + ) -> Result> { + let session = sqlx::query_as::<_, UserSession>( + "INSERT INTO user_sessions (user_id, bot_id, title) VALUES ($1, $2, $3) RETURNING *", + ) + .bind(user_id) + .bind(bot_id) + .bind(title) + .fetch_one(&self.pool) + .await?; + + if let Some(redis_client) = &self.redis { + let mut conn = redis_client.get_multiplexed_async_connection().await?; + let cache_key = format!("session:{}:{}", user_id, bot_id); + let session_json = serde_json::to_string(&session)?; + let _: () = conn.set_ex(cache_key, session_json, 1800).await?; + } + + Ok(session) + } + + pub async fn save_message( + &self, + session_id: Uuid, + user_id: Uuid, + role: &str, + content: &str, + message_type: &str, + ) -> Result<(), Box> { + let message_count: i64 = + sqlx::query("SELECT COUNT(*) as count FROM message_history WHERE session_id = $1") + .bind(session_id) + .fetch_one(&self.pool) + .await? + .get("count"); + + sqlx::query( + "INSERT INTO message_history (session_id, user_id, role, content_encrypted, message_type, message_index) + VALUES ($1, $2, $3, $4, $5, $6)", + ) + .bind(session_id) + .bind(user_id) + .bind(role) + .bind(content) + .bind(message_type) + .bind(message_count + 1) + .execute(&self.pool) + .await?; + + sqlx::query("UPDATE user_sessions SET updated_at = NOW() WHERE id = $1") + .bind(session_id) + .execute(&self.pool) + .await?; + + if let Some(redis_client) = &self.redis { + if let Some(session_info) = + sqlx::query("SELECT user_id, bot_id FROM user_sessions WHERE id = $1") + .bind(session_id) + .fetch_optional(&self.pool) + .await? + { + let user_id: Uuid = session_info.get("user_id"); + let bot_id: Uuid = session_info.get("bot_id"); + let mut conn = redis_client.get_multiplexed_async_connection().await?; + let cache_key = format!("session:{}:{}", user_id, bot_id); + let _: () = conn.del(cache_key).await?; + } + } + + Ok(()) + } + + pub async fn get_conversation_history( + &self, + session_id: Uuid, + user_id: Uuid, + ) -> Result, Box> { + let messages = sqlx::query( + "SELECT role, content_encrypted FROM message_history + WHERE session_id = $1 AND user_id = $2 + ORDER BY message_index ASC", + ) + .bind(session_id) + .bind(user_id) + .fetch_all(&self.pool) + .await?; + + let history = messages + .into_iter() + .map(|row| (row.get("role"), row.get("content_encrypted"))) + .collect(); + + Ok(history) + } + + pub async fn get_user_sessions( + &self, + user_id: Uuid, + ) -> Result, Box> { + let sessions = sqlx::query_as::<_, UserSession>( + "SELECT * FROM user_sessions WHERE user_id = $1 ORDER BY updated_at DESC", + ) + .bind(user_id) + .fetch_all(&self.pool) + .await?; + Ok(sessions) + } + + pub async fn update_answer_mode( + &self, + user_id: &str, + bot_id: &str, + mode: &str, + ) -> Result<(), Box> { + let user_uuid = Uuid::parse_str(user_id)?; + let bot_uuid = Uuid::parse_str(bot_id)?; + + sqlx::query( + "UPDATE user_sessions + SET answer_mode = $1, updated_at = NOW() + WHERE user_id = $2 AND bot_id = $3", + ) + .bind(mode) + .bind(user_uuid) + .bind(bot_uuid) + .execute(&self.pool) + .await?; + + if let Some(redis_client) = &self.redis { + let mut conn = redis_client.get_multiplexed_async_connection().await?; + let cache_key = format!("session:{}:{}", user_uuid, bot_uuid); + let _: () = conn.del(cache_key).await?; + } + + Ok(()) + } + + pub async fn update_current_tool( + &self, + user_id: &str, + bot_id: &str, + tool_name: Option<&str>, + ) -> Result<(), Box> { + let user_uuid = Uuid::parse_str(user_id)?; + let bot_uuid = Uuid::parse_str(bot_id)?; + + sqlx::query( + "UPDATE user_sessions + SET current_tool = $1, updated_at = NOW() + WHERE user_id = $2 AND bot_id = $3", + ) + .bind(tool_name) + .bind(user_uuid) + .bind(bot_uuid) + .execute(&self.pool) + .await?; + + if let Some(redis_client) = &self.redis { + let mut conn = redis_client.get_multiplexed_async_connection().await?; + let cache_key = format!("session:{}:{}", user_uuid, bot_uuid); + let _: () = conn.del(cache_key).await?; + } + + Ok(()) + } + + pub async fn get_session_by_id( + &self, + session_id: Uuid, + ) -> Result, Box> { + if let Some(redis_client) = &self.redis { + let mut conn = redis_client.get_multiplexed_async_connection().await?; + let cache_key = format!("session_by_id:{}", session_id); + let session_json: Option = conn.get(&cache_key).await?; + if let Some(json) = session_json { + if let Ok(session) = serde_json::from_str::(&json) { + return Ok(Some(session)); + } + } + } + + let session = sqlx::query_as::<_, UserSession>("SELECT * FROM user_sessions WHERE id = $1") + .bind(session_id) + .fetch_optional(&self.pool) + .await?; + + if let Some(ref session) = session { + if let Some(redis_client) = &self.redis { + let mut conn = redis_client.get_multiplexed_async_connection().await?; + let cache_key = format!("session_by_id:{}", session_id); + let session_json = serde_json::to_string(session)?; + let _: () = conn.set_ex(cache_key, session_json, 1800).await?; + } + } + + Ok(session) + } + + pub async fn cleanup_old_sessions( + &self, + days_old: i32, + ) -> Result> { + let result = sqlx::query( + "DELETE FROM user_sessions + WHERE updated_at < NOW() - INTERVAL '1 day' * $1", + ) + .bind(days_old) + .execute(&self.pool) + .await?; + Ok(result.rows_affected()) + } + + pub async fn set_current_tool( + &self, + user_id: &str, + bot_id: &str, + tool_name: Option, + ) -> Result<(), Box> { + let user_uuid = Uuid::parse_str(user_id)?; + let bot_uuid = Uuid::parse_str(bot_id)?; + + sqlx::query( + "UPDATE user_sessions + SET current_tool = $1, updated_at = NOW() + WHERE user_id = $2 AND bot_id = $3", + ) + .bind(tool_name) + .bind(user_uuid) + .bind(bot_uuid) + .execute(&self.pool) + .await?; + + if let Some(redis_client) = &self.redis { + let mut conn = redis_client.get_multiplexed_async_connection().await?; + let cache_key = format!("session:{}:{}", user_uuid, bot_uuid); + let _: () = conn.del(cache_key).await?; + } + + Ok(()) + } +} + +impl Clone for SessionManager { + fn clone(&self) -> Self { + Self { + pool: self.pool.clone(), + redis: self.redis.clone(), + } + } +} + +// src/tools/mod.rs +use async_trait::async_trait; +use redis::AsyncCommands; +use rhai::{Engine, Scope}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::{mpsc, Mutex}; +use uuid::Uuid; + +use crate::channels::ChannelAdapter; +use crate::session::SessionManager; +use crate::shared::BotResponse; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolResult { + pub success: bool, + pub output: String, + pub requires_input: bool, + pub session_id: String, +} + +#[derive(Clone)] +pub struct Tool { + pub name: String, + pub description: String, + pub parameters: HashMap, + pub script: String, +} + +#[async_trait] +pub trait ToolExecutor: Send + Sync { + async fn execute( + &self, + tool_name: &str, + session_id: &str, + user_id: &str, + ) -> Result>; + async fn provide_input( + &self, + session_id: &str, + input: &str, + ) -> Result<(), Box>; + async fn get_output(&self, session_id: &str) + -> Result, Box>; + async fn is_waiting_for_input( + &self, + session_id: &str, + ) -> Result>; +} + +pub struct RedisToolExecutor { + redis_client: redis::Client, + web_adapter: Arc, + voice_adapter: Arc, + whatsapp_adapter: Arc, +} + +impl RedisToolExecutor { + pub fn new( + redis_url: &str, + web_adapter: Arc, + voice_adapter: Arc, + whatsapp_adapter: Arc, + ) -> Result> { + let client = redis::Client::open(redis_url)?; + Ok(Self { + redis_client: client, + web_adapter, + voice_adapter, + whatsapp_adapter, + }) + } + + async fn send_tool_message( + &self, + session_id: &str, + user_id: &str, + channel: &str, + message: &str, + ) -> Result<(), Box> { + let response = BotResponse { + bot_id: "tool_bot".to_string(), + user_id: user_id.to_string(), + session_id: session_id.to_string(), + channel: channel.to_string(), + content: message.to_string(), + message_type: "tool".to_string(), + stream_token: None, + is_complete: true, + }; + + match channel { + "web" => self.web_adapter.send_message(response).await, + "voice" => self.voice_adapter.send_message(response).await, + "whatsapp" => self.whatsapp_adapter.send_message(response).await, + _ => Ok(()), + } + } + + fn create_rhai_engine(&self, session_id: String, user_id: String, channel: String) -> Engine { + let mut engine = Engine::new(); + + // Clone for TALK function + let tool_executor = Arc::new(( + self.redis_client.clone(), + self.web_adapter.clone(), + self.voice_adapter.clone(), + self.whatsapp_adapter.clone(), + )); + + let session_id_clone = session_id.clone(); + let user_id_clone = user_id.clone(); + let channel_clone = channel.clone(); + + engine.register_fn("talk", move |message: String| { + let tool_executor = Arc::clone(&tool_executor); + let session_id = session_id_clone.clone(); + let user_id = user_id_clone.clone(); + let channel = channel_clone.clone(); + + tokio::spawn(async move { + let (redis_client, web_adapter, voice_adapter, whatsapp_adapter) = &*tool_executor; + + let response = BotResponse { + bot_id: "tool_bot".to_string(), + user_id: user_id.clone(), + session_id: session_id.clone(), + channel: channel.clone(), + content: message.clone(), + message_type: "tool".to_string(), + stream_token: None, + is_complete: true, + }; + + let result = match channel.as_str() { + "web" => web_adapter.send_message(response).await, + "voice" => voice_adapter.send_message(response).await, + "whatsapp" => whatsapp_adapter.send_message(response).await, + _ => Ok(()), + }; + + if let Err(e) = result { + log::error!("Failed to send tool message: {}", e); + } + + if let Ok(mut conn) = redis_client.get_async_connection().await { + let output_key = format!("tool:{}:output", session_id); + let _ = conn.lpush(&output_key, &message).await; + } + }); + }); + + let hear_executor = self.redis_client.clone(); + let session_id_clone = session_id.clone(); + + engine.register_fn("hear", move || -> String { + let hear_executor = hear_executor.clone(); + let session_id = session_id_clone.clone(); + + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async move { + match hear_executor.get_async_connection().await { + Ok(mut conn) => { + let input_key = format!("tool:{}:input", session_id); + let waiting_key = format!("tool:{}:waiting", session_id); + + let _ = conn.set_ex(&waiting_key, "true", 300).await; + let result: Option<(String, String)> = + conn.brpop(&input_key, 30).await.ok().flatten(); + let _ = conn.del(&waiting_key).await; + + result + .map(|(_, input)| input) + .unwrap_or_else(|| "timeout".to_string()) + } + Err(e) => { + log::error!("HEAR Redis error: {}", e); + "error".to_string() + } + } + }) + }); + + engine + } + + async fn cleanup_session(&self, session_id: &str) -> Result<(), Box> { + let mut conn = self.redis_client.get_multiplexed_async_connection().await?; + + let keys = vec![ + format!("tool:{}:output", session_id), + format!("tool:{}:input", session_id), + format!("tool:{}:waiting", session_id), + format!("tool:{}:active", session_id), + ]; + + for key in keys { + let _: () = conn.del(&key).await?; + } + + Ok(()) + } +} + +#[async_trait] +impl ToolExecutor for RedisToolExecutor { + async fn execute( + &self, + tool_name: &str, + session_id: &str, + user_id: &str, + ) -> Result> { + let tool = get_tool(tool_name).ok_or_else(|| format!("Tool not found: {}", tool_name))?; + + let mut conn = self.redis_client.get_multiplexed_async_connection().await?; + let session_key = format!("tool:{}:session", session_id); + let session_data = serde_json::json!({ + "user_id": user_id, + "tool_name": tool_name, + "started_at": chrono::Utc::now().to_rfc3339(), + }); + conn.set_ex(&session_key, session_data.to_string(), 3600) + .await?; + + let active_key = format!("tool:{}:active", session_id); + conn.set_ex(&active_key, "true", 3600).await?; + + let channel = "web"; + let _engine = self.create_rhai_engine( + session_id.to_string(), + user_id.to_string(), + channel.to_string(), + ); + + let redis_clone = self.redis_client.clone(); + let web_adapter_clone = self.web_adapter.clone(); + let voice_adapter_clone = self.voice_adapter.clone(); + let whatsapp_adapter_clone = self.whatsapp_adapter.clone(); + let session_id_clone = session_id.to_string(); + let user_id_clone = user_id.to_string(); + let tool_script = tool.script.clone(); + + tokio::spawn(async move { + let mut engine = Engine::new(); + let mut scope = Scope::new(); + + let redis_client = redis_clone.clone(); + let web_adapter = web_adapter_clone.clone(); + let voice_adapter = voice_adapter_clone.clone(); + let whatsapp_adapter = whatsapp_adapter_clone.clone(); + let session_id = session_id_clone.clone(); + let user_id = user_id_clone.clone(); + + engine.register_fn("talk", move |message: String| { + let redis_client = redis_client.clone(); + let web_adapter = web_adapter.clone(); + let voice_adapter = voice_adapter.clone(); + let whatsapp_adapter = whatsapp_adapter.clone(); + let session_id = session_id.clone(); + let user_id = user_id.clone(); + + tokio::spawn(async move { + let channel = "web"; + + let response = BotResponse { + bot_id: "tool_bot".to_string(), + user_id: user_id.clone(), + session_id: session_id.clone(), + channel: channel.to_string(), + content: message.clone(), + message_type: "tool".to_string(), + stream_token: None, + is_complete: true, + }; + + let send_result = match channel { + "web" => web_adapter.send_message(response).await, + "voice" => voice_adapter.send_message(response).await, + "whatsapp" => whatsapp_adapter.send_message(response).await, + _ => Ok(()), + }; + + if let Err(e) = send_result { + log::error!("Failed to send tool message: {}", e); + } + + if let Ok(mut conn) = redis_client.get_async_connection().await { + let output_key = format!("tool:{}:output", session_id); + let _ = conn.lpush(&output_key, &message).await; + } + }); + }); + + let hear_redis = redis_clone.clone(); + let session_id_hear = session_id.clone(); + engine.register_fn("hear", move || -> String { + let hear_redis = hear_redis.clone(); + let session_id = session_id_hear.clone(); + + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async move { + match hear_redis.get_async_connection().await { + Ok(mut conn) => { + let input_key = format!("tool:{}:input", session_id); + let waiting_key = format!("tool:{}:waiting", session_id); + + let _ = conn.set_ex(&waiting_key, "true", 300).await; + let result: Option<(String, String)> = + conn.brpop(&input_key, 30).await.ok().flatten(); + let _ = conn.del(&waiting_key).await; + + result + .map(|(_, input)| input) + .unwrap_or_else(|| "timeout".to_string()) + } + Err(_) => "error".to_string(), + } + }) + }); + + match engine.eval_with_scope::<()>(&mut scope, &tool_script) { + Ok(_) => { + log::info!( + "Tool {} completed successfully for session {}", + tool_name, + session_id + ); + + let completion_msg = + "🛠️ Tool execution completed. How can I help you with anything else?"; + let response = BotResponse { + bot_id: "tool_bot".to_string(), + user_id: user_id_clone, + session_id: session_id_clone.clone(), + channel: "web".to_string(), + content: completion_msg.to_string(), + message_type: "tool_complete".to_string(), + stream_token: None, + is_complete: true, + }; + + let _ = web_adapter_clone.send_message(response).await; + } + Err(e) => { + log::error!("Tool execution failed: {}", e); + + let error_msg = format!("❌ Tool error: {}", e); + let response = BotResponse { + bot_id: "tool_bot".to_string(), + user_id: user_id_clone, + session_id: session_id_clone.clone(), + channel: "web".to_string(), + content: error_msg, + message_type: "tool_error".to_string(), + stream_token: None, + is_complete: true, + }; + + let _ = web_adapter_clone.send_message(response).await; + } + } + + if let Ok(mut conn) = redis_clone.get_async_connection().await { + let active_key = format!("tool:{}:active", session_id_clone); + let _ = conn.del(&active_key).await; + } + }); + + Ok(ToolResult { + success: true, + output: format!( + "🛠️ Starting {} tool. Please follow the tool's instructions.", + tool_name + ), + requires_input: true, + session_id: session_id.to_string(), + }) + } + + async fn provide_input( + &self, + session_id: &str, + input: &str, + ) -> Result<(), Box> { + let mut conn = self.redis_client.get_multiplexed_async_connection().await?; + let input_key = format!("tool:{}:input", session_id); + conn.lpush(&input_key, input).await?; + Ok(()) + } + + async fn get_output( + &self, + session_id: &str, + ) -> Result, Box> { + let mut conn = self.redis_client.get_multiplexed_async_connection().await?; + let output_key = format!("tool:{}:output", session_id); + let messages: Vec = conn.lrange(&output_key, 0, -1).await?; + let _: () = conn.del(&output_key).await?; + Ok(messages) + } + + async fn is_waiting_for_input( + &self, + session_id: &str, + ) -> Result> { + let mut conn = self.redis_client.get_multiplexed_async_connection().await?; + let waiting_key = format!("tool:{}:waiting", session_id); + let exists: bool = conn.exists(&waiting_key).await?; + Ok(exists) + } +} + +fn get_tool(name: &str) -> Option { + match name { + "calculator" => Some(Tool { + name: "calculator".to_string(), + description: "Perform mathematical calculations".to_string(), + parameters: HashMap::from([ + ("operation".to_string(), "add|subtract|multiply|divide".to_string()), + ("a".to_string(), "number".to_string()), + ("b".to_string(), "number".to_string()), + ]), + script: r#" + let TALK = |message| { + talk(message); + }; + + let HEAR = || { + hear() + }; + + TALK("🔢 Calculator started!"); + TALK("Please enter the first number:"); + let a = HEAR(); + TALK("Please enter the second number:"); + let b = HEAR(); + TALK("Choose operation: add, subtract, multiply, or divide:"); + let op = HEAR(); + + let num_a = a.to_float(); + let num_b = b.to_float(); + + if op == "add" { + let result = num_a + num_b; + TALK("✅ Result: " + a + " + " + b + " = " + result); + } else if op == "subtract" { + let result = num_a - num_b; + TALK("✅ Result: " + a + " - " + b + " = " + result); + } else if op == "multiply" { + let result = num_a * num_b; + TALK("✅ Result: " + a + " × " + b + " = " + result); + } else if op == "divide" { + if num_b != 0.0 { + let result = num_a / num_b; + TALK("✅ Result: " + a + " ÷ " + b + " = " + result); + } else { + TALK("❌ Error: Cannot divide by zero!"); + } + } else { + TALK("❌ Error: Invalid operation. Please use: add, subtract, multiply, or divide"); + } + + TALK("Calculator session completed. Thank you!"); + "#.to_string(), + }), + _ => None, + } +} + +#[derive(Clone)] +pub struct ToolManager { + tools: HashMap, + waiting_responses: Arc>>>, +} + +impl ToolManager { + pub fn new() -> Self { + let mut tools = HashMap::new(); + + let calculator_tool = Tool { + name: "calculator".to_string(), + description: "Perform calculations".to_string(), + parameters: HashMap::from([ + ( + "operation".to_string(), + "add|subtract|multiply|divide".to_string(), + ), + ("a".to_string(), "number".to_string()), + ("b".to_string(), "number".to_string()), + ]), + script: r#" + TALK("Calculator started. Enter first number:"); + let a = HEAR(); + TALK("Enter second number:"); + let b = HEAR(); + TALK("Operation (add/subtract/multiply/divide):"); + let op = HEAR(); + + let num_a = a.parse::().unwrap(); + let num_b = b.parse::().unwrap(); + let result = if op == "add" { + num_a + num_b + } else if op == "subtract" { + num_a - num_b + } else if op == "multiply" { + num_a * num_b + } else if op == "divide" { + if num_b == 0.0 { + TALK("Cannot divide by zero"); + return; + } + num_a / num_b + } else { + TALK("Invalid operation"); + return; + }; + TALK("Result: ".to_string() + &result.to_string()); + "# + .to_string(), + }; + + tools.insert(calculator_tool.name.clone(), calculator_tool); + Self { + tools, + waiting_responses: Arc::new(Mutex::new(HashMap::new())), + } + } + + pub fn get_tool(&self, name: &str) -> Option<&Tool> { + self.tools.get(name) + } + + pub fn list_tools(&self) -> Vec { + self.tools.keys().cloned().collect() + } + + pub async fn execute_tool( + &self, + tool_name: &str, + session_id: &str, + user_id: &str, + ) -> Result> { + let tool = self.get_tool(tool_name).ok_or("Tool not found")?; + + Ok(ToolResult { + success: true, + output: format!("Tool {} started for user {}", tool_name, user_id), + requires_input: true, + session_id: session_id.to_string(), + }) + } + + pub async fn is_tool_waiting( + &self, + session_id: &str, + ) -> Result> { + let waiting = self.waiting_responses.lock().await; + Ok(waiting.contains_key(session_id)) + } + + pub async fn provide_input( + &self, + session_id: &str, + input: &str, + ) -> Result<(), Box> { + self.provide_user_response(session_id, "default_bot", input.to_string()) + .await + } + + pub async fn get_tool_output( + &self, + session_id: &str, + ) -> Result, Box> { + Ok(vec![]) + } + + pub async fn execute_tool_with_session( + &self, + tool_name: &str, + user_id: &str, + bot_id: &str, + session_manager: SessionManager, + channel_sender: mpsc::Sender, + ) -> Result<(), Box> { + let tool = self.get_tool(tool_name).ok_or("Tool not found")?; + session_manager + .set_current_tool(user_id, bot_id, Some(tool_name.to_string())) + .await?; + + let user_id = user_id.to_string(); + let bot_id = bot_id.to_string(); + let script = tool.script.clone(); + let session_manager_clone = session_manager.clone(); + let waiting_responses = self.waiting_responses.clone(); + + tokio::spawn(async move { + let mut engine = rhai::Engine::new(); + let (talk_tx, mut talk_rx) = mpsc::channel(100); + let (hear_tx, mut hear_rx) = mpsc::channel(100); + + { + let key = format!("{}:{}", user_id, bot_id); + let mut waiting = waiting_responses.lock().await; + waiting.insert(key, hear_tx); + } + + let channel_sender_clone = channel_sender.clone(); + let user_id_clone = user_id.clone(); + let bot_id_clone = bot_id.clone(); + + let talk_tx_clone = talk_tx.clone(); + engine.register_fn("TALK", move |message: String| { + let tx = talk_tx_clone.clone(); + tokio::spawn(async move { + let _ = tx.send(message).await; + }); + }); + + let hear_rx_mutex = Arc::new(Mutex::new(hear_rx)); + engine.register_fn("HEAR", move || { + let hear_rx = hear_rx_mutex.clone(); + tokio::task::block_in_place(|| { + tokio::runtime::Handle::current().block_on(async move { + let mut receiver = hear_rx.lock().await; + receiver.recv().await.unwrap_or_default() + }) + }) + }); + + let script_result = + tokio::task::spawn_blocking(move || engine.eval::<()>(&script)).await; + + if let Ok(Err(e)) = script_result { + let error_response = BotResponse { + bot_id: bot_id_clone.clone(), + user_id: user_id_clone.clone(), + session_id: Uuid::new_v4().to_string(), + channel: "test".to_string(), + content: format!("Tool error: {}", e), + message_type: "text".to_string(), + stream_token: None, + is_complete: true, + }; + let _ = channel_sender_clone.send(error_response).await; + } + + while let Some(message) = talk_rx.recv().await { + let response = BotResponse { + bot_id: bot_id.clone(), + user_id: user_id.clone(), + session_id: Uuid::new_v4().to_string(), + channel: "test".to_string(), + content: message, + message_type: "text".to_string(), + stream_token: None, + is_complete: true, + }; + let _ = channel_sender.send(response).await; + } + + let _ = session_manager_clone + .set_current_tool(&user_id, &bot_id, None) + .await; + }); + + Ok(()) + } + + pub async fn provide_user_response( + &self, + user_id: &str, + bot_id: &str, + response: String, + ) -> Result<(), Box> { + let key = format!("{}:{}", user_id, bot_id); + let mut waiting = self.waiting_responses.lock().await; + if let Some(tx) = waiting.get_mut(&key) { + let _ = tx.send(response).await; + waiting.remove(&key); + } + Ok(()) + } +} + +impl Default for ToolManager { + fn default() -> Self { + Self::new() + } +} + +. +└── src + ├── auth + │   └── mod.rs + ├── automation + │   └── mod.rs + ├── basic + │   ├── keywords + │   │   ├── create_draft.rs + │   │   ├── create_site.rs + │   │   ├── find.rs + │   │   ├── first.rs + │   │   ├── format.rs + │   │   ├── for_next.rs + │   │   ├── get.rs + │   │   ├── get_website.rs + │   │   ├── last.rs + │   │   ├── llm_keyword.rs + │   │   ├── mod.rs + │   │   ├── on.rs + │   │   ├── print.rs + │   │   ├── set.rs + │   │   ├── set_schedule.rs + │   │   └── wait.rs + │   └── mod.rs + ├── bot + │   └── mod.rs + ├── channels + │   └── mod.rs + ├── chart + │   └── mod.rs + ├── config + │   └── mod.rs + ├── context + │   └── mod.rs + ├── email + │   └── mod.rs + ├── file + │   └── mod.rs + ├── llm + │   ├── llm_generic.rs + │   ├── llm_local.rs + │   ├── llm_provider.rs + │   ├── llm.rs + │   └── mod.rs + ├── main.rs + ├── org + │   └── mod.rs + ├── session + │   └── mod.rs + ├── shared + │   ├── models.rs + │   ├── mod.rs + │   ├── state.rs + │   └── utils.rs + ├── tests + │   ├── integration_email_list.rs + │   ├── integration_file_list_test.rs + │   └── integration_file_upload_test.rs + ├── tools + │   └── mod.rs + ├── web_automation + │   └── mod.rs + └── whatsapp + └── mod.rs + +21 directories, 44 files diff --git a/scripts/dev/llm_fix.sh b/scripts/dev/llm_fix.sh new file mode 100755 index 0000000..d30f1d7 --- /dev/null +++ b/scripts/dev/llm_fix.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" +OUTPUT_FILE="$SCRIPT_DIR/llm_context.txt" + +echo "Consolidated LLM Context" > "$OUTPUT_FILE" + +prompts=( + "../../prompts/dev/general.md" + "../../Cargo.toml" + "../../prompts/dev/fix.md" +) + +for file in "${prompts[@]}"; do + cat "$file" >> "$OUTPUT_FILE" + echo "" >> "$OUTPUT_FILE" +done + +dirs=( + "src/shared" + "src/bot" + "src/session" + "src/tools" +) + +for dir in "${dirs[@]}"; do + find "$PROJECT_ROOT/$dir" -name "*.rs" | while read file; do + cat "$file" >> "$OUTPUT_FILE" + echo "" >> "$OUTPUT_FILE" + done +done + +cd "$PROJECT_ROOT" +tree -P '*.rs' -I 'target|*.lock' --prune | grep -v '[0-9] directories$' >> "$OUTPUT_FILE" diff --git a/scripts/dev/source_tree.sh b/scripts/dev/source_tree.sh new file mode 100644 index 0000000..ce788f6 --- /dev/null +++ b/scripts/dev/source_tree.sh @@ -0,0 +1,2 @@ +# apt install tree +tree -P '*.rs' -I 'target|*.lock' --prune | grep -v '[0-9] directories$' diff --git a/src/auth/mod.rs b/src/auth/mod.rs index a43c8e9..92890ae 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -2,19 +2,19 @@ use argon2::{ password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString}, Argon2, }; -use redis::aio::Connection as ConnectionManager; -use redis::AsyncCommands; -use sqlx::PgPool; +use redis::Client; +use sqlx::{PgPool, Row}; // <-- required for .get() use std::sync::Arc; use uuid::Uuid; pub struct AuthService { pub pool: PgPool, - pub redis: Option>, + pub redis: Option>, } impl AuthService { - pub fn new(pool: PgPool, redis: Option>) -> Self { + #[allow(clippy::new_without_default)] + pub fn new(pool: PgPool, redis: Option>) -> Self { Self { pool, redis } } @@ -23,17 +23,6 @@ impl AuthService { username: &str, password: &str, ) -> Result, Box> { - // Try Redis cache first - if let Some(redis) = &self.redis { - let cache_key = format!("auth:user:{}", username); - if let Ok(user_id_str) = redis.clone().get::<_, String>(cache_key).await { - if let Ok(user_id) = Uuid::parse_str(&user_id_str) { - return Ok(Some(user_id)); - } - } - } - - // Fallback to database let user = sqlx::query( "SELECT id, password_hash FROM users WHERE username = $1 AND is_active = true", ) @@ -44,22 +33,17 @@ impl AuthService { if let Some(row) = user { let user_id: Uuid = row.get("id"); let password_hash: String = row.get("password_hash"); - let parsed_hash = PasswordHash::new(&password_hash)?; - if Argon2::default() - .verify_password(password.as_bytes(), &parsed_hash) - .is_ok() - { - // Cache in Redis - if let Some(redis) = &self.redis { - let cache_key = format!("auth:user:{}", username); - let _: () = redis - .clone() - .set_ex(cache_key, user_id.to_string(), 3600) - .await?; + + if let Ok(parsed_hash) = PasswordHash::new(&password_hash) { + if Argon2::default() + .verify_password(password.as_bytes(), &parsed_hash) + .is_ok() + { + return Ok(Some(user_id)); } - return Ok(Some(user_id)); } } + Ok(None) } @@ -71,20 +55,83 @@ impl AuthService { ) -> Result> { let salt = SaltString::generate(&mut OsRng); let argon2 = Argon2::default(); - let password_hash = argon2 - .hash_password(password.as_bytes(), &salt)? - .to_string(); + let password_hash = match argon2.hash_password(password.as_bytes(), &salt) { + Ok(ph) => ph.to_string(), + Err(e) => { + return Err(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + e.to_string(), + ))) + } + }; - let user_id = sqlx::query( + let row = sqlx::query( "INSERT INTO users (username, email, password_hash) VALUES ($1, $2, $3) RETURNING id", ) .bind(username) .bind(email) - .bind(password_hash) + .bind(&password_hash) .fetch_one(&self.pool) - .await? - .get("id"); + .await?; - Ok(user_id) + Ok(row.get::("id")) + } + + pub async fn delete_user_cache( + &self, + username: &str, + ) -> Result<(), Box> { + if let Some(redis_client) = &self.redis { + let mut conn = redis_client.get_multiplexed_async_connection().await?; + let cache_key = format!("auth:user:{}", username); + + let _: () = redis::Cmd::del(&cache_key).query_async(&mut conn).await?; + } + Ok(()) + } + + pub async fn update_user_password( + &self, + user_id: Uuid, + new_password: &str, + ) -> Result<(), Box> { + let salt = SaltString::generate(&mut OsRng); + let argon2 = Argon2::default(); + let password_hash = match argon2.hash_password(new_password.as_bytes(), &salt) { + Ok(ph) => ph.to_string(), + Err(e) => { + return Err(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + e.to_string(), + ))) + } + }; + + sqlx::query("UPDATE users SET password_hash = $1, updated_at = NOW() WHERE id = $2") + .bind(&password_hash) + .bind(user_id) + .execute(&self.pool) + .await?; + + if let Some(user_row) = sqlx::query("SELECT username FROM users WHERE id = $1") + .bind(user_id) + .fetch_optional(&self.pool) + .await? + { + let username: String = user_row.get("username"); + self.delete_user_cache(&username).await?; + } + + Ok(()) + } +} +{{END_REWRITTEN_CODE}} + +impl Clone for AuthService { + fn clone(&self) -> Self { + Self { + pool: self.pool.clone(), + redis: self.redis.clone(), + } } } diff --git a/src/automation/mod.rs b/src/automation/mod.rs index 5aec59e..f1b2b47 100644 --- a/src/automation/mod.rs +++ b/src/automation/mod.rs @@ -1,6 +1,6 @@ -use crate::models::automation_model::{Automation, TriggerKind}; use crate::basic::ScriptService; -use crate::state::AppState; + use crate::shared::models::automation_model::{Automation, TriggerKind}; +use crate::shared::state::AppState; use chrono::Datelike; use chrono::Timelike; use chrono::{DateTime, Utc}; @@ -128,7 +128,6 @@ impl AutomationService { ) .execute(pool) .await - { error!( "Failed to update last_triggered for automation {}: {}", diff --git a/src/basic/keywords/create_draft.rs b/src/basic/keywords/create_draft.rs index 7e866c5..8516b4e 100644 --- a/src/basic/keywords/create_draft.rs +++ b/src/basic/keywords/create_draft.rs @@ -1,6 +1,6 @@ use crate::email::save_email_draft; use crate::email::{fetch_latest_sent_to, SaveDraftRequest}; -use crate::state::AppState; +use crate::shared::state::AppState; use rhai::Dynamic; use rhai::Engine; diff --git a/src/basic/keywords/create_site.rs b/src/basic/keywords/create_site.rs index cc63c3f..9a6e399 100644 --- a/src/basic/keywords/create_site.rs +++ b/src/basic/keywords/create_site.rs @@ -7,8 +7,8 @@ use std::fs; use std::io::Read; use std::path::PathBuf; -use crate::state::AppState; -use crate::utils; +use crate::shared::state::AppState; +use crate::shared::utils; pub fn create_site_keyword(state: &AppState, engine: &mut Engine) { let state_clone = state.clone(); diff --git a/src/basic/keywords/find.rs b/src/basic/keywords/find.rs index 51b08de..6c9935e 100644 --- a/src/basic/keywords/find.rs +++ b/src/basic/keywords/find.rs @@ -4,10 +4,10 @@ use rhai::Engine; use serde_json::{json, Value}; use sqlx::PgPool; -use crate::state::AppState; -use crate::utils; -use crate::utils::row_to_json; -use crate::utils::to_array; +use crate::shared::state::AppState; +use crate::shared::utils; +use crate::shared::utils::row_to_json; +use crate::shared::utils::to_array; pub fn find_keyword(state: &AppState, engine: &mut Engine) { let db = state.db_custom.clone(); diff --git a/src/basic/keywords/for_next.rs b/src/basic/keywords/for_next.rs index b1ca46a..1fcdba9 100644 --- a/src/basic/keywords/for_next.rs +++ b/src/basic/keywords/for_next.rs @@ -1,4 +1,4 @@ -use crate::state::AppState; +use crate::shared::state::AppState; use log::info; use rhai::Dynamic; use rhai::Engine; diff --git a/src/basic/keywords/get.rs b/src/basic/keywords/get.rs index 0d00d5b..c0b0770 100644 --- a/src/basic/keywords/get.rs +++ b/src/basic/keywords/get.rs @@ -1,6 +1,6 @@ use log::info; -use crate::state::AppState; +use crate::shared::state::AppState; use reqwest::{self, Client}; use rhai::{Dynamic, Engine}; use scraper::{Html, Selector}; diff --git a/src/basic/keywords/llm_keyword.rs b/src/basic/keywords/llm_keyword.rs index 0e38057..e4f06f1 100644 --- a/src/basic/keywords/llm_keyword.rs +++ b/src/basic/keywords/llm_keyword.rs @@ -1,6 +1,6 @@ use log::info; -use crate::{state::AppState, utils::call_llm}; +use crate::{shared::state::AppState, utils::call_llm}; use rhai::{Dynamic, Engine}; pub fn llm_keyword(state: &AppState, engine: &mut Engine) { diff --git a/src/basic/keywords/on.rs b/src/basic/keywords/on.rs index a048866..9d07afb 100644 --- a/src/basic/keywords/on.rs +++ b/src/basic/keywords/on.rs @@ -4,8 +4,8 @@ use rhai::Engine; use serde_json::{json, Value}; use sqlx::PgPool; -use crate::models::automation_model::TriggerKind; -use crate::state::AppState; +use crate::shared::models::automation_model::TriggerKind; +use crate::shared::state::AppState; pub fn on_keyword(state: &AppState, engine: &mut Engine) { let db = state.db_custom.clone(); diff --git a/src/basic/keywords/print.rs b/src/basic/keywords/print.rs index 30fd4ea..1624822 100644 --- a/src/basic/keywords/print.rs +++ b/src/basic/keywords/print.rs @@ -2,7 +2,7 @@ use log::info; use rhai::Dynamic; use rhai::Engine; -use crate::state::AppState; +use crate::shared::state::AppState; pub fn print_keyword(_state: &AppState, engine: &mut Engine) { // PRINT command diff --git a/src/basic/keywords/set.rs b/src/basic/keywords/set.rs index ca5f301..20900af 100644 --- a/src/basic/keywords/set.rs +++ b/src/basic/keywords/set.rs @@ -5,8 +5,8 @@ use serde_json::{json, Value}; use sqlx::PgPool; use std::error::Error; -use crate::state::AppState; -use crate::utils; +use crate::shared::state::AppState; +use crate::shared::utils; pub fn set_keyword(state: &AppState, engine: &mut Engine) { let db = state.db_custom.clone(); diff --git a/src/basic/keywords/set_schedule.rs b/src/basic/keywords/set_schedule.rs index a83db27..710d8a4 100644 --- a/src/basic/keywords/set_schedule.rs +++ b/src/basic/keywords/set_schedule.rs @@ -4,8 +4,8 @@ use rhai::Engine; use serde_json::{json, Value}; use sqlx::PgPool; -use crate::models::automation_model::TriggerKind; -use crate::state::AppState; +use crate::shared::models::automation_model::TriggerKind; +use crate::shared::state::AppState; pub fn set_schedule_keyword(state: &AppState, engine: &mut Engine) { let db = state.db_custom.clone(); diff --git a/src/basic/keywords/wait.rs b/src/basic/keywords/wait.rs index ad6d3a0..c84c301 100644 --- a/src/basic/keywords/wait.rs +++ b/src/basic/keywords/wait.rs @@ -1,4 +1,4 @@ -use crate::state::AppState; +use crate::shared::state::AppState; use log::info; use rhai::{Dynamic, Engine}; use std::thread; diff --git a/src/bot/mod.rs b/src/bot/mod.rs index 330c840..823834a 100644 --- a/src/bot/mod.rs +++ b/src/bot/mod.rs @@ -1,8 +1,8 @@ use actix_web::{web, HttpRequest, HttpResponse, Result}; +use actix_ws::Message as WsMessage; use chrono::Utc; use langchain_rust::{ chain::{Chain, LLMChain}, - embedding::openai::openai_embedder::OpenAiEmbedder, llm::openai::OpenAI, memory::SimpleMemory, prompt_args, @@ -11,19 +11,23 @@ use langchain_rust::{ vectorstore::{VecStoreOptions, VectorStore}, }; use log::info; +use serde_json; use std::collections::HashMap; use std::fs; use std::sync::Arc; use tokio::sync::{mpsc, Mutex}; use uuid::Uuid; -use services::auth::AuthService; -use services::channels::{ChannelAdapter, VoiceAdapter, WebChannelAdapter}; -use services::chart::ChartGenerator; -use services::llm::{AnthropicClient, LLMProvider, MockLLMProvider, OpenAIClient}; -use services::session::SessionManager; -use services::tools::ToolManager; -use services::whatsapp::WhatsAppAdapter; +use crate::{ + auth::AuthService, + channels::{ChannelAdapter, VoiceAdapter, WebChannelAdapter}, + chart::ChartGenerator, + llm::LLMProvider, + session::SessionManager, + shared::{BotResponse, UserMessage, UserSession}, + tools::ToolManager, + whatsapp::WhatsAppAdapter, +}; pub struct BotOrchestrator { session_manager: SessionManager, @@ -33,18 +37,18 @@ pub struct BotOrchestrator { channels: HashMap>, response_channels: Arc>>>, chart_generator: Option>, - vector_store: Option>>, + vector_store: Option>, sql_chain: Option>, } impl BotOrchestrator { - fn new( + pub fn new( session_manager: SessionManager, tool_manager: ToolManager, llm_provider: Arc, auth_service: AuthService, chart_generator: Option>, - vector_store: Option>>, + vector_store: Option>, sql_chain: Option>, ) -> Self { Self { @@ -60,11 +64,11 @@ impl BotOrchestrator { } } - fn add_channel(&mut self, channel_type: &str, adapter: Arc) { + pub fn add_channel(&mut self, channel_type: &str, adapter: Arc) { self.channels.insert(channel_type.to_string(), adapter); } - async fn register_response_channel( + pub async fn register_response_channel( &self, session_id: String, sender: mpsc::Sender, @@ -75,22 +79,22 @@ impl BotOrchestrator { .insert(session_id, sender); } - async fn set_user_answer_mode( + pub async fn set_user_answer_mode( &self, user_id: &str, bot_id: &str, mode: &str, - ) -> Result<(), Box> { + ) -> Result<(), Box> { self.session_manager .update_answer_mode(user_id, bot_id, mode) .await?; Ok(()) } - async fn process_message( + pub async fn process_message( &self, message: UserMessage, - ) -> Result<(), Box> { + ) -> Result<(), Box> { info!( "Processing message from channel: {}, user: {}", message.channel, message.user_id @@ -113,7 +117,6 @@ impl BotOrchestrator { } }; - // Check if we're in tool mode and there's an active tool if session.answer_mode == "tool" && session.current_tool.is_some() { self.tool_manager .provide_user_response(&message.user_id, &message.bot_id, message.content.clone()) @@ -131,7 +134,6 @@ impl BotOrchestrator { ) .await?; - // Handle different answer modes with LangChain integration let response_content = match session.answer_mode.as_str() { "document" => self.document_mode_handler(&message, &session).await?, "chart" => self.chart_mode_handler(&message, &session).await?, @@ -166,7 +168,7 @@ impl BotOrchestrator { &self, message: &UserMessage, session: &UserSession, - ) -> Result> { + ) -> Result> { if let Some(vector_store) = &self.vector_store { let similar_docs = vector_store .similarity_search(&message.content, 3, &VecStoreOptions::default()) @@ -196,13 +198,12 @@ impl BotOrchestrator { &self, message: &UserMessage, session: &UserSession, - ) -> Result> { + ) -> Result> { if let Some(chart_generator) = &self.chart_generator { let chart_response = chart_generator .generate_chart(&message.content, "bar") .await?; - // Store chart generation in history self.session_manager .save_message( session.id, @@ -218,7 +219,6 @@ impl BotOrchestrator { chart_response.sql_query )) } else { - // Fallback to document mode self.document_mode_handler(message, session).await } } @@ -226,8 +226,8 @@ impl BotOrchestrator { async fn database_mode_handler( &self, message: &UserMessage, - session: &UserSession, - ) -> Result> { + _session: &UserSession, + ) -> Result> { if let Some(sql_chain) = &self.sql_chain { let input_variables = prompt_args! { "input" => message.content, @@ -236,7 +236,6 @@ impl BotOrchestrator { let result = sql_chain.invoke(input_variables).await?; Ok(result.to_string()) } else { - // Use LangChain SQL database chain as fallback let db_url = std::env::var("DATABASE_URL")?; let engine = PostgreSQLEngine::new(&db_url).await?; let db = SQLDatabaseBuilder::new(engine).build().await?; @@ -258,11 +257,10 @@ impl BotOrchestrator { async fn tool_mode_handler( &self, message: &UserMessage, - session: &UserSession, - ) -> Result> { - // Check if we should start a tool + _session: &UserSession, + ) -> Result> { if message.content.to_lowercase().contains("calculator") { - if let Some(adapter) = self.channels.get(&message.channel) { + if let Some(_adapter) = self.channels.get(&message.channel) { let (tx, _rx) = mpsc::channel(100); self.register_response_channel(message.session_id.clone(), tx.clone()) @@ -274,17 +272,19 @@ impl BotOrchestrator { let session_manager = self.session_manager.clone(); tokio::spawn(async move { - tool_manager - .execute_tool("calculator", &user_id_str, &bot_id_str, session_manager, tx) - .await - .unwrap_or_else(|e| { - log::error!("Error executing tool: {}", e); - }); + let _ = tool_manager + .execute_tool_with_session( + "calculator", + &user_id_str, + &bot_id_str, + session_manager, + tx, + ) + .await; }); } Ok("Starting calculator tool...".to_string()) } else { - // Fall back to normal LLM response with tool context let available_tools = self.tool_manager.list_tools(); let tools_context = if !available_tools.is_empty() { format!("\n\nAvailable tools: {}. If the user needs calculations, suggest using the calculator tool.", available_tools.join(", ")) @@ -304,8 +304,7 @@ impl BotOrchestrator { &self, message: &UserMessage, session: &UserSession, - ) -> Result> { - // Get conversation history for context using LangChain memory + ) -> Result> { let history = self .session_manager .get_conversation_history(session.id, session.user_id) @@ -320,7 +319,6 @@ impl BotOrchestrator { } } - // Build prompt with memory context let mut prompt = String::new(); if let Some(chat_history) = memory.get_chat_history() { for message in chat_history { @@ -338,11 +336,11 @@ impl BotOrchestrator { .await } - async fn stream_response( + pub async fn stream_response( &self, message: UserMessage, mut response_tx: mpsc::Sender, - ) -> Result<(), Box> { + ) -> Result<(), Box> { info!("Streaming response for user: {}", message.user_id); let user_id = Uuid::parse_str(&message.user_id).unwrap_or_else(|_| Uuid::new_v4()); @@ -379,7 +377,6 @@ impl BotOrchestrator { ) .await?; - // Get conversation history for streaming context let history = self .session_manager .get_conversation_history(session.id, user_id) @@ -455,18 +452,18 @@ impl BotOrchestrator { Ok(()) } - async fn get_user_sessions( + pub async fn get_user_sessions( &self, user_id: Uuid, - ) -> Result, Box> { + ) -> Result, Box> { self.session_manager.get_user_sessions(user_id).await } - async fn get_conversation_history( + pub async fn get_conversation_history( &self, session_id: Uuid, user_id: Uuid, - ) -> Result, Box> { + ) -> Result, Box> { self.session_manager .get_conversation_history(session_id, user_id) .await @@ -475,7 +472,7 @@ impl BotOrchestrator { pub async fn process_message_with_tools( &self, message: UserMessage, - ) -> Result<(), Box> { + ) -> Result<(), Box> { info!( "Processing message with tools from user: {}", message.user_id @@ -508,7 +505,6 @@ impl BotOrchestrator { ) .await?; - // Check if we're in a tool conversation let is_tool_waiting = self .tool_manager .is_tool_waiting(&message.session_id) @@ -516,12 +512,10 @@ impl BotOrchestrator { .unwrap_or(false); if is_tool_waiting { - // Provide input to the running tool self.tool_manager .provide_input(&message.session_id, &message.content) .await?; - // Get tool output and send it as a response if let Ok(tool_output) = self.tool_manager.get_tool_output(&message.session_id).await { for output in tool_output { let bot_response = BotResponse { @@ -543,29 +537,16 @@ impl BotOrchestrator { return Ok(()); } - // Normal LLM processing with tool awareness - let available_tools = self.tool_manager.list_tools(); - let tools_context = if !available_tools.is_empty() { - format!("\n\nAvailable tools: {}. If the user needs calculations, suggest using the calculator tool.", available_tools.join(", ")) - } else { - String::new() - }; - - let full_prompt = format!("{}{}", message.content, tools_context); - - // Simple tool detection (in a real system, this would be LLM-driven) let response = if message.content.to_lowercase().contains("calculator") || message.content.to_lowercase().contains("calculate") || message.content.to_lowercase().contains("math") { - // Start calculator tool match self .tool_manager .execute_tool("calculator", &message.session_id, &message.user_id) .await { Ok(tool_result) => { - // Save tool start message self.session_manager .save_message( session.id, @@ -583,7 +564,15 @@ impl BotOrchestrator { } } } else { - // Normal LLM response + let available_tools = self.tool_manager.list_tools(); + let tools_context = if !available_tools.is_empty() { + format!("\n\nAvailable tools: {}. If the user needs calculations, suggest using the calculator tool.", available_tools.join(", ")) + } else { + String::new() + }; + + let full_prompt = format!("{}{}", message.content, tools_context); + self.llm_provider .generate(&full_prompt, &serde_json::Value::Null) .await? @@ -612,18 +601,11 @@ impl BotOrchestrator { } } -struct AppState { - orchestrator: Arc, - web_adapter: Arc, - voice_adapter: Arc, - whatsapp_adapter: Arc, -} - #[actix_web::get("/ws")] async fn websocket_handler( req: HttpRequest, stream: web::Payload, - data: web::Data, + data: web::Data, ) -> Result { let (res, mut session, mut msg_stream) = actix_ws::handle(&req, stream)?; let session_id = Uuid::new_v4().to_string(); @@ -653,7 +635,7 @@ async fn websocket_handler( actix_web::rt::spawn(async move { while let Some(Ok(msg)) = msg_stream.recv().await { match msg { - Message::Text(text) => { + WsMessage::Text(text) => { let user_message = UserMessage { bot_id: "default_bot".to_string(), user_id: "default_user".to_string(), @@ -669,7 +651,7 @@ async fn websocket_handler( info!("Error processing message: {}", e); } } - Message::Close(_) => { + WsMessage::Close(_) => { web_adapter.remove_connection(&session_id).await; break; } @@ -683,7 +665,7 @@ async fn websocket_handler( #[actix_web::get("/api/whatsapp/webhook")] async fn whatsapp_webhook_verify( - data: web::Data, + data: web::Data, web::Query(params): web::Query>, ) -> Result { let mode = params.get("hub.mode").unwrap_or(&"".to_string()); @@ -698,8 +680,8 @@ async fn whatsapp_webhook_verify( #[actix_web::post("/api/whatsapp/webhook")] async fn whatsapp_webhook( - data: web::Data, - payload: web::Json, + data: web::Data, + payload: web::Json, ) -> Result { match data .whatsapp_adapter @@ -723,7 +705,7 @@ async fn whatsapp_webhook( #[actix_web::post("/api/voice/start")] async fn voice_start( - data: web::Data, + data: web::Data, info: web::Json, ) -> Result { let session_id = info @@ -752,7 +734,7 @@ async fn voice_start( #[actix_web::post("/api/voice/stop")] async fn voice_stop( - data: web::Data, + data: web::Data, info: web::Json, ) -> Result { let session_id = info @@ -770,7 +752,7 @@ async fn voice_stop( } #[actix_web::post("/api/sessions")] -async fn create_session(data: web::Data) -> Result { +async fn create_session(_data: web::Data) -> Result { let session_id = Uuid::new_v4(); Ok(HttpResponse::Ok().json(serde_json::json!({ "session_id": session_id, @@ -780,7 +762,7 @@ async fn create_session(data: web::Data) -> Result { } #[actix_web::get("/api/sessions")] -async fn get_sessions(data: web::Data) -> Result { +async fn get_sessions(data: web::Data) -> Result { let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap(); match data.orchestrator.get_user_sessions(user_id).await { Ok(sessions) => Ok(HttpResponse::Ok().json(sessions)), @@ -793,24 +775,22 @@ async fn get_sessions(data: web::Data) -> Result { #[actix_web::get("/api/sessions/{session_id}")] async fn get_session_history( - data: web::Data, + data: web::Data, path: web::Path, ) -> Result { let session_id = path.into_inner(); let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap(); match Uuid::parse_str(&session_id) { - Ok(session_uuid) => { - match data - .orchestrator - .get_conversation_history(session_uuid, user_id) - .await - { - Ok(history) => Ok(HttpResponse::Ok().json(history)), - Err(e) => Ok(HttpResponse::InternalServerError() - .json(serde_json::json!({"error": e.to_string()}))), - } - } + Ok(session_uuid) => match data + .orchestrator + .get_conversation_history(session_uuid, user_id) + .await + { + Ok(history) => Ok(HttpResponse::Ok().json(history)), + Err(e) => Ok(HttpResponse::InternalServerError() + .json(serde_json::json!({"error": e.to_string()}))), + }, Err(_) => { Ok(HttpResponse::BadRequest().json(serde_json::json!({"error": "Invalid session ID"}))) } @@ -819,7 +799,7 @@ async fn get_session_history( #[actix_web::post("/api/set_mode")] async fn set_mode_handler( - data: web::Data, + data: web::Data, info: web::Json>, ) -> Result { let default_user = "default_user".to_string(); @@ -846,7 +826,7 @@ async fn set_mode_handler( #[actix_web::get("/")] async fn index() -> Result { let html = fs::read_to_string("templates/index.html") - .unwrap_or_else(|_| include_str!("../templates/index.html").to_string()); + .unwrap_or_else(|_| include_str!("../../static/index.html").to_string()); Ok(HttpResponse::Ok().content_type("text/html").body(html)) } diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 278514a..07443b1 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -1,5 +1,3 @@ -pub mod channels; - use async_trait::async_trait; use chrono::Utc; use livekit::{DataPacketKind, Room, RoomOptions}; diff --git a/src/file/mod.rs b/src/file/mod.rs index f3aeda8..6bbe33e 100644 --- a/src/file/mod.rs +++ b/src/file/mod.rs @@ -15,7 +15,7 @@ use minio::s3::http::BaseUrl; use std::str::FromStr; use crate::config::AppConfig; -use crate::state::AppState; +use crate::shared::state::AppState; pub async fn init_minio(config: &AppConfig) -> Result { let scheme = if config.minio.use_ssl { diff --git a/src/main.rs b/src/main.rs index 95e814c..d8d2702 100644 --- a/src/main.rs +++ b/src/main.rs @@ -78,7 +78,6 @@ async fn main() -> std::io::Result<()> { .expect("Failed to create Redis client"); let conn = client .get_connection() - .await .expect("Failed to create Redis connection"); Some(Arc::new(conn)) } @@ -93,17 +92,17 @@ async fn main() -> std::io::Result<()> { let session_manager = SessionManager::new(db.clone(), redis_conn.clone()); let auth_service = AuthService::new(db.clone(), redis_conn.clone()); - let llm_provider: Arc = match std::env::var("LLM_PROVIDER") + let llm_provider: Arc = match std::env::var("LLM_PROVIDER") .unwrap_or("mock".to_string()) .as_str() { - "openai" => Arc::new(crate::llm_local::OpenAIClient::new( + "openai" => Arc::new(crate::llm::OpenAIClient::new( std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY required"), )), - "anthropic" => Arc::new(crate::llm_local::AnthropicClient::new( + "anthropic" => Arc::new(crate::llm::AnthropicClient::new( std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY required"), )), - _ => Arc::new(crate::llm_local::MockLLMProvider::new()), + _ => Arc::new(crate::llm::MockLLMProvider::new()), }; let web_adapter = Arc::new(crate::channels::WebChannelAdapter::new()); diff --git a/src/org/mod.rs b/src/org/mod.rs index 10f4516..e8dd677 100644 --- a/src/org/mod.rs +++ b/src/org/mod.rs @@ -1,4 +1,4 @@ -use actix_web::{web, HttpResponse, Result}; +use actix_web::{put, web, HttpResponse, Result}; use chrono::Utc; use serde::{Deserialize, Serialize}; use sqlx::PgPool; diff --git a/src/session/mod.rs b/src/session/mod.rs index 06f4362..d3303a3 100644 --- a/src/session/mod.rs +++ b/src/session/mod.rs @@ -1,19 +1,18 @@ -use crate::shared::shared::UserSession; -use sqlx::Row; - -use redis::{aio::Connection as ConnectionManager, AsyncCommands}; +use redis::{AsyncCommands, Client}; use serde_json; -use sqlx::PgPool; +use sqlx::{PgPool, Row}; use std::sync::Arc; use uuid::Uuid; +use crate::shared::UserSession; + pub struct SessionManager { pub pool: PgPool, - pub redis: Option>, + pub redis: Option>, } impl SessionManager { - pub fn new(pool: PgPool, redis: Option>) -> Self { + pub fn new(pool: PgPool, redis: Option>) -> Self { Self { pool, redis } } @@ -21,33 +20,31 @@ impl SessionManager { &self, user_id: Uuid, bot_id: Uuid, - ) -> Result, Box> { - // Try Redis cache first - if let Some(redis) = &self.redis { + ) -> Result, Box> { + if let Some(redis_client) = &self.redis { + let mut conn = redis_client.get_multiplexed_async_connection().await?; let cache_key = format!("session:{}:{}", user_id, bot_id); - let mut conn = redis.clone().lock().await; - if let Ok(session_json) = conn.get::<_, String>(cache_key).await { - if let Ok(session) = serde_json::from_str::(&session_json) { + let session_json: Option = conn.get(&cache_key).await?; + if let Some(json) = session_json { + if let Ok(session) = serde_json::from_str::(&json) { return Ok(Some(session)); } } } - // Fallback to database - let session = sqlx::query_as( - "SELECT * FROM user_sessions WHERE user_id = $1 AND bot_id = $2 ORDER BY updated_at DESC LIMIT 1" + let session = sqlx::query_as::<_, UserSession>( + "SELECT * FROM user_sessions WHERE user_id = $1 AND bot_id = $2 ORDER BY updated_at DESC LIMIT 1", ) .bind(user_id) .bind(bot_id) .fetch_optional(&self.pool) .await?; - // Cache in Redis if found if let Some(ref session) = session { - if let Some(redis) = &self.redis { + if let Some(redis_client) = &self.redis { + let mut conn = redis_client.get_multiplexed_async_connection().await?; let cache_key = format!("session:{}:{}", user_id, bot_id); let session_json = serde_json::to_string(session)?; - let mut conn = redis.clone().lock().await; let _: () = conn.set_ex(cache_key, session_json, 1800).await?; } } @@ -60,13 +57,8 @@ impl SessionManager { user_id: Uuid, bot_id: Uuid, title: &str, - ) -> Result> { - log::info!( - "Creating new session for user: {}, bot: {}", - user_id, - bot_id - ); - let session = sqlx::query_as( + ) -> Result> { + let session = sqlx::query_as::<_, UserSession>( "INSERT INTO user_sessions (user_id, bot_id, title) VALUES ($1, $2, $3) RETURNING *", ) .bind(user_id) @@ -75,11 +67,11 @@ impl SessionManager { .fetch_one(&self.pool) .await?; - // Cache in Redis - if let Some(redis) = &self.redis { + if let Some(redis_client) = &self.redis { + let mut conn = redis_client.get_multiplexed_async_connection().await?; let cache_key = format!("session:{}:{}", user_id, bot_id); let session_json = serde_json::to_string(&session)?; - let _: () = redis.clone().set_ex(cache_key, session_json, 1800).await?; + let _: () = conn.set_ex(cache_key, session_json, 1800).await?; } Ok(session) @@ -92,8 +84,8 @@ impl SessionManager { role: &str, content: &str, message_type: &str, - ) -> Result<(), Box> { - let message_count: i32 = + ) -> Result<(), Box> { + let message_count: i64 = sqlx::query("SELECT COUNT(*) as count FROM message_history WHERE session_id = $1") .bind(session_id) .fetch_one(&self.pool) @@ -102,12 +94,12 @@ impl SessionManager { sqlx::query( "INSERT INTO message_history (session_id, user_id, role, content_encrypted, message_type, message_index) - VALUES ($1, $2, $3, $4, $5, $6)" + VALUES ($1, $2, $3, $4, $5, $6)", ) .bind(session_id) .bind(user_id) .bind(role) - .bind(content) // Note: Encryption removed for simplicity + .bind(content) .bind(message_type) .bind(message_count + 1) .execute(&self.pool) @@ -118,17 +110,18 @@ impl SessionManager { .execute(&self.pool) .await?; - // Invalidate session cache - if let Some(redis) = &self.redis { - let session: Option<(Uuid, Uuid)> = - sqlx::query_as("SELECT user_id, bot_id FROM user_sessions WHERE id = $1") + if let Some(redis_client) = &self.redis { + if let Some(session_info) = + sqlx::query("SELECT user_id, bot_id FROM user_sessions WHERE id = $1") .bind(session_id) .fetch_optional(&self.pool) - .await?; - - if let Some((user_id, bot_id)) = session { + .await? + { + let user_id: Uuid = session_info.get("user_id"); + let bot_id: Uuid = session_info.get("bot_id"); + let mut conn = redis_client.get_multiplexed_async_connection().await?; let cache_key = format!("session:{}:{}", user_id, bot_id); - let _: () = redis.clone().del(cache_key).await?; + let _: () = conn.del(cache_key).await?; } } @@ -139,7 +132,7 @@ impl SessionManager { &self, session_id: Uuid, user_id: Uuid, - ) -> Result, Box> { + ) -> Result, Box> { let messages = sqlx::query( "SELECT role, content_encrypted FROM message_history WHERE session_id = $1 AND user_id = $2 @@ -150,21 +143,19 @@ impl SessionManager { .fetch_all(&self.pool) .await?; - let mut history = Vec::new(); - for row in messages { - let role: String = row.get("role"); - let content: String = row.get("content_encrypted"); // No decryption needed - history.push((role, content)); - } + let history = messages + .into_iter() + .map(|row| (row.get("role"), row.get("content_encrypted"))) + .collect(); + Ok(history) } pub async fn get_user_sessions( &self, user_id: Uuid, - ) -> Result, Box> { - // For lists, we'll just use database - let sessions = sqlx::query_as( + ) -> Result, Box> { + let sessions = sqlx::query_as::<_, UserSession>( "SELECT * FROM user_sessions WHERE user_id = $1 ORDER BY updated_at DESC", ) .bind(user_id) @@ -172,4 +163,146 @@ impl SessionManager { .await?; Ok(sessions) } + + pub async fn update_answer_mode( + &self, + user_id: &str, + bot_id: &str, + mode: &str, + ) -> Result<(), Box> { + let user_uuid = Uuid::parse_str(user_id)?; + let bot_uuid = Uuid::parse_str(bot_id)?; + + sqlx::query( + "UPDATE user_sessions + SET answer_mode = $1, updated_at = NOW() + WHERE user_id = $2 AND bot_id = $3", + ) + .bind(mode) + .bind(user_uuid) + .bind(bot_uuid) + .execute(&self.pool) + .await?; + + if let Some(redis_client) = &self.redis { + let mut conn = redis_client.get_multiplexed_async_connection().await?; + let cache_key = format!("session:{}:{}", user_uuid, bot_uuid); + let _: () = conn.del(cache_key).await?; + } + + Ok(()) + } + + pub async fn update_current_tool( + &self, + user_id: &str, + bot_id: &str, + tool_name: Option<&str>, + ) -> Result<(), Box> { + let user_uuid = Uuid::parse_str(user_id)?; + let bot_uuid = Uuid::parse_str(bot_id)?; + + sqlx::query( + "UPDATE user_sessions + SET current_tool = $1, updated_at = NOW() + WHERE user_id = $2 AND bot_id = $3", + ) + .bind(tool_name) + .bind(user_uuid) + .bind(bot_uuid) + .execute(&self.pool) + .await?; + + if let Some(redis_client) = &self.redis { + let mut conn = redis_client.get_multiplexed_async_connection().await?; + let cache_key = format!("session:{}:{}", user_uuid, bot_uuid); + let _: () = conn.del(cache_key).await?; + } + + Ok(()) + } + + pub async fn get_session_by_id( + &self, + session_id: Uuid, + ) -> Result, Box> { + if let Some(redis_client) = &self.redis { + let mut conn = redis_client.get_multiplexed_async_connection().await?; + let cache_key = format!("session_by_id:{}", session_id); + let session_json: Option = conn.get(&cache_key).await?; + if let Some(json) = session_json { + if let Ok(session) = serde_json::from_str::(&json) { + return Ok(Some(session)); + } + } + } + + let session = sqlx::query_as::<_, UserSession>("SELECT * FROM user_sessions WHERE id = $1") + .bind(session_id) + .fetch_optional(&self.pool) + .await?; + + if let Some(ref session) = session { + if let Some(redis_client) = &self.redis { + let mut conn = redis_client.get_multiplexed_async_connection().await?; + let cache_key = format!("session_by_id:{}", session_id); + let session_json = serde_json::to_string(session)?; + let _: () = conn.set_ex(cache_key, session_json, 1800).await?; + } + } + + Ok(session) + } + + pub async fn cleanup_old_sessions( + &self, + days_old: i32, + ) -> Result> { + let result = sqlx::query( + "DELETE FROM user_sessions + WHERE updated_at < NOW() - INTERVAL '1 day' * $1", + ) + .bind(days_old) + .execute(&self.pool) + .await?; + Ok(result.rows_affected()) + } + + pub async fn set_current_tool( + &self, + user_id: &str, + bot_id: &str, + tool_name: Option, + ) -> Result<(), Box> { + let user_uuid = Uuid::parse_str(user_id)?; + let bot_uuid = Uuid::parse_str(bot_id)?; + + sqlx::query( + "UPDATE user_sessions + SET current_tool = $1, updated_at = NOW() + WHERE user_id = $2 AND bot_id = $3", + ) + .bind(tool_name) + .bind(user_uuid) + .bind(bot_uuid) + .execute(&self.pool) + .await?; + + if let Some(redis_client) = &self.redis { + let mut conn = redis_client.get_multiplexed_async_connection().await?; + let cache_key = format!("session:{}:{}", user_uuid, bot_uuid); + let _: () = conn.del(cache_key).await?; + } + + Ok(()) + } +} + +impl Clone for SessionManager { + fn clone(&self) -> Self { + Self { + pool: self.pool.clone(), + redis: self.redis.clone(), + } + } } diff --git a/src/shared/mod.rs b/src/shared/mod.rs index 1975b83..1d5fa97 100644 --- a/src/shared/mod.rs +++ b/src/shared/mod.rs @@ -2,61 +2,6 @@ pub mod models; pub mod state; pub mod utils; -use chrono::{DateTime, Utc}; -use serde::{Deserialize, Serialize}; -use sqlx::FromRow; -use uuid::Uuid; - -#[derive(Debug, Clone, Serialize, Deserialize, FromRow)] -pub struct UserSession { - pub id: Uuid, - pub user_id: Uuid, - pub bot_id: Uuid, - pub title: String, - pub context_data: serde_json::Value, - pub created_at: DateTime, - pub updated_at: DateTime, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct EmbeddingRequest { - pub text: String, - pub model: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct EmbeddingResponse { - pub embedding: Vec, - pub model: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SearchResult { - pub text: String, - pub similarity: f32, - pub metadata: serde_json::Value, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct UserMessage { - pub bot_id: String, - pub user_id: String, - pub session_id: String, - pub channel: String, - pub content: String, - pub message_type: String, - pub media_url: Option, - pub timestamp: DateTime, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct BotResponse { - pub bot_id: String, - pub user_id: String, - pub session_id: String, - pub channel: String, - pub content: String, - pub message_type: String, - pub stream_token: Option, - pub is_complete: bool, -} +pub use models::*; +pub use state::*; +pub use utils::*; diff --git a/src/shared/models.rs b/src/shared/models.rs index 0162644..c6001f8 100644 --- a/src/shared/models.rs +++ b/src/shared/models.rs @@ -4,7 +4,7 @@ use sqlx::FromRow; use uuid::Uuid; #[derive(Debug, Clone, Serialize, Deserialize, FromRow)] -pub struct organization { +pub struct Organization { pub org_id: Uuid, pub name: String, pub slug: String, @@ -17,8 +17,8 @@ pub struct Bot { pub name: String, pub status: BotStatus, pub config: serde_json::Value, - pub created_at: chrono::DateTime, - pub updated_at: chrono::DateTime, + pub created_at: DateTime, + pub updated_at: DateTime, } #[derive(Debug, Clone, Serialize, Deserialize, sqlx::Type)] @@ -30,11 +30,6 @@ pub enum BotStatus { Maintenance, } -use chrono::{DateTime, Utc}; -use serde::{Deserialize, Serialize}; -use sqlx::FromRow; -use uuid::Uuid; - #[derive(Debug, Clone, Copy, PartialEq)] pub enum TriggerKind { Scheduled = 0, @@ -58,10 +53,66 @@ impl TriggerKind { #[derive(Debug, FromRow, Serialize, Deserialize)] pub struct Automation { pub id: Uuid, - pub kind: i32, // Using number for trigger type + pub kind: i32, pub target: Option, pub schedule: Option, pub param: String, pub is_active: bool, pub last_triggered: Option>, } + +#[derive(Debug, Clone, Serialize, Deserialize, FromRow)] +pub struct UserSession { + pub id: Uuid, + pub user_id: Uuid, + pub bot_id: Uuid, + pub title: String, + pub context_data: serde_json::Value, + pub answer_mode: String, + pub current_tool: Option, + pub created_at: DateTime, + pub updated_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EmbeddingRequest { + pub text: String, + pub model: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EmbeddingResponse { + pub embedding: Vec, + pub model: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SearchResult { + pub text: String, + pub similarity: f32, + pub metadata: serde_json::Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UserMessage { + pub bot_id: String, + pub user_id: String, + pub session_id: String, + pub channel: String, + pub content: String, + pub message_type: String, + pub media_url: Option, + pub timestamp: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BotResponse { + pub bot_id: String, + pub user_id: String, + pub session_id: String, + pub channel: String, + pub content: String, + pub message_type: String, + pub stream_token: Option, + pub is_complete: bool, +} diff --git a/src/shared/state.rs b/src/shared/state.rs index 4d68886..c27cc78 100644 --- a/src/shared/state.rs +++ b/src/shared/state.rs @@ -1,12 +1,16 @@ use std::sync::Arc; -use minio::s3::Client; - -use crate::{config::AppConfig, web_automator::BrowserPool}; +use crate::{ + bot::BotOrchestrator, + channels::{VoiceAdapter, WebChannelAdapter, WhatsAppAdapter}, + config::AppConfig, + tools::ToolApi, + web_automation::BrowserPool +}; #[derive(Clone)] pub struct AppState { - pub minio_client: Option, + pub minio_client: Option, pub config: Option, pub db: Option, pub db_custom: Option, @@ -15,9 +19,10 @@ pub struct AppState { pub web_adapter: Arc, pub voice_adapter: Arc, pub whatsapp_adapter: Arc, - tool_api: Arc, // Add this + pub tool_api: Arc, } -pub struct _BotState { + +pub struct BotState { pub language: String, pub work_folder: String, } diff --git a/src/shared/utils.rs b/src/shared/utils.rs index d42f91c..b7b36b0 100644 --- a/src/shared/utils.rs +++ b/src/shared/utils.rs @@ -1,15 +1,10 @@ -use crate::config::AIConfig; -use langchain_rust::llm::OpenAI; -use langchain_rust::{language_models::llm::LLM, llm::AzureConfig}; -use log::error; +use chrono::{DateTime, Utc}; +use langchain_rust::llm::AzureConfig; use log::{debug, warn}; use rhai::{Array, Dynamic}; use serde_json::{json, Value}; use smartstring::SmartString; -use sqlx::Column; // Required for .name() method -use sqlx::TypeInfo; // Required for .type_info() method -use sqlx::{postgres::PgRow, Row}; -use sqlx::{Decode, Type}; +use sqlx::{postgres::PgRow, Column, Decode, Row, Type, TypeInfo}; use std::error::Error; use std::fs::File; use std::io::BufReader; @@ -18,6 +13,7 @@ use tokio::fs::File as TokioFile; use tokio_stream::StreamExt; use zip::ZipArchive; +use crate::config::AIConfig; use reqwest::Client; use tokio::io::AsyncWriteExt; @@ -34,16 +30,14 @@ pub async fn call_llm( ai_config: &AIConfig, ) -> Result> { let azure_config = azure_from_config(&ai_config.clone()); - let open_ai = OpenAI::new(azure_config); + let open_ai = langchain_rust::llm::OpenAI::new(azure_config); - // Directly use the input text as prompt let prompt = text.to_string(); - // Call LLM and return the raw text response match open_ai.invoke(&prompt).await { Ok(response_text) => Ok(response_text), Err(err) => { - error!("Error invoking LLM API: {}", err); + log::error!("Error invoking LLM API: {}", err); Err(Box::new(std::io::Error::new( std::io::ErrorKind::Other, "Failed to invoke LLM API", @@ -79,6 +73,7 @@ pub fn extract_zip_recursive( Ok(()) } + pub fn row_to_json(row: PgRow) -> Result> { let mut result = serde_json::Map::new(); let columns = row.columns(); @@ -131,17 +126,15 @@ where } fn handle_json(row: &PgRow, idx: usize, col_name: &str) -> Value { - // First try to get as Option match row.try_get::, _>(idx) { Ok(Some(val)) => { debug!("Successfully read JSON column {} as Value", col_name); return val; } Ok(None) => return Value::Null, - Err(_) => (), // Fall through to other attempts + Err(_) => (), } - // Try as Option that might contain JSON match row.try_get::, _>(idx) { Ok(Some(s)) => match serde_json::from_str(&s) { Ok(val) => val, @@ -157,6 +150,7 @@ fn handle_json(row: &PgRow, idx: usize, col_name: &str) -> Value { } } } + pub fn json_value_to_dynamic(value: &Value) -> Dynamic { match value { Value::Null => Dynamic::UNIT, @@ -184,16 +178,12 @@ pub fn json_value_to_dynamic(value: &Value) -> Dynamic { } } -/// Converts any value to an array - single values become single-element arrays pub fn to_array(value: Dynamic) -> Array { if value.is_array() { - // Already an array - return as-is value.cast::() } else if value.is_unit() || value.is::<()>() { - // Handle empty/unit case Array::new() } else { - // Convert single value to single-element array Array::from([value]) } } @@ -218,7 +208,6 @@ pub async fn download_file(url: &str, output_path: &str) -> Result<(), Box Result<(String, Vec), Box> { let parts: Vec<&str> = filter_str.split('=').collect(); if parts.len() != 2 { @@ -228,7 +217,6 @@ pub fn parse_filter(filter_str: &str) -> Result<(String, Vec), Box Result<(String, Vec), Box Result>; + ) -> Result>; async fn provide_input( &self, session_id: &str, input: &str, - ) -> Result<(), Box>; - async fn get_output(&self, session_id: &str) - -> Result, Box>; + ) -> Result<(), Box>; + async fn get_output( + &self, + session_id: &str, + ) -> Result, Box>; async fn is_waiting_for_input( &self, session_id: &str, - ) -> Result>; + ) -> Result>; } pub struct RedisToolExecutor { @@ -58,7 +65,7 @@ impl RedisToolExecutor { web_adapter: Arc, voice_adapter: Arc, whatsapp_adapter: Arc, - ) -> Result> { + ) -> Result> { let client = redis::Client::open(redis_url)?; Ok(Self { redis_client: client, @@ -74,7 +81,7 @@ impl RedisToolExecutor { user_id: &str, channel: &str, message: &str, - ) -> Result<(), Box> { + ) -> Result<(), Box> { let response = BotResponse { bot_id: "tool_bot".to_string(), user_id: user_id.to_string(), @@ -97,7 +104,6 @@ impl RedisToolExecutor { fn create_rhai_engine(&self, session_id: String, user_id: String, channel: String) -> Engine { let mut engine = Engine::new(); - // Clone for TALK function let tool_executor = Arc::new(( self.redis_client.clone(), self.web_adapter.clone(), @@ -118,7 +124,6 @@ impl RedisToolExecutor { tokio::spawn(async move { let (redis_client, web_adapter, voice_adapter, whatsapp_adapter) = &*tool_executor; - // Send message through appropriate channel let response = BotResponse { bot_id: "tool_bot".to_string(), user_id: user_id.clone(), @@ -141,7 +146,6 @@ impl RedisToolExecutor { log::error!("Failed to send tool message: {}", e); } - // Also store in Redis for persistence if let Ok(mut conn) = redis_client.get_async_connection().await { let output_key = format!("tool:{}:output", session_id); let _ = conn.lpush(&output_key, &message).await; @@ -149,7 +153,6 @@ impl RedisToolExecutor { }); }); - // Clone for HEAR function let hear_executor = self.redis_client.clone(); let session_id_clone = session_id.clone(); @@ -164,14 +167,9 @@ impl RedisToolExecutor { let input_key = format!("tool:{}:input", session_id); let waiting_key = format!("tool:{}:waiting", session_id); - // Set waiting flag let _ = conn.set_ex(&waiting_key, "true", 300).await; - - // Wait for input let result: Option<(String, String)> = conn.brpop(&input_key, 30).await.ok().flatten(); - - // Clear waiting flag let _ = conn.del(&waiting_key).await; result @@ -189,8 +187,8 @@ impl RedisToolExecutor { engine } - async fn cleanup_session(&self, session_id: &str) -> Result<(), Box> { - let mut conn = self.redis_client.get_async_connection().await?; + async fn cleanup_session(&self, session_id: &str) -> Result<(), Box> { + let mut conn = self.redis_client.get_multiplexed_async_connection().await?; let keys = vec![ format!("tool:{}:output", session_id), @@ -214,11 +212,10 @@ impl ToolExecutor for RedisToolExecutor { tool_name: &str, session_id: &str, user_id: &str, - ) -> Result> { + ) -> Result> { let tool = get_tool(tool_name).ok_or_else(|| format!("Tool not found: {}", tool_name))?; - // Store session info in Redis - let mut conn = self.redis_client.get_async_connection().await?; + let mut conn = self.redis_client.get_multiplexed_async_connection().await?; let session_key = format!("tool:{}:session", session_id); let session_data = serde_json::json!({ "user_id": user_id, @@ -228,20 +225,16 @@ impl ToolExecutor for RedisToolExecutor { conn.set_ex(&session_key, session_data.to_string(), 3600) .await?; - // Mark tool as active let active_key = format!("tool:{}:active", session_id); conn.set_ex(&active_key, "true", 3600).await?; - // Get channel from session (you might want to store this differently) - let channel = "web"; // Default channel, you might want to track this per session - - let engine = self.create_rhai_engine( + let channel = "web"; + let _engine = self.create_rhai_engine( session_id.to_string(), user_id.to_string(), channel.to_string(), ); - // Execute tool in background let redis_clone = self.redis_client.clone(); let web_adapter_clone = self.web_adapter.clone(); let voice_adapter_clone = self.voice_adapter.clone(); @@ -254,7 +247,6 @@ impl ToolExecutor for RedisToolExecutor { let mut engine = Engine::new(); let mut scope = Scope::new(); - // Register TALK function for background execution let redis_client = redis_clone.clone(); let web_adapter = web_adapter_clone.clone(); let voice_adapter = voice_adapter_clone.clone(); @@ -271,8 +263,7 @@ impl ToolExecutor for RedisToolExecutor { let user_id = user_id.clone(); tokio::spawn(async move { - // Determine channel from session data - let channel = "web"; // In real implementation, get from session storage + let channel = "web"; let response = BotResponse { bot_id: "tool_bot".to_string(), @@ -285,7 +276,6 @@ impl ToolExecutor for RedisToolExecutor { is_complete: true, }; - // Send through appropriate channel let send_result = match channel { "web" => web_adapter.send_message(response).await, "voice" => voice_adapter.send_message(response).await, @@ -297,7 +287,6 @@ impl ToolExecutor for RedisToolExecutor { log::error!("Failed to send tool message: {}", e); } - // Store in Redis for backup if let Ok(mut conn) = redis_client.get_async_connection().await { let output_key = format!("tool:{}:output", session_id); let _ = conn.lpush(&output_key, &message).await; @@ -305,7 +294,6 @@ impl ToolExecutor for RedisToolExecutor { }); }); - // Register HEAR function let hear_redis = redis_clone.clone(); let session_id_hear = session_id.clone(); engine.register_fn("hear", move || -> String { @@ -333,7 +321,6 @@ impl ToolExecutor for RedisToolExecutor { }) }); - // Execute the tool match engine.eval_with_scope::<()>(&mut scope, &tool_script) { Ok(_) => { log::info!( @@ -342,7 +329,6 @@ impl ToolExecutor for RedisToolExecutor { session_id ); - // Send completion message let completion_msg = "🛠️ Tool execution completed. How can I help you with anything else?"; let response = BotResponse { @@ -377,7 +363,6 @@ impl ToolExecutor for RedisToolExecutor { } } - // Cleanup if let Ok(mut conn) = redis_clone.get_async_connection().await { let active_key = format!("tool:{}:active", session_id_clone); let _ = conn.del(&active_key).await; @@ -399,8 +384,8 @@ impl ToolExecutor for RedisToolExecutor { &self, session_id: &str, input: &str, - ) -> Result<(), Box> { - let mut conn = self.redis_client.get_async_connection().await?; + ) -> Result<(), Box> { + let mut conn = self.redis_client.get_multiplexed_async_connection().await?; let input_key = format!("tool:{}:input", session_id); conn.lpush(&input_key, input).await?; Ok(()) @@ -409,29 +394,25 @@ impl ToolExecutor for RedisToolExecutor { async fn get_output( &self, session_id: &str, - ) -> Result, Box> { - let mut conn = self.redis_client.get_async_connection().await?; + ) -> Result, Box> { + let mut conn = self.redis_client.get_multiplexed_async_connection().await?; let output_key = format!("tool:{}:output", session_id); let messages: Vec = conn.lrange(&output_key, 0, -1).await?; - - // Clear after reading let _: () = conn.del(&output_key).await?; - Ok(messages) } async fn is_waiting_for_input( &self, session_id: &str, - ) -> Result> { - let mut conn = self.redis_client.get_async_connection().await?; + ) -> Result> { + let mut conn = self.redis_client.get_multiplexed_async_connection().await?; let waiting_key = format!("tool:{}:waiting", session_id); let exists: bool = conn.exists(&waiting_key).await?; Ok(exists) } } -// Tool definitions fn get_tool(name: &str) -> Option { match name { "calculator" => Some(Tool { @@ -443,7 +424,6 @@ fn get_tool(name: &str) -> Option { ("b".to_string(), "number".to_string()), ]), script: r#" - // Calculator tool using TALK/HEAR pattern let TALK = |message| { talk(message); }; @@ -490,13 +470,71 @@ fn get_tool(name: &str) -> Option { } } +#[derive(Clone)] pub struct ToolManager { - executor: Arc, + tools: HashMap, + waiting_responses: Arc>>>, } impl ToolManager { - pub fn new(executor: Arc) -> Self { - Self { executor } + pub fn new() -> Self { + let mut tools = HashMap::new(); + + let calculator_tool = Tool { + name: "calculator".to_string(), + description: "Perform calculations".to_string(), + parameters: HashMap::from([ + ( + "operation".to_string(), + "add|subtract|multiply|divide".to_string(), + ), + ("a".to_string(), "number".to_string()), + ("b".to_string(), "number".to_string()), + ]), + script: r#" + TALK("Calculator started. Enter first number:"); + let a = HEAR(); + TALK("Enter second number:"); + let b = HEAR(); + TALK("Operation (add/subtract/multiply/divide):"); + let op = HEAR(); + + let num_a = a.parse::().unwrap(); + let num_b = b.parse::().unwrap(); + let result = if op == "add" { + num_a + num_b + } else if op == "subtract" { + num_a - num_b + } else if op == "multiply" { + num_a * num_b + } else if op == "divide" { + if num_b == 0.0 { + TALK("Cannot divide by zero"); + return; + } + num_a / num_b + } else { + TALK("Invalid operation"); + return; + }; + TALK("Result: ".to_string() + &result.to_string()); + "# + .to_string(), + }; + + tools.insert(calculator_tool.name.clone(), calculator_tool); + Self { + tools, + waiting_responses: Arc::new(Mutex::new(HashMap::new())), + } + } + + pub fn get_tool(&self, name: &str) -> Option<&Tool> { + self.tools.get(name) + } + + pub fn list_tools(&self) -> Vec { + self.tools.keys().cloned().collect() } pub async fn execute_tool( @@ -504,33 +542,159 @@ impl ToolManager { tool_name: &str, session_id: &str, user_id: &str, - ) -> Result> { - self.executor.execute(tool_name, session_id, user_id).await + ) -> Result> { + let tool = self.get_tool(tool_name).ok_or("Tool not found")?; + + Ok(ToolResult { + success: true, + output: format!("Tool {} started for user {}", tool_name, user_id), + requires_input: true, + session_id: session_id.to_string(), + }) + } + + pub async fn is_tool_waiting( + &self, + session_id: &str, + ) -> Result> { + let waiting = self.waiting_responses.lock().await; + Ok(waiting.contains_key(session_id)) } pub async fn provide_input( &self, session_id: &str, input: &str, - ) -> Result<(), Box> { - self.executor.provide_input(session_id, input).await + ) -> Result<(), Box> { + self.provide_user_response(session_id, "default_bot", input.to_string()) + .await } pub async fn get_tool_output( &self, session_id: &str, - ) -> Result, Box> { - self.executor.get_output(session_id).await + ) -> Result, Box> { + Ok(vec![]) } - pub async fn is_tool_waiting( + pub async fn execute_tool_with_session( &self, - session_id: &str, - ) -> Result> { - self.executor.is_waiting_for_input(session_id).await + tool_name: &str, + user_id: &str, + bot_id: &str, + session_manager: SessionManager, + channel_sender: mpsc::Sender, + ) -> Result<(), Box> { + let tool = self.get_tool(tool_name).ok_or("Tool not found")?; + session_manager + .set_current_tool(user_id, bot_id, Some(tool_name.to_string())) + .await?; + + let user_id = user_id.to_string(); + let bot_id = bot_id.to_string(); + let script = tool.script.clone(); + let session_manager_clone = session_manager.clone(); + let waiting_responses = self.waiting_responses.clone(); + + tokio::spawn(async move { + let mut engine = rhai::Engine::new(); + let (talk_tx, mut talk_rx) = mpsc::channel(100); + let (hear_tx, mut hear_rx) = mpsc::channel(100); + + { + let key = format!("{}:{}", user_id, bot_id); + let mut waiting = waiting_responses.lock().await; + waiting.insert(key, hear_tx); + } + + let channel_sender_clone = channel_sender.clone(); + let user_id_clone = user_id.clone(); + let bot_id_clone = bot_id.clone(); + + let talk_tx_clone = talk_tx.clone(); + engine.register_fn("TALK", move |message: String| { + let tx = talk_tx_clone.clone(); + tokio::spawn(async move { + let _ = tx.send(message).await; + }); + }); + + let hear_rx_mutex = Arc::new(Mutex::new(hear_rx)); + engine.register_fn("HEAR", move || { + let hear_rx = hear_rx_mutex.clone(); + tokio::task::block_in_place(|| { + tokio::runtime::Handle::current().block_on(async move { + let mut receiver = hear_rx.lock().await; + receiver.recv().await.unwrap_or_default() + }) + }) + }); + + let script_result = + tokio::task::spawn_blocking(move || engine.eval::<()>(&script)).await; + + if let Ok(Err(e)) = script_result { + let error_response = BotResponse { + bot_id: bot_id_clone.clone(), + user_id: user_id_clone.clone(), + session_id: Uuid::new_v4().to_string(), + channel: "test".to_string(), + content: format!("Tool error: {}", e), + message_type: "text".to_string(), + stream_token: None, + is_complete: true, + }; + let _ = channel_sender_clone.send(error_response).await; + } + + while let Some(message) = talk_rx.recv().await { + let response = BotResponse { + bot_id: bot_id.clone(), + user_id: user_id.clone(), + session_id: Uuid::new_v4().to_string(), + channel: "test".to_string(), + content: message, + message_type: "text".to_string(), + stream_token: None, + is_complete: true, + }; + let _ = channel_sender.send(response).await; + } + + let _ = session_manager_clone + .set_current_tool(&user_id, &bot_id, None) + .await; + }); + + Ok(()) } - pub fn list_tools(&self) -> Vec { - vec!["calculator".to_string()] + pub async fn provide_user_response( + &self, + user_id: &str, + bot_id: &str, + response: String, + ) -> Result<(), Box> { + let key = format!("{}:{}", user_id, bot_id); + let mut waiting = self.waiting_responses.lock().await; + if let Some(tx) = waiting.get_mut(&key) { + let _ = tx.send(response).await; + waiting.remove(&key); + } + Ok(()) + } +} + +impl Default for ToolManager { + fn default() -> Self { + Self::new() + } +} + +pub struct ToolApi; + +impl ToolApi { + pub fn new() -> Self { + Self } } diff --git a/src/utils/add-drive-user.sh b/src/utils/add-drive-user.sh new file mode 100644 index 0000000..0f691e2 --- /dev/null +++ b/src/utils/add-drive-user.sh @@ -0,0 +1,27 @@ +export BOT_ID= +./mc alias set minio http://localhost:9000 user pass +./mc admin user add minio $BOT_ID + +cat > $BOT_ID-policy.json </dev/null + +# Temporary files +echo "Cleaning temporary files..." +rm -rf /tmp/* /var/tmp/* + +# Thumbnail cache +echo "Cleaning thumbnail cache..." +rm -rf ~/.cache/thumbnails/* /root/.cache/thumbnails/* + +# DNS cache +echo "Flushing DNS cache..." +systemd-resolve --flush-caches 2>/dev/null || true + +# Old kernels (keep 2 latest) +echo "Removing old kernels..." +apt purge -y $(dpkg -l | awk '/^ii linux-image-*/{print $2}' | grep -v $(uname -r) | head -n -2) 2>/dev/null + +# Crash reports +echo "Clearing crash reports..." +rm -f /var/crash/* + +### LXC Containers Cleanup ### +echo -e "\n[ LXC CONTAINERS CLEANUP ]" + +# Check if LXC is installed +if command -v lxc >/dev/null 2>&1; then + for container in $(lxc list -c n --format csv | grep -v "^$"); do + echo -e "\nCleaning container: $container" + + # Execute cleanup commands in container + lxc exec "$container" -- bash -c " + echo 'Cleaning package cache...' + apt clean && apt autoclean && apt autoremove -y + + echo 'Cleaning temporary files...' + rm -rf /tmp/* /var/tmp/* + + echo 'Cleaning logs...' + rm -rf /opt/gbo/logs/* + + echo 'Cleaning journal logs...' + journalctl --vacuum-time=1d 2>/dev/null || true + + echo 'Cleaning thumbnail cache...' + rm -rf /home/*/.cache/thumbnails/* /root/.cache/thumbnails/* + " 2>/dev/null + done +else + echo "LXC not installed, skipping container cleanup." +fi + +echo -e "\nCleanup completed!" \ No newline at end of file diff --git a/src/utils/disk-size.md b/src/utils/disk-size.md new file mode 100644 index 0000000..b7d3a54 --- /dev/null +++ b/src/utils/disk-size.md @@ -0,0 +1,6 @@ +lxc list --format json | jq -r '.[].name' | while read container; do + echo -n "$container: " + lxc exec $container -- df -h / --output=used < /dev/null | tail -n1 +done + +du -h --max-depth=1 "." 2>/dev/null | sort -rh | head -n 50 | awk '{printf "%-10s %s\n", $1, $2}' diff --git a/src/utils/email-ips.sh b/src/utils/email-ips.sh new file mode 100644 index 0000000..442af3c --- /dev/null +++ b/src/utils/email-ips.sh @@ -0,0 +1,8 @@ +az network public-ip list --resource-group "$CLOUD_GROUP" \ + --query "[].{Name:name, IP:ipAddress, ReverseDNS:dnsSettings.reverseFqdn}" \ + -o table + +az network public-ip update --resource-group "$CLOUD_GROUP" + --name "pip-network-adapter-name" + --reverse-fqdn "outbound14.domain.com.br" + diff --git a/src/utils/install-libreoffice-online.sh b/src/utils/install-libreoffice-online.sh new file mode 100644 index 0000000..30b953d --- /dev/null +++ b/src/utils/install-libreoffice-online.sh @@ -0,0 +1,65 @@ +sudo apt install -y cloud-guest-utils e2fsprogs + +apt install -y make g++ build-essential +apt install -y openjdk-17-jdk ant +apt install -y sudo systemd wget zip procps ccache +apt install -y automake bison flex git gperf graphviz junit4 libtool m4 nasm +apt install -y libcairo2-dev libjpeg-dev libegl1-mesa-dev libfontconfig1-dev \ + libgl1-mesa-dev libgif-dev libgtk-3-dev librsvg2-dev libpango1.0-dev +apt install -y libcap-dev libcap2-bin libkrb5-dev libpcap0.8-dev openssl libssl-dev +apt install -y libxcb-dev libx11-xcb-dev libxkbcommon-x11-dev libxtst-dev \ + libxrender-dev libxslt1-dev libxt-dev xsltproc +apt install -y libcunit1-dev libcppunit-dev libpam0g-dev libcups2-dev libzstd-dev uuid-runtime +apt install -y python3-dev python3-lxml python3-pip python3-polib +apt install -y nodejs npm +apt install -y libpoco-dev libpococrypto80 +apt install -y libreoffice-dev + + +mkdir -p /opt/lo && cd /opt/lo +wget https://github.com/CollaboraOnline/online/releases/download/for-code-assets/core-co-24.04-assets.tar.gz +tar xf core-co-24.04-assets.tar.gz && rm core-co-24.04-assets.tar.gz + +useradd cool -G sudo +mkdir -p /opt/cool && chown cool:cool /opt/cool +cd /opt/cool +sudo -Hu cool git clone https://github.com/CollaboraOnline/online.git +cd online && sudo -Hu cool ./autogen.sh + +export CPPFLAGS=-I/opt/lo/include +export LDFLAGS=-L/opt/lo/instdir/program +./configure --with-lokit-path=/opt/lo --with-lo-path=/opt/lo/instdir --with-poco-includes=/usr/local/include --with-poco-libs=/usr/local/lib + +sudo -Hu cool make -j$(nproc) + +make install +mkdir -p /etc/coolwsd /usr/local/var/cache/coolwsd + +chown cool:cool /usr/local/var/cache/coolwsd +admin_pwd=$(openssl rand -hex 6) + +cat < /lib/systemd/system/coolwsd.service +[Unit] +Description=Collabora Online WebSocket Daemon +After=network.target + +[Service] +ExecStart=/opt/cool/online/coolwsd --o:sys_template_path=/opt/cool/online/systemplate \ +--o:lo_template_path=/opt/lo/instdir --o:child_root_path=/opt/cool/online/jails \ +--o:admin_console.username=admin --o:admin_console.password=$DOC_EDITOR_ADMIN_PWD \ +--o:ssl.enable=false +User=cool + +[Install] +WantedBy=multi-user.target +EOT + +systemctl daemon-reload +systemctl enable coolwsd.service +systemctl start coolwsd.service +" + +echo "Installation complete!" +echo "Admin password: $admin_pwd" +echo "Access at: https://localhost:9980" + diff --git a/src/utils/set-limits.sh b/src/utils/set-limits.sh new file mode 100644 index 0000000..215cca7 --- /dev/null +++ b/src/utils/set-limits.sh @@ -0,0 +1,53 @@ +#!/usr/bin/env bash + +# Define container limits in an associative array +declare -A container_limits=( + # Pattern Memory CPU Allowance + ["*tables*"]="4096MB:100ms/100ms" + ["*dns*"]="2048MB:100ms/100ms" + ["*doc-editor*"]="512MB:10ms/100ms" + ["*proxy*"]="2048MB:100ms/100ms" + ["*directory*"]="1024MB:50ms/100ms" + ["*drive*"]="4096MB:50ms/100ms" + ["*email*"]="4096MB:100ms/100ms" + ["*webmail*"]="4096MB:100ms/100ms" + ["*bot*"]="4096MB:50ms/100ms" + ["*meeting*"]="4096MB:100ms/100ms" + ["*alm*"]="512MB:50ms/100ms" + ["*alm-ci*"]="4096MB:100ms/100ms" + ["*system*"]="4096MB:50ms/100ms" + ["*mailer*"]="4096MB:25ms/100ms" +) + +# Default values (for containers that don't match any pattern) +DEFAULT_MEMORY="1024MB" +DEFAULT_CPU_ALLOWANCE="15ms/100ms" +CPU_COUNT=2 +CPU_PRIORITY=10 + +for pattern in "${!container_limits[@]}"; do + echo "Configuring $container..." + + memory=$DEFAULT_MEMORY + cpu_allowance=$DEFAULT_CPU_ALLOWANCE + + # Configure all containers + for container in $(lxc list -c n --format csv); do + # Check if container matches any pattern + if [[ $container == $pattern ]]; then + IFS=':' read -r memory cpu_allowance <<< "${container_limits[$pattern]}" + + # Apply configuration + lxc config set "$container" limits.memory "$memory" + lxc config set "$container" limits.cpu.allowance "$cpu_allowance" + lxc config set "$container" limits.cpu "$CPU_COUNT" + lxc config set "$container" limits.cpu.priority "$CPU_PRIORITY" + + echo "Restarting $container..." + lxc restart "$container" + + lxc config show "$container" | grep -E "memory|cpu" + break + fi + done +done diff --git a/src/utils/set-size-5GB.sh b/src/utils/set-size-5GB.sh new file mode 100644 index 0000000..6de4216 --- /dev/null +++ b/src/utils/set-size-5GB.sh @@ -0,0 +1,7 @@ +lxc config device override $CONTAINER_NAME root +lxc config device set $CONTAINER_NAME root size 6GB + +zpool set autoexpand=on default +zpool online -e default /var/snap/lxd/common/lxd/disks/default.img +zpool list +zfs list diff --git a/src/utils/setup-host.sh b/src/utils/setup-host.sh new file mode 100644 index 0000000..ae611e8 --- /dev/null +++ b/src/utils/setup-host.sh @@ -0,0 +1,6 @@ + +# Host +sudo lxc config set core.trust_password "$LXC_TRUST_PASSWORD" + +# ALM-CI +lxc remote add bot 10.16.164.? --accept-certificate --password "$LXC_TRUST_PASSWORD" diff --git a/src/utils/startup.sh b/src/utils/startup.sh new file mode 100644 index 0000000..e274f6a --- /dev/null +++ b/src/utils/startup.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +# Disable shell timeout + +sed -i '/TMOUT/d' /etc/profile /etc/bash.bashrc /etc/profile.d/* +echo 'export TMOUT=0' > /etc/profile.d/notimeout.sh +chmod +x /etc/profile.d/notimeout.sh +sed -i '/pam_exec.so/s/quiet/quiet set_timeout=0/' /etc/pam.d/sshd 2>/dev/null +source /etc/profile + diff --git a/src/web_automation/mod.rs b/src/web_automation/mod.rs index 1cb85a3..57ae5a6 100644 --- a/src/web_automation/mod.rs +++ b/src/web_automation/mod.rs @@ -2,15 +2,15 @@ // sudo dpkg -i google-chrome-stable_current_amd64.deb use log::info; -use crate::utils; +use crate::shared::utils; use std::env; use std::error::Error; use std::future::Future; +use std::path::PathBuf; use std::pin::Pin; use std::process::Command; use std::sync::Arc; -use std::path::PathBuf; use thirtyfour::{DesiredCapabilities, WebDriver}; use tokio::fs; use tokio::sync::Semaphore; @@ -76,32 +76,33 @@ impl BrowserSetup { }) } - async fn find_brave() -> Result> { - let mut possible_paths = vec![ - // Windows - Program Files - String::from(r"C:\Program Files\BraveSoftware\Brave-Browser\Application\brave.exe"), - // macOS - String::from("/Applications/Brave Browser.app/Contents/MacOS/Brave Browser"), - // Linux - String::from("/usr/bin/brave-browser"), - String::from("/usr/bin/brave"), - ]; + async fn find_brave() -> Result> { + let mut possible_paths = vec![ + // Windows - Program Files + String::from(r"C:\Program Files\BraveSoftware\Brave-Browser\Application\brave.exe"), + // macOS + String::from("/Applications/Brave Browser.app/Contents/MacOS/Brave Browser"), + // Linux + String::from("/usr/bin/brave-browser"), + String::from("/usr/bin/brave"), + ]; - // Windows - AppData (usuário atual) - if let Ok(local_appdata) = env::var("LOCALAPPDATA") { - let mut path = PathBuf::from(local_appdata); - path.push("BraveSoftware\\Brave-Browser\\Application\\brave.exe"); - possible_paths.push(path.to_string_lossy().to_string()); - } - - for path in possible_paths { - if fs::metadata(&path).await.is_ok() { - return Ok(path); + // Windows - AppData (usuário atual) + if let Ok(local_appdata) = env::var("LOCALAPPDATA") { + let mut path = PathBuf::from(local_appdata); + path.push("BraveSoftware\\Brave-Browser\\Application\\brave.exe"); + possible_paths.push(path.to_string_lossy().to_string()); } - } - Err("Brave browser not found. Please install Brave first.".into()) -} async fn setup_chromedriver() -> Result> { + for path in possible_paths { + if fs::metadata(&path).await.is_ok() { + return Ok(path); + } + } + + Err("Brave browser not found. Please install Brave first.".into()) + } + async fn setup_chromedriver() -> Result> { // Create chromedriver directory in executable's parent directory let mut chromedriver_dir = env::current_exe()?.parent().unwrap().to_path_buf(); chromedriver_dir.push("chromedriver");