Refactor to Arc<AppState> for shared state
- Migrate core services to store Arc<AppState> and use locks - Centralize state in AppState with Arc-wrapped managers - Update handlers to pass Arc<AppState> via web::Data - Add Default for AppState and initialize components in main - Update debug.json program path from gbserver to botserver
This commit is contained in:
parent
e14908fc0b
commit
e00eb40d43
10 changed files with 384 additions and 280 deletions
|
|
@ -5,7 +5,7 @@
|
|||
"command": "cargo",
|
||||
"args": ["build"]
|
||||
},
|
||||
"program": "$ZED_WORKTREE_ROOT/target/debug/gbserver",
|
||||
"program": "$ZED_WORKTREE_ROOT/target/debug/botserver",
|
||||
|
||||
"sourceLanguages": ["rust"],
|
||||
"request": "launch",
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@ echo "Consolidated LLM Context" > "$OUTPUT_FILE"
|
|||
prompts=(
|
||||
"../../prompts/dev/shared.md"
|
||||
"../../Cargo.toml"
|
||||
#"../../prompts/dev/fix.md"
|
||||
"../../prompts/dev/generation.md"
|
||||
"../../prompts/dev/fix.md"
|
||||
#"../../prompts/dev/generation.md"
|
||||
)
|
||||
|
||||
for file in "${prompts[@]}"; do
|
||||
|
|
@ -23,18 +23,18 @@ dirs=(
|
|||
#"automation"
|
||||
#"basic"
|
||||
"bot"
|
||||
"channels"
|
||||
#"channels"
|
||||
#"config"
|
||||
"context"
|
||||
#"context"
|
||||
#"email"
|
||||
#"file"
|
||||
"llm"
|
||||
#"llm"
|
||||
#"llm_legacy"
|
||||
#"org"
|
||||
"session"
|
||||
#"session"
|
||||
"shared"
|
||||
#"tests"
|
||||
"tools"
|
||||
#"tools"
|
||||
#"web_automation"
|
||||
#"whatsapp"
|
||||
)
|
||||
|
|
@ -47,10 +47,10 @@ done
|
|||
|
||||
# Also append the specific files you mentioned
|
||||
cat "$PROJECT_ROOT/src/main.rs" >> "$OUTPUT_FILE"
|
||||
cat "$PROJECT_ROOT/src/basic/keywords/hear_talk.rs" >> "$OUTPUT_FILE"
|
||||
|
||||
echo "This BASIC file will run as soon as the conversation is created. " >> "$OUTPUT_FILE"
|
||||
cat "$PROJECT_ROOT/templates/annoucements.gbai/annoucements.gbdialog/start.bas" >> "$OUTPUT_FILE"
|
||||
# cat "$PROJECT_ROOT/src/basic/keywords/hear_talk.rs" >> "$OUTPUT_FILE"
|
||||
# cat "$PROJECT_ROOT/src/basic/keywords/get.rs" >> "$OUTPUT_FILE"
|
||||
cat "$PROJECT_ROOT/src/basic/keywords/llm_keyword.rs" >> "$OUTPUT_FILE"
|
||||
# cat "$PROJECT_ROOT/src/basic/mod.rs" >> "$OUTPUT_FILE"
|
||||
|
||||
|
||||
echo "" >> "$OUTPUT_FILE"
|
||||
|
|
|
|||
|
|
@ -10,12 +10,12 @@ use tokio::time::Duration;
|
|||
use uuid::Uuid;
|
||||
|
||||
pub struct AutomationService {
|
||||
state: AppState,
|
||||
state: Arc<AppState>,
|
||||
scripts_dir: String,
|
||||
}
|
||||
|
||||
impl AutomationService {
|
||||
pub fn new(state: AppState, scripts_dir: &str) -> Self {
|
||||
pub fn new(state: Arc<AppState>, scripts_dir: &str) -> Self {
|
||||
Self {
|
||||
state,
|
||||
scripts_dir: scripts_dir.to_string(),
|
||||
|
|
@ -61,13 +61,11 @@ impl AutomationService {
|
|||
|
||||
async fn check_table_changes(&self, automations: &[Automation], since: DateTime<Utc>) {
|
||||
for automation in automations {
|
||||
// Resolve the trigger kind, disambiguating the `from_i32` call.
|
||||
let trigger_kind = match crate::shared::models::TriggerKind::from_i32(automation.kind) {
|
||||
Some(k) => k,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
// We're only interested in table‑change triggers.
|
||||
if !matches!(
|
||||
trigger_kind,
|
||||
TriggerKind::TableUpdate | TriggerKind::TableInsert | TriggerKind::TableDelete
|
||||
|
|
@ -75,50 +73,41 @@ impl AutomationService {
|
|||
continue;
|
||||
}
|
||||
|
||||
// Table name must be present.
|
||||
let table = match &automation.target {
|
||||
Some(t) => t,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
// Choose the appropriate timestamp column.
|
||||
let column = match trigger_kind {
|
||||
TriggerKind::TableInsert => "created_at",
|
||||
_ => "updated_at",
|
||||
};
|
||||
|
||||
// Build a simple COUNT(*) query; alias the column so Diesel can map it directly to i64.
|
||||
let query = format!(
|
||||
"SELECT COUNT(*) as count FROM {} WHERE {} > $1",
|
||||
table, column
|
||||
);
|
||||
|
||||
// Acquire a connection for this query.
|
||||
let mut conn_guard = self.state.conn.lock().unwrap();
|
||||
let conn = &mut *conn_guard;
|
||||
|
||||
// Define a struct to capture the query result
|
||||
#[derive(diesel::QueryableByName)]
|
||||
struct CountResult {
|
||||
#[diesel(sql_type = diesel::sql_types::BigInt)]
|
||||
count: i64,
|
||||
}
|
||||
|
||||
// Execute the query, retrieving a plain i64 count.
|
||||
let count_result = diesel::sql_query(&query)
|
||||
.bind::<diesel::sql_types::Timestamp, _>(since.naive_utc())
|
||||
.get_result::<CountResult>(conn);
|
||||
|
||||
match count_result {
|
||||
Ok(result) if result.count > 0 => {
|
||||
// Release the lock before awaiting asynchronous work.
|
||||
drop(conn_guard);
|
||||
self.execute_action(&automation.param).await;
|
||||
self.update_last_triggered(automation.id).await;
|
||||
}
|
||||
Ok(_result) => {
|
||||
// No relevant rows changed; continue to the next automation.
|
||||
}
|
||||
Ok(_result) => {}
|
||||
Err(e) => {
|
||||
error!("Error checking changes for table '{}': {}", table, e);
|
||||
}
|
||||
|
|
@ -215,7 +204,7 @@ impl AutomationService {
|
|||
updated_at: Utc::now(),
|
||||
};
|
||||
|
||||
let script_service = ScriptService::new(&self.state, user_session);
|
||||
let script_service = ScriptService::new(Arc::clone(&self.state), user_session);
|
||||
let ast = match script_service.compile(&script_content) {
|
||||
Ok(ast) => ast,
|
||||
Err(e) => {
|
||||
|
|
|
|||
|
|
@ -1,59 +1,45 @@
|
|||
use log::info;
|
||||
use crate::shared::state::AppState;
|
||||
use crate::shared::models::UserSession;
|
||||
use crate::shared::state::AppState;
|
||||
use log::info;
|
||||
use reqwest::{self, Client};
|
||||
use rhai::{Dynamic, Engine};
|
||||
use std::error::Error;
|
||||
|
||||
pub fn get_keyword(_state: &AppState, _user: UserSession, engine: &mut Engine) {
|
||||
pub fn get_keyword(state: &AppState, _user: UserSession, engine: &mut Engine) {
|
||||
let state_clone = state.clone();
|
||||
|
||||
engine
|
||||
.register_custom_syntax(
|
||||
&["GET", "$expr$"],
|
||||
false,
|
||||
move |context, inputs| {
|
||||
let url = context.eval_expression_tree(&inputs[0])?;
|
||||
let url_str = url.to_string();
|
||||
.register_custom_syntax(&["GET", "$expr$"], false, move |context, inputs| {
|
||||
let url = context.eval_expression_tree(&inputs[0])?;
|
||||
let url_str = url.to_string();
|
||||
|
||||
if url_str.contains("..") {
|
||||
return Err("URL contains invalid path traversal sequences like '..'.".into());
|
||||
}
|
||||
if url_str.contains("..") {
|
||||
return Err("URL contains invalid path traversal sequences like '..'.".into());
|
||||
}
|
||||
|
||||
let modified_url = if url_str.starts_with("/") {
|
||||
let work_root = std::env::var("WORK_ROOT").unwrap_or_else(|_| "./work".to_string());
|
||||
let full_path = std::path::Path::new(&work_root)
|
||||
.join(url_str.trim_start_matches('/'))
|
||||
.to_string_lossy()
|
||||
.into_owned();
|
||||
let state_for_async = state_clone.clone();
|
||||
let url_for_async = url_str.clone();
|
||||
|
||||
let base_url = "file://";
|
||||
format!("{}{}", base_url, full_path)
|
||||
} else {
|
||||
url_str.to_string()
|
||||
};
|
||||
if url_str.starts_with("https://") {
|
||||
info!("HTTPS GET request: {}", url_for_async);
|
||||
|
||||
if modified_url.starts_with("https://") {
|
||||
info!("HTTPS GET request: {}", modified_url);
|
||||
let fut = execute_get(&url_for_async);
|
||||
let result =
|
||||
tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(fut))
|
||||
.map_err(|e| format!("HTTP request failed: {}", e))?;
|
||||
|
||||
let fut = execute_get(&modified_url);
|
||||
let result =
|
||||
tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(fut))
|
||||
.map_err(|e| format!("HTTP request failed: {}", e))?;
|
||||
Ok(Dynamic::from(result))
|
||||
} else {
|
||||
info!("Local file GET request from bucket: {}", url_for_async);
|
||||
|
||||
Ok(Dynamic::from(result))
|
||||
} else if modified_url.starts_with("file://") {
|
||||
let file_path = modified_url.trim_start_matches("file://");
|
||||
match std::fs::read_to_string(file_path) {
|
||||
Ok(content) => Ok(Dynamic::from(content)),
|
||||
Err(e) => Err(format!("Failed to read file: {}", e).into()),
|
||||
}
|
||||
} else {
|
||||
Err(
|
||||
format!("GET request failed: URL must begin with 'https://' or 'file://'")
|
||||
.into(),
|
||||
)
|
||||
}
|
||||
},
|
||||
)
|
||||
let fut = get_from_bucket(&state_for_async, &url_for_async);
|
||||
let result =
|
||||
tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(fut))
|
||||
.map_err(|e| format!("Bucket GET failed: {}", e))?;
|
||||
|
||||
Ok(Dynamic::from(result))
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
|
|
@ -69,3 +55,29 @@ pub async fn execute_get(url: &str) -> Result<String, Box<dyn Error + Send + Syn
|
|||
|
||||
Ok(content)
|
||||
}
|
||||
|
||||
pub async fn get_from_bucket(
|
||||
state: &AppState,
|
||||
file_path: &str,
|
||||
) -> Result<String, Box<dyn Error + Send + Sync>> {
|
||||
info!("Getting file from bucket: {}", file_path);
|
||||
|
||||
if let Some(s3_client) = &state.s3_client {
|
||||
let bucket_name =
|
||||
std::env::var("DEFAULT_BUCKET").unwrap_or_else(|_| "default-bucket".to_string());
|
||||
|
||||
let response = s3_client
|
||||
.get_object()
|
||||
.bucket(&bucket_name)
|
||||
.key(file_path)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let data = response.body.collect().await?;
|
||||
let content = String::from_utf8(data.into_bytes().to_vec())?;
|
||||
|
||||
Ok(content)
|
||||
} else {
|
||||
Err("S3 client not configured".into())
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
use crate::shared::models::UserSession;
|
||||
use crate::shared::state::AppState;
|
||||
use crate::{channels::ChannelAdapter, shared::models::UserSession};
|
||||
use log::info;
|
||||
use rhai::{Dynamic, Engine, EvalAltResult};
|
||||
|
||||
pub fn hear_keyword(_state: &AppState, user: UserSession, engine: &mut Engine) {
|
||||
pub fn hear_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
|
||||
let session_id = user.id;
|
||||
let cache = state.redis_client.clone();
|
||||
|
||||
engine
|
||||
.register_custom_syntax(&["HEAR", "$ident$"], true, move |_context, inputs| {
|
||||
|
|
@ -18,19 +19,35 @@ pub fn hear_keyword(_state: &AppState, user: UserSession, engine: &mut Engine) {
|
|||
variable_name
|
||||
);
|
||||
|
||||
// Spawn a background task to handle the input‑waiting logic.
|
||||
// The actual waiting implementation should be added here.
|
||||
let cache_clone = cache.clone();
|
||||
let session_id_clone = session_id;
|
||||
let var_name_clone = variable_name.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
log::debug!(
|
||||
"HEAR: Starting async task for session {} and variable '{}'",
|
||||
session_id,
|
||||
variable_name
|
||||
session_id_clone,
|
||||
var_name_clone
|
||||
);
|
||||
// TODO: implement actual waiting logic here without using the orchestrator
|
||||
// For now, just log that we would wait for input
|
||||
|
||||
if let Some(cache_client) = &cache_clone {
|
||||
let mut conn = match cache_client.get_multiplexed_async_connection().await {
|
||||
Ok(conn) => conn,
|
||||
Err(e) => {
|
||||
log::error!("Failed to connect to cache: {}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let key = format!("hear:{}:{}", session_id_clone, var_name_clone);
|
||||
let _: Result<(), _> = redis::cmd("SET")
|
||||
.arg(&key)
|
||||
.arg("waiting")
|
||||
.query_async(&mut conn)
|
||||
.await;
|
||||
}
|
||||
});
|
||||
|
||||
// Interrupt the current Rhai evaluation flow until the user input is received.
|
||||
Err(Box::new(EvalAltResult::ErrorRuntime(
|
||||
"Waiting for user input".into(),
|
||||
rhai::Position::NONE,
|
||||
|
|
@ -40,7 +57,6 @@ pub fn hear_keyword(_state: &AppState, user: UserSession, engine: &mut Engine) {
|
|||
}
|
||||
|
||||
pub fn talk_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
|
||||
// Import the BotResponse type directly to satisfy diagnostics.
|
||||
use crate::shared::models::BotResponse;
|
||||
|
||||
let state_clone = state.clone();
|
||||
|
|
@ -63,13 +79,11 @@ pub fn talk_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
|
|||
is_complete: true,
|
||||
};
|
||||
|
||||
// Send response through a channel or queue instead of accessing orchestrator directly
|
||||
let _state_for_spawn = state_clone.clone();
|
||||
let state_for_spawn = state_clone.clone();
|
||||
tokio::spawn(async move {
|
||||
// Use a more thread-safe approach to send the message
|
||||
// This avoids capturing the orchestrator directly which isn't Send + Sync
|
||||
// TODO: Implement proper response handling once response_sender field is added to AppState
|
||||
log::debug!("TALK: Would send response: {:?}", response);
|
||||
if let Err(e) = state_for_spawn.web_adapter.send_message(response).await {
|
||||
log::error!("Failed to send TALK message: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Dynamic::UNIT)
|
||||
|
|
@ -78,7 +92,7 @@ pub fn talk_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
|
|||
}
|
||||
|
||||
pub fn set_context_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
|
||||
let state_clone = state.clone();
|
||||
let cache = state.redis_client.clone();
|
||||
|
||||
engine
|
||||
.register_custom_syntax(
|
||||
|
|
@ -91,14 +105,14 @@ pub fn set_context_keyword(state: &AppState, user: UserSession, engine: &mut Eng
|
|||
|
||||
let redis_key = format!("context:{}:{}", user.user_id, user.id);
|
||||
|
||||
let state_for_redis = state_clone.clone();
|
||||
let cache_clone = cache.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Some(redis_client) = &state_for_redis.redis_client {
|
||||
let mut conn = match redis_client.get_multiplexed_async_connection().await {
|
||||
if let Some(cache_client) = &cache_clone {
|
||||
let mut conn = match cache_client.get_multiplexed_async_connection().await {
|
||||
Ok(conn) => conn,
|
||||
Err(e) => {
|
||||
log::error!("Failed to connect to Redis: {}", e);
|
||||
log::error!("Failed to connect to cache: {}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
use crate::shared::models::UserSession;
|
||||
use crate::shared::state::AppState;
|
||||
use crate::shared::utils::call_llm;
|
||||
use log::info;
|
||||
use rhai::{Dynamic, Engine};
|
||||
|
||||
pub fn llm_keyword(state: &AppState, _user: UserSession, engine: &mut Engine) {
|
||||
let ai_config = state.config.clone().unwrap().ai.clone();
|
||||
let state_clone = state.clone();
|
||||
|
||||
engine
|
||||
.register_custom_syntax(&["LLM", "$expr$"], false, move |context, inputs| {
|
||||
|
|
@ -14,10 +13,18 @@ pub fn llm_keyword(state: &AppState, _user: UserSession, engine: &mut Engine) {
|
|||
|
||||
info!("LLM processing text: {}", text_str);
|
||||
|
||||
let fut = call_llm(&text_str, &ai_config);
|
||||
let state_inner = state_clone.clone();
|
||||
let fut = async move {
|
||||
let prompt = text_str;
|
||||
state_inner
|
||||
.llm_provider
|
||||
.generate(&prompt, &serde_json::Value::Null)
|
||||
.await
|
||||
.map_err(|e| format!("LLM call failed: {}", e))
|
||||
};
|
||||
|
||||
let result =
|
||||
tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(fut))
|
||||
.map_err(|e| format!("LLM call failed: {}", e))?;
|
||||
tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(fut))?;
|
||||
|
||||
Ok(Dynamic::from(result))
|
||||
})
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
pub mod keywords;
|
||||
use crate::shared::models::UserSession;
|
||||
use crate::shared::state::AppState;
|
||||
use log::info;
|
||||
use rhai::{Dynamic, Engine, EvalAltResult};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[cfg(feature = "email")]
|
||||
use self::keywords::create_draft_keyword;
|
||||
pub mod keywords;
|
||||
|
||||
use self::keywords::create_site::create_site_keyword;
|
||||
use self::keywords::find::find_keyword;
|
||||
|
|
@ -17,43 +20,54 @@ use self::keywords::print::print_keyword;
|
|||
use self::keywords::set::set_keyword;
|
||||
use self::keywords::set_schedule::set_schedule_keyword;
|
||||
use self::keywords::wait::wait_keyword;
|
||||
use crate::shared::models::UserSession;
|
||||
use crate::shared::state::AppState;
|
||||
use log::info;
|
||||
use rhai::{Dynamic, Engine, EvalAltResult};
|
||||
|
||||
#[cfg(feature = "email")]
|
||||
use self::keywords::create_draft_keyword;
|
||||
|
||||
#[cfg(feature = "web_automation")]
|
||||
use self::keywords::get_website::get_website_keyword;
|
||||
|
||||
pub struct ScriptService {
|
||||
engine: Engine,
|
||||
state: Arc<AppState>,
|
||||
user: UserSession,
|
||||
}
|
||||
|
||||
impl ScriptService {
|
||||
pub fn new(state: &AppState, user: UserSession) -> Self {
|
||||
pub fn new(state: Arc<AppState>, user: UserSession) -> Self {
|
||||
let mut engine = Engine::new();
|
||||
|
||||
engine.set_allow_anonymous_fn(true);
|
||||
engine.set_allow_looping(true);
|
||||
|
||||
#[cfg(feature = "email")]
|
||||
create_draft_keyword(state, user.clone(), &mut engine);
|
||||
create_draft_keyword(&state, user.clone(), &mut engine);
|
||||
|
||||
create_site_keyword(state, user.clone(), &mut engine);
|
||||
find_keyword(state, user.clone(), &mut engine);
|
||||
for_keyword(state, user.clone(), &mut engine);
|
||||
create_site_keyword(&state, user.clone(), &mut engine);
|
||||
find_keyword(&state, user.clone(), &mut engine);
|
||||
for_keyword(&state, user.clone(), &mut engine);
|
||||
first_keyword(&mut engine);
|
||||
last_keyword(&mut engine);
|
||||
format_keyword(&mut engine);
|
||||
llm_keyword(state, user.clone(), &mut engine);
|
||||
get_keyword(state, user.clone(), &mut engine);
|
||||
set_keyword(state, user.clone(), &mut engine);
|
||||
wait_keyword(state, user.clone(), &mut engine);
|
||||
print_keyword(state, user.clone(), &mut engine);
|
||||
on_keyword(state, user.clone(), &mut engine);
|
||||
set_schedule_keyword(state, user.clone(), &mut engine);
|
||||
hear_keyword(state, user.clone(), &mut engine);
|
||||
talk_keyword(state, user.clone(), &mut engine);
|
||||
set_context_keyword(state, user.clone(), &mut engine);
|
||||
llm_keyword(&state, user.clone(), &mut engine);
|
||||
get_keyword(&state, user.clone(), &mut engine);
|
||||
set_keyword(&state, user.clone(), &mut engine);
|
||||
wait_keyword(&state, user.clone(), &mut engine);
|
||||
print_keyword(&state, user.clone(), &mut engine);
|
||||
on_keyword(&state, user.clone(), &mut engine);
|
||||
set_schedule_keyword(&state, user.clone(), &mut engine);
|
||||
hear_keyword(&state, user.clone(), &mut engine);
|
||||
talk_keyword(&state, user.clone(), &mut engine);
|
||||
set_context_keyword(&state, user.clone(), &mut engine);
|
||||
|
||||
ScriptService { engine }
|
||||
#[cfg(feature = "web_automation")]
|
||||
get_website_keyword(&state, user.clone(), &mut engine);
|
||||
|
||||
ScriptService {
|
||||
engine,
|
||||
state,
|
||||
user,
|
||||
}
|
||||
}
|
||||
|
||||
fn preprocess_basic_script(&self, script: &str) -> String {
|
||||
|
|
|
|||
236
src/bot/mod.rs
236
src/bot/mod.rs
|
|
@ -6,40 +6,20 @@ use serde_json;
|
|||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use tokio::sync::mpsc;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::auth::AuthService;
|
||||
use crate::channels::ChannelAdapter;
|
||||
use crate::llm::LLMProvider;
|
||||
use crate::session::SessionManager;
|
||||
use crate::shared::models::{BotResponse, UserMessage, UserSession};
|
||||
use crate::tools::ToolManager;
|
||||
use crate::shared::state::AppState;
|
||||
|
||||
pub struct BotOrchestrator {
|
||||
pub session_manager: Arc<Mutex<SessionManager>>,
|
||||
tool_manager: Arc<ToolManager>,
|
||||
llm_provider: Arc<dyn LLMProvider>,
|
||||
auth_service: Arc<Mutex<AuthService>>,
|
||||
pub channels: HashMap<String, Arc<dyn ChannelAdapter>>,
|
||||
response_channels: Arc<Mutex<HashMap<String, mpsc::Sender<BotResponse>>>>,
|
||||
pub state: Arc<AppState>,
|
||||
}
|
||||
|
||||
impl BotOrchestrator {
|
||||
pub fn new(
|
||||
session_manager: SessionManager,
|
||||
tool_manager: ToolManager,
|
||||
llm_provider: Arc<dyn LLMProvider>,
|
||||
auth_service: AuthService,
|
||||
) -> Self {
|
||||
Self {
|
||||
session_manager: Arc::new(Mutex::new(session_manager)),
|
||||
tool_manager: Arc::new(tool_manager),
|
||||
llm_provider,
|
||||
auth_service: Arc::new(Mutex::new(auth_service)),
|
||||
channels: HashMap::new(),
|
||||
response_channels: Arc::new(Mutex::new(HashMap::new())),
|
||||
}
|
||||
pub fn new(state: Arc<AppState>) -> Self {
|
||||
Self { state }
|
||||
}
|
||||
|
||||
pub async fn handle_user_input(
|
||||
|
|
@ -47,18 +27,22 @@ impl BotOrchestrator {
|
|||
session_id: Uuid,
|
||||
user_input: &str,
|
||||
) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let mut session_manager = self.session_manager.lock().await;
|
||||
let mut session_manager = self.state.session_manager.lock().await;
|
||||
session_manager.provide_input(session_id, user_input.to_string())?;
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
pub async fn is_waiting_for_input(&self, session_id: Uuid) -> bool {
|
||||
let session_manager = self.session_manager.lock().await;
|
||||
let session_manager = self.state.session_manager.lock().await;
|
||||
session_manager.is_waiting_for_input(&session_id)
|
||||
}
|
||||
|
||||
pub fn add_channel(&mut self, channel_type: &str, adapter: Arc<dyn ChannelAdapter>) {
|
||||
self.channels.insert(channel_type.to_string(), adapter);
|
||||
pub fn add_channel(&self, channel_type: &str, adapter: Arc<dyn ChannelAdapter>) {
|
||||
self.state
|
||||
.channels
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(channel_type.to_string(), adapter);
|
||||
}
|
||||
|
||||
pub async fn register_response_channel(
|
||||
|
|
@ -66,7 +50,8 @@ impl BotOrchestrator {
|
|||
session_id: String,
|
||||
sender: mpsc::Sender<BotResponse>,
|
||||
) {
|
||||
self.response_channels
|
||||
self.state
|
||||
.response_channels
|
||||
.lock()
|
||||
.await
|
||||
.insert(session_id, sender);
|
||||
|
|
@ -78,7 +63,7 @@ impl BotOrchestrator {
|
|||
bot_id: &str,
|
||||
mode: i32,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let mut session_manager = self.session_manager.lock().await;
|
||||
let mut session_manager = self.state.session_manager.lock().await;
|
||||
session_manager.update_answer_mode(user_id, bot_id, mode)?;
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -97,10 +82,15 @@ impl BotOrchestrator {
|
|||
.unwrap_or_else(|_| Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap());
|
||||
|
||||
let session = {
|
||||
let mut session_manager = self.session_manager.lock().await;
|
||||
let mut session_manager = self.state.session_manager.lock().await;
|
||||
match session_manager.get_user_session(user_id, bot_id)? {
|
||||
Some(session) => session,
|
||||
None => session_manager.create_session(user_id, bot_id, "New Conversation")?,
|
||||
None => {
|
||||
let new_session =
|
||||
session_manager.create_session(user_id, bot_id, "New Conversation")?;
|
||||
Self::run_start_script(&new_session, Arc::clone(&self.state)).await;
|
||||
new_session
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -113,7 +103,7 @@ impl BotOrchestrator {
|
|||
variable_name, session.id
|
||||
);
|
||||
|
||||
if let Some(adapter) = self.channels.get(&message.channel) {
|
||||
if let Some(adapter) = self.state.channels.lock().unwrap().get(&message.channel) {
|
||||
let ack_response = BotResponse {
|
||||
bot_id: message.bot_id.clone(),
|
||||
user_id: message.user_id.clone(),
|
||||
|
|
@ -131,7 +121,7 @@ impl BotOrchestrator {
|
|||
}
|
||||
|
||||
if session.answer_mode == 1 && session.current_tool.is_some() {
|
||||
self.tool_manager.provide_user_response(
|
||||
self.state.tool_manager.provide_user_response(
|
||||
&message.user_id,
|
||||
&message.bot_id,
|
||||
message.content.clone(),
|
||||
|
|
@ -140,7 +130,7 @@ impl BotOrchestrator {
|
|||
}
|
||||
|
||||
{
|
||||
let mut session_manager = self.session_manager.lock().await;
|
||||
let mut session_manager = self.state.session_manager.lock().await;
|
||||
session_manager.save_message(
|
||||
session.id,
|
||||
user_id,
|
||||
|
|
@ -153,7 +143,7 @@ impl BotOrchestrator {
|
|||
let response_content = self.direct_mode_handler(&message, &session).await?;
|
||||
|
||||
{
|
||||
let mut session_manager = self.session_manager.lock().await;
|
||||
let mut session_manager = self.state.session_manager.lock().await;
|
||||
session_manager.save_message(session.id, user_id, 2, &response_content, 1)?;
|
||||
}
|
||||
|
||||
|
|
@ -168,7 +158,7 @@ impl BotOrchestrator {
|
|||
is_complete: true,
|
||||
};
|
||||
|
||||
if let Some(adapter) = self.channels.get(&message.channel) {
|
||||
if let Some(adapter) = self.state.channels.lock().unwrap().get(&message.channel) {
|
||||
adapter.send_message(bot_response).await?;
|
||||
}
|
||||
|
||||
|
|
@ -180,21 +170,19 @@ impl BotOrchestrator {
|
|||
message: &UserMessage,
|
||||
session: &UserSession,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
// Retrieve conversation history while holding a mutable lock on the session manager.
|
||||
let history = {
|
||||
let mut session_manager = self.session_manager.lock().await;
|
||||
let mut session_manager = self.state.session_manager.lock().await;
|
||||
session_manager.get_conversation_history(session.id, session.user_id)?
|
||||
};
|
||||
|
||||
// Build the prompt from the conversation history.
|
||||
let mut prompt = String::new();
|
||||
for (role, content) in history {
|
||||
prompt.push_str(&format!("{}: {}\n", role, content));
|
||||
}
|
||||
prompt.push_str(&format!("User: {}\nAssistant:", message.content));
|
||||
|
||||
// Generate the assistant's response using the LLM provider.
|
||||
self.llm_provider
|
||||
self.state
|
||||
.llm_provider
|
||||
.generate(&prompt, &serde_json::Value::Null)
|
||||
.await
|
||||
}
|
||||
|
|
@ -206,31 +194,31 @@ impl BotOrchestrator {
|
|||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
info!("Streaming response for user: {}", message.user_id);
|
||||
|
||||
// Parse identifiers, falling back to safe defaults.
|
||||
let mut user_id = Uuid::parse_str(&message.user_id).unwrap_or_else(|_| Uuid::new_v4());
|
||||
let bot_id = Uuid::parse_str(&message.bot_id).unwrap_or_else(|_| Uuid::nil());
|
||||
let mut auth = self.auth_service.lock().await;
|
||||
let mut auth = self.state.auth_service.lock().await;
|
||||
let user_exists = auth.get_user_by_id(user_id)?;
|
||||
|
||||
if user_exists.is_none() {
|
||||
// User does not exist, invoke Authentication service to create them
|
||||
user_id = auth.create_user("anonymous1", "anonymous@local", "password")?;
|
||||
} else {
|
||||
user_id = user_exists.unwrap().id;
|
||||
}
|
||||
|
||||
// Retrieve an existing session or create a new one.
|
||||
let session = {
|
||||
let mut sm = self.session_manager.lock().await;
|
||||
let mut sm = self.state.session_manager.lock().await;
|
||||
match sm.get_user_session(user_id, bot_id)? {
|
||||
Some(sess) => sess,
|
||||
None => sm.create_session(user_id, bot_id, "New Conversation")?,
|
||||
None => {
|
||||
let new_session = sm.create_session(user_id, bot_id, "New Conversation")?;
|
||||
Self::run_start_script(&new_session, Arc::clone(&self.state)).await;
|
||||
new_session
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// If the session is awaiting tool input, forward the user's answer to the tool manager.
|
||||
if session.answer_mode == 1 && session.current_tool.is_some() {
|
||||
self.tool_manager.provide_user_response(
|
||||
self.state.tool_manager.provide_user_response(
|
||||
&message.user_id,
|
||||
&message.bot_id,
|
||||
message.content.clone(),
|
||||
|
|
@ -238,9 +226,8 @@ impl BotOrchestrator {
|
|||
return Ok(());
|
||||
}
|
||||
|
||||
// Persist the incoming user message.
|
||||
{
|
||||
let mut sm = self.session_manager.lock().await;
|
||||
let mut sm = self.state.session_manager.lock().await;
|
||||
sm.save_message(
|
||||
session.id,
|
||||
user_id,
|
||||
|
|
@ -250,9 +237,8 @@ impl BotOrchestrator {
|
|||
)?;
|
||||
}
|
||||
|
||||
// Build the prompt from the conversation history.
|
||||
let prompt = {
|
||||
let mut sm = self.session_manager.lock().await;
|
||||
let mut sm = self.state.session_manager.lock().await;
|
||||
let history = sm.get_conversation_history(session.id, user_id)?;
|
||||
let mut p = String::new();
|
||||
for (role, content) in history {
|
||||
|
|
@ -262,11 +248,9 @@ impl BotOrchestrator {
|
|||
p
|
||||
};
|
||||
|
||||
// Set up a channel for the streaming LLM output.
|
||||
let (stream_tx, mut stream_rx) = mpsc::channel::<String>(100);
|
||||
let llm = self.llm_provider.clone();
|
||||
let llm = self.state.llm_provider.clone();
|
||||
|
||||
// Spawn the LLM streaming task.
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = llm
|
||||
.generate_stream(&prompt, &serde_json::Value::Null, stream_tx)
|
||||
|
|
@ -276,7 +260,6 @@ impl BotOrchestrator {
|
|||
}
|
||||
});
|
||||
|
||||
// Forward each chunk to the client as it arrives.
|
||||
let mut full_response = String::new();
|
||||
while let Some(chunk) = stream_rx.recv().await {
|
||||
full_response.push_str(&chunk);
|
||||
|
|
@ -293,18 +276,15 @@ impl BotOrchestrator {
|
|||
};
|
||||
|
||||
if response_tx.send(partial).await.is_err() {
|
||||
// Receiver has been dropped; stop streaming.
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Save the complete assistant reply.
|
||||
{
|
||||
let mut sm = self.session_manager.lock().await;
|
||||
let mut sm = self.state.session_manager.lock().await;
|
||||
sm.save_message(session.id, user_id, 2, &full_response, 1)?;
|
||||
}
|
||||
|
||||
// Notify the client that the stream is finished.
|
||||
let final_msg = BotResponse {
|
||||
bot_id: message.bot_id,
|
||||
user_id: message.user_id,
|
||||
|
|
@ -324,7 +304,7 @@ impl BotOrchestrator {
|
|||
&self,
|
||||
user_id: Uuid,
|
||||
) -> Result<Vec<UserSession>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let mut session_manager = self.session_manager.lock().await;
|
||||
let mut session_manager = self.state.session_manager.lock().await;
|
||||
session_manager.get_user_sessions(user_id)
|
||||
}
|
||||
|
||||
|
|
@ -333,7 +313,7 @@ impl BotOrchestrator {
|
|||
session_id: Uuid,
|
||||
user_id: Uuid,
|
||||
) -> Result<Vec<(String, String)>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let mut session_manager = self.session_manager.lock().await;
|
||||
let mut session_manager = self.state.session_manager.lock().await;
|
||||
session_manager.get_conversation_history(session_id, user_id)
|
||||
}
|
||||
|
||||
|
|
@ -351,15 +331,20 @@ impl BotOrchestrator {
|
|||
.unwrap_or_else(|_| Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap());
|
||||
|
||||
let session = {
|
||||
let mut session_manager = self.session_manager.lock().await;
|
||||
let mut session_manager = self.state.session_manager.lock().await;
|
||||
match session_manager.get_user_session(user_id, bot_id)? {
|
||||
Some(session) => session,
|
||||
None => session_manager.create_session(user_id, bot_id, "New Conversation")?,
|
||||
None => {
|
||||
let new_session =
|
||||
session_manager.create_session(user_id, bot_id, "New Conversation")?;
|
||||
Self::run_start_script(&new_session, Arc::clone(&self.state)).await;
|
||||
new_session
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
{
|
||||
let mut session_manager = self.session_manager.lock().await;
|
||||
let mut session_manager = self.state.session_manager.lock().await;
|
||||
session_manager.save_message(
|
||||
session.id,
|
||||
user_id,
|
||||
|
|
@ -370,17 +355,24 @@ impl BotOrchestrator {
|
|||
}
|
||||
|
||||
let is_tool_waiting = self
|
||||
.state
|
||||
.tool_manager
|
||||
.is_tool_waiting(&message.session_id)
|
||||
.await
|
||||
.unwrap_or(false);
|
||||
|
||||
if is_tool_waiting {
|
||||
self.tool_manager
|
||||
self.state
|
||||
.tool_manager
|
||||
.provide_input(&message.session_id, &message.content)
|
||||
.await?;
|
||||
|
||||
if let Ok(tool_output) = self.tool_manager.get_tool_output(&message.session_id).await {
|
||||
if let Ok(tool_output) = self
|
||||
.state
|
||||
.tool_manager
|
||||
.get_tool_output(&message.session_id)
|
||||
.await
|
||||
{
|
||||
for output in tool_output {
|
||||
let bot_response = BotResponse {
|
||||
bot_id: message.bot_id.clone(),
|
||||
|
|
@ -393,7 +385,8 @@ impl BotOrchestrator {
|
|||
is_complete: true,
|
||||
};
|
||||
|
||||
if let Some(adapter) = self.channels.get(&message.channel) {
|
||||
if let Some(adapter) = self.state.channels.lock().unwrap().get(&message.channel)
|
||||
{
|
||||
adapter.send_message(bot_response).await?;
|
||||
}
|
||||
}
|
||||
|
|
@ -406,12 +399,13 @@ impl BotOrchestrator {
|
|||
|| message.content.to_lowercase().contains("math")
|
||||
{
|
||||
match self
|
||||
.state
|
||||
.tool_manager
|
||||
.execute_tool("calculator", &message.session_id, &message.user_id)
|
||||
.await
|
||||
{
|
||||
Ok(tool_result) => {
|
||||
let mut session_manager = self.session_manager.lock().await;
|
||||
let mut session_manager = self.state.session_manager.lock().await;
|
||||
session_manager.save_message(session.id, user_id, 2, &tool_result.output, 2)?;
|
||||
|
||||
tool_result.output
|
||||
|
|
@ -421,7 +415,7 @@ impl BotOrchestrator {
|
|||
}
|
||||
}
|
||||
} else {
|
||||
let available_tools = self.tool_manager.list_tools();
|
||||
let available_tools = self.state.tool_manager.list_tools();
|
||||
let tools_context = if !available_tools.is_empty() {
|
||||
format!("\n\nAvailable tools: {}. If the user needs calculations, suggest using the calculator tool.", available_tools.join(", "))
|
||||
} else {
|
||||
|
|
@ -430,13 +424,14 @@ impl BotOrchestrator {
|
|||
|
||||
let full_prompt = format!("{}{}", message.content, tools_context);
|
||||
|
||||
self.llm_provider
|
||||
self.state
|
||||
.llm_provider
|
||||
.generate(&full_prompt, &serde_json::Value::Null)
|
||||
.await?
|
||||
};
|
||||
|
||||
{
|
||||
let mut session_manager = self.session_manager.lock().await;
|
||||
let mut session_manager = self.state.session_manager.lock().await;
|
||||
session_manager.save_message(session.id, user_id, 2, &response, 1)?;
|
||||
}
|
||||
|
||||
|
|
@ -451,25 +446,63 @@ impl BotOrchestrator {
|
|||
is_complete: true,
|
||||
};
|
||||
|
||||
if let Some(adapter) = self.channels.get(&message.channel) {
|
||||
if let Some(adapter) = self.state.channels.lock().unwrap().get(&message.channel) {
|
||||
adapter.send_message(bot_response).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_start_script(session: &UserSession, state: Arc<AppState>) {
|
||||
let start_script = r#"
|
||||
TALK "Welcome to General Bots!"
|
||||
HEAR name
|
||||
TALK "Hello, " + name
|
||||
|
||||
text = GET "default.pdf"
|
||||
SET CONTEXT text
|
||||
|
||||
resume = LLM "Build a resume from " + text
|
||||
"#;
|
||||
|
||||
info!("Running start.bas for session: {}", session.id);
|
||||
|
||||
let session_clone = session.clone();
|
||||
let state_clone = state.clone();
|
||||
tokio::spawn(async move {
|
||||
let state_for_run = state_clone.clone();
|
||||
if let Err(e) = crate::basic::ScriptService::new(state_clone, session_clone.clone())
|
||||
.compile(start_script)
|
||||
.and_then(|ast| {
|
||||
crate::basic::ScriptService::new(state_for_run, session_clone.clone()).run(&ast)
|
||||
})
|
||||
{
|
||||
log::error!("Failed to run start.bas: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for BotOrchestrator {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
state: Arc::new(AppState::default()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[actix_web::get("/ws")]
|
||||
async fn websocket_handler(
|
||||
req: HttpRequest,
|
||||
stream: web::Payload,
|
||||
data: web::Data<crate::shared::state::AppState>,
|
||||
data: web::Data<AppState>,
|
||||
) -> Result<HttpResponse, actix_web::Error> {
|
||||
let (res, mut session, mut msg_stream) = actix_ws::handle(&req, stream)?;
|
||||
let session_id = Uuid::new_v4().to_string();
|
||||
let (tx, mut rx) = mpsc::channel::<BotResponse>(100);
|
||||
|
||||
data.orchestrator
|
||||
let orchestrator = BotOrchestrator::new(Arc::clone(&data));
|
||||
orchestrator
|
||||
.register_response_channel(session_id.clone(), tx.clone())
|
||||
.await;
|
||||
data.web_adapter
|
||||
|
|
@ -479,7 +512,6 @@ async fn websocket_handler(
|
|||
.add_connection(session_id.clone(), tx.clone())
|
||||
.await;
|
||||
|
||||
let orchestrator = data.orchestrator.clone();
|
||||
let web_adapter = data.web_adapter.clone();
|
||||
|
||||
actix_web::rt::spawn(async move {
|
||||
|
|
@ -523,7 +555,7 @@ async fn websocket_handler(
|
|||
|
||||
#[actix_web::get("/api/whatsapp/webhook")]
|
||||
async fn whatsapp_webhook_verify(
|
||||
data: web::Data<crate::shared::state::AppState>,
|
||||
data: web::Data<AppState>,
|
||||
web::Query(params): web::Query<HashMap<String, String>>,
|
||||
) -> Result<HttpResponse> {
|
||||
let empty = String::new();
|
||||
|
|
@ -539,7 +571,7 @@ async fn whatsapp_webhook_verify(
|
|||
|
||||
#[actix_web::post("/api/whatsapp/webhook")]
|
||||
async fn whatsapp_webhook(
|
||||
data: web::Data<crate::shared::state::AppState>,
|
||||
data: web::Data<AppState>,
|
||||
payload: web::Json<crate::whatsapp::WhatsAppMessage>,
|
||||
) -> Result<HttpResponse> {
|
||||
match data
|
||||
|
|
@ -549,7 +581,8 @@ async fn whatsapp_webhook(
|
|||
{
|
||||
Ok(user_messages) => {
|
||||
for user_message in user_messages {
|
||||
if let Err(e) = data.orchestrator.process_message(user_message).await {
|
||||
let orchestrator = BotOrchestrator::new(Arc::clone(&data));
|
||||
if let Err(e) = orchestrator.process_message(user_message).await {
|
||||
log::error!("Error processing WhatsApp message: {}", e);
|
||||
}
|
||||
}
|
||||
|
|
@ -564,7 +597,7 @@ async fn whatsapp_webhook(
|
|||
|
||||
#[actix_web::post("/api/voice/start")]
|
||||
async fn voice_start(
|
||||
data: web::Data<crate::shared::state::AppState>,
|
||||
data: web::Data<AppState>,
|
||||
info: web::Json<serde_json::Value>,
|
||||
) -> Result<HttpResponse> {
|
||||
let session_id = info
|
||||
|
|
@ -593,7 +626,7 @@ async fn voice_start(
|
|||
|
||||
#[actix_web::post("/api/voice/stop")]
|
||||
async fn voice_stop(
|
||||
data: web::Data<crate::shared::state::AppState>,
|
||||
data: web::Data<AppState>,
|
||||
info: web::Json<serde_json::Value>,
|
||||
) -> Result<HttpResponse> {
|
||||
let session_id = info
|
||||
|
|
@ -611,7 +644,7 @@ async fn voice_stop(
|
|||
}
|
||||
|
||||
#[actix_web::post("/api/sessions")]
|
||||
async fn create_session(_data: web::Data<crate::shared::state::AppState>) -> Result<HttpResponse> {
|
||||
async fn create_session(_data: web::Data<AppState>) -> Result<HttpResponse> {
|
||||
let session_id = Uuid::new_v4();
|
||||
Ok(HttpResponse::Ok().json(serde_json::json!({
|
||||
"session_id": session_id,
|
||||
|
|
@ -621,9 +654,10 @@ async fn create_session(_data: web::Data<crate::shared::state::AppState>) -> Res
|
|||
}
|
||||
|
||||
#[actix_web::get("/api/sessions")]
|
||||
async fn get_sessions(data: web::Data<crate::shared::state::AppState>) -> Result<HttpResponse> {
|
||||
async fn get_sessions(data: web::Data<AppState>) -> Result<HttpResponse> {
|
||||
let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap();
|
||||
match data.orchestrator.get_user_sessions(user_id).await {
|
||||
let orchestrator = BotOrchestrator::new(Arc::clone(&data));
|
||||
match orchestrator.get_user_sessions(user_id).await {
|
||||
Ok(sessions) => Ok(HttpResponse::Ok().json(sessions)),
|
||||
Err(e) => {
|
||||
Ok(HttpResponse::InternalServerError()
|
||||
|
|
@ -634,22 +668,24 @@ async fn get_sessions(data: web::Data<crate::shared::state::AppState>) -> Result
|
|||
|
||||
#[actix_web::get("/api/sessions/{session_id}")]
|
||||
async fn get_session_history(
|
||||
data: web::Data<crate::shared::state::AppState>,
|
||||
data: web::Data<AppState>,
|
||||
path: web::Path<String>,
|
||||
) -> Result<HttpResponse> {
|
||||
let session_id = path.into_inner();
|
||||
let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap();
|
||||
|
||||
match Uuid::parse_str(&session_id) {
|
||||
Ok(session_uuid) => match data
|
||||
.orchestrator
|
||||
.get_conversation_history(session_uuid, user_id)
|
||||
.await
|
||||
{
|
||||
Ok(history) => Ok(HttpResponse::Ok().json(history)),
|
||||
Err(e) => Ok(HttpResponse::InternalServerError()
|
||||
.json(serde_json::json!({"error": e.to_string()}))),
|
||||
},
|
||||
Ok(session_uuid) => {
|
||||
let orchestrator = BotOrchestrator::new(Arc::clone(&data));
|
||||
match orchestrator
|
||||
.get_conversation_history(session_uuid, user_id)
|
||||
.await
|
||||
{
|
||||
Ok(history) => Ok(HttpResponse::Ok().json(history)),
|
||||
Err(e) => Ok(HttpResponse::InternalServerError()
|
||||
.json(serde_json::json!({"error": e.to_string()}))),
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
Ok(HttpResponse::BadRequest().json(serde_json::json!({"error": "Invalid session ID"})))
|
||||
}
|
||||
|
|
@ -658,7 +694,7 @@ async fn get_session_history(
|
|||
|
||||
#[actix_web::post("/api/set_mode")]
|
||||
async fn set_mode_handler(
|
||||
data: web::Data<crate::shared::state::AppState>,
|
||||
data: web::Data<AppState>,
|
||||
info: web::Json<HashMap<String, String>>,
|
||||
) -> Result<HttpResponse> {
|
||||
let default_user = "default_user".to_string();
|
||||
|
|
@ -671,8 +707,8 @@ async fn set_mode_handler(
|
|||
|
||||
let mode = mode_str.parse::<i32>().unwrap_or(0);
|
||||
|
||||
if let Err(e) = data
|
||||
.orchestrator
|
||||
let orchestrator = BotOrchestrator::new(Arc::clone(&data));
|
||||
if let Err(e) = orchestrator
|
||||
.set_user_answer_mode(user_id, bot_id, mode)
|
||||
.await
|
||||
{
|
||||
|
|
|
|||
78
src/main.rs
78
src/main.rs
|
|
@ -5,6 +5,7 @@ use actix_web::middleware::Logger;
|
|||
use actix_web::{web, App, HttpServer};
|
||||
use dotenvy::dotenv;
|
||||
use log::info;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
mod auth;
|
||||
|
|
@ -105,7 +106,7 @@ async fn main() -> std::io::Result<()> {
|
|||
};
|
||||
|
||||
let tool_manager = Arc::new(tools::ToolManager::new());
|
||||
let llm_provider = Arc::new(llm::OpenAIClient::new(
|
||||
let llm_provider = Arc::new(crate::llm::OpenAIClient::new(
|
||||
"empty".to_string(),
|
||||
Some("http://localhost:8081".to_string()),
|
||||
));
|
||||
|
|
@ -125,69 +126,40 @@ async fn main() -> std::io::Result<()> {
|
|||
|
||||
let tool_api = Arc::new(tools::ToolApi::new());
|
||||
|
||||
let base_app_state = AppState {
|
||||
let session_manager = Arc::new(tokio::sync::Mutex::new(session::SessionManager::new(
|
||||
diesel::Connection::establish(&cfg.database_url()).unwrap(),
|
||||
redis_client.clone(),
|
||||
)));
|
||||
|
||||
let auth_service = Arc::new(tokio::sync::Mutex::new(auth::AuthService::new(
|
||||
diesel::Connection::establish(&cfg.database_url()).unwrap(),
|
||||
redis_client.clone(),
|
||||
)));
|
||||
|
||||
let app_state = Arc::new(AppState {
|
||||
s3_client: None,
|
||||
config: Some(cfg.clone()),
|
||||
conn: db_pool.clone(),
|
||||
custom_conn: db_custom_pool.clone(),
|
||||
redis_client: redis_client.clone(),
|
||||
orchestrator: Arc::new(bot::BotOrchestrator::new(
|
||||
session::SessionManager::new(
|
||||
diesel::Connection::establish(&cfg.database_url()).unwrap(),
|
||||
redis_client.clone(),
|
||||
),
|
||||
(*tool_manager).clone(),
|
||||
llm_provider.clone(),
|
||||
auth::AuthService::new(
|
||||
diesel::Connection::establish(&cfg.database_url()).unwrap(),
|
||||
redis_client.clone(),
|
||||
),
|
||||
)),
|
||||
web_adapter,
|
||||
voice_adapter,
|
||||
whatsapp_adapter,
|
||||
tool_api,
|
||||
};
|
||||
session_manager: session_manager.clone(),
|
||||
tool_manager: tool_manager.clone(),
|
||||
llm_provider: llm_provider.clone(),
|
||||
auth_service: auth_service.clone(),
|
||||
channels: Arc::new(Mutex::new(HashMap::new())),
|
||||
response_channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
|
||||
web_adapter: web_adapter.clone(),
|
||||
voice_adapter: voice_adapter.clone(),
|
||||
whatsapp_adapter: whatsapp_adapter.clone(),
|
||||
tool_api: tool_api.clone(),
|
||||
});
|
||||
|
||||
info!(
|
||||
"Starting server on {}:{}",
|
||||
config.server.host, config.server.port
|
||||
);
|
||||
|
||||
let closure_config = config.clone();
|
||||
|
||||
HttpServer::new(move || {
|
||||
let cfg = closure_config.clone();
|
||||
|
||||
let auth_service = auth::AuthService::new(
|
||||
diesel::Connection::establish(&cfg.database_url()).unwrap(),
|
||||
redis_client.clone(),
|
||||
);
|
||||
let session_manager = session::SessionManager::new(
|
||||
diesel::Connection::establish(&cfg.database_url()).unwrap(),
|
||||
redis_client.clone(),
|
||||
);
|
||||
|
||||
let orchestrator = Arc::new(bot::BotOrchestrator::new(
|
||||
session_manager,
|
||||
(*tool_manager).clone(),
|
||||
llm_provider.clone(),
|
||||
auth_service,
|
||||
));
|
||||
|
||||
let app_state = AppState {
|
||||
s3_client: base_app_state.s3_client.clone(),
|
||||
config: base_app_state.config.clone(),
|
||||
conn: base_app_state.conn.clone(),
|
||||
custom_conn: base_app_state.custom_conn.clone(),
|
||||
redis_client: base_app_state.redis_client.clone(),
|
||||
orchestrator,
|
||||
web_adapter: base_app_state.web_adapter.clone(),
|
||||
voice_adapter: base_app_state.voice_adapter.clone(),
|
||||
whatsapp_adapter: base_app_state.whatsapp_adapter.clone(),
|
||||
tool_api: base_app_state.tool_api.clone(),
|
||||
};
|
||||
|
||||
let cors = Cors::default()
|
||||
.allow_any_origin()
|
||||
.allow_any_method()
|
||||
|
|
@ -200,7 +172,7 @@ async fn main() -> std::io::Result<()> {
|
|||
.wrap(cors)
|
||||
.wrap(Logger::default())
|
||||
.wrap(Logger::new("HTTP REQUEST: %a %{User-Agent}i"))
|
||||
.app_data(web::Data::new(app_state_clone));
|
||||
.app_data(web::Data::from(app_state_clone));
|
||||
|
||||
app = app
|
||||
.service(upload_file)
|
||||
|
|
|
|||
|
|
@ -1,21 +1,33 @@
|
|||
use crate::bot::BotOrchestrator;
|
||||
use crate::channels::{VoiceAdapter, WebChannelAdapter};
|
||||
use crate::auth::AuthService;
|
||||
use crate::channels::{ChannelAdapter, VoiceAdapter, WebChannelAdapter};
|
||||
use crate::config::AppConfig;
|
||||
use crate::tools::ToolApi;
|
||||
use crate::llm::LLMProvider;
|
||||
use crate::session::SessionManager;
|
||||
use crate::tools::{ToolApi, ToolManager};
|
||||
use crate::whatsapp::WhatsAppAdapter;
|
||||
use diesel::PgConnection;
|
||||
use diesel::{Connection, PgConnection};
|
||||
use redis::Client;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::shared::models::BotResponse;
|
||||
|
||||
pub struct AppState {
|
||||
pub s3_client: Option<aws_sdk_s3::Client>,
|
||||
pub config: Option<AppConfig>,
|
||||
pub conn: Arc<Mutex<PgConnection>>,
|
||||
pub custom_conn: Arc<Mutex<PgConnection>>,
|
||||
|
||||
pub redis_client: Option<Arc<Client>>,
|
||||
pub orchestrator: Arc<BotOrchestrator>,
|
||||
|
||||
pub session_manager: Arc<tokio::sync::Mutex<SessionManager>>,
|
||||
pub tool_manager: Arc<ToolManager>,
|
||||
pub llm_provider: Arc<dyn LLMProvider>,
|
||||
pub auth_service: Arc<tokio::sync::Mutex<AuthService>>,
|
||||
pub channels: Arc<Mutex<HashMap<String, Arc<dyn ChannelAdapter>>>>,
|
||||
pub response_channels: Arc<tokio::sync::Mutex<HashMap<String, mpsc::Sender<BotResponse>>>>,
|
||||
|
||||
pub web_adapter: Arc<WebChannelAdapter>,
|
||||
pub voice_adapter: Arc<VoiceAdapter>,
|
||||
pub whatsapp_adapter: Arc<WhatsAppAdapter>,
|
||||
|
|
@ -30,7 +42,12 @@ impl Clone for AppState {
|
|||
conn: Arc::clone(&self.conn),
|
||||
custom_conn: Arc::clone(&self.custom_conn),
|
||||
redis_client: self.redis_client.clone(),
|
||||
orchestrator: Arc::clone(&self.orchestrator),
|
||||
session_manager: Arc::clone(&self.session_manager),
|
||||
tool_manager: Arc::clone(&self.tool_manager),
|
||||
llm_provider: Arc::clone(&self.llm_provider),
|
||||
auth_service: Arc::clone(&self.auth_service),
|
||||
channels: Arc::clone(&self.channels),
|
||||
response_channels: Arc::clone(&self.response_channels),
|
||||
web_adapter: Arc::clone(&self.web_adapter),
|
||||
voice_adapter: Arc::clone(&self.voice_adapter),
|
||||
whatsapp_adapter: Arc::clone(&self.whatsapp_adapter),
|
||||
|
|
@ -38,3 +55,46 @@ impl Clone for AppState {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AppState {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
s3_client: None,
|
||||
config: None,
|
||||
conn: Arc::new(Mutex::new(
|
||||
diesel::PgConnection::establish("postgres://localhost/test").unwrap(),
|
||||
)),
|
||||
custom_conn: Arc::new(Mutex::new(
|
||||
diesel::PgConnection::establish("postgres://localhost/test").unwrap(),
|
||||
)),
|
||||
redis_client: None,
|
||||
session_manager: Arc::new(tokio::sync::Mutex::new(SessionManager::new(
|
||||
diesel::PgConnection::establish("postgres://localhost/test").unwrap(),
|
||||
None,
|
||||
))),
|
||||
tool_manager: Arc::new(ToolManager::new()),
|
||||
llm_provider: Arc::new(crate::llm::OpenAIClient::new(
|
||||
"empty".to_string(),
|
||||
Some("http://localhost:8081".to_string()),
|
||||
)),
|
||||
auth_service: Arc::new(tokio::sync::Mutex::new(AuthService::new(
|
||||
diesel::PgConnection::establish("postgres://localhost/test").unwrap(),
|
||||
None,
|
||||
))),
|
||||
channels: Arc::new(Mutex::new(HashMap::new())),
|
||||
response_channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
|
||||
web_adapter: Arc::new(WebChannelAdapter::new()),
|
||||
voice_adapter: Arc::new(VoiceAdapter::new(
|
||||
"https://livekit.example.com".to_string(),
|
||||
"api_key".to_string(),
|
||||
"api_secret".to_string(),
|
||||
)),
|
||||
whatsapp_adapter: Arc::new(WhatsAppAdapter::new(
|
||||
"whatsapp_token".to_string(),
|
||||
"phone_number_id".to_string(),
|
||||
"verify_token".to_string(),
|
||||
)),
|
||||
tool_api: Arc::new(ToolApi::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue