Refactor to Arc<AppState> for shared state

- Migrate core services to store Arc<AppState> and use locks
- Centralize state in AppState with Arc-wrapped managers
- Update handlers to pass Arc<AppState> via web::Data
- Add Default for AppState and initialize components in main
- Update debug.json program path from gbserver to botserver
This commit is contained in:
Rodrigo Rodriguez (Pragmatismo) 2025-10-12 20:12:49 -03:00
parent e14908fc0b
commit e00eb40d43
10 changed files with 384 additions and 280 deletions

View file

@ -5,7 +5,7 @@
"command": "cargo", "command": "cargo",
"args": ["build"] "args": ["build"]
}, },
"program": "$ZED_WORKTREE_ROOT/target/debug/gbserver", "program": "$ZED_WORKTREE_ROOT/target/debug/botserver",
"sourceLanguages": ["rust"], "sourceLanguages": ["rust"],
"request": "launch", "request": "launch",

View file

@ -9,8 +9,8 @@ echo "Consolidated LLM Context" > "$OUTPUT_FILE"
prompts=( prompts=(
"../../prompts/dev/shared.md" "../../prompts/dev/shared.md"
"../../Cargo.toml" "../../Cargo.toml"
#"../../prompts/dev/fix.md" "../../prompts/dev/fix.md"
"../../prompts/dev/generation.md" #"../../prompts/dev/generation.md"
) )
for file in "${prompts[@]}"; do for file in "${prompts[@]}"; do
@ -23,18 +23,18 @@ dirs=(
#"automation" #"automation"
#"basic" #"basic"
"bot" "bot"
"channels" #"channels"
#"config" #"config"
"context" #"context"
#"email" #"email"
#"file" #"file"
"llm" #"llm"
#"llm_legacy" #"llm_legacy"
#"org" #"org"
"session" #"session"
"shared" "shared"
#"tests" #"tests"
"tools" #"tools"
#"web_automation" #"web_automation"
#"whatsapp" #"whatsapp"
) )
@ -47,10 +47,10 @@ done
# Also append the specific files you mentioned # Also append the specific files you mentioned
cat "$PROJECT_ROOT/src/main.rs" >> "$OUTPUT_FILE" cat "$PROJECT_ROOT/src/main.rs" >> "$OUTPUT_FILE"
cat "$PROJECT_ROOT/src/basic/keywords/hear_talk.rs" >> "$OUTPUT_FILE" # cat "$PROJECT_ROOT/src/basic/keywords/hear_talk.rs" >> "$OUTPUT_FILE"
# cat "$PROJECT_ROOT/src/basic/keywords/get.rs" >> "$OUTPUT_FILE"
echo "This BASIC file will run as soon as the conversation is created. " >> "$OUTPUT_FILE" cat "$PROJECT_ROOT/src/basic/keywords/llm_keyword.rs" >> "$OUTPUT_FILE"
cat "$PROJECT_ROOT/templates/annoucements.gbai/annoucements.gbdialog/start.bas" >> "$OUTPUT_FILE" # cat "$PROJECT_ROOT/src/basic/mod.rs" >> "$OUTPUT_FILE"
echo "" >> "$OUTPUT_FILE" echo "" >> "$OUTPUT_FILE"

View file

@ -10,12 +10,12 @@ use tokio::time::Duration;
use uuid::Uuid; use uuid::Uuid;
pub struct AutomationService { pub struct AutomationService {
state: AppState, state: Arc<AppState>,
scripts_dir: String, scripts_dir: String,
} }
impl AutomationService { impl AutomationService {
pub fn new(state: AppState, scripts_dir: &str) -> Self { pub fn new(state: Arc<AppState>, scripts_dir: &str) -> Self {
Self { Self {
state, state,
scripts_dir: scripts_dir.to_string(), scripts_dir: scripts_dir.to_string(),
@ -61,13 +61,11 @@ impl AutomationService {
async fn check_table_changes(&self, automations: &[Automation], since: DateTime<Utc>) { async fn check_table_changes(&self, automations: &[Automation], since: DateTime<Utc>) {
for automation in automations { for automation in automations {
// Resolve the trigger kind, disambiguating the `from_i32` call.
let trigger_kind = match crate::shared::models::TriggerKind::from_i32(automation.kind) { let trigger_kind = match crate::shared::models::TriggerKind::from_i32(automation.kind) {
Some(k) => k, Some(k) => k,
None => continue, None => continue,
}; };
// We're only interested in tablechange triggers.
if !matches!( if !matches!(
trigger_kind, trigger_kind,
TriggerKind::TableUpdate | TriggerKind::TableInsert | TriggerKind::TableDelete TriggerKind::TableUpdate | TriggerKind::TableInsert | TriggerKind::TableDelete
@ -75,50 +73,41 @@ impl AutomationService {
continue; continue;
} }
// Table name must be present.
let table = match &automation.target { let table = match &automation.target {
Some(t) => t, Some(t) => t,
None => continue, None => continue,
}; };
// Choose the appropriate timestamp column.
let column = match trigger_kind { let column = match trigger_kind {
TriggerKind::TableInsert => "created_at", TriggerKind::TableInsert => "created_at",
_ => "updated_at", _ => "updated_at",
}; };
// Build a simple COUNT(*) query; alias the column so Diesel can map it directly to i64.
let query = format!( let query = format!(
"SELECT COUNT(*) as count FROM {} WHERE {} > $1", "SELECT COUNT(*) as count FROM {} WHERE {} > $1",
table, column table, column
); );
// Acquire a connection for this query.
let mut conn_guard = self.state.conn.lock().unwrap(); let mut conn_guard = self.state.conn.lock().unwrap();
let conn = &mut *conn_guard; let conn = &mut *conn_guard;
// Define a struct to capture the query result
#[derive(diesel::QueryableByName)] #[derive(diesel::QueryableByName)]
struct CountResult { struct CountResult {
#[diesel(sql_type = diesel::sql_types::BigInt)] #[diesel(sql_type = diesel::sql_types::BigInt)]
count: i64, count: i64,
} }
// Execute the query, retrieving a plain i64 count.
let count_result = diesel::sql_query(&query) let count_result = diesel::sql_query(&query)
.bind::<diesel::sql_types::Timestamp, _>(since.naive_utc()) .bind::<diesel::sql_types::Timestamp, _>(since.naive_utc())
.get_result::<CountResult>(conn); .get_result::<CountResult>(conn);
match count_result { match count_result {
Ok(result) if result.count > 0 => { Ok(result) if result.count > 0 => {
// Release the lock before awaiting asynchronous work.
drop(conn_guard); drop(conn_guard);
self.execute_action(&automation.param).await; self.execute_action(&automation.param).await;
self.update_last_triggered(automation.id).await; self.update_last_triggered(automation.id).await;
} }
Ok(_result) => { Ok(_result) => {}
// No relevant rows changed; continue to the next automation.
}
Err(e) => { Err(e) => {
error!("Error checking changes for table '{}': {}", table, e); error!("Error checking changes for table '{}': {}", table, e);
} }
@ -215,7 +204,7 @@ impl AutomationService {
updated_at: Utc::now(), updated_at: Utc::now(),
}; };
let script_service = ScriptService::new(&self.state, user_session); let script_service = ScriptService::new(Arc::clone(&self.state), user_session);
let ast = match script_service.compile(&script_content) { let ast = match script_service.compile(&script_content) {
Ok(ast) => ast, Ok(ast) => ast,
Err(e) => { Err(e) => {

View file

@ -1,16 +1,15 @@
use log::info;
use crate::shared::state::AppState;
use crate::shared::models::UserSession; use crate::shared::models::UserSession;
use crate::shared::state::AppState;
use log::info;
use reqwest::{self, Client}; use reqwest::{self, Client};
use rhai::{Dynamic, Engine}; use rhai::{Dynamic, Engine};
use std::error::Error; use std::error::Error;
pub fn get_keyword(_state: &AppState, _user: UserSession, engine: &mut Engine) { pub fn get_keyword(state: &AppState, _user: UserSession, engine: &mut Engine) {
let state_clone = state.clone();
engine engine
.register_custom_syntax( .register_custom_syntax(&["GET", "$expr$"], false, move |context, inputs| {
&["GET", "$expr$"],
false,
move |context, inputs| {
let url = context.eval_expression_tree(&inputs[0])?; let url = context.eval_expression_tree(&inputs[0])?;
let url_str = url.to_string(); let url_str = url.to_string();
@ -18,42 +17,29 @@ pub fn get_keyword(_state: &AppState, _user: UserSession, engine: &mut Engine) {
return Err("URL contains invalid path traversal sequences like '..'.".into()); return Err("URL contains invalid path traversal sequences like '..'.".into());
} }
let modified_url = if url_str.starts_with("/") { let state_for_async = state_clone.clone();
let work_root = std::env::var("WORK_ROOT").unwrap_or_else(|_| "./work".to_string()); let url_for_async = url_str.clone();
let full_path = std::path::Path::new(&work_root)
.join(url_str.trim_start_matches('/'))
.to_string_lossy()
.into_owned();
let base_url = "file://"; if url_str.starts_with("https://") {
format!("{}{}", base_url, full_path) info!("HTTPS GET request: {}", url_for_async);
} else {
url_str.to_string()
};
if modified_url.starts_with("https://") { let fut = execute_get(&url_for_async);
info!("HTTPS GET request: {}", modified_url);
let fut = execute_get(&modified_url);
let result = let result =
tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(fut)) tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(fut))
.map_err(|e| format!("HTTP request failed: {}", e))?; .map_err(|e| format!("HTTP request failed: {}", e))?;
Ok(Dynamic::from(result)) Ok(Dynamic::from(result))
} else if modified_url.starts_with("file://") {
let file_path = modified_url.trim_start_matches("file://");
match std::fs::read_to_string(file_path) {
Ok(content) => Ok(Dynamic::from(content)),
Err(e) => Err(format!("Failed to read file: {}", e).into()),
}
} else { } else {
Err( info!("Local file GET request from bucket: {}", url_for_async);
format!("GET request failed: URL must begin with 'https://' or 'file://'")
.into(), let fut = get_from_bucket(&state_for_async, &url_for_async);
) let result =
tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(fut))
.map_err(|e| format!("Bucket GET failed: {}", e))?;
Ok(Dynamic::from(result))
} }
}, })
)
.unwrap(); .unwrap();
} }
@ -69,3 +55,29 @@ pub async fn execute_get(url: &str) -> Result<String, Box<dyn Error + Send + Syn
Ok(content) Ok(content)
} }
pub async fn get_from_bucket(
state: &AppState,
file_path: &str,
) -> Result<String, Box<dyn Error + Send + Sync>> {
info!("Getting file from bucket: {}", file_path);
if let Some(s3_client) = &state.s3_client {
let bucket_name =
std::env::var("DEFAULT_BUCKET").unwrap_or_else(|_| "default-bucket".to_string());
let response = s3_client
.get_object()
.bucket(&bucket_name)
.key(file_path)
.send()
.await?;
let data = response.body.collect().await?;
let content = String::from_utf8(data.into_bytes().to_vec())?;
Ok(content)
} else {
Err("S3 client not configured".into())
}
}

View file

@ -1,10 +1,11 @@
use crate::shared::models::UserSession;
use crate::shared::state::AppState; use crate::shared::state::AppState;
use crate::{channels::ChannelAdapter, shared::models::UserSession};
use log::info; use log::info;
use rhai::{Dynamic, Engine, EvalAltResult}; use rhai::{Dynamic, Engine, EvalAltResult};
pub fn hear_keyword(_state: &AppState, user: UserSession, engine: &mut Engine) { pub fn hear_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
let session_id = user.id; let session_id = user.id;
let cache = state.redis_client.clone();
engine engine
.register_custom_syntax(&["HEAR", "$ident$"], true, move |_context, inputs| { .register_custom_syntax(&["HEAR", "$ident$"], true, move |_context, inputs| {
@ -18,19 +19,35 @@ pub fn hear_keyword(_state: &AppState, user: UserSession, engine: &mut Engine) {
variable_name variable_name
); );
// Spawn a background task to handle the inputwaiting logic. let cache_clone = cache.clone();
// The actual waiting implementation should be added here. let session_id_clone = session_id;
let var_name_clone = variable_name.clone();
tokio::spawn(async move { tokio::spawn(async move {
log::debug!( log::debug!(
"HEAR: Starting async task for session {} and variable '{}'", "HEAR: Starting async task for session {} and variable '{}'",
session_id, session_id_clone,
variable_name var_name_clone
); );
// TODO: implement actual waiting logic here without using the orchestrator
// For now, just log that we would wait for input if let Some(cache_client) = &cache_clone {
let mut conn = match cache_client.get_multiplexed_async_connection().await {
Ok(conn) => conn,
Err(e) => {
log::error!("Failed to connect to cache: {}", e);
return;
}
};
let key = format!("hear:{}:{}", session_id_clone, var_name_clone);
let _: Result<(), _> = redis::cmd("SET")
.arg(&key)
.arg("waiting")
.query_async(&mut conn)
.await;
}
}); });
// Interrupt the current Rhai evaluation flow until the user input is received.
Err(Box::new(EvalAltResult::ErrorRuntime( Err(Box::new(EvalAltResult::ErrorRuntime(
"Waiting for user input".into(), "Waiting for user input".into(),
rhai::Position::NONE, rhai::Position::NONE,
@ -40,7 +57,6 @@ pub fn hear_keyword(_state: &AppState, user: UserSession, engine: &mut Engine) {
} }
pub fn talk_keyword(state: &AppState, user: UserSession, engine: &mut Engine) { pub fn talk_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
// Import the BotResponse type directly to satisfy diagnostics.
use crate::shared::models::BotResponse; use crate::shared::models::BotResponse;
let state_clone = state.clone(); let state_clone = state.clone();
@ -63,13 +79,11 @@ pub fn talk_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
is_complete: true, is_complete: true,
}; };
// Send response through a channel or queue instead of accessing orchestrator directly let state_for_spawn = state_clone.clone();
let _state_for_spawn = state_clone.clone();
tokio::spawn(async move { tokio::spawn(async move {
// Use a more thread-safe approach to send the message if let Err(e) = state_for_spawn.web_adapter.send_message(response).await {
// This avoids capturing the orchestrator directly which isn't Send + Sync log::error!("Failed to send TALK message: {}", e);
// TODO: Implement proper response handling once response_sender field is added to AppState }
log::debug!("TALK: Would send response: {:?}", response);
}); });
Ok(Dynamic::UNIT) Ok(Dynamic::UNIT)
@ -78,7 +92,7 @@ pub fn talk_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
} }
pub fn set_context_keyword(state: &AppState, user: UserSession, engine: &mut Engine) { pub fn set_context_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
let state_clone = state.clone(); let cache = state.redis_client.clone();
engine engine
.register_custom_syntax( .register_custom_syntax(
@ -91,14 +105,14 @@ pub fn set_context_keyword(state: &AppState, user: UserSession, engine: &mut Eng
let redis_key = format!("context:{}:{}", user.user_id, user.id); let redis_key = format!("context:{}:{}", user.user_id, user.id);
let state_for_redis = state_clone.clone(); let cache_clone = cache.clone();
tokio::spawn(async move { tokio::spawn(async move {
if let Some(redis_client) = &state_for_redis.redis_client { if let Some(cache_client) = &cache_clone {
let mut conn = match redis_client.get_multiplexed_async_connection().await { let mut conn = match cache_client.get_multiplexed_async_connection().await {
Ok(conn) => conn, Ok(conn) => conn,
Err(e) => { Err(e) => {
log::error!("Failed to connect to Redis: {}", e); log::error!("Failed to connect to cache: {}", e);
return; return;
} }
}; };

View file

@ -1,11 +1,10 @@
use crate::shared::models::UserSession; use crate::shared::models::UserSession;
use crate::shared::state::AppState; use crate::shared::state::AppState;
use crate::shared::utils::call_llm;
use log::info; use log::info;
use rhai::{Dynamic, Engine}; use rhai::{Dynamic, Engine};
pub fn llm_keyword(state: &AppState, _user: UserSession, engine: &mut Engine) { pub fn llm_keyword(state: &AppState, _user: UserSession, engine: &mut Engine) {
let ai_config = state.config.clone().unwrap().ai.clone(); let state_clone = state.clone();
engine engine
.register_custom_syntax(&["LLM", "$expr$"], false, move |context, inputs| { .register_custom_syntax(&["LLM", "$expr$"], false, move |context, inputs| {
@ -14,10 +13,18 @@ pub fn llm_keyword(state: &AppState, _user: UserSession, engine: &mut Engine) {
info!("LLM processing text: {}", text_str); info!("LLM processing text: {}", text_str);
let fut = call_llm(&text_str, &ai_config); let state_inner = state_clone.clone();
let fut = async move {
let prompt = text_str;
state_inner
.llm_provider
.generate(&prompt, &serde_json::Value::Null)
.await
.map_err(|e| format!("LLM call failed: {}", e))
};
let result = let result =
tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(fut)) tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(fut))?;
.map_err(|e| format!("LLM call failed: {}", e))?;
Ok(Dynamic::from(result)) Ok(Dynamic::from(result))
}) })

View file

@ -1,7 +1,10 @@
pub mod keywords; use crate::shared::models::UserSession;
use crate::shared::state::AppState;
use log::info;
use rhai::{Dynamic, Engine, EvalAltResult};
use std::sync::Arc;
#[cfg(feature = "email")] pub mod keywords;
use self::keywords::create_draft_keyword;
use self::keywords::create_site::create_site_keyword; use self::keywords::create_site::create_site_keyword;
use self::keywords::find::find_keyword; use self::keywords::find::find_keyword;
@ -17,43 +20,54 @@ use self::keywords::print::print_keyword;
use self::keywords::set::set_keyword; use self::keywords::set::set_keyword;
use self::keywords::set_schedule::set_schedule_keyword; use self::keywords::set_schedule::set_schedule_keyword;
use self::keywords::wait::wait_keyword; use self::keywords::wait::wait_keyword;
use crate::shared::models::UserSession;
use crate::shared::state::AppState; #[cfg(feature = "email")]
use log::info; use self::keywords::create_draft_keyword;
use rhai::{Dynamic, Engine, EvalAltResult};
#[cfg(feature = "web_automation")]
use self::keywords::get_website::get_website_keyword;
pub struct ScriptService { pub struct ScriptService {
engine: Engine, engine: Engine,
state: Arc<AppState>,
user: UserSession,
} }
impl ScriptService { impl ScriptService {
pub fn new(state: &AppState, user: UserSession) -> Self { pub fn new(state: Arc<AppState>, user: UserSession) -> Self {
let mut engine = Engine::new(); let mut engine = Engine::new();
engine.set_allow_anonymous_fn(true); engine.set_allow_anonymous_fn(true);
engine.set_allow_looping(true); engine.set_allow_looping(true);
#[cfg(feature = "email")] #[cfg(feature = "email")]
create_draft_keyword(state, user.clone(), &mut engine); create_draft_keyword(&state, user.clone(), &mut engine);
create_site_keyword(state, user.clone(), &mut engine); create_site_keyword(&state, user.clone(), &mut engine);
find_keyword(state, user.clone(), &mut engine); find_keyword(&state, user.clone(), &mut engine);
for_keyword(state, user.clone(), &mut engine); for_keyword(&state, user.clone(), &mut engine);
first_keyword(&mut engine); first_keyword(&mut engine);
last_keyword(&mut engine); last_keyword(&mut engine);
format_keyword(&mut engine); format_keyword(&mut engine);
llm_keyword(state, user.clone(), &mut engine); llm_keyword(&state, user.clone(), &mut engine);
get_keyword(state, user.clone(), &mut engine); get_keyword(&state, user.clone(), &mut engine);
set_keyword(state, user.clone(), &mut engine); set_keyword(&state, user.clone(), &mut engine);
wait_keyword(state, user.clone(), &mut engine); wait_keyword(&state, user.clone(), &mut engine);
print_keyword(state, user.clone(), &mut engine); print_keyword(&state, user.clone(), &mut engine);
on_keyword(state, user.clone(), &mut engine); on_keyword(&state, user.clone(), &mut engine);
set_schedule_keyword(state, user.clone(), &mut engine); set_schedule_keyword(&state, user.clone(), &mut engine);
hear_keyword(state, user.clone(), &mut engine); hear_keyword(&state, user.clone(), &mut engine);
talk_keyword(state, user.clone(), &mut engine); talk_keyword(&state, user.clone(), &mut engine);
set_context_keyword(state, user.clone(), &mut engine); set_context_keyword(&state, user.clone(), &mut engine);
ScriptService { engine } #[cfg(feature = "web_automation")]
get_website_keyword(&state, user.clone(), &mut engine);
ScriptService {
engine,
state,
user,
}
} }
fn preprocess_basic_script(&self, script: &str) -> String { fn preprocess_basic_script(&self, script: &str) -> String {

View file

@ -6,40 +6,20 @@ use serde_json;
use std::collections::HashMap; use std::collections::HashMap;
use std::fs; use std::fs;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::{mpsc, Mutex}; use tokio::sync::mpsc;
use uuid::Uuid; use uuid::Uuid;
use crate::auth::AuthService;
use crate::channels::ChannelAdapter; use crate::channels::ChannelAdapter;
use crate::llm::LLMProvider;
use crate::session::SessionManager;
use crate::shared::models::{BotResponse, UserMessage, UserSession}; use crate::shared::models::{BotResponse, UserMessage, UserSession};
use crate::tools::ToolManager; use crate::shared::state::AppState;
pub struct BotOrchestrator { pub struct BotOrchestrator {
pub session_manager: Arc<Mutex<SessionManager>>, pub state: Arc<AppState>,
tool_manager: Arc<ToolManager>,
llm_provider: Arc<dyn LLMProvider>,
auth_service: Arc<Mutex<AuthService>>,
pub channels: HashMap<String, Arc<dyn ChannelAdapter>>,
response_channels: Arc<Mutex<HashMap<String, mpsc::Sender<BotResponse>>>>,
} }
impl BotOrchestrator { impl BotOrchestrator {
pub fn new( pub fn new(state: Arc<AppState>) -> Self {
session_manager: SessionManager, Self { state }
tool_manager: ToolManager,
llm_provider: Arc<dyn LLMProvider>,
auth_service: AuthService,
) -> Self {
Self {
session_manager: Arc::new(Mutex::new(session_manager)),
tool_manager: Arc::new(tool_manager),
llm_provider,
auth_service: Arc::new(Mutex::new(auth_service)),
channels: HashMap::new(),
response_channels: Arc::new(Mutex::new(HashMap::new())),
}
} }
pub async fn handle_user_input( pub async fn handle_user_input(
@ -47,18 +27,22 @@ impl BotOrchestrator {
session_id: Uuid, session_id: Uuid,
user_input: &str, user_input: &str,
) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> {
let mut session_manager = self.session_manager.lock().await; let mut session_manager = self.state.session_manager.lock().await;
session_manager.provide_input(session_id, user_input.to_string())?; session_manager.provide_input(session_id, user_input.to_string())?;
Ok(None) Ok(None)
} }
pub async fn is_waiting_for_input(&self, session_id: Uuid) -> bool { pub async fn is_waiting_for_input(&self, session_id: Uuid) -> bool {
let session_manager = self.session_manager.lock().await; let session_manager = self.state.session_manager.lock().await;
session_manager.is_waiting_for_input(&session_id) session_manager.is_waiting_for_input(&session_id)
} }
pub fn add_channel(&mut self, channel_type: &str, adapter: Arc<dyn ChannelAdapter>) { pub fn add_channel(&self, channel_type: &str, adapter: Arc<dyn ChannelAdapter>) {
self.channels.insert(channel_type.to_string(), adapter); self.state
.channels
.lock()
.unwrap()
.insert(channel_type.to_string(), adapter);
} }
pub async fn register_response_channel( pub async fn register_response_channel(
@ -66,7 +50,8 @@ impl BotOrchestrator {
session_id: String, session_id: String,
sender: mpsc::Sender<BotResponse>, sender: mpsc::Sender<BotResponse>,
) { ) {
self.response_channels self.state
.response_channels
.lock() .lock()
.await .await
.insert(session_id, sender); .insert(session_id, sender);
@ -78,7 +63,7 @@ impl BotOrchestrator {
bot_id: &str, bot_id: &str,
mode: i32, mode: i32,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let mut session_manager = self.session_manager.lock().await; let mut session_manager = self.state.session_manager.lock().await;
session_manager.update_answer_mode(user_id, bot_id, mode)?; session_manager.update_answer_mode(user_id, bot_id, mode)?;
Ok(()) Ok(())
} }
@ -97,10 +82,15 @@ impl BotOrchestrator {
.unwrap_or_else(|_| Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap()); .unwrap_or_else(|_| Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap());
let session = { let session = {
let mut session_manager = self.session_manager.lock().await; let mut session_manager = self.state.session_manager.lock().await;
match session_manager.get_user_session(user_id, bot_id)? { match session_manager.get_user_session(user_id, bot_id)? {
Some(session) => session, Some(session) => session,
None => session_manager.create_session(user_id, bot_id, "New Conversation")?, None => {
let new_session =
session_manager.create_session(user_id, bot_id, "New Conversation")?;
Self::run_start_script(&new_session, Arc::clone(&self.state)).await;
new_session
}
} }
}; };
@ -113,7 +103,7 @@ impl BotOrchestrator {
variable_name, session.id variable_name, session.id
); );
if let Some(adapter) = self.channels.get(&message.channel) { if let Some(adapter) = self.state.channels.lock().unwrap().get(&message.channel) {
let ack_response = BotResponse { let ack_response = BotResponse {
bot_id: message.bot_id.clone(), bot_id: message.bot_id.clone(),
user_id: message.user_id.clone(), user_id: message.user_id.clone(),
@ -131,7 +121,7 @@ impl BotOrchestrator {
} }
if session.answer_mode == 1 && session.current_tool.is_some() { if session.answer_mode == 1 && session.current_tool.is_some() {
self.tool_manager.provide_user_response( self.state.tool_manager.provide_user_response(
&message.user_id, &message.user_id,
&message.bot_id, &message.bot_id,
message.content.clone(), message.content.clone(),
@ -140,7 +130,7 @@ impl BotOrchestrator {
} }
{ {
let mut session_manager = self.session_manager.lock().await; let mut session_manager = self.state.session_manager.lock().await;
session_manager.save_message( session_manager.save_message(
session.id, session.id,
user_id, user_id,
@ -153,7 +143,7 @@ impl BotOrchestrator {
let response_content = self.direct_mode_handler(&message, &session).await?; let response_content = self.direct_mode_handler(&message, &session).await?;
{ {
let mut session_manager = self.session_manager.lock().await; let mut session_manager = self.state.session_manager.lock().await;
session_manager.save_message(session.id, user_id, 2, &response_content, 1)?; session_manager.save_message(session.id, user_id, 2, &response_content, 1)?;
} }
@ -168,7 +158,7 @@ impl BotOrchestrator {
is_complete: true, is_complete: true,
}; };
if let Some(adapter) = self.channels.get(&message.channel) { if let Some(adapter) = self.state.channels.lock().unwrap().get(&message.channel) {
adapter.send_message(bot_response).await?; adapter.send_message(bot_response).await?;
} }
@ -180,21 +170,19 @@ impl BotOrchestrator {
message: &UserMessage, message: &UserMessage,
session: &UserSession, session: &UserSession,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
// Retrieve conversation history while holding a mutable lock on the session manager.
let history = { let history = {
let mut session_manager = self.session_manager.lock().await; let mut session_manager = self.state.session_manager.lock().await;
session_manager.get_conversation_history(session.id, session.user_id)? session_manager.get_conversation_history(session.id, session.user_id)?
}; };
// Build the prompt from the conversation history.
let mut prompt = String::new(); let mut prompt = String::new();
for (role, content) in history { for (role, content) in history {
prompt.push_str(&format!("{}: {}\n", role, content)); prompt.push_str(&format!("{}: {}\n", role, content));
} }
prompt.push_str(&format!("User: {}\nAssistant:", message.content)); prompt.push_str(&format!("User: {}\nAssistant:", message.content));
// Generate the assistant's response using the LLM provider. self.state
self.llm_provider .llm_provider
.generate(&prompt, &serde_json::Value::Null) .generate(&prompt, &serde_json::Value::Null)
.await .await
} }
@ -206,31 +194,31 @@ impl BotOrchestrator {
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
info!("Streaming response for user: {}", message.user_id); info!("Streaming response for user: {}", message.user_id);
// Parse identifiers, falling back to safe defaults.
let mut user_id = Uuid::parse_str(&message.user_id).unwrap_or_else(|_| Uuid::new_v4()); let mut user_id = Uuid::parse_str(&message.user_id).unwrap_or_else(|_| Uuid::new_v4());
let bot_id = Uuid::parse_str(&message.bot_id).unwrap_or_else(|_| Uuid::nil()); let bot_id = Uuid::parse_str(&message.bot_id).unwrap_or_else(|_| Uuid::nil());
let mut auth = self.auth_service.lock().await; let mut auth = self.state.auth_service.lock().await;
let user_exists = auth.get_user_by_id(user_id)?; let user_exists = auth.get_user_by_id(user_id)?;
if user_exists.is_none() { if user_exists.is_none() {
// User does not exist, invoke Authentication service to create them
user_id = auth.create_user("anonymous1", "anonymous@local", "password")?; user_id = auth.create_user("anonymous1", "anonymous@local", "password")?;
} else { } else {
user_id = user_exists.unwrap().id; user_id = user_exists.unwrap().id;
} }
// Retrieve an existing session or create a new one.
let session = { let session = {
let mut sm = self.session_manager.lock().await; let mut sm = self.state.session_manager.lock().await;
match sm.get_user_session(user_id, bot_id)? { match sm.get_user_session(user_id, bot_id)? {
Some(sess) => sess, Some(sess) => sess,
None => sm.create_session(user_id, bot_id, "New Conversation")?, None => {
let new_session = sm.create_session(user_id, bot_id, "New Conversation")?;
Self::run_start_script(&new_session, Arc::clone(&self.state)).await;
new_session
}
} }
}; };
// If the session is awaiting tool input, forward the user's answer to the tool manager.
if session.answer_mode == 1 && session.current_tool.is_some() { if session.answer_mode == 1 && session.current_tool.is_some() {
self.tool_manager.provide_user_response( self.state.tool_manager.provide_user_response(
&message.user_id, &message.user_id,
&message.bot_id, &message.bot_id,
message.content.clone(), message.content.clone(),
@ -238,9 +226,8 @@ impl BotOrchestrator {
return Ok(()); return Ok(());
} }
// Persist the incoming user message.
{ {
let mut sm = self.session_manager.lock().await; let mut sm = self.state.session_manager.lock().await;
sm.save_message( sm.save_message(
session.id, session.id,
user_id, user_id,
@ -250,9 +237,8 @@ impl BotOrchestrator {
)?; )?;
} }
// Build the prompt from the conversation history.
let prompt = { let prompt = {
let mut sm = self.session_manager.lock().await; let mut sm = self.state.session_manager.lock().await;
let history = sm.get_conversation_history(session.id, user_id)?; let history = sm.get_conversation_history(session.id, user_id)?;
let mut p = String::new(); let mut p = String::new();
for (role, content) in history { for (role, content) in history {
@ -262,11 +248,9 @@ impl BotOrchestrator {
p p
}; };
// Set up a channel for the streaming LLM output.
let (stream_tx, mut stream_rx) = mpsc::channel::<String>(100); let (stream_tx, mut stream_rx) = mpsc::channel::<String>(100);
let llm = self.llm_provider.clone(); let llm = self.state.llm_provider.clone();
// Spawn the LLM streaming task.
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = llm if let Err(e) = llm
.generate_stream(&prompt, &serde_json::Value::Null, stream_tx) .generate_stream(&prompt, &serde_json::Value::Null, stream_tx)
@ -276,7 +260,6 @@ impl BotOrchestrator {
} }
}); });
// Forward each chunk to the client as it arrives.
let mut full_response = String::new(); let mut full_response = String::new();
while let Some(chunk) = stream_rx.recv().await { while let Some(chunk) = stream_rx.recv().await {
full_response.push_str(&chunk); full_response.push_str(&chunk);
@ -293,18 +276,15 @@ impl BotOrchestrator {
}; };
if response_tx.send(partial).await.is_err() { if response_tx.send(partial).await.is_err() {
// Receiver has been dropped; stop streaming.
break; break;
} }
} }
// Save the complete assistant reply.
{ {
let mut sm = self.session_manager.lock().await; let mut sm = self.state.session_manager.lock().await;
sm.save_message(session.id, user_id, 2, &full_response, 1)?; sm.save_message(session.id, user_id, 2, &full_response, 1)?;
} }
// Notify the client that the stream is finished.
let final_msg = BotResponse { let final_msg = BotResponse {
bot_id: message.bot_id, bot_id: message.bot_id,
user_id: message.user_id, user_id: message.user_id,
@ -324,7 +304,7 @@ impl BotOrchestrator {
&self, &self,
user_id: Uuid, user_id: Uuid,
) -> Result<Vec<UserSession>, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<Vec<UserSession>, Box<dyn std::error::Error + Send + Sync>> {
let mut session_manager = self.session_manager.lock().await; let mut session_manager = self.state.session_manager.lock().await;
session_manager.get_user_sessions(user_id) session_manager.get_user_sessions(user_id)
} }
@ -333,7 +313,7 @@ impl BotOrchestrator {
session_id: Uuid, session_id: Uuid,
user_id: Uuid, user_id: Uuid,
) -> Result<Vec<(String, String)>, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<Vec<(String, String)>, Box<dyn std::error::Error + Send + Sync>> {
let mut session_manager = self.session_manager.lock().await; let mut session_manager = self.state.session_manager.lock().await;
session_manager.get_conversation_history(session_id, user_id) session_manager.get_conversation_history(session_id, user_id)
} }
@ -351,15 +331,20 @@ impl BotOrchestrator {
.unwrap_or_else(|_| Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap()); .unwrap_or_else(|_| Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap());
let session = { let session = {
let mut session_manager = self.session_manager.lock().await; let mut session_manager = self.state.session_manager.lock().await;
match session_manager.get_user_session(user_id, bot_id)? { match session_manager.get_user_session(user_id, bot_id)? {
Some(session) => session, Some(session) => session,
None => session_manager.create_session(user_id, bot_id, "New Conversation")?, None => {
let new_session =
session_manager.create_session(user_id, bot_id, "New Conversation")?;
Self::run_start_script(&new_session, Arc::clone(&self.state)).await;
new_session
}
} }
}; };
{ {
let mut session_manager = self.session_manager.lock().await; let mut session_manager = self.state.session_manager.lock().await;
session_manager.save_message( session_manager.save_message(
session.id, session.id,
user_id, user_id,
@ -370,17 +355,24 @@ impl BotOrchestrator {
} }
let is_tool_waiting = self let is_tool_waiting = self
.state
.tool_manager .tool_manager
.is_tool_waiting(&message.session_id) .is_tool_waiting(&message.session_id)
.await .await
.unwrap_or(false); .unwrap_or(false);
if is_tool_waiting { if is_tool_waiting {
self.tool_manager self.state
.tool_manager
.provide_input(&message.session_id, &message.content) .provide_input(&message.session_id, &message.content)
.await?; .await?;
if let Ok(tool_output) = self.tool_manager.get_tool_output(&message.session_id).await { if let Ok(tool_output) = self
.state
.tool_manager
.get_tool_output(&message.session_id)
.await
{
for output in tool_output { for output in tool_output {
let bot_response = BotResponse { let bot_response = BotResponse {
bot_id: message.bot_id.clone(), bot_id: message.bot_id.clone(),
@ -393,7 +385,8 @@ impl BotOrchestrator {
is_complete: true, is_complete: true,
}; };
if let Some(adapter) = self.channels.get(&message.channel) { if let Some(adapter) = self.state.channels.lock().unwrap().get(&message.channel)
{
adapter.send_message(bot_response).await?; adapter.send_message(bot_response).await?;
} }
} }
@ -406,12 +399,13 @@ impl BotOrchestrator {
|| message.content.to_lowercase().contains("math") || message.content.to_lowercase().contains("math")
{ {
match self match self
.state
.tool_manager .tool_manager
.execute_tool("calculator", &message.session_id, &message.user_id) .execute_tool("calculator", &message.session_id, &message.user_id)
.await .await
{ {
Ok(tool_result) => { Ok(tool_result) => {
let mut session_manager = self.session_manager.lock().await; let mut session_manager = self.state.session_manager.lock().await;
session_manager.save_message(session.id, user_id, 2, &tool_result.output, 2)?; session_manager.save_message(session.id, user_id, 2, &tool_result.output, 2)?;
tool_result.output tool_result.output
@ -421,7 +415,7 @@ impl BotOrchestrator {
} }
} }
} else { } else {
let available_tools = self.tool_manager.list_tools(); let available_tools = self.state.tool_manager.list_tools();
let tools_context = if !available_tools.is_empty() { let tools_context = if !available_tools.is_empty() {
format!("\n\nAvailable tools: {}. If the user needs calculations, suggest using the calculator tool.", available_tools.join(", ")) format!("\n\nAvailable tools: {}. If the user needs calculations, suggest using the calculator tool.", available_tools.join(", "))
} else { } else {
@ -430,13 +424,14 @@ impl BotOrchestrator {
let full_prompt = format!("{}{}", message.content, tools_context); let full_prompt = format!("{}{}", message.content, tools_context);
self.llm_provider self.state
.llm_provider
.generate(&full_prompt, &serde_json::Value::Null) .generate(&full_prompt, &serde_json::Value::Null)
.await? .await?
}; };
{ {
let mut session_manager = self.session_manager.lock().await; let mut session_manager = self.state.session_manager.lock().await;
session_manager.save_message(session.id, user_id, 2, &response, 1)?; session_manager.save_message(session.id, user_id, 2, &response, 1)?;
} }
@ -451,25 +446,63 @@ impl BotOrchestrator {
is_complete: true, is_complete: true,
}; };
if let Some(adapter) = self.channels.get(&message.channel) { if let Some(adapter) = self.state.channels.lock().unwrap().get(&message.channel) {
adapter.send_message(bot_response).await?; adapter.send_message(bot_response).await?;
} }
Ok(()) Ok(())
} }
async fn run_start_script(session: &UserSession, state: Arc<AppState>) {
let start_script = r#"
TALK "Welcome to General Bots!"
HEAR name
TALK "Hello, " + name
text = GET "default.pdf"
SET CONTEXT text
resume = LLM "Build a resume from " + text
"#;
info!("Running start.bas for session: {}", session.id);
let session_clone = session.clone();
let state_clone = state.clone();
tokio::spawn(async move {
let state_for_run = state_clone.clone();
if let Err(e) = crate::basic::ScriptService::new(state_clone, session_clone.clone())
.compile(start_script)
.and_then(|ast| {
crate::basic::ScriptService::new(state_for_run, session_clone.clone()).run(&ast)
})
{
log::error!("Failed to run start.bas: {}", e);
}
});
}
}
impl Default for BotOrchestrator {
fn default() -> Self {
Self {
state: Arc::new(AppState::default()),
}
}
} }
#[actix_web::get("/ws")] #[actix_web::get("/ws")]
async fn websocket_handler( async fn websocket_handler(
req: HttpRequest, req: HttpRequest,
stream: web::Payload, stream: web::Payload,
data: web::Data<crate::shared::state::AppState>, data: web::Data<AppState>,
) -> Result<HttpResponse, actix_web::Error> { ) -> Result<HttpResponse, actix_web::Error> {
let (res, mut session, mut msg_stream) = actix_ws::handle(&req, stream)?; let (res, mut session, mut msg_stream) = actix_ws::handle(&req, stream)?;
let session_id = Uuid::new_v4().to_string(); let session_id = Uuid::new_v4().to_string();
let (tx, mut rx) = mpsc::channel::<BotResponse>(100); let (tx, mut rx) = mpsc::channel::<BotResponse>(100);
data.orchestrator let orchestrator = BotOrchestrator::new(Arc::clone(&data));
orchestrator
.register_response_channel(session_id.clone(), tx.clone()) .register_response_channel(session_id.clone(), tx.clone())
.await; .await;
data.web_adapter data.web_adapter
@ -479,7 +512,6 @@ async fn websocket_handler(
.add_connection(session_id.clone(), tx.clone()) .add_connection(session_id.clone(), tx.clone())
.await; .await;
let orchestrator = data.orchestrator.clone();
let web_adapter = data.web_adapter.clone(); let web_adapter = data.web_adapter.clone();
actix_web::rt::spawn(async move { actix_web::rt::spawn(async move {
@ -523,7 +555,7 @@ async fn websocket_handler(
#[actix_web::get("/api/whatsapp/webhook")] #[actix_web::get("/api/whatsapp/webhook")]
async fn whatsapp_webhook_verify( async fn whatsapp_webhook_verify(
data: web::Data<crate::shared::state::AppState>, data: web::Data<AppState>,
web::Query(params): web::Query<HashMap<String, String>>, web::Query(params): web::Query<HashMap<String, String>>,
) -> Result<HttpResponse> { ) -> Result<HttpResponse> {
let empty = String::new(); let empty = String::new();
@ -539,7 +571,7 @@ async fn whatsapp_webhook_verify(
#[actix_web::post("/api/whatsapp/webhook")] #[actix_web::post("/api/whatsapp/webhook")]
async fn whatsapp_webhook( async fn whatsapp_webhook(
data: web::Data<crate::shared::state::AppState>, data: web::Data<AppState>,
payload: web::Json<crate::whatsapp::WhatsAppMessage>, payload: web::Json<crate::whatsapp::WhatsAppMessage>,
) -> Result<HttpResponse> { ) -> Result<HttpResponse> {
match data match data
@ -549,7 +581,8 @@ async fn whatsapp_webhook(
{ {
Ok(user_messages) => { Ok(user_messages) => {
for user_message in user_messages { for user_message in user_messages {
if let Err(e) = data.orchestrator.process_message(user_message).await { let orchestrator = BotOrchestrator::new(Arc::clone(&data));
if let Err(e) = orchestrator.process_message(user_message).await {
log::error!("Error processing WhatsApp message: {}", e); log::error!("Error processing WhatsApp message: {}", e);
} }
} }
@ -564,7 +597,7 @@ async fn whatsapp_webhook(
#[actix_web::post("/api/voice/start")] #[actix_web::post("/api/voice/start")]
async fn voice_start( async fn voice_start(
data: web::Data<crate::shared::state::AppState>, data: web::Data<AppState>,
info: web::Json<serde_json::Value>, info: web::Json<serde_json::Value>,
) -> Result<HttpResponse> { ) -> Result<HttpResponse> {
let session_id = info let session_id = info
@ -593,7 +626,7 @@ async fn voice_start(
#[actix_web::post("/api/voice/stop")] #[actix_web::post("/api/voice/stop")]
async fn voice_stop( async fn voice_stop(
data: web::Data<crate::shared::state::AppState>, data: web::Data<AppState>,
info: web::Json<serde_json::Value>, info: web::Json<serde_json::Value>,
) -> Result<HttpResponse> { ) -> Result<HttpResponse> {
let session_id = info let session_id = info
@ -611,7 +644,7 @@ async fn voice_stop(
} }
#[actix_web::post("/api/sessions")] #[actix_web::post("/api/sessions")]
async fn create_session(_data: web::Data<crate::shared::state::AppState>) -> Result<HttpResponse> { async fn create_session(_data: web::Data<AppState>) -> Result<HttpResponse> {
let session_id = Uuid::new_v4(); let session_id = Uuid::new_v4();
Ok(HttpResponse::Ok().json(serde_json::json!({ Ok(HttpResponse::Ok().json(serde_json::json!({
"session_id": session_id, "session_id": session_id,
@ -621,9 +654,10 @@ async fn create_session(_data: web::Data<crate::shared::state::AppState>) -> Res
} }
#[actix_web::get("/api/sessions")] #[actix_web::get("/api/sessions")]
async fn get_sessions(data: web::Data<crate::shared::state::AppState>) -> Result<HttpResponse> { async fn get_sessions(data: web::Data<AppState>) -> Result<HttpResponse> {
let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap(); let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap();
match data.orchestrator.get_user_sessions(user_id).await { let orchestrator = BotOrchestrator::new(Arc::clone(&data));
match orchestrator.get_user_sessions(user_id).await {
Ok(sessions) => Ok(HttpResponse::Ok().json(sessions)), Ok(sessions) => Ok(HttpResponse::Ok().json(sessions)),
Err(e) => { Err(e) => {
Ok(HttpResponse::InternalServerError() Ok(HttpResponse::InternalServerError()
@ -634,22 +668,24 @@ async fn get_sessions(data: web::Data<crate::shared::state::AppState>) -> Result
#[actix_web::get("/api/sessions/{session_id}")] #[actix_web::get("/api/sessions/{session_id}")]
async fn get_session_history( async fn get_session_history(
data: web::Data<crate::shared::state::AppState>, data: web::Data<AppState>,
path: web::Path<String>, path: web::Path<String>,
) -> Result<HttpResponse> { ) -> Result<HttpResponse> {
let session_id = path.into_inner(); let session_id = path.into_inner();
let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap(); let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap();
match Uuid::parse_str(&session_id) { match Uuid::parse_str(&session_id) {
Ok(session_uuid) => match data Ok(session_uuid) => {
.orchestrator let orchestrator = BotOrchestrator::new(Arc::clone(&data));
match orchestrator
.get_conversation_history(session_uuid, user_id) .get_conversation_history(session_uuid, user_id)
.await .await
{ {
Ok(history) => Ok(HttpResponse::Ok().json(history)), Ok(history) => Ok(HttpResponse::Ok().json(history)),
Err(e) => Ok(HttpResponse::InternalServerError() Err(e) => Ok(HttpResponse::InternalServerError()
.json(serde_json::json!({"error": e.to_string()}))), .json(serde_json::json!({"error": e.to_string()}))),
}, }
}
Err(_) => { Err(_) => {
Ok(HttpResponse::BadRequest().json(serde_json::json!({"error": "Invalid session ID"}))) Ok(HttpResponse::BadRequest().json(serde_json::json!({"error": "Invalid session ID"})))
} }
@ -658,7 +694,7 @@ async fn get_session_history(
#[actix_web::post("/api/set_mode")] #[actix_web::post("/api/set_mode")]
async fn set_mode_handler( async fn set_mode_handler(
data: web::Data<crate::shared::state::AppState>, data: web::Data<AppState>,
info: web::Json<HashMap<String, String>>, info: web::Json<HashMap<String, String>>,
) -> Result<HttpResponse> { ) -> Result<HttpResponse> {
let default_user = "default_user".to_string(); let default_user = "default_user".to_string();
@ -671,8 +707,8 @@ async fn set_mode_handler(
let mode = mode_str.parse::<i32>().unwrap_or(0); let mode = mode_str.parse::<i32>().unwrap_or(0);
if let Err(e) = data let orchestrator = BotOrchestrator::new(Arc::clone(&data));
.orchestrator if let Err(e) = orchestrator
.set_user_answer_mode(user_id, bot_id, mode) .set_user_answer_mode(user_id, bot_id, mode)
.await .await
{ {

View file

@ -5,6 +5,7 @@ use actix_web::middleware::Logger;
use actix_web::{web, App, HttpServer}; use actix_web::{web, App, HttpServer};
use dotenvy::dotenv; use dotenvy::dotenv;
use log::info; use log::info;
use std::collections::HashMap;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
mod auth; mod auth;
@ -105,7 +106,7 @@ async fn main() -> std::io::Result<()> {
}; };
let tool_manager = Arc::new(tools::ToolManager::new()); let tool_manager = Arc::new(tools::ToolManager::new());
let llm_provider = Arc::new(llm::OpenAIClient::new( let llm_provider = Arc::new(crate::llm::OpenAIClient::new(
"empty".to_string(), "empty".to_string(),
Some("http://localhost:8081".to_string()), Some("http://localhost:8081".to_string()),
)); ));
@ -125,69 +126,40 @@ async fn main() -> std::io::Result<()> {
let tool_api = Arc::new(tools::ToolApi::new()); let tool_api = Arc::new(tools::ToolApi::new());
let base_app_state = AppState { let session_manager = Arc::new(tokio::sync::Mutex::new(session::SessionManager::new(
diesel::Connection::establish(&cfg.database_url()).unwrap(),
redis_client.clone(),
)));
let auth_service = Arc::new(tokio::sync::Mutex::new(auth::AuthService::new(
diesel::Connection::establish(&cfg.database_url()).unwrap(),
redis_client.clone(),
)));
let app_state = Arc::new(AppState {
s3_client: None, s3_client: None,
config: Some(cfg.clone()), config: Some(cfg.clone()),
conn: db_pool.clone(), conn: db_pool.clone(),
custom_conn: db_custom_pool.clone(), custom_conn: db_custom_pool.clone(),
redis_client: redis_client.clone(), redis_client: redis_client.clone(),
orchestrator: Arc::new(bot::BotOrchestrator::new( session_manager: session_manager.clone(),
session::SessionManager::new( tool_manager: tool_manager.clone(),
diesel::Connection::establish(&cfg.database_url()).unwrap(), llm_provider: llm_provider.clone(),
redis_client.clone(), auth_service: auth_service.clone(),
), channels: Arc::new(Mutex::new(HashMap::new())),
(*tool_manager).clone(), response_channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
llm_provider.clone(), web_adapter: web_adapter.clone(),
auth::AuthService::new( voice_adapter: voice_adapter.clone(),
diesel::Connection::establish(&cfg.database_url()).unwrap(), whatsapp_adapter: whatsapp_adapter.clone(),
redis_client.clone(), tool_api: tool_api.clone(),
), });
)),
web_adapter,
voice_adapter,
whatsapp_adapter,
tool_api,
};
info!( info!(
"Starting server on {}:{}", "Starting server on {}:{}",
config.server.host, config.server.port config.server.host, config.server.port
); );
let closure_config = config.clone();
HttpServer::new(move || { HttpServer::new(move || {
let cfg = closure_config.clone();
let auth_service = auth::AuthService::new(
diesel::Connection::establish(&cfg.database_url()).unwrap(),
redis_client.clone(),
);
let session_manager = session::SessionManager::new(
diesel::Connection::establish(&cfg.database_url()).unwrap(),
redis_client.clone(),
);
let orchestrator = Arc::new(bot::BotOrchestrator::new(
session_manager,
(*tool_manager).clone(),
llm_provider.clone(),
auth_service,
));
let app_state = AppState {
s3_client: base_app_state.s3_client.clone(),
config: base_app_state.config.clone(),
conn: base_app_state.conn.clone(),
custom_conn: base_app_state.custom_conn.clone(),
redis_client: base_app_state.redis_client.clone(),
orchestrator,
web_adapter: base_app_state.web_adapter.clone(),
voice_adapter: base_app_state.voice_adapter.clone(),
whatsapp_adapter: base_app_state.whatsapp_adapter.clone(),
tool_api: base_app_state.tool_api.clone(),
};
let cors = Cors::default() let cors = Cors::default()
.allow_any_origin() .allow_any_origin()
.allow_any_method() .allow_any_method()
@ -200,7 +172,7 @@ async fn main() -> std::io::Result<()> {
.wrap(cors) .wrap(cors)
.wrap(Logger::default()) .wrap(Logger::default())
.wrap(Logger::new("HTTP REQUEST: %a %{User-Agent}i")) .wrap(Logger::new("HTTP REQUEST: %a %{User-Agent}i"))
.app_data(web::Data::new(app_state_clone)); .app_data(web::Data::from(app_state_clone));
app = app app = app
.service(upload_file) .service(upload_file)

View file

@ -1,21 +1,33 @@
use crate::bot::BotOrchestrator; use crate::auth::AuthService;
use crate::channels::{VoiceAdapter, WebChannelAdapter}; use crate::channels::{ChannelAdapter, VoiceAdapter, WebChannelAdapter};
use crate::config::AppConfig; use crate::config::AppConfig;
use crate::tools::ToolApi; use crate::llm::LLMProvider;
use crate::session::SessionManager;
use crate::tools::{ToolApi, ToolManager};
use crate::whatsapp::WhatsAppAdapter; use crate::whatsapp::WhatsAppAdapter;
use diesel::PgConnection; use diesel::{Connection, PgConnection};
use redis::Client; use redis::Client;
use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::sync::Mutex; use std::sync::Mutex;
use tokio::sync::mpsc;
use crate::shared::models::BotResponse;
pub struct AppState { pub struct AppState {
pub s3_client: Option<aws_sdk_s3::Client>, pub s3_client: Option<aws_sdk_s3::Client>,
pub config: Option<AppConfig>, pub config: Option<AppConfig>,
pub conn: Arc<Mutex<PgConnection>>, pub conn: Arc<Mutex<PgConnection>>,
pub custom_conn: Arc<Mutex<PgConnection>>, pub custom_conn: Arc<Mutex<PgConnection>>,
pub redis_client: Option<Arc<Client>>, pub redis_client: Option<Arc<Client>>,
pub orchestrator: Arc<BotOrchestrator>,
pub session_manager: Arc<tokio::sync::Mutex<SessionManager>>,
pub tool_manager: Arc<ToolManager>,
pub llm_provider: Arc<dyn LLMProvider>,
pub auth_service: Arc<tokio::sync::Mutex<AuthService>>,
pub channels: Arc<Mutex<HashMap<String, Arc<dyn ChannelAdapter>>>>,
pub response_channels: Arc<tokio::sync::Mutex<HashMap<String, mpsc::Sender<BotResponse>>>>,
pub web_adapter: Arc<WebChannelAdapter>, pub web_adapter: Arc<WebChannelAdapter>,
pub voice_adapter: Arc<VoiceAdapter>, pub voice_adapter: Arc<VoiceAdapter>,
pub whatsapp_adapter: Arc<WhatsAppAdapter>, pub whatsapp_adapter: Arc<WhatsAppAdapter>,
@ -30,7 +42,12 @@ impl Clone for AppState {
conn: Arc::clone(&self.conn), conn: Arc::clone(&self.conn),
custom_conn: Arc::clone(&self.custom_conn), custom_conn: Arc::clone(&self.custom_conn),
redis_client: self.redis_client.clone(), redis_client: self.redis_client.clone(),
orchestrator: Arc::clone(&self.orchestrator), session_manager: Arc::clone(&self.session_manager),
tool_manager: Arc::clone(&self.tool_manager),
llm_provider: Arc::clone(&self.llm_provider),
auth_service: Arc::clone(&self.auth_service),
channels: Arc::clone(&self.channels),
response_channels: Arc::clone(&self.response_channels),
web_adapter: Arc::clone(&self.web_adapter), web_adapter: Arc::clone(&self.web_adapter),
voice_adapter: Arc::clone(&self.voice_adapter), voice_adapter: Arc::clone(&self.voice_adapter),
whatsapp_adapter: Arc::clone(&self.whatsapp_adapter), whatsapp_adapter: Arc::clone(&self.whatsapp_adapter),
@ -38,3 +55,46 @@ impl Clone for AppState {
} }
} }
} }
impl Default for AppState {
fn default() -> Self {
Self {
s3_client: None,
config: None,
conn: Arc::new(Mutex::new(
diesel::PgConnection::establish("postgres://localhost/test").unwrap(),
)),
custom_conn: Arc::new(Mutex::new(
diesel::PgConnection::establish("postgres://localhost/test").unwrap(),
)),
redis_client: None,
session_manager: Arc::new(tokio::sync::Mutex::new(SessionManager::new(
diesel::PgConnection::establish("postgres://localhost/test").unwrap(),
None,
))),
tool_manager: Arc::new(ToolManager::new()),
llm_provider: Arc::new(crate::llm::OpenAIClient::new(
"empty".to_string(),
Some("http://localhost:8081".to_string()),
)),
auth_service: Arc::new(tokio::sync::Mutex::new(AuthService::new(
diesel::PgConnection::establish("postgres://localhost/test").unwrap(),
None,
))),
channels: Arc::new(Mutex::new(HashMap::new())),
response_channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
web_adapter: Arc::new(WebChannelAdapter::new()),
voice_adapter: Arc::new(VoiceAdapter::new(
"https://livekit.example.com".to_string(),
"api_key".to_string(),
"api_secret".to_string(),
)),
whatsapp_adapter: Arc::new(WhatsAppAdapter::new(
"whatsapp_token".to_string(),
"phone_number_id".to_string(),
"verify_token".to_string(),
)),
tool_api: Arc::new(ToolApi::new()),
}
}
}