Refactor to Arc<AppState> for shared state
- Migrate core services to store Arc<AppState> and use locks - Centralize state in AppState with Arc-wrapped managers - Update handlers to pass Arc<AppState> via web::Data - Add Default for AppState and initialize components in main - Update debug.json program path from gbserver to botserver
This commit is contained in:
parent
e14908fc0b
commit
e00eb40d43
10 changed files with 384 additions and 280 deletions
|
|
@ -5,7 +5,7 @@
|
||||||
"command": "cargo",
|
"command": "cargo",
|
||||||
"args": ["build"]
|
"args": ["build"]
|
||||||
},
|
},
|
||||||
"program": "$ZED_WORKTREE_ROOT/target/debug/gbserver",
|
"program": "$ZED_WORKTREE_ROOT/target/debug/botserver",
|
||||||
|
|
||||||
"sourceLanguages": ["rust"],
|
"sourceLanguages": ["rust"],
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
|
|
|
||||||
|
|
@ -9,8 +9,8 @@ echo "Consolidated LLM Context" > "$OUTPUT_FILE"
|
||||||
prompts=(
|
prompts=(
|
||||||
"../../prompts/dev/shared.md"
|
"../../prompts/dev/shared.md"
|
||||||
"../../Cargo.toml"
|
"../../Cargo.toml"
|
||||||
#"../../prompts/dev/fix.md"
|
"../../prompts/dev/fix.md"
|
||||||
"../../prompts/dev/generation.md"
|
#"../../prompts/dev/generation.md"
|
||||||
)
|
)
|
||||||
|
|
||||||
for file in "${prompts[@]}"; do
|
for file in "${prompts[@]}"; do
|
||||||
|
|
@ -23,18 +23,18 @@ dirs=(
|
||||||
#"automation"
|
#"automation"
|
||||||
#"basic"
|
#"basic"
|
||||||
"bot"
|
"bot"
|
||||||
"channels"
|
#"channels"
|
||||||
#"config"
|
#"config"
|
||||||
"context"
|
#"context"
|
||||||
#"email"
|
#"email"
|
||||||
#"file"
|
#"file"
|
||||||
"llm"
|
#"llm"
|
||||||
#"llm_legacy"
|
#"llm_legacy"
|
||||||
#"org"
|
#"org"
|
||||||
"session"
|
#"session"
|
||||||
"shared"
|
"shared"
|
||||||
#"tests"
|
#"tests"
|
||||||
"tools"
|
#"tools"
|
||||||
#"web_automation"
|
#"web_automation"
|
||||||
#"whatsapp"
|
#"whatsapp"
|
||||||
)
|
)
|
||||||
|
|
@ -47,10 +47,10 @@ done
|
||||||
|
|
||||||
# Also append the specific files you mentioned
|
# Also append the specific files you mentioned
|
||||||
cat "$PROJECT_ROOT/src/main.rs" >> "$OUTPUT_FILE"
|
cat "$PROJECT_ROOT/src/main.rs" >> "$OUTPUT_FILE"
|
||||||
cat "$PROJECT_ROOT/src/basic/keywords/hear_talk.rs" >> "$OUTPUT_FILE"
|
# cat "$PROJECT_ROOT/src/basic/keywords/hear_talk.rs" >> "$OUTPUT_FILE"
|
||||||
|
# cat "$PROJECT_ROOT/src/basic/keywords/get.rs" >> "$OUTPUT_FILE"
|
||||||
echo "This BASIC file will run as soon as the conversation is created. " >> "$OUTPUT_FILE"
|
cat "$PROJECT_ROOT/src/basic/keywords/llm_keyword.rs" >> "$OUTPUT_FILE"
|
||||||
cat "$PROJECT_ROOT/templates/annoucements.gbai/annoucements.gbdialog/start.bas" >> "$OUTPUT_FILE"
|
# cat "$PROJECT_ROOT/src/basic/mod.rs" >> "$OUTPUT_FILE"
|
||||||
|
|
||||||
|
|
||||||
echo "" >> "$OUTPUT_FILE"
|
echo "" >> "$OUTPUT_FILE"
|
||||||
|
|
|
||||||
|
|
@ -10,12 +10,12 @@ use tokio::time::Duration;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
pub struct AutomationService {
|
pub struct AutomationService {
|
||||||
state: AppState,
|
state: Arc<AppState>,
|
||||||
scripts_dir: String,
|
scripts_dir: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AutomationService {
|
impl AutomationService {
|
||||||
pub fn new(state: AppState, scripts_dir: &str) -> Self {
|
pub fn new(state: Arc<AppState>, scripts_dir: &str) -> Self {
|
||||||
Self {
|
Self {
|
||||||
state,
|
state,
|
||||||
scripts_dir: scripts_dir.to_string(),
|
scripts_dir: scripts_dir.to_string(),
|
||||||
|
|
@ -61,13 +61,11 @@ impl AutomationService {
|
||||||
|
|
||||||
async fn check_table_changes(&self, automations: &[Automation], since: DateTime<Utc>) {
|
async fn check_table_changes(&self, automations: &[Automation], since: DateTime<Utc>) {
|
||||||
for automation in automations {
|
for automation in automations {
|
||||||
// Resolve the trigger kind, disambiguating the `from_i32` call.
|
|
||||||
let trigger_kind = match crate::shared::models::TriggerKind::from_i32(automation.kind) {
|
let trigger_kind = match crate::shared::models::TriggerKind::from_i32(automation.kind) {
|
||||||
Some(k) => k,
|
Some(k) => k,
|
||||||
None => continue,
|
None => continue,
|
||||||
};
|
};
|
||||||
|
|
||||||
// We're only interested in table‑change triggers.
|
|
||||||
if !matches!(
|
if !matches!(
|
||||||
trigger_kind,
|
trigger_kind,
|
||||||
TriggerKind::TableUpdate | TriggerKind::TableInsert | TriggerKind::TableDelete
|
TriggerKind::TableUpdate | TriggerKind::TableInsert | TriggerKind::TableDelete
|
||||||
|
|
@ -75,50 +73,41 @@ impl AutomationService {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Table name must be present.
|
|
||||||
let table = match &automation.target {
|
let table = match &automation.target {
|
||||||
Some(t) => t,
|
Some(t) => t,
|
||||||
None => continue,
|
None => continue,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Choose the appropriate timestamp column.
|
|
||||||
let column = match trigger_kind {
|
let column = match trigger_kind {
|
||||||
TriggerKind::TableInsert => "created_at",
|
TriggerKind::TableInsert => "created_at",
|
||||||
_ => "updated_at",
|
_ => "updated_at",
|
||||||
};
|
};
|
||||||
|
|
||||||
// Build a simple COUNT(*) query; alias the column so Diesel can map it directly to i64.
|
|
||||||
let query = format!(
|
let query = format!(
|
||||||
"SELECT COUNT(*) as count FROM {} WHERE {} > $1",
|
"SELECT COUNT(*) as count FROM {} WHERE {} > $1",
|
||||||
table, column
|
table, column
|
||||||
);
|
);
|
||||||
|
|
||||||
// Acquire a connection for this query.
|
|
||||||
let mut conn_guard = self.state.conn.lock().unwrap();
|
let mut conn_guard = self.state.conn.lock().unwrap();
|
||||||
let conn = &mut *conn_guard;
|
let conn = &mut *conn_guard;
|
||||||
|
|
||||||
// Define a struct to capture the query result
|
|
||||||
#[derive(diesel::QueryableByName)]
|
#[derive(diesel::QueryableByName)]
|
||||||
struct CountResult {
|
struct CountResult {
|
||||||
#[diesel(sql_type = diesel::sql_types::BigInt)]
|
#[diesel(sql_type = diesel::sql_types::BigInt)]
|
||||||
count: i64,
|
count: i64,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute the query, retrieving a plain i64 count.
|
|
||||||
let count_result = diesel::sql_query(&query)
|
let count_result = diesel::sql_query(&query)
|
||||||
.bind::<diesel::sql_types::Timestamp, _>(since.naive_utc())
|
.bind::<diesel::sql_types::Timestamp, _>(since.naive_utc())
|
||||||
.get_result::<CountResult>(conn);
|
.get_result::<CountResult>(conn);
|
||||||
|
|
||||||
match count_result {
|
match count_result {
|
||||||
Ok(result) if result.count > 0 => {
|
Ok(result) if result.count > 0 => {
|
||||||
// Release the lock before awaiting asynchronous work.
|
|
||||||
drop(conn_guard);
|
drop(conn_guard);
|
||||||
self.execute_action(&automation.param).await;
|
self.execute_action(&automation.param).await;
|
||||||
self.update_last_triggered(automation.id).await;
|
self.update_last_triggered(automation.id).await;
|
||||||
}
|
}
|
||||||
Ok(_result) => {
|
Ok(_result) => {}
|
||||||
// No relevant rows changed; continue to the next automation.
|
|
||||||
}
|
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Error checking changes for table '{}': {}", table, e);
|
error!("Error checking changes for table '{}': {}", table, e);
|
||||||
}
|
}
|
||||||
|
|
@ -215,7 +204,7 @@ impl AutomationService {
|
||||||
updated_at: Utc::now(),
|
updated_at: Utc::now(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let script_service = ScriptService::new(&self.state, user_session);
|
let script_service = ScriptService::new(Arc::clone(&self.state), user_session);
|
||||||
let ast = match script_service.compile(&script_content) {
|
let ast = match script_service.compile(&script_content) {
|
||||||
Ok(ast) => ast,
|
Ok(ast) => ast,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
|
|
||||||
|
|
@ -1,59 +1,45 @@
|
||||||
use log::info;
|
|
||||||
use crate::shared::state::AppState;
|
|
||||||
use crate::shared::models::UserSession;
|
use crate::shared::models::UserSession;
|
||||||
|
use crate::shared::state::AppState;
|
||||||
|
use log::info;
|
||||||
use reqwest::{self, Client};
|
use reqwest::{self, Client};
|
||||||
use rhai::{Dynamic, Engine};
|
use rhai::{Dynamic, Engine};
|
||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
|
|
||||||
pub fn get_keyword(_state: &AppState, _user: UserSession, engine: &mut Engine) {
|
pub fn get_keyword(state: &AppState, _user: UserSession, engine: &mut Engine) {
|
||||||
|
let state_clone = state.clone();
|
||||||
|
|
||||||
engine
|
engine
|
||||||
.register_custom_syntax(
|
.register_custom_syntax(&["GET", "$expr$"], false, move |context, inputs| {
|
||||||
&["GET", "$expr$"],
|
let url = context.eval_expression_tree(&inputs[0])?;
|
||||||
false,
|
let url_str = url.to_string();
|
||||||
move |context, inputs| {
|
|
||||||
let url = context.eval_expression_tree(&inputs[0])?;
|
|
||||||
let url_str = url.to_string();
|
|
||||||
|
|
||||||
if url_str.contains("..") {
|
if url_str.contains("..") {
|
||||||
return Err("URL contains invalid path traversal sequences like '..'.".into());
|
return Err("URL contains invalid path traversal sequences like '..'.".into());
|
||||||
}
|
}
|
||||||
|
|
||||||
let modified_url = if url_str.starts_with("/") {
|
let state_for_async = state_clone.clone();
|
||||||
let work_root = std::env::var("WORK_ROOT").unwrap_or_else(|_| "./work".to_string());
|
let url_for_async = url_str.clone();
|
||||||
let full_path = std::path::Path::new(&work_root)
|
|
||||||
.join(url_str.trim_start_matches('/'))
|
|
||||||
.to_string_lossy()
|
|
||||||
.into_owned();
|
|
||||||
|
|
||||||
let base_url = "file://";
|
if url_str.starts_with("https://") {
|
||||||
format!("{}{}", base_url, full_path)
|
info!("HTTPS GET request: {}", url_for_async);
|
||||||
} else {
|
|
||||||
url_str.to_string()
|
|
||||||
};
|
|
||||||
|
|
||||||
if modified_url.starts_with("https://") {
|
let fut = execute_get(&url_for_async);
|
||||||
info!("HTTPS GET request: {}", modified_url);
|
let result =
|
||||||
|
tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(fut))
|
||||||
|
.map_err(|e| format!("HTTP request failed: {}", e))?;
|
||||||
|
|
||||||
let fut = execute_get(&modified_url);
|
Ok(Dynamic::from(result))
|
||||||
let result =
|
} else {
|
||||||
tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(fut))
|
info!("Local file GET request from bucket: {}", url_for_async);
|
||||||
.map_err(|e| format!("HTTP request failed: {}", e))?;
|
|
||||||
|
|
||||||
Ok(Dynamic::from(result))
|
let fut = get_from_bucket(&state_for_async, &url_for_async);
|
||||||
} else if modified_url.starts_with("file://") {
|
let result =
|
||||||
let file_path = modified_url.trim_start_matches("file://");
|
tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(fut))
|
||||||
match std::fs::read_to_string(file_path) {
|
.map_err(|e| format!("Bucket GET failed: {}", e))?;
|
||||||
Ok(content) => Ok(Dynamic::from(content)),
|
|
||||||
Err(e) => Err(format!("Failed to read file: {}", e).into()),
|
Ok(Dynamic::from(result))
|
||||||
}
|
}
|
||||||
} else {
|
})
|
||||||
Err(
|
|
||||||
format!("GET request failed: URL must begin with 'https://' or 'file://'")
|
|
||||||
.into(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -69,3 +55,29 @@ pub async fn execute_get(url: &str) -> Result<String, Box<dyn Error + Send + Syn
|
||||||
|
|
||||||
Ok(content)
|
Ok(content)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn get_from_bucket(
|
||||||
|
state: &AppState,
|
||||||
|
file_path: &str,
|
||||||
|
) -> Result<String, Box<dyn Error + Send + Sync>> {
|
||||||
|
info!("Getting file from bucket: {}", file_path);
|
||||||
|
|
||||||
|
if let Some(s3_client) = &state.s3_client {
|
||||||
|
let bucket_name =
|
||||||
|
std::env::var("DEFAULT_BUCKET").unwrap_or_else(|_| "default-bucket".to_string());
|
||||||
|
|
||||||
|
let response = s3_client
|
||||||
|
.get_object()
|
||||||
|
.bucket(&bucket_name)
|
||||||
|
.key(file_path)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let data = response.body.collect().await?;
|
||||||
|
let content = String::from_utf8(data.into_bytes().to_vec())?;
|
||||||
|
|
||||||
|
Ok(content)
|
||||||
|
} else {
|
||||||
|
Err("S3 client not configured".into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,11 @@
|
||||||
use crate::shared::models::UserSession;
|
|
||||||
use crate::shared::state::AppState;
|
use crate::shared::state::AppState;
|
||||||
|
use crate::{channels::ChannelAdapter, shared::models::UserSession};
|
||||||
use log::info;
|
use log::info;
|
||||||
use rhai::{Dynamic, Engine, EvalAltResult};
|
use rhai::{Dynamic, Engine, EvalAltResult};
|
||||||
|
|
||||||
pub fn hear_keyword(_state: &AppState, user: UserSession, engine: &mut Engine) {
|
pub fn hear_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
|
||||||
let session_id = user.id;
|
let session_id = user.id;
|
||||||
|
let cache = state.redis_client.clone();
|
||||||
|
|
||||||
engine
|
engine
|
||||||
.register_custom_syntax(&["HEAR", "$ident$"], true, move |_context, inputs| {
|
.register_custom_syntax(&["HEAR", "$ident$"], true, move |_context, inputs| {
|
||||||
|
|
@ -18,19 +19,35 @@ pub fn hear_keyword(_state: &AppState, user: UserSession, engine: &mut Engine) {
|
||||||
variable_name
|
variable_name
|
||||||
);
|
);
|
||||||
|
|
||||||
// Spawn a background task to handle the input‑waiting logic.
|
let cache_clone = cache.clone();
|
||||||
// The actual waiting implementation should be added here.
|
let session_id_clone = session_id;
|
||||||
|
let var_name_clone = variable_name.clone();
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
log::debug!(
|
log::debug!(
|
||||||
"HEAR: Starting async task for session {} and variable '{}'",
|
"HEAR: Starting async task for session {} and variable '{}'",
|
||||||
session_id,
|
session_id_clone,
|
||||||
variable_name
|
var_name_clone
|
||||||
);
|
);
|
||||||
// TODO: implement actual waiting logic here without using the orchestrator
|
|
||||||
// For now, just log that we would wait for input
|
if let Some(cache_client) = &cache_clone {
|
||||||
|
let mut conn = match cache_client.get_multiplexed_async_connection().await {
|
||||||
|
Ok(conn) => conn,
|
||||||
|
Err(e) => {
|
||||||
|
log::error!("Failed to connect to cache: {}", e);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let key = format!("hear:{}:{}", session_id_clone, var_name_clone);
|
||||||
|
let _: Result<(), _> = redis::cmd("SET")
|
||||||
|
.arg(&key)
|
||||||
|
.arg("waiting")
|
||||||
|
.query_async(&mut conn)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// Interrupt the current Rhai evaluation flow until the user input is received.
|
|
||||||
Err(Box::new(EvalAltResult::ErrorRuntime(
|
Err(Box::new(EvalAltResult::ErrorRuntime(
|
||||||
"Waiting for user input".into(),
|
"Waiting for user input".into(),
|
||||||
rhai::Position::NONE,
|
rhai::Position::NONE,
|
||||||
|
|
@ -40,7 +57,6 @@ pub fn hear_keyword(_state: &AppState, user: UserSession, engine: &mut Engine) {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn talk_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
|
pub fn talk_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
|
||||||
// Import the BotResponse type directly to satisfy diagnostics.
|
|
||||||
use crate::shared::models::BotResponse;
|
use crate::shared::models::BotResponse;
|
||||||
|
|
||||||
let state_clone = state.clone();
|
let state_clone = state.clone();
|
||||||
|
|
@ -63,13 +79,11 @@ pub fn talk_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
|
||||||
is_complete: true,
|
is_complete: true,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Send response through a channel or queue instead of accessing orchestrator directly
|
let state_for_spawn = state_clone.clone();
|
||||||
let _state_for_spawn = state_clone.clone();
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
// Use a more thread-safe approach to send the message
|
if let Err(e) = state_for_spawn.web_adapter.send_message(response).await {
|
||||||
// This avoids capturing the orchestrator directly which isn't Send + Sync
|
log::error!("Failed to send TALK message: {}", e);
|
||||||
// TODO: Implement proper response handling once response_sender field is added to AppState
|
}
|
||||||
log::debug!("TALK: Would send response: {:?}", response);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
Ok(Dynamic::UNIT)
|
Ok(Dynamic::UNIT)
|
||||||
|
|
@ -78,7 +92,7 @@ pub fn talk_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_context_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
|
pub fn set_context_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
|
||||||
let state_clone = state.clone();
|
let cache = state.redis_client.clone();
|
||||||
|
|
||||||
engine
|
engine
|
||||||
.register_custom_syntax(
|
.register_custom_syntax(
|
||||||
|
|
@ -91,14 +105,14 @@ pub fn set_context_keyword(state: &AppState, user: UserSession, engine: &mut Eng
|
||||||
|
|
||||||
let redis_key = format!("context:{}:{}", user.user_id, user.id);
|
let redis_key = format!("context:{}:{}", user.user_id, user.id);
|
||||||
|
|
||||||
let state_for_redis = state_clone.clone();
|
let cache_clone = cache.clone();
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Some(redis_client) = &state_for_redis.redis_client {
|
if let Some(cache_client) = &cache_clone {
|
||||||
let mut conn = match redis_client.get_multiplexed_async_connection().await {
|
let mut conn = match cache_client.get_multiplexed_async_connection().await {
|
||||||
Ok(conn) => conn,
|
Ok(conn) => conn,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
log::error!("Failed to connect to Redis: {}", e);
|
log::error!("Failed to connect to cache: {}", e);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,10 @@
|
||||||
use crate::shared::models::UserSession;
|
use crate::shared::models::UserSession;
|
||||||
use crate::shared::state::AppState;
|
use crate::shared::state::AppState;
|
||||||
use crate::shared::utils::call_llm;
|
|
||||||
use log::info;
|
use log::info;
|
||||||
use rhai::{Dynamic, Engine};
|
use rhai::{Dynamic, Engine};
|
||||||
|
|
||||||
pub fn llm_keyword(state: &AppState, _user: UserSession, engine: &mut Engine) {
|
pub fn llm_keyword(state: &AppState, _user: UserSession, engine: &mut Engine) {
|
||||||
let ai_config = state.config.clone().unwrap().ai.clone();
|
let state_clone = state.clone();
|
||||||
|
|
||||||
engine
|
engine
|
||||||
.register_custom_syntax(&["LLM", "$expr$"], false, move |context, inputs| {
|
.register_custom_syntax(&["LLM", "$expr$"], false, move |context, inputs| {
|
||||||
|
|
@ -14,10 +13,18 @@ pub fn llm_keyword(state: &AppState, _user: UserSession, engine: &mut Engine) {
|
||||||
|
|
||||||
info!("LLM processing text: {}", text_str);
|
info!("LLM processing text: {}", text_str);
|
||||||
|
|
||||||
let fut = call_llm(&text_str, &ai_config);
|
let state_inner = state_clone.clone();
|
||||||
|
let fut = async move {
|
||||||
|
let prompt = text_str;
|
||||||
|
state_inner
|
||||||
|
.llm_provider
|
||||||
|
.generate(&prompt, &serde_json::Value::Null)
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("LLM call failed: {}", e))
|
||||||
|
};
|
||||||
|
|
||||||
let result =
|
let result =
|
||||||
tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(fut))
|
tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(fut))?;
|
||||||
.map_err(|e| format!("LLM call failed: {}", e))?;
|
|
||||||
|
|
||||||
Ok(Dynamic::from(result))
|
Ok(Dynamic::from(result))
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,10 @@
|
||||||
pub mod keywords;
|
use crate::shared::models::UserSession;
|
||||||
|
use crate::shared::state::AppState;
|
||||||
|
use log::info;
|
||||||
|
use rhai::{Dynamic, Engine, EvalAltResult};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
#[cfg(feature = "email")]
|
pub mod keywords;
|
||||||
use self::keywords::create_draft_keyword;
|
|
||||||
|
|
||||||
use self::keywords::create_site::create_site_keyword;
|
use self::keywords::create_site::create_site_keyword;
|
||||||
use self::keywords::find::find_keyword;
|
use self::keywords::find::find_keyword;
|
||||||
|
|
@ -17,43 +20,54 @@ use self::keywords::print::print_keyword;
|
||||||
use self::keywords::set::set_keyword;
|
use self::keywords::set::set_keyword;
|
||||||
use self::keywords::set_schedule::set_schedule_keyword;
|
use self::keywords::set_schedule::set_schedule_keyword;
|
||||||
use self::keywords::wait::wait_keyword;
|
use self::keywords::wait::wait_keyword;
|
||||||
use crate::shared::models::UserSession;
|
|
||||||
use crate::shared::state::AppState;
|
#[cfg(feature = "email")]
|
||||||
use log::info;
|
use self::keywords::create_draft_keyword;
|
||||||
use rhai::{Dynamic, Engine, EvalAltResult};
|
|
||||||
|
#[cfg(feature = "web_automation")]
|
||||||
|
use self::keywords::get_website::get_website_keyword;
|
||||||
|
|
||||||
pub struct ScriptService {
|
pub struct ScriptService {
|
||||||
engine: Engine,
|
engine: Engine,
|
||||||
|
state: Arc<AppState>,
|
||||||
|
user: UserSession,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ScriptService {
|
impl ScriptService {
|
||||||
pub fn new(state: &AppState, user: UserSession) -> Self {
|
pub fn new(state: Arc<AppState>, user: UserSession) -> Self {
|
||||||
let mut engine = Engine::new();
|
let mut engine = Engine::new();
|
||||||
|
|
||||||
engine.set_allow_anonymous_fn(true);
|
engine.set_allow_anonymous_fn(true);
|
||||||
engine.set_allow_looping(true);
|
engine.set_allow_looping(true);
|
||||||
|
|
||||||
#[cfg(feature = "email")]
|
#[cfg(feature = "email")]
|
||||||
create_draft_keyword(state, user.clone(), &mut engine);
|
create_draft_keyword(&state, user.clone(), &mut engine);
|
||||||
|
|
||||||
create_site_keyword(state, user.clone(), &mut engine);
|
create_site_keyword(&state, user.clone(), &mut engine);
|
||||||
find_keyword(state, user.clone(), &mut engine);
|
find_keyword(&state, user.clone(), &mut engine);
|
||||||
for_keyword(state, user.clone(), &mut engine);
|
for_keyword(&state, user.clone(), &mut engine);
|
||||||
first_keyword(&mut engine);
|
first_keyword(&mut engine);
|
||||||
last_keyword(&mut engine);
|
last_keyword(&mut engine);
|
||||||
format_keyword(&mut engine);
|
format_keyword(&mut engine);
|
||||||
llm_keyword(state, user.clone(), &mut engine);
|
llm_keyword(&state, user.clone(), &mut engine);
|
||||||
get_keyword(state, user.clone(), &mut engine);
|
get_keyword(&state, user.clone(), &mut engine);
|
||||||
set_keyword(state, user.clone(), &mut engine);
|
set_keyword(&state, user.clone(), &mut engine);
|
||||||
wait_keyword(state, user.clone(), &mut engine);
|
wait_keyword(&state, user.clone(), &mut engine);
|
||||||
print_keyword(state, user.clone(), &mut engine);
|
print_keyword(&state, user.clone(), &mut engine);
|
||||||
on_keyword(state, user.clone(), &mut engine);
|
on_keyword(&state, user.clone(), &mut engine);
|
||||||
set_schedule_keyword(state, user.clone(), &mut engine);
|
set_schedule_keyword(&state, user.clone(), &mut engine);
|
||||||
hear_keyword(state, user.clone(), &mut engine);
|
hear_keyword(&state, user.clone(), &mut engine);
|
||||||
talk_keyword(state, user.clone(), &mut engine);
|
talk_keyword(&state, user.clone(), &mut engine);
|
||||||
set_context_keyword(state, user.clone(), &mut engine);
|
set_context_keyword(&state, user.clone(), &mut engine);
|
||||||
|
|
||||||
ScriptService { engine }
|
#[cfg(feature = "web_automation")]
|
||||||
|
get_website_keyword(&state, user.clone(), &mut engine);
|
||||||
|
|
||||||
|
ScriptService {
|
||||||
|
engine,
|
||||||
|
state,
|
||||||
|
user,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn preprocess_basic_script(&self, script: &str) -> String {
|
fn preprocess_basic_script(&self, script: &str) -> String {
|
||||||
|
|
|
||||||
236
src/bot/mod.rs
236
src/bot/mod.rs
|
|
@ -6,40 +6,20 @@ use serde_json;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::fs;
|
use std::fs;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::{mpsc, Mutex};
|
use tokio::sync::mpsc;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::auth::AuthService;
|
|
||||||
use crate::channels::ChannelAdapter;
|
use crate::channels::ChannelAdapter;
|
||||||
use crate::llm::LLMProvider;
|
|
||||||
use crate::session::SessionManager;
|
|
||||||
use crate::shared::models::{BotResponse, UserMessage, UserSession};
|
use crate::shared::models::{BotResponse, UserMessage, UserSession};
|
||||||
use crate::tools::ToolManager;
|
use crate::shared::state::AppState;
|
||||||
|
|
||||||
pub struct BotOrchestrator {
|
pub struct BotOrchestrator {
|
||||||
pub session_manager: Arc<Mutex<SessionManager>>,
|
pub state: Arc<AppState>,
|
||||||
tool_manager: Arc<ToolManager>,
|
|
||||||
llm_provider: Arc<dyn LLMProvider>,
|
|
||||||
auth_service: Arc<Mutex<AuthService>>,
|
|
||||||
pub channels: HashMap<String, Arc<dyn ChannelAdapter>>,
|
|
||||||
response_channels: Arc<Mutex<HashMap<String, mpsc::Sender<BotResponse>>>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BotOrchestrator {
|
impl BotOrchestrator {
|
||||||
pub fn new(
|
pub fn new(state: Arc<AppState>) -> Self {
|
||||||
session_manager: SessionManager,
|
Self { state }
|
||||||
tool_manager: ToolManager,
|
|
||||||
llm_provider: Arc<dyn LLMProvider>,
|
|
||||||
auth_service: AuthService,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
|
||||||
session_manager: Arc::new(Mutex::new(session_manager)),
|
|
||||||
tool_manager: Arc::new(tool_manager),
|
|
||||||
llm_provider,
|
|
||||||
auth_service: Arc::new(Mutex::new(auth_service)),
|
|
||||||
channels: HashMap::new(),
|
|
||||||
response_channels: Arc::new(Mutex::new(HashMap::new())),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn handle_user_input(
|
pub async fn handle_user_input(
|
||||||
|
|
@ -47,18 +27,22 @@ impl BotOrchestrator {
|
||||||
session_id: Uuid,
|
session_id: Uuid,
|
||||||
user_input: &str,
|
user_input: &str,
|
||||||
) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let mut session_manager = self.session_manager.lock().await;
|
let mut session_manager = self.state.session_manager.lock().await;
|
||||||
session_manager.provide_input(session_id, user_input.to_string())?;
|
session_manager.provide_input(session_id, user_input.to_string())?;
|
||||||
Ok(None)
|
Ok(None)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn is_waiting_for_input(&self, session_id: Uuid) -> bool {
|
pub async fn is_waiting_for_input(&self, session_id: Uuid) -> bool {
|
||||||
let session_manager = self.session_manager.lock().await;
|
let session_manager = self.state.session_manager.lock().await;
|
||||||
session_manager.is_waiting_for_input(&session_id)
|
session_manager.is_waiting_for_input(&session_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_channel(&mut self, channel_type: &str, adapter: Arc<dyn ChannelAdapter>) {
|
pub fn add_channel(&self, channel_type: &str, adapter: Arc<dyn ChannelAdapter>) {
|
||||||
self.channels.insert(channel_type.to_string(), adapter);
|
self.state
|
||||||
|
.channels
|
||||||
|
.lock()
|
||||||
|
.unwrap()
|
||||||
|
.insert(channel_type.to_string(), adapter);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn register_response_channel(
|
pub async fn register_response_channel(
|
||||||
|
|
@ -66,7 +50,8 @@ impl BotOrchestrator {
|
||||||
session_id: String,
|
session_id: String,
|
||||||
sender: mpsc::Sender<BotResponse>,
|
sender: mpsc::Sender<BotResponse>,
|
||||||
) {
|
) {
|
||||||
self.response_channels
|
self.state
|
||||||
|
.response_channels
|
||||||
.lock()
|
.lock()
|
||||||
.await
|
.await
|
||||||
.insert(session_id, sender);
|
.insert(session_id, sender);
|
||||||
|
|
@ -78,7 +63,7 @@ impl BotOrchestrator {
|
||||||
bot_id: &str,
|
bot_id: &str,
|
||||||
mode: i32,
|
mode: i32,
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let mut session_manager = self.session_manager.lock().await;
|
let mut session_manager = self.state.session_manager.lock().await;
|
||||||
session_manager.update_answer_mode(user_id, bot_id, mode)?;
|
session_manager.update_answer_mode(user_id, bot_id, mode)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
@ -97,10 +82,15 @@ impl BotOrchestrator {
|
||||||
.unwrap_or_else(|_| Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap());
|
.unwrap_or_else(|_| Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap());
|
||||||
|
|
||||||
let session = {
|
let session = {
|
||||||
let mut session_manager = self.session_manager.lock().await;
|
let mut session_manager = self.state.session_manager.lock().await;
|
||||||
match session_manager.get_user_session(user_id, bot_id)? {
|
match session_manager.get_user_session(user_id, bot_id)? {
|
||||||
Some(session) => session,
|
Some(session) => session,
|
||||||
None => session_manager.create_session(user_id, bot_id, "New Conversation")?,
|
None => {
|
||||||
|
let new_session =
|
||||||
|
session_manager.create_session(user_id, bot_id, "New Conversation")?;
|
||||||
|
Self::run_start_script(&new_session, Arc::clone(&self.state)).await;
|
||||||
|
new_session
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -113,7 +103,7 @@ impl BotOrchestrator {
|
||||||
variable_name, session.id
|
variable_name, session.id
|
||||||
);
|
);
|
||||||
|
|
||||||
if let Some(adapter) = self.channels.get(&message.channel) {
|
if let Some(adapter) = self.state.channels.lock().unwrap().get(&message.channel) {
|
||||||
let ack_response = BotResponse {
|
let ack_response = BotResponse {
|
||||||
bot_id: message.bot_id.clone(),
|
bot_id: message.bot_id.clone(),
|
||||||
user_id: message.user_id.clone(),
|
user_id: message.user_id.clone(),
|
||||||
|
|
@ -131,7 +121,7 @@ impl BotOrchestrator {
|
||||||
}
|
}
|
||||||
|
|
||||||
if session.answer_mode == 1 && session.current_tool.is_some() {
|
if session.answer_mode == 1 && session.current_tool.is_some() {
|
||||||
self.tool_manager.provide_user_response(
|
self.state.tool_manager.provide_user_response(
|
||||||
&message.user_id,
|
&message.user_id,
|
||||||
&message.bot_id,
|
&message.bot_id,
|
||||||
message.content.clone(),
|
message.content.clone(),
|
||||||
|
|
@ -140,7 +130,7 @@ impl BotOrchestrator {
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
let mut session_manager = self.session_manager.lock().await;
|
let mut session_manager = self.state.session_manager.lock().await;
|
||||||
session_manager.save_message(
|
session_manager.save_message(
|
||||||
session.id,
|
session.id,
|
||||||
user_id,
|
user_id,
|
||||||
|
|
@ -153,7 +143,7 @@ impl BotOrchestrator {
|
||||||
let response_content = self.direct_mode_handler(&message, &session).await?;
|
let response_content = self.direct_mode_handler(&message, &session).await?;
|
||||||
|
|
||||||
{
|
{
|
||||||
let mut session_manager = self.session_manager.lock().await;
|
let mut session_manager = self.state.session_manager.lock().await;
|
||||||
session_manager.save_message(session.id, user_id, 2, &response_content, 1)?;
|
session_manager.save_message(session.id, user_id, 2, &response_content, 1)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -168,7 +158,7 @@ impl BotOrchestrator {
|
||||||
is_complete: true,
|
is_complete: true,
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some(adapter) = self.channels.get(&message.channel) {
|
if let Some(adapter) = self.state.channels.lock().unwrap().get(&message.channel) {
|
||||||
adapter.send_message(bot_response).await?;
|
adapter.send_message(bot_response).await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -180,21 +170,19 @@ impl BotOrchestrator {
|
||||||
message: &UserMessage,
|
message: &UserMessage,
|
||||||
session: &UserSession,
|
session: &UserSession,
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
// Retrieve conversation history while holding a mutable lock on the session manager.
|
|
||||||
let history = {
|
let history = {
|
||||||
let mut session_manager = self.session_manager.lock().await;
|
let mut session_manager = self.state.session_manager.lock().await;
|
||||||
session_manager.get_conversation_history(session.id, session.user_id)?
|
session_manager.get_conversation_history(session.id, session.user_id)?
|
||||||
};
|
};
|
||||||
|
|
||||||
// Build the prompt from the conversation history.
|
|
||||||
let mut prompt = String::new();
|
let mut prompt = String::new();
|
||||||
for (role, content) in history {
|
for (role, content) in history {
|
||||||
prompt.push_str(&format!("{}: {}\n", role, content));
|
prompt.push_str(&format!("{}: {}\n", role, content));
|
||||||
}
|
}
|
||||||
prompt.push_str(&format!("User: {}\nAssistant:", message.content));
|
prompt.push_str(&format!("User: {}\nAssistant:", message.content));
|
||||||
|
|
||||||
// Generate the assistant's response using the LLM provider.
|
self.state
|
||||||
self.llm_provider
|
.llm_provider
|
||||||
.generate(&prompt, &serde_json::Value::Null)
|
.generate(&prompt, &serde_json::Value::Null)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
@ -206,31 +194,31 @@ impl BotOrchestrator {
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
info!("Streaming response for user: {}", message.user_id);
|
info!("Streaming response for user: {}", message.user_id);
|
||||||
|
|
||||||
// Parse identifiers, falling back to safe defaults.
|
|
||||||
let mut user_id = Uuid::parse_str(&message.user_id).unwrap_or_else(|_| Uuid::new_v4());
|
let mut user_id = Uuid::parse_str(&message.user_id).unwrap_or_else(|_| Uuid::new_v4());
|
||||||
let bot_id = Uuid::parse_str(&message.bot_id).unwrap_or_else(|_| Uuid::nil());
|
let bot_id = Uuid::parse_str(&message.bot_id).unwrap_or_else(|_| Uuid::nil());
|
||||||
let mut auth = self.auth_service.lock().await;
|
let mut auth = self.state.auth_service.lock().await;
|
||||||
let user_exists = auth.get_user_by_id(user_id)?;
|
let user_exists = auth.get_user_by_id(user_id)?;
|
||||||
|
|
||||||
if user_exists.is_none() {
|
if user_exists.is_none() {
|
||||||
// User does not exist, invoke Authentication service to create them
|
|
||||||
user_id = auth.create_user("anonymous1", "anonymous@local", "password")?;
|
user_id = auth.create_user("anonymous1", "anonymous@local", "password")?;
|
||||||
} else {
|
} else {
|
||||||
user_id = user_exists.unwrap().id;
|
user_id = user_exists.unwrap().id;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Retrieve an existing session or create a new one.
|
|
||||||
let session = {
|
let session = {
|
||||||
let mut sm = self.session_manager.lock().await;
|
let mut sm = self.state.session_manager.lock().await;
|
||||||
match sm.get_user_session(user_id, bot_id)? {
|
match sm.get_user_session(user_id, bot_id)? {
|
||||||
Some(sess) => sess,
|
Some(sess) => sess,
|
||||||
None => sm.create_session(user_id, bot_id, "New Conversation")?,
|
None => {
|
||||||
|
let new_session = sm.create_session(user_id, bot_id, "New Conversation")?;
|
||||||
|
Self::run_start_script(&new_session, Arc::clone(&self.state)).await;
|
||||||
|
new_session
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// If the session is awaiting tool input, forward the user's answer to the tool manager.
|
|
||||||
if session.answer_mode == 1 && session.current_tool.is_some() {
|
if session.answer_mode == 1 && session.current_tool.is_some() {
|
||||||
self.tool_manager.provide_user_response(
|
self.state.tool_manager.provide_user_response(
|
||||||
&message.user_id,
|
&message.user_id,
|
||||||
&message.bot_id,
|
&message.bot_id,
|
||||||
message.content.clone(),
|
message.content.clone(),
|
||||||
|
|
@ -238,9 +226,8 @@ impl BotOrchestrator {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Persist the incoming user message.
|
|
||||||
{
|
{
|
||||||
let mut sm = self.session_manager.lock().await;
|
let mut sm = self.state.session_manager.lock().await;
|
||||||
sm.save_message(
|
sm.save_message(
|
||||||
session.id,
|
session.id,
|
||||||
user_id,
|
user_id,
|
||||||
|
|
@ -250,9 +237,8 @@ impl BotOrchestrator {
|
||||||
)?;
|
)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build the prompt from the conversation history.
|
|
||||||
let prompt = {
|
let prompt = {
|
||||||
let mut sm = self.session_manager.lock().await;
|
let mut sm = self.state.session_manager.lock().await;
|
||||||
let history = sm.get_conversation_history(session.id, user_id)?;
|
let history = sm.get_conversation_history(session.id, user_id)?;
|
||||||
let mut p = String::new();
|
let mut p = String::new();
|
||||||
for (role, content) in history {
|
for (role, content) in history {
|
||||||
|
|
@ -262,11 +248,9 @@ impl BotOrchestrator {
|
||||||
p
|
p
|
||||||
};
|
};
|
||||||
|
|
||||||
// Set up a channel for the streaming LLM output.
|
|
||||||
let (stream_tx, mut stream_rx) = mpsc::channel::<String>(100);
|
let (stream_tx, mut stream_rx) = mpsc::channel::<String>(100);
|
||||||
let llm = self.llm_provider.clone();
|
let llm = self.state.llm_provider.clone();
|
||||||
|
|
||||||
// Spawn the LLM streaming task.
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Err(e) = llm
|
if let Err(e) = llm
|
||||||
.generate_stream(&prompt, &serde_json::Value::Null, stream_tx)
|
.generate_stream(&prompt, &serde_json::Value::Null, stream_tx)
|
||||||
|
|
@ -276,7 +260,6 @@ impl BotOrchestrator {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// Forward each chunk to the client as it arrives.
|
|
||||||
let mut full_response = String::new();
|
let mut full_response = String::new();
|
||||||
while let Some(chunk) = stream_rx.recv().await {
|
while let Some(chunk) = stream_rx.recv().await {
|
||||||
full_response.push_str(&chunk);
|
full_response.push_str(&chunk);
|
||||||
|
|
@ -293,18 +276,15 @@ impl BotOrchestrator {
|
||||||
};
|
};
|
||||||
|
|
||||||
if response_tx.send(partial).await.is_err() {
|
if response_tx.send(partial).await.is_err() {
|
||||||
// Receiver has been dropped; stop streaming.
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save the complete assistant reply.
|
|
||||||
{
|
{
|
||||||
let mut sm = self.session_manager.lock().await;
|
let mut sm = self.state.session_manager.lock().await;
|
||||||
sm.save_message(session.id, user_id, 2, &full_response, 1)?;
|
sm.save_message(session.id, user_id, 2, &full_response, 1)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Notify the client that the stream is finished.
|
|
||||||
let final_msg = BotResponse {
|
let final_msg = BotResponse {
|
||||||
bot_id: message.bot_id,
|
bot_id: message.bot_id,
|
||||||
user_id: message.user_id,
|
user_id: message.user_id,
|
||||||
|
|
@ -324,7 +304,7 @@ impl BotOrchestrator {
|
||||||
&self,
|
&self,
|
||||||
user_id: Uuid,
|
user_id: Uuid,
|
||||||
) -> Result<Vec<UserSession>, Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<Vec<UserSession>, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let mut session_manager = self.session_manager.lock().await;
|
let mut session_manager = self.state.session_manager.lock().await;
|
||||||
session_manager.get_user_sessions(user_id)
|
session_manager.get_user_sessions(user_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -333,7 +313,7 @@ impl BotOrchestrator {
|
||||||
session_id: Uuid,
|
session_id: Uuid,
|
||||||
user_id: Uuid,
|
user_id: Uuid,
|
||||||
) -> Result<Vec<(String, String)>, Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<Vec<(String, String)>, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let mut session_manager = self.session_manager.lock().await;
|
let mut session_manager = self.state.session_manager.lock().await;
|
||||||
session_manager.get_conversation_history(session_id, user_id)
|
session_manager.get_conversation_history(session_id, user_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -351,15 +331,20 @@ impl BotOrchestrator {
|
||||||
.unwrap_or_else(|_| Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap());
|
.unwrap_or_else(|_| Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap());
|
||||||
|
|
||||||
let session = {
|
let session = {
|
||||||
let mut session_manager = self.session_manager.lock().await;
|
let mut session_manager = self.state.session_manager.lock().await;
|
||||||
match session_manager.get_user_session(user_id, bot_id)? {
|
match session_manager.get_user_session(user_id, bot_id)? {
|
||||||
Some(session) => session,
|
Some(session) => session,
|
||||||
None => session_manager.create_session(user_id, bot_id, "New Conversation")?,
|
None => {
|
||||||
|
let new_session =
|
||||||
|
session_manager.create_session(user_id, bot_id, "New Conversation")?;
|
||||||
|
Self::run_start_script(&new_session, Arc::clone(&self.state)).await;
|
||||||
|
new_session
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
{
|
{
|
||||||
let mut session_manager = self.session_manager.lock().await;
|
let mut session_manager = self.state.session_manager.lock().await;
|
||||||
session_manager.save_message(
|
session_manager.save_message(
|
||||||
session.id,
|
session.id,
|
||||||
user_id,
|
user_id,
|
||||||
|
|
@ -370,17 +355,24 @@ impl BotOrchestrator {
|
||||||
}
|
}
|
||||||
|
|
||||||
let is_tool_waiting = self
|
let is_tool_waiting = self
|
||||||
|
.state
|
||||||
.tool_manager
|
.tool_manager
|
||||||
.is_tool_waiting(&message.session_id)
|
.is_tool_waiting(&message.session_id)
|
||||||
.await
|
.await
|
||||||
.unwrap_or(false);
|
.unwrap_or(false);
|
||||||
|
|
||||||
if is_tool_waiting {
|
if is_tool_waiting {
|
||||||
self.tool_manager
|
self.state
|
||||||
|
.tool_manager
|
||||||
.provide_input(&message.session_id, &message.content)
|
.provide_input(&message.session_id, &message.content)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
if let Ok(tool_output) = self.tool_manager.get_tool_output(&message.session_id).await {
|
if let Ok(tool_output) = self
|
||||||
|
.state
|
||||||
|
.tool_manager
|
||||||
|
.get_tool_output(&message.session_id)
|
||||||
|
.await
|
||||||
|
{
|
||||||
for output in tool_output {
|
for output in tool_output {
|
||||||
let bot_response = BotResponse {
|
let bot_response = BotResponse {
|
||||||
bot_id: message.bot_id.clone(),
|
bot_id: message.bot_id.clone(),
|
||||||
|
|
@ -393,7 +385,8 @@ impl BotOrchestrator {
|
||||||
is_complete: true,
|
is_complete: true,
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some(adapter) = self.channels.get(&message.channel) {
|
if let Some(adapter) = self.state.channels.lock().unwrap().get(&message.channel)
|
||||||
|
{
|
||||||
adapter.send_message(bot_response).await?;
|
adapter.send_message(bot_response).await?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -406,12 +399,13 @@ impl BotOrchestrator {
|
||||||
|| message.content.to_lowercase().contains("math")
|
|| message.content.to_lowercase().contains("math")
|
||||||
{
|
{
|
||||||
match self
|
match self
|
||||||
|
.state
|
||||||
.tool_manager
|
.tool_manager
|
||||||
.execute_tool("calculator", &message.session_id, &message.user_id)
|
.execute_tool("calculator", &message.session_id, &message.user_id)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(tool_result) => {
|
Ok(tool_result) => {
|
||||||
let mut session_manager = self.session_manager.lock().await;
|
let mut session_manager = self.state.session_manager.lock().await;
|
||||||
session_manager.save_message(session.id, user_id, 2, &tool_result.output, 2)?;
|
session_manager.save_message(session.id, user_id, 2, &tool_result.output, 2)?;
|
||||||
|
|
||||||
tool_result.output
|
tool_result.output
|
||||||
|
|
@ -421,7 +415,7 @@ impl BotOrchestrator {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
let available_tools = self.tool_manager.list_tools();
|
let available_tools = self.state.tool_manager.list_tools();
|
||||||
let tools_context = if !available_tools.is_empty() {
|
let tools_context = if !available_tools.is_empty() {
|
||||||
format!("\n\nAvailable tools: {}. If the user needs calculations, suggest using the calculator tool.", available_tools.join(", "))
|
format!("\n\nAvailable tools: {}. If the user needs calculations, suggest using the calculator tool.", available_tools.join(", "))
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -430,13 +424,14 @@ impl BotOrchestrator {
|
||||||
|
|
||||||
let full_prompt = format!("{}{}", message.content, tools_context);
|
let full_prompt = format!("{}{}", message.content, tools_context);
|
||||||
|
|
||||||
self.llm_provider
|
self.state
|
||||||
|
.llm_provider
|
||||||
.generate(&full_prompt, &serde_json::Value::Null)
|
.generate(&full_prompt, &serde_json::Value::Null)
|
||||||
.await?
|
.await?
|
||||||
};
|
};
|
||||||
|
|
||||||
{
|
{
|
||||||
let mut session_manager = self.session_manager.lock().await;
|
let mut session_manager = self.state.session_manager.lock().await;
|
||||||
session_manager.save_message(session.id, user_id, 2, &response, 1)?;
|
session_manager.save_message(session.id, user_id, 2, &response, 1)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -451,25 +446,63 @@ impl BotOrchestrator {
|
||||||
is_complete: true,
|
is_complete: true,
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some(adapter) = self.channels.get(&message.channel) {
|
if let Some(adapter) = self.state.channels.lock().unwrap().get(&message.channel) {
|
||||||
adapter.send_message(bot_response).await?;
|
adapter.send_message(bot_response).await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn run_start_script(session: &UserSession, state: Arc<AppState>) {
|
||||||
|
let start_script = r#"
|
||||||
|
TALK "Welcome to General Bots!"
|
||||||
|
HEAR name
|
||||||
|
TALK "Hello, " + name
|
||||||
|
|
||||||
|
text = GET "default.pdf"
|
||||||
|
SET CONTEXT text
|
||||||
|
|
||||||
|
resume = LLM "Build a resume from " + text
|
||||||
|
"#;
|
||||||
|
|
||||||
|
info!("Running start.bas for session: {}", session.id);
|
||||||
|
|
||||||
|
let session_clone = session.clone();
|
||||||
|
let state_clone = state.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let state_for_run = state_clone.clone();
|
||||||
|
if let Err(e) = crate::basic::ScriptService::new(state_clone, session_clone.clone())
|
||||||
|
.compile(start_script)
|
||||||
|
.and_then(|ast| {
|
||||||
|
crate::basic::ScriptService::new(state_for_run, session_clone.clone()).run(&ast)
|
||||||
|
})
|
||||||
|
{
|
||||||
|
log::error!("Failed to run start.bas: {}", e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for BotOrchestrator {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
state: Arc::new(AppState::default()),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[actix_web::get("/ws")]
|
#[actix_web::get("/ws")]
|
||||||
async fn websocket_handler(
|
async fn websocket_handler(
|
||||||
req: HttpRequest,
|
req: HttpRequest,
|
||||||
stream: web::Payload,
|
stream: web::Payload,
|
||||||
data: web::Data<crate::shared::state::AppState>,
|
data: web::Data<AppState>,
|
||||||
) -> Result<HttpResponse, actix_web::Error> {
|
) -> Result<HttpResponse, actix_web::Error> {
|
||||||
let (res, mut session, mut msg_stream) = actix_ws::handle(&req, stream)?;
|
let (res, mut session, mut msg_stream) = actix_ws::handle(&req, stream)?;
|
||||||
let session_id = Uuid::new_v4().to_string();
|
let session_id = Uuid::new_v4().to_string();
|
||||||
let (tx, mut rx) = mpsc::channel::<BotResponse>(100);
|
let (tx, mut rx) = mpsc::channel::<BotResponse>(100);
|
||||||
|
|
||||||
data.orchestrator
|
let orchestrator = BotOrchestrator::new(Arc::clone(&data));
|
||||||
|
orchestrator
|
||||||
.register_response_channel(session_id.clone(), tx.clone())
|
.register_response_channel(session_id.clone(), tx.clone())
|
||||||
.await;
|
.await;
|
||||||
data.web_adapter
|
data.web_adapter
|
||||||
|
|
@ -479,7 +512,6 @@ async fn websocket_handler(
|
||||||
.add_connection(session_id.clone(), tx.clone())
|
.add_connection(session_id.clone(), tx.clone())
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let orchestrator = data.orchestrator.clone();
|
|
||||||
let web_adapter = data.web_adapter.clone();
|
let web_adapter = data.web_adapter.clone();
|
||||||
|
|
||||||
actix_web::rt::spawn(async move {
|
actix_web::rt::spawn(async move {
|
||||||
|
|
@ -523,7 +555,7 @@ async fn websocket_handler(
|
||||||
|
|
||||||
#[actix_web::get("/api/whatsapp/webhook")]
|
#[actix_web::get("/api/whatsapp/webhook")]
|
||||||
async fn whatsapp_webhook_verify(
|
async fn whatsapp_webhook_verify(
|
||||||
data: web::Data<crate::shared::state::AppState>,
|
data: web::Data<AppState>,
|
||||||
web::Query(params): web::Query<HashMap<String, String>>,
|
web::Query(params): web::Query<HashMap<String, String>>,
|
||||||
) -> Result<HttpResponse> {
|
) -> Result<HttpResponse> {
|
||||||
let empty = String::new();
|
let empty = String::new();
|
||||||
|
|
@ -539,7 +571,7 @@ async fn whatsapp_webhook_verify(
|
||||||
|
|
||||||
#[actix_web::post("/api/whatsapp/webhook")]
|
#[actix_web::post("/api/whatsapp/webhook")]
|
||||||
async fn whatsapp_webhook(
|
async fn whatsapp_webhook(
|
||||||
data: web::Data<crate::shared::state::AppState>,
|
data: web::Data<AppState>,
|
||||||
payload: web::Json<crate::whatsapp::WhatsAppMessage>,
|
payload: web::Json<crate::whatsapp::WhatsAppMessage>,
|
||||||
) -> Result<HttpResponse> {
|
) -> Result<HttpResponse> {
|
||||||
match data
|
match data
|
||||||
|
|
@ -549,7 +581,8 @@ async fn whatsapp_webhook(
|
||||||
{
|
{
|
||||||
Ok(user_messages) => {
|
Ok(user_messages) => {
|
||||||
for user_message in user_messages {
|
for user_message in user_messages {
|
||||||
if let Err(e) = data.orchestrator.process_message(user_message).await {
|
let orchestrator = BotOrchestrator::new(Arc::clone(&data));
|
||||||
|
if let Err(e) = orchestrator.process_message(user_message).await {
|
||||||
log::error!("Error processing WhatsApp message: {}", e);
|
log::error!("Error processing WhatsApp message: {}", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -564,7 +597,7 @@ async fn whatsapp_webhook(
|
||||||
|
|
||||||
#[actix_web::post("/api/voice/start")]
|
#[actix_web::post("/api/voice/start")]
|
||||||
async fn voice_start(
|
async fn voice_start(
|
||||||
data: web::Data<crate::shared::state::AppState>,
|
data: web::Data<AppState>,
|
||||||
info: web::Json<serde_json::Value>,
|
info: web::Json<serde_json::Value>,
|
||||||
) -> Result<HttpResponse> {
|
) -> Result<HttpResponse> {
|
||||||
let session_id = info
|
let session_id = info
|
||||||
|
|
@ -593,7 +626,7 @@ async fn voice_start(
|
||||||
|
|
||||||
#[actix_web::post("/api/voice/stop")]
|
#[actix_web::post("/api/voice/stop")]
|
||||||
async fn voice_stop(
|
async fn voice_stop(
|
||||||
data: web::Data<crate::shared::state::AppState>,
|
data: web::Data<AppState>,
|
||||||
info: web::Json<serde_json::Value>,
|
info: web::Json<serde_json::Value>,
|
||||||
) -> Result<HttpResponse> {
|
) -> Result<HttpResponse> {
|
||||||
let session_id = info
|
let session_id = info
|
||||||
|
|
@ -611,7 +644,7 @@ async fn voice_stop(
|
||||||
}
|
}
|
||||||
|
|
||||||
#[actix_web::post("/api/sessions")]
|
#[actix_web::post("/api/sessions")]
|
||||||
async fn create_session(_data: web::Data<crate::shared::state::AppState>) -> Result<HttpResponse> {
|
async fn create_session(_data: web::Data<AppState>) -> Result<HttpResponse> {
|
||||||
let session_id = Uuid::new_v4();
|
let session_id = Uuid::new_v4();
|
||||||
Ok(HttpResponse::Ok().json(serde_json::json!({
|
Ok(HttpResponse::Ok().json(serde_json::json!({
|
||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
|
|
@ -621,9 +654,10 @@ async fn create_session(_data: web::Data<crate::shared::state::AppState>) -> Res
|
||||||
}
|
}
|
||||||
|
|
||||||
#[actix_web::get("/api/sessions")]
|
#[actix_web::get("/api/sessions")]
|
||||||
async fn get_sessions(data: web::Data<crate::shared::state::AppState>) -> Result<HttpResponse> {
|
async fn get_sessions(data: web::Data<AppState>) -> Result<HttpResponse> {
|
||||||
let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap();
|
let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap();
|
||||||
match data.orchestrator.get_user_sessions(user_id).await {
|
let orchestrator = BotOrchestrator::new(Arc::clone(&data));
|
||||||
|
match orchestrator.get_user_sessions(user_id).await {
|
||||||
Ok(sessions) => Ok(HttpResponse::Ok().json(sessions)),
|
Ok(sessions) => Ok(HttpResponse::Ok().json(sessions)),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
Ok(HttpResponse::InternalServerError()
|
Ok(HttpResponse::InternalServerError()
|
||||||
|
|
@ -634,22 +668,24 @@ async fn get_sessions(data: web::Data<crate::shared::state::AppState>) -> Result
|
||||||
|
|
||||||
#[actix_web::get("/api/sessions/{session_id}")]
|
#[actix_web::get("/api/sessions/{session_id}")]
|
||||||
async fn get_session_history(
|
async fn get_session_history(
|
||||||
data: web::Data<crate::shared::state::AppState>,
|
data: web::Data<AppState>,
|
||||||
path: web::Path<String>,
|
path: web::Path<String>,
|
||||||
) -> Result<HttpResponse> {
|
) -> Result<HttpResponse> {
|
||||||
let session_id = path.into_inner();
|
let session_id = path.into_inner();
|
||||||
let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap();
|
let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap();
|
||||||
|
|
||||||
match Uuid::parse_str(&session_id) {
|
match Uuid::parse_str(&session_id) {
|
||||||
Ok(session_uuid) => match data
|
Ok(session_uuid) => {
|
||||||
.orchestrator
|
let orchestrator = BotOrchestrator::new(Arc::clone(&data));
|
||||||
.get_conversation_history(session_uuid, user_id)
|
match orchestrator
|
||||||
.await
|
.get_conversation_history(session_uuid, user_id)
|
||||||
{
|
.await
|
||||||
Ok(history) => Ok(HttpResponse::Ok().json(history)),
|
{
|
||||||
Err(e) => Ok(HttpResponse::InternalServerError()
|
Ok(history) => Ok(HttpResponse::Ok().json(history)),
|
||||||
.json(serde_json::json!({"error": e.to_string()}))),
|
Err(e) => Ok(HttpResponse::InternalServerError()
|
||||||
},
|
.json(serde_json::json!({"error": e.to_string()}))),
|
||||||
|
}
|
||||||
|
}
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
Ok(HttpResponse::BadRequest().json(serde_json::json!({"error": "Invalid session ID"})))
|
Ok(HttpResponse::BadRequest().json(serde_json::json!({"error": "Invalid session ID"})))
|
||||||
}
|
}
|
||||||
|
|
@ -658,7 +694,7 @@ async fn get_session_history(
|
||||||
|
|
||||||
#[actix_web::post("/api/set_mode")]
|
#[actix_web::post("/api/set_mode")]
|
||||||
async fn set_mode_handler(
|
async fn set_mode_handler(
|
||||||
data: web::Data<crate::shared::state::AppState>,
|
data: web::Data<AppState>,
|
||||||
info: web::Json<HashMap<String, String>>,
|
info: web::Json<HashMap<String, String>>,
|
||||||
) -> Result<HttpResponse> {
|
) -> Result<HttpResponse> {
|
||||||
let default_user = "default_user".to_string();
|
let default_user = "default_user".to_string();
|
||||||
|
|
@ -671,8 +707,8 @@ async fn set_mode_handler(
|
||||||
|
|
||||||
let mode = mode_str.parse::<i32>().unwrap_or(0);
|
let mode = mode_str.parse::<i32>().unwrap_or(0);
|
||||||
|
|
||||||
if let Err(e) = data
|
let orchestrator = BotOrchestrator::new(Arc::clone(&data));
|
||||||
.orchestrator
|
if let Err(e) = orchestrator
|
||||||
.set_user_answer_mode(user_id, bot_id, mode)
|
.set_user_answer_mode(user_id, bot_id, mode)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
|
|
|
||||||
78
src/main.rs
78
src/main.rs
|
|
@ -5,6 +5,7 @@ use actix_web::middleware::Logger;
|
||||||
use actix_web::{web, App, HttpServer};
|
use actix_web::{web, App, HttpServer};
|
||||||
use dotenvy::dotenv;
|
use dotenvy::dotenv;
|
||||||
use log::info;
|
use log::info;
|
||||||
|
use std::collections::HashMap;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
mod auth;
|
mod auth;
|
||||||
|
|
@ -105,7 +106,7 @@ async fn main() -> std::io::Result<()> {
|
||||||
};
|
};
|
||||||
|
|
||||||
let tool_manager = Arc::new(tools::ToolManager::new());
|
let tool_manager = Arc::new(tools::ToolManager::new());
|
||||||
let llm_provider = Arc::new(llm::OpenAIClient::new(
|
let llm_provider = Arc::new(crate::llm::OpenAIClient::new(
|
||||||
"empty".to_string(),
|
"empty".to_string(),
|
||||||
Some("http://localhost:8081".to_string()),
|
Some("http://localhost:8081".to_string()),
|
||||||
));
|
));
|
||||||
|
|
@ -125,69 +126,40 @@ async fn main() -> std::io::Result<()> {
|
||||||
|
|
||||||
let tool_api = Arc::new(tools::ToolApi::new());
|
let tool_api = Arc::new(tools::ToolApi::new());
|
||||||
|
|
||||||
let base_app_state = AppState {
|
let session_manager = Arc::new(tokio::sync::Mutex::new(session::SessionManager::new(
|
||||||
|
diesel::Connection::establish(&cfg.database_url()).unwrap(),
|
||||||
|
redis_client.clone(),
|
||||||
|
)));
|
||||||
|
|
||||||
|
let auth_service = Arc::new(tokio::sync::Mutex::new(auth::AuthService::new(
|
||||||
|
diesel::Connection::establish(&cfg.database_url()).unwrap(),
|
||||||
|
redis_client.clone(),
|
||||||
|
)));
|
||||||
|
|
||||||
|
let app_state = Arc::new(AppState {
|
||||||
s3_client: None,
|
s3_client: None,
|
||||||
config: Some(cfg.clone()),
|
config: Some(cfg.clone()),
|
||||||
conn: db_pool.clone(),
|
conn: db_pool.clone(),
|
||||||
custom_conn: db_custom_pool.clone(),
|
custom_conn: db_custom_pool.clone(),
|
||||||
redis_client: redis_client.clone(),
|
redis_client: redis_client.clone(),
|
||||||
orchestrator: Arc::new(bot::BotOrchestrator::new(
|
session_manager: session_manager.clone(),
|
||||||
session::SessionManager::new(
|
tool_manager: tool_manager.clone(),
|
||||||
diesel::Connection::establish(&cfg.database_url()).unwrap(),
|
llm_provider: llm_provider.clone(),
|
||||||
redis_client.clone(),
|
auth_service: auth_service.clone(),
|
||||||
),
|
channels: Arc::new(Mutex::new(HashMap::new())),
|
||||||
(*tool_manager).clone(),
|
response_channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
|
||||||
llm_provider.clone(),
|
web_adapter: web_adapter.clone(),
|
||||||
auth::AuthService::new(
|
voice_adapter: voice_adapter.clone(),
|
||||||
diesel::Connection::establish(&cfg.database_url()).unwrap(),
|
whatsapp_adapter: whatsapp_adapter.clone(),
|
||||||
redis_client.clone(),
|
tool_api: tool_api.clone(),
|
||||||
),
|
});
|
||||||
)),
|
|
||||||
web_adapter,
|
|
||||||
voice_adapter,
|
|
||||||
whatsapp_adapter,
|
|
||||||
tool_api,
|
|
||||||
};
|
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
"Starting server on {}:{}",
|
"Starting server on {}:{}",
|
||||||
config.server.host, config.server.port
|
config.server.host, config.server.port
|
||||||
);
|
);
|
||||||
|
|
||||||
let closure_config = config.clone();
|
|
||||||
|
|
||||||
HttpServer::new(move || {
|
HttpServer::new(move || {
|
||||||
let cfg = closure_config.clone();
|
|
||||||
|
|
||||||
let auth_service = auth::AuthService::new(
|
|
||||||
diesel::Connection::establish(&cfg.database_url()).unwrap(),
|
|
||||||
redis_client.clone(),
|
|
||||||
);
|
|
||||||
let session_manager = session::SessionManager::new(
|
|
||||||
diesel::Connection::establish(&cfg.database_url()).unwrap(),
|
|
||||||
redis_client.clone(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let orchestrator = Arc::new(bot::BotOrchestrator::new(
|
|
||||||
session_manager,
|
|
||||||
(*tool_manager).clone(),
|
|
||||||
llm_provider.clone(),
|
|
||||||
auth_service,
|
|
||||||
));
|
|
||||||
|
|
||||||
let app_state = AppState {
|
|
||||||
s3_client: base_app_state.s3_client.clone(),
|
|
||||||
config: base_app_state.config.clone(),
|
|
||||||
conn: base_app_state.conn.clone(),
|
|
||||||
custom_conn: base_app_state.custom_conn.clone(),
|
|
||||||
redis_client: base_app_state.redis_client.clone(),
|
|
||||||
orchestrator,
|
|
||||||
web_adapter: base_app_state.web_adapter.clone(),
|
|
||||||
voice_adapter: base_app_state.voice_adapter.clone(),
|
|
||||||
whatsapp_adapter: base_app_state.whatsapp_adapter.clone(),
|
|
||||||
tool_api: base_app_state.tool_api.clone(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let cors = Cors::default()
|
let cors = Cors::default()
|
||||||
.allow_any_origin()
|
.allow_any_origin()
|
||||||
.allow_any_method()
|
.allow_any_method()
|
||||||
|
|
@ -200,7 +172,7 @@ async fn main() -> std::io::Result<()> {
|
||||||
.wrap(cors)
|
.wrap(cors)
|
||||||
.wrap(Logger::default())
|
.wrap(Logger::default())
|
||||||
.wrap(Logger::new("HTTP REQUEST: %a %{User-Agent}i"))
|
.wrap(Logger::new("HTTP REQUEST: %a %{User-Agent}i"))
|
||||||
.app_data(web::Data::new(app_state_clone));
|
.app_data(web::Data::from(app_state_clone));
|
||||||
|
|
||||||
app = app
|
app = app
|
||||||
.service(upload_file)
|
.service(upload_file)
|
||||||
|
|
|
||||||
|
|
@ -1,21 +1,33 @@
|
||||||
use crate::bot::BotOrchestrator;
|
use crate::auth::AuthService;
|
||||||
use crate::channels::{VoiceAdapter, WebChannelAdapter};
|
use crate::channels::{ChannelAdapter, VoiceAdapter, WebChannelAdapter};
|
||||||
use crate::config::AppConfig;
|
use crate::config::AppConfig;
|
||||||
use crate::tools::ToolApi;
|
use crate::llm::LLMProvider;
|
||||||
|
use crate::session::SessionManager;
|
||||||
|
use crate::tools::{ToolApi, ToolManager};
|
||||||
use crate::whatsapp::WhatsAppAdapter;
|
use crate::whatsapp::WhatsAppAdapter;
|
||||||
use diesel::PgConnection;
|
use diesel::{Connection, PgConnection};
|
||||||
use redis::Client;
|
use redis::Client;
|
||||||
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::sync::Mutex;
|
use std::sync::Mutex;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
|
use crate::shared::models::BotResponse;
|
||||||
|
|
||||||
pub struct AppState {
|
pub struct AppState {
|
||||||
pub s3_client: Option<aws_sdk_s3::Client>,
|
pub s3_client: Option<aws_sdk_s3::Client>,
|
||||||
pub config: Option<AppConfig>,
|
pub config: Option<AppConfig>,
|
||||||
pub conn: Arc<Mutex<PgConnection>>,
|
pub conn: Arc<Mutex<PgConnection>>,
|
||||||
pub custom_conn: Arc<Mutex<PgConnection>>,
|
pub custom_conn: Arc<Mutex<PgConnection>>,
|
||||||
|
|
||||||
pub redis_client: Option<Arc<Client>>,
|
pub redis_client: Option<Arc<Client>>,
|
||||||
pub orchestrator: Arc<BotOrchestrator>,
|
|
||||||
|
pub session_manager: Arc<tokio::sync::Mutex<SessionManager>>,
|
||||||
|
pub tool_manager: Arc<ToolManager>,
|
||||||
|
pub llm_provider: Arc<dyn LLMProvider>,
|
||||||
|
pub auth_service: Arc<tokio::sync::Mutex<AuthService>>,
|
||||||
|
pub channels: Arc<Mutex<HashMap<String, Arc<dyn ChannelAdapter>>>>,
|
||||||
|
pub response_channels: Arc<tokio::sync::Mutex<HashMap<String, mpsc::Sender<BotResponse>>>>,
|
||||||
|
|
||||||
pub web_adapter: Arc<WebChannelAdapter>,
|
pub web_adapter: Arc<WebChannelAdapter>,
|
||||||
pub voice_adapter: Arc<VoiceAdapter>,
|
pub voice_adapter: Arc<VoiceAdapter>,
|
||||||
pub whatsapp_adapter: Arc<WhatsAppAdapter>,
|
pub whatsapp_adapter: Arc<WhatsAppAdapter>,
|
||||||
|
|
@ -30,7 +42,12 @@ impl Clone for AppState {
|
||||||
conn: Arc::clone(&self.conn),
|
conn: Arc::clone(&self.conn),
|
||||||
custom_conn: Arc::clone(&self.custom_conn),
|
custom_conn: Arc::clone(&self.custom_conn),
|
||||||
redis_client: self.redis_client.clone(),
|
redis_client: self.redis_client.clone(),
|
||||||
orchestrator: Arc::clone(&self.orchestrator),
|
session_manager: Arc::clone(&self.session_manager),
|
||||||
|
tool_manager: Arc::clone(&self.tool_manager),
|
||||||
|
llm_provider: Arc::clone(&self.llm_provider),
|
||||||
|
auth_service: Arc::clone(&self.auth_service),
|
||||||
|
channels: Arc::clone(&self.channels),
|
||||||
|
response_channels: Arc::clone(&self.response_channels),
|
||||||
web_adapter: Arc::clone(&self.web_adapter),
|
web_adapter: Arc::clone(&self.web_adapter),
|
||||||
voice_adapter: Arc::clone(&self.voice_adapter),
|
voice_adapter: Arc::clone(&self.voice_adapter),
|
||||||
whatsapp_adapter: Arc::clone(&self.whatsapp_adapter),
|
whatsapp_adapter: Arc::clone(&self.whatsapp_adapter),
|
||||||
|
|
@ -38,3 +55,46 @@ impl Clone for AppState {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Default for AppState {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
s3_client: None,
|
||||||
|
config: None,
|
||||||
|
conn: Arc::new(Mutex::new(
|
||||||
|
diesel::PgConnection::establish("postgres://localhost/test").unwrap(),
|
||||||
|
)),
|
||||||
|
custom_conn: Arc::new(Mutex::new(
|
||||||
|
diesel::PgConnection::establish("postgres://localhost/test").unwrap(),
|
||||||
|
)),
|
||||||
|
redis_client: None,
|
||||||
|
session_manager: Arc::new(tokio::sync::Mutex::new(SessionManager::new(
|
||||||
|
diesel::PgConnection::establish("postgres://localhost/test").unwrap(),
|
||||||
|
None,
|
||||||
|
))),
|
||||||
|
tool_manager: Arc::new(ToolManager::new()),
|
||||||
|
llm_provider: Arc::new(crate::llm::OpenAIClient::new(
|
||||||
|
"empty".to_string(),
|
||||||
|
Some("http://localhost:8081".to_string()),
|
||||||
|
)),
|
||||||
|
auth_service: Arc::new(tokio::sync::Mutex::new(AuthService::new(
|
||||||
|
diesel::PgConnection::establish("postgres://localhost/test").unwrap(),
|
||||||
|
None,
|
||||||
|
))),
|
||||||
|
channels: Arc::new(Mutex::new(HashMap::new())),
|
||||||
|
response_channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
|
||||||
|
web_adapter: Arc::new(WebChannelAdapter::new()),
|
||||||
|
voice_adapter: Arc::new(VoiceAdapter::new(
|
||||||
|
"https://livekit.example.com".to_string(),
|
||||||
|
"api_key".to_string(),
|
||||||
|
"api_secret".to_string(),
|
||||||
|
)),
|
||||||
|
whatsapp_adapter: Arc::new(WhatsAppAdapter::new(
|
||||||
|
"whatsapp_token".to_string(),
|
||||||
|
"phone_number_id".to_string(),
|
||||||
|
"verify_token".to_string(),
|
||||||
|
)),
|
||||||
|
tool_api: Arc::new(ToolApi::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue