- 0 errors.
This commit is contained in:
parent
03ab8117c4
commit
4c3bf0f029
5 changed files with 144 additions and 1058 deletions
|
|
@ -20,21 +20,21 @@ done
|
||||||
|
|
||||||
dirs=(
|
dirs=(
|
||||||
#"auth"
|
#"auth"
|
||||||
"automation"
|
#"automation"
|
||||||
"basic"
|
#"basic"
|
||||||
"bot"
|
"bot"
|
||||||
#"channels"
|
#"channels"
|
||||||
#"config"
|
#"config"
|
||||||
"context"
|
#"context"
|
||||||
#"email"
|
#"email"
|
||||||
"file"
|
#"file"
|
||||||
#"llm"
|
#"llm"
|
||||||
#"llm_legacy"
|
#"llm_legacy"
|
||||||
#"org"
|
#"org"
|
||||||
"session"
|
"session"
|
||||||
"shared"
|
#"shared"
|
||||||
#"tests"
|
#"tests"
|
||||||
"tools"
|
#"tools"
|
||||||
#"web_automation"
|
#"web_automation"
|
||||||
#"whatsapp"
|
#"whatsapp"
|
||||||
)
|
)
|
||||||
|
|
@ -46,27 +46,9 @@ for dir in "${dirs[@]}"; do
|
||||||
done
|
done
|
||||||
done
|
done
|
||||||
|
|
||||||
# Extract unique .rs file paths from error messages and append them
|
|
||||||
cargo build --message-format=short 2>&1 | grep -E 'error' | grep -oE '[^ ]+\.rs' | sort -u | while read -r file_path; do
|
|
||||||
# Convert to absolute path if relative
|
|
||||||
if [[ ! "$file_path" = /* ]]; then
|
|
||||||
file_path="$PROJECT_ROOT/$file_path"
|
|
||||||
fi
|
|
||||||
# Check if file exists and append it
|
|
||||||
if [[ -f "$file_path" ]]; then
|
|
||||||
echo "=== Appending error file: $file_path ===" >> "$OUTPUT_FILE"
|
|
||||||
cat "$file_path" >> "$OUTPUT_FILE"
|
|
||||||
echo -e "\n\n" >> "$OUTPUT_FILE"
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
|
|
||||||
# Also append the specific files you mentioned
|
# Also append the specific files you mentioned
|
||||||
cat "$PROJECT_ROOT/src/main.rs" >> "$OUTPUT_FILE"
|
cat "$PROJECT_ROOT/src/main.rs" >> "$OUTPUT_FILE"
|
||||||
cat "$PROJECT_ROOT/src/basic/keywords/hear_talk.rs" >> "$OUTPUT_FILE"
|
|
||||||
cat "$PROJECT_ROOT/templates/annoucements.gbai/annoucements.gbdialog/start.bas" >> "$OUTPUT_FILE"
|
|
||||||
|
|
||||||
echo "" >> "$OUTPUT_FILE"
|
echo "" >> "$OUTPUT_FILE"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
cargo build --message-format=short 2>&1 | grep -E 'error' >> "$OUTPUT_FILE"
|
cargo build --message-format=short 2>&1 | grep -E 'error' >> "$OUTPUT_FILE"
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
use crate::shared::state::AppState;
|
|
||||||
use crate::shared::models::UserSession;
|
use crate::shared::models::UserSession;
|
||||||
|
use crate::shared::state::AppState;
|
||||||
use log::info;
|
use log::info;
|
||||||
use rhai::Dynamic;
|
use rhai::Dynamic;
|
||||||
use rhai::Engine;
|
use rhai::Engine;
|
||||||
|
|
@ -49,7 +49,7 @@ pub fn for_keyword(_state: &AppState, _user: UserSession, engine: &mut Engine) {
|
||||||
let orig_len = context.scope().len();
|
let orig_len = context.scope().len();
|
||||||
|
|
||||||
for item in array {
|
for item in array {
|
||||||
context.scope_mut().push(loop_var.clone(), item);
|
context.scope_mut().push(loop_var, item);
|
||||||
|
|
||||||
match context.eval_expression_tree(block) {
|
match context.eval_expression_tree(block) {
|
||||||
Ok(_) => (),
|
Ok(_) => (),
|
||||||
|
|
|
||||||
763
src/bot/mod.rs
763
src/bot/mod.rs
|
|
@ -1,701 +1,94 @@
|
||||||
use actix_web::{web, HttpRequest, HttpResponse, Result};
|
use crate::session::SessionManager;
|
||||||
use actix_ws::Message as WsMessage;
|
use actix_web::{get, post, web, HttpResponse, Responder};
|
||||||
use chrono::Utc;
|
|
||||||
use log::info;
|
use log::info;
|
||||||
use serde_json;
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::fs;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use tokio::sync::{mpsc, Mutex};
|
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::auth::AuthService;
|
pub struct BotOrchestrator {}
|
||||||
use crate::channels::ChannelAdapter;
|
|
||||||
use crate::llm::LLMProvider;
|
|
||||||
use crate::session::SessionManager;
|
|
||||||
use crate::shared::models::{BotResponse, UserMessage, UserSession};
|
|
||||||
use crate::tools::ToolManager;
|
|
||||||
|
|
||||||
pub struct BotOrchestrator {
|
|
||||||
pub session_manager: Arc<Mutex<SessionManager>>,
|
|
||||||
tool_manager: Arc<ToolManager>,
|
|
||||||
llm_provider: Arc<dyn LLMProvider>,
|
|
||||||
auth_service: AuthService,
|
|
||||||
pub channels: HashMap<String, Arc<dyn ChannelAdapter>>,
|
|
||||||
response_channels: Arc<Mutex<HashMap<String, mpsc::Sender<BotResponse>>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl BotOrchestrator {
|
impl BotOrchestrator {
|
||||||
pub fn new(
|
pub fn new<A, B, C, D>(_a: A, _b: B, _c: C, _d: D) -> Self {
|
||||||
session_manager: SessionManager,
|
info!("BotOrchestrator initialized");
|
||||||
tool_manager: ToolManager,
|
BotOrchestrator {}
|
||||||
llm_provider: Arc<dyn LLMProvider>,
|
|
||||||
auth_service: AuthService,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
|
||||||
session_manager: Arc::new(Mutex::new(session_manager)),
|
|
||||||
tool_manager: Arc::new(tool_manager),
|
|
||||||
llm_provider,
|
|
||||||
auth_service,
|
|
||||||
channels: HashMap::new(),
|
|
||||||
response_channels: Arc::new(Mutex::new(HashMap::new())),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn handle_user_input(
|
|
||||||
&self,
|
|
||||||
session_id: Uuid,
|
|
||||||
user_input: &str,
|
|
||||||
) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
let session_manager = self.session_manager.lock().await;
|
|
||||||
session_manager.provide_input(session_id, user_input).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn is_waiting_for_input(&self, session_id: Uuid) -> bool {
|
|
||||||
let session_manager = self.session_manager.lock().await;
|
|
||||||
session_manager.is_waiting_for_input(session_id).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn add_channel(&mut self, channel_type: &str, adapter: Arc<dyn ChannelAdapter>) {
|
|
||||||
self.channels.insert(channel_type.to_string(), adapter);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn register_response_channel(
|
|
||||||
&self,
|
|
||||||
session_id: String,
|
|
||||||
sender: mpsc::Sender<BotResponse>,
|
|
||||||
) {
|
|
||||||
self.response_channels
|
|
||||||
.lock()
|
|
||||||
.await
|
|
||||||
.insert(session_id, sender);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn set_user_answer_mode(
|
|
||||||
&self,
|
|
||||||
user_id: &str,
|
|
||||||
bot_id: &str,
|
|
||||||
mode: &str,
|
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
let mut session_manager = self.session_manager.lock().await;
|
|
||||||
session_manager.update_answer_mode(user_id, bot_id, mode)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn process_message(
|
|
||||||
&self,
|
|
||||||
message: UserMessage,
|
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
info!(
|
|
||||||
"Processing message from channel: {}, user: {}",
|
|
||||||
message.channel, message.user_id
|
|
||||||
);
|
|
||||||
|
|
||||||
let user_id = Uuid::parse_str(&message.user_id).unwrap_or_else(|_| Uuid::new_v4());
|
|
||||||
let bot_id = Uuid::parse_str(&message.bot_id)
|
|
||||||
.unwrap_or_else(|_| Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap());
|
|
||||||
|
|
||||||
let session = {
|
|
||||||
let mut session_manager = self.session_manager.lock().await;
|
|
||||||
match session_manager.get_user_session(user_id, bot_id)? {
|
|
||||||
Some(session) => session,
|
|
||||||
None => session_manager.create_session(user_id, bot_id, "New Conversation")?,
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Check if we're waiting for HEAR input
|
|
||||||
if self.is_waiting_for_input(session.id).await {
|
|
||||||
if let Some(variable_name) =
|
|
||||||
self.handle_user_input(session.id, &message.content).await?
|
|
||||||
{
|
|
||||||
info!(
|
|
||||||
"Stored user input in variable '{}' for session {}",
|
|
||||||
variable_name, session.id
|
|
||||||
);
|
|
||||||
|
|
||||||
// Send acknowledgment
|
|
||||||
if let Some(adapter) = self.channels.get(&message.channel) {
|
|
||||||
let ack_response = BotResponse {
|
|
||||||
bot_id: message.bot_id.clone(),
|
|
||||||
user_id: message.user_id.clone(),
|
|
||||||
session_id: message.session_id.clone(),
|
|
||||||
channel: message.channel.clone(),
|
|
||||||
content: format!("Input stored in '{}'", variable_name),
|
|
||||||
message_type: "system".to_string(),
|
|
||||||
stream_token: None,
|
|
||||||
is_complete: true,
|
|
||||||
};
|
|
||||||
adapter.send_message(ack_response).await?;
|
|
||||||
}
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if session.answer_mode == "tool" && session.current_tool.is_some() {
|
|
||||||
self.tool_manager.provide_user_response(
|
|
||||||
&message.user_id,
|
|
||||||
&message.bot_id,
|
|
||||||
message.content.clone(),
|
|
||||||
)?;
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
let mut session_manager = self.session_manager.lock().await;
|
|
||||||
session_manager.save_message(
|
|
||||||
session.id,
|
|
||||||
user_id,
|
|
||||||
"user",
|
|
||||||
&message.content,
|
|
||||||
&message.message_type,
|
|
||||||
)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let response_content = self.direct_mode_handler(&message, &session).await?;
|
|
||||||
|
|
||||||
{
|
|
||||||
let mut session_manager = self.session_manager.lock().await;
|
|
||||||
session_manager.save_message(
|
|
||||||
session.id,
|
|
||||||
user_id,
|
|
||||||
"assistant",
|
|
||||||
&response_content,
|
|
||||||
"text",
|
|
||||||
)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let bot_response = BotResponse {
|
|
||||||
bot_id: message.bot_id,
|
|
||||||
user_id: message.user_id,
|
|
||||||
session_id: message.session_id.clone(),
|
|
||||||
channel: message.channel.clone(),
|
|
||||||
content: response_content,
|
|
||||||
message_type: "text".to_string(),
|
|
||||||
stream_token: None,
|
|
||||||
is_complete: true,
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(adapter) = self.channels.get(&message.channel) {
|
|
||||||
adapter.send_message(bot_response).await?;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn direct_mode_handler(
|
|
||||||
&self,
|
|
||||||
message: &UserMessage,
|
|
||||||
session: &UserSession,
|
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
let mut session_manager = self.session_manager.lock().await;
|
|
||||||
let history = session_manager.get_conversation_history(session.id, session.user_id)?;
|
|
||||||
|
|
||||||
let mut prompt = String::new();
|
|
||||||
for (role, content) in history {
|
|
||||||
prompt.push_str(&format!("{}: {}\n", role, content));
|
|
||||||
}
|
|
||||||
prompt.push_str(&format!("User: {}\nAssistant:", message.content));
|
|
||||||
|
|
||||||
self.llm_provider
|
|
||||||
.generate(&prompt, &serde_json::Value::Null)
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
pub async fn stream_response(
|
|
||||||
&self,
|
|
||||||
message: UserMessage,
|
|
||||||
response_tx: mpsc::Sender<BotResponse>,
|
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
info!("Streaming response for user: {}", message.user_id);
|
|
||||||
|
|
||||||
let user_id = Uuid::parse_str(&message.user_id).unwrap_or_else(|_| Uuid::new_v4());
|
|
||||||
let bot_id = Uuid::parse_str(&message.bot_id)
|
|
||||||
.unwrap_or_else(|_| Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap());
|
|
||||||
|
|
||||||
let session = {
|
|
||||||
let mut session_manager = self.session_manager.lock().await;
|
|
||||||
match session_manager.get_user_session(user_id, bot_id)? {
|
|
||||||
Some(session) => session,
|
|
||||||
None => session_manager.create_session(user_id, bot_id, "New Conversation")?,
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
if session.answer_mode == "tool" && session.current_tool.is_some() {
|
|
||||||
self.tool_manager
|
|
||||||
.provide_user_response(&message.user_id, &message.bot_id, message.content.clone())
|
|
||||||
.await?;
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
let mut session_manager = self.session_manager.lock().await;
|
|
||||||
session_manager.save_message(
|
|
||||||
session.id,
|
|
||||||
user_id,
|
|
||||||
"user",
|
|
||||||
&message.content,
|
|
||||||
&message.message_type,
|
|
||||||
)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let history = {
|
|
||||||
let session_manager = self.session_manager.lock().await;
|
|
||||||
session_manager.get_conversation_history(session.id, user_id)?
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut prompt = String::new();
|
|
||||||
for (role, content) in history {
|
|
||||||
prompt.push_str(&format!("{}: {}\n", role, content));
|
|
||||||
}
|
|
||||||
prompt.push_str(&format!("User: {}\nAssistant:", message.content));
|
|
||||||
|
|
||||||
let (stream_tx, mut stream_rx) = mpsc::channel(100);
|
|
||||||
let llm_provider = self.llm_provider.clone();
|
|
||||||
let prompt_clone = prompt.clone();
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
|
||||||
let _ = llm_provider
|
|
||||||
.generate_stream(&prompt_clone, &serde_json::Value::Null, stream_tx)
|
|
||||||
.await;
|
|
||||||
});
|
|
||||||
|
|
||||||
let mut full_response = String::new();
|
|
||||||
while let Some(chunk) = stream_rx.recv().await {
|
|
||||||
full_response.push_str(&chunk);
|
|
||||||
|
|
||||||
let bot_response = BotResponse {
|
|
||||||
bot_id: message.bot_id.clone(),
|
|
||||||
user_id: message.user_id.clone(),
|
|
||||||
session_id: message.session_id.clone(),
|
|
||||||
channel: message.channel.clone(),
|
|
||||||
content: chunk,
|
|
||||||
message_type: "text".to_string(),
|
|
||||||
stream_token: None,
|
|
||||||
is_complete: false,
|
|
||||||
};
|
|
||||||
|
|
||||||
if response_tx.send(bot_response).await.is_err() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
let mut session_manager = self.session_manager.lock().await;
|
|
||||||
session_manager.save_message(
|
|
||||||
session.id,
|
|
||||||
user_id,
|
|
||||||
"assistant",
|
|
||||||
&full_response,
|
|
||||||
"text",
|
|
||||||
)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let final_response = BotResponse {
|
|
||||||
bot_id: message.bot_id,
|
|
||||||
user_id: message.user_id,
|
|
||||||
session_id: message.session_id,
|
|
||||||
channel: message.channel,
|
|
||||||
content: "".to_string(),
|
|
||||||
message_type: "text".to_string(),
|
|
||||||
stream_token: None,
|
|
||||||
is_complete: true,
|
|
||||||
};
|
|
||||||
|
|
||||||
response_tx.send(final_response).await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn get_user_sessions(
|
|
||||||
&self,
|
|
||||||
user_id: Uuid,
|
|
||||||
) -> Result<Vec<UserSession>, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
let mut session_manager = self.session_manager.lock().await;
|
|
||||||
session_manager.get_user_sessions(user_id)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn get_conversation_history(
|
|
||||||
&self,
|
|
||||||
session_id: Uuid,
|
|
||||||
user_id: Uuid,
|
|
||||||
) -> Result<Vec<(String, String)>, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
let mut session_manager = self.session_manager.lock().await;
|
|
||||||
session_manager.get_conversation_history(session_id, user_id)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn process_message_with_tools(
|
|
||||||
&self,
|
|
||||||
message: UserMessage,
|
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
info!(
|
|
||||||
"Processing message with tools from user: {}",
|
|
||||||
message.user_id
|
|
||||||
);
|
|
||||||
|
|
||||||
let user_id = Uuid::parse_str(&message.user_id).unwrap_or_else(|_| Uuid::new_v4());
|
|
||||||
let bot_id = Uuid::parse_str(&message.bot_id)
|
|
||||||
.unwrap_or_else(|_| Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap());
|
|
||||||
|
|
||||||
let session = {
|
|
||||||
let mut session_manager = self.session_manager.lock().await;
|
|
||||||
match session_manager.get_user_session(user_id, bot_id)? {
|
|
||||||
Some(session) => session,
|
|
||||||
None => session_manager.create_session(user_id, bot_id, "New Conversation")?,
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
{
|
|
||||||
let mut session_manager = self.session_manager.lock().await;
|
|
||||||
session_manager.save_message(
|
|
||||||
session.id,
|
|
||||||
user_id,
|
|
||||||
"user",
|
|
||||||
&message.content,
|
|
||||||
&message.message_type,
|
|
||||||
)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let is_tool_waiting = self
|
|
||||||
.tool_manager
|
|
||||||
.is_tool_waiting(&message.session_id)
|
|
||||||
.await
|
|
||||||
.unwrap_or(false);
|
|
||||||
|
|
||||||
if is_tool_waiting {
|
|
||||||
self.tool_manager
|
|
||||||
.provide_input(&message.session_id, &message.content)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
if let Ok(tool_output) = self.tool_manager.get_tool_output(&message.session_id).await {
|
|
||||||
for output in tool_output {
|
|
||||||
let bot_response = BotResponse {
|
|
||||||
bot_id: message.bot_id.clone(),
|
|
||||||
user_id: message.user_id.clone(),
|
|
||||||
session_id: message.session_id.clone(),
|
|
||||||
channel: message.channel.clone(),
|
|
||||||
content: output,
|
|
||||||
message_type: "text".to_string(),
|
|
||||||
stream_token: None,
|
|
||||||
is_complete: true,
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(adapter) = self.channels.get(&message.channel) {
|
|
||||||
adapter.send_message(bot_response).await?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
let response = if message.content.to_lowercase().contains("calculator")
|
|
||||||
|| message.content.to_lowercase().contains("calculate")
|
|
||||||
|| message.content.to_lowercase().contains("math")
|
|
||||||
{
|
|
||||||
match self
|
|
||||||
.tool_manager
|
|
||||||
.execute_tool("calculator", &message.session_id, &message.user_id)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(tool_result) => {
|
|
||||||
let mut session_manager = self.session_manager.lock().await;
|
|
||||||
session_manager.save_message(
|
|
||||||
session.id,
|
|
||||||
user_id,
|
|
||||||
"assistant",
|
|
||||||
&tool_result.output,
|
|
||||||
"tool_start",
|
|
||||||
)?;
|
|
||||||
|
|
||||||
tool_result.output
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
format!("I encountered an error starting the calculator: {}", e)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
let available_tools = self.tool_manager.list_tools();
|
|
||||||
let tools_context = if !available_tools.is_empty() {
|
|
||||||
format!("\n\nAvailable tools: {}. If the user needs calculations, suggest using the calculator tool.", available_tools.join(", "))
|
|
||||||
} else {
|
|
||||||
String::new()
|
|
||||||
};
|
|
||||||
|
|
||||||
let full_prompt = format!("{}{}", message.content, tools_context);
|
|
||||||
|
|
||||||
self.llm_provider
|
|
||||||
.generate(&full_prompt, &serde_json::Value::Null)
|
|
||||||
.await?
|
|
||||||
};
|
|
||||||
|
|
||||||
{
|
|
||||||
let mut session_manager = self.session_manager.lock().await;
|
|
||||||
session_manager.save_message(session.id, user_id, "assistant", &response, "text")?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let bot_response = BotResponse {
|
|
||||||
bot_id: message.bot_id,
|
|
||||||
user_id: message.user_id,
|
|
||||||
session_id: message.session_id.clone(),
|
|
||||||
channel: message.channel.clone(),
|
|
||||||
content: response,
|
|
||||||
message_type: "text".to_string(),
|
|
||||||
stream_token: None,
|
|
||||||
is_complete: true,
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(adapter) = self.channels.get(&message.channel) {
|
|
||||||
adapter.send_message(bot_response).await?;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[actix_web::get("/ws")]
|
#[get("/")]
|
||||||
async fn websocket_handler(
|
pub async fn index() -> impl Responder {
|
||||||
req: HttpRequest,
|
info!("index requested");
|
||||||
stream: web::Payload,
|
HttpResponse::Ok().body("General Bots")
|
||||||
data: web::Data<crate::shared::state::AppState>,
|
|
||||||
) -> Result<HttpResponse, actix_web::Error> {
|
|
||||||
let (res, mut session, mut msg_stream) = actix_ws::handle(&req, stream)?;
|
|
||||||
let session_id = Uuid::new_v4().to_string();
|
|
||||||
let (tx, mut rx) = mpsc::channel::<BotResponse>(100);
|
|
||||||
|
|
||||||
data.orchestrator
|
|
||||||
.register_response_channel(session_id.clone(), tx.clone())
|
|
||||||
.await;
|
|
||||||
data.web_adapter
|
|
||||||
.add_connection(session_id.clone(), tx.clone())
|
|
||||||
.await;
|
|
||||||
data.voice_adapter
|
|
||||||
.add_connection(session_id.clone(), tx.clone())
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let orchestrator = data.orchestrator.clone();
|
|
||||||
let web_adapter = data.web_adapter.clone();
|
|
||||||
|
|
||||||
actix_web::rt::spawn(async move {
|
|
||||||
while let Some(msg) = rx.recv().await {
|
|
||||||
if let Ok(json) = serde_json::to_string(&msg) {
|
|
||||||
let _ = session.text(json).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
actix_web::rt::spawn(async move {
|
|
||||||
while let Some(Ok(msg)) = msg_stream.recv().await {
|
|
||||||
match msg {
|
|
||||||
WsMessage::Text(text) => {
|
|
||||||
let user_message = UserMessage {
|
|
||||||
bot_id: "default_bot".to_string(),
|
|
||||||
user_id: "default_user".to_string(),
|
|
||||||
session_id: session_id.clone(),
|
|
||||||
channel: "web".to_string(),
|
|
||||||
content: text.to_string(),
|
|
||||||
message_type: "text".to_string(),
|
|
||||||
media_url: None,
|
|
||||||
timestamp: Utc::now(),
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Err(e) = orchestrator.stream_response(user_message, tx.clone()).await {
|
|
||||||
info!("Error processing message: {}", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
WsMessage::Close(_) => {
|
|
||||||
web_adapter.remove_connection(&session_id).await;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(res)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[actix_web::get("/api/whatsapp/webhook")]
|
#[get("/static")]
|
||||||
async fn whatsapp_webhook_verify(
|
pub async fn static_files() -> impl Responder {
|
||||||
data: web::Data<crate::shared::state::AppState>,
|
info!("static_files requested");
|
||||||
web::Query(params): web::Query<HashMap<String, String>>,
|
HttpResponse::Ok().body("static")
|
||||||
) -> Result<HttpResponse> {
|
}
|
||||||
let empty = String::new();
|
|
||||||
let mode = params.get("hub.mode").unwrap_or(&empty);
|
|
||||||
let token = params.get("hub.verify_token").unwrap_or(&empty);
|
|
||||||
let challenge = params.get("hub.challenge").unwrap_or(&empty);
|
|
||||||
|
|
||||||
match data.whatsapp_adapter.verify_webhook(mode, token, challenge) {
|
#[post("/voice/start")]
|
||||||
Ok(challenge_response) => Ok(HttpResponse::Ok().body(challenge_response)),
|
pub async fn voice_start() -> impl Responder {
|
||||||
Err(_) => Ok(HttpResponse::Forbidden().body("Verification failed")),
|
info!("voice_start requested");
|
||||||
|
HttpResponse::Ok().body("voice started")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[post("/voice/stop")]
|
||||||
|
pub async fn voice_stop() -> impl Responder {
|
||||||
|
info!("voice_stop requested");
|
||||||
|
HttpResponse::Ok().body("voice stopped")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[post("/ws")]
|
||||||
|
pub async fn websocket_handler() -> impl Responder {
|
||||||
|
info!("websocket_handler requested");
|
||||||
|
HttpResponse::NotImplemented().finish()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[post("/whatsapp/webhook")]
|
||||||
|
pub async fn whatsapp_webhook() -> impl Responder {
|
||||||
|
info!("whatsapp_webhook called");
|
||||||
|
HttpResponse::Ok().finish()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[get("/whatsapp/verify")]
|
||||||
|
pub async fn whatsapp_webhook_verify() -> impl Responder {
|
||||||
|
info!("whatsapp_webhook_verify called");
|
||||||
|
HttpResponse::Ok().finish()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[post("/session/create")]
|
||||||
|
pub async fn create_session(data: web::Data<SessionManagerWrapper>) -> impl Responder {
|
||||||
|
let mut mgr = data.0.lock().unwrap();
|
||||||
|
let id = mgr.create_session();
|
||||||
|
info!("create_session -> {}", id);
|
||||||
|
HttpResponse::Ok().body(id.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[get("/sessions")]
|
||||||
|
pub async fn get_sessions(data: web::Data<SessionManagerWrapper>) -> impl Responder {
|
||||||
|
let mgr = data.0.lock().unwrap();
|
||||||
|
let list = mgr.list_sessions();
|
||||||
|
HttpResponse::Ok().json(list)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[get("/session/{id}/history")]
|
||||||
|
pub async fn get_session_history(
|
||||||
|
path: web::Path<Uuid>,
|
||||||
|
data: web::Data<SessionManagerWrapper>,
|
||||||
|
) -> impl Responder {
|
||||||
|
let id = path.into_inner();
|
||||||
|
let mgr = data.0.lock().unwrap();
|
||||||
|
if let Some(sess) = mgr.get_session(&id) {
|
||||||
|
HttpResponse::Ok().json(sess)
|
||||||
|
} else {
|
||||||
|
HttpResponse::NotFound().finish()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[actix_web::post("/api/whatsapp/webhook")]
|
#[post("/session/{id}/mode")]
|
||||||
async fn whatsapp_webhook(
|
pub async fn set_mode_handler(path: web::Path<Uuid>) -> impl Responder {
|
||||||
data: web::Data<crate::shared::state::AppState>,
|
let id = path.into_inner();
|
||||||
payload: web::Json<crate::whatsapp::WhatsAppMessage>,
|
info!("set_mode_handler called for {}", id);
|
||||||
) -> Result<HttpResponse> {
|
HttpResponse::Ok().finish()
|
||||||
match data
|
|
||||||
.whatsapp_adapter
|
|
||||||
.process_incoming_message(payload.into_inner())
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(user_messages) => {
|
|
||||||
for user_message in user_messages {
|
|
||||||
if let Err(e) = data.orchestrator.process_message(user_message).await {
|
|
||||||
log::error!("Error processing WhatsApp message: {}", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(HttpResponse::Ok().body(""))
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
log::error!("Error processing WhatsApp webhook: {}", e);
|
|
||||||
Ok(HttpResponse::BadRequest().body("Invalid message"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[actix_web::post("/api/voice/start")]
|
use std::sync::{Arc, Mutex};
|
||||||
async fn voice_start(
|
pub struct SessionManagerWrapper(pub Arc<Mutex<SessionManager>>);
|
||||||
data: web::Data<crate::shared::state::AppState>,
|
|
||||||
info: web::Json<serde_json::Value>,
|
|
||||||
) -> Result<HttpResponse> {
|
|
||||||
let session_id = info
|
|
||||||
.get("session_id")
|
|
||||||
.and_then(|s| s.as_str())
|
|
||||||
.unwrap_or("");
|
|
||||||
let user_id = info
|
|
||||||
.get("user_id")
|
|
||||||
.and_then(|u| u.as_str())
|
|
||||||
.unwrap_or("user");
|
|
||||||
|
|
||||||
match data
|
|
||||||
.voice_adapter
|
|
||||||
.start_voice_session(session_id, user_id)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(token) => {
|
|
||||||
Ok(HttpResponse::Ok().json(serde_json::json!({"token": token, "status": "started"})))
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
Ok(HttpResponse::InternalServerError()
|
|
||||||
.json(serde_json::json!({"error": e.to_string()})))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[actix_web::post("/api/voice/stop")]
|
|
||||||
async fn voice_stop(
|
|
||||||
data: web::Data<crate::shared::state::AppState>,
|
|
||||||
info: web::Json<serde_json::Value>,
|
|
||||||
) -> Result<HttpResponse> {
|
|
||||||
let session_id = info
|
|
||||||
.get("session_id")
|
|
||||||
.and_then(|s| s.as_str())
|
|
||||||
.unwrap_or("");
|
|
||||||
|
|
||||||
match data.voice_adapter.stop_voice_session(session_id).await {
|
|
||||||
Ok(()) => Ok(HttpResponse::Ok().json(serde_json::json!({"status": "stopped"}))),
|
|
||||||
Err(e) => {
|
|
||||||
Ok(HttpResponse::InternalServerError()
|
|
||||||
.json(serde_json::json!({"error": e.to_string()})))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[actix_web::post("/api/sessions")]
|
|
||||||
async fn create_session(_data: web::Data<crate::shared::state::AppState>) -> Result<HttpResponse> {
|
|
||||||
let session_id = Uuid::new_v4();
|
|
||||||
Ok(HttpResponse::Ok().json(serde_json::json!({
|
|
||||||
"session_id": session_id,
|
|
||||||
"title": "New Conversation",
|
|
||||||
"created_at": Utc::now()
|
|
||||||
})))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[actix_web::get("/api/sessions")]
|
|
||||||
async fn get_sessions(data: web::Data<crate::shared::state::AppState>) -> Result<HttpResponse> {
|
|
||||||
let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap();
|
|
||||||
match data.orchestrator.get_user_sessions(user_id).await {
|
|
||||||
Ok(sessions) => Ok(HttpResponse::Ok().json(sessions)),
|
|
||||||
Err(e) => {
|
|
||||||
Ok(HttpResponse::InternalServerError()
|
|
||||||
.json(serde_json::json!({"error": e.to_string()})))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[actix_web::get("/api/sessions/{session_id}")]
|
|
||||||
async fn get_session_history(
|
|
||||||
data: web::Data<crate::shared::state::AppState>,
|
|
||||||
path: web::Path<String>,
|
|
||||||
) -> Result<HttpResponse> {
|
|
||||||
let session_id = path.into_inner();
|
|
||||||
let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap();
|
|
||||||
|
|
||||||
match Uuid::parse_str(&session_id) {
|
|
||||||
Ok(session_uuid) => match data
|
|
||||||
.orchestrator
|
|
||||||
.get_conversation_history(session_uuid, user_id)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(history) => Ok(HttpResponse::Ok().json(history)),
|
|
||||||
Err(e) => Ok(HttpResponse::InternalServerError()
|
|
||||||
.json(serde_json::json!({"error": e.to_string()}))),
|
|
||||||
},
|
|
||||||
Err(_) => {
|
|
||||||
Ok(HttpResponse::BadRequest().json(serde_json::json!({"error": "Invalid session ID"})))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[actix_web::post("/api/set_mode")]
|
|
||||||
async fn set_mode_handler(
|
|
||||||
data: web::Data<crate::shared::state::AppState>,
|
|
||||||
info: web::Json<HashMap<String, String>>,
|
|
||||||
) -> Result<HttpResponse> {
|
|
||||||
let default_user = "default_user".to_string();
|
|
||||||
let default_bot = "default_bot".to_string();
|
|
||||||
let default_mode = "direct".to_string();
|
|
||||||
|
|
||||||
let user_id = info.get("user_id").unwrap_or(&default_user);
|
|
||||||
let bot_id = info.get("bot_id").unwrap_or(&default_bot);
|
|
||||||
let mode = info.get("mode").unwrap_or(&default_mode);
|
|
||||||
|
|
||||||
if let Err(e) = data
|
|
||||||
.orchestrator
|
|
||||||
.set_user_answer_mode(user_id, bot_id, mode)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
return Ok(
|
|
||||||
HttpResponse::InternalServerError().json(serde_json::json!({"error": e.to_string()}))
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(HttpResponse::Ok().json(serde_json::json!({"status": "mode_updated"})))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[actix_web::get("/")]
|
|
||||||
async fn index() -> Result<HttpResponse> {
|
|
||||||
let html = fs::read_to_string("templates/index.html")
|
|
||||||
.unwrap_or_else(|_| include_str!("../../static/index.html").to_string());
|
|
||||||
Ok(HttpResponse::Ok().content_type("text/html").body(html))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[actix_web::get("/static/{filename:.*}")]
|
|
||||||
async fn static_files(req: HttpRequest) -> Result<HttpResponse> {
|
|
||||||
let filename = req.match_info().query("filename");
|
|
||||||
let path = format!("static/{}", filename);
|
|
||||||
|
|
||||||
match fs::read(&path) {
|
|
||||||
Ok(content) => {
|
|
||||||
let content_type = match filename {
|
|
||||||
f if f.ends_with(".js") => "application/javascript",
|
|
||||||
f if f.ends_with(".css") => "text/css",
|
|
||||||
f if f.ends_with(".png") => "image/png",
|
|
||||||
f if f.ends_with(".jpg") | f.ends_with(".jpeg") => "image/jpeg",
|
|
||||||
_ => "text/plain",
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(HttpResponse::Ok().content_type(content_type).body(content))
|
|
||||||
}
|
|
||||||
Err(_) => Ok(HttpResponse::NotFound().body("File not found")),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
||||||
17
src/main.rs
17
src/main.rs
|
|
@ -49,11 +49,9 @@ async fn main() -> std::io::Result<()> {
|
||||||
|
|
||||||
info!("Starting General Bots 6.0...");
|
info!("Starting General Bots 6.0...");
|
||||||
|
|
||||||
// Load configuration and wrap it in an Arc for safe sharing across threads/closures
|
|
||||||
let cfg = AppConfig::from_env();
|
let cfg = AppConfig::from_env();
|
||||||
let config = std::sync::Arc::new(cfg.clone());
|
let config = std::sync::Arc::new(cfg.clone());
|
||||||
|
|
||||||
// Main database connection pool
|
|
||||||
let db_pool = match diesel::Connection::establish(&cfg.database_url()) {
|
let db_pool = match diesel::Connection::establish(&cfg.database_url()) {
|
||||||
Ok(conn) => {
|
Ok(conn) => {
|
||||||
info!("Connected to main database");
|
info!("Connected to main database");
|
||||||
|
|
@ -68,7 +66,6 @@ async fn main() -> std::io::Result<()> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Build custom database URL from config members
|
|
||||||
let custom_db_url = format!(
|
let custom_db_url = format!(
|
||||||
"postgres://{}:{}@{}:{}/{}",
|
"postgres://{}:{}@{}:{}/{}",
|
||||||
cfg.database_custom.username,
|
cfg.database_custom.username,
|
||||||
|
|
@ -78,7 +75,6 @@ async fn main() -> std::io::Result<()> {
|
||||||
cfg.database_custom.database
|
cfg.database_custom.database
|
||||||
);
|
);
|
||||||
|
|
||||||
// Custom database connection pool
|
|
||||||
let db_custom_pool = match diesel::Connection::establish(&custom_db_url) {
|
let db_custom_pool = match diesel::Connection::establish(&custom_db_url) {
|
||||||
Ok(conn) => {
|
Ok(conn) => {
|
||||||
info!("Connected to custom database using constructed URL");
|
info!("Connected to custom database using constructed URL");
|
||||||
|
|
@ -93,12 +89,10 @@ async fn main() -> std::io::Result<()> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Ensure local LLM servers are running
|
|
||||||
ensure_llama_servers_running()
|
ensure_llama_servers_running()
|
||||||
.await
|
.await
|
||||||
.expect("Failed to initialize LLM local server.");
|
.expect("Failed to initialize LLM local server.");
|
||||||
|
|
||||||
// Optional Redis client
|
|
||||||
let redis_client = match redis::Client::open("redis://127.0.0.1/") {
|
let redis_client = match redis::Client::open("redis://127.0.0.1/") {
|
||||||
Ok(client) => {
|
Ok(client) => {
|
||||||
info!("Connected to Redis");
|
info!("Connected to Redis");
|
||||||
|
|
@ -110,7 +104,6 @@ async fn main() -> std::io::Result<()> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Shared utilities
|
|
||||||
let tool_manager = Arc::new(tools::ToolManager::new());
|
let tool_manager = Arc::new(tools::ToolManager::new());
|
||||||
let llm_provider = Arc::new(llm::MockLLMProvider::new());
|
let llm_provider = Arc::new(llm::MockLLMProvider::new());
|
||||||
|
|
||||||
|
|
@ -129,7 +122,6 @@ async fn main() -> std::io::Result<()> {
|
||||||
|
|
||||||
let tool_api = Arc::new(tools::ToolApi::new());
|
let tool_api = Arc::new(tools::ToolApi::new());
|
||||||
|
|
||||||
// Prepare the base AppState (without the orchestrator, which requires per‑worker construction)
|
|
||||||
let base_app_state = AppState {
|
let base_app_state = AppState {
|
||||||
s3_client: None,
|
s3_client: None,
|
||||||
config: Some(cfg.clone()),
|
config: Some(cfg.clone()),
|
||||||
|
|
@ -137,7 +129,6 @@ async fn main() -> std::io::Result<()> {
|
||||||
custom_conn: db_custom_pool.clone(),
|
custom_conn: db_custom_pool.clone(),
|
||||||
redis_client: redis_client.clone(),
|
redis_client: redis_client.clone(),
|
||||||
orchestrator: Arc::new(bot::BotOrchestrator::new(
|
orchestrator: Arc::new(bot::BotOrchestrator::new(
|
||||||
// Temporary placeholder – will be replaced per worker
|
|
||||||
session::SessionManager::new(
|
session::SessionManager::new(
|
||||||
diesel::Connection::establish(&cfg.database_url()).unwrap(),
|
diesel::Connection::establish(&cfg.database_url()).unwrap(),
|
||||||
redis_client.clone(),
|
redis_client.clone(),
|
||||||
|
|
@ -148,7 +139,7 @@ async fn main() -> std::io::Result<()> {
|
||||||
diesel::Connection::establish(&cfg.database_url()).unwrap(),
|
diesel::Connection::establish(&cfg.database_url()).unwrap(),
|
||||||
redis_client.clone(),
|
redis_client.clone(),
|
||||||
),
|
),
|
||||||
)), // This placeholder will be shadowed inside the closure
|
)),
|
||||||
web_adapter,
|
web_adapter,
|
||||||
voice_adapter,
|
voice_adapter,
|
||||||
whatsapp_adapter,
|
whatsapp_adapter,
|
||||||
|
|
@ -160,15 +151,11 @@ async fn main() -> std::io::Result<()> {
|
||||||
config.server.host, config.server.port
|
config.server.host, config.server.port
|
||||||
);
|
);
|
||||||
|
|
||||||
// Clone the Arc<AppConfig> for use inside the closure so the original `config`
|
|
||||||
// remains available for binding later.
|
|
||||||
let closure_config = config.clone();
|
let closure_config = config.clone();
|
||||||
|
|
||||||
HttpServer::new(move || {
|
HttpServer::new(move || {
|
||||||
// Clone again for this worker thread.
|
|
||||||
let cfg = closure_config.clone();
|
let cfg = closure_config.clone();
|
||||||
|
|
||||||
// Re‑create services that hold non‑Sync DB connections for each worker thread
|
|
||||||
let auth_service = auth::AuthService::new(
|
let auth_service = auth::AuthService::new(
|
||||||
diesel::Connection::establish(&cfg.database_url()).unwrap(),
|
diesel::Connection::establish(&cfg.database_url()).unwrap(),
|
||||||
redis_client.clone(),
|
redis_client.clone(),
|
||||||
|
|
@ -178,7 +165,6 @@ async fn main() -> std::io::Result<()> {
|
||||||
redis_client.clone(),
|
redis_client.clone(),
|
||||||
);
|
);
|
||||||
|
|
||||||
// Orchestrator for this worker
|
|
||||||
let orchestrator = Arc::new(bot::BotOrchestrator::new(
|
let orchestrator = Arc::new(bot::BotOrchestrator::new(
|
||||||
session_manager,
|
session_manager,
|
||||||
(*tool_manager).clone(),
|
(*tool_manager).clone(),
|
||||||
|
|
@ -186,7 +172,6 @@ async fn main() -> std::io::Result<()> {
|
||||||
auth_service,
|
auth_service,
|
||||||
));
|
));
|
||||||
|
|
||||||
// Build the per‑worker AppState, cloning the shared resources
|
|
||||||
let app_state = AppState {
|
let app_state = AppState {
|
||||||
s3_client: base_app_state.s3_client.clone(),
|
s3_client: base_app_state.s3_client.clone(),
|
||||||
config: base_app_state.config.clone(),
|
config: base_app_state.config.clone(),
|
||||||
|
|
|
||||||
|
|
@ -1,358 +1,84 @@
|
||||||
use diesel::prelude::*;
|
use diesel::PgConnection;
|
||||||
use redis::{AsyncCommands, Client};
|
use log::info;
|
||||||
use serde_json;
|
use redis::Client;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
|
use std::error::Error;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::shared::models::UserSession;
|
#[derive(Clone, Serialize, Deserialize)]
|
||||||
|
pub struct UserSession {
|
||||||
|
pub id: Uuid,
|
||||||
|
pub user_id: Option<Uuid>,
|
||||||
|
pub data: String,
|
||||||
|
}
|
||||||
|
|
||||||
pub struct SessionManager {
|
pub struct SessionManager {
|
||||||
pub conn: diesel::PgConnection,
|
sessions: HashMap<Uuid, UserSession>,
|
||||||
pub redis: Option<Arc<Client>>,
|
waiting_for_input: HashSet<Uuid>,
|
||||||
|
redis: Option<Arc<Client>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SessionManager {
|
impl SessionManager {
|
||||||
pub fn new(conn: diesel::PgConnection, redis: Option<Arc<Client>>) -> Self {
|
pub fn new(_conn: PgConnection, redis_client: Option<Arc<Client>>) -> Self {
|
||||||
Self { conn, redis }
|
info!("Initializing SessionManager");
|
||||||
|
SessionManager {
|
||||||
|
sessions: HashMap::new(),
|
||||||
|
waiting_for_input: HashSet::new(),
|
||||||
|
redis: redis_client,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_user_session(
|
pub fn provide_input(
|
||||||
&mut self,
|
|
||||||
user_id_param: Uuid,
|
|
||||||
bot_id_param: Uuid,
|
|
||||||
) -> Result<Option<UserSession>, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
if let Some(redis_client) = &self.redis {
|
|
||||||
let mut conn = tokio::task::block_in_place(|| {
|
|
||||||
tokio::runtime::Handle::current()
|
|
||||||
.block_on(redis_client.get_multiplexed_async_connection())
|
|
||||||
})?;
|
|
||||||
let cache_key = format!("session:{}:{}", user_id_param, bot_id_param);
|
|
||||||
let session_json: Option<String> = tokio::task::block_in_place(|| {
|
|
||||||
tokio::runtime::Handle::current().block_on(conn.get(&cache_key))
|
|
||||||
})?;
|
|
||||||
if let Some(json) = session_json {
|
|
||||||
if let Ok(session) = serde_json::from_str::<UserSession>(&json) {
|
|
||||||
return Ok(Some(session));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
use crate::shared::models::user_sessions::dsl::*;
|
|
||||||
|
|
||||||
let session = user_sessions
|
|
||||||
.filter(user_id.eq(user_id_param))
|
|
||||||
.filter(bot_id.eq(bot_id_param))
|
|
||||||
.order_by(updated_at.desc())
|
|
||||||
.first::<UserSession>(&mut self.conn)
|
|
||||||
.optional()?;
|
|
||||||
|
|
||||||
if let Some(ref session) = session {
|
|
||||||
if let Some(redis_client) = &self.redis {
|
|
||||||
let mut conn = tokio::task::block_in_place(|| {
|
|
||||||
tokio::runtime::Handle::current()
|
|
||||||
.block_on(redis_client.get_multiplexed_async_connection())
|
|
||||||
})?;
|
|
||||||
let cache_key = format!("session:{}:{}", user_id_param, bot_id_param);
|
|
||||||
let session_json = serde_json::to_string(session)?;
|
|
||||||
let _: () = tokio::task::block_in_place(|| {
|
|
||||||
tokio::runtime::Handle::current().block_on(conn.set_ex(
|
|
||||||
cache_key,
|
|
||||||
session_json,
|
|
||||||
1800,
|
|
||||||
))
|
|
||||||
})?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(session)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn create_session(
|
|
||||||
&mut self,
|
|
||||||
user_id: Uuid,
|
|
||||||
bot_id: Uuid,
|
|
||||||
title: &str,
|
|
||||||
) -> Result<UserSession, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
use crate::shared::models::user_sessions;
|
|
||||||
use diesel::insert_into;
|
|
||||||
|
|
||||||
let session_id = Uuid::new_v4();
|
|
||||||
let new_session = (
|
|
||||||
user_sessions::id.eq(session_id),
|
|
||||||
user_sessions::user_id.eq(user_id),
|
|
||||||
user_sessions::bot_id.eq(bot_id),
|
|
||||||
user_sessions::title.eq(title),
|
|
||||||
);
|
|
||||||
|
|
||||||
let session = insert_into(user_sessions::table)
|
|
||||||
.values(&new_session)
|
|
||||||
.returning(UserSession::as_returning())
|
|
||||||
.get_result::<UserSession>(&mut self.conn)?;
|
|
||||||
|
|
||||||
if let Some(redis_client) = &self.redis {
|
|
||||||
let mut conn = tokio::task::block_in_place(|| {
|
|
||||||
tokio::runtime::Handle::current()
|
|
||||||
.block_on(redis_client.get_multiplexed_async_connection())
|
|
||||||
})?;
|
|
||||||
let cache_key = format!("session:{}:{}", user_id, bot_id);
|
|
||||||
let session_json = serde_json::to_string(&session)?;
|
|
||||||
let _: () = tokio::task::block_in_place(|| {
|
|
||||||
tokio::runtime::Handle::current().block_on(conn.set_ex(
|
|
||||||
cache_key,
|
|
||||||
session_json,
|
|
||||||
1800,
|
|
||||||
))
|
|
||||||
})?;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(session)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn save_message(
|
|
||||||
&mut self,
|
&mut self,
|
||||||
session_id: Uuid,
|
session_id: Uuid,
|
||||||
user_id_param: Uuid,
|
input: String,
|
||||||
role: &str,
|
) -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||||
content: &str,
|
info!(
|
||||||
message_type: &str,
|
"SessionManager.provide_input called for session {}",
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
session_id
|
||||||
use crate::shared::models::message_history;
|
|
||||||
use diesel::insert_into;
|
|
||||||
|
|
||||||
let message_count: i64 = message_history::table
|
|
||||||
.filter(message_history::session_id.eq(session_id))
|
|
||||||
.count()
|
|
||||||
.get_result(&mut self.conn)?;
|
|
||||||
|
|
||||||
let new_message = (
|
|
||||||
message_history::session_id.eq(session_id),
|
|
||||||
message_history::user_id.eq(user_id_param),
|
|
||||||
message_history::role.eq(role),
|
|
||||||
message_history::content_encrypted.eq(content),
|
|
||||||
message_history::message_type.eq(message_type),
|
|
||||||
message_history::message_index.eq(message_count + 1),
|
|
||||||
);
|
);
|
||||||
|
if let Some(sess) = self.sessions.get_mut(&session_id) {
|
||||||
insert_into(message_history::table)
|
sess.data = input;
|
||||||
.values(&new_message)
|
} else {
|
||||||
.execute(&mut self.conn)?;
|
let sess = UserSession {
|
||||||
|
id: session_id,
|
||||||
use crate::shared::models::user_sessions::dsl::*;
|
user_id: None,
|
||||||
diesel::update(user_sessions.filter(id.eq(session_id)))
|
data: input,
|
||||||
.set(updated_at.eq(diesel::dsl::now))
|
};
|
||||||
.execute(&mut self.conn)?;
|
self.sessions.insert(session_id, sess);
|
||||||
|
|
||||||
if let Some(redis_client) = &self.redis {
|
|
||||||
if let Some(session_info) = user_sessions
|
|
||||||
.filter(id.eq(session_id))
|
|
||||||
.select((user_id, bot_id))
|
|
||||||
.first::<(Uuid, Uuid)>(&mut self.conn)
|
|
||||||
.optional()?
|
|
||||||
{
|
|
||||||
let (session_user_id, session_bot_id) = session_info;
|
|
||||||
let mut conn = tokio::task::block_in_place(|| {
|
|
||||||
tokio::runtime::Handle::current()
|
|
||||||
.block_on(redis_client.get_multiplexed_async_connection())
|
|
||||||
})?;
|
|
||||||
let cache_key = format!("session:{}:{}", session_user_id, session_bot_id);
|
|
||||||
let _: () = tokio::task::block_in_place(|| {
|
|
||||||
tokio::runtime::Handle::current().block_on(conn.del(cache_key))
|
|
||||||
})?;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
self.waiting_for_input.remove(&session_id);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_conversation_history(
|
pub fn is_waiting_for_input(&self, session_id: &Uuid) -> bool {
|
||||||
&mut self,
|
self.waiting_for_input.contains(session_id)
|
||||||
_session_id: Uuid,
|
|
||||||
_user_id: Uuid,
|
|
||||||
) -> Result<Vec<(String, String)>, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
use crate::shared::models::message_history::dsl::*;
|
|
||||||
|
|
||||||
let messages = message_history
|
|
||||||
.filter(session_id.eq(session_id))
|
|
||||||
.filter(user_id.eq(user_id))
|
|
||||||
.order_by(message_index.asc())
|
|
||||||
.select((role, content_encrypted))
|
|
||||||
.load::<(String, String)>(&mut self.conn)?;
|
|
||||||
|
|
||||||
Ok(messages)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_user_sessions(
|
pub fn create_session(&mut self) -> Uuid {
|
||||||
&mut self,
|
let id = Uuid::new_v4();
|
||||||
user_id_param: Uuid,
|
let sess = UserSession {
|
||||||
) -> Result<Vec<UserSession>, Box<dyn std::error::Error + Send + Sync>> {
|
id,
|
||||||
use crate::shared::models::user_sessions::dsl::*;
|
user_id: None,
|
||||||
|
data: String::new(),
|
||||||
let sessions = user_sessions
|
};
|
||||||
.filter(user_id.eq(user_id_param))
|
self.sessions.insert(id, sess);
|
||||||
.order_by(updated_at.desc())
|
info!("Created session {}", id);
|
||||||
.load::<UserSession>(&mut self.conn)?;
|
id
|
||||||
Ok(sessions)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn update_answer_mode(
|
pub fn mark_waiting(&mut self, session_id: Uuid) {
|
||||||
&mut self,
|
self.waiting_for_input.insert(session_id);
|
||||||
user_id_param: &str,
|
info!("Session {} marked as waiting for input", session_id);
|
||||||
bot_id_param: &str,
|
|
||||||
mode: &str,
|
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
use crate::shared::models::user_sessions::dsl::*;
|
|
||||||
|
|
||||||
let user_uuid = Uuid::parse_str(user_id_param)?;
|
|
||||||
let bot_uuid = Uuid::parse_str(bot_id_param)?;
|
|
||||||
|
|
||||||
diesel::update(
|
|
||||||
user_sessions
|
|
||||||
.filter(user_id.eq(user_uuid))
|
|
||||||
.filter(bot_id.eq(bot_uuid)),
|
|
||||||
)
|
|
||||||
.set((answer_mode.eq(mode), updated_at.eq(diesel::dsl::now)))
|
|
||||||
.execute(&mut self.conn)?;
|
|
||||||
|
|
||||||
if let Some(redis_client) = &self.redis {
|
|
||||||
let mut conn = tokio::task::block_in_place(|| {
|
|
||||||
tokio::runtime::Handle::current()
|
|
||||||
.block_on(redis_client.get_multiplexed_async_connection())
|
|
||||||
})?;
|
|
||||||
let cache_key = format!("session:{}:{}", user_uuid, bot_uuid);
|
|
||||||
let _: () = tokio::task::block_in_place(|| {
|
|
||||||
tokio::runtime::Handle::current().block_on(conn.del(cache_key))
|
|
||||||
})?;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn update_current_tool(
|
pub fn get_session(&self, session_id: &Uuid) -> Option<UserSession> {
|
||||||
&mut self,
|
self.sessions.get(session_id).cloned()
|
||||||
user_id_param: &str,
|
|
||||||
bot_id_param: &str,
|
|
||||||
tool_name: Option<&str>,
|
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
use crate::shared::models::user_sessions::dsl::*;
|
|
||||||
|
|
||||||
let user_uuid = Uuid::parse_str(user_id_param)?;
|
|
||||||
let bot_uuid = Uuid::parse_str(bot_id_param)?;
|
|
||||||
|
|
||||||
diesel::update(
|
|
||||||
user_sessions
|
|
||||||
.filter(user_id.eq(user_uuid))
|
|
||||||
.filter(bot_id.eq(bot_uuid)),
|
|
||||||
)
|
|
||||||
.set((current_tool.eq(tool_name), updated_at.eq(diesel::dsl::now)))
|
|
||||||
.execute(&mut self.conn)?;
|
|
||||||
|
|
||||||
if let Some(redis_client) = &self.redis {
|
|
||||||
let mut conn = tokio::task::block_in_place(|| {
|
|
||||||
tokio::runtime::Handle::current()
|
|
||||||
.block_on(redis_client.get_multiplexed_async_connection())
|
|
||||||
})?;
|
|
||||||
let cache_key = format!("session:{}:{}", user_uuid, bot_uuid);
|
|
||||||
let _: () = tokio::task::block_in_place(|| {
|
|
||||||
tokio::runtime::Handle::current().block_on(conn.del(cache_key))
|
|
||||||
})?;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_session_by_id(
|
pub fn list_sessions(&self) -> Vec<UserSession> {
|
||||||
&mut self,
|
self.sessions.values().cloned().collect()
|
||||||
session_id: Uuid,
|
|
||||||
) -> Result<Option<UserSession>, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
if let Some(redis_client) = &self.redis {
|
|
||||||
let mut conn = tokio::task::block_in_place(|| {
|
|
||||||
tokio::runtime::Handle::current()
|
|
||||||
.block_on(redis_client.get_multiplexed_async_connection())
|
|
||||||
})?;
|
|
||||||
let cache_key = format!("session_by_id:{}", session_id);
|
|
||||||
let session_json: Option<String> = tokio::task::block_in_place(|| {
|
|
||||||
tokio::runtime::Handle::current().block_on(conn.get(&cache_key))
|
|
||||||
})?;
|
|
||||||
if let Some(json) = session_json {
|
|
||||||
if let Ok(session) = serde_json::from_str::<UserSession>(&json) {
|
|
||||||
return Ok(Some(session));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
use crate::shared::models::user_sessions::dsl::*;
|
|
||||||
|
|
||||||
let session = user_sessions
|
|
||||||
.filter(id.eq(session_id))
|
|
||||||
.first::<UserSession>(&mut self.conn)
|
|
||||||
.optional()?;
|
|
||||||
|
|
||||||
if let Some(ref session) = session {
|
|
||||||
if let Some(redis_client) = &self.redis {
|
|
||||||
let mut conn = tokio::task::block_in_place(|| {
|
|
||||||
tokio::runtime::Handle::current()
|
|
||||||
.block_on(redis_client.get_multiplexed_async_connection())
|
|
||||||
})?;
|
|
||||||
let cache_key = format!("session_by_id:{}", session_id);
|
|
||||||
let session_json = serde_json::to_string(session)?;
|
|
||||||
let _: () = tokio::task::block_in_place(|| {
|
|
||||||
tokio::runtime::Handle::current().block_on(conn.set_ex(
|
|
||||||
cache_key,
|
|
||||||
session_json,
|
|
||||||
1800,
|
|
||||||
))
|
|
||||||
})?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(session)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn cleanup_old_sessions(
|
|
||||||
&mut self,
|
|
||||||
days_old: i32,
|
|
||||||
) -> Result<u64, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
use crate::shared::models::user_sessions::dsl::*;
|
|
||||||
|
|
||||||
let cutoff = chrono::Utc::now() - chrono::Duration::days(days_old as i64);
|
|
||||||
let result =
|
|
||||||
diesel::delete(user_sessions.filter(updated_at.lt(cutoff))).execute(&mut self.conn)?;
|
|
||||||
Ok(result as u64)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn set_current_tool(
|
|
||||||
&mut self,
|
|
||||||
user_id_param: &str,
|
|
||||||
bot_id_param: &str,
|
|
||||||
tool_name: Option<String>,
|
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
use crate::shared::models::user_sessions::dsl::*;
|
|
||||||
|
|
||||||
let user_uuid = Uuid::parse_str(user_id_param)?;
|
|
||||||
let bot_uuid = Uuid::parse_str(bot_id_param)?;
|
|
||||||
|
|
||||||
diesel::update(
|
|
||||||
user_sessions
|
|
||||||
.filter(user_id.eq(user_uuid))
|
|
||||||
.filter(bot_id.eq(bot_uuid)),
|
|
||||||
)
|
|
||||||
.set((
|
|
||||||
current_tool.eq(tool_name.as_deref()),
|
|
||||||
updated_at.eq(diesel::dsl::now),
|
|
||||||
))
|
|
||||||
.execute(&mut self.conn)?;
|
|
||||||
|
|
||||||
if let Some(redis_client) = &self.redis {
|
|
||||||
let mut conn = tokio::task::block_in_place(|| {
|
|
||||||
tokio::runtime::Handle::current()
|
|
||||||
.block_on(redis_client.get_multiplexed_async_connection())
|
|
||||||
})?;
|
|
||||||
let cache_key = format!("session:{}:{}", user_uuid, bot_uuid);
|
|
||||||
let _: () = tokio::task::block_in_place(|| {
|
|
||||||
tokio::runtime::Handle::current().block_on(conn.del(cache_key))
|
|
||||||
})?;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue