Update attendance, keywords, calendar, compliance, console, core, drive, email, llm, msteams, security, and tasks modules

This commit is contained in:
Rodrigo Rodriguez (Pragmatismo) 2025-12-24 09:29:27 -03:00
parent 883c6d07e1
commit 7cbfe43319
33 changed files with 525 additions and 2412 deletions

View file

@ -2,7 +2,7 @@
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use chrono::{DateTime, Duration, Local, NaiveTime, Utc}; use chrono::{DateTime, Duration, Local, Utc};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;

View file

@ -1,40 +1,3 @@
use crate::core::config::ConfigManager; use crate::core::config::ConfigManager;
use crate::shared::models::UserSession; use crate::shared::models::UserSession;
use crate::shared::state::AppState; use crate::shared::state::AppState;
@ -46,18 +9,14 @@ use axum::{
}; };
use chrono::Utc; use chrono::Utc;
use diesel::prelude::*; use diesel::prelude::*;
use log::{debug, error, info, warn}; use log::{error, info, warn};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
use uuid::Uuid; use uuid::Uuid;
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
pub struct LlmAssistConfig { pub struct LlmAssistConfig {
pub tips_enabled: bool, pub tips_enabled: bool,
pub polish_enabled: bool, pub polish_enabled: bool,
@ -74,7 +33,6 @@ pub struct LlmAssistConfig {
} }
impl LlmAssistConfig { impl LlmAssistConfig {
pub fn from_config(bot_id: Uuid, work_path: &str) -> Self { pub fn from_config(bot_id: Uuid, work_path: &str) -> Self {
let config_path = PathBuf::from(work_path) let config_path = PathBuf::from(work_path)
.join(format!("{}.gbai", bot_id)) .join(format!("{}.gbai", bot_id))
@ -143,7 +101,6 @@ impl LlmAssistConfig {
config config
} }
pub fn any_enabled(&self) -> bool { pub fn any_enabled(&self) -> bool {
self.tips_enabled self.tips_enabled
|| self.polish_enabled || self.polish_enabled
@ -153,9 +110,6 @@ impl LlmAssistConfig {
} }
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct TipRequest { pub struct TipRequest {
pub session_id: Uuid, pub session_id: Uuid,
@ -165,7 +119,6 @@ pub struct TipRequest {
pub history: Vec<ConversationMessage>, pub history: Vec<ConversationMessage>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct PolishRequest { pub struct PolishRequest {
pub session_id: Uuid, pub session_id: Uuid,
@ -179,7 +132,6 @@ fn default_tone() -> String {
"professional".to_string() "professional".to_string()
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct SmartRepliesRequest { pub struct SmartRepliesRequest {
pub session_id: Uuid, pub session_id: Uuid,
@ -187,13 +139,11 @@ pub struct SmartRepliesRequest {
pub history: Vec<ConversationMessage>, pub history: Vec<ConversationMessage>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct SummaryRequest { pub struct SummaryRequest {
pub session_id: Uuid, pub session_id: Uuid,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct SentimentRequest { pub struct SentimentRequest {
pub session_id: Uuid, pub session_id: Uuid,
@ -202,7 +152,6 @@ pub struct SentimentRequest {
pub history: Vec<ConversationMessage>, pub history: Vec<ConversationMessage>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConversationMessage { pub struct ConversationMessage {
pub role: String, pub role: String,
@ -210,7 +159,6 @@ pub struct ConversationMessage {
pub timestamp: Option<String>, pub timestamp: Option<String>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
pub struct TipResponse { pub struct TipResponse {
pub success: bool, pub success: bool,
@ -219,7 +167,6 @@ pub struct TipResponse {
pub error: Option<String>, pub error: Option<String>,
} }
#[derive(Debug, Clone, Serialize)] #[derive(Debug, Clone, Serialize)]
pub struct AttendantTip { pub struct AttendantTip {
pub tip_type: TipType, pub tip_type: TipType,
@ -228,11 +175,9 @@ pub struct AttendantTip {
pub priority: i32, pub priority: i32,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum TipType { pub enum TipType {
Intent, Intent,
Action, Action,
@ -246,7 +191,6 @@ pub enum TipType {
General, General,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
pub struct PolishResponse { pub struct PolishResponse {
pub success: bool, pub success: bool,
@ -257,7 +201,6 @@ pub struct PolishResponse {
pub error: Option<String>, pub error: Option<String>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
pub struct SmartRepliesResponse { pub struct SmartRepliesResponse {
pub success: bool, pub success: bool,
@ -266,7 +209,6 @@ pub struct SmartRepliesResponse {
pub error: Option<String>, pub error: Option<String>,
} }
#[derive(Debug, Clone, Serialize)] #[derive(Debug, Clone, Serialize)]
pub struct SmartReply { pub struct SmartReply {
pub text: String, pub text: String,
@ -275,7 +217,6 @@ pub struct SmartReply {
pub category: String, pub category: String,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
pub struct SummaryResponse { pub struct SummaryResponse {
pub success: bool, pub success: bool,
@ -284,7 +225,6 @@ pub struct SummaryResponse {
pub error: Option<String>, pub error: Option<String>,
} }
#[derive(Debug, Clone, Serialize, Default)] #[derive(Debug, Clone, Serialize, Default)]
pub struct ConversationSummary { pub struct ConversationSummary {
pub brief: String, pub brief: String,
@ -297,7 +237,6 @@ pub struct ConversationSummary {
pub duration_minutes: i32, pub duration_minutes: i32,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
pub struct SentimentResponse { pub struct SentimentResponse {
pub success: bool, pub success: bool,
@ -306,7 +245,6 @@ pub struct SentimentResponse {
pub error: Option<String>, pub error: Option<String>,
} }
#[derive(Debug, Clone, Serialize, Default)] #[derive(Debug, Clone, Serialize, Default)]
pub struct SentimentAnalysis { pub struct SentimentAnalysis {
pub overall: String, pub overall: String,
@ -317,16 +255,12 @@ pub struct SentimentAnalysis {
pub emoji: String, pub emoji: String,
} }
#[derive(Debug, Clone, Serialize)] #[derive(Debug, Clone, Serialize)]
pub struct Emotion { pub struct Emotion {
pub name: String, pub name: String,
pub intensity: f32, pub intensity: f32,
} }
async fn execute_llm_with_context( async fn execute_llm_with_context(
state: &Arc<AppState>, state: &Arc<AppState>,
bot_id: Uuid, bot_id: Uuid,
@ -351,7 +285,6 @@ async fn execute_llm_with_context(
.unwrap_or_default() .unwrap_or_default()
}); });
let messages = serde_json::json!([ let messages = serde_json::json!([
{ {
"role": "system", "role": "system",
@ -368,29 +301,24 @@ async fn execute_llm_with_context(
.generate(user_prompt, &messages, &model, &key) .generate(user_prompt, &messages, &model, &key)
.await?; .await?;
let handler = crate::llm::llm_models::get_handler(&model); let handler = crate::llm::llm_models::get_handler(&model);
let processed = handler.process_content(&response); let processed = handler.process_content(&response);
Ok(processed) Ok(processed)
} }
fn get_bot_system_prompt(bot_id: Uuid, work_path: &str) -> String { fn get_bot_system_prompt(bot_id: Uuid, work_path: &str) -> String {
let config = LlmAssistConfig::from_config(bot_id, work_path); let config = LlmAssistConfig::from_config(bot_id, work_path);
if let Some(prompt) = config.bot_system_prompt { if let Some(prompt) = config.bot_system_prompt {
return prompt; return prompt;
} }
let start_bas_path = PathBuf::from(work_path) let start_bas_path = PathBuf::from(work_path)
.join(format!("{}.gbai", bot_id)) .join(format!("{}.gbai", bot_id))
.join(format!("{}.gbdialog", bot_id)) .join(format!("{}.gbdialog", bot_id))
.join("start.bas"); .join("start.bas");
if let Ok(content) = std::fs::read_to_string(&start_bas_path) { if let Ok(content) = std::fs::read_to_string(&start_bas_path) {
let mut description_lines = Vec::new(); let mut description_lines = Vec::new();
for line in content.lines() { for line in content.lines() {
let trimmed = line.trim(); let trimmed = line.trim();
@ -406,21 +334,15 @@ fn get_bot_system_prompt(bot_id: Uuid, work_path: &str) -> String {
} }
} }
"You are a professional customer service assistant. Be helpful, empathetic, and solution-oriented. Maintain a friendly but professional tone.".to_string() "You are a professional customer service assistant. Be helpful, empathetic, and solution-oriented. Maintain a friendly but professional tone.".to_string()
} }
pub async fn generate_tips( pub async fn generate_tips(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
Json(request): Json<TipRequest>, Json(request): Json<TipRequest>,
) -> (StatusCode, Json<TipResponse>) { ) -> (StatusCode, Json<TipResponse>) {
info!("Generating tips for session {}", request.session_id); info!("Generating tips for session {}", request.session_id);
let session_result = get_session(&state, request.session_id).await; let session_result = get_session(&state, request.session_id).await;
let session = match session_result { let session = match session_result {
Ok(s) => s, Ok(s) => s,
@ -436,7 +358,6 @@ pub async fn generate_tips(
} }
}; };
let work_path = std::env::var("WORK_PATH").unwrap_or_else(|_| "./work".to_string()); let work_path = std::env::var("WORK_PATH").unwrap_or_else(|_| "./work".to_string());
let config = LlmAssistConfig::from_config(session.bot_id, &work_path); let config = LlmAssistConfig::from_config(session.bot_id, &work_path);
@ -451,7 +372,6 @@ pub async fn generate_tips(
); );
} }
let history_context = request let history_context = request
.history .history
.iter() .iter()
@ -497,7 +417,6 @@ Provide tips for the attendant."#,
match execute_llm_with_context(&state, session.bot_id, &system_prompt, &user_prompt).await { match execute_llm_with_context(&state, session.bot_id, &system_prompt, &user_prompt).await {
Ok(response) => { Ok(response) => {
let tips = parse_tips_response(&response); let tips = parse_tips_response(&response);
( (
StatusCode::OK, StatusCode::OK,
@ -523,8 +442,6 @@ Provide tips for the attendant."#,
} }
} }
pub async fn polish_message( pub async fn polish_message(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
Json(request): Json<PolishRequest>, Json(request): Json<PolishRequest>,
@ -621,8 +538,6 @@ Respond in JSON format:
} }
} }
pub async fn generate_smart_replies( pub async fn generate_smart_replies(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
Json(request): Json<SmartRepliesRequest>, Json(request): Json<SmartRepliesRequest>,
@ -677,7 +592,7 @@ The service has this personality: {}
Generate exactly 3 reply suggestions that: Generate exactly 3 reply suggestions that:
1. Are contextually appropriate 1. Are contextually appropriate
2. Sound natural and human (not robotic) 2. Sound natural and human (not robotic)
3. Vary in approach (one empathetic, one solution-focused, one follow-up) 3. Vary in approach (one empathetic, one solution-focused, one follow_up)
4. Are ready to send (no placeholders like [name]) 4. Are ready to send (no placeholders like [name])
Respond in JSON format: Respond in JSON format:
@ -692,10 +607,10 @@ Respond in JSON format:
); );
let user_prompt = format!( let user_prompt = format!(
r#"Conversation: r"Conversation:
{} {}
Generate 3 reply options for the attendant."#, Generate 3 reply options for the attendant.",
history_context history_context
); );
@ -725,8 +640,6 @@ Generate 3 reply options for the attendant."#,
} }
} }
pub async fn generate_summary( pub async fn generate_summary(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
Path(session_id): Path<Uuid>, Path(session_id): Path<Uuid>,
@ -762,7 +675,6 @@ pub async fn generate_summary(
); );
} }
let history = load_conversation_history(&state, session_id).await; let history = load_conversation_history(&state, session_id).await;
if history.is_empty() { if history.is_empty() {
@ -806,9 +718,9 @@ Respond in JSON format:
); );
let user_prompt = format!( let user_prompt = format!(
r#"Summarize this conversation: r"Summarize this conversation:
{}"#, {}",
history_text history_text
); );
@ -817,7 +729,6 @@ Respond in JSON format:
let mut summary = parse_summary_response(&response); let mut summary = parse_summary_response(&response);
summary.message_count = history.len() as i32; summary.message_count = history.len() as i32;
if let (Some(first_ts), Some(last_ts)) = ( if let (Some(first_ts), Some(last_ts)) = (
history.first().and_then(|m| m.timestamp.as_ref()), history.first().and_then(|m| m.timestamp.as_ref()),
history.last().and_then(|m| m.timestamp.as_ref()), history.last().and_then(|m| m.timestamp.as_ref()),
@ -857,8 +768,6 @@ Respond in JSON format:
} }
} }
pub async fn analyze_sentiment( pub async fn analyze_sentiment(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
Json(request): Json<SentimentRequest>, Json(request): Json<SentimentRequest>,
@ -884,7 +793,6 @@ pub async fn analyze_sentiment(
let config = LlmAssistConfig::from_config(session.bot_id, &work_path); let config = LlmAssistConfig::from_config(session.bot_id, &work_path);
if !config.sentiment_enabled { if !config.sentiment_enabled {
let sentiment = analyze_sentiment_keywords(&request.message); let sentiment = analyze_sentiment_keywords(&request.message);
return ( return (
StatusCode::OK, StatusCode::OK,
@ -959,8 +867,6 @@ Analyze the customer's sentiment."#,
} }
} }
pub async fn get_llm_config( pub async fn get_llm_config(
State(_state): State<Arc<AppState>>, State(_state): State<Arc<AppState>>,
Path(bot_id): Path<Uuid>, Path(bot_id): Path<Uuid>,
@ -981,9 +887,6 @@ pub async fn get_llm_config(
) )
} }
pub async fn process_attendant_command( pub async fn process_attendant_command(
state: &Arc<AppState>, state: &Arc<AppState>,
attendant_phone: &str, attendant_phone: &str,
@ -1020,7 +923,6 @@ pub async fn process_attendant_command(
} }
async fn handle_queue_command(state: &Arc<AppState>) -> Result<String, String> { async fn handle_queue_command(state: &Arc<AppState>) -> Result<String, String> {
let conn = state.conn.clone(); let conn = state.conn.clone();
let result = tokio::task::spawn_blocking(move || { let result = tokio::task::spawn_blocking(move || {
let mut db_conn = conn.get().map_err(|e| e.to_string())?; let mut db_conn = conn.get().map_err(|e| e.to_string())?;
@ -1174,8 +1076,6 @@ async fn handle_status_command(
} }
}; };
let conn = state.conn.clone(); let conn = state.conn.clone();
let phone = attendant_phone.to_string(); let phone = attendant_phone.to_string();
let status_val = status_value.to_string(); let status_val = status_value.to_string();
@ -1185,8 +1085,6 @@ async fn handle_status_command(
use crate::shared::models::schema::user_sessions; use crate::shared::models::schema::user_sessions;
let sessions: Vec<UserSession> = user_sessions::table let sessions: Vec<UserSession> = user_sessions::table
.filter( .filter(
user_sessions::context_data user_sessions::context_data
@ -1243,7 +1141,6 @@ async fn handle_transfer_command(
let target = args.join(" "); let target = args.join(" ");
let target_clean = target.trim_start_matches('@').to_string(); let target_clean = target.trim_start_matches('@').to_string();
let conn = state.conn.clone(); let conn = state.conn.clone();
let target_attendant = target_clean.clone(); let target_attendant = target_clean.clone();
@ -1252,13 +1149,11 @@ async fn handle_transfer_command(
use crate::shared::models::schema::user_sessions; use crate::shared::models::schema::user_sessions;
let session: UserSession = user_sessions::table let session: UserSession = user_sessions::table
.find(session_id) .find(session_id)
.first(&mut db_conn) .first(&mut db_conn)
.map_err(|e| format!("Session not found: {}", e))?; .map_err(|e| format!("Session not found: {}", e))?;
let mut ctx = session.context_data.clone(); let mut ctx = session.context_data.clone();
let previous_attendant = ctx let previous_attendant = ctx
.get("assigned_to_phone") .get("assigned_to_phone")
@ -1271,11 +1166,9 @@ async fn handle_transfer_command(
ctx["transferred_at"] = serde_json::json!(Utc::now().to_rfc3339()); ctx["transferred_at"] = serde_json::json!(Utc::now().to_rfc3339());
ctx["status"] = serde_json::json!("pending_transfer"); ctx["status"] = serde_json::json!("pending_transfer");
ctx["assigned_to_phone"] = serde_json::Value::Null; ctx["assigned_to_phone"] = serde_json::Value::Null;
ctx["assigned_to"] = serde_json::Value::Null; ctx["assigned_to"] = serde_json::Value::Null;
ctx["needs_human"] = serde_json::json!(true); ctx["needs_human"] = serde_json::json!(true);
diesel::update(user_sessions::table.filter(user_sessions::id.eq(session_id))) diesel::update(user_sessions::table.filter(user_sessions::id.eq(session_id)))
@ -1347,7 +1240,6 @@ async fn handle_tips_command(
) -> Result<String, String> { ) -> Result<String, String> {
let session_id = current_session.ok_or("No active conversation. Use /take first.")?; let session_id = current_session.ok_or("No active conversation. Use /take first.")?;
let history = load_conversation_history(state, session_id).await; let history = load_conversation_history(state, session_id).await;
if history.is_empty() { if history.is_empty() {
@ -1369,7 +1261,6 @@ async fn handle_tips_command(
history, history,
}; };
let (_, Json(tip_response)) = generate_tips(State(state.clone()), Json(request)).await; let (_, Json(tip_response)) = generate_tips(State(state.clone()), Json(request)).await;
if tip_response.tips.is_empty() { if tip_response.tips.is_empty() {
@ -1446,7 +1337,8 @@ async fn handle_replies_command(
history, history,
}; };
let (_, Json(replies_response)) = generate_smart_replies(State(state.clone()), Json(request)).await; let (_, Json(replies_response)) =
generate_smart_replies(State(state.clone()), Json(request)).await;
if replies_response.replies.is_empty() { if replies_response.replies.is_empty() {
return Ok(" No reply suggestions available.".to_string()); return Ok(" No reply suggestions available.".to_string());
@ -1474,7 +1366,8 @@ async fn handle_summary_command(
) -> Result<String, String> { ) -> Result<String, String> {
let session_id = current_session.ok_or("No active conversation")?; let session_id = current_session.ok_or("No active conversation")?;
let (_, Json(summary_response)) = generate_summary(State(state.clone()), Path(session_id)).await; let (_, Json(summary_response)) =
generate_summary(State(state.clone()), Path(session_id)).await;
if !summary_response.success { if !summary_response.success {
return Err(summary_response return Err(summary_response
@ -1527,7 +1420,7 @@ async fn handle_summary_command(
} }
fn get_help_text() -> String { fn get_help_text() -> String {
r#" *Attendant Commands* r#"*Attendant Commands*
*Queue Management:* *Queue Management:*
`/queue` - View waiting conversations `/queue` - View waiting conversations
@ -1549,9 +1442,6 @@ _Portuguese: /fila, /pegar, /transferir, /resolver, /dicas, /polir, /respostas,
.to_string() .to_string()
} }
async fn get_session(state: &Arc<AppState>, session_id: Uuid) -> Result<UserSession, String> { async fn get_session(state: &Arc<AppState>, session_id: Uuid) -> Result<UserSession, String> {
let conn = state.conn.clone(); let conn = state.conn.clone();
@ -1569,7 +1459,6 @@ async fn get_session(state: &Arc<AppState>, session_id: Uuid) -> Result<UserSess
.map_err(|e| format!("Task error: {}", e))? .map_err(|e| format!("Task error: {}", e))?
} }
async fn load_conversation_history( async fn load_conversation_history(
state: &Arc<AppState>, state: &Arc<AppState>,
session_id: Uuid, session_id: Uuid,
@ -1616,9 +1505,7 @@ async fn load_conversation_history(
result result
} }
fn parse_tips_response(response: &str) -> Vec<AttendantTip> { fn parse_tips_response(response: &str) -> Vec<AttendantTip> {
let json_str = extract_json(response); let json_str = extract_json(response);
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&json_str) { if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&json_str) {
@ -1653,7 +1540,6 @@ fn parse_tips_response(response: &str) -> Vec<AttendantTip> {
} }
} }
if !response.trim().is_empty() { if !response.trim().is_empty() {
vec![AttendantTip { vec![AttendantTip {
tip_type: TipType::General, tip_type: TipType::General,
@ -1666,7 +1552,6 @@ fn parse_tips_response(response: &str) -> Vec<AttendantTip> {
} }
} }
fn parse_polish_response(response: &str, original: &str) -> (String, Vec<String>) { fn parse_polish_response(response: &str, original: &str) -> (String, Vec<String>) {
let json_str = extract_json(response); let json_str = extract_json(response);
@ -1690,14 +1575,12 @@ fn parse_polish_response(response: &str, original: &str) -> (String, Vec<String>
return (polished, changes); return (polished, changes);
} }
( (
response.trim().to_string(), response.trim().to_string(),
vec!["Message improved".to_string()], vec!["Message improved".to_string()],
) )
} }
fn parse_smart_replies_response(response: &str) -> Vec<SmartReply> { fn parse_smart_replies_response(response: &str) -> Vec<SmartReply> {
let json_str = extract_json(response); let json_str = extract_json(response);
@ -1731,7 +1614,6 @@ fn parse_smart_replies_response(response: &str) -> Vec<SmartReply> {
generate_fallback_replies() generate_fallback_replies()
} }
fn parse_summary_response(response: &str) -> ConversationSummary { fn parse_summary_response(response: &str) -> ConversationSummary {
let json_str = extract_json(response); let json_str = extract_json(response);
@ -1789,7 +1671,6 @@ fn parse_summary_response(response: &str) -> ConversationSummary {
} }
} }
fn parse_sentiment_response(response: &str) -> SentimentAnalysis { fn parse_sentiment_response(response: &str) -> SentimentAnalysis {
let json_str = extract_json(response); let json_str = extract_json(response);
@ -1839,9 +1720,7 @@ fn parse_sentiment_response(response: &str) -> SentimentAnalysis {
SentimentAnalysis::default() SentimentAnalysis::default()
} }
fn extract_json(response: &str) -> String { fn extract_json(response: &str) -> String {
if let Some(start) = response.find('{') { if let Some(start) = response.find('{') {
if let Some(end) = response.rfind('}') { if let Some(end) = response.rfind('}') {
if end > start { if end > start {
@ -1850,7 +1729,6 @@ fn extract_json(response: &str) -> String {
} }
} }
if let Some(start) = response.find('[') { if let Some(start) = response.find('[') {
if let Some(end) = response.rfind(']') { if let Some(end) = response.rfind(']') {
if end > start { if end > start {
@ -1862,12 +1740,10 @@ fn extract_json(response: &str) -> String {
response.to_string() response.to_string()
} }
fn generate_fallback_tips(message: &str) -> Vec<AttendantTip> { fn generate_fallback_tips(message: &str) -> Vec<AttendantTip> {
let msg_lower = message.to_lowercase(); let msg_lower = message.to_lowercase();
let mut tips = Vec::new(); let mut tips = Vec::new();
if msg_lower.contains("urgent") if msg_lower.contains("urgent")
|| msg_lower.contains("asap") || msg_lower.contains("asap")
|| msg_lower.contains("immediately") || msg_lower.contains("immediately")
@ -1881,7 +1757,6 @@ fn generate_fallback_tips(message: &str) -> Vec<AttendantTip> {
}); });
} }
if msg_lower.contains("frustrated") if msg_lower.contains("frustrated")
|| msg_lower.contains("angry") || msg_lower.contains("angry")
|| msg_lower.contains("ridiculous") || msg_lower.contains("ridiculous")
@ -1895,7 +1770,6 @@ fn generate_fallback_tips(message: &str) -> Vec<AttendantTip> {
}); });
} }
if message.contains('?') { if message.contains('?') {
tips.push(AttendantTip { tips.push(AttendantTip {
tip_type: TipType::Intent, tip_type: TipType::Intent,
@ -1905,7 +1779,6 @@ fn generate_fallback_tips(message: &str) -> Vec<AttendantTip> {
}); });
} }
if msg_lower.contains("problem") if msg_lower.contains("problem")
|| msg_lower.contains("issue") || msg_lower.contains("issue")
|| msg_lower.contains("not working") || msg_lower.contains("not working")
@ -1919,7 +1792,6 @@ fn generate_fallback_tips(message: &str) -> Vec<AttendantTip> {
}); });
} }
if msg_lower.contains("thank") if msg_lower.contains("thank")
|| msg_lower.contains("great") || msg_lower.contains("great")
|| msg_lower.contains("perfect") || msg_lower.contains("perfect")
@ -1934,7 +1806,6 @@ fn generate_fallback_tips(message: &str) -> Vec<AttendantTip> {
}); });
} }
if tips.is_empty() { if tips.is_empty() {
tips.push(AttendantTip { tips.push(AttendantTip {
tip_type: TipType::General, tip_type: TipType::General,
@ -1947,7 +1818,6 @@ fn generate_fallback_tips(message: &str) -> Vec<AttendantTip> {
tips tips
} }
fn generate_fallback_replies() -> Vec<SmartReply> { fn generate_fallback_replies() -> Vec<SmartReply> {
vec![ vec![
SmartReply { SmartReply {
@ -1971,7 +1841,6 @@ fn generate_fallback_replies() -> Vec<SmartReply> {
] ]
} }
fn analyze_sentiment_keywords(message: &str) -> SentimentAnalysis { fn analyze_sentiment_keywords(message: &str) -> SentimentAnalysis {
let msg_lower = message.to_lowercase(); let msg_lower = message.to_lowercase();
@ -2097,5 +1966,3 @@ fn analyze_sentiment_keywords(message: &str) -> SentimentAnalysis {
emoji: emoji.to_string(), emoji: emoji.to_string(),
} }
} }

View file

@ -1,40 +1,8 @@
pub mod drive; pub mod drive;
pub mod keyword_services; pub mod keyword_services;
pub mod llm_assist; pub mod llm_assist;
pub mod queue; pub mod queue;
pub use drive::{AttendanceDriveConfig, AttendanceDriveService, RecordMetadata, SyncResult}; pub use drive::{AttendanceDriveConfig, AttendanceDriveService, RecordMetadata, SyncResult};
pub use keyword_services::{ pub use keyword_services::{
AttendanceCommand, AttendanceRecord, AttendanceResponse, AttendanceService, KeywordConfig, AttendanceCommand, AttendanceRecord, AttendanceResponse, AttendanceService, KeywordConfig,
@ -58,14 +26,13 @@ use crate::shared::state::{AppState, AttendantNotification};
use axum::{ use axum::{
extract::{ extract::{
ws::{Message, WebSocket, WebSocketUpgrade}, ws::{Message, WebSocket, WebSocketUpgrade},
Path, Query, State, Query, State,
}, },
http::StatusCode, http::StatusCode,
response::IntoResponse, response::IntoResponse,
routing::{get, post}, routing::{get, post},
Json, Router, Json, Router,
}; };
use botlib::MessageType;
use chrono::Utc; use chrono::Utc;
use diesel::prelude::*; use diesel::prelude::*;
use futures::{SinkExt, StreamExt}; use futures::{SinkExt, StreamExt};
@ -76,10 +43,8 @@ use std::sync::Arc;
use tokio::sync::broadcast; use tokio::sync::broadcast;
use uuid::Uuid; use uuid::Uuid;
pub fn configure_attendance_routes() -> Router<Arc<AppState>> { pub fn configure_attendance_routes() -> Router<Arc<AppState>> {
Router::new() Router::new()
.route("/api/attendance/queue", get(queue::list_queue)) .route("/api/attendance/queue", get(queue::list_queue))
.route("/api/attendance/attendants", get(queue::list_attendants)) .route("/api/attendance/attendants", get(queue::list_attendants))
.route("/api/attendance/assign", post(queue::assign_conversation)) .route("/api/attendance/assign", post(queue::assign_conversation))
@ -92,11 +57,8 @@ pub fn configure_attendance_routes() -> Router<Arc<AppState>> {
post(queue::resolve_conversation), post(queue::resolve_conversation),
) )
.route("/api/attendance/insights", get(queue::get_insights)) .route("/api/attendance/insights", get(queue::get_insights))
.route("/api/attendance/respond", post(attendant_respond)) .route("/api/attendance/respond", post(attendant_respond))
.route("/ws/attendant", get(attendant_websocket_handler)) .route("/ws/attendant", get(attendant_websocket_handler))
.route("/api/attendance/llm/tips", post(llm_assist::generate_tips)) .route("/api/attendance/llm/tips", post(llm_assist::generate_tips))
.route( .route(
"/api/attendance/llm/polish", "/api/attendance/llm/polish",
@ -120,7 +82,6 @@ pub fn configure_attendance_routes() -> Router<Arc<AppState>> {
) )
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct AttendantRespondRequest { pub struct AttendantRespondRequest {
pub session_id: String, pub session_id: String,
@ -128,7 +89,6 @@ pub struct AttendantRespondRequest {
pub attendant_id: String, pub attendant_id: String,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
pub struct AttendantRespondResponse { pub struct AttendantRespondResponse {
pub success: bool, pub success: bool,
@ -137,7 +97,6 @@ pub struct AttendantRespondResponse {
pub error: Option<String>, pub error: Option<String>,
} }
pub async fn attendant_respond( pub async fn attendant_respond(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
Json(request): Json<AttendantRespondRequest>, Json(request): Json<AttendantRespondRequest>,
@ -161,7 +120,6 @@ pub async fn attendant_respond(
} }
}; };
let conn = state.conn.clone(); let conn = state.conn.clone();
let session_result = tokio::task::spawn_blocking(move || { let session_result = tokio::task::spawn_blocking(move || {
let mut db_conn = conn.get().ok()?; let mut db_conn = conn.get().ok()?;
@ -189,26 +147,22 @@ pub async fn attendant_respond(
} }
}; };
let channel = session let channel = session
.context_data .context_data
.get("channel") .get("channel")
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.unwrap_or("web"); .unwrap_or("web");
let recipient = session let recipient = session
.context_data .context_data
.get("phone") .get("phone")
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.unwrap_or(""); .unwrap_or("");
if let Err(e) = save_message_to_history(&state, &session, &request.message, "attendant").await { if let Err(e) = save_message_to_history(&state, &session, &request.message, "attendant").await {
error!("Failed to save attendant message: {}", e); error!("Failed to save attendant message: {}", e);
} }
match channel { match channel {
"whatsapp" => { "whatsapp" => {
if recipient.is_empty() { if recipient.is_empty() {
@ -240,7 +194,6 @@ pub async fn attendant_respond(
match adapter.send_message(response).await { match adapter.send_message(response).await {
Ok(_) => { Ok(_) => {
broadcast_attendant_action(&state, &session, &request, "attendant_response") broadcast_attendant_action(&state, &session, &request, "attendant_response")
.await; .await;
@ -264,7 +217,6 @@ pub async fn attendant_respond(
} }
} }
"web" | _ => { "web" | _ => {
let sent = if let Some(tx) = state let sent = if let Some(tx) = state
.response_channels .response_channels
.lock() .lock()
@ -282,15 +234,14 @@ pub async fn attendant_respond(
is_complete: true, is_complete: true,
suggestions: vec![], suggestions: vec![],
context_name: None, context_name: None,
context_length: 0, context_length: 0,
context_max_length: 0, context_max_length: 0,
}; };
tx.send(response).await.is_ok() tx.send(response).await.is_ok()
} else { } else {
false false
}; };
broadcast_attendant_action(&state, &session, &request, "attendant_response").await; broadcast_attendant_action(&state, &session, &request, "attendant_response").await;
if sent { if sent {
@ -303,7 +254,6 @@ pub async fn attendant_respond(
}), }),
) )
} else { } else {
( (
StatusCode::OK, StatusCode::OK,
Json(AttendantRespondResponse { Json(AttendantRespondResponse {
@ -317,7 +267,6 @@ pub async fn attendant_respond(
} }
} }
async fn save_message_to_history( async fn save_message_to_history(
state: &Arc<AppState>, state: &Arc<AppState>,
session: &UserSession, session: &UserSession,
@ -356,7 +305,6 @@ async fn save_message_to_history(
Ok(()) Ok(())
} }
async fn broadcast_attendant_action( async fn broadcast_attendant_action(
state: &Arc<AppState>, state: &Arc<AppState>,
session: &UserSession, session: &UserSession,
@ -396,7 +344,6 @@ async fn broadcast_attendant_action(
} }
} }
pub async fn attendant_websocket_handler( pub async fn attendant_websocket_handler(
ws: WebSocketUpgrade, ws: WebSocketUpgrade,
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
@ -422,13 +369,11 @@ pub async fn attendant_websocket_handler(
.into_response() .into_response()
} }
async fn handle_attendant_websocket(socket: WebSocket, state: Arc<AppState>, attendant_id: String) { async fn handle_attendant_websocket(socket: WebSocket, state: Arc<AppState>, attendant_id: String) {
let (mut sender, mut receiver) = socket.split(); let (mut sender, mut receiver) = socket.split();
info!("Attendant WebSocket connected: {}", attendant_id); info!("Attendant WebSocket connected: {}", attendant_id);
let welcome = serde_json::json!({ let welcome = serde_json::json!({
"type": "connected", "type": "connected",
"attendant_id": attendant_id, "attendant_id": attendant_id,
@ -447,7 +392,6 @@ async fn handle_attendant_websocket(socket: WebSocket, state: Arc<AppState>, att
} }
} }
let mut broadcast_rx = if let Some(broadcast_tx) = state.attendant_broadcast.as_ref() { let mut broadcast_rx = if let Some(broadcast_tx) = state.attendant_broadcast.as_ref() {
broadcast_tx.subscribe() broadcast_tx.subscribe()
} else { } else {
@ -455,14 +399,11 @@ async fn handle_attendant_websocket(socket: WebSocket, state: Arc<AppState>, att
return; return;
}; };
let attendant_id_clone = attendant_id.clone(); let attendant_id_clone = attendant_id.clone();
let mut send_task = tokio::spawn(async move { let mut send_task = tokio::spawn(async move {
loop { loop {
match broadcast_rx.recv().await { match broadcast_rx.recv().await {
Ok(notification) => { Ok(notification) => {
let should_send = notification.assigned_to.is_none() let should_send = notification.assigned_to.is_none()
|| notification.assigned_to.as_ref() == Some(&attendant_id_clone); || notification.assigned_to.as_ref() == Some(&attendant_id_clone);
@ -493,7 +434,6 @@ async fn handle_attendant_websocket(socket: WebSocket, state: Arc<AppState>, att
} }
}); });
let state_clone = state.clone(); let state_clone = state.clone();
let attendant_id_for_recv = attendant_id.clone(); let attendant_id_for_recv = attendant_id.clone();
let mut recv_task = tokio::spawn(async move { let mut recv_task = tokio::spawn(async move {
@ -505,7 +445,6 @@ async fn handle_attendant_websocket(socket: WebSocket, state: Arc<AppState>, att
attendant_id_for_recv, text attendant_id_for_recv, text
); );
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&text) { if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&text) {
handle_attendant_message(&state_clone, &attendant_id_for_recv, parsed) handle_attendant_message(&state_clone, &attendant_id_for_recv, parsed)
.await; .await;
@ -513,7 +452,6 @@ async fn handle_attendant_websocket(socket: WebSocket, state: Arc<AppState>, att
} }
Message::Ping(data) => { Message::Ping(data) => {
debug!("Received ping from attendant {}", attendant_id_for_recv); debug!("Received ping from attendant {}", attendant_id_for_recv);
} }
Message::Close(_) => { Message::Close(_) => {
info!( info!(
@ -527,7 +465,6 @@ async fn handle_attendant_websocket(socket: WebSocket, state: Arc<AppState>, att
} }
}); });
tokio::select! { tokio::select! {
_ = (&mut send_task) => { _ = (&mut send_task) => {
recv_task.abort(); recv_task.abort();
@ -540,7 +477,6 @@ async fn handle_attendant_websocket(socket: WebSocket, state: Arc<AppState>, att
info!("Attendant WebSocket disconnected: {}", attendant_id); info!("Attendant WebSocket disconnected: {}", attendant_id);
} }
async fn handle_attendant_message( async fn handle_attendant_message(
state: &Arc<AppState>, state: &Arc<AppState>,
attendant_id: &str, attendant_id: &str,
@ -553,24 +489,19 @@ async fn handle_attendant_message(
match msg_type { match msg_type {
"status_update" => { "status_update" => {
if let Some(status) = message.get("status").and_then(|v| v.as_str()) { if let Some(status) = message.get("status").and_then(|v| v.as_str()) {
info!("Attendant {} status update: {}", attendant_id, status); info!("Attendant {} status update: {}", attendant_id, status);
} }
} }
"typing" => { "typing" => {
if let Some(session_id) = message.get("session_id").and_then(|v| v.as_str()) { if let Some(session_id) = message.get("session_id").and_then(|v| v.as_str()) {
debug!( debug!(
"Attendant {} typing in session {}", "Attendant {} typing in session {}",
attendant_id, session_id attendant_id, session_id
); );
} }
} }
"read" => { "read" => {
if let Some(session_id) = message.get("session_id").and_then(|v| v.as_str()) { if let Some(session_id) = message.get("session_id").and_then(|v| v.as_str()) {
debug!( debug!(
"Attendant {} marked session {} as read", "Attendant {} marked session {} as read",
@ -579,7 +510,6 @@ async fn handle_attendant_message(
} }
} }
"respond" => { "respond" => {
if let (Some(session_id), Some(content)) = ( if let (Some(session_id), Some(content)) = (
message.get("session_id").and_then(|v| v.as_str()), message.get("session_id").and_then(|v| v.as_str()),
message.get("content").and_then(|v| v.as_str()), message.get("content").and_then(|v| v.as_str()),
@ -589,14 +519,12 @@ async fn handle_attendant_message(
attendant_id, session_id attendant_id, session_id
); );
let request = AttendantRespondRequest { let request = AttendantRespondRequest {
session_id: session_id.to_string(), session_id: session_id.to_string(),
message: content.to_string(), message: content.to_string(),
attendant_id: attendant_id.to_string(), attendant_id: attendant_id.to_string(),
}; };
if let Ok(uuid) = Uuid::parse_str(session_id) { if let Ok(uuid) = Uuid::parse_str(session_id) {
let conn = state.conn.clone(); let conn = state.conn.clone();
if let Some(session) = tokio::task::spawn_blocking(move || { if let Some(session) = tokio::task::spawn_blocking(move || {
@ -611,11 +539,9 @@ async fn handle_attendant_message(
.ok() .ok()
.flatten() .flatten()
{ {
let _ = let _ =
save_message_to_history(state, &session, content, "attendant").await; save_message_to_history(state, &session, content, "attendant").await;
let channel = session let channel = session
.context_data .context_data
.get("channel") .get("channel")
@ -639,14 +565,13 @@ async fn handle_attendant_message(
is_complete: true, is_complete: true,
suggestions: vec![], suggestions: vec![],
context_name: None, context_name: None,
context_length: 0, context_length: 0,
context_max_length: 0, context_max_length: 0,
}; };
let _ = adapter.send_message(response).await; let _ = adapter.send_message(response).await;
} }
} }
broadcast_attendant_action(state, &session, &request, "attendant_response") broadcast_attendant_action(state, &session, &request, "attendant_response")
.await; .await;
} }

View file

@ -1,22 +1,3 @@
use crate::shared::models::UserSession; use crate::shared::models::UserSession;
use crate::shared::state::AppState; use crate::shared::state::AppState;
use diesel::prelude::*; use diesel::prelude::*;
@ -27,10 +8,8 @@ use serde::{Deserialize, Serialize};
use std::sync::Arc; use std::sync::Arc;
use uuid::Uuid; use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ReflectionType { pub enum ReflectionType {
ConversationQuality, ConversationQuality,
ResponseAccuracy, ResponseAccuracy,
@ -66,7 +45,6 @@ impl From<&str> for ReflectionType {
} }
impl ReflectionType { impl ReflectionType {
pub fn prompt_template(&self) -> String { pub fn prompt_template(&self) -> String {
match self { match self {
ReflectionType::ConversationQuality => { ReflectionType::ConversationQuality => {
@ -190,10 +168,8 @@ Provide your analysis in JSON format:
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReflectionConfig { pub struct ReflectionConfig {
pub enabled: bool, pub enabled: bool,
pub interval: u32, pub interval: u32,
@ -224,7 +200,6 @@ impl Default for ReflectionConfig {
} }
impl ReflectionConfig { impl ReflectionConfig {
pub fn from_bot_config(state: &AppState, bot_id: Uuid) -> Self { pub fn from_bot_config(state: &AppState, bot_id: Uuid) -> Self {
let mut config = Self::default(); let mut config = Self::default();
@ -278,10 +253,8 @@ impl ReflectionConfig {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReflectionResult { pub struct ReflectionResult {
pub id: Uuid, pub id: Uuid,
pub bot_id: Uuid, pub bot_id: Uuid,
@ -325,7 +298,6 @@ impl ReflectionResult {
} }
} }
pub fn from_llm_response( pub fn from_llm_response(
bot_id: Uuid, bot_id: Uuid,
session_id: Uuid, session_id: Uuid,
@ -337,16 +309,13 @@ impl ReflectionResult {
result.raw_response = response.to_string(); result.raw_response = response.to_string();
result.messages_analyzed = messages_analyzed; result.messages_analyzed = messages_analyzed;
if let Ok(json) = serde_json::from_str::<serde_json::Value>(response) { if let Ok(json) = serde_json::from_str::<serde_json::Value>(response) {
result.score = json result.score = json
.get("score") .get("score")
.or_else(|| json.get("overall_score")) .or_else(|| json.get("overall_score"))
.and_then(|v| v.as_f64()) .and_then(|v| v.as_f64())
.unwrap_or(5.0) as f32; .unwrap_or(5.0) as f32;
if let Some(insights) = json.get("key_insights").or_else(|| json.get("insights")) { if let Some(insights) = json.get("key_insights").or_else(|| json.get("insights")) {
if let Some(arr) = insights.as_array() { if let Some(arr) = insights.as_array() {
result.insights = arr result.insights = arr
@ -356,7 +325,6 @@ impl ReflectionResult {
} }
} }
if let Some(improvements) = json if let Some(improvements) = json
.get("improvements") .get("improvements")
.or_else(|| json.get("critical_improvements")) .or_else(|| json.get("critical_improvements"))
@ -370,7 +338,6 @@ impl ReflectionResult {
} }
} }
if let Some(patterns) = json if let Some(patterns) = json
.get("positive_patterns") .get("positive_patterns")
.or_else(|| json.get("strengths")) .or_else(|| json.get("strengths"))
@ -384,7 +351,6 @@ impl ReflectionResult {
} }
} }
if let Some(concerns) = json if let Some(concerns) = json
.get("weaknesses") .get("weaknesses")
.or_else(|| json.get("concerns")) .or_else(|| json.get("concerns"))
@ -399,7 +365,6 @@ impl ReflectionResult {
} }
} }
} else { } else {
warn!("Reflection response was not valid JSON, extracting from plain text"); warn!("Reflection response was not valid JSON, extracting from plain text");
result.insights = extract_insights_from_text(response); result.insights = extract_insights_from_text(response);
result.score = 5.0; result.score = 5.0;
@ -408,12 +373,10 @@ impl ReflectionResult {
result result
} }
pub fn needs_improvement(&self, threshold: f32) -> bool { pub fn needs_improvement(&self, threshold: f32) -> bool {
self.score < threshold self.score < threshold
} }
pub fn summary(&self) -> String { pub fn summary(&self) -> String {
format!( format!(
"Reflection Score: {:.1}/10\n\ "Reflection Score: {:.1}/10\n\
@ -430,11 +393,9 @@ impl ReflectionResult {
} }
} }
pub fn extract_insights_from_text(text: &str) -> Vec<String> {
fn extract_insights_from_text(text: &str) -> Vec<String> {
let mut insights = Vec::new(); let mut insights = Vec::new();
for line in text.lines() { for line in text.lines() {
let trimmed = line.trim(); let trimmed = line.trim();
if trimmed.starts_with(|c: char| c.is_ascii_digit()) if trimmed.starts_with(|c: char| c.is_ascii_digit())
@ -453,7 +414,6 @@ fn extract_insights_from_text(text: &str) -> Vec<String> {
} }
} }
if insights.is_empty() { if insights.is_empty() {
for sentence in text.split(|c| c == '.' || c == '!' || c == '?') { for sentence in text.split(|c| c == '.' || c == '!' || c == '?') {
let trimmed = sentence.trim(); let trimmed = sentence.trim();
@ -467,7 +427,6 @@ fn extract_insights_from_text(text: &str) -> Vec<String> {
insights insights
} }
pub struct ReflectionEngine { pub struct ReflectionEngine {
state: Arc<AppState>, state: Arc<AppState>,
config: ReflectionConfig, config: ReflectionConfig,
@ -501,7 +460,6 @@ impl ReflectionEngine {
} }
} }
pub async fn reflect( pub async fn reflect(
&self, &self,
session_id: Uuid, session_id: Uuid,
@ -511,7 +469,6 @@ impl ReflectionEngine {
return Err("Reflection is not enabled for this bot".to_string()); return Err("Reflection is not enabled for this bot".to_string());
} }
let history = self.get_recent_history(session_id, 20).await?; let history = self.get_recent_history(session_id, 20).await?;
if history.is_empty() { if history.is_empty() {
@ -520,13 +477,10 @@ impl ReflectionEngine {
let messages_count = history.len(); let messages_count = history.len();
let prompt = self.build_reflection_prompt(&reflection_type, &history)?; let prompt = self.build_reflection_prompt(&reflection_type, &history)?;
let response = self.call_llm_for_reflection(&prompt).await?; let response = self.call_llm_for_reflection(&prompt).await?;
let result = ReflectionResult::from_llm_response( let result = ReflectionResult::from_llm_response(
self.bot_id, self.bot_id,
session_id, session_id,
@ -535,10 +489,8 @@ impl ReflectionEngine {
messages_count, messages_count,
); );
self.store_reflection(&result).await?; self.store_reflection(&result).await?;
if self.config.auto_apply && result.needs_improvement(self.config.improvement_threshold) { if self.config.auto_apply && result.needs_improvement(self.config.improvement_threshold) {
self.apply_improvements(&result).await?; self.apply_improvements(&result).await?;
} }
@ -551,7 +503,6 @@ impl ReflectionEngine {
Ok(result) Ok(result)
} }
async fn get_recent_history( async fn get_recent_history(
&self, &self,
session_id: Uuid, session_id: Uuid,
@ -595,7 +546,6 @@ impl ReflectionEngine {
Ok(history) Ok(history)
} }
fn build_reflection_prompt( fn build_reflection_prompt(
&self, &self,
reflection_type: &ReflectionType, reflection_type: &ReflectionType,
@ -607,7 +557,6 @@ impl ReflectionEngine {
reflection_type.prompt_template() reflection_type.prompt_template()
}; };
let conversation = history let conversation = history
.iter() .iter()
.map(|m| { .map(|m| {
@ -629,9 +578,7 @@ impl ReflectionEngine {
Ok(prompt) Ok(prompt)
} }
async fn call_llm_for_reflection(&self, prompt: &str) -> Result<String, String> { async fn call_llm_for_reflection(&self, prompt: &str) -> Result<String, String> {
let (llm_url, llm_model, llm_key) = self.get_llm_config().await?; let (llm_url, llm_model, llm_key) = self.get_llm_config().await?;
let client = reqwest::Client::new(); let client = reqwest::Client::new();
@ -680,7 +627,6 @@ impl ReflectionEngine {
Ok(content) Ok(content)
} }
async fn get_llm_config(&self) -> Result<(String, String, String), String> { async fn get_llm_config(&self) -> Result<(String, String, String), String> {
let mut conn = self let mut conn = self
.state .state
@ -720,7 +666,6 @@ impl ReflectionEngine {
Ok((llm_url, llm_model, llm_key)) Ok((llm_url, llm_model, llm_key))
} }
async fn store_reflection(&self, result: &ReflectionResult) -> Result<(), String> { async fn store_reflection(&self, result: &ReflectionResult) -> Result<(), String> {
let mut conn = self let mut conn = self
.state .state
@ -759,9 +704,7 @@ impl ReflectionEngine {
Ok(()) Ok(())
} }
async fn apply_improvements(&self, result: &ReflectionResult) -> Result<(), String> { async fn apply_improvements(&self, result: &ReflectionResult) -> Result<(), String> {
let mut conn = self let mut conn = self
.state .state
.conn .conn
@ -796,7 +739,6 @@ impl ReflectionEngine {
Ok(()) Ok(())
} }
pub async fn get_insights(&self, limit: usize) -> Result<Vec<ReflectionResult>, String> { pub async fn get_insights(&self, limit: usize) -> Result<Vec<ReflectionResult>, String> {
let mut conn = self let mut conn = self
.state .state
@ -867,13 +809,11 @@ impl ReflectionEngine {
Ok(results) Ok(results)
} }
pub async fn should_reflect(&self, session_id: Uuid) -> bool { pub async fn should_reflect(&self, session_id: Uuid) -> bool {
if !self.config.enabled { if !self.config.enabled {
return false; return false;
} }
if let Ok(mut conn) = self.state.conn.get() { if let Ok(mut conn) = self.state.conn.get() {
#[derive(QueryableByName)] #[derive(QueryableByName)]
struct CountRow { struct CountRow {
@ -897,7 +837,6 @@ impl ReflectionEngine {
} }
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
struct ConversationMessage { struct ConversationMessage {
pub role: String, pub role: String,
@ -905,14 +844,12 @@ struct ConversationMessage {
pub timestamp: chrono::DateTime<chrono::Utc>, pub timestamp: chrono::DateTime<chrono::Utc>,
} }
pub fn register_reflection_keywords(state: Arc<AppState>, user: UserSession, engine: &mut Engine) { pub fn register_reflection_keywords(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
set_bot_reflection_keyword(state.clone(), user.clone(), engine); set_bot_reflection_keyword(state.clone(), user.clone(), engine);
reflect_on_keyword(state.clone(), user.clone(), engine); reflect_on_keyword(state.clone(), user.clone(), engine);
get_reflection_insights_keyword(state.clone(), user.clone(), engine); get_reflection_insights_keyword(state.clone(), user.clone(), engine);
} }
pub fn set_bot_reflection_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) { pub fn set_bot_reflection_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
let state_clone = Arc::clone(&state); let state_clone = Arc::clone(&state);
let user_clone = user.clone(); let user_clone = user.clone();
@ -964,7 +901,6 @@ pub fn set_bot_reflection_keyword(state: Arc<AppState>, user: UserSession, engin
.expect("Failed to register SET BOT REFLECTION syntax"); .expect("Failed to register SET BOT REFLECTION syntax");
} }
pub fn reflect_on_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) { pub fn reflect_on_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
let state_clone = Arc::clone(&state); let state_clone = Arc::clone(&state);
let user_clone = user.clone(); let user_clone = user.clone();
@ -1016,7 +952,6 @@ pub fn reflect_on_keyword(state: Arc<AppState>, user: UserSession, engine: &mut
.expect("Failed to register REFLECT ON syntax"); .expect("Failed to register REFLECT ON syntax");
} }
pub fn get_reflection_insights_keyword( pub fn get_reflection_insights_keyword(
state: Arc<AppState>, state: Arc<AppState>,
user: UserSession, user: UserSession,
@ -1045,7 +980,6 @@ pub fn get_reflection_insights_keyword(
}); });
} }
async fn set_reflection_enabled( async fn set_reflection_enabled(
state: &AppState, state: &AppState,
bot_id: Uuid, bot_id: Uuid,
@ -1077,5 +1011,3 @@ async fn set_reflection_enabled(
if enabled { "enabled" } else { "disabled" } if enabled { "enabled" } else { "disabled" }
)) ))
} }

View file

@ -1,25 +1,3 @@
use crate::shared::models::UserSession; use crate::shared::models::UserSession;
use crate::shared::state::AppState; use crate::shared::state::AppState;
use diesel::prelude::*; use diesel::prelude::*;
@ -33,10 +11,8 @@ use std::time::Duration;
use tokio::time::timeout; use tokio::time::timeout;
use uuid::Uuid; use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum SandboxRuntime { pub enum SandboxRuntime {
LXC, LXC,
Docker, Docker,
@ -63,7 +39,6 @@ impl From<&str> for SandboxRuntime {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum CodeLanguage { pub enum CodeLanguage {
Python, Python,
@ -108,10 +83,8 @@ impl CodeLanguage {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SandboxConfig { pub struct SandboxConfig {
pub enabled: bool, pub enabled: bool,
pub runtime: SandboxRuntime, pub runtime: SandboxRuntime,
@ -151,7 +124,6 @@ impl Default for SandboxConfig {
} }
impl SandboxConfig { impl SandboxConfig {
pub fn from_bot_config(state: &AppState, bot_id: Uuid) -> Self { pub fn from_bot_config(state: &AppState, bot_id: Uuid) -> Self {
let mut config = Self::default(); let mut config = Self::default();
@ -218,10 +190,8 @@ impl SandboxConfig {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutionResult { pub struct ExecutionResult {
pub stdout: String, pub stdout: String,
pub stderr: String, pub stderr: String,
@ -289,7 +259,6 @@ impl ExecutionResult {
} }
} }
pub struct CodeSandbox { pub struct CodeSandbox {
config: SandboxConfig, config: SandboxConfig,
session_id: Uuid, session_id: Uuid,
@ -309,7 +278,6 @@ impl CodeSandbox {
Self { config, session_id } Self { config, session_id }
} }
pub async fn execute(&self, code: &str, language: CodeLanguage) -> ExecutionResult { pub async fn execute(&self, code: &str, language: CodeLanguage) -> ExecutionResult {
if !self.config.enabled { if !self.config.enabled {
return ExecutionResult::error("Sandbox execution is disabled"); return ExecutionResult::error("Sandbox execution is disabled");
@ -339,20 +307,16 @@ impl CodeSandbox {
} }
} }
async fn execute_lxc( async fn execute_lxc(
&self, &self,
code: &str, code: &str,
language: &CodeLanguage, language: &CodeLanguage,
) -> Result<ExecutionResult, String> { ) -> Result<ExecutionResult, String> {
let container_name = format!("gb-sandbox-{}", Uuid::new_v4()); let container_name = format!("gb-sandbox-{}", Uuid::new_v4());
std::fs::create_dir_all(&self.config.work_dir) std::fs::create_dir_all(&self.config.work_dir)
.map_err(|e| format!("Failed to create work dir: {}", e))?; .map_err(|e| format!("Failed to create work dir: {}", e))?;
let code_file = format!( let code_file = format!(
"{}/{}.{}", "{}/{}.{}",
self.config.work_dir, self.config.work_dir,
@ -362,7 +326,6 @@ impl CodeSandbox {
std::fs::write(&code_file, code) std::fs::write(&code_file, code)
.map_err(|e| format!("Failed to write code file: {}", e))?; .map_err(|e| format!("Failed to write code file: {}", e))?;
let mut cmd = Command::new("lxc-execute"); let mut cmd = Command::new("lxc-execute");
cmd.arg("-n") cmd.arg("-n")
.arg(&container_name) .arg(&container_name)
@ -372,7 +335,6 @@ impl CodeSandbox {
.arg(language.interpreter()) .arg(language.interpreter())
.arg(&code_file); .arg(&code_file);
cmd.env( cmd.env(
"LXC_CGROUP_MEMORY_LIMIT", "LXC_CGROUP_MEMORY_LIMIT",
format!("{}M", self.config.memory_limit_mb), format!("{}M", self.config.memory_limit_mb),
@ -382,7 +344,6 @@ impl CodeSandbox {
format!("{}", self.config.cpu_limit_percent * 1000), format!("{}", self.config.cpu_limit_percent * 1000),
); );
let timeout_duration = Duration::from_secs(self.config.timeout_seconds); let timeout_duration = Duration::from_secs(self.config.timeout_seconds);
let output = timeout(timeout_duration, async { let output = timeout(timeout_duration, async {
tokio::process::Command::new("lxc-execute") tokio::process::Command::new("lxc-execute")
@ -398,7 +359,6 @@ impl CodeSandbox {
}) })
.await; .await;
let _ = std::fs::remove_file(&code_file); let _ = std::fs::remove_file(&code_file);
match output { match output {
@ -422,20 +382,17 @@ impl CodeSandbox {
} }
} }
async fn execute_docker( async fn execute_docker(
&self, &self,
code: &str, code: &str,
language: &CodeLanguage, language: &CodeLanguage,
) -> Result<ExecutionResult, String> { ) -> Result<ExecutionResult, String> {
let image = match language { let image = match language {
CodeLanguage::Python => "python:3.11-slim", CodeLanguage::Python => "python:3.11-slim",
CodeLanguage::JavaScript => "node:20-slim", CodeLanguage::JavaScript => "node:20-slim",
CodeLanguage::Bash => "alpine:latest", CodeLanguage::Bash => "alpine:latest",
}; };
let args = vec![ let args = vec![
"run".to_string(), "run".to_string(),
"--rm".to_string(), "--rm".to_string(),
@ -459,7 +416,6 @@ impl CodeSandbox {
code.to_string(), code.to_string(),
]; ];
let timeout_duration = Duration::from_secs(self.config.timeout_seconds); let timeout_duration = Duration::from_secs(self.config.timeout_seconds);
let output = timeout(timeout_duration, async { let output = timeout(timeout_duration, async {
tokio::process::Command::new("docker") tokio::process::Command::new("docker")
@ -490,42 +446,34 @@ impl CodeSandbox {
} }
} }
async fn execute_firecracker( async fn execute_firecracker(
&self, &self,
code: &str, code: &str,
language: &CodeLanguage, language: &CodeLanguage,
) -> Result<ExecutionResult, String> { ) -> Result<ExecutionResult, String> {
warn!("Firecracker runtime not yet implemented, falling back to process isolation"); warn!("Firecracker runtime not yet implemented, falling back to process isolation");
self.execute_process(code, language).await self.execute_process(code, language).await
} }
async fn execute_process( async fn execute_process(
&self, &self,
code: &str, code: &str,
language: &CodeLanguage, language: &CodeLanguage,
) -> Result<ExecutionResult, String> { ) -> Result<ExecutionResult, String> {
let temp_dir = format!("{}/{}", self.config.work_dir, Uuid::new_v4()); let temp_dir = format!("{}/{}", self.config.work_dir, Uuid::new_v4());
std::fs::create_dir_all(&temp_dir) std::fs::create_dir_all(&temp_dir)
.map_err(|e| format!("Failed to create temp dir: {}", e))?; .map_err(|e| format!("Failed to create temp dir: {}", e))?;
let code_file = format!("{}/code.{}", temp_dir, language.file_extension()); let code_file = format!("{}/code.{}", temp_dir, language.file_extension());
std::fs::write(&code_file, code) std::fs::write(&code_file, code)
.map_err(|e| format!("Failed to write code file: {}", e))?; .map_err(|e| format!("Failed to write code file: {}", e))?;
let (cmd_name, cmd_args): (&str, Vec<&str>) = match language { let (cmd_name, cmd_args): (&str, Vec<&str>) = match language {
CodeLanguage::Python => ("python3", vec![&code_file]), CodeLanguage::Python => ("python3", vec![&code_file]),
CodeLanguage::JavaScript => ("node", vec![&code_file]), CodeLanguage::JavaScript => ("node", vec![&code_file]),
CodeLanguage::Bash => ("bash", vec![&code_file]), CodeLanguage::Bash => ("bash", vec![&code_file]),
}; };
let timeout_duration = Duration::from_secs(self.config.timeout_seconds); let timeout_duration = Duration::from_secs(self.config.timeout_seconds);
let output = timeout(timeout_duration, async { let output = timeout(timeout_duration, async {
tokio::process::Command::new(cmd_name) tokio::process::Command::new(cmd_name)
@ -540,7 +488,6 @@ impl CodeSandbox {
}) })
.await; .await;
let _ = std::fs::remove_dir_all(&temp_dir); let _ = std::fs::remove_dir_all(&temp_dir);
match output { match output {
@ -564,7 +511,6 @@ impl CodeSandbox {
} }
} }
pub async fn execute_file(&self, file_path: &str, language: CodeLanguage) -> ExecutionResult { pub async fn execute_file(&self, file_path: &str, language: CodeLanguage) -> ExecutionResult {
match std::fs::read_to_string(file_path) { match std::fs::read_to_string(file_path) {
Ok(code) => self.execute(&code, language).await, Ok(code) => self.execute(&code, language).await,
@ -573,7 +519,6 @@ impl CodeSandbox {
} }
} }
pub fn register_sandbox_keywords(state: Arc<AppState>, user: UserSession, engine: &mut Engine) { pub fn register_sandbox_keywords(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
run_python_keyword(state.clone(), user.clone(), engine); run_python_keyword(state.clone(), user.clone(), engine);
run_javascript_keyword(state.clone(), user.clone(), engine); run_javascript_keyword(state.clone(), user.clone(), engine);
@ -581,7 +526,6 @@ pub fn register_sandbox_keywords(state: Arc<AppState>, user: UserSession, engine
run_file_keyword(state.clone(), user.clone(), engine); run_file_keyword(state.clone(), user.clone(), engine);
} }
pub fn run_python_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) { pub fn run_python_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
let state_clone = Arc::clone(&state); let state_clone = Arc::clone(&state);
let user_clone = user.clone(); let user_clone = user.clone();
@ -627,7 +571,6 @@ pub fn run_python_keyword(state: Arc<AppState>, user: UserSession, engine: &mut
.expect("Failed to register RUN PYTHON syntax"); .expect("Failed to register RUN PYTHON syntax");
} }
pub fn run_javascript_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) { pub fn run_javascript_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
let state_clone = Arc::clone(&state); let state_clone = Arc::clone(&state);
let user_clone = user.clone(); let user_clone = user.clone();
@ -672,7 +615,6 @@ pub fn run_javascript_keyword(state: Arc<AppState>, user: UserSession, engine: &
) )
.expect("Failed to register RUN JAVASCRIPT syntax"); .expect("Failed to register RUN JAVASCRIPT syntax");
let state_clone2 = Arc::clone(&state); let state_clone2 = Arc::clone(&state);
let user_clone2 = user.clone(); let user_clone2 = user.clone();
@ -711,7 +653,6 @@ pub fn run_javascript_keyword(state: Arc<AppState>, user: UserSession, engine: &
.expect("Failed to register RUN JS syntax"); .expect("Failed to register RUN JS syntax");
} }
pub fn run_bash_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) { pub fn run_bash_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
let state_clone = Arc::clone(&state); let state_clone = Arc::clone(&state);
let user_clone = user.clone(); let user_clone = user.clone();
@ -753,7 +694,6 @@ pub fn run_bash_keyword(state: Arc<AppState>, user: UserSession, engine: &mut En
.expect("Failed to register RUN BASH syntax"); .expect("Failed to register RUN BASH syntax");
} }
pub fn run_file_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) { pub fn run_file_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
let state_clone = Arc::clone(&state); let state_clone = Arc::clone(&state);
let user_clone = user.clone(); let user_clone = user.clone();
@ -802,7 +742,6 @@ pub fn run_file_keyword(state: Arc<AppState>, user: UserSession, engine: &mut En
) )
.expect("Failed to register RUN PYTHON WITH FILE syntax"); .expect("Failed to register RUN PYTHON WITH FILE syntax");
let state_clone2 = Arc::clone(&state); let state_clone2 = Arc::clone(&state);
let user_clone2 = user.clone(); let user_clone2 = user.clone();
@ -847,11 +786,8 @@ pub fn run_file_keyword(state: Arc<AppState>, user: UserSession, engine: &mut En
.expect("Failed to register RUN JAVASCRIPT WITH FILE syntax"); .expect("Failed to register RUN JAVASCRIPT WITH FILE syntax");
} }
pub fn generate_python_lxc_config() -> String { pub fn generate_python_lxc_config() -> String {
r#" r"
# LXC configuration for Python sandbox # LXC configuration for Python sandbox
lxc.include = /usr/share/lxc/config/common.conf lxc.include = /usr/share/lxc/config/common.conf
lxc.arch = linux64 lxc.arch = linux64
@ -879,13 +815,12 @@ lxc.mount.auto = proc:mixed sys:ro
lxc.mount.entry = /usr/bin/python3 usr/bin/python3 none ro,bind 0 0 lxc.mount.entry = /usr/bin/python3 usr/bin/python3 none ro,bind 0 0
lxc.mount.entry = /usr/lib/python3 usr/lib/python3 none ro,bind 0 0 lxc.mount.entry = /usr/lib/python3 usr/lib/python3 none ro,bind 0 0
lxc.mount.entry = tmpfs tmp tmpfs defaults 0 0 lxc.mount.entry = tmpfs tmp tmpfs defaults 0 0
"# "
.to_string() .to_string()
} }
pub fn generate_node_lxc_config() -> String { pub fn generate_node_lxc_config() -> String {
r#" r"
# LXC configuration for Node.js sandbox # LXC configuration for Node.js sandbox
lxc.include = /usr/share/lxc/config/common.conf lxc.include = /usr/share/lxc/config/common.conf
lxc.arch = linux64 lxc.arch = linux64
@ -913,8 +848,6 @@ lxc.mount.auto = proc:mixed sys:ro
lxc.mount.entry = /usr/bin/node usr/bin/node none ro,bind 0 0 lxc.mount.entry = /usr/bin/node usr/bin/node none ro,bind 0 0
lxc.mount.entry = /usr/lib/node_modules usr/lib/node_modules none ro,bind 0 0 lxc.mount.entry = /usr/lib/node_modules usr/lib/node_modules none ro,bind 0 0
lxc.mount.entry = tmpfs tmp tmpfs defaults 0 0 lxc.mount.entry = tmpfs tmp tmpfs defaults 0 0
"# "
.to_string() .to_string()
} }

View file

@ -1,8 +1,3 @@
use crate::llm::LLMProvider; use crate::llm::LLMProvider;
use crate::shared::models::UserSession; use crate::shared::models::UserSession;
use crate::shared::state::AppState; use crate::shared::state::AppState;
@ -16,7 +11,6 @@ use std::io::Read;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
pub fn create_site_keyword(state: &AppState, user: UserSession, engine: &mut Engine) { pub fn create_site_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
let state_clone = state.clone(); let state_clone = state.clone();
let user_clone = user.clone(); let user_clone = user.clone();
@ -58,12 +52,6 @@ pub fn create_site_keyword(state: &AppState, user: UserSession, engine: &mut Eng
.unwrap(); .unwrap();
} }
async fn create_site( async fn create_site(
config: crate::config::AppConfig, config: crate::config::AppConfig,
s3: Option<std::sync::Arc<aws_sdk_s3::Client>>, s3: Option<std::sync::Arc<aws_sdk_s3::Client>>,
@ -83,20 +71,16 @@ async fn create_site(
alias_str, template_dir_str alias_str, template_dir_str
); );
let base_path = PathBuf::from(&config.site_path); let base_path = PathBuf::from(&config.site_path);
let template_path = base_path.join(&template_dir_str); let template_path = base_path.join(&template_dir_str);
let combined_content = load_templates(&template_path)?; let combined_content = load_templates(&template_path)?;
let generated_html = generate_html_from_prompt(llm, &combined_content, &prompt_str).await?; let generated_html = generate_html_from_prompt(llm, &combined_content, &prompt_str).await?;
let drive_path = format!("apps/{}", alias_str); let drive_path = format!("apps/{}", alias_str);
store_to_drive(&s3, &bucket, &bot_id, &drive_path, &generated_html).await?; store_to_drive(&s3, &bucket, &bot_id, &drive_path, &generated_html).await?;
let serve_path = base_path.join(&alias_str); let serve_path = base_path.join(&alias_str);
sync_to_serve_path(&serve_path, &generated_html, &template_path).await?; sync_to_serve_path(&serve_path, &generated_html, &template_path).await?;
@ -108,7 +92,6 @@ async fn create_site(
Ok(format!("/apps/{}", alias_str)) Ok(format!("/apps/{}", alias_str))
} }
fn load_templates(template_path: &PathBuf) -> Result<String, Box<dyn Error + Send + Sync>> { fn load_templates(template_path: &PathBuf) -> Result<String, Box<dyn Error + Send + Sync>> {
let mut combined_content = String::new(); let mut combined_content = String::new();
@ -141,7 +124,6 @@ fn load_templates(template_path: &PathBuf) -> Result<String, Box<dyn Error + Sen
Ok(combined_content) Ok(combined_content)
} }
async fn generate_html_from_prompt( async fn generate_html_from_prompt(
llm: Option<Arc<dyn LLMProvider>>, llm: Option<Arc<dyn LLMProvider>>,
templates: &str, templates: &str,
@ -209,7 +191,6 @@ OUTPUT: Complete index.html file only, no explanations."#,
Ok(html) Ok(html)
} }
fn extract_html_from_response(response: &str) -> String { fn extract_html_from_response(response: &str) -> String {
let trimmed = response.trim(); let trimmed = response.trim();
@ -234,7 +215,6 @@ fn extract_html_from_response(response: &str) -> String {
trimmed.to_string() trimmed.to_string()
} }
fn generate_placeholder_html(prompt: &str) -> String { fn generate_placeholder_html(prompt: &str) -> String {
format!( format!(
r##"<!DOCTYPE html> r##"<!DOCTYPE html>
@ -278,7 +258,6 @@ fn generate_placeholder_html(prompt: &str) -> String {
) )
} }
async fn store_to_drive( async fn store_to_drive(
s3: &Option<std::sync::Arc<aws_sdk_s3::Client>>, s3: &Option<std::sync::Arc<aws_sdk_s3::Client>>,
bucket: &str, bucket: &str,
@ -304,7 +283,6 @@ async fn store_to_drive(
.await .await
.map_err(|e| format!("Failed to store to drive: {}", e))?; .map_err(|e| format!("Failed to store to drive: {}", e))?;
let schema_key = format!("{}.gbdrive/{}/schema.json", bot_id, drive_path); let schema_key = format!("{}.gbdrive/{}/schema.json", bot_id, drive_path);
let schema = r#"{"tables": {}, "version": 1}"#; let schema = r#"{"tables": {}, "version": 1}"#;
@ -321,23 +299,19 @@ async fn store_to_drive(
Ok(()) Ok(())
} }
async fn sync_to_serve_path( async fn sync_to_serve_path(
serve_path: &PathBuf, serve_path: &PathBuf,
html_content: &str, html_content: &str,
template_path: &PathBuf, template_path: &PathBuf,
) -> Result<(), Box<dyn Error + Send + Sync>> { ) -> Result<(), Box<dyn Error + Send + Sync>> {
fs::create_dir_all(serve_path).map_err(|e| format!("Failed to create serve path: {}", e))?; fs::create_dir_all(serve_path).map_err(|e| format!("Failed to create serve path: {}", e))?;
let index_path = serve_path.join("index.html"); let index_path = serve_path.join("index.html");
fs::write(&index_path, html_content) fs::write(&index_path, html_content)
.map_err(|e| format!("Failed to write index.html: {}", e))?; .map_err(|e| format!("Failed to write index.html: {}", e))?;
info!("Written: {:?}", index_path); info!("Written: {:?}", index_path);
let template_assets = template_path.join("_assets"); let template_assets = template_path.join("_assets");
let serve_assets = serve_path.join("_assets"); let serve_assets = serve_path.join("_assets");
@ -345,26 +319,21 @@ async fn sync_to_serve_path(
copy_dir_recursive(&template_assets, &serve_assets)?; copy_dir_recursive(&template_assets, &serve_assets)?;
info!("Copied assets to: {:?}", serve_assets); info!("Copied assets to: {:?}", serve_assets);
} else { } else {
fs::create_dir_all(&serve_assets) fs::create_dir_all(&serve_assets)
.map_err(|e| format!("Failed to create assets dir: {}", e))?; .map_err(|e| format!("Failed to create assets dir: {}", e))?;
let htmx_path = serve_assets.join("htmx.min.js"); let htmx_path = serve_assets.join("htmx.min.js");
if !htmx_path.exists() { if !htmx_path.exists() {
fs::write(&htmx_path, "/* HTMX - include from CDN or bundle */") fs::write(&htmx_path, "/* HTMX - include from CDN or bundle */")
.map_err(|e| format!("Failed to write htmx: {}", e))?; .map_err(|e| format!("Failed to write htmx: {}", e))?;
} }
let styles_path = serve_assets.join("styles.css"); let styles_path = serve_assets.join("styles.css");
if !styles_path.exists() { if !styles_path.exists() {
fs::write(&styles_path, DEFAULT_STYLES) fs::write(&styles_path, DEFAULT_STYLES)
.map_err(|e| format!("Failed to write styles: {}", e))?; .map_err(|e| format!("Failed to write styles: {}", e))?;
} }
let app_js_path = serve_assets.join("app.js"); let app_js_path = serve_assets.join("app.js");
if !app_js_path.exists() { if !app_js_path.exists() {
fs::write(&app_js_path, DEFAULT_APP_JS) fs::write(&app_js_path, DEFAULT_APP_JS)
@ -372,7 +341,6 @@ async fn sync_to_serve_path(
} }
} }
let schema_path = serve_path.join("schema.json"); let schema_path = serve_path.join("schema.json");
fs::write(&schema_path, r#"{"tables": {}, "version": 1}"#) fs::write(&schema_path, r#"{"tables": {}, "version": 1}"#)
.map_err(|e| format!("Failed to write schema.json: {}", e))?; .map_err(|e| format!("Failed to write schema.json: {}", e))?;
@ -380,7 +348,6 @@ async fn sync_to_serve_path(
Ok(()) Ok(())
} }
fn copy_dir_recursive(src: &PathBuf, dst: &PathBuf) -> Result<(), Box<dyn Error + Send + Sync>> { fn copy_dir_recursive(src: &PathBuf, dst: &PathBuf) -> Result<(), Box<dyn Error + Send + Sync>> {
fs::create_dir_all(dst).map_err(|e| format!("Failed to create dir {:?}: {}", dst, e))?; fs::create_dir_all(dst).map_err(|e| format!("Failed to create dir {:?}: {}", dst, e))?;
@ -400,8 +367,7 @@ fn copy_dir_recursive(src: &PathBuf, dst: &PathBuf) -> Result<(), Box<dyn Error
Ok(()) Ok(())
} }
const DEFAULT_STYLES: &str = r"
const DEFAULT_STYLES: &str = r#"
:root { :root {
--primary: #0ea5e9; --primary: #0ea5e9;
--success: #22c55e; --success: #22c55e;
@ -494,10 +460,9 @@ th, td {
.htmx-request .htmx-indicator { .htmx-request .htmx-indicator {
opacity: 1; opacity: 1;
} }
"#; ";
const DEFAULT_APP_JS: &str = r"
const DEFAULT_APP_JS: &str = r#"
function toast(message, type = 'info') { function toast(message, type = 'info') {
const el = document.createElement('div'); const el = document.createElement('div');
@ -525,4 +490,4 @@ function openModal(id) {
function closeModal(id) { function closeModal(id) {
document.getElementById(id)?.classList.remove('active'); document.getElementById(id)?.classList.remove('active');
} }
"#; ";

View file

@ -1,52 +1,11 @@
use chrono::{DateTime, Duration, Utc}; use chrono::{DateTime, Duration, Utc};
use rhai::{Array, Dynamic, Engine, Map}; use rhai::{Array, Dynamic, Engine, Map};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tracing::info; use tracing::info;
use uuid::Uuid; use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Episode { pub struct Episode {
pub id: Uuid, pub id: Uuid,
pub user_id: Uuid, pub user_id: Uuid,
@ -79,10 +38,8 @@ pub struct Episode {
pub metadata: serde_json::Value, pub metadata: serde_json::Value,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ActionItem { pub struct ActionItem {
pub description: String, pub description: String,
pub assignee: Option<String>, pub assignee: Option<String>,
@ -94,7 +51,6 @@ pub struct ActionItem {
pub completed: bool, pub completed: bool,
} }
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum Priority { pub enum Priority {
@ -110,10 +66,8 @@ impl Default for Priority {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Sentiment { pub struct Sentiment {
pub score: f64, pub score: f64,
pub label: SentimentLabel, pub label: SentimentLabel,
@ -121,7 +75,6 @@ pub struct Sentiment {
pub confidence: f64, pub confidence: f64,
} }
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum SentimentLabel { pub enum SentimentLabel {
@ -148,7 +101,6 @@ impl Default for Sentiment {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum ResolutionStatus { pub enum ResolutionStatus {
@ -165,10 +117,8 @@ impl Default for ResolutionStatus {
} }
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct EpisodicMemoryConfig { pub struct EpisodicMemoryConfig {
pub enabled: bool, pub enabled: bool,
pub threshold: usize, pub threshold: usize,
@ -198,7 +148,6 @@ impl Default for EpisodicMemoryConfig {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConversationMessage { pub struct ConversationMessage {
pub id: Uuid, pub id: Uuid,
@ -207,19 +156,16 @@ pub struct ConversationMessage {
pub timestamp: DateTime<Utc>, pub timestamp: DateTime<Utc>,
} }
#[derive(Debug)] #[derive(Debug)]
pub struct EpisodicMemoryManager { pub struct EpisodicMemoryManager {
config: EpisodicMemoryConfig, config: EpisodicMemoryConfig,
} }
impl EpisodicMemoryManager { impl EpisodicMemoryManager {
pub fn new(config: EpisodicMemoryConfig) -> Self { pub fn new(config: EpisodicMemoryConfig) -> Self {
EpisodicMemoryManager { config } EpisodicMemoryManager { config }
} }
pub fn from_config(config_map: &std::collections::HashMap<String, String>) -> Self { pub fn from_config(config_map: &std::collections::HashMap<String, String>) -> Self {
let config = EpisodicMemoryConfig { let config = EpisodicMemoryConfig {
enabled: config_map enabled: config_map
@ -254,22 +200,18 @@ impl EpisodicMemoryManager {
EpisodicMemoryManager::new(config) EpisodicMemoryManager::new(config)
} }
pub fn should_summarize(&self, message_count: usize) -> bool { pub fn should_summarize(&self, message_count: usize) -> bool {
self.config.enabled && self.config.auto_summarize && message_count >= self.config.threshold self.config.enabled && self.config.auto_summarize && message_count >= self.config.threshold
} }
pub fn get_history_to_keep(&self) -> usize { pub fn get_history_to_keep(&self) -> usize {
self.config.history self.config.history
} }
pub fn get_threshold(&self) -> usize { pub fn get_threshold(&self) -> usize {
self.config.threshold self.config.threshold
} }
pub fn generate_summary_prompt(&self, messages: &[ConversationMessage]) -> String { pub fn generate_summary_prompt(&self, messages: &[ConversationMessage]) -> String {
let formatted_messages = messages let formatted_messages = messages
.iter() .iter()
@ -309,7 +251,6 @@ Respond with valid JSON only:
) )
} }
pub fn parse_summary_response( pub fn parse_summary_response(
&self, &self,
response: &str, response: &str,
@ -318,7 +259,6 @@ Respond with valid JSON only:
bot_id: Uuid, bot_id: Uuid,
session_id: Uuid, session_id: Uuid,
) -> Result<Episode, String> { ) -> Result<Episode, String> {
let json_str = extract_json(response)?; let json_str = extract_json(response)?;
let parsed: serde_json::Value = let parsed: serde_json::Value =
@ -418,22 +358,18 @@ Respond with valid JSON only:
}) })
} }
pub fn get_retention_cutoff(&self) -> DateTime<Utc> { pub fn get_retention_cutoff(&self) -> DateTime<Utc> {
Utc::now() - Duration::days(self.config.retention_days as i64) Utc::now() - Duration::days(self.config.retention_days as i64)
} }
} }
pub fn extract_json(response: &str) -> Result<String, String> {
fn extract_json(response: &str) -> Result<String, String> {
if let Some(start) = response.find("```json") { if let Some(start) = response.find("```json") {
if let Some(end) = response[start + 7..].find("```") { if let Some(end) = response[start + 7..].find("```") {
return Ok(response[start + 7..start + 7 + end].trim().to_string()); return Ok(response[start + 7..start + 7 + end].trim().to_string());
} }
} }
if let Some(start) = response.find("```") { if let Some(start) = response.find("```") {
let after_start = start + 3; let after_start = start + 3;
@ -446,7 +382,6 @@ fn extract_json(response: &str) -> Result<String, String> {
} }
} }
if let Some(start) = response.find('{') { if let Some(start) = response.find('{') {
if let Some(end) = response.rfind('}') { if let Some(end) = response.rfind('}') {
if end > start { if end > start {
@ -458,7 +393,6 @@ fn extract_json(response: &str) -> Result<String, String> {
Err("No JSON found in response".to_string()) Err("No JSON found in response".to_string())
} }
impl Episode { impl Episode {
pub fn to_dynamic(&self) -> Dynamic { pub fn to_dynamic(&self) -> Dynamic {
let mut map = Map::new(); let mut map = Map::new();
@ -531,12 +465,7 @@ impl Episode {
} }
} }
pub fn register_episodic_memory_keywords(engine: &mut Engine) { pub fn register_episodic_memory_keywords(engine: &mut Engine) {
engine.register_fn("episode_summary", |episode: Map| -> String { engine.register_fn("episode_summary", |episode: Map| -> String {
episode episode
.get("summary") .get("summary")
@ -584,7 +513,6 @@ pub fn register_episodic_memory_keywords(engine: &mut Engine) {
info!("Episodic memory keywords registered"); info!("Episodic memory keywords registered");
} }
pub const EPISODIC_MEMORY_SCHEMA: &str = r#" pub const EPISODIC_MEMORY_SCHEMA: &str = r#"
-- Conversation episodes (summaries) -- Conversation episodes (summaries)
CREATE TABLE IF NOT EXISTS conversation_episodes ( CREATE TABLE IF NOT EXISTS conversation_episodes (
@ -619,9 +547,8 @@ CREATE INDEX IF NOT EXISTS idx_episodes_summary_fts ON conversation_episodes
USING GIN(to_tsvector('english', summary)); USING GIN(to_tsvector('english', summary));
"#; "#;
pub mod sql { pub mod sql {
pub const INSERT_EPISODE: &str = r#" pub const INSERT_EPISODE: &str = r"
INSERT INTO conversation_episodes ( INSERT INTO conversation_episodes (
id, user_id, bot_id, session_id, summary, key_topics, decisions, id, user_id, bot_id, session_id, summary, key_topics, decisions,
action_items, sentiment, resolution, message_count, message_ids, action_items, sentiment, resolution, message_count, message_ids,
@ -629,22 +556,22 @@ pub mod sql {
) VALUES ( ) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16 $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16
) )
"#; ";
pub const GET_EPISODES_BY_USER: &str = r#" pub const GET_EPISODES_BY_USER: &str = r"
SELECT * FROM conversation_episodes SELECT * FROM conversation_episodes
WHERE user_id = $1 WHERE user_id = $1
ORDER BY created_at DESC ORDER BY created_at DESC
LIMIT $2 LIMIT $2
"#; ";
pub const GET_EPISODES_BY_SESSION: &str = r#" pub const GET_EPISODES_BY_SESSION: &str = r"
SELECT * FROM conversation_episodes SELECT * FROM conversation_episodes
WHERE session_id = $1 WHERE session_id = $1
ORDER BY created_at DESC ORDER BY created_at DESC
"#; ";
pub const SEARCH_EPISODES: &str = r#" pub const SEARCH_EPISODES: &str = r"
SELECT * FROM conversation_episodes SELECT * FROM conversation_episodes
WHERE user_id = $1 WHERE user_id = $1
AND ( AND (
@ -653,19 +580,19 @@ pub mod sql {
) )
ORDER BY created_at DESC ORDER BY created_at DESC
LIMIT $4 LIMIT $4
"#; ";
pub const DELETE_OLD_EPISODES: &str = r#" pub const DELETE_OLD_EPISODES: &str = r"
DELETE FROM conversation_episodes DELETE FROM conversation_episodes
WHERE created_at < $1 WHERE created_at < $1
"#; ";
pub const COUNT_USER_EPISODES: &str = r#" pub const COUNT_USER_EPISODES: &str = r"
SELECT COUNT(*) FROM conversation_episodes SELECT COUNT(*) FROM conversation_episodes
WHERE user_id = $1 WHERE user_id = $1
"#; ";
pub const DELETE_OLDEST_EPISODES: &str = r#" pub const DELETE_OLDEST_EPISODES: &str = r"
DELETE FROM conversation_episodes DELETE FROM conversation_episodes
WHERE id IN ( WHERE id IN (
SELECT id FROM conversation_episodes SELECT id FROM conversation_episodes
@ -673,5 +600,5 @@ pub mod sql {
ORDER BY created_at ASC ORDER BY created_at ASC
LIMIT $2 LIMIT $2
) )
"#; ";
} }

View file

@ -3,7 +3,7 @@ use rhai::Engine;
pub fn first_keyword(engine: &mut Engine) { pub fn first_keyword(engine: &mut Engine) {
engine engine
.register_custom_syntax(&["FIRST", "$expr$"], false, { .register_custom_syntax(["FIRST", "$expr$"], false, {
move |context, inputs| { move |context, inputs| {
let input_string = context.eval_expression_tree(&inputs[0])?; let input_string = context.eval_expression_tree(&inputs[0])?;
let input_str = input_string.to_string(); let input_str = input_string.to_string();

View file

@ -1,29 +1,3 @@
use crate::shared::message_types::MessageType; use crate::shared::message_types::MessageType;
use crate::shared::models::{BotResponse, UserSession}; use crate::shared::models::{BotResponse, UserSession};
use crate::shared::state::AppState; use crate::shared::state::AppState;
@ -34,8 +8,7 @@ use serde::{Deserialize, Serialize};
use std::sync::Arc; use std::sync::Arc;
use uuid::Uuid; use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum InputType { pub enum InputType {
Any, Any,
Email, Email,
@ -67,81 +40,74 @@ pub enum InputType {
} }
impl InputType { impl InputType {
#[must_use]
pub fn error_message(&self) -> String { pub fn error_message(&self) -> String {
match self { match self {
InputType::Any => "".to_string(), Self::Any => String::new(),
InputType::Email => { Self::Email => {
"Please enter a valid email address (e.g., user@example.com)".to_string() "Please enter a valid email address (e.g., user@example.com)".to_string()
} }
InputType::Date => { Self::Date => "Please enter a valid date (e.g., 25/12/2024 or 2024-12-25)".to_string(),
"Please enter a valid date (e.g., 25/12/2024 or 2024-12-25)".to_string() Self::Name => "Please enter a valid name (letters and spaces only)".to_string(),
} Self::Integer => "Please enter a valid whole number".to_string(),
InputType::Name => "Please enter a valid name (letters and spaces only)".to_string(), Self::Float => "Please enter a valid number".to_string(),
InputType::Integer => "Please enter a valid whole number".to_string(), Self::Boolean => "Please answer yes or no".to_string(),
InputType::Float => "Please enter a valid number".to_string(), Self::Hour => "Please enter a valid time (e.g., 14:30 or 2:30 PM)".to_string(),
InputType::Boolean => "Please answer yes or no".to_string(), Self::Money => "Please enter a valid amount (e.g., 100.00 or R$ 100,00)".to_string(),
InputType::Hour => "Please enter a valid time (e.g., 14:30 or 2:30 PM)".to_string(), Self::Mobile => "Please enter a valid mobile number".to_string(),
InputType::Money => { Self::Zipcode => "Please enter a valid ZIP/postal code".to_string(),
"Please enter a valid amount (e.g., 100.00 or R$ 100,00)".to_string() Self::Language => "Please enter a valid language code (e.g., en, pt, es)".to_string(),
} Self::Cpf => "Please enter a valid CPF (11 digits)".to_string(),
InputType::Mobile => "Please enter a valid mobile number".to_string(), Self::Cnpj => "Please enter a valid CNPJ (14 digits)".to_string(),
InputType::Zipcode => "Please enter a valid ZIP/postal code".to_string(), Self::QrCode => "Please send an image containing a QR code".to_string(),
InputType::Language => { Self::Login => "Please complete the authentication process".to_string(),
"Please enter a valid language code (e.g., en, pt, es)".to_string() Self::Menu(options) => format!("Please select one of: {}", options.join(", ")),
} Self::File => "Please upload a file".to_string(),
InputType::Cpf => "Please enter a valid CPF (11 digits)".to_string(), Self::Image => "Please send an image".to_string(),
InputType::Cnpj => "Please enter a valid CNPJ (14 digits)".to_string(), Self::Audio => "Please send an audio file or voice message".to_string(),
InputType::QrCode => "Please send an image containing a QR code".to_string(), Self::Video => "Please send a video".to_string(),
InputType::Login => "Please complete the authentication process".to_string(), Self::Document => "Please send a document (PDF, Word, etc.)".to_string(),
InputType::Menu(options) => format!("Please select one of: {}", options.join(", ")), Self::Url => "Please enter a valid URL".to_string(),
InputType::File => "Please upload a file".to_string(), Self::Uuid => "Please enter a valid UUID".to_string(),
InputType::Image => "Please send an image".to_string(), Self::Color => "Please enter a valid color (e.g., #FF0000 or red)".to_string(),
InputType::Audio => "Please send an audio file or voice message".to_string(), Self::CreditCard => "Please enter a valid credit card number".to_string(),
InputType::Video => "Please send a video".to_string(), Self::Password => "Please enter a password (minimum 8 characters)".to_string(),
InputType::Document => "Please send a document (PDF, Word, etc.)".to_string(),
InputType::Url => "Please enter a valid URL".to_string(),
InputType::Uuid => "Please enter a valid UUID".to_string(),
InputType::Color => "Please enter a valid color (e.g., #FF0000 or red)".to_string(),
InputType::CreditCard => "Please enter a valid credit card number".to_string(),
InputType::Password => "Please enter a password (minimum 8 characters)".to_string(),
} }
} }
#[must_use]
pub fn from_str(s: &str) -> Self { pub fn from_str(s: &str) -> Self {
match s.to_uppercase().as_str() { match s.to_uppercase().as_str() {
"EMAIL" => InputType::Email, "EMAIL" => Self::Email,
"DATE" => InputType::Date, "DATE" => Self::Date,
"NAME" => InputType::Name, "NAME" => Self::Name,
"INTEGER" | "INT" | "NUMBER" => InputType::Integer, "INTEGER" | "INT" | "NUMBER" => Self::Integer,
"FLOAT" | "DECIMAL" | "DOUBLE" => InputType::Float, "FLOAT" | "DECIMAL" | "DOUBLE" => Self::Float,
"BOOLEAN" | "BOOL" => InputType::Boolean, "BOOLEAN" | "BOOL" => Self::Boolean,
"HOUR" | "TIME" => InputType::Hour, "HOUR" | "TIME" => Self::Hour,
"MONEY" | "CURRENCY" | "AMOUNT" => InputType::Money, "MONEY" | "CURRENCY" | "AMOUNT" => Self::Money,
"MOBILE" | "PHONE" | "TELEPHONE" => InputType::Mobile, "MOBILE" | "PHONE" | "TELEPHONE" => Self::Mobile,
"ZIPCODE" | "ZIP" | "CEP" | "POSTALCODE" => InputType::Zipcode, "ZIPCODE" | "ZIP" | "CEP" | "POSTALCODE" => Self::Zipcode,
"LANGUAGE" | "LANG" => InputType::Language, "LANGUAGE" | "LANG" => Self::Language,
"CPF" => InputType::Cpf, "CPF" => Self::Cpf,
"CNPJ" => InputType::Cnpj, "CNPJ" => Self::Cnpj,
"QRCODE" | "QR" => InputType::QrCode, "QRCODE" | "QR" => Self::QrCode,
"LOGIN" | "AUTH" => InputType::Login, "LOGIN" | "AUTH" => Self::Login,
"FILE" => InputType::File, "FILE" => Self::File,
"IMAGE" | "PHOTO" | "PICTURE" => InputType::Image, "IMAGE" | "PHOTO" | "PICTURE" => Self::Image,
"AUDIO" | "VOICE" | "SOUND" => InputType::Audio, "AUDIO" | "VOICE" | "SOUND" => Self::Audio,
"VIDEO" => InputType::Video, "VIDEO" => Self::Video,
"DOCUMENT" | "DOC" | "PDF" => InputType::Document, "DOCUMENT" | "DOC" | "PDF" => Self::Document,
"URL" | "LINK" => InputType::Url, "URL" | "LINK" => Self::Url,
"UUID" | "GUID" => InputType::Uuid, "UUID" | "GUID" => Self::Uuid,
"COLOR" | "COLOUR" => InputType::Color, "COLOR" | "COLOUR" => Self::Color,
"CREDITCARD" | "CARD" => InputType::CreditCard, "CREDITCARD" | "CARD" => Self::CreditCard,
"PASSWORD" | "PASS" | "SECRET" => InputType::Password, "PASSWORD" | "PASS" | "SECRET" => Self::Password,
_ => InputType::Any, _ => Self::Any,
} }
} }
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ValidationResult { pub struct ValidationResult {
pub is_valid: bool, pub is_valid: bool,
@ -151,6 +117,7 @@ pub struct ValidationResult {
} }
impl ValidationResult { impl ValidationResult {
#[must_use]
pub fn valid(value: String) -> Self { pub fn valid(value: String) -> Self {
Self { Self {
is_valid: true, is_valid: true,
@ -160,6 +127,7 @@ impl ValidationResult {
} }
} }
#[must_use]
pub fn valid_with_metadata(value: String, metadata: serde_json::Value) -> Self { pub fn valid_with_metadata(value: String, metadata: serde_json::Value) -> Self {
Self { Self {
is_valid: true, is_valid: true,
@ -169,6 +137,7 @@ impl ValidationResult {
} }
} }
#[must_use]
pub fn invalid(error: String) -> Self { pub fn invalid(error: String) -> Self {
Self { Self {
is_valid: false, is_valid: false,
@ -179,26 +148,20 @@ impl ValidationResult {
} }
} }
pub fn hear_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) { pub fn hear_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
register_hear_basic(state.clone(), user.clone(), engine); register_hear_basic(state.clone(), user.clone(), engine);
register_hear_as_type(state.clone(), user.clone(), engine); register_hear_as_type(state.clone(), user.clone(), engine);
register_hear_as_menu(state.clone(), user.clone(), engine); register_hear_as_menu(state.clone(), user.clone(), engine);
} }
fn register_hear_basic(state: Arc<AppState>, user: UserSession, engine: &mut Engine) { fn register_hear_basic(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
let session_id = user.id; let session_id = user.id;
let state_clone = Arc::clone(&state); let state_clone = Arc::clone(&state);
engine engine
.register_custom_syntax(&["HEAR", "$ident$"], true, move |_context, inputs| { .register_custom_syntax(&["HEAR", "$ident$"], true, move |_context, inputs| {
let variable_name = inputs[0] let variable_name = inputs[0]
.get_string_value() .get_string_value()
.expect("Expected identifier as string") .expect("Expected identifier as string")
@ -223,10 +186,9 @@ fn register_hear_basic(state: Arc<AppState>, user: UserSession, engine: &mut Eng
let mut session_manager = state_for_spawn.session_manager.lock().await; let mut session_manager = state_for_spawn.session_manager.lock().await;
session_manager.mark_waiting(session_id_clone); session_manager.mark_waiting(session_id_clone);
if let Some(redis_client) = &state_for_spawn.cache { if let Some(redis_client) = &state_for_spawn.cache {
if let Ok(mut conn) = redis_client.get_multiplexed_async_connection().await { if let Ok(mut conn) = redis_client.get_multiplexed_async_connection().await {
let key = format!("hear:{}:{}", session_id_clone, var_name_clone); let key = format!("hear:{session_id_clone}:{var_name_clone}");
let wait_data = serde_json::json!({ let wait_data = serde_json::json!({
"variable": var_name_clone, "variable": var_name_clone,
"type": "any", "type": "any",
@ -252,7 +214,6 @@ fn register_hear_basic(state: Arc<AppState>, user: UserSession, engine: &mut Eng
.unwrap(); .unwrap();
} }
fn register_hear_as_type(state: Arc<AppState>, user: UserSession, engine: &mut Engine) { fn register_hear_as_type(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
let session_id = user.id; let session_id = user.id;
let state_clone = Arc::clone(&state); let state_clone = Arc::clone(&state);
@ -262,7 +223,6 @@ fn register_hear_as_type(state: Arc<AppState>, user: UserSession, engine: &mut E
&["HEAR", "$ident$", "AS", "$ident$"], &["HEAR", "$ident$", "AS", "$ident$"],
true, true,
move |_context, inputs| { move |_context, inputs| {
let variable_name = inputs[0] let variable_name = inputs[0]
.get_string_value() .get_string_value()
.expect("Expected identifier for variable") .expect("Expected identifier for variable")
@ -274,11 +234,7 @@ fn register_hear_as_type(state: Arc<AppState>, user: UserSession, engine: &mut E
let _input_type = InputType::from_str(&type_name); let _input_type = InputType::from_str(&type_name);
trace!( trace!("HEAR {variable_name} AS {type_name} - waiting for validated input");
"HEAR {} AS {} - waiting for validated input",
variable_name,
type_name
);
let state_for_spawn = Arc::clone(&state_clone); let state_for_spawn = Arc::clone(&state_clone);
let session_id_clone = session_id; let session_id_clone = session_id;
@ -289,11 +245,10 @@ fn register_hear_as_type(state: Arc<AppState>, user: UserSession, engine: &mut E
let mut session_manager = state_for_spawn.session_manager.lock().await; let mut session_manager = state_for_spawn.session_manager.lock().await;
session_manager.mark_waiting(session_id_clone); session_manager.mark_waiting(session_id_clone);
if let Some(redis_client) = &state_for_spawn.cache { if let Some(redis_client) = &state_for_spawn.cache {
if let Ok(mut conn) = redis_client.get_multiplexed_async_connection().await if let Ok(mut conn) = redis_client.get_multiplexed_async_connection().await
{ {
let key = format!("hear:{}:{}", session_id_clone, var_name_clone); let key = format!("hear:{session_id_clone}:{var_name_clone}");
let wait_data = serde_json::json!({ let wait_data = serde_json::json!({
"variable": var_name_clone, "variable": var_name_clone,
"type": type_clone.to_lowercase(), "type": type_clone.to_lowercase(),
@ -321,44 +276,34 @@ fn register_hear_as_type(state: Arc<AppState>, user: UserSession, engine: &mut E
.unwrap(); .unwrap();
} }
fn register_hear_as_menu(state: Arc<AppState>, user: UserSession, engine: &mut Engine) { fn register_hear_as_menu(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
let session_id = user.id; let session_id = user.id;
let state_clone = Arc::clone(&state); let state_clone = Arc::clone(&state);
engine engine
.register_custom_syntax( .register_custom_syntax(
&["HEAR", "$ident$", "AS", "$expr$"], &["HEAR", "$ident$", "AS", "$expr$"],
true, true,
move |context, inputs| { move |context, inputs| {
let variable_name = inputs[0] let variable_name = inputs[0]
.get_string_value() .get_string_value()
.expect("Expected identifier for variable") .expect("Expected identifier for variable")
.to_lowercase(); .to_lowercase();
let options_expr = context.eval_expression_tree(&inputs[1])?; let options_expr = context.eval_expression_tree(&inputs[1])?;
let options_str = options_expr.to_string(); let options_str = options_expr.to_string();
let input_type = InputType::from_str(&options_str); let input_type = InputType::from_str(&options_str);
if input_type != InputType::Any { if input_type != InputType::Any {
return Err(Box::new(EvalAltResult::ErrorRuntime( return Err(Box::new(EvalAltResult::ErrorRuntime(
"Use HEAR AS TYPE syntax".into(), "Use HEAR AS TYPE syntax".into(),
rhai::Position::NONE, rhai::Position::NONE,
))); )));
} }
let options: Vec<String> = if options_str.starts_with('[') { let options: Vec<String> = if options_str.starts_with('[') {
serde_json::from_str(&options_str).unwrap_or_default() serde_json::from_str(&options_str).unwrap_or_default()
} else { } else {
options_str options_str
.split(',') .split(',')
.map(|s| s.trim().trim_matches('"').to_string()) .map(|s| s.trim().trim_matches('"').to_string())
@ -384,11 +329,10 @@ fn register_hear_as_menu(state: Arc<AppState>, user: UserSession, engine: &mut E
let mut session_manager = state_for_spawn.session_manager.lock().await; let mut session_manager = state_for_spawn.session_manager.lock().await;
session_manager.mark_waiting(session_id_clone); session_manager.mark_waiting(session_id_clone);
if let Some(redis_client) = &state_for_spawn.cache { if let Some(redis_client) = &state_for_spawn.cache {
if let Ok(mut conn) = redis_client.get_multiplexed_async_connection().await if let Ok(mut conn) = redis_client.get_multiplexed_async_connection().await
{ {
let key = format!("hear:{}:{}", session_id_clone, var_name_clone); let key = format!("hear:{session_id_clone}:{var_name_clone}");
let wait_data = serde_json::json!({ let wait_data = serde_json::json!({
"variable": var_name_clone, "variable": var_name_clone,
"type": "menu", "type": "menu",
@ -404,9 +348,8 @@ fn register_hear_as_menu(state: Arc<AppState>, user: UserSession, engine: &mut E
.query_async(&mut conn) .query_async(&mut conn)
.await; .await;
let suggestions_key = let suggestions_key =
format!("suggestions:{}:{}", session_id_clone, session_id_clone); format!("suggestions:{session_id_clone}:{session_id_clone}");
for opt in &options_clone { for opt in &options_clone {
let suggestion = serde_json::json!({ let suggestion = serde_json::json!({
"text": opt, "text": opt,
@ -431,15 +374,19 @@ fn register_hear_as_menu(state: Arc<AppState>, user: UserSession, engine: &mut E
.unwrap(); .unwrap();
} }
#[must_use]
pub fn validate_input(input: &str, input_type: &InputType) -> ValidationResult { pub fn validate_input(input: &str, input_type: &InputType) -> ValidationResult {
let trimmed = input.trim(); let trimmed = input.trim();
match input_type { match input_type {
InputType::Any => ValidationResult::valid(trimmed.to_string()), InputType::Any
| InputType::QrCode
| InputType::File
| InputType::Image
| InputType::Audio
| InputType::Video
| InputType::Document
| InputType::Login => ValidationResult::valid(trimmed.to_string()),
InputType::Email => validate_email(trimmed), InputType::Email => validate_email(trimmed),
InputType::Date => validate_date(trimmed), InputType::Date => validate_date(trimmed),
InputType::Name => validate_name(trimmed), InputType::Name => validate_name(trimmed),
@ -458,18 +405,7 @@ pub fn validate_input(input: &str, input_type: &InputType) -> ValidationResult {
InputType::Color => validate_color(trimmed), InputType::Color => validate_color(trimmed),
InputType::CreditCard => validate_credit_card(trimmed), InputType::CreditCard => validate_credit_card(trimmed),
InputType::Password => validate_password(trimmed), InputType::Password => validate_password(trimmed),
InputType::Menu(options) => validate_menu(trimmed, options), InputType::Menu(options) => validate_menu(trimmed, options),
InputType::QrCode
| InputType::File
| InputType::Image
| InputType::Audio
| InputType::Video
| InputType::Document => ValidationResult::valid(trimmed.to_string()),
InputType::Login => ValidationResult::valid(trimmed.to_string()),
} }
} }
@ -486,32 +422,23 @@ fn validate_email(input: &str) -> ValidationResult {
} }
fn validate_date(input: &str) -> ValidationResult { fn validate_date(input: &str) -> ValidationResult {
let formats = [ let formats = [
"%d/%m/%Y", "%d/%m/%Y", "%d-%m-%Y", "%Y-%m-%d", "%Y/%m/%d", "%d.%m.%Y", "%m/%d/%Y", "%d %b %Y",
"%d-%m-%Y",
"%Y-%m-%d",
"%Y/%m/%d",
"%d.%m.%Y",
"%m/%d/%Y",
"%d %b %Y",
"%d %B %Y", "%d %B %Y",
]; ];
for format in &formats { for format in &formats {
if let Ok(date) = chrono::NaiveDate::parse_from_str(input, format) { if let Ok(date) = chrono::NaiveDate::parse_from_str(input, format) {
return ValidationResult::valid_with_metadata( return ValidationResult::valid_with_metadata(
date.format("%Y-%m-%d").to_string(), date.format("%Y-%m-%d").to_string(),
serde_json::json!({ serde_json::json!({
"original": input, "original": input,
"parsed_format": format "parsed_format": *format
}), }),
); );
} }
} }
let lower = input.to_lowercase(); let lower = input.to_lowercase();
let today = chrono::Local::now().date_naive(); let today = chrono::Local::now().date_naive();
@ -537,7 +464,6 @@ fn validate_date(input: &str) -> ValidationResult {
} }
fn validate_name(input: &str) -> ValidationResult { fn validate_name(input: &str) -> ValidationResult {
let name_regex = Regex::new(r"^[\p{L}\s\-']+$").unwrap(); let name_regex = Regex::new(r"^[\p{L}\s\-']+$").unwrap();
if input.len() < 2 { if input.len() < 2 {
@ -549,7 +475,6 @@ fn validate_name(input: &str) -> ValidationResult {
} }
if name_regex.is_match(input) { if name_regex.is_match(input) {
let normalized = input let normalized = input
.split_whitespace() .split_whitespace()
.map(|word| { .map(|word| {
@ -568,11 +493,10 @@ fn validate_name(input: &str) -> ValidationResult {
} }
fn validate_integer(input: &str) -> ValidationResult { fn validate_integer(input: &str) -> ValidationResult {
let cleaned = input let cleaned = input
.replace(",", "") .replace(',', "")
.replace(".", "") .replace('.', "")
.replace(" ", "") .replace(' ', "")
.trim() .trim()
.to_string(); .to_string();
@ -586,7 +510,6 @@ fn validate_integer(input: &str) -> ValidationResult {
} }
fn validate_float(input: &str) -> ValidationResult { fn validate_float(input: &str) -> ValidationResult {
let cleaned = input.replace(" ", "").replace(",", ".").trim().to_string(); let cleaned = input.replace(" ", "").replace(",", ".").trim().to_string();
match cleaned.parse::<f64>() { match cleaned.parse::<f64>() {
@ -644,7 +567,6 @@ fn validate_boolean(input: &str) -> ValidationResult {
} }
fn validate_hour(input: &str) -> ValidationResult { fn validate_hour(input: &str) -> ValidationResult {
let time_24_regex = Regex::new(r"^([01]?\d|2[0-3]):([0-5]\d)$").unwrap(); let time_24_regex = Regex::new(r"^([01]?\d|2[0-3]):([0-5]\d)$").unwrap();
if let Some(caps) = time_24_regex.captures(input) { if let Some(caps) = time_24_regex.captures(input) {
let hour: u32 = caps[1].parse().unwrap(); let hour: u32 = caps[1].parse().unwrap();
@ -655,7 +577,6 @@ fn validate_hour(input: &str) -> ValidationResult {
); );
} }
let time_12_regex = let time_12_regex =
Regex::new(r"^(1[0-2]|0?[1-9]):([0-5]\d)\s*(AM|PM|am|pm|a\.m\.|p\.m\.)$").unwrap(); Regex::new(r"^(1[0-2]|0?[1-9]):([0-5]\d)\s*(AM|PM|am|pm|a\.m\.|p\.m\.)$").unwrap();
if let Some(caps) = time_12_regex.captures(input) { if let Some(caps) = time_12_regex.captures(input) {
@ -679,7 +600,6 @@ fn validate_hour(input: &str) -> ValidationResult {
} }
fn validate_money(input: &str) -> ValidationResult { fn validate_money(input: &str) -> ValidationResult {
let cleaned = input let cleaned = input
.replace("R$", "") .replace("R$", "")
.replace("$", "") .replace("$", "")
@ -690,21 +610,16 @@ fn validate_money(input: &str) -> ValidationResult {
.trim() .trim()
.to_string(); .to_string();
let normalized = if cleaned.contains(',') && cleaned.contains('.') { let normalized = if cleaned.contains(',') && cleaned.contains('.') {
let last_comma = cleaned.rfind(',').unwrap_or(0); let last_comma = cleaned.rfind(',').unwrap_or(0);
let last_dot = cleaned.rfind('.').unwrap_or(0); let last_dot = cleaned.rfind('.').unwrap_or(0);
if last_comma > last_dot { if last_comma > last_dot {
cleaned.replace('.', "").replace(',', ".")
cleaned.replace(".", "").replace(",", ".")
} else { } else {
cleaned.replace(",", "") cleaned.replace(",", "")
} }
} else if cleaned.contains(',') { } else if cleaned.contains(',') {
cleaned.replace(",", ".") cleaned.replace(",", ".")
} else { } else {
cleaned cleaned
@ -720,26 +635,19 @@ fn validate_money(input: &str) -> ValidationResult {
} }
fn validate_mobile(input: &str) -> ValidationResult { fn validate_mobile(input: &str) -> ValidationResult {
let digits: String = input.chars().filter(|c| c.is_ascii_digit()).collect(); let digits: String = input.chars().filter(|c| c.is_ascii_digit()).collect();
if digits.len() < 10 || digits.len() > 15 { if digits.len() < 10 || digits.len() > 15 {
return ValidationResult::invalid(InputType::Mobile.error_message()); return ValidationResult::invalid(InputType::Mobile.error_message());
} }
let formatted = if digits.len() == 11 && digits.starts_with('9') { let formatted = if digits.len() == 11 && digits.starts_with('9') {
format!("({}) {}-{}", &digits[0..2], &digits[2..7], &digits[7..11]) format!("({}) {}-{}", &digits[0..2], &digits[2..7], &digits[7..11])
} else if digits.len() == 11 { } else if digits.len() == 11 {
format!("({}) {}-{}", &digits[0..2], &digits[2..7], &digits[7..11]) format!("({}) {}-{}", &digits[0..2], &digits[2..7], &digits[7..11])
} else if digits.len() == 10 { } else if digits.len() == 10 {
format!("({}) {}-{}", &digits[0..3], &digits[3..6], &digits[6..10]) format!("({}) {}-{}", &digits[0..3], &digits[3..6], &digits[6..10])
} else { } else {
format!("+{}", digits) format!("+{}", digits)
}; };
@ -750,13 +658,11 @@ fn validate_mobile(input: &str) -> ValidationResult {
} }
fn validate_zipcode(input: &str) -> ValidationResult { fn validate_zipcode(input: &str) -> ValidationResult {
let cleaned: String = input let cleaned: String = input
.chars() .chars()
.filter(|c| c.is_ascii_alphanumeric()) .filter(|c| c.is_ascii_alphanumeric())
.collect(); .collect();
if cleaned.len() == 8 && cleaned.chars().all(|c| c.is_ascii_digit()) { if cleaned.len() == 8 && cleaned.chars().all(|c| c.is_ascii_digit()) {
let formatted = format!("{}-{}", &cleaned[0..5], &cleaned[5..8]); let formatted = format!("{}-{}", &cleaned[0..5], &cleaned[5..8]);
return ValidationResult::valid_with_metadata( return ValidationResult::valid_with_metadata(
@ -765,10 +671,9 @@ fn validate_zipcode(input: &str) -> ValidationResult {
); );
} }
if (cleaned.len() == 5 || cleaned.len() == 9) && cleaned.chars().all(|c| c.is_ascii_digit()) { if (cleaned.len() == 5 || cleaned.len() == 9) && cleaned.chars().all(|c| c.is_ascii_digit()) {
let formatted = if cleaned.len() == 9 { let formatted = if cleaned.len() == 9 {
format!("{}-{}", &cleaned[0..5], &cleaned[5..9]) format!("{}-{}", &cleaned[0..5], &cleaned[5..8])
} else { } else {
cleaned.clone() cleaned.clone()
}; };
@ -778,7 +683,6 @@ fn validate_zipcode(input: &str) -> ValidationResult {
); );
} }
let uk_regex = Regex::new(r"^[A-Z]{1,2}\d[A-Z\d]?\s?\d[A-Z]{2}$").unwrap(); let uk_regex = Regex::new(r"^[A-Z]{1,2}\d[A-Z\d]?\s?\d[A-Z]{2}$").unwrap();
if uk_regex.is_match(&cleaned.to_uppercase()) { if uk_regex.is_match(&cleaned.to_uppercase()) {
return ValidationResult::valid_with_metadata( return ValidationResult::valid_with_metadata(
@ -793,7 +697,6 @@ fn validate_zipcode(input: &str) -> ValidationResult {
fn validate_language(input: &str) -> ValidationResult { fn validate_language(input: &str) -> ValidationResult {
let lower = input.to_lowercase().trim().to_string(); let lower = input.to_lowercase().trim().to_string();
let languages = [ let languages = [
("en", "english", "inglês", "ingles"), ("en", "english", "inglês", "ingles"),
("pt", "portuguese", "português", "portugues"), ("pt", "portuguese", "português", "portugues"),
@ -827,7 +730,6 @@ fn validate_language(input: &str) -> ValidationResult {
} }
} }
if lower.len() == 2 && lower.chars().all(|c| c.is_ascii_lowercase()) { if lower.len() == 2 && lower.chars().all(|c| c.is_ascii_lowercase()) {
return ValidationResult::valid(lower); return ValidationResult::valid(lower);
} }
@ -836,21 +738,17 @@ fn validate_language(input: &str) -> ValidationResult {
} }
fn validate_cpf(input: &str) -> ValidationResult { fn validate_cpf(input: &str) -> ValidationResult {
let digits: String = input.chars().filter(|c| c.is_ascii_digit()).collect(); let digits: String = input.chars().filter(|c| c.is_ascii_digit()).collect();
if digits.len() != 11 { if digits.len() != 11 {
return ValidationResult::invalid(InputType::Cpf.error_message()); return ValidationResult::invalid(InputType::Cpf.error_message());
} }
if digits.chars().all(|c| c == digits.chars().next().unwrap()) { if digits.chars().all(|c| c == digits.chars().next().unwrap()) {
return ValidationResult::invalid("Invalid CPF".to_string()); return ValidationResult::invalid("Invalid CPF".to_string());
} }
let digits_vec: Vec<u32> = digits.chars().filter_map(|c| c.to_digit(10)).collect();
let digits_vec: Vec<u32> = digits.chars().map(|c| c.to_digit(10).unwrap()).collect();
let sum1: u32 = digits_vec[0..9] let sum1: u32 = digits_vec[0..9]
.iter() .iter()
@ -864,7 +762,6 @@ fn validate_cpf(input: &str) -> ValidationResult {
return ValidationResult::invalid("Invalid CPF".to_string()); return ValidationResult::invalid("Invalid CPF".to_string());
} }
let sum2: u32 = digits_vec[0..10] let sum2: u32 = digits_vec[0..10]
.iter() .iter()
.enumerate() .enumerate()
@ -877,7 +774,6 @@ fn validate_cpf(input: &str) -> ValidationResult {
return ValidationResult::invalid("Invalid CPF".to_string()); return ValidationResult::invalid("Invalid CPF".to_string());
} }
let formatted = format!( let formatted = format!(
"{}.{}.{}-{}", "{}.{}.{}-{}",
&digits[0..3], &digits[0..3],
@ -893,16 +789,13 @@ fn validate_cpf(input: &str) -> ValidationResult {
} }
fn validate_cnpj(input: &str) -> ValidationResult { fn validate_cnpj(input: &str) -> ValidationResult {
let digits: String = input.chars().filter(|c| c.is_ascii_digit()).collect(); let digits: String = input.chars().filter(|c| c.is_ascii_digit()).collect();
if digits.len() != 14 { if digits.len() != 14 {
return ValidationResult::invalid(InputType::Cnpj.error_message()); return ValidationResult::invalid(InputType::Cnpj.error_message());
} }
let digits_vec: Vec<u32> = digits.chars().filter_map(|c| c.to_digit(10)).collect();
let digits_vec: Vec<u32> = digits.chars().map(|c| c.to_digit(10).unwrap()).collect();
let weights1 = [5, 4, 3, 2, 9, 8, 7, 6, 5, 4, 3, 2]; let weights1 = [5, 4, 3, 2, 9, 8, 7, 6, 5, 4, 3, 2];
let sum1: u32 = digits_vec[0..12] let sum1: u32 = digits_vec[0..12]
@ -917,7 +810,6 @@ fn validate_cnpj(input: &str) -> ValidationResult {
return ValidationResult::invalid("Invalid CNPJ".to_string()); return ValidationResult::invalid("Invalid CNPJ".to_string());
} }
let weights2 = [6, 5, 4, 3, 2, 9, 8, 7, 6, 5, 4, 3, 2]; let weights2 = [6, 5, 4, 3, 2, 9, 8, 7, 6, 5, 4, 3, 2];
let sum2: u32 = digits_vec[0..13] let sum2: u32 = digits_vec[0..13]
.iter() .iter()
@ -931,7 +823,6 @@ fn validate_cnpj(input: &str) -> ValidationResult {
return ValidationResult::invalid("Invalid CNPJ".to_string()); return ValidationResult::invalid("Invalid CNPJ".to_string());
} }
let formatted = format!( let formatted = format!(
"{}.{}.{}/{}-{}", "{}.{}.{}/{}-{}",
&digits[0..2], &digits[0..2],
@ -948,9 +839,8 @@ fn validate_cnpj(input: &str) -> ValidationResult {
} }
fn validate_url(input: &str) -> ValidationResult { fn validate_url(input: &str) -> ValidationResult {
let url_str = if !input.starts_with("http://") && !input.starts_with("https://") { let url_str = if !input.starts_with("http://") && !input.starts_with("https://") {
format!("https://{}", input) format!("https://{input}")
} else { } else {
input.to_string() input.to_string()
}; };
@ -976,7 +866,6 @@ fn validate_uuid(input: &str) -> ValidationResult {
fn validate_color(input: &str) -> ValidationResult { fn validate_color(input: &str) -> ValidationResult {
let lower = input.to_lowercase().trim().to_string(); let lower = input.to_lowercase().trim().to_string();
let named_colors = [ let named_colors = [
("red", "#FF0000"), ("red", "#FF0000"),
("green", "#00FF00"), ("green", "#00FF00"),
@ -1003,7 +892,6 @@ fn validate_color(input: &str) -> ValidationResult {
} }
} }
let hex_regex = Regex::new(r"^#?([A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$").unwrap(); let hex_regex = Regex::new(r"^#?([A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$").unwrap();
if let Some(caps) = hex_regex.captures(&lower) { if let Some(caps) = hex_regex.captures(&lower) {
let hex = caps[1].to_uppercase(); let hex = caps[1].to_uppercase();
@ -1017,7 +905,6 @@ fn validate_color(input: &str) -> ValidationResult {
return ValidationResult::valid(format!("#{}", full_hex)); return ValidationResult::valid(format!("#{}", full_hex));
} }
let rgb_regex = let rgb_regex =
Regex::new(r"^rgb\s*\(\s*(\d{1,3})\s*,\s*(\d{1,3})\s*,\s*(\d{1,3})\s*\)$").unwrap(); Regex::new(r"^rgb\s*\(\s*(\d{1,3})\s*,\s*(\d{1,3})\s*,\s*(\d{1,3})\s*\)$").unwrap();
if let Some(caps) = rgb_regex.captures(&lower) { if let Some(caps) = rgb_regex.captures(&lower) {
@ -1031,14 +918,12 @@ fn validate_color(input: &str) -> ValidationResult {
} }
fn validate_credit_card(input: &str) -> ValidationResult { fn validate_credit_card(input: &str) -> ValidationResult {
let digits: String = input.chars().filter(|c| c.is_ascii_digit()).collect(); let digits: String = input.chars().filter(|c| c.is_ascii_digit()).collect();
if digits.len() < 13 || digits.len() > 19 { if digits.len() < 13 || digits.len() > 19 {
return ValidationResult::invalid(InputType::CreditCard.error_message()); return ValidationResult::invalid(InputType::CreditCard.error_message());
} }
let mut sum = 0; let mut sum = 0;
let mut double = false; let mut double = false;
@ -1058,7 +943,6 @@ fn validate_credit_card(input: &str) -> ValidationResult {
return ValidationResult::invalid("Invalid card number".to_string()); return ValidationResult::invalid("Invalid card number".to_string());
} }
let card_type = if digits.starts_with('4') { let card_type = if digits.starts_with('4') {
"Visa" "Visa"
} else if digits.starts_with("51") } else if digits.starts_with("51")
@ -1078,7 +962,6 @@ fn validate_credit_card(input: &str) -> ValidationResult {
"Unknown" "Unknown"
}; };
let masked = format!( let masked = format!(
"{} **** **** {}", "{} **** **** {}",
&digits[0..4], &digits[0..4],
@ -1113,7 +996,6 @@ fn validate_password(input: &str) -> ValidationResult {
_ => "weak", _ => "weak",
}; };
ValidationResult::valid_with_metadata( ValidationResult::valid_with_metadata(
"[PASSWORD SET]".to_string(), "[PASSWORD SET]".to_string(),
serde_json::json!({ serde_json::json!({
@ -1126,7 +1008,6 @@ fn validate_password(input: &str) -> ValidationResult {
fn validate_menu(input: &str, options: &[String]) -> ValidationResult { fn validate_menu(input: &str, options: &[String]) -> ValidationResult {
let lower_input = input.to_lowercase().trim().to_string(); let lower_input = input.to_lowercase().trim().to_string();
for (i, opt) in options.iter().enumerate() { for (i, opt) in options.iter().enumerate() {
if opt.to_lowercase() == lower_input { if opt.to_lowercase() == lower_input {
return ValidationResult::valid_with_metadata( return ValidationResult::valid_with_metadata(
@ -1136,7 +1017,6 @@ fn validate_menu(input: &str, options: &[String]) -> ValidationResult {
} }
} }
if let Ok(num) = lower_input.parse::<usize>() { if let Ok(num) = lower_input.parse::<usize>() {
if num >= 1 && num <= options.len() { if num >= 1 && num <= options.len() {
let selected = &options[num - 1]; let selected = &options[num - 1];
@ -1147,7 +1027,6 @@ fn validate_menu(input: &str, options: &[String]) -> ValidationResult {
} }
} }
let matches: Vec<&String> = options let matches: Vec<&String> = options
.iter() .iter()
.filter(|opt| opt.to_lowercase().contains(&lower_input)) .filter(|opt| opt.to_lowercase().contains(&lower_input))
@ -1161,11 +1040,10 @@ fn validate_menu(input: &str, options: &[String]) -> ValidationResult {
); );
} }
ValidationResult::invalid(format!("Please select one of: {}", options.join(", "))) let opts = options.join(", ");
ValidationResult::invalid(format!("Please select one of: {opts}"))
} }
pub async fn execute_talk( pub async fn execute_talk(
state: Arc<AppState>, state: Arc<AppState>,
user_session: UserSession, user_session: UserSession,
@ -1173,7 +1051,6 @@ pub async fn execute_talk(
) -> Result<BotResponse, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<BotResponse, Box<dyn std::error::Error + Send + Sync>> {
let mut suggestions = Vec::new(); let mut suggestions = Vec::new();
if let Some(redis_client) = &state.cache { if let Some(redis_client) = &state.cache {
if let Ok(mut conn) = redis_client.get_multiplexed_async_connection().await { if let Ok(mut conn) = redis_client.get_multiplexed_async_connection().await {
let redis_key = format!("suggestions:{}:{}", user_session.user_id, user_session.id); let redis_key = format!("suggestions:{}:{}", user_session.user_id, user_session.id);
@ -1212,7 +1089,6 @@ pub async fn execute_talk(
let user_id = user_session.id.to_string(); let user_id = user_session.id.to_string();
let response_clone = response.clone(); let response_clone = response.clone();
let web_adapter = Arc::clone(&state.web_adapter); let web_adapter = Arc::clone(&state.web_adapter);
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = web_adapter if let Err(e) = web_adapter
@ -1249,9 +1125,6 @@ pub fn talk_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine
.unwrap(); .unwrap();
} }
pub async fn process_hear_input( pub async fn process_hear_input(
state: &AppState, state: &AppState,
session_id: Uuid, session_id: Uuid,
@ -1259,10 +1132,9 @@ pub async fn process_hear_input(
input: &str, input: &str,
attachments: Option<Vec<crate::shared::models::Attachment>>, attachments: Option<Vec<crate::shared::models::Attachment>>,
) -> Result<(String, Option<serde_json::Value>), String> { ) -> Result<(String, Option<serde_json::Value>), String> {
let wait_data = if let Some(redis_client) = &state.cache { let wait_data = if let Some(redis_client) = &state.cache {
if let Ok(mut conn) = redis_client.get_multiplexed_async_connection().await { if let Ok(mut conn) = redis_client.get_multiplexed_async_connection().await {
let key = format!("hear:{}:{}", session_id, variable_name); let key = format!("hear:{session_id}:{variable_name}");
let data: Result<String, _> = redis::cmd("GET").arg(&key).query_async(&mut conn).await; let data: Result<String, _> = redis::cmd("GET").arg(&key).query_async(&mut conn).await;
@ -1293,14 +1165,12 @@ pub async fn process_hear_input(
.collect::<Vec<_>>() .collect::<Vec<_>>()
}); });
let validation_type = if let Some(opts) = options { let validation_type = if let Some(opts) = options {
InputType::Menu(opts) InputType::Menu(opts)
} else { } else {
InputType::from_str(input_type) InputType::from_str(input_type)
}; };
match validation_type { match validation_type {
InputType::Image | InputType::QrCode => { InputType::Image | InputType::QrCode => {
if let Some(atts) = &attachments { if let Some(atts) = &attachments {
@ -1309,7 +1179,6 @@ pub async fn process_hear_input(
.find(|a| a.mime_type.as_deref().unwrap_or("").starts_with("image/")) .find(|a| a.mime_type.as_deref().unwrap_or("").starts_with("image/"))
{ {
if validation_type == InputType::QrCode { if validation_type == InputType::QrCode {
return process_qrcode(state, &img.url).await; return process_qrcode(state, &img.url).await;
} }
return Ok(( return Ok((
@ -1326,7 +1195,6 @@ pub async fn process_hear_input(
.iter() .iter()
.find(|a| a.mime_type.as_deref().unwrap_or("").starts_with("audio/")) .find(|a| a.mime_type.as_deref().unwrap_or("").starts_with("audio/"))
{ {
return process_audio_to_text(state, &audio.url).await; return process_audio_to_text(state, &audio.url).await;
} }
} }
@ -1338,7 +1206,6 @@ pub async fn process_hear_input(
.iter() .iter()
.find(|a| a.mime_type.as_deref().unwrap_or("").starts_with("video/")) .find(|a| a.mime_type.as_deref().unwrap_or("").starts_with("video/"))
{ {
return process_video_description(state, &video.url).await; return process_video_description(state, &video.url).await;
} }
} }
@ -1358,14 +1225,12 @@ pub async fn process_hear_input(
_ => {} _ => {}
} }
let result = validate_input(input, &validation_type); let result = validate_input(input, &validation_type);
if result.is_valid { if result.is_valid {
if let Some(redis_client) = &state.cache { if let Some(redis_client) = &state.cache {
if let Ok(mut conn) = redis_client.get_multiplexed_async_connection().await { if let Ok(mut conn) = redis_client.get_multiplexed_async_connection().await {
let key = format!("hear:{}:{}", session_id, variable_name); let key = format!("hear:{session_id}:{variable_name}");
let _: Result<(), _> = redis::cmd("DEL").arg(&key).query_async(&mut conn).await; let _: Result<(), _> = redis::cmd("DEL").arg(&key).query_async(&mut conn).await;
} }
} }
@ -1378,12 +1243,10 @@ pub async fn process_hear_input(
} }
} }
async fn process_qrcode( async fn process_qrcode(
state: &AppState, state: &AppState,
image_url: &str, image_url: &str,
) -> Result<(String, Option<serde_json::Value>), String> { ) -> Result<(String, Option<serde_json::Value>), String> {
let botmodels_url = { let botmodels_url = {
let config_url = state.conn.get().ok().and_then(|mut conn| { let config_url = state.conn.get().ok().and_then(|mut conn| {
use crate::shared::models::schema::bot_memories::dsl::*; use crate::shared::models::schema::bot_memories::dsl::*;
@ -1401,7 +1264,6 @@ async fn process_qrcode(
let client = reqwest::Client::new(); let client = reqwest::Client::new();
let image_data = client let image_data = client
.get(image_url) .get(image_url)
.send() .send()
@ -1409,11 +1271,10 @@ async fn process_qrcode(
.map_err(|e| format!("Failed to download image: {}", e))? .map_err(|e| format!("Failed to download image: {}", e))?
.bytes() .bytes()
.await .await
.map_err(|e| format!("Failed to read image: {}", e))?; .map_err(|e| format!("Failed to fetch image: {e}"))?;
let response = client let response = client
.post(format!("{}/api/v1/vision/qrcode", botmodels_url)) .post(format!("{botmodels_url}/api/v1/vision/qrcode"))
.header("Content-Type", "application/octet-stream") .header("Content-Type", "application/octet-stream")
.body(image_data.to_vec()) .body(image_data.to_vec())
.send() .send()
@ -1424,7 +1285,7 @@ async fn process_qrcode(
let result: serde_json::Value = response let result: serde_json::Value = response
.json() .json()
.await .await
.map_err(|e| format!("Failed to parse response: {}", e))?; .map_err(|e| format!("Failed to read image: {e}"))?;
if let Some(qr_data) = result.get("data").and_then(|d| d.as_str()) { if let Some(qr_data) = result.get("data").and_then(|d| d.as_str()) {
Ok(( Ok((
@ -1442,7 +1303,6 @@ async fn process_qrcode(
} }
} }
async fn process_audio_to_text( async fn process_audio_to_text(
_state: &AppState, _state: &AppState,
audio_url: &str, audio_url: &str,
@ -1452,7 +1312,6 @@ async fn process_audio_to_text(
let client = reqwest::Client::new(); let client = reqwest::Client::new();
let audio_data = client let audio_data = client
.get(audio_url) .get(audio_url)
.send() .send()
@ -1460,11 +1319,10 @@ async fn process_audio_to_text(
.map_err(|e| format!("Failed to download audio: {}", e))? .map_err(|e| format!("Failed to download audio: {}", e))?
.bytes() .bytes()
.await .await
.map_err(|e| format!("Failed to read audio: {}", e))?; .map_err(|e| format!("Failed to read audio: {e}"))?;
let response = client let response = client
.post(format!("{}/api/v1/speech/to-text", botmodels_url)) .post(format!("{botmodels_url}/api/v1/speech/to-text"))
.header("Content-Type", "application/octet-stream") .header("Content-Type", "application/octet-stream")
.body(audio_data.to_vec()) .body(audio_data.to_vec())
.send() .send()
@ -1494,7 +1352,6 @@ async fn process_audio_to_text(
} }
} }
async fn process_video_description( async fn process_video_description(
_state: &AppState, _state: &AppState,
video_url: &str, video_url: &str,
@ -1504,7 +1361,6 @@ async fn process_video_description(
let client = reqwest::Client::new(); let client = reqwest::Client::new();
let video_data = client let video_data = client
.get(video_url) .get(video_url)
.send() .send()
@ -1512,16 +1368,15 @@ async fn process_video_description(
.map_err(|e| format!("Failed to download video: {}", e))? .map_err(|e| format!("Failed to download video: {}", e))?
.bytes() .bytes()
.await .await
.map_err(|e| format!("Failed to read video: {}", e))?; .map_err(|e| format!("Failed to fetch video: {e}"))?;
let response = client let response = client
.post(format!("{}/api/v1/vision/describe-video", botmodels_url)) .post(format!("{botmodels_url}/api/v1/vision/describe-video"))
.header("Content-Type", "application/octet-stream") .header("Content-Type", "application/octet-stream")
.body(video_data.to_vec()) .body(video_data.to_vec())
.send() .send()
.await .await
.map_err(|e| format!("Failed to call botmodels: {}", e))?; .map_err(|e| format!("Failed to read video: {e}"))?;
if response.status().is_success() { if response.status().is_success() {
let result: serde_json::Value = response let result: serde_json::Value = response

View file

@ -1,57 +1,3 @@
use chrono::{DateTime, Duration, Utc}; use chrono::{DateTime, Duration, Utc};
use rhai::{Array, Dynamic, Engine, Map}; use rhai::{Array, Dynamic, Engine, Map};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -59,10 +5,8 @@ use std::collections::HashMap;
use tracing::info; use tracing::info;
use uuid::Uuid; use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApprovalRequest { pub struct ApprovalRequest {
pub id: Uuid, pub id: Uuid,
pub bot_id: Uuid, pub bot_id: Uuid,
@ -106,11 +50,9 @@ pub struct ApprovalRequest {
pub comments: Option<String>, pub comments: Option<String>,
} }
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum ApprovalStatus { pub enum ApprovalStatus {
Pending, Pending,
Approved, Approved,
@ -132,7 +74,6 @@ impl Default for ApprovalStatus {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum ApprovalDecision { pub enum ApprovalDecision {
@ -143,7 +84,6 @@ pub enum ApprovalDecision {
RequestInfo, RequestInfo,
} }
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum ApprovalChannel { pub enum ApprovalChannel {
@ -176,10 +116,8 @@ impl std::fmt::Display for ApprovalChannel {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApprovalChain { pub struct ApprovalChain {
pub name: String, pub name: String,
pub bot_id: Uuid, pub bot_id: Uuid,
@ -195,10 +133,8 @@ pub struct ApprovalChain {
pub created_at: DateTime<Utc>, pub created_at: DateTime<Utc>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApprovalLevel { pub struct ApprovalLevel {
pub level: u32, pub level: u32,
pub channel: ApprovalChannel, pub channel: ApprovalChannel,
@ -216,10 +152,8 @@ pub struct ApprovalLevel {
pub required_approvals: u32, pub required_approvals: u32,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApprovalAuditEntry { pub struct ApprovalAuditEntry {
pub id: Uuid, pub id: Uuid,
pub request_id: Uuid, pub request_id: Uuid,
@ -237,7 +171,6 @@ pub struct ApprovalAuditEntry {
pub user_agent: Option<String>, pub user_agent: Option<String>,
} }
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum AuditAction { pub enum AuditAction {
@ -253,10 +186,8 @@ pub enum AuditAction {
ContextUpdated, ContextUpdated,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ApprovalConfig { pub struct ApprovalConfig {
pub enabled: bool, pub enabled: bool,
pub default_timeout: u64, pub default_timeout: u64,
@ -289,19 +220,16 @@ impl Default for ApprovalConfig {
} }
} }
#[derive(Debug)] #[derive(Debug)]
pub struct ApprovalManager { pub struct ApprovalManager {
config: ApprovalConfig, config: ApprovalConfig,
} }
impl ApprovalManager { impl ApprovalManager {
pub fn new(config: ApprovalConfig) -> Self { pub fn new(config: ApprovalConfig) -> Self {
ApprovalManager { config } ApprovalManager { config }
} }
pub fn from_config(config_map: &HashMap<String, String>) -> Self { pub fn from_config(config_map: &HashMap<String, String>) -> Self {
let config = ApprovalConfig { let config = ApprovalConfig {
enabled: config_map enabled: config_map
@ -331,7 +259,6 @@ impl ApprovalManager {
ApprovalManager::new(config) ApprovalManager::new(config)
} }
pub fn create_request( pub fn create_request(
&self, &self,
bot_id: Uuid, bot_id: Uuid,
@ -373,12 +300,10 @@ impl ApprovalManager {
} }
} }
pub fn is_expired(&self, request: &ApprovalRequest) -> bool { pub fn is_expired(&self, request: &ApprovalRequest) -> bool {
Utc::now() > request.expires_at Utc::now() > request.expires_at
} }
pub fn should_send_reminder(&self, request: &ApprovalRequest) -> bool { pub fn should_send_reminder(&self, request: &ApprovalRequest) -> bool {
if request.status != ApprovalStatus::Pending { if request.status != ApprovalStatus::Pending {
return false; return false;
@ -398,7 +323,6 @@ impl ApprovalManager {
since_last.num_seconds() >= self.config.reminder_interval as i64 since_last.num_seconds() >= self.config.reminder_interval as i64
} }
pub fn generate_approval_url(&self, request_id: Uuid, action: &str, token: &str) -> String { pub fn generate_approval_url(&self, request_id: Uuid, action: &str, token: &str) -> String {
let base_url = self let base_url = self
.config .config
@ -412,7 +336,6 @@ impl ApprovalManager {
) )
} }
pub fn generate_email_content(&self, request: &ApprovalRequest, token: &str) -> EmailContent { pub fn generate_email_content(&self, request: &ApprovalRequest, token: &str) -> EmailContent {
let approve_url = self.generate_approval_url(request.id, "approve", token); let approve_url = self.generate_approval_url(request.id, "approve", token);
let reject_url = self.generate_approval_url(request.id, "reject", token); let reject_url = self.generate_approval_url(request.id, "reject", token);
@ -423,7 +346,7 @@ impl ApprovalManager {
); );
let body = format!( let body = format!(
r#" r"
An approval is requested for: An approval is requested for:
Type: {} Type: {}
@ -438,7 +361,7 @@ To approve, click: {}
To reject, click: {} To reject, click: {}
If you have questions, reply to this email. If you have questions, reply to this email.
"#, ",
request.approval_type, request.approval_type,
request.message, request.message,
serde_json::to_string_pretty(&request.context).unwrap_or_default(), serde_json::to_string_pretty(&request.context).unwrap_or_default(),
@ -454,7 +377,6 @@ If you have questions, reply to this email.
} }
} }
pub fn process_decision( pub fn process_decision(
&self, &self,
request: &mut ApprovalRequest, request: &mut ApprovalRequest,
@ -475,7 +397,6 @@ If you have questions, reply to this email.
}; };
} }
pub fn handle_timeout(&self, request: &mut ApprovalRequest) { pub fn handle_timeout(&self, request: &mut ApprovalRequest) {
if let Some(default_action) = &request.default_action { if let Some(default_action) = &request.default_action {
request.decision = Some(default_action.clone()); request.decision = Some(default_action.clone());
@ -491,14 +412,11 @@ If you have questions, reply to this email.
} }
} }
pub fn evaluate_condition( pub fn evaluate_condition(
&self, &self,
condition: &str, condition: &str,
context: &serde_json::Value, context: &serde_json::Value,
) -> Result<bool, String> { ) -> Result<bool, String> {
let parts: Vec<&str> = condition.split_whitespace().collect(); let parts: Vec<&str> = condition.split_whitespace().collect();
if parts.len() != 3 { if parts.len() != 3 {
return Err(format!("Invalid condition format: {}", condition)); return Err(format!("Invalid condition format: {}", condition));
@ -531,7 +449,6 @@ If you have questions, reply to this email.
} }
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct EmailContent { pub struct EmailContent {
pub subject: String, pub subject: String,
@ -539,7 +456,6 @@ pub struct EmailContent {
pub html_body: Option<String>, pub html_body: Option<String>,
} }
impl ApprovalRequest { impl ApprovalRequest {
pub fn to_dynamic(&self) -> Dynamic { pub fn to_dynamic(&self) -> Dynamic {
let mut map = Map::new(); let mut map = Map::new();
@ -589,7 +505,6 @@ impl ApprovalRequest {
} }
} }
fn json_to_dynamic(value: &serde_json::Value) -> Dynamic { fn json_to_dynamic(value: &serde_json::Value) -> Dynamic {
match value { match value {
serde_json::Value::Null => Dynamic::UNIT, serde_json::Value::Null => Dynamic::UNIT,
@ -618,10 +533,7 @@ fn json_to_dynamic(value: &serde_json::Value) -> Dynamic {
} }
} }
pub fn register_approval_keywords(engine: &mut Engine) { pub fn register_approval_keywords(engine: &mut Engine) {
engine.register_fn("approval_is_approved", |request: Map| -> bool { engine.register_fn("approval_is_approved", |request: Map| -> bool {
request request
.get("status") .get("status")
@ -678,7 +590,6 @@ pub fn register_approval_keywords(engine: &mut Engine) {
info!("Approval keywords registered"); info!("Approval keywords registered");
} }
pub const APPROVAL_SCHEMA: &str = r#" pub const APPROVAL_SCHEMA: &str = r#"
-- Approval requests -- Approval requests
CREATE TABLE IF NOT EXISTS approval_requests ( CREATE TABLE IF NOT EXISTS approval_requests (
@ -758,9 +669,8 @@ CREATE INDEX IF NOT EXISTS idx_approval_tokens_token ON approval_tokens(token);
CREATE INDEX IF NOT EXISTS idx_approval_tokens_request_id ON approval_tokens(request_id); CREATE INDEX IF NOT EXISTS idx_approval_tokens_request_id ON approval_tokens(request_id);
"#; "#;
pub mod sql { pub mod sql {
pub const INSERT_REQUEST: &str = r#" pub const INSERT_REQUEST: &str = r"
INSERT INTO approval_requests ( INSERT INTO approval_requests (
id, bot_id, session_id, initiated_by, approval_type, status, id, bot_id, session_id, initiated_by, approval_type, status,
channel, recipient, context, message, timeout_seconds, channel, recipient, context, message, timeout_seconds,
@ -769,9 +679,9 @@ pub mod sql {
) VALUES ( ) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17 $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17
) )
"#; ";
pub const UPDATE_REQUEST: &str = r#" pub const UPDATE_REQUEST: &str = r"
UPDATE approval_requests UPDATE approval_requests
SET status = $2, SET status = $2,
decision = $3, decision = $3,
@ -779,71 +689,71 @@ pub mod sql {
decided_at = $5, decided_at = $5,
comments = $6 comments = $6
WHERE id = $1 WHERE id = $1
"#; ";
pub const GET_REQUEST: &str = r#" pub const GET_REQUEST: &str = r"
SELECT * FROM approval_requests WHERE id = $1 SELECT * FROM approval_requests WHERE id = $1
"#; ";
pub const GET_PENDING_REQUESTS: &str = r#" pub const GET_PENDING_REQUESTS: &str = r"
SELECT * FROM approval_requests SELECT * FROM approval_requests
WHERE status = 'pending' WHERE status = 'pending'
AND expires_at > NOW() AND expires_at > NOW()
ORDER BY created_at ASC ORDER BY created_at ASC
"#; ";
pub const GET_EXPIRED_REQUESTS: &str = r#" pub const GET_EXPIRED_REQUESTS: &str = r"
SELECT * FROM approval_requests SELECT * FROM approval_requests
WHERE status = 'pending' WHERE status = 'pending'
AND expires_at <= NOW() AND expires_at <= NOW()
"#; ";
pub const GET_REQUESTS_BY_SESSION: &str = r#" pub const GET_REQUESTS_BY_SESSION: &str = r"
SELECT * FROM approval_requests SELECT * FROM approval_requests
WHERE session_id = $1 WHERE session_id = $1
ORDER BY created_at DESC ORDER BY created_at DESC
"#; ";
pub const UPDATE_REMINDERS: &str = r#" pub const UPDATE_REMINDERS: &str = r"
UPDATE approval_requests UPDATE approval_requests
SET reminders_sent = reminders_sent || $2::jsonb SET reminders_sent = reminders_sent || $2::jsonb
WHERE id = $1 WHERE id = $1
"#; ";
pub const INSERT_AUDIT: &str = r#" pub const INSERT_AUDIT: &str = r"
INSERT INTO approval_audit_log ( INSERT INTO approval_audit_log (
id, request_id, action, actor, details, timestamp, ip_address, user_agent id, request_id, action, actor, details, timestamp, ip_address, user_agent
) VALUES ( ) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8 $1, $2, $3, $4, $5, $6, $7, $8
) )
"#; ";
pub const GET_AUDIT_LOG: &str = r#" pub const GET_AUDIT_LOG: &str = r"
SELECT * FROM approval_audit_log SELECT * FROM approval_audit_log
WHERE request_id = $1 WHERE request_id = $1
ORDER BY timestamp ASC ORDER BY timestamp ASC
"#; ";
pub const INSERT_TOKEN: &str = r#" pub const INSERT_TOKEN: &str = r"
INSERT INTO approval_tokens ( INSERT INTO approval_tokens (
id, request_id, token, action, expires_at, created_at id, request_id, token, action, expires_at, created_at
) VALUES ( ) VALUES (
$1, $2, $3, $4, $5, $6 $1, $2, $3, $4, $5, $6
) )
"#; ";
pub const GET_TOKEN: &str = r#" pub const GET_TOKEN: &str = r"
SELECT * FROM approval_tokens SELECT * FROM approval_tokens
WHERE token = $1 AND used = false AND expires_at > NOW() WHERE token = $1 AND used = false AND expires_at > NOW()
"#; ";
pub const USE_TOKEN: &str = r#" pub const USE_TOKEN: &str = r"
UPDATE approval_tokens UPDATE approval_tokens
SET used = true, used_at = NOW() SET used = true, used_at = NOW()
WHERE token = $1 WHERE token = $1
"#; ";
pub const INSERT_CHAIN: &str = r#" pub const INSERT_CHAIN: &str = r"
INSERT INTO approval_chains ( INSERT INTO approval_chains (
id, name, bot_id, levels, stop_on_reject, require_all, description, created_at id, name, bot_id, levels, stop_on_reject, require_all, description, created_at
) VALUES ( ) VALUES (
@ -855,10 +765,10 @@ pub mod sql {
stop_on_reject = $5, stop_on_reject = $5,
require_all = $6, require_all = $6,
description = $7 description = $7
"#; ";
pub const GET_CHAIN: &str = r#" pub const GET_CHAIN: &str = r"
SELECT * FROM approval_chains SELECT * FROM approval_chains
WHERE bot_id = $1 AND name = $2 WHERE bot_id = $1 AND name = $2
"#; ";
} }

View file

@ -1,50 +1,3 @@
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use rhai::{Array, Dynamic, Engine, Map}; use rhai::{Array, Dynamic, Engine, Map};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -52,10 +5,8 @@ use std::collections::HashMap;
use tracing::info; use tracing::info;
use uuid::Uuid; use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KgEntity { pub struct KgEntity {
pub id: Uuid, pub id: Uuid,
pub bot_id: Uuid, pub bot_id: Uuid,
@ -77,7 +28,6 @@ pub struct KgEntity {
pub updated_at: DateTime<Utc>, pub updated_at: DateTime<Utc>,
} }
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum EntitySource { pub enum EntitySource {
@ -93,10 +43,8 @@ impl Default for EntitySource {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KgRelationship { pub struct KgRelationship {
pub id: Uuid, pub id: Uuid,
pub bot_id: Uuid, pub bot_id: Uuid,
@ -118,10 +66,8 @@ pub struct KgRelationship {
pub created_at: DateTime<Utc>, pub created_at: DateTime<Utc>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtractedEntity { pub struct ExtractedEntity {
pub name: String, pub name: String,
pub canonical_name: String, pub canonical_name: String,
@ -137,10 +83,8 @@ pub struct ExtractedEntity {
pub properties: serde_json::Value, pub properties: serde_json::Value,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtractedRelationship { pub struct ExtractedRelationship {
pub from_entity: String, pub from_entity: String,
pub to_entity: String, pub to_entity: String,
@ -152,10 +96,8 @@ pub struct ExtractedRelationship {
pub evidence: String, pub evidence: String,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtractionResult { pub struct ExtractionResult {
pub entities: Vec<ExtractedEntity>, pub entities: Vec<ExtractedEntity>,
pub relationships: Vec<ExtractedRelationship>, pub relationships: Vec<ExtractedRelationship>,
@ -163,10 +105,8 @@ pub struct ExtractionResult {
pub metadata: ExtractionMetadata, pub metadata: ExtractionMetadata,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtractionMetadata { pub struct ExtractionMetadata {
pub model: String, pub model: String,
pub processing_time_ms: u64, pub processing_time_ms: u64,
@ -176,10 +116,8 @@ pub struct ExtractionMetadata {
pub text_length: usize, pub text_length: usize,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphQueryResult { pub struct GraphQueryResult {
pub entities: Vec<KgEntity>, pub entities: Vec<KgEntity>,
pub relationships: Vec<KgRelationship>, pub relationships: Vec<KgRelationship>,
@ -189,10 +127,8 @@ pub struct GraphQueryResult {
pub confidence: f64, pub confidence: f64,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct KnowledgeGraphConfig { pub struct KnowledgeGraphConfig {
pub enabled: bool, pub enabled: bool,
pub backend: String, pub backend: String,
@ -233,19 +169,16 @@ impl Default for KnowledgeGraphConfig {
} }
} }
#[derive(Debug)] #[derive(Debug)]
pub struct KnowledgeGraphManager { pub struct KnowledgeGraphManager {
config: KnowledgeGraphConfig, config: KnowledgeGraphConfig,
} }
impl KnowledgeGraphManager { impl KnowledgeGraphManager {
pub fn new(config: KnowledgeGraphConfig) -> Self { pub fn new(config: KnowledgeGraphConfig) -> Self {
KnowledgeGraphManager { config } KnowledgeGraphManager { config }
} }
pub fn from_config(config_map: &HashMap<String, String>) -> Self { pub fn from_config(config_map: &HashMap<String, String>) -> Self {
let config = KnowledgeGraphConfig { let config = KnowledgeGraphConfig {
enabled: config_map enabled: config_map
@ -284,7 +217,6 @@ impl KnowledgeGraphManager {
KnowledgeGraphManager::new(config) KnowledgeGraphManager::new(config)
} }
pub fn generate_extraction_prompt(&self, text: &str) -> String { pub fn generate_extraction_prompt(&self, text: &str) -> String {
let entity_types = self.config.entity_types.join(", "); let entity_types = self.config.entity_types.join(", ");
@ -320,7 +252,6 @@ Respond with valid JSON only:
) )
} }
pub fn generate_query_prompt(&self, query: &str, context: &str) -> String { pub fn generate_query_prompt(&self, query: &str, context: &str) -> String {
format!( format!(
r#"Answer this question using the knowledge graph context. r#"Answer this question using the knowledge graph context.
@ -336,7 +267,6 @@ If the information is not available, say so clearly.
) )
} }
pub fn parse_extraction_response( pub fn parse_extraction_response(
&self, &self,
response: &str, response: &str,
@ -401,12 +331,10 @@ If the information is not available, say so clearly.
}) })
} }
pub fn should_extract(&self) -> bool { pub fn should_extract(&self) -> bool {
self.config.enabled && self.config.extract_entities self.config.enabled && self.config.extract_entities
} }
pub fn is_valid_entity_type(&self, entity_type: &str) -> bool { pub fn is_valid_entity_type(&self, entity_type: &str) -> bool {
self.config self.config
.entity_types .entity_types
@ -415,16 +343,13 @@ If the information is not available, say so clearly.
} }
} }
fn extract_json(response: &str) -> Result<String, String> { fn extract_json(response: &str) -> Result<String, String> {
if let Some(start) = response.find("```json") { if let Some(start) = response.find("```json") {
if let Some(end) = response[start + 7..].find("```") { if let Some(end) = response[start + 7..].find("```") {
return Ok(response[start + 7..start + 7 + end].trim().to_string()); return Ok(response[start + 7..start + 7 + end].trim().to_string());
} }
} }
if let Some(start) = response.find("```") { if let Some(start) = response.find("```") {
let after_start = start + 3; let after_start = start + 3;
let json_start = response[after_start..] let json_start = response[after_start..]
@ -436,7 +361,6 @@ fn extract_json(response: &str) -> Result<String, String> {
} }
} }
if let Some(start) = response.find('{') { if let Some(start) = response.find('{') {
if let Some(end) = response.rfind('}') { if let Some(end) = response.rfind('}') {
if end > start { if end > start {
@ -448,7 +372,6 @@ fn extract_json(response: &str) -> Result<String, String> {
Err("No JSON found in response".to_string()) Err("No JSON found in response".to_string())
} }
impl KgEntity { impl KgEntity {
pub fn to_dynamic(&self) -> Dynamic { pub fn to_dynamic(&self) -> Dynamic {
let mut map = Map::new(); let mut map = Map::new();
@ -478,7 +401,6 @@ impl KgEntity {
} }
} }
impl KgRelationship { impl KgRelationship {
pub fn to_dynamic(&self) -> Dynamic { pub fn to_dynamic(&self) -> Dynamic {
let mut map = Map::new(); let mut map = Map::new();
@ -507,7 +429,6 @@ impl KgRelationship {
} }
} }
fn json_to_dynamic(value: &serde_json::Value) -> Dynamic { fn json_to_dynamic(value: &serde_json::Value) -> Dynamic {
match value { match value {
serde_json::Value::Null => Dynamic::UNIT, serde_json::Value::Null => Dynamic::UNIT,
@ -536,10 +457,7 @@ fn json_to_dynamic(value: &serde_json::Value) -> Dynamic {
} }
} }
pub fn register_knowledge_graph_keywords(engine: &mut Engine) { pub fn register_knowledge_graph_keywords(engine: &mut Engine) {
engine.register_fn("entity_name", |entity: Map| -> String { engine.register_fn("entity_name", |entity: Map| -> String {
entity entity
.get("entity_name") .get("entity_name")
@ -576,7 +494,6 @@ pub fn register_knowledge_graph_keywords(engine: &mut Engine) {
info!("Knowledge graph keywords registered"); info!("Knowledge graph keywords registered");
} }
pub const KNOWLEDGE_GRAPH_SCHEMA: &str = r#" pub const KNOWLEDGE_GRAPH_SCHEMA: &str = r#"
-- Knowledge graph entities -- Knowledge graph entities
CREATE TABLE IF NOT EXISTS kg_entities ( CREATE TABLE IF NOT EXISTS kg_entities (
@ -627,9 +544,8 @@ CREATE INDEX IF NOT EXISTS idx_kg_entities_name_fts ON kg_entities
USING GIN(to_tsvector('english', entity_name)); USING GIN(to_tsvector('english', entity_name));
"#; "#;
pub mod sql { pub mod sql {
pub const INSERT_ENTITY: &str = r#" pub const INSERT_ENTITY: &str = r"
INSERT INTO kg_entities ( INSERT INTO kg_entities (
id, bot_id, entity_type, entity_name, aliases, properties, id, bot_id, entity_type, entity_name, aliases, properties,
confidence, source, created_at, updated_at confidence, source, created_at, updated_at
@ -643,9 +559,9 @@ pub mod sql {
confidence = GREATEST(kg_entities.confidence, $7), confidence = GREATEST(kg_entities.confidence, $7),
updated_at = $10 updated_at = $10
RETURNING id RETURNING id
"#; ";
pub const INSERT_RELATIONSHIP: &str = r#" pub const INSERT_RELATIONSHIP: &str = r"
INSERT INTO kg_relationships ( INSERT INTO kg_relationships (
id, bot_id, from_entity_id, to_entity_id, relationship_type, id, bot_id, from_entity_id, to_entity_id, relationship_type,
properties, confidence, bidirectional, source, created_at properties, confidence, bidirectional, source, created_at
@ -657,9 +573,9 @@ pub mod sql {
properties = kg_relationships.properties || $6, properties = kg_relationships.properties || $6,
confidence = GREATEST(kg_relationships.confidence, $7) confidence = GREATEST(kg_relationships.confidence, $7)
RETURNING id RETURNING id
"#; ";
pub const GET_ENTITY_BY_NAME: &str = r#" pub const GET_ENTITY_BY_NAME: &str = r"
SELECT * FROM kg_entities SELECT * FROM kg_entities
WHERE bot_id = $1 WHERE bot_id = $1
AND ( AND (
@ -667,13 +583,13 @@ pub mod sql {
OR aliases @> $3::jsonb OR aliases @> $3::jsonb
) )
LIMIT 1 LIMIT 1
"#; ";
pub const GET_ENTITY_BY_ID: &str = r#" pub const GET_ENTITY_BY_ID: &str = r"
SELECT * FROM kg_entities WHERE id = $1 SELECT * FROM kg_entities WHERE id = $1
"#; ";
pub const SEARCH_ENTITIES: &str = r#" pub const SEARCH_ENTITIES: &str = r"
SELECT * FROM kg_entities SELECT * FROM kg_entities
WHERE bot_id = $1 WHERE bot_id = $1
AND ( AND (
@ -682,16 +598,16 @@ pub mod sql {
) )
ORDER BY confidence DESC ORDER BY confidence DESC
LIMIT $4 LIMIT $4
"#; ";
pub const GET_ENTITIES_BY_TYPE: &str = r#" pub const GET_ENTITIES_BY_TYPE: &str = r"
SELECT * FROM kg_entities SELECT * FROM kg_entities
WHERE bot_id = $1 AND entity_type = $2 WHERE bot_id = $1 AND entity_type = $2
ORDER BY entity_name ORDER BY entity_name
LIMIT $3 LIMIT $3
"#; ";
pub const GET_RELATED_ENTITIES: &str = r#" pub const GET_RELATED_ENTITIES: &str = r"
SELECT e.*, r.relationship_type, r.confidence as rel_confidence SELECT e.*, r.relationship_type, r.confidence as rel_confidence
FROM kg_entities e FROM kg_entities e
JOIN kg_relationships r ON ( JOIN kg_relationships r ON (
@ -701,9 +617,9 @@ pub mod sql {
WHERE r.bot_id = $2 WHERE r.bot_id = $2
ORDER BY r.confidence DESC ORDER BY r.confidence DESC
LIMIT $3 LIMIT $3
"#; ";
pub const GET_RELATED_BY_TYPE: &str = r#" pub const GET_RELATED_BY_TYPE: &str = r"
SELECT e.*, r.relationship_type, r.confidence as rel_confidence SELECT e.*, r.relationship_type, r.confidence as rel_confidence
FROM kg_entities e FROM kg_entities e
JOIN kg_relationships r ON ( JOIN kg_relationships r ON (
@ -713,17 +629,17 @@ pub mod sql {
WHERE r.bot_id = $2 AND r.relationship_type = $3 WHERE r.bot_id = $2 AND r.relationship_type = $3
ORDER BY r.confidence DESC ORDER BY r.confidence DESC
LIMIT $4 LIMIT $4
"#; ";
pub const GET_RELATIONSHIP: &str = r#" pub const GET_RELATIONSHIP: &str = r"
SELECT * FROM kg_relationships SELECT * FROM kg_relationships
WHERE bot_id = $1 WHERE bot_id = $1
AND from_entity_id = $2 AND from_entity_id = $2
AND to_entity_id = $3 AND to_entity_id = $3
AND relationship_type = $4 AND relationship_type = $4
"#; ";
pub const GET_ALL_RELATIONSHIPS_FOR_ENTITY: &str = r#" pub const GET_ALL_RELATIONSHIPS_FOR_ENTITY: &str = r"
SELECT r.*, SELECT r.*,
e1.entity_name as from_name, e1.entity_type as from_type, e1.entity_name as from_name, e1.entity_type as from_type,
e2.entity_name as to_name, e2.entity_type as to_type e2.entity_name as to_name, e2.entity_type as to_type
@ -733,42 +649,41 @@ pub mod sql {
WHERE r.bot_id = $1 WHERE r.bot_id = $1
AND (r.from_entity_id = $2 OR r.to_entity_id = $2) AND (r.from_entity_id = $2 OR r.to_entity_id = $2)
ORDER BY r.confidence DESC ORDER BY r.confidence DESC
"#; ";
pub const DELETE_ENTITY: &str = r#" pub const DELETE_ENTITY: &str = r"
DELETE FROM kg_entities WHERE id = $1 AND bot_id = $2 DELETE FROM kg_entities WHERE id = $1 AND bot_id = $2
"#; ";
pub const DELETE_RELATIONSHIP: &str = r#" pub const DELETE_RELATIONSHIP: &str = r"
DELETE FROM kg_relationships WHERE id = $1 AND bot_id = $2 DELETE FROM kg_relationships WHERE id = $1 AND bot_id = $2
"#; ";
pub const COUNT_ENTITIES: &str = r#" pub const COUNT_ENTITIES: &str = r"
SELECT COUNT(*) FROM kg_entities WHERE bot_id = $1 SELECT COUNT(*) FROM kg_entities WHERE bot_id = $1
"#; ";
pub const COUNT_RELATIONSHIPS: &str = r#" pub const COUNT_RELATIONSHIPS: &str = r"
SELECT COUNT(*) FROM kg_relationships WHERE bot_id = $1 SELECT COUNT(*) FROM kg_relationships WHERE bot_id = $1
"#; ";
pub const GET_ENTITY_TYPES: &str = r#" pub const GET_ENTITY_TYPES: &str = r"
SELECT DISTINCT entity_type, COUNT(*) as count SELECT DISTINCT entity_type, COUNT(*) as count
FROM kg_entities FROM kg_entities
WHERE bot_id = $1 WHERE bot_id = $1
GROUP BY entity_type GROUP BY entity_type
ORDER BY count DESC ORDER BY count DESC
"#; ";
pub const GET_RELATIONSHIP_TYPES: &str = r#" pub const GET_RELATIONSHIP_TYPES: &str = r"
SELECT DISTINCT relationship_type, COUNT(*) as count SELECT DISTINCT relationship_type, COUNT(*) as count
FROM kg_relationships FROM kg_relationships
WHERE bot_id = $1 WHERE bot_id = $1
GROUP BY relationship_type GROUP BY relationship_type
ORDER BY count DESC ORDER BY count DESC
"#; ";
pub const FIND_PATH: &str = r"
pub const FIND_PATH: &str = r#"
WITH RECURSIVE path_finder AS ( WITH RECURSIVE path_finder AS (
-- Base case: start from source entity -- Base case: start from source entity
SELECT SELECT
@ -799,10 +714,9 @@ pub mod sql {
WHERE to_entity_id = $3 WHERE to_entity_id = $3
ORDER BY depth ORDER BY depth
LIMIT 1 LIMIT 1
"#; ";
} }
pub mod relationship_types { pub mod relationship_types {
pub const WORKS_ON: &str = "works_on"; pub const WORKS_ON: &str = "works_on";
pub const REPORTS_TO: &str = "reports_to"; pub const REPORTS_TO: &str = "reports_to";
@ -820,7 +734,6 @@ pub mod relationship_types {
pub const ALIAS_OF: &str = "alias_of"; pub const ALIAS_OF: &str = "alias_of";
} }
pub mod entity_types { pub mod entity_types {
pub const PERSON: &str = "person"; pub const PERSON: &str = "person";
pub const ORGANIZATION: &str = "organization"; pub const ORGANIZATION: &str = "organization";

View file

@ -3,7 +3,7 @@ use rhai::Engine;
pub fn last_keyword(engine: &mut Engine) { pub fn last_keyword(engine: &mut Engine) {
engine engine
.register_custom_syntax(&["LAST", "(", "$expr$", ")"], false, { .register_custom_syntax(["LAST", "(", "$expr$", ")"], false, {
move |context, inputs| { move |context, inputs| {
let input_string = context.eval_expression_tree(&inputs[0])?; let input_string = context.eval_expression_tree(&inputs[0])?;
let input_str = input_string.to_string(); let input_str = input_string.to_string();

View file

@ -28,14 +28,6 @@
| | | |
\*****************************************************************************/ \*****************************************************************************/
use crate::core::config::ConfigManager; use crate::core::config::ConfigManager;
use crate::shared::models::UserSession; use crate::shared::models::UserSession;
use crate::shared::state::AppState; use crate::shared::state::AppState;
@ -45,7 +37,6 @@ use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use uuid::Uuid; use uuid::Uuid;
pub fn register_llm_macros(state: Arc<AppState>, user: UserSession, engine: &mut Engine) { pub fn register_llm_macros(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
register_calculate_keyword(state.clone(), user.clone(), engine); register_calculate_keyword(state.clone(), user.clone(), engine);
register_validate_keyword(state.clone(), user.clone(), engine); register_validate_keyword(state.clone(), user.clone(), engine);
@ -53,7 +44,6 @@ pub fn register_llm_macros(state: Arc<AppState>, user: UserSession, engine: &mut
register_summarize_keyword(state, user, engine); register_summarize_keyword(state, user, engine);
} }
async fn call_llm( async fn call_llm(
state: &AppState, state: &AppState,
prompt: &str, prompt: &str,
@ -75,7 +65,6 @@ async fn call_llm(
Ok(processed) Ok(processed)
} }
fn run_llm_with_timeout( fn run_llm_with_timeout(
state: Arc<AppState>, state: Arc<AppState>,
prompt: String, prompt: String,
@ -117,8 +106,6 @@ fn run_llm_with_timeout(
} }
} }
pub fn register_calculate_keyword(state: Arc<AppState>, _user: UserSession, engine: &mut Engine) { pub fn register_calculate_keyword(state: Arc<AppState>, _user: UserSession, engine: &mut Engine) {
let state_clone = Arc::clone(&state); let state_clone = Arc::clone(&state);
@ -158,7 +145,7 @@ fn build_calculate_prompt(formula: &str, variables: &Dynamic) -> String {
}; };
format!( format!(
r#"You are a precise calculator. Evaluate the following expression. r"You are a precise calculator. Evaluate the following expression.
Formula: {} Formula: {}
Variables: {} Variables: {}
@ -169,7 +156,7 @@ Instructions:
3. Return ONLY the final result (number, boolean, or text) 3. Return ONLY the final result (number, boolean, or text)
4. No explanations, just the result 4. No explanations, just the result
Result:"#, Result:",
formula, vars_str formula, vars_str
) )
} }
@ -198,8 +185,6 @@ fn parse_calculate_result(result: &str) -> Result<Dynamic, Box<rhai::EvalAltResu
Ok(Dynamic::from(trimmed.to_string())) Ok(Dynamic::from(trimmed.to_string()))
} }
pub fn register_validate_keyword(state: Arc<AppState>, _user: UserSession, engine: &mut Engine) { pub fn register_validate_keyword(state: Arc<AppState>, _user: UserSession, engine: &mut Engine) {
let state_clone = Arc::clone(&state); let state_clone = Arc::clone(&state);
@ -306,8 +291,6 @@ fn parse_validate_result(result: &str) -> Result<Dynamic, Box<rhai::EvalAltResul
Ok(Dynamic::from(map)) Ok(Dynamic::from(map))
} }
pub fn register_translate_keyword(state: Arc<AppState>, _user: UserSession, engine: &mut Engine) { pub fn register_translate_keyword(state: Arc<AppState>, _user: UserSession, engine: &mut Engine) {
let state_clone = Arc::clone(&state); let state_clone = Arc::clone(&state);
@ -336,20 +319,18 @@ pub fn register_translate_keyword(state: Arc<AppState>, _user: UserSession, engi
fn build_translate_prompt(text: &str, language: &str) -> String { fn build_translate_prompt(text: &str, language: &str) -> String {
format!( format!(
r#"Translate the following text to {}. r"Translate the following text to {}.
Text: Text:
{} {}
Return ONLY the translated text, no explanations: Return ONLY the translated text, no explanations:
Translation:"#, Translation:",
language, text language, text
) )
} }
pub fn register_summarize_keyword(state: Arc<AppState>, _user: UserSession, engine: &mut Engine) { pub fn register_summarize_keyword(state: Arc<AppState>, _user: UserSession, engine: &mut Engine) {
let state_clone = Arc::clone(&state); let state_clone = Arc::clone(&state);
@ -369,14 +350,14 @@ pub fn register_summarize_keyword(state: Arc<AppState>, _user: UserSession, engi
fn build_summarize_prompt(text: &str) -> String { fn build_summarize_prompt(text: &str) -> String {
format!( format!(
r#"Summarize the following text concisely. r"Summarize the following text concisely.
Text: Text:
{} {}
Return ONLY the summary: Return ONLY the summary:
Summary:"#, Summary:",
text text
) )
} }

View file

@ -126,7 +126,7 @@ pub fn parse_folder_path(path: &str) -> (FolderProvider, Option<String>, String)
(FolderProvider::Local, None, path.to_string()) (FolderProvider::Local, None, path.to_string())
} }
fn detect_provider_from_email(email: &str) -> FolderProvider { pub fn detect_provider_from_email(email: &str) -> FolderProvider {
let lower = email.to_lowercase(); let lower = email.to_lowercase();
if lower.ends_with("@gmail.com") || lower.contains("google") { if lower.ends_with("@gmail.com") || lower.contains("google") {
FolderProvider::GDrive FolderProvider::GDrive
@ -285,7 +285,7 @@ fn register_on_change_with_events(state: &AppState, user: UserSession, engine: &
.unwrap(); .unwrap();
} }
fn sanitize_path_for_filename(path: &str) -> String { pub fn sanitize_path_for_filename(path: &str) -> String {
path.replace('/', "_") path.replace('/', "_")
.replace('\\', "_") .replace('\\', "_")
.replace(':', "_") .replace(':', "_")

View file

@ -1,16 +1,3 @@
use crate::shared::models::UserSession; use crate::shared::models::UserSession;
use crate::shared::state::AppState; use crate::shared::state::AppState;
use log::{info, trace}; use log::{info, trace};
@ -21,7 +8,6 @@ use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
use uuid::Uuid; use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ContentType { pub enum ContentType {
Video, Video,
@ -39,10 +25,8 @@ pub enum ContentType {
} }
impl ContentType { impl ContentType {
pub fn from_extension(ext: &str) -> Self { pub fn from_extension(ext: &str) -> Self {
match ext.to_lowercase().as_str() { match ext.to_lowercase().as_str() {
"mp4" | "webm" | "ogg" | "mov" | "avi" | "mkv" | "m4v" => ContentType::Video, "mp4" | "webm" | "ogg" | "mov" | "avi" | "mkv" | "m4v" => ContentType::Video,
"mp3" | "wav" | "flac" | "aac" | "m4a" | "wma" => ContentType::Audio, "mp3" | "wav" | "flac" | "aac" | "m4a" | "wma" => ContentType::Audio,
@ -68,7 +52,6 @@ impl ContentType {
} }
} }
pub fn from_mime(mime: &str) -> Self { pub fn from_mime(mime: &str) -> Self {
if mime.starts_with("video/") { if mime.starts_with("video/") {
ContentType::Video ContentType::Video
@ -97,7 +80,6 @@ impl ContentType {
} }
} }
pub fn player_component(&self) -> &'static str { pub fn player_component(&self) -> &'static str {
match self { match self {
ContentType::Video => "video-player", ContentType::Video => "video-player",
@ -116,7 +98,6 @@ impl ContentType {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize, Default)] #[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct PlayOptions { pub struct PlayOptions {
pub autoplay: bool, pub autoplay: bool,
@ -137,7 +118,6 @@ pub struct PlayOptions {
} }
impl PlayOptions { impl PlayOptions {
pub fn from_string(options_str: &str) -> Self { pub fn from_string(options_str: &str) -> Self {
let mut opts = PlayOptions::default(); let mut opts = PlayOptions::default();
opts.controls = true; opts.controls = true;
@ -151,7 +131,6 @@ impl PlayOptions {
"nocontrols" => opts.controls = false, "nocontrols" => opts.controls = false,
"linenumbers" => opts.line_numbers = Some(true), "linenumbers" => opts.line_numbers = Some(true),
_ => { _ => {
if let Some((key, value)) = opt.split_once('=') { if let Some((key, value)) = opt.split_once('=') {
match key { match key {
"start" => opts.start_time = value.parse().ok(), "start" => opts.start_time = value.parse().ok(),
@ -177,7 +156,6 @@ impl PlayOptions {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PlayContent { pub struct PlayContent {
pub id: Uuid, pub id: Uuid,
@ -189,7 +167,6 @@ pub struct PlayContent {
pub created_at: chrono::DateTime<chrono::Utc>, pub created_at: chrono::DateTime<chrono::Utc>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PlayResponse { pub struct PlayResponse {
pub player_id: Uuid, pub player_id: Uuid,
@ -201,7 +178,6 @@ pub struct PlayResponse {
pub metadata: HashMap<String, String>, pub metadata: HashMap<String, String>,
} }
pub fn play_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) { pub fn play_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
if let Err(e) = play_simple_keyword(state.clone(), user.clone(), engine) { if let Err(e) = play_simple_keyword(state.clone(), user.clone(), engine) {
log::error!("Failed to register PLAY keyword: {}", e); log::error!("Failed to register PLAY keyword: {}", e);
@ -220,7 +196,6 @@ pub fn play_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine
} }
} }
fn play_simple_keyword( fn play_simple_keyword(
state: Arc<AppState>, state: Arc<AppState>,
user: UserSession, user: UserSession,
@ -269,7 +244,6 @@ fn play_simple_keyword(
Ok(()) Ok(())
} }
fn play_with_options_keyword( fn play_with_options_keyword(
state: Arc<AppState>, state: Arc<AppState>,
user: UserSession, user: UserSession,
@ -334,7 +308,6 @@ fn play_with_options_keyword(
Ok(()) Ok(())
} }
fn stop_keyword( fn stop_keyword(
state: Arc<AppState>, state: Arc<AppState>,
user: UserSession, user: UserSession,
@ -373,7 +346,6 @@ fn stop_keyword(
Ok(()) Ok(())
} }
fn pause_keyword( fn pause_keyword(
state: Arc<AppState>, state: Arc<AppState>,
user: UserSession, user: UserSession,
@ -413,7 +385,6 @@ fn pause_keyword(
Ok(()) Ok(())
} }
fn resume_keyword( fn resume_keyword(
state: Arc<AppState>, state: Arc<AppState>,
user: UserSession, user: UserSession,
@ -453,34 +424,25 @@ fn resume_keyword(
Ok(()) Ok(())
} }
async fn execute_play( async fn execute_play(
state: &AppState, state: &AppState,
session_id: Uuid, session_id: Uuid,
source: &str, source: &str,
options: PlayOptions, options: PlayOptions,
) -> Result<PlayResponse, String> { ) -> Result<PlayResponse, String> {
let content_type = detect_content_type(source); let content_type = detect_content_type(source);
let source_url = resolve_source_url(state, session_id, source).await?; let source_url = resolve_source_url(state, session_id, source).await?;
let metadata = get_content_metadata(state, &source_url, &content_type).await?; let metadata = get_content_metadata(state, &source_url, &content_type).await?;
let player_id = Uuid::new_v4(); let player_id = Uuid::new_v4();
let title = metadata let title = metadata
.get("title") .get("title")
.cloned() .cloned()
.unwrap_or_else(|| extract_title_from_source(source)); .unwrap_or_else(|| extract_title_from_source(source));
let response = PlayResponse { let response = PlayResponse {
player_id, player_id,
content_type: content_type.clone(), content_type: content_type.clone(),
@ -491,7 +453,6 @@ async fn execute_play(
metadata, metadata,
}; };
send_play_to_client(state, session_id, &response).await?; send_play_to_client(state, session_id, &response).await?;
info!( info!(
@ -502,11 +463,8 @@ async fn execute_play(
Ok(response) Ok(response)
} }
pub fn detect_content_type(source: &str) -> ContentType {
fn detect_content_type(source: &str) -> ContentType {
if source.starts_with("http://") || source.starts_with("https://") { if source.starts_with("http://") || source.starts_with("https://") {
if source.contains("youtube.com") if source.contains("youtube.com")
|| source.contains("youtu.be") || source.contains("youtu.be")
|| source.contains("vimeo.com") || source.contains("vimeo.com")
@ -514,7 +472,6 @@ fn detect_content_type(source: &str) -> ContentType {
return ContentType::Video; return ContentType::Video;
} }
if source.contains("imgur.com") if source.contains("imgur.com")
|| source.contains("unsplash.com") || source.contains("unsplash.com")
|| source.contains("pexels.com") || source.contains("pexels.com")
@ -522,18 +479,15 @@ fn detect_content_type(source: &str) -> ContentType {
return ContentType::Image; return ContentType::Image;
} }
if let Some(path) = source.split('?').next() { if let Some(path) = source.split('?').next() {
if let Some(ext) = Path::new(path).extension() { if let Some(ext) = Path::new(path).extension() {
return ContentType::from_extension(&ext.to_string_lossy()); return ContentType::from_extension(&ext.to_string_lossy());
} }
} }
return ContentType::Iframe; return ContentType::Iframe;
} }
if let Some(ext) = Path::new(source).extension() { if let Some(ext) = Path::new(source).extension() {
return ContentType::from_extension(&ext.to_string_lossy()); return ContentType::from_extension(&ext.to_string_lossy());
} }
@ -541,20 +495,16 @@ fn detect_content_type(source: &str) -> ContentType {
ContentType::Unknown ContentType::Unknown
} }
async fn resolve_source_url( async fn resolve_source_url(
_state: &AppState, _state: &AppState,
session_id: Uuid, session_id: Uuid,
source: &str, source: &str,
) -> Result<String, String> { ) -> Result<String, String> {
if source.starts_with("http://") || source.starts_with("https://") { if source.starts_with("http://") || source.starts_with("https://") {
return Ok(source.to_string()); return Ok(source.to_string());
} }
if source.starts_with("/") || source.contains(".gbdrive") { if source.starts_with("/") || source.contains(".gbdrive") {
let file_url = format!( let file_url = format!(
"/api/drive/file/{}?session={}", "/api/drive/file/{}?session={}",
urlencoding::encode(source), urlencoding::encode(source),
@ -563,7 +513,6 @@ async fn resolve_source_url(
return Ok(file_url); return Ok(file_url);
} }
let file_url = format!( let file_url = format!(
"/api/drive/file/{}?session={}", "/api/drive/file/{}?session={}",
urlencoding::encode(source), urlencoding::encode(source),
@ -573,7 +522,6 @@ async fn resolve_source_url(
Ok(file_url) Ok(file_url)
} }
async fn get_content_metadata( async fn get_content_metadata(
_state: &AppState, _state: &AppState,
source_url: &str, source_url: &str,
@ -584,7 +532,6 @@ async fn get_content_metadata(
metadata.insert("source".to_string(), source_url.to_string()); metadata.insert("source".to_string(), source_url.to_string());
metadata.insert("type".to_string(), format!("{:?}", content_type)); metadata.insert("type".to_string(), format!("{:?}", content_type));
match content_type { match content_type {
ContentType::Video => { ContentType::Video => {
metadata.insert("player".to_string(), "html5".to_string()); metadata.insert("player".to_string(), "html5".to_string());
@ -613,9 +560,7 @@ async fn get_content_metadata(
Ok(metadata) Ok(metadata)
} }
pub fn extract_title_from_source(source: &str) -> String {
fn extract_title_from_source(source: &str) -> String {
let path = source.split('?').next().unwrap_or(source); let path = source.split('?').next().unwrap_or(source);
Path::new(path) Path::new(path)
@ -624,7 +569,6 @@ fn extract_title_from_source(source: &str) -> String {
.unwrap_or_else(|| "Untitled".to_string()) .unwrap_or_else(|| "Untitled".to_string())
} }
async fn send_play_to_client( async fn send_play_to_client(
state: &AppState, state: &AppState,
session_id: Uuid, session_id: Uuid,
@ -638,7 +582,6 @@ async fn send_play_to_client(
let message_str = let message_str =
serde_json::to_string(&message).map_err(|e| format!("Failed to serialize: {}", e))?; serde_json::to_string(&message).map_err(|e| format!("Failed to serialize: {}", e))?;
let bot_response = crate::shared::models::BotResponse { let bot_response = crate::shared::models::BotResponse {
bot_id: String::new(), bot_id: String::new(),
user_id: String::new(), user_id: String::new(),
@ -663,7 +606,6 @@ async fn send_play_to_client(
Ok(()) Ok(())
} }
async fn send_player_command( async fn send_player_command(
state: &AppState, state: &AppState,
session_id: Uuid, session_id: Uuid,
@ -677,7 +619,6 @@ async fn send_player_command(
let message_str = let message_str =
serde_json::to_string(&message).map_err(|e| format!("Failed to serialize: {}", e))?; serde_json::to_string(&message).map_err(|e| format!("Failed to serialize: {}", e))?;
let _ = state let _ = state
.web_adapter .web_adapter
.send_message_to_session( .send_message_to_session(

View file

@ -28,18 +28,6 @@
| | | |
\*****************************************************************************/ \*****************************************************************************/
use crate::core::config::ConfigManager; use crate::core::config::ConfigManager;
use crate::shared::models::UserSession; use crate::shared::models::UserSession;
use crate::shared::state::AppState; use crate::shared::state::AppState;
@ -49,7 +37,6 @@ use serde::{Deserialize, Serialize};
use std::sync::Arc; use std::sync::Arc;
use uuid::Uuid; use uuid::Uuid;
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub enum SmsProvider { pub enum SmsProvider {
Twilio, Twilio,
@ -59,7 +46,6 @@ pub enum SmsProvider {
Custom(String), Custom(String),
} }
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum SmsPriority { pub enum SmsPriority {
Low, Low,
@ -108,7 +94,6 @@ impl From<&str> for SmsProvider {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SmsSendResult { pub struct SmsSendResult {
pub success: bool, pub success: bool,
@ -119,15 +104,12 @@ pub struct SmsSendResult {
pub error: Option<String>, pub error: Option<String>,
} }
pub fn register_sms_keywords(state: Arc<AppState>, user: UserSession, engine: &mut Engine) { pub fn register_sms_keywords(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
register_send_sms_keyword(state.clone(), user.clone(), engine); register_send_sms_keyword(state.clone(), user.clone(), engine);
register_send_sms_with_third_arg_keyword(state.clone(), user.clone(), engine); register_send_sms_with_third_arg_keyword(state.clone(), user.clone(), engine);
register_send_sms_full_keyword(state, user, engine); register_send_sms_full_keyword(state, user, engine);
} }
pub fn register_send_sms_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) { pub fn register_send_sms_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
let state_clone = Arc::clone(&state); let state_clone = Arc::clone(&state);
let user_clone = user.clone(); let user_clone = user.clone();
@ -211,9 +193,6 @@ pub fn register_send_sms_keyword(state: Arc<AppState>, user: UserSession, engine
.unwrap(); .unwrap();
} }
pub fn register_send_sms_with_third_arg_keyword( pub fn register_send_sms_with_third_arg_keyword(
state: Arc<AppState>, state: Arc<AppState>,
user: UserSession, user: UserSession,
@ -231,7 +210,6 @@ pub fn register_send_sms_with_third_arg_keyword(
let message = context.eval_expression_tree(&inputs[1])?.to_string(); let message = context.eval_expression_tree(&inputs[1])?.to_string();
let third_arg = context.eval_expression_tree(&inputs[2])?.to_string(); let third_arg = context.eval_expression_tree(&inputs[2])?.to_string();
let is_priority = matches!( let is_priority = matches!(
third_arg.to_lowercase().as_str(), third_arg.to_lowercase().as_str(),
"low" | "normal" | "high" | "urgent" | "critical" "low" | "normal" | "high" | "urgent" | "critical"
@ -319,8 +297,6 @@ pub fn register_send_sms_with_third_arg_keyword(
.unwrap(); .unwrap();
} }
pub fn register_send_sms_full_keyword( pub fn register_send_sms_full_keyword(
state: Arc<AppState>, state: Arc<AppState>,
user: UserSession, user: UserSession,
@ -414,7 +390,6 @@ pub fn register_send_sms_full_keyword(
) )
.unwrap(); .unwrap();
let state_clone2 = Arc::clone(&state); let state_clone2 = Arc::clone(&state);
let user_clone2 = user.clone(); let user_clone2 = user.clone();
@ -517,7 +492,6 @@ async fn execute_send_sms(
let config_manager = ConfigManager::new(state.conn.clone()); let config_manager = ConfigManager::new(state.conn.clone());
let bot_id = user.bot_id; let bot_id = user.bot_id;
let provider_name = match provider_override { let provider_name = match provider_override {
Some(p) => p.to_string(), Some(p) => p.to_string(),
None => config_manager None => config_manager
@ -527,7 +501,6 @@ async fn execute_send_sms(
let provider = SmsProvider::from(provider_name.as_str()); let provider = SmsProvider::from(provider_name.as_str());
let priority = match priority_override { let priority = match priority_override {
Some(p) => SmsPriority::from(p), Some(p) => SmsPriority::from(p),
None => { None => {
@ -538,10 +511,8 @@ async fn execute_send_sms(
} }
}; };
let normalized_phone = normalize_phone_number(phone); let normalized_phone = normalize_phone_number(phone);
if matches!(priority, SmsPriority::High | SmsPriority::Urgent) { if matches!(priority, SmsPriority::High | SmsPriority::Urgent) {
info!( info!(
"High priority SMS to {}: priority={}", "High priority SMS to {}: priority={}",
@ -549,7 +520,6 @@ async fn execute_send_sms(
); );
} }
let result = match provider { let result = match provider {
SmsProvider::Twilio => { SmsProvider::Twilio => {
send_via_twilio(state, &bot_id, &normalized_phone, message, &priority).await send_via_twilio(state, &bot_id, &normalized_phone, message, &priority).await
@ -602,20 +572,16 @@ async fn execute_send_sms(
} }
fn normalize_phone_number(phone: &str) -> String { fn normalize_phone_number(phone: &str) -> String {
let has_plus = phone.starts_with('+'); let has_plus = phone.starts_with('+');
let digits: String = phone.chars().filter(|c| c.is_ascii_digit()).collect(); let digits: String = phone.chars().filter(|c| c.is_ascii_digit()).collect();
if has_plus { if has_plus {
format!("+{}", digits) format!("+{}", digits)
} else if digits.len() == 10 { } else if digits.len() == 10 {
format!("+1{}", digits) format!("+1{}", digits)
} else if digits.len() == 11 && digits.starts_with('1') { } else if digits.len() == 11 && digits.starts_with('1') {
format!("+{}", digits) format!("+{}", digits)
} else { } else {
format!("+{}", digits) format!("+{}", digits)
} }
} }
@ -647,8 +613,6 @@ async fn send_via_twilio(
account_sid account_sid
); );
let final_message = match priority { let final_message = match priority {
SmsPriority::Urgent => format!("[URGENT] {}", message), SmsPriority::Urgent => format!("[URGENT] {}", message),
SmsPriority::High => format!("[HIGH] {}", message), SmsPriority::High => format!("[HIGH] {}", message),
@ -699,22 +663,17 @@ async fn send_via_aws_sns(
.get_config(bot_id, "aws-region", Some("us-east-1")) .get_config(bot_id, "aws-region", Some("us-east-1"))
.unwrap_or_else(|_| "us-east-1".to_string()); .unwrap_or_else(|_| "us-east-1".to_string());
let client = reqwest::Client::new(); let client = reqwest::Client::new();
let url = format!("https://sns.{}.amazonaws.com/", region); let url = format!("https://sns.{}.amazonaws.com/", region);
let timestamp = chrono::Utc::now().format("%Y%m%dT%H%M%SZ").to_string(); let timestamp = chrono::Utc::now().format("%Y%m%dT%H%M%SZ").to_string();
let _date = &timestamp[..8]; let _date = &timestamp[..8];
let sms_type = match priority { let sms_type = match priority {
SmsPriority::High | SmsPriority::Urgent => "Transactional", SmsPriority::High | SmsPriority::Urgent => "Transactional",
_ => "Promotional", _ => "Promotional",
}; };
let params = [ let params = [
("Action", "Publish"), ("Action", "Publish"),
("PhoneNumber", phone), ("PhoneNumber", phone),
@ -725,8 +684,6 @@ async fn send_via_aws_sns(
("MessageAttributes.entry.1.Value.StringValue", sms_type), ("MessageAttributes.entry.1.Value.StringValue", sms_type),
]; ];
let response = client let response = client
.post(&url) .post(&url)
.form(&params) .form(&params)
@ -774,7 +731,6 @@ async fn send_via_vonage(
let client = reqwest::Client::new(); let client = reqwest::Client::new();
let message_class = match priority { let message_class = match priority {
SmsPriority::Urgent => Some("0"), SmsPriority::Urgent => Some("0"),
_ => None, _ => None,
@ -798,25 +754,24 @@ async fn send_via_vonage(
.send() .send()
.await?; .await?;
if response.status().is_success() { if !response.status().is_success() {
let json: serde_json::Value = response.json().await?;
let messages = json["messages"].as_array();
if let Some(msgs) = messages {
if let Some(first) = msgs.first() {
if first["status"].as_str() == Some("0") {
return Ok(first["message-id"].as_str().map(|s| s.to_string()));
} else {
let error_text = first["error-text"].as_str().unwrap_or("Unknown error");
return Err(format!("Vonage error: {}", error_text).into());
}
}
}
Err("Invalid Vonage response".into())
} else {
let error_text = response.text().await?; let error_text = response.text().await?;
Err(format!("Vonage API error: {}", error_text).into()) return Err(format!("Vonage API error: {error_text}").into());
} }
let json: serde_json::Value = response.json().await?;
let messages = json["messages"].as_array();
if let Some(msgs) = messages {
if let Some(first) = msgs.first() {
if first["status"].as_str() == Some("0") {
return Ok(first["message-id"].as_str().map(|s| s.to_string()));
}
let error_text = first["error-text"].as_str().unwrap_or("Unknown error");
return Err(format!("Vonage error: {error_text}").into());
}
}
Err("Invalid Vonage response".into())
} }
async fn send_via_messagebird( async fn send_via_messagebird(
@ -840,7 +795,6 @@ async fn send_via_messagebird(
let client = reqwest::Client::new(); let client = reqwest::Client::new();
let type_details = match priority { let type_details = match priority {
SmsPriority::Urgent => Some(serde_json::json!({"class": 0})), SmsPriority::Urgent => Some(serde_json::json!({"class": 0})),
SmsPriority::High => Some(serde_json::json!({"class": 1})), SmsPriority::High => Some(serde_json::json!({"class": 1})),

View file

@ -1,42 +1,3 @@
use crate::shared::models::UserSession; use crate::shared::models::UserSession;
use crate::shared::state::AppState; use crate::shared::state::AppState;
use chrono::Utc; use chrono::Utc;
@ -48,10 +9,8 @@ use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
use uuid::Uuid; use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransferToHumanRequest { pub struct TransferToHumanRequest {
pub name: Option<String>, pub name: Option<String>,
pub department: Option<String>, pub department: Option<String>,
@ -63,7 +22,6 @@ pub struct TransferToHumanRequest {
pub context: Option<String>, pub context: Option<String>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransferResult { pub struct TransferResult {
pub success: bool, pub success: bool,
@ -75,11 +33,9 @@ pub struct TransferResult {
pub message: String, pub message: String,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum TransferStatus { pub enum TransferStatus {
Queued, Queued,
Assigned, Assigned,
@ -95,7 +51,6 @@ pub enum TransferStatus {
Error, Error,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Attendant { pub struct Attendant {
pub id: String, pub id: String,
@ -107,7 +62,6 @@ pub struct Attendant {
pub status: AttendantStatus, pub status: AttendantStatus,
} }
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum AttendantStatus { pub enum AttendantStatus {
@ -123,14 +77,12 @@ impl Default for AttendantStatus {
} }
} }
pub fn is_crm_enabled(bot_id: Uuid, work_path: &str) -> bool { pub fn is_crm_enabled(bot_id: Uuid, work_path: &str) -> bool {
let config_path = PathBuf::from(work_path) let config_path = PathBuf::from(work_path)
.join(format!("{}.gbai", bot_id)) .join(format!("{}.gbai", bot_id))
.join("config.csv"); .join("config.csv");
if !config_path.exists() { if !config_path.exists() {
let alt_path = PathBuf::from(work_path).join("config.csv"); let alt_path = PathBuf::from(work_path).join("config.csv");
if alt_path.exists() { if alt_path.exists() {
return check_config_for_crm(&alt_path); return check_config_for_crm(&alt_path);
@ -167,14 +119,12 @@ fn check_config_for_crm(config_path: &PathBuf) -> bool {
} }
} }
pub fn read_attendants(bot_id: Uuid, work_path: &str) -> Vec<Attendant> { pub fn read_attendants(bot_id: Uuid, work_path: &str) -> Vec<Attendant> {
let attendant_path = PathBuf::from(work_path) let attendant_path = PathBuf::from(work_path)
.join(format!("{}.gbai", bot_id)) .join(format!("{}.gbai", bot_id))
.join("attendant.csv"); .join("attendant.csv");
if !attendant_path.exists() { if !attendant_path.exists() {
let alt_path = PathBuf::from(work_path).join("attendant.csv"); let alt_path = PathBuf::from(work_path).join("attendant.csv");
if alt_path.exists() { if alt_path.exists() {
return parse_attendants_csv(&alt_path); return parse_attendants_csv(&alt_path);
@ -192,11 +142,9 @@ fn parse_attendants_csv(path: &PathBuf) -> Vec<Attendant> {
let mut attendants = Vec::new(); let mut attendants = Vec::new();
let mut lines = content.lines(); let mut lines = content.lines();
let header = lines.next().unwrap_or(""); let header = lines.next().unwrap_or("");
let headers: Vec<String> = header.split(',').map(|s| s.trim().to_lowercase()).collect(); let headers: Vec<String> = header.split(',').map(|s| s.trim().to_lowercase()).collect();
let id_idx = headers.iter().position(|h| h == "id").unwrap_or(0); let id_idx = headers.iter().position(|h| h == "id").unwrap_or(0);
let name_idx = headers.iter().position(|h| h == "name").unwrap_or(1); let name_idx = headers.iter().position(|h| h == "name").unwrap_or(1);
let channel_idx = headers.iter().position(|h| h == "channel").unwrap_or(2); let channel_idx = headers.iter().position(|h| h == "channel").unwrap_or(2);
@ -239,7 +187,6 @@ fn parse_attendants_csv(path: &PathBuf) -> Vec<Attendant> {
} }
} }
pub fn find_attendant<'a>( pub fn find_attendant<'a>(
attendants: &'a [Attendant], attendants: &'a [Attendant],
name: Option<&str>, name: Option<&str>,
@ -248,7 +195,6 @@ pub fn find_attendant<'a>(
if let Some(search_name) = name { if let Some(search_name) = name {
let search_lower = search_name.to_lowercase(); let search_lower = search_name.to_lowercase();
if let Some(att) = attendants if let Some(att) = attendants
.iter() .iter()
.find(|a| a.name.to_lowercase() == search_lower) .find(|a| a.name.to_lowercase() == search_lower)
@ -256,7 +202,6 @@ pub fn find_attendant<'a>(
return Some(att); return Some(att);
} }
if let Some(att) = attendants if let Some(att) = attendants
.iter() .iter()
.find(|a| a.name.to_lowercase().contains(&search_lower)) .find(|a| a.name.to_lowercase().contains(&search_lower))
@ -264,7 +209,6 @@ pub fn find_attendant<'a>(
return Some(att); return Some(att);
} }
if let Some(att) = attendants if let Some(att) = attendants
.iter() .iter()
.find(|a| a.aliases.contains(&search_lower)) .find(|a| a.aliases.contains(&search_lower))
@ -272,7 +216,6 @@ pub fn find_attendant<'a>(
return Some(att); return Some(att);
} }
if let Some(att) = attendants if let Some(att) = attendants
.iter() .iter()
.find(|a| a.id.to_lowercase() == search_lower) .find(|a| a.id.to_lowercase() == search_lower)
@ -284,7 +227,6 @@ pub fn find_attendant<'a>(
if let Some(dept) = department { if let Some(dept) = department {
let dept_lower = dept.to_lowercase(); let dept_lower = dept.to_lowercase();
if let Some(att) = attendants.iter().find(|a| { if let Some(att) = attendants.iter().find(|a| {
a.department a.department
.as_ref() .as_ref()
@ -295,7 +237,6 @@ pub fn find_attendant<'a>(
return Some(att); return Some(att);
} }
if let Some(att) = attendants.iter().find(|a| { if let Some(att) = attendants.iter().find(|a| {
a.preferences.to_lowercase().contains(&dept_lower) a.preferences.to_lowercase().contains(&dept_lower)
&& a.status == AttendantStatus::Online && a.status == AttendantStatus::Online
@ -304,13 +245,11 @@ pub fn find_attendant<'a>(
} }
} }
attendants attendants
.iter() .iter()
.find(|a| a.status == AttendantStatus::Online) .find(|a| a.status == AttendantStatus::Online)
} }
fn priority_to_int(priority: Option<&str>) -> i32 { fn priority_to_int(priority: Option<&str>) -> i32 {
match priority.map(|p| p.to_lowercase()).as_deref() { match priority.map(|p| p.to_lowercase()).as_deref() {
Some("urgent") => 3, Some("urgent") => 3,
@ -321,7 +260,6 @@ fn priority_to_int(priority: Option<&str>) -> i32 {
} }
} }
pub async fn execute_transfer( pub async fn execute_transfer(
state: Arc<AppState>, state: Arc<AppState>,
session: &UserSession, session: &UserSession,
@ -330,7 +268,6 @@ pub async fn execute_transfer(
let work_path = "./work"; let work_path = "./work";
let bot_id = session.bot_id; let bot_id = session.bot_id;
if !is_crm_enabled(bot_id, work_path) { if !is_crm_enabled(bot_id, work_path) {
return TransferResult { return TransferResult {
success: false, success: false,
@ -344,7 +281,6 @@ pub async fn execute_transfer(
}; };
} }
let attendants = read_attendants(bot_id, work_path); let attendants = read_attendants(bot_id, work_path);
if attendants.is_empty() { if attendants.is_empty() {
return TransferResult { return TransferResult {
@ -359,14 +295,12 @@ pub async fn execute_transfer(
}; };
} }
let attendant = find_attendant( let attendant = find_attendant(
&attendants, &attendants,
request.name.as_deref(), request.name.as_deref(),
request.department.as_deref(), request.department.as_deref(),
); );
if request.name.is_some() && attendant.is_none() { if request.name.is_some() && attendant.is_none() {
return TransferResult { return TransferResult {
success: false, success: false,
@ -387,7 +321,6 @@ pub async fn execute_transfer(
}; };
} }
let priority = priority_to_int(request.priority.as_deref()); let priority = priority_to_int(request.priority.as_deref());
let transfer_context = serde_json::json!({ let transfer_context = serde_json::json!({
"transfer_requested_at": Utc::now().to_rfc3339(), "transfer_requested_at": Utc::now().to_rfc3339(),
@ -402,7 +335,6 @@ pub async fn execute_transfer(
"status": if attendant.is_some() { "assigned" } else { "queued" }, "status": if attendant.is_some() { "assigned" } else { "queued" },
}); });
let session_id = session.id; let session_id = session.id;
let conn = state.conn.clone(); let conn = state.conn.clone();
let ctx_data = transfer_context.clone(); let ctx_data = transfer_context.clone();
@ -491,7 +423,6 @@ pub async fn execute_transfer(
} }
} }
impl TransferResult { impl TransferResult {
pub fn to_dynamic(&self) -> Dynamic { pub fn to_dynamic(&self) -> Dynamic {
let mut map = Map::new(); let mut map = Map::new();
@ -532,11 +463,11 @@ async fn calculate_queue_position(state: &Arc<AppState>, current_session_id: Uui
}; };
let query = diesel::sql_query( let query = diesel::sql_query(
r#"SELECT COUNT(*) as position FROM user_sessions r"SELECT COUNT(*) as position FROM user_sessions
WHERE context_data->>'needs_human' = 'true' WHERE context_data->>'needs_human' = 'true'
AND context_data->>'status' = 'queued' AND context_data->>'status' = 'queued'
AND created_at <= (SELECT created_at FROM user_sessions WHERE id = $1) AND created_at <= (SELECT created_at FROM user_sessions WHERE id = $1)
AND id != $2"#, AND id != $2",
) )
.bind::<diesel::sql_types::Uuid, _>(current_session_id) .bind::<diesel::sql_types::Uuid, _>(current_session_id)
.bind::<diesel::sql_types::Uuid, _>(current_session_id); .bind::<diesel::sql_types::Uuid, _>(current_session_id);
@ -563,13 +494,11 @@ async fn calculate_queue_position(state: &Arc<AppState>, current_session_id: Uui
} }
} }
pub fn register_transfer_to_human_keyword( pub fn register_transfer_to_human_keyword(
state: Arc<AppState>, state: Arc<AppState>,
user: UserSession, user: UserSession,
engine: &mut Engine, engine: &mut Engine,
) { ) {
let state_clone = state.clone(); let state_clone = state.clone();
let user_clone = user.clone(); let user_clone = user.clone();
engine.register_fn("transfer_to_human", move || -> Dynamic { engine.register_fn("transfer_to_human", move || -> Dynamic {
@ -595,7 +524,6 @@ pub fn register_transfer_to_human_keyword(
result.to_dynamic() result.to_dynamic()
}); });
let state_clone = state.clone(); let state_clone = state.clone();
let user_clone = user.clone(); let user_clone = user.clone();
engine.register_fn("transfer_to_human", move |name: &str| -> Dynamic { engine.register_fn("transfer_to_human", move |name: &str| -> Dynamic {
@ -622,7 +550,6 @@ pub fn register_transfer_to_human_keyword(
result.to_dynamic() result.to_dynamic()
}); });
let state_clone = state.clone(); let state_clone = state.clone();
let user_clone = user.clone(); let user_clone = user.clone();
engine.register_fn( engine.register_fn(
@ -653,7 +580,6 @@ pub fn register_transfer_to_human_keyword(
}, },
); );
let state_clone = state.clone(); let state_clone = state.clone();
let user_clone = user.clone(); let user_clone = user.clone();
engine.register_fn( engine.register_fn(
@ -685,7 +611,6 @@ pub fn register_transfer_to_human_keyword(
}, },
); );
let state_clone = state.clone(); let state_clone = state.clone();
let user_clone = user.clone(); let user_clone = user.clone();
engine.register_fn("transfer_to_human_ex", move |params: Map| -> Dynamic { engine.register_fn("transfer_to_human_ex", move |params: Map| -> Dynamic {
@ -730,7 +655,6 @@ pub fn register_transfer_to_human_keyword(
debug!("Registered TRANSFER TO HUMAN keywords"); debug!("Registered TRANSFER TO HUMAN keywords");
} }
pub const TRANSFER_TO_HUMAN_TOOL_SCHEMA: &str = r#"{ pub const TRANSFER_TO_HUMAN_TOOL_SCHEMA: &str = r#"{
"name": "transfer_to_human", "name": "transfer_to_human",
"description": "Transfer the conversation to a human attendant. Use this when the customer explicitly asks to speak with a person, when the issue is too complex for automated handling, or when emotional support is needed.", "description": "Transfer the conversation to a human attendant. Use this when the customer explicitly asks to speak with a person, when the issue is too complex for automated handling, or when emotional support is needed.",
@ -761,7 +685,6 @@ pub const TRANSFER_TO_HUMAN_TOOL_SCHEMA: &str = r#"{
} }
}"#; }"#;
pub fn get_tool_definition() -> serde_json::Value { pub fn get_tool_definition() -> serde_json::Value {
serde_json::json!({ serde_json::json!({
"type": "function", "type": "function",

View file

@ -6,7 +6,7 @@ use std::time::Duration;
pub fn wait_keyword(_state: &AppState, _user: UserSession, engine: &mut Engine) { pub fn wait_keyword(_state: &AppState, _user: UserSession, engine: &mut Engine) {
engine engine
.register_custom_syntax(&["WAIT", "$expr$"], false, move |context, inputs| { .register_custom_syntax(["WAIT", "$expr$"], false, move |context, inputs| {
let seconds = context.eval_expression_tree(&inputs[0])?; let seconds = context.eval_expression_tree(&inputs[0])?;
let duration_secs = if seconds.is::<i64>() { let duration_secs = if seconds.is::<i64>() {
#[allow(clippy::cast_precision_loss)] #[allow(clippy::cast_precision_loss)]

View file

@ -345,9 +345,7 @@ async fn scrape_table(
.chain(row.select(&td_sel)) .chain(row.select(&td_sel))
.map(|cell| cell.text().collect::<Vec<_>>().join(" ").trim().to_string()) .map(|cell| cell.text().collect::<Vec<_>>().join(" ").trim().to_string())
.collect(); .collect();
if headers.is_empty() { if headers.is_empty() {}
continue;
}
} else { } else {
let mut row_map = Map::new(); let mut row_map = Map::new();
for (j, cell) in row.select(&td_sel).enumerate() { for (j, cell) in row.select(&td_sel).enumerate() {

View file

@ -1,8 +1,3 @@
use axum::{ use axum::{
http::StatusCode, http::StatusCode,
response::{IntoResponse, Response}, response::{IntoResponse, Response},
@ -14,11 +9,7 @@ use std::sync::Arc;
use super::CalendarEngine; use super::CalendarEngine;
use crate::shared::state::AppState; use crate::shared::state::AppState;
pub fn create_caldav_router(_engine: Arc<CalendarEngine>) -> Router<Arc<AppState>> { pub fn create_caldav_router(_engine: Arc<CalendarEngine>) -> Router<Arc<AppState>> {
Router::new() Router::new()
.route("/caldav", get(caldav_root)) .route("/caldav", get(caldav_root))
.route("/caldav/principals", get(caldav_principals)) .route("/caldav/principals", get(caldav_principals))
@ -30,7 +21,6 @@ pub fn create_caldav_router(_engine: Arc<CalendarEngine>) -> Router<Arc<AppState
) )
} }
async fn caldav_root() -> impl IntoResponse { async fn caldav_root() -> impl IntoResponse {
Response::builder() Response::builder()
.status(StatusCode::OK) .status(StatusCode::OK)
@ -57,7 +47,6 @@ async fn caldav_root() -> impl IntoResponse {
.unwrap() .unwrap()
} }
async fn caldav_principals() -> impl IntoResponse { async fn caldav_principals() -> impl IntoResponse {
Response::builder() Response::builder()
.status(StatusCode::OK) .status(StatusCode::OK)
@ -86,7 +75,6 @@ async fn caldav_principals() -> impl IntoResponse {
.unwrap() .unwrap()
} }
async fn caldav_calendars() -> impl IntoResponse { async fn caldav_calendars() -> impl IntoResponse {
Response::builder() Response::builder()
.status(StatusCode::OK) .status(StatusCode::OK)
@ -129,7 +117,6 @@ async fn caldav_calendars() -> impl IntoResponse {
.unwrap() .unwrap()
} }
async fn caldav_calendar() -> impl IntoResponse { async fn caldav_calendar() -> impl IntoResponse {
Response::builder() Response::builder()
.status(StatusCode::OK) .status(StatusCode::OK)
@ -156,14 +143,12 @@ async fn caldav_calendar() -> impl IntoResponse {
.unwrap() .unwrap()
} }
async fn caldav_event() -> impl IntoResponse { async fn caldav_event() -> impl IntoResponse {
Response::builder() Response::builder()
.status(StatusCode::OK) .status(StatusCode::OK)
.header("Content-Type", "text/calendar; charset=utf-8") .header("Content-Type", "text/calendar; charset=utf-8")
.body( .body(
r#"BEGIN:VCALENDAR r"BEGIN:VCALENDAR
VERSION:2.0 VERSION:2.0
PRODID:- PRODID:-
BEGIN:VEVENT BEGIN:VEVENT
@ -173,15 +158,13 @@ DTSTART:20240101T090000Z
DTEND:20240101T100000Z DTEND:20240101T100000Z
SUMMARY:Placeholder Event SUMMARY:Placeholder Event
END:VEVENT END:VEVENT
END:VCALENDAR"# END:VCALENDAR"
.to_string(), .to_string(),
) )
.unwrap() .unwrap()
} }
async fn caldav_put_event() -> impl IntoResponse { async fn caldav_put_event() -> impl IntoResponse {
Response::builder() Response::builder()
.status(StatusCode::CREATED) .status(StatusCode::CREATED)
.header("ETag", "\"placeholder-etag\"") .header("ETag", "\"placeholder-etag\"")

View file

@ -1,16 +1,9 @@
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use regex::Regex; use regex::Regex;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use walkdir::WalkDir; use walkdir::WalkDir;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum IssueSeverity { pub enum IssueSeverity {
@ -33,7 +26,6 @@ impl std::fmt::Display for IssueSeverity {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum IssueType { pub enum IssueType {
@ -64,7 +56,6 @@ impl std::fmt::Display for IssueType {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CodeIssue { pub struct CodeIssue {
pub id: String, pub id: String,
@ -80,7 +71,6 @@ pub struct CodeIssue {
pub detected_at: DateTime<Utc>, pub detected_at: DateTime<Utc>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BotScanResult { pub struct BotScanResult {
pub bot_id: String, pub bot_id: String,
@ -91,7 +81,6 @@ pub struct BotScanResult {
pub stats: ScanStats, pub stats: ScanStats,
} }
#[derive(Debug, Clone, Serialize, Deserialize, Default)] #[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ScanStats { pub struct ScanStats {
pub critical: usize, pub critical: usize,
@ -124,7 +113,6 @@ impl ScanStats {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComplianceScanResult { pub struct ComplianceScanResult {
pub scanned_at: DateTime<Utc>, pub scanned_at: DateTime<Utc>,
@ -135,7 +123,6 @@ pub struct ComplianceScanResult {
pub bot_results: Vec<BotScanResult>, pub bot_results: Vec<BotScanResult>,
} }
struct ScanPattern { struct ScanPattern {
regex: Regex, regex: Regex,
issue_type: IssueType, issue_type: IssueType,
@ -146,14 +133,12 @@ struct ScanPattern {
category: String, category: String,
} }
pub struct CodeScanner { pub struct CodeScanner {
patterns: Vec<ScanPattern>, patterns: Vec<ScanPattern>,
base_path: PathBuf, base_path: PathBuf,
} }
impl CodeScanner { impl CodeScanner {
pub fn new(base_path: impl AsRef<Path>) -> Self { pub fn new(base_path: impl AsRef<Path>) -> Self {
let patterns = Self::build_patterns(); let patterns = Self::build_patterns();
Self { Self {
@ -162,11 +147,9 @@ impl CodeScanner {
} }
} }
fn build_patterns() -> Vec<ScanPattern> { fn build_patterns() -> Vec<ScanPattern> {
let mut patterns = Vec::new(); let mut patterns = Vec::new();
patterns.push(ScanPattern { patterns.push(ScanPattern {
regex: Regex::new(r#"(?i)password\s*=\s*["'][^"']+["']"#).unwrap(), regex: Regex::new(r#"(?i)password\s*=\s*["'][^"']+["']"#).unwrap(),
issue_type: IssueType::PasswordInConfig, issue_type: IssueType::PasswordInConfig,
@ -197,9 +180,8 @@ impl CodeScanner {
category: "Security".to_string(), category: "Security".to_string(),
}); });
patterns.push(ScanPattern { patterns.push(ScanPattern {
regex: Regex::new(r#"(?i)IF\s+.*\binput\b"#).unwrap(), regex: Regex::new(r"(?i)IF\s+.*\binput\b").unwrap(),
issue_type: IssueType::DeprecatedIfInput, issue_type: IssueType::DeprecatedIfInput,
severity: IssueSeverity::Medium, severity: IssueSeverity::Medium,
title: "Deprecated IF...input Pattern".to_string(), title: "Deprecated IF...input Pattern".to_string(),
@ -211,9 +193,8 @@ impl CodeScanner {
category: "Code Quality".to_string(), category: "Code Quality".to_string(),
}); });
patterns.push(ScanPattern { patterns.push(ScanPattern {
regex: Regex::new(r#"(?i)\b(GET_BOT_MEMORY|SET_BOT_MEMORY|GET_USER_MEMORY|SET_USER_MEMORY|USE_KB|USE_TOOL|SEND_MAIL|CREATE_TASK)\b"#).unwrap(), regex: Regex::new(r"(?i)\b(GET_BOT_MEMORY|SET_BOT_MEMORY|GET_USER_MEMORY|SET_USER_MEMORY|USE_KB|USE_TOOL|SEND_MAIL|CREATE_TASK)\b").unwrap(),
issue_type: IssueType::UnderscoreInKeyword, issue_type: IssueType::UnderscoreInKeyword,
severity: IssueSeverity::Low, severity: IssueSeverity::Low,
title: "Underscore in Keyword".to_string(), title: "Underscore in Keyword".to_string(),
@ -222,9 +203,8 @@ impl CodeScanner {
category: "Naming Convention".to_string(), category: "Naming Convention".to_string(),
}); });
patterns.push(ScanPattern { patterns.push(ScanPattern {
regex: Regex::new(r#"(?i)POST\s+TO\s+INSTAGRAM\s+\w+\s*,\s*\w+"#).unwrap(), regex: Regex::new(r"(?i)POST\s+TO\s+INSTAGRAM\s+\w+\s*,\s*\w+").unwrap(),
issue_type: IssueType::InsecurePattern, issue_type: IssueType::InsecurePattern,
severity: IssueSeverity::High, severity: IssueSeverity::High,
title: "Instagram Credentials in Code".to_string(), title: "Instagram Credentials in Code".to_string(),
@ -235,10 +215,8 @@ impl CodeScanner {
category: "Security".to_string(), category: "Security".to_string(),
}); });
patterns.push(ScanPattern { patterns.push(ScanPattern {
regex: Regex::new(r#"(?i)(SELECT|INSERT|UPDATE|DELETE)\s+.*(FROM|INTO|SET)\s+"#) regex: Regex::new(r"(?i)(SELECT|INSERT|UPDATE|DELETE)\s+.*(FROM|INTO|SET)\s+").unwrap(),
.unwrap(),
issue_type: IssueType::FragileCode, issue_type: IssueType::FragileCode,
severity: IssueSeverity::Medium, severity: IssueSeverity::Medium,
title: "Raw SQL Query".to_string(), title: "Raw SQL Query".to_string(),
@ -250,9 +228,8 @@ impl CodeScanner {
category: "Security".to_string(), category: "Security".to_string(),
}); });
patterns.push(ScanPattern { patterns.push(ScanPattern {
regex: Regex::new(r#"(?i)\bEVAL\s*\("#).unwrap(), regex: Regex::new(r"(?i)\bEVAL\s*\(").unwrap(),
issue_type: IssueType::FragileCode, issue_type: IssueType::FragileCode,
severity: IssueSeverity::High, severity: IssueSeverity::High,
title: "Dynamic Code Execution".to_string(), title: "Dynamic Code Execution".to_string(),
@ -261,7 +238,6 @@ impl CodeScanner {
category: "Security".to_string(), category: "Security".to_string(),
}); });
patterns.push(ScanPattern { patterns.push(ScanPattern {
regex: Regex::new( regex: Regex::new(
r#"(?i)(password|secret|key|token)\s*=\s*["'][A-Za-z0-9+/=]{40,}["']"#, r#"(?i)(password|secret|key|token)\s*=\s*["'][A-Za-z0-9+/=]{40,}["']"#,
@ -276,9 +252,8 @@ impl CodeScanner {
category: "Security".to_string(), category: "Security".to_string(),
}); });
patterns.push(ScanPattern { patterns.push(ScanPattern {
regex: Regex::new(r#"(?i)(AKIA[0-9A-Z]{16})"#).unwrap(), regex: Regex::new(r"(?i)(AKIA[0-9A-Z]{16})").unwrap(),
issue_type: IssueType::HardcodedSecret, issue_type: IssueType::HardcodedSecret,
severity: IssueSeverity::Critical, severity: IssueSeverity::Critical,
title: "AWS Access Key".to_string(), title: "AWS Access Key".to_string(),
@ -288,9 +263,8 @@ impl CodeScanner {
category: "Security".to_string(), category: "Security".to_string(),
}); });
patterns.push(ScanPattern { patterns.push(ScanPattern {
regex: Regex::new(r#"-----BEGIN\s+(RSA\s+)?PRIVATE\s+KEY-----"#).unwrap(), regex: Regex::new(r"-----BEGIN\s+(RSA\s+)?PRIVATE\s+KEY-----").unwrap(),
issue_type: IssueType::HardcodedSecret, issue_type: IssueType::HardcodedSecret,
severity: IssueSeverity::Critical, severity: IssueSeverity::Critical,
title: "Private Key in Code".to_string(), title: "Private Key in Code".to_string(),
@ -300,9 +274,8 @@ impl CodeScanner {
category: "Security".to_string(), category: "Security".to_string(),
}); });
patterns.push(ScanPattern { patterns.push(ScanPattern {
regex: Regex::new(r#"(?i)(postgres|mysql|mongodb|redis)://[^:]+:[^@]+@"#).unwrap(), regex: Regex::new(r"(?i)(postgres|mysql|mongodb|redis)://[^:]+:[^@]+@").unwrap(),
issue_type: IssueType::HardcodedSecret, issue_type: IssueType::HardcodedSecret,
severity: IssueSeverity::Critical, severity: IssueSeverity::Critical,
title: "Database Credentials in Connection String".to_string(), title: "Database Credentials in Connection String".to_string(),
@ -314,7 +287,6 @@ impl CodeScanner {
patterns patterns
} }
pub async fn scan_all( pub async fn scan_all(
&self, &self,
) -> Result<ComplianceScanResult, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<ComplianceScanResult, Box<dyn std::error::Error + Send + Sync>> {
@ -323,13 +295,11 @@ impl CodeScanner {
let mut total_stats = ScanStats::default(); let mut total_stats = ScanStats::default();
let mut total_files = 0; let mut total_files = 0;
let templates_path = self.base_path.join("templates"); let templates_path = self.base_path.join("templates");
let work_path = self.base_path.join("work"); let work_path = self.base_path.join("work");
let mut bot_paths = Vec::new(); let mut bot_paths = Vec::new();
if templates_path.exists() { if templates_path.exists() {
for entry in WalkDir::new(&templates_path).max_depth(3) { for entry in WalkDir::new(&templates_path).max_depth(3) {
if let Ok(entry) = entry { if let Ok(entry) = entry {
@ -344,7 +314,6 @@ impl CodeScanner {
} }
} }
if work_path.exists() { if work_path.exists() {
for entry in WalkDir::new(&work_path).max_depth(3) { for entry in WalkDir::new(&work_path).max_depth(3) {
if let Ok(entry) = entry { if let Ok(entry) = entry {
@ -359,7 +328,6 @@ impl CodeScanner {
} }
} }
for bot_path in &bot_paths { for bot_path in &bot_paths {
let result = self.scan_bot(bot_path).await?; let result = self.scan_bot(bot_path).await?;
total_files += result.files_scanned; total_files += result.files_scanned;
@ -379,7 +347,6 @@ impl CodeScanner {
}) })
} }
pub async fn scan_bot( pub async fn scan_bot(
&self, &self,
bot_path: &Path, bot_path: &Path,
@ -397,7 +364,6 @@ impl CodeScanner {
let mut stats = ScanStats::default(); let mut stats = ScanStats::default();
let mut files_scanned = 0; let mut files_scanned = 0;
for entry in WalkDir::new(bot_path) { for entry in WalkDir::new(bot_path) {
if let Ok(entry) = entry { if let Ok(entry) = entry {
let path = entry.path(); let path = entry.path();
@ -415,7 +381,6 @@ impl CodeScanner {
} }
} }
let config_path = bot_path.join("config.csv"); let config_path = bot_path.join("config.csv");
if config_path.exists() { if config_path.exists() {
let vault_configured = self.check_vault_config(&config_path).await?; let vault_configured = self.check_vault_config(&config_path).await?;
@ -438,7 +403,6 @@ impl CodeScanner {
} }
} }
issues.sort_by(|a, b| b.severity.cmp(&a.severity)); issues.sort_by(|a, b| b.severity.cmp(&a.severity));
Ok(BotScanResult { Ok(BotScanResult {
@ -451,7 +415,6 @@ impl CodeScanner {
}) })
} }
async fn scan_file( async fn scan_file(
&self, &self,
file_path: &Path, file_path: &Path,
@ -468,7 +431,6 @@ impl CodeScanner {
for (line_number, line) in content.lines().enumerate() { for (line_number, line) in content.lines().enumerate() {
let line_num = line_number + 1; let line_num = line_number + 1;
let trimmed = line.trim(); let trimmed = line.trim();
if trimmed.starts_with("REM") || trimmed.starts_with("'") || trimmed.starts_with("//") { if trimmed.starts_with("REM") || trimmed.starts_with("'") || trimmed.starts_with("//") {
continue; continue;
@ -476,7 +438,6 @@ impl CodeScanner {
for pattern in &self.patterns { for pattern in &self.patterns {
if pattern.regex.is_match(line) { if pattern.regex.is_match(line) {
let snippet = self.redact_sensitive(line); let snippet = self.redact_sensitive(line);
let issue = CodeIssue { let issue = CodeIssue {
@ -500,18 +461,15 @@ impl CodeScanner {
Ok(issues) Ok(issues)
} }
fn redact_sensitive(&self, line: &str) -> String { fn redact_sensitive(&self, line: &str) -> String {
let mut result = line.to_string(); let mut result = line.to_string();
let secret_pattern = Regex::new(r#"(["'])[^"']{8,}(["'])"#).unwrap(); let secret_pattern = Regex::new(r#"(["'])[^"']{8,}(["'])"#).unwrap();
result = secret_pattern result = secret_pattern
.replace_all(&result, "$1***REDACTED***$2") .replace_all(&result, "$1***REDACTED***$2")
.to_string(); .to_string();
let aws_pattern = Regex::new(r"AKIA[0-9A-Z]{16}").unwrap();
let aws_pattern = Regex::new(r#"AKIA[0-9A-Z]{16}"#).unwrap();
result = aws_pattern result = aws_pattern
.replace_all(&result, "AKIA***REDACTED***") .replace_all(&result, "AKIA***REDACTED***")
.to_string(); .to_string();
@ -519,14 +477,12 @@ impl CodeScanner {
result result
} }
async fn check_vault_config( async fn check_vault_config(
&self, &self,
config_path: &Path, config_path: &Path,
) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
let content = tokio::fs::read_to_string(config_path).await?; let content = tokio::fs::read_to_string(config_path).await?;
let has_vault = content.to_lowercase().contains("vault_addr") let has_vault = content.to_lowercase().contains("vault_addr")
|| content.to_lowercase().contains("vault_token") || content.to_lowercase().contains("vault_token")
|| content.to_lowercase().contains("vault-"); || content.to_lowercase().contains("vault-");
@ -534,7 +490,6 @@ impl CodeScanner {
Ok(has_vault) Ok(has_vault)
} }
pub async fn scan_bots( pub async fn scan_bots(
&self, &self,
bot_ids: &[String], bot_ids: &[String],
@ -543,14 +498,11 @@ impl CodeScanner {
return self.scan_all().await; return self.scan_all().await;
} }
let mut full_result = self.scan_all().await?; let mut full_result = self.scan_all().await?;
full_result full_result
.bot_results .bot_results
.retain(|r| bot_ids.contains(&r.bot_id) || bot_ids.contains(&r.bot_name)); .retain(|r| bot_ids.contains(&r.bot_id) || bot_ids.contains(&r.bot_name));
let mut new_stats = ScanStats::default(); let mut new_stats = ScanStats::default();
for bot in &full_result.bot_results { for bot in &full_result.bot_results {
new_stats.merge(&bot.stats); new_stats.merge(&bot.stats);
@ -562,11 +514,9 @@ impl CodeScanner {
} }
} }
pub struct ComplianceReporter; pub struct ComplianceReporter;
impl ComplianceReporter { impl ComplianceReporter {
pub fn to_html(result: &ComplianceScanResult) -> String { pub fn to_html(result: &ComplianceScanResult) -> String {
let mut html = String::new(); let mut html = String::new();
@ -617,12 +567,10 @@ impl ComplianceReporter {
html html
} }
pub fn to_json(result: &ComplianceScanResult) -> Result<String, serde_json::Error> { pub fn to_json(result: &ComplianceScanResult) -> Result<String, serde_json::Error> {
serde_json::to_string_pretty(result) serde_json::to_string_pretty(result)
} }
pub fn to_csv(result: &ComplianceScanResult) -> String { pub fn to_csv(result: &ComplianceScanResult) -> String {
let mut csv = String::new(); let mut csv = String::new();
csv.push_str("Severity,Type,Category,File,Line,Title,Description,Remediation\n"); csv.push_str("Severity,Type,Category,File,Line,Title,Description,Remediation\n");
@ -650,7 +598,6 @@ impl ComplianceReporter {
} }
} }
fn escape_csv(s: &str) -> String { fn escape_csv(s: &str) -> String {
if s.contains(',') || s.contains('"') || s.contains('\n') { if s.contains(',') || s.contains('"') || s.contains('\n') {
format!("\"{}\"", s.replace('"', "\"\"")) format!("\"{}\"", s.replace('"', "\"\""))

View file

@ -1,13 +1,3 @@
use crate::shared::platform_name; use crate::shared::platform_name;
use crate::shared::BOTSERVER_VERSION; use crate::shared::BOTSERVER_VERSION;
use crossterm::{ use crossterm::{
@ -21,38 +11,27 @@ use serde::{Deserialize, Serialize};
use std::io::{self, Write}; use std::io::{self, Write};
use std::path::PathBuf; use std::path::PathBuf;
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WizardConfig { pub struct WizardConfig {
pub llm_provider: LlmProvider, pub llm_provider: LlmProvider,
pub llm_api_key: Option<String>, pub llm_api_key: Option<String>,
pub local_model_path: Option<String>, pub local_model_path: Option<String>,
pub components: Vec<ComponentChoice>, pub components: Vec<ComponentChoice>,
pub admin: AdminConfig, pub admin: AdminConfig,
pub organization: OrgConfig, pub organization: OrgConfig,
pub template: Option<String>, pub template: Option<String>,
pub install_mode: InstallMode, pub install_mode: InstallMode,
pub data_dir: PathBuf, pub data_dir: PathBuf,
} }
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum LlmProvider { pub enum LlmProvider {
Claude, Claude,
@ -74,7 +53,6 @@ impl std::fmt::Display for LlmProvider {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ComponentChoice { pub enum ComponentChoice {
Drive, Drive,
@ -104,7 +82,6 @@ impl std::fmt::Display for ComponentChoice {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize, Default)] #[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct AdminConfig { pub struct AdminConfig {
pub username: String, pub username: String,
@ -113,7 +90,6 @@ pub struct AdminConfig {
pub display_name: String, pub display_name: String,
} }
#[derive(Debug, Clone, Serialize, Deserialize, Default)] #[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct OrgConfig { pub struct OrgConfig {
pub name: String, pub name: String,
@ -121,7 +97,6 @@ pub struct OrgConfig {
pub domain: Option<String>, pub domain: Option<String>,
} }
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum InstallMode { pub enum InstallMode {
Development, Development,
@ -149,7 +124,6 @@ impl Default for WizardConfig {
} }
} }
#[derive(Debug)] #[derive(Debug)]
pub struct StartupWizard { pub struct StartupWizard {
config: WizardConfig, config: WizardConfig,
@ -166,12 +140,10 @@ impl StartupWizard {
} }
} }
pub fn run(&mut self) -> io::Result<WizardConfig> { pub fn run(&mut self) -> io::Result<WizardConfig> {
terminal::enable_raw_mode()?; terminal::enable_raw_mode()?;
let mut stdout = io::stdout(); let mut stdout = io::stdout();
execute!( execute!(
stdout, stdout,
terminal::Clear(ClearType::All), terminal::Clear(ClearType::All),
@ -181,31 +153,24 @@ impl StartupWizard {
self.show_welcome(&mut stdout)?; self.show_welcome(&mut stdout)?;
self.wait_for_enter()?; self.wait_for_enter()?;
self.current_step = 1; self.current_step = 1;
self.step_install_mode(&mut stdout)?; self.step_install_mode(&mut stdout)?;
self.current_step = 2; self.current_step = 2;
self.step_llm_provider(&mut stdout)?; self.step_llm_provider(&mut stdout)?;
self.current_step = 3; self.current_step = 3;
self.step_components(&mut stdout)?; self.step_components(&mut stdout)?;
self.current_step = 4; self.current_step = 4;
self.step_organization(&mut stdout)?; self.step_organization(&mut stdout)?;
self.current_step = 5; self.current_step = 5;
self.step_admin_user(&mut stdout)?; self.step_admin_user(&mut stdout)?;
self.current_step = 6; self.current_step = 6;
self.step_template(&mut stdout)?; self.step_template(&mut stdout)?;
self.current_step = 7; self.current_step = 7;
self.step_summary(&mut stdout)?; self.step_summary(&mut stdout)?;
@ -220,7 +185,7 @@ impl StartupWizard {
cursor::MoveTo(0, 0) cursor::MoveTo(0, 0)
)?; )?;
let banner = r#" let banner = r"
@ -237,7 +202,7 @@ impl StartupWizard {
"#; ";
execute!( execute!(
stdout, stdout,
@ -288,7 +253,6 @@ impl StartupWizard {
cursor::MoveTo(0, 0) cursor::MoveTo(0, 0)
)?; )?;
let progress = format!("Step {}/{}: {}", self.current_step, self.total_steps, title); let progress = format!("Step {}/{}: {}", self.current_step, self.total_steps, title);
let bar_width = 50; let bar_width = 50;
let filled = (self.current_step * bar_width) / self.total_steps; let filled = (self.current_step * bar_width) / self.total_steps;
@ -396,7 +360,6 @@ impl StartupWizard {
let selected = self.select_option(stdout, &options, 0)?; let selected = self.select_option(stdout, &options, 0)?;
self.config.llm_provider = options[selected].2.clone(); self.config.llm_provider = options[selected].2.clone();
if self.config.llm_provider != LlmProvider::Local if self.config.llm_provider != LlmProvider::Local
&& self.config.llm_provider != LlmProvider::None && self.config.llm_provider != LlmProvider::None
{ {
@ -485,7 +448,6 @@ impl StartupWizard {
io::stdin().read_line(&mut org_name)?; io::stdin().read_line(&mut org_name)?;
self.config.organization.name = org_name.trim().to_string(); self.config.organization.name = org_name.trim().to_string();
self.config.organization.slug = self self.config.organization.slug = self
.config .config
.organization .organization
@ -557,7 +519,6 @@ impl StartupWizard {
execute!(stdout, cursor::MoveTo(2, 13), Print("Admin password: "))?; execute!(stdout, cursor::MoveTo(2, 13), Print("Admin password: "))?;
stdout.flush()?; stdout.flush()?;
let mut password = String::new(); let mut password = String::new();
io::stdin().read_line(&mut password)?; io::stdin().read_line(&mut password)?;
self.config.admin.password = password.trim().to_string(); self.config.admin.password = password.trim().to_string();
@ -824,7 +785,6 @@ impl StartupWizard {
} }
KeyCode::Char(' ') => { KeyCode::Char(' ') => {
if options[cursor].2 { if options[cursor].2 {
selected[cursor] = !selected[cursor]; selected[cursor] = !selected[cursor];
} }
} }
@ -857,7 +817,6 @@ impl StartupWizard {
} }
} }
pub fn save_wizard_config(config: &WizardConfig, path: &str) -> io::Result<()> { pub fn save_wizard_config(config: &WizardConfig, path: &str) -> io::Result<()> {
let content = toml::to_string_pretty(config) let content = toml::to_string_pretty(config)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
@ -865,7 +824,6 @@ pub fn save_wizard_config(config: &WizardConfig, path: &str) -> io::Result<()> {
Ok(()) Ok(())
} }
pub fn load_wizard_config(path: &str) -> io::Result<WizardConfig> { pub fn load_wizard_config(path: &str) -> io::Result<WizardConfig> {
let content = std::fs::read_to_string(path)?; let content = std::fs::read_to_string(path)?;
let config: WizardConfig = let config: WizardConfig =
@ -873,32 +831,26 @@ pub fn load_wizard_config(path: &str) -> io::Result<WizardConfig> {
Ok(config) Ok(config)
} }
pub fn should_run_wizard() -> bool { pub fn should_run_wizard() -> bool {
!std::path::Path::new("./botserver-stack").exists() !std::path::Path::new("./botserver-stack").exists()
&& !std::path::Path::new("/opt/gbo").exists() && !std::path::Path::new("/opt/gbo").exists()
} }
pub fn apply_wizard_config(config: &WizardConfig) -> io::Result<()> { pub fn apply_wizard_config(config: &WizardConfig) -> io::Result<()> {
use std::fs; use std::fs;
fs::create_dir_all(&config.data_dir)?; fs::create_dir_all(&config.data_dir)?;
let subdirs = ["bots", "logs", "cache", "uploads", "config"]; let subdirs = ["bots", "logs", "cache", "uploads", "config"];
for subdir in &subdirs { for subdir in &subdirs {
fs::create_dir_all(config.data_dir.join(subdir))?; fs::create_dir_all(config.data_dir.join(subdir))?;
} }
save_wizard_config( save_wizard_config(
config, config,
&config.data_dir.join("config/wizard.toml").to_string_lossy(), &config.data_dir.join("config/wizard.toml").to_string_lossy(),
)?; )?;
let mut env_content = String::new(); let mut env_content = String::new();
env_content.push_str(&format!( env_content.push_str(&format!(
"# Generated by {} Setup Wizard\n\n", "# Generated by {} Setup Wizard\n\n",

File diff suppressed because it is too large Load diff

View file

@ -71,7 +71,6 @@ impl PackageManager {
.await?; .await?;
} }
if !component.data_download_list.is_empty() { if !component.data_download_list.is_empty() {
let cache_base = self.base_path.parent().unwrap_or(&self.base_path); let cache_base = self.base_path.parent().unwrap_or(&self.base_path);
let cache = DownloadCache::new(cache_base).ok(); let cache = DownloadCache::new(cache_base).ok();
@ -83,18 +82,15 @@ impl PackageManager {
.join(&component.name) .join(&component.name)
.join(&filename); .join(&filename);
if output_path.exists() { if output_path.exists() {
info!("Data file already exists: {:?}", output_path); info!("Data file already exists: {:?}", output_path);
continue; continue;
} }
if let Some(parent) = output_path.parent() { if let Some(parent) = output_path.parent() {
std::fs::create_dir_all(parent)?; std::fs::create_dir_all(parent)?;
} }
if let Some(ref c) = cache { if let Some(ref c) = cache {
if let Some(cached_path) = c.get_cached_path(&filename) { if let Some(cached_path) = c.get_cached_path(&filename) {
info!("Using cached data file: {:?}", cached_path); info!("Using cached data file: {:?}", cached_path);
@ -103,7 +99,6 @@ impl PackageManager {
} }
} }
let download_target = if let Some(ref c) = cache { let download_target = if let Some(ref c) = cache {
c.get_cache_path(&filename) c.get_cache_path(&filename)
} else { } else {
@ -114,7 +109,6 @@ impl PackageManager {
println!("Downloading {}", url); println!("Downloading {}", url);
utils::download_file(url, download_target.to_str().unwrap()).await?; utils::download_file(url, download_target.to_str().unwrap()).await?;
if cache.is_some() && download_target != output_path { if cache.is_some() && download_target != output_path {
std::fs::copy(&download_target, &output_path)?; std::fs::copy(&download_target, &output_path)?;
info!("Copied cached file to: {:?}", output_path); info!("Copied cached file to: {:?}", output_path);
@ -127,10 +121,8 @@ impl PackageManager {
pub fn install_container(&self, component: &ComponentConfig) -> Result<InstallResult> { pub fn install_container(&self, component: &ComponentConfig) -> Result<InstallResult> {
let container_name = format!("{}-{}", self.tenant, component.name); let container_name = format!("{}-{}", self.tenant, component.name);
let _ = Command::new("lxd").args(&["init", "--auto"]).output(); let _ = Command::new("lxd").args(&["init", "--auto"]).output();
let images = [ let images = [
"ubuntu:24.04", "ubuntu:24.04",
"ubuntu:22.04", "ubuntu:22.04",
@ -157,14 +149,13 @@ impl PackageManager {
info!("Successfully created container with image: {}", image); info!("Successfully created container with image: {}", image);
success = true; success = true;
break; break;
} else {
last_error = String::from_utf8_lossy(&output.stderr).to_string();
warn!("Failed to create container with {}: {}", image, last_error);
let _ = Command::new("lxc")
.args(&["delete", &container_name, "--force"])
.output();
} }
last_error = String::from_utf8_lossy(&output.stderr).to_string();
warn!("Failed to create container with {}: {}", image, last_error);
let _ = Command::new("lxc")
.args(&["delete", &container_name, "--force"])
.output();
} }
if !success { if !success {
@ -176,7 +167,6 @@ impl PackageManager {
std::thread::sleep(std::time::Duration::from_secs(15)); std::thread::sleep(std::time::Duration::from_secs(15));
self.exec_in_container(&container_name, "mkdir -p /opt/gbo/{bin,data,conf,logs}")?; self.exec_in_container(&container_name, "mkdir -p /opt/gbo/{bin,data,conf,logs}")?;
self.exec_in_container( self.exec_in_container(
&container_name, &container_name,
"echo 'nameserver 8.8.8.8' > /etc/resolv.conf", "echo 'nameserver 8.8.8.8' > /etc/resolv.conf",
@ -186,7 +176,6 @@ impl PackageManager {
"echo 'nameserver 8.8.4.4' >> /etc/resolv.conf", "echo 'nameserver 8.8.4.4' >> /etc/resolv.conf",
)?; )?;
self.exec_in_container(&container_name, "apt-get update -qq")?; self.exec_in_container(&container_name, "apt-get update -qq")?;
self.exec_in_container( self.exec_in_container(
&container_name, &container_name,
@ -247,15 +236,12 @@ impl PackageManager {
} }
self.setup_port_forwarding(&container_name, &component.ports)?; self.setup_port_forwarding(&container_name, &component.ports)?;
let container_ip = self.get_container_ip(&container_name)?; let container_ip = self.get_container_ip(&container_name)?;
if component.name == "vault" { if component.name == "vault" {
self.initialize_vault(&container_name, &container_ip)?; self.initialize_vault(&container_name, &container_ip)?;
} }
let (connection_info, env_vars) = let (connection_info, env_vars) =
self.generate_connection_info(&component.name, &container_ip, &component.ports); self.generate_connection_info(&component.name, &container_ip, &component.ports);
@ -275,12 +261,9 @@ impl PackageManager {
}) })
} }
fn get_container_ip(&self, container_name: &str) -> Result<String> { fn get_container_ip(&self, container_name: &str) -> Result<String> {
std::thread::sleep(std::time::Duration::from_secs(2)); std::thread::sleep(std::time::Duration::from_secs(2));
let output = Command::new("lxc") let output = Command::new("lxc")
.args(&["list", container_name, "-c", "4", "--format", "csv"]) .args(&["list", container_name, "-c", "4", "--format", "csv"])
.output()?; .output()?;
@ -289,7 +272,6 @@ impl PackageManager {
let ip_output = String::from_utf8_lossy(&output.stdout).trim().to_string(); let ip_output = String::from_utf8_lossy(&output.stdout).trim().to_string();
if !ip_output.is_empty() { if !ip_output.is_empty() {
let ip = ip_output let ip = ip_output
.split(|c| c == ' ' || c == '(') .split(|c| c == ' ' || c == '(')
.next() .next()
@ -301,7 +283,6 @@ impl PackageManager {
} }
} }
let output = Command::new("lxc") let output = Command::new("lxc")
.args(&["exec", container_name, "--", "hostname", "-I"]) .args(&["exec", container_name, "--", "hostname", "-I"])
.output()?; .output()?;
@ -318,15 +299,11 @@ impl PackageManager {
Ok("unknown".to_string()) Ok("unknown".to_string())
} }
fn initialize_vault(&self, container_name: &str, ip: &str) -> Result<()> { fn initialize_vault(&self, container_name: &str, ip: &str) -> Result<()> {
info!("Initializing Vault..."); info!("Initializing Vault...");
std::thread::sleep(std::time::Duration::from_secs(5)); std::thread::sleep(std::time::Duration::from_secs(5));
let output = Command::new("lxc") let output = Command::new("lxc")
.args(&[ .args(&[
"exec", "exec",
@ -350,7 +327,6 @@ impl PackageManager {
let init_output = String::from_utf8_lossy(&output.stdout); let init_output = String::from_utf8_lossy(&output.stdout);
let init_json: serde_json::Value = let init_json: serde_json::Value =
serde_json::from_str(&init_output).context("Failed to parse Vault init output")?; serde_json::from_str(&init_output).context("Failed to parse Vault init output")?;
@ -361,12 +337,10 @@ impl PackageManager {
.as_str() .as_str()
.context("No root token in output")?; .context("No root token in output")?;
let unseal_keys_file = PathBuf::from("vault-unseal-keys"); let unseal_keys_file = PathBuf::from("vault-unseal-keys");
let mut unseal_content = String::new(); let mut unseal_content = String::new();
for (i, key) in unseal_keys.iter().enumerate() { for (i, key) in unseal_keys.iter().enumerate() {
if i < 3 { if i < 3 {
unseal_content.push_str(&format!( unseal_content.push_str(&format!(
"VAULT_UNSEAL_KEY_{}={}\n", "VAULT_UNSEAL_KEY_{}={}\n",
i + 1, i + 1,
@ -376,7 +350,6 @@ impl PackageManager {
} }
std::fs::write(&unseal_keys_file, &unseal_content)?; std::fs::write(&unseal_keys_file, &unseal_content)?;
#[cfg(unix)] #[cfg(unix)]
{ {
use std::os::unix::fs::PermissionsExt; use std::os::unix::fs::PermissionsExt;
@ -385,7 +358,6 @@ impl PackageManager {
info!("Created {}", unseal_keys_file.display()); info!("Created {}", unseal_keys_file.display());
let env_file = PathBuf::from(".env"); let env_file = PathBuf::from(".env");
let env_content = format!( let env_content = format!(
"\n# Vault Configuration (auto-generated)\nVAULT_ADDR=http://{}:8200\nVAULT_TOKEN={}\nVAULT_UNSEAL_KEYS_FILE=vault-unseal-keys\n", "\n# Vault Configuration (auto-generated)\nVAULT_ADDR=http://{}:8200\nVAULT_TOKEN={}\nVAULT_UNSEAL_KEYS_FILE=vault-unseal-keys\n",
@ -393,11 +365,9 @@ impl PackageManager {
); );
if env_file.exists() { if env_file.exists() {
let existing = std::fs::read_to_string(&env_file)?; let existing = std::fs::read_to_string(&env_file)?;
if !existing.contains("VAULT_ADDR=") { if !existing.contains("VAULT_ADDR=") {
let mut file = std::fs::OpenOptions::new().append(true).open(&env_file)?; let mut file = std::fs::OpenOptions::new().append(true).open(&env_file)?;
use std::io::Write; use std::io::Write;
file.write_all(env_content.as_bytes())?; file.write_all(env_content.as_bytes())?;
@ -406,13 +376,10 @@ impl PackageManager {
warn!(".env already contains VAULT_ADDR, not overwriting"); warn!(".env already contains VAULT_ADDR, not overwriting");
} }
} else { } else {
std::fs::write(&env_file, env_content.trim_start())?; std::fs::write(&env_file, env_content.trim_start())?;
info!("Created .env with Vault config"); info!("Created .env with Vault config");
} }
for i in 0..3 { for i in 0..3 {
if let Some(key) = unseal_keys.get(i) { if let Some(key) = unseal_keys.get(i) {
let key_str = key.as_str().unwrap_or(""); let key_str = key.as_str().unwrap_or("");
@ -434,8 +401,6 @@ impl PackageManager {
Ok(()) Ok(())
} }
fn generate_connection_info( fn generate_connection_info(
&self, &self,
component: &str, component: &str,
@ -445,9 +410,8 @@ impl PackageManager {
let env_vars = HashMap::new(); let env_vars = HashMap::new();
let connection_info = match component { let connection_info = match component {
"vault" => { "vault" => {
format!( format!(
r#"Vault Server: r"Vault Server:
URL: http://{}:8200 URL: http://{}:8200
UI: http://{}:8200/ui UI: http://{}:8200/ui
@ -466,141 +430,141 @@ Or manually:
lxc exec {}-vault -- /opt/gbo/bin/vault operator unseal <key> lxc exec {}-vault -- /opt/gbo/bin/vault operator unseal <key>
For other auto-unseal options (TPM, HSM, Transit), see: For other auto-unseal options (TPM, HSM, Transit), see:
https://generalbots.github.io/botbook/chapter-08/secrets-management.html"#, https://generalbots.github.io/botbook/chapter-08/secrets-management.html",
ip, ip, self.tenant ip, ip, self.tenant
) )
} }
"vector_db" => { "vector_db" => {
format!( format!(
r#"Qdrant Vector Database: r"Qdrant Vector Database:
REST API: http://{}:6333 REST API: http://{}:6333
gRPC: {}:6334 gRPC: {}:6334
Dashboard: http://{}:6333/dashboard Dashboard: http://{}:6333/dashboard
Store credentials in Vault: Store credentials in Vault:
botserver vault put gbo/vectordb host={} port=6333"#, botserver vault put gbo/vectordb host={} port=6333",
ip, ip, ip, ip ip, ip, ip, ip
) )
} }
"tables" => { "tables" => {
format!( format!(
r#"PostgreSQL Database: r"PostgreSQL Database:
Host: {} Host: {}
Port: 5432 Port: 5432
Database: botserver Database: botserver
User: gbuser User: gbuser
Store credentials in Vault: Store credentials in Vault:
botserver vault put gbo/tables host={} port=5432 database=botserver username=gbuser password=<your-password>"#, botserver vault put gbo/tables host={} port=5432 database=botserver username=gbuser password=<your-password>",
ip, ip ip, ip
) )
} }
"drive" => { "drive" => {
format!( format!(
r#"MinIO Object Storage: r"MinIO Object Storage:
API: http://{}:9000 API: http://{}:9000
Console: http://{}:9001 Console: http://{}:9001
Store credentials in Vault: Store credentials in Vault:
botserver vault put gbo/drive server={} port=9000 accesskey=minioadmin secret=<your-secret>"#, botserver vault put gbo/drive server={} port=9000 accesskey=minioadmin secret=<your-secret>",
ip, ip, ip ip, ip, ip
) )
} }
"cache" => { "cache" => {
format!( format!(
r#"Redis/Valkey Cache: r"Redis/Valkey Cache:
Host: {} Host: {}
Port: 6379 Port: 6379
Store credentials in Vault: Store credentials in Vault:
botserver vault put gbo/cache host={} port=6379 password=<your-password>"#, botserver vault put gbo/cache host={} port=6379 password=<your-password>",
ip, ip ip, ip
) )
} }
"email" => { "email" => {
format!( format!(
r#"Email Server (Stalwart): r"Email Server (Stalwart):
SMTP: {}:25 SMTP: {}:25
IMAP: {}:143 IMAP: {}:143
Web: http://{}:8080 Web: http://{}:8080
Store credentials in Vault: Store credentials in Vault:
botserver vault put gbo/email server={} port=25 username=admin password=<your-password>"#, botserver vault put gbo/email server={} port=25 username=admin password=<your-password>",
ip, ip, ip, ip ip, ip, ip, ip
) )
} }
"directory" => { "directory" => {
format!( format!(
r#"Zitadel Identity Provider: r"Zitadel Identity Provider:
URL: http://{}:8080 URL: http://{}:8080
Console: http://{}:8080/ui/console Console: http://{}:8080/ui/console
Store credentials in Vault: Store credentials in Vault:
botserver vault put gbo/directory url=http://{}:8080 client_id=<client-id> client_secret=<client-secret>"#, botserver vault put gbo/directory url=http://{}:8080 client_id=<client-id> client_secret=<client-secret>",
ip, ip, ip ip, ip, ip
) )
} }
"llm" => { "llm" => {
format!( format!(
r#"LLM Server (llama.cpp): r"LLM Server (llama.cpp):
API: http://{}:8081 API: http://{}:8081
Test: Test:
curl http://{}:8081/v1/models curl http://{}:8081/v1/models
Store credentials in Vault: Store credentials in Vault:
botserver vault put gbo/llm url=http://{}:8081 local=true"#, botserver vault put gbo/llm url=http://{}:8081 local=true",
ip, ip, ip ip, ip, ip
) )
} }
"meeting" => { "meeting" => {
format!( format!(
r#"LiveKit Meeting Server: r"LiveKit Meeting Server:
WebSocket: ws://{}:7880 WebSocket: ws://{}:7880
API: http://{}:7880 API: http://{}:7880
Store credentials in Vault: Store credentials in Vault:
botserver vault put gbo/meet url=ws://{}:7880 api_key=<api-key> api_secret=<api-secret>"#, botserver vault put gbo/meet url=ws://{}:7880 api_key=<api-key> api_secret=<api-secret>",
ip, ip, ip ip, ip, ip
) )
} }
"proxy" => { "proxy" => {
format!( format!(
r#"Caddy Reverse Proxy: r"Caddy Reverse Proxy:
HTTP: http://{}:80 HTTP: http://{}:80
HTTPS: https://{}:443 HTTPS: https://{}:443
Admin: http://{}:2019"#, Admin: http://{}:2019",
ip, ip, ip ip, ip, ip
) )
} }
"timeseries_db" => { "timeseries_db" => {
format!( format!(
r#"InfluxDB Time Series Database: r"InfluxDB Time Series Database:
API: http://{}:8086 API: http://{}:8086
Store credentials in Vault: Store credentials in Vault:
botserver vault put gbo/observability url=http://{}:8086 token=<influx-token> org=pragmatismo bucket=metrics"#, botserver vault put gbo/observability url=http://{}:8086 token=<influx-token> org=pragmatismo bucket=metrics",
ip, ip ip, ip
) )
} }
"observability" => { "observability" => {
format!( format!(
r#"Vector Log Aggregation: r"Vector Log Aggregation:
API: http://{}:8686 API: http://{}:8686
Store credentials in Vault: Store credentials in Vault:
botserver vault put gbo/observability vector_url=http://{}:8686"#, botserver vault put gbo/observability vector_url=http://{}:8686",
ip, ip ip, ip
) )
} }
"alm" => { "alm" => {
format!( format!(
r#"Forgejo Git Server: r"Forgejo Git Server:
Web: http://{}:3000 Web: http://{}:3000
SSH: {}:22 SSH: {}:22
Store credentials in Vault: Store credentials in Vault:
botserver vault put gbo/alm url=http://{}:3000 token=<api-token>"#, botserver vault put gbo/alm url=http://{}:3000 token=<api-token>",
ip, ip, ip ip, ip, ip
) )
} }
@ -611,11 +575,11 @@ Store credentials in Vault:
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join("\n"); .join("\n");
format!( format!(
r#"Component: {} r"Component: {}
Container: {}-{} Container: {}-{}
IP: {} IP: {}
Ports: Ports:
{}"#, {}",
component, self.tenant, component, ip, ports_str component, self.tenant, component, ip, ports_str
) )
} }
@ -740,7 +704,6 @@ Store credentials in Vault:
let bin_path = self.base_path.join("bin").join(component); let bin_path = self.base_path.join("bin").join(component);
std::fs::create_dir_all(&bin_path)?; std::fs::create_dir_all(&bin_path)?;
let cache_base = self.base_path.parent().unwrap_or(&self.base_path); let cache_base = self.base_path.parent().unwrap_or(&self.base_path);
let cache = DownloadCache::new(cache_base).unwrap_or_else(|e| { let cache = DownloadCache::new(cache_base).unwrap_or_else(|e| {
warn!("Failed to initialize download cache: {}", e); warn!("Failed to initialize download cache: {}", e);
@ -748,7 +711,6 @@ Store credentials in Vault:
DownloadCache::new(&self.base_path).expect("Failed to create fallback cache") DownloadCache::new(&self.base_path).expect("Failed to create fallback cache")
}); });
let cache_result = cache.resolve_component_url(component, url); let cache_result = cache.resolve_component_url(component, url);
let source_file = match cache_result { let source_file = match cache_result {
@ -763,7 +725,6 @@ Store credentials in Vault:
info!("Downloading {} from {}", component, download_url); info!("Downloading {} from {}", component, download_url);
println!("Downloading {}", download_url); println!("Downloading {}", download_url);
self.download_with_reqwest(&download_url, &cache_path, component) self.download_with_reqwest(&download_url, &cache_path, component)
.await?; .await?;
@ -772,7 +733,6 @@ Store credentials in Vault:
} }
}; };
self.handle_downloaded_file(&source_file, &bin_path, binary_name)?; self.handle_downloaded_file(&source_file, &bin_path, binary_name)?;
Ok(()) Ok(())
} }
@ -785,7 +745,6 @@ Store credentials in Vault:
const MAX_RETRIES: u32 = 3; const MAX_RETRIES: u32 = 3;
const RETRY_DELAY: std::time::Duration = std::time::Duration::from_secs(2); const RETRY_DELAY: std::time::Duration = std::time::Duration::from_secs(2);
if let Some(parent) = target_file.parent() { if let Some(parent) = target_file.parent() {
std::fs::create_dir_all(parent)?; std::fs::create_dir_all(parent)?;
} }
@ -870,7 +829,6 @@ Store credentials in Vault:
} else { } else {
let final_path = bin_path.join(temp_file.file_name().unwrap()); let final_path = bin_path.join(temp_file.file_name().unwrap());
if temp_file.to_string_lossy().contains("botserver-installers") { if temp_file.to_string_lossy().contains("botserver-installers") {
std::fs::copy(temp_file, &final_path)?; std::fs::copy(temp_file, &final_path)?;
} else { } else {
@ -894,7 +852,6 @@ Store credentials in Vault:
)); ));
} }
if !temp_file.to_string_lossy().contains("botserver-installers") { if !temp_file.to_string_lossy().contains("botserver-installers") {
std::fs::remove_file(temp_file)?; std::fs::remove_file(temp_file)?;
} }
@ -912,7 +869,6 @@ Store credentials in Vault:
)); ));
} }
#[cfg(unix)] #[cfg(unix)]
{ {
use std::os::unix::fs::PermissionsExt; use std::os::unix::fs::PermissionsExt;
@ -935,8 +891,6 @@ Store credentials in Vault:
} }
} }
if !temp_file.to_string_lossy().contains("botserver-installers") { if !temp_file.to_string_lossy().contains("botserver-installers") {
std::fs::remove_file(temp_file)?; std::fs::remove_file(temp_file)?;
} }
@ -950,7 +904,6 @@ Store credentials in Vault:
) -> Result<()> { ) -> Result<()> {
let final_path = bin_path.join(name); let final_path = bin_path.join(name);
if temp_file.to_string_lossy().contains("botserver-installers") { if temp_file.to_string_lossy().contains("botserver-installers") {
std::fs::copy(temp_file, &final_path)?; std::fs::copy(temp_file, &final_path)?;
} else { } else {
@ -981,7 +934,6 @@ Store credentials in Vault:
PathBuf::from("/opt/gbo/data") PathBuf::from("/opt/gbo/data")
}; };
let conf_path = if target == "local" { let conf_path = if target == "local" {
self.base_path.join("conf") self.base_path.join("conf")
} else { } else {
@ -993,15 +945,12 @@ Store credentials in Vault:
PathBuf::from("/opt/gbo/logs") PathBuf::from("/opt/gbo/logs")
}; };
let db_password = match get_database_url_sync() { let db_password = match get_database_url_sync() {
Ok(url) => { Ok(url) => {
let (_, password, _, _, _) = parse_database_url(&url); let (_, password, _, _, _) = parse_database_url(&url);
password password
} }
Err(_) => { Err(_) => {
trace!("Vault not available for DB_PASSWORD, using empty string"); trace!("Vault not available for DB_PASSWORD, using empty string");
String::new() String::new()
} }
@ -1124,8 +1073,6 @@ Store credentials in Vault:
exec_cmd: &str, exec_cmd: &str,
env_vars: &HashMap<String, String>, env_vars: &HashMap<String, String>,
) -> Result<()> { ) -> Result<()> {
let db_password = match get_database_url_sync() { let db_password = match get_database_url_sync() {
Ok(url) => { Ok(url) => {
let (_, password, _, _, _) = parse_database_url(&url); let (_, password, _, _, _) = parse_database_url(&url);

View file

@ -5,7 +5,6 @@ use std::time::Duration;
use tokio::fs; use tokio::fs;
use tokio::time::sleep; use tokio::time::sleep;
#[derive(Debug)] #[derive(Debug)]
pub struct EmailSetup { pub struct EmailSetup {
base_url: String, base_url: String,
@ -34,7 +33,6 @@ pub struct EmailDomain {
impl EmailSetup { impl EmailSetup {
pub fn new(base_url: String, config_path: PathBuf) -> Self { pub fn new(base_url: String, config_path: PathBuf) -> Self {
let admin_user = format!( let admin_user = format!(
"admin_{}@botserver.local", "admin_{}@botserver.local",
uuid::Uuid::new_v4() uuid::Uuid::new_v4()
@ -53,7 +51,6 @@ impl EmailSetup {
} }
} }
fn generate_secure_password() -> String { fn generate_secure_password() -> String {
use rand::distr::Alphanumeric; use rand::distr::Alphanumeric;
use rand::Rng; use rand::Rng;
@ -66,12 +63,10 @@ impl EmailSetup {
.collect() .collect()
} }
pub async fn wait_for_ready(&self, max_attempts: u32) -> Result<()> { pub async fn wait_for_ready(&self, max_attempts: u32) -> Result<()> {
log::info!("Waiting for Email service to be ready..."); log::info!("Waiting for Email service to be ready...");
for attempt in 1..=max_attempts { for attempt in 1..=max_attempts {
if let Ok(_) = tokio::net::TcpStream::connect("127.0.0.1:25").await { if let Ok(_) = tokio::net::TcpStream::connect("127.0.0.1:25").await {
log::info!("Email service is ready!"); log::info!("Email service is ready!");
return Ok(()); return Ok(());
@ -88,27 +83,22 @@ impl EmailSetup {
anyhow::bail!("Email service did not become ready in time") anyhow::bail!("Email service did not become ready in time")
} }
pub async fn initialize( pub async fn initialize(
&mut self, &mut self,
directory_config_path: Option<PathBuf>, directory_config_path: Option<PathBuf>,
) -> Result<EmailConfig> { ) -> Result<EmailConfig> {
log::info!(" Initializing Email (Stalwart) server..."); log::info!(" Initializing Email (Stalwart) server...");
if let Ok(existing_config) = self.load_existing_config().await { if let Ok(existing_config) = self.load_existing_config().await {
log::info!("Email already initialized, using existing config"); log::info!("Email already initialized, using existing config");
return Ok(existing_config); return Ok(existing_config);
} }
self.wait_for_ready(30).await?; self.wait_for_ready(30).await?;
self.create_default_domain().await?; self.create_default_domain().await?;
log::info!(" Created default email domain: localhost"); log::info!(" Created default email domain: localhost");
let directory_integration = if let Some(dir_config_path) = directory_config_path { let directory_integration = if let Some(dir_config_path) = directory_config_path {
match self.setup_directory_integration(&dir_config_path).await { match self.setup_directory_integration(&dir_config_path).await {
Ok(_) => { Ok(_) => {
@ -124,7 +114,6 @@ impl EmailSetup {
false false
}; };
self.create_admin_account().await?; self.create_admin_account().await?;
log::info!(" Created admin email account: {}", self.admin_user); log::info!(" Created admin email account: {}", self.admin_user);
@ -139,7 +128,6 @@ impl EmailSetup {
directory_integration, directory_integration,
}; };
self.save_config(&config).await?; self.save_config(&config).await?;
log::info!(" Saved Email configuration"); log::info!(" Saved Email configuration");
@ -151,14 +139,10 @@ impl EmailSetup {
Ok(config) Ok(config)
} }
async fn create_default_domain(&self) -> Result<()> { async fn create_default_domain(&self) -> Result<()> {
Ok(()) Ok(())
} }
async fn create_admin_account(&self) -> Result<()> { async fn create_admin_account(&self) -> Result<()> {
log::info!("Creating admin email account via Stalwart API..."); log::info!("Creating admin email account via Stalwart API...");
@ -166,14 +150,13 @@ impl EmailSetup {
.timeout(Duration::from_secs(30)) .timeout(Duration::from_secs(30))
.build()?; .build()?;
let api_url = format!("{}/api/account", self.base_url); let api_url = format!("{}/api/account", self.base_url);
let account_data = serde_json::json!({ let account_data = serde_json::json!({
"name": self.admin_user, "name": self.admin_user,
"secret": self.admin_pass, "secret": self.admin_pass,
"description": "BotServer Admin Account", "description": "BotServer Admin Account",
"quota": 1073741824, "quota": 1_073_741_824,
"type": "individual", "type": "individual",
"emails": [self.admin_user.clone()], "emails": [self.admin_user.clone()],
"memberOf": ["administrators"], "memberOf": ["administrators"],
@ -196,7 +179,6 @@ impl EmailSetup {
); );
Ok(()) Ok(())
} else if resp.status().as_u16() == 409 { } else if resp.status().as_u16() == 409 {
log::info!("Admin email account already exists: {}", self.admin_user); log::info!("Admin email account already exists: {}", self.admin_user);
Ok(()) Ok(())
} else { } else {
@ -222,7 +204,6 @@ impl EmailSetup {
} }
} }
async fn setup_directory_integration(&self, directory_config_path: &PathBuf) -> Result<()> { async fn setup_directory_integration(&self, directory_config_path: &PathBuf) -> Result<()> {
let content = fs::read_to_string(directory_config_path).await?; let content = fs::read_to_string(directory_config_path).await?;
let dir_config: serde_json::Value = serde_json::from_str(&content)?; let dir_config: serde_json::Value = serde_json::from_str(&content)?;
@ -234,31 +215,25 @@ impl EmailSetup {
log::info!("Setting up OIDC authentication with Directory..."); log::info!("Setting up OIDC authentication with Directory...");
log::info!("Issuer URL: {}", issuer_url); log::info!("Issuer URL: {}", issuer_url);
Ok(()) Ok(())
} }
async fn save_config(&self, config: &EmailConfig) -> Result<()> { async fn save_config(&self, config: &EmailConfig) -> Result<()> {
let json = serde_json::to_string_pretty(config)?; let json = serde_json::to_string_pretty(config)?;
fs::write(&self.config_path, json).await?; fs::write(&self.config_path, json).await?;
Ok(()) Ok(())
} }
async fn load_existing_config(&self) -> Result<EmailConfig> { async fn load_existing_config(&self) -> Result<EmailConfig> {
let content = fs::read_to_string(&self.config_path).await?; let content = fs::read_to_string(&self.config_path).await?;
let config: EmailConfig = serde_json::from_str(&content)?; let config: EmailConfig = serde_json::from_str(&content)?;
Ok(config) Ok(config)
} }
pub async fn get_config(&self) -> Result<EmailConfig> { pub async fn get_config(&self) -> Result<EmailConfig> {
self.load_existing_config().await self.load_existing_config().await
} }
pub async fn create_user_mailbox( pub async fn create_user_mailbox(
&self, &self,
_username: &str, _username: &str,
@ -267,20 +242,15 @@ impl EmailSetup {
) -> Result<()> { ) -> Result<()> {
log::info!("Creating mailbox for user: {}", email); log::info!("Creating mailbox for user: {}", email);
Ok(()) Ok(())
} }
pub async fn sync_users_from_directory(&self, directory_config_path: &PathBuf) -> Result<()> { pub async fn sync_users_from_directory(&self, directory_config_path: &PathBuf) -> Result<()> {
log::info!("Syncing users from Directory to Email..."); log::info!("Syncing users from Directory to Email...");
let content = fs::read_to_string(directory_config_path).await?; let content = fs::read_to_string(directory_config_path).await?;
let dir_config: serde_json::Value = serde_json::from_str(&content)?; let dir_config: serde_json::Value = serde_json::from_str(&content)?;
if let Some(default_user) = dir_config.get("default_user") { if let Some(default_user) = dir_config.get("default_user") {
let email = default_user["email"].as_str().unwrap_or(""); let email = default_user["email"].as_str().unwrap_or("");
let password = default_user["password"].as_str().unwrap_or(""); let password = default_user["password"].as_str().unwrap_or("");
@ -296,7 +266,6 @@ impl EmailSetup {
} }
} }
pub async fn generate_email_config( pub async fn generate_email_config(
config_path: PathBuf, config_path: PathBuf,
data_path: PathBuf, data_path: PathBuf,
@ -352,7 +321,6 @@ store = "sqlite"
data_path.display() data_path.display()
); );
if directory_integration { if directory_integration {
config.push_str( config.push_str(
r#" r#"

View file

@ -10,10 +10,7 @@ use uuid::Uuid;
#[cfg(feature = "vectordb")] #[cfg(feature = "vectordb")]
use qdrant_client::{ use qdrant_client::{
qdrant::{ qdrant::{Distance, PointStruct, VectorParams},
vectors_config::Config, CreateCollection, Distance, PointStruct, VectorParams,
VectorsConfig,
},
Qdrant, Qdrant,
}; };

File diff suppressed because it is too large Load diff

View file

@ -1,23 +1,3 @@
use anyhow::{anyhow, Context, Result}; use anyhow::{anyhow, Context, Result};
use chrono::{DateTime, NaiveDate, Utc}; use chrono::{DateTime, NaiveDate, Utc};
use reqwest::{Client, Method, StatusCode}; use reqwest::{Client, Method, StatusCode};
@ -26,23 +6,14 @@ use serde_json::{json, Value};
use std::time::Duration; use std::time::Duration;
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
const DEFAULT_TIMEOUT_SECS: u64 = 30; const DEFAULT_TIMEOUT_SECS: u64 = 30;
pub const DEFAULT_QUEUE_POLL_INTERVAL_SECS: u64 = 30; pub const DEFAULT_QUEUE_POLL_INTERVAL_SECS: u64 = 30;
pub const DEFAULT_METRICS_POLL_INTERVAL_SECS: u64 = 60; pub const DEFAULT_METRICS_POLL_INTERVAL_SECS: u64 = 60;
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueueStatus { pub struct QueueStatus {
pub is_running: bool, pub is_running: bool,
pub total_queued: u64, pub total_queued: u64,
@ -50,10 +21,8 @@ pub struct QueueStatus {
pub messages: Vec<QueuedMessage>, pub messages: Vec<QueuedMessage>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueuedMessage { pub struct QueuedMessage {
pub id: String, pub id: String,
pub from: String, pub from: String,
@ -81,7 +50,6 @@ pub struct QueuedMessage {
pub queued_at: Option<DateTime<Utc>>, pub queued_at: Option<DateTime<Utc>>,
} }
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum DeliveryStatus { pub enum DeliveryStatus {
@ -94,7 +62,6 @@ pub enum DeliveryStatus {
Unknown, Unknown,
} }
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize)]
struct QueueListResponse { struct QueueListResponse {
#[serde(default)] #[serde(default)]
@ -103,9 +70,6 @@ struct QueueListResponse {
items: Vec<QueuedMessage>, items: Vec<QueuedMessage>,
} }
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum PrincipalType { pub enum PrincipalType {
@ -119,10 +83,8 @@ pub enum PrincipalType {
Other, Other,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Principal { pub struct Principal {
pub id: Option<u64>, pub id: Option<u64>,
#[serde(rename = "type")] #[serde(rename = "type")]
@ -149,10 +111,8 @@ pub struct Principal {
pub disabled: bool, pub disabled: bool,
} }
#[derive(Debug, Clone, Serialize)] #[derive(Debug, Clone, Serialize)]
pub struct AccountUpdate { pub struct AccountUpdate {
pub action: String, pub action: String,
pub field: String, pub field: String,
@ -161,7 +121,6 @@ pub struct AccountUpdate {
} }
impl AccountUpdate { impl AccountUpdate {
pub fn set(field: &str, value: impl Into<Value>) -> Self { pub fn set(field: &str, value: impl Into<Value>) -> Self {
Self { Self {
action: "set".to_string(), action: "set".to_string(),
@ -170,7 +129,6 @@ impl AccountUpdate {
} }
} }
pub fn add_item(field: &str, value: impl Into<Value>) -> Self { pub fn add_item(field: &str, value: impl Into<Value>) -> Self {
Self { Self {
action: "addItem".to_string(), action: "addItem".to_string(),
@ -179,7 +137,6 @@ impl AccountUpdate {
} }
} }
pub fn remove_item(field: &str, value: impl Into<Value>) -> Self { pub fn remove_item(field: &str, value: impl Into<Value>) -> Self {
Self { Self {
action: "removeItem".to_string(), action: "removeItem".to_string(),
@ -188,7 +145,6 @@ impl AccountUpdate {
} }
} }
pub fn clear(field: &str) -> Self { pub fn clear(field: &str) -> Self {
Self { Self {
action: "clear".to_string(), action: "clear".to_string(),
@ -198,12 +154,8 @@ impl AccountUpdate {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AutoResponderConfig { pub struct AutoResponderConfig {
pub enabled: bool, pub enabled: bool,
pub subject: String, pub subject: String,
@ -235,7 +187,8 @@ impl Default for AutoResponderConfig {
Self { Self {
enabled: false, enabled: false,
subject: "Out of Office".to_string(), subject: "Out of Office".to_string(),
body_plain: "I am currently out of the office and will respond upon my return.".to_string(), body_plain: "I am currently out of the office and will respond upon my return."
.to_string(),
body_html: None, body_html: None,
start_date: None, start_date: None,
end_date: None, end_date: None,
@ -245,10 +198,8 @@ impl Default for AutoResponderConfig {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmailRule { pub struct EmailRule {
pub id: String, pub id: String,
pub name: String, pub name: String,
@ -270,10 +221,8 @@ fn default_stop_processing() -> bool {
true true
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RuleCondition { pub struct RuleCondition {
pub field: String, pub field: String,
pub operator: String, pub operator: String,
@ -287,22 +236,16 @@ pub struct RuleCondition {
pub case_sensitive: bool, pub case_sensitive: bool,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RuleAction { pub struct RuleAction {
pub action_type: String, pub action_type: String,
#[serde(default)] #[serde(default)]
pub value: String, pub value: String,
} }
#[derive(Debug, Clone, Serialize, Deserialize, Default)] #[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Metrics { pub struct Metrics {
#[serde(default)] #[serde(default)]
pub messages_received: u64, pub messages_received: u64,
@ -334,10 +277,8 @@ pub struct Metrics {
pub extra: std::collections::HashMap<String, Value>, pub extra: std::collections::HashMap<String, Value>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LogEntry { pub struct LogEntry {
pub timestamp: DateTime<Utc>, pub timestamp: DateTime<Utc>,
pub level: String, pub level: String,
@ -351,7 +292,6 @@ pub struct LogEntry {
pub context: Option<Value>, pub context: Option<Value>,
} }
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize)]
pub struct LogList { pub struct LogList {
#[serde(default)] #[serde(default)]
@ -360,10 +300,8 @@ pub struct LogList {
pub items: Vec<LogEntry>, pub items: Vec<LogEntry>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TraceEvent { pub struct TraceEvent {
pub timestamp: DateTime<Utc>, pub timestamp: DateTime<Utc>,
pub event_type: String, pub event_type: String,
@ -390,7 +328,6 @@ pub struct TraceEvent {
pub duration_ms: Option<u64>, pub duration_ms: Option<u64>,
} }
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize)]
pub struct TraceList { pub struct TraceList {
#[serde(default)] #[serde(default)]
@ -399,12 +336,8 @@ pub struct TraceList {
pub items: Vec<TraceEvent>, pub items: Vec<TraceEvent>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Report { pub struct Report {
pub id: String, pub id: String,
pub report_type: String, pub report_type: String,
@ -423,7 +356,6 @@ pub struct Report {
pub data: Value, pub data: Value,
} }
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize)]
pub struct ReportList { pub struct ReportList {
#[serde(default)] #[serde(default)]
@ -432,12 +364,8 @@ pub struct ReportList {
pub items: Vec<Report>, pub items: Vec<Report>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpamClassifyRequest { pub struct SpamClassifyRequest {
pub from: String, pub from: String,
pub to: Vec<String>, pub to: Vec<String>,
@ -455,10 +383,8 @@ pub struct SpamClassifyRequest {
pub body: Option<String>, pub body: Option<String>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpamClassifyResult { pub struct SpamClassifyResult {
pub score: f64, pub score: f64,
pub classification: String, pub classification: String,
@ -470,10 +396,8 @@ pub struct SpamClassifyResult {
pub action: Option<String>, pub action: Option<String>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpamTest { pub struct SpamTest {
pub name: String, pub name: String,
pub score: f64, pub score: f64,
@ -482,9 +406,6 @@ pub struct SpamTest {
pub description: Option<String>, pub description: Option<String>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
#[serde(untagged)] #[serde(untagged)]
enum ApiResponse<T> { enum ApiResponse<T> {
@ -493,9 +414,6 @@ enum ApiResponse<T> {
Error { error: String }, Error { error: String },
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct StalwartClient { pub struct StalwartClient {
base_url: String, base_url: String,
@ -504,18 +422,6 @@ pub struct StalwartClient {
} }
impl StalwartClient { impl StalwartClient {
pub fn new(base_url: &str, token: &str) -> Self { pub fn new(base_url: &str, token: &str) -> Self {
let http_client = Client::builder() let http_client = Client::builder()
.timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS)) .timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
@ -529,7 +435,6 @@ impl StalwartClient {
} }
} }
pub fn with_timeout(base_url: &str, token: &str, timeout_secs: u64) -> Self { pub fn with_timeout(base_url: &str, token: &str, timeout_secs: u64) -> Self {
let http_client = Client::builder() let http_client = Client::builder()
.timeout(Duration::from_secs(timeout_secs)) .timeout(Duration::from_secs(timeout_secs))
@ -543,7 +448,6 @@ impl StalwartClient {
} }
} }
async fn request<T: DeserializeOwned>( async fn request<T: DeserializeOwned>(
&self, &self,
method: Method, method: Method,
@ -563,20 +467,24 @@ impl StalwartClient {
req = req.header("Content-Type", "application/json").json(b); req = req.header("Content-Type", "application/json").json(b);
} }
let resp = req.send().await.context("Failed to send request to Stalwart")?; let resp = req
.send()
.await
.context("Failed to send request to Stalwart")?;
let status = resp.status(); let status = resp.status();
if !status.is_success() { if !status.is_success() {
let error_text = resp.text().await.unwrap_or_else(|_| "Unknown error".to_string()); let error_text = resp
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
error!("Stalwart API error: {} - {}", status, error_text); error!("Stalwart API error: {} - {}", status, error_text);
return Err(anyhow!("Stalwart API error ({}): {}", status, error_text)); return Err(anyhow!("Stalwart API error ({}): {}", status, error_text));
} }
let text = resp.text().await.context("Failed to read response body")?; let text = resp.text().await.context("Failed to read response body")?;
if text.is_empty() || text == "null" { if text.is_empty() || text == "null" {
return serde_json::from_str("null") return serde_json::from_str("null")
.or_else(|_| serde_json::from_str("{}")) .or_else(|_| serde_json::from_str("{}"))
.or_else(|_| serde_json::from_str("true")) .or_else(|_| serde_json::from_str("true"))
@ -586,8 +494,13 @@ impl StalwartClient {
serde_json::from_str(&text).context("Failed to parse Stalwart API response") serde_json::from_str(&text).context("Failed to parse Stalwart API response")
} }
async fn request_raw(
async fn request_raw(&self, method: Method, path: &str, body: &str, content_type: &str) -> Result<()> { &self,
method: Method,
path: &str,
body: &str,
content_type: &str,
) -> Result<()> {
let url = format!("{}{}", self.base_url, path); let url = format!("{}{}", self.base_url, path);
debug!("Stalwart API raw request: {} {}", method, url); debug!("Stalwart API raw request: {} {}", method, url);
@ -603,18 +516,16 @@ impl StalwartClient {
let status = resp.status(); let status = resp.status();
if !status.is_success() { if !status.is_success() {
let error_text = resp.text().await.unwrap_or_else(|_| "Unknown error".to_string()); let error_text = resp
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow!("Stalwart API error ({}): {}", status, error_text)); return Err(anyhow!("Stalwart API error ({}): {}", status, error_text));
} }
Ok(()) Ok(())
} }
pub async fn get_queue_status(&self) -> Result<QueueStatus> { pub async fn get_queue_status(&self) -> Result<QueueStatus> {
let status: bool = self let status: bool = self
.request(Method::GET, "/api/queue/status", None) .request(Method::GET, "/api/queue/status", None)
@ -636,7 +547,6 @@ impl StalwartClient {
}) })
} }
pub async fn get_queued_message(&self, message_id: &str) -> Result<QueuedMessage> { pub async fn get_queued_message(&self, message_id: &str) -> Result<QueuedMessage> {
self.request( self.request(
Method::GET, Method::GET,
@ -646,7 +556,6 @@ impl StalwartClient {
.await .await
} }
pub async fn list_queued_messages( pub async fn list_queued_messages(
&self, &self,
limit: u32, limit: u32,
@ -660,7 +569,6 @@ impl StalwartClient {
self.request(Method::GET, &path, None).await self.request(Method::GET, &path, None).await
} }
pub async fn retry_delivery(&self, message_id: &str) -> Result<bool> { pub async fn retry_delivery(&self, message_id: &str) -> Result<bool> {
self.request( self.request(
Method::PATCH, Method::PATCH,
@ -670,7 +578,6 @@ impl StalwartClient {
.await .await
} }
pub async fn cancel_delivery(&self, message_id: &str) -> Result<bool> { pub async fn cancel_delivery(&self, message_id: &str) -> Result<bool> {
self.request( self.request(
Method::DELETE, Method::DELETE,
@ -680,17 +587,16 @@ impl StalwartClient {
.await .await
} }
pub async fn stop_queue(&self) -> Result<bool> { pub async fn stop_queue(&self) -> Result<bool> {
self.request(Method::PATCH, "/api/queue/status/stop", None).await self.request(Method::PATCH, "/api/queue/status/stop", None)
.await
} }
pub async fn start_queue(&self) -> Result<bool> { pub async fn start_queue(&self) -> Result<bool> {
self.request(Method::PATCH, "/api/queue/status/start", None).await self.request(Method::PATCH, "/api/queue/status/start", None)
.await
} }
pub async fn get_failed_delivery_count(&self) -> Result<u64> { pub async fn get_failed_delivery_count(&self) -> Result<u64> {
let resp: QueueListResponse = self let resp: QueueListResponse = self
.request( .request(
@ -702,11 +608,6 @@ impl StalwartClient {
Ok(resp.total) Ok(resp.total)
} }
pub async fn create_account( pub async fn create_account(
&self, &self,
email: &str, email: &str,
@ -725,19 +626,19 @@ impl StalwartClient {
"roles": ["user"] "roles": ["user"]
}); });
self.request(Method::POST, "/api/principal", Some(body)).await self.request(Method::POST, "/api/principal", Some(body))
.await
} }
pub async fn create_account_full(&self, principal: &Principal, password: &str) -> Result<u64> { pub async fn create_account_full(&self, principal: &Principal, password: &str) -> Result<u64> {
let mut body = serde_json::to_value(principal)?; let mut body = serde_json::to_value(principal)?;
if let Some(obj) = body.as_object_mut() { if let Some(obj) = body.as_object_mut() {
obj.insert("secrets".to_string(), json!([password])); obj.insert("secrets".to_string(), json!([password]));
} }
self.request(Method::POST, "/api/principal", Some(body)).await self.request(Method::POST, "/api/principal", Some(body))
.await
} }
pub async fn create_distribution_list( pub async fn create_distribution_list(
&self, &self,
name: &str, name: &str,
@ -752,10 +653,10 @@ impl StalwartClient {
"description": format!("Distribution list: {}", name) "description": format!("Distribution list: {}", name)
}); });
self.request(Method::POST, "/api/principal", Some(body)).await self.request(Method::POST, "/api/principal", Some(body))
.await
} }
pub async fn create_shared_mailbox( pub async fn create_shared_mailbox(
&self, &self,
name: &str, name: &str,
@ -770,20 +671,15 @@ impl StalwartClient {
"description": format!("Shared mailbox: {}", name) "description": format!("Shared mailbox: {}", name)
}); });
self.request(Method::POST, "/api/principal", Some(body)).await self.request(Method::POST, "/api/principal", Some(body))
.await
} }
pub async fn get_account(&self, account_id: &str) -> Result<Principal> { pub async fn get_account(&self, account_id: &str) -> Result<Principal> {
self.request( self.request(Method::GET, &format!("/api/principal/{}", account_id), None)
Method::GET, .await
&format!("/api/principal/{}", account_id),
None,
)
.await
} }
pub async fn get_account_by_email(&self, email: &str) -> Result<Principal> { pub async fn get_account_by_email(&self, email: &str) -> Result<Principal> {
self.request( self.request(
Method::GET, Method::GET,
@ -793,8 +689,11 @@ impl StalwartClient {
.await .await
} }
pub async fn update_account(
pub async fn update_account(&self, account_id: &str, updates: Vec<AccountUpdate>) -> Result<()> { &self,
account_id: &str,
updates: Vec<AccountUpdate>,
) -> Result<()> {
let body: Vec<Value> = updates let body: Vec<Value> = updates
.iter() .iter()
.map(|u| { .map(|u| {
@ -815,7 +714,6 @@ impl StalwartClient {
Ok(()) Ok(())
} }
pub async fn delete_account(&self, account_id: &str) -> Result<()> { pub async fn delete_account(&self, account_id: &str) -> Result<()> {
self.request::<Value>( self.request::<Value>(
Method::DELETE, Method::DELETE,
@ -826,8 +724,10 @@ impl StalwartClient {
Ok(()) Ok(())
} }
pub async fn list_principals(
pub async fn list_principals(&self, principal_type: Option<PrincipalType>) -> Result<Vec<Principal>> { &self,
principal_type: Option<PrincipalType>,
) -> Result<Vec<Principal>> {
let path = match principal_type { let path = match principal_type {
Some(t) => format!("/api/principal?type={:?}", t).to_lowercase(), Some(t) => format!("/api/principal?type={:?}", t).to_lowercase(),
None => "/api/principal".to_string(), None => "/api/principal".to_string(),
@ -835,7 +735,6 @@ impl StalwartClient {
self.request(Method::GET, &path, None).await self.request(Method::GET, &path, None).await
} }
pub async fn add_members(&self, account_id: &str, members: Vec<String>) -> Result<()> { pub async fn add_members(&self, account_id: &str, members: Vec<String>) -> Result<()> {
let updates: Vec<AccountUpdate> = members let updates: Vec<AccountUpdate> = members
.into_iter() .into_iter()
@ -844,7 +743,6 @@ impl StalwartClient {
self.update_account(account_id, updates).await self.update_account(account_id, updates).await
} }
pub async fn remove_members(&self, account_id: &str, members: Vec<String>) -> Result<()> { pub async fn remove_members(&self, account_id: &str, members: Vec<String>) -> Result<()> {
let updates: Vec<AccountUpdate> = members let updates: Vec<AccountUpdate> = members
.into_iter() .into_iter()
@ -853,11 +751,6 @@ impl StalwartClient {
self.update_account(account_id, updates).await self.update_account(account_id, updates).await
} }
pub async fn set_auto_responder( pub async fn set_auto_responder(
&self, &self,
account_id: &str, account_id: &str,
@ -879,7 +772,6 @@ impl StalwartClient {
Ok(script_id) Ok(script_id)
} }
pub async fn disable_auto_responder(&self, account_id: &str) -> Result<()> { pub async fn disable_auto_responder(&self, account_id: &str) -> Result<()> {
let script_id = format!("{}_vacation", account_id); let script_id = format!("{}_vacation", account_id);
@ -895,10 +787,9 @@ impl StalwartClient {
Ok(()) Ok(())
} }
pub fn generate_vacation_sieve(&self, config: &AutoResponderConfig) -> String { pub fn generate_vacation_sieve(&self, config: &AutoResponderConfig) -> String {
let mut script = String::from("require [\"vacation\", \"variables\", \"date\", \"relational\"];\n\n"); let mut script =
String::from("require [\"vacation\", \"variables\", \"date\", \"relational\"];\n\n");
if config.start_date.is_some() || config.end_date.is_some() { if config.start_date.is_some() || config.end_date.is_some() {
script.push_str("# Date-based activation\n"); script.push_str("# Date-based activation\n");
@ -920,7 +811,6 @@ impl StalwartClient {
script.push('\n'); script.push('\n');
} }
let subject = config.subject.replace('"', "\\\"").replace('\n', " "); let subject = config.subject.replace('"', "\\\"").replace('\n', " ");
let body = config.body_plain.replace('"', "\\\"").replace('\n', "\\n"); let body = config.body_plain.replace('"', "\\\"").replace('\n', "\\n");
@ -932,7 +822,6 @@ impl StalwartClient {
script script
} }
pub async fn set_filter_rule(&self, account_id: &str, rule: &EmailRule) -> Result<String> { pub async fn set_filter_rule(&self, account_id: &str, rule: &EmailRule) -> Result<String> {
let sieve_script = self.generate_filter_sieve(rule); let sieve_script = self.generate_filter_sieve(rule);
let script_id = format!("{}_filter_{}", account_id, rule.id); let script_id = format!("{}_filter_{}", account_id, rule.id);
@ -950,7 +839,6 @@ impl StalwartClient {
Ok(script_id) Ok(script_id)
} }
pub async fn delete_filter_rule(&self, account_id: &str, rule_id: &str) -> Result<()> { pub async fn delete_filter_rule(&self, account_id: &str, rule_id: &str) -> Result<()> {
let script_id = format!("{}_filter_{}", account_id, rule_id); let script_id = format!("{}_filter_{}", account_id, rule_id);
@ -966,10 +854,10 @@ impl StalwartClient {
Ok(()) Ok(())
} }
pub fn generate_filter_sieve(&self, rule: &EmailRule) -> String { pub fn generate_filter_sieve(&self, rule: &EmailRule) -> String {
let mut script = let mut script = String::from(
String::from("require [\"fileinto\", \"reject\", \"vacation\", \"imap4flags\", \"copy\"];\n\n"); "require [\"fileinto\", \"reject\", \"vacation\", \"imap4flags\", \"copy\"];\n\n",
);
script.push_str(&format!("# Rule: {}\n", rule.name)); script.push_str(&format!("# Rule: {}\n", rule.name));
@ -978,7 +866,6 @@ impl StalwartClient {
return script; return script;
} }
let mut conditions = Vec::new(); let mut conditions = Vec::new();
for condition in &rule.conditions { for condition in &rule.conditions {
let cond_str = self.generate_condition_sieve(condition); let cond_str = self.generate_condition_sieve(condition);
@ -988,14 +875,11 @@ impl StalwartClient {
} }
if conditions.is_empty() { if conditions.is_empty() {
script.push_str("# Always applies\n"); script.push_str("# Always applies\n");
} else { } else {
script.push_str(&format!("if allof ({}) {{\n", conditions.join(", "))); script.push_str(&format!("if allof ({}) {{\n", conditions.join(", ")));
} }
for action in &rule.actions { for action in &rule.actions {
let action_str = self.generate_action_sieve(action); let action_str = self.generate_action_sieve(action);
if !action_str.is_empty() { if !action_str.is_empty() {
@ -1007,7 +891,6 @@ impl StalwartClient {
} }
} }
if rule.stop_processing { if rule.stop_processing {
if conditions.is_empty() { if conditions.is_empty() {
script.push_str("stop;\n"); script.push_str("stop;\n");
@ -1016,7 +899,6 @@ impl StalwartClient {
} }
} }
if !conditions.is_empty() { if !conditions.is_empty() {
script.push_str("}\n"); script.push_str("}\n");
} }
@ -1024,7 +906,6 @@ impl StalwartClient {
script script
} }
pub fn generate_condition_sieve(&self, condition: &RuleCondition) -> String { pub fn generate_condition_sieve(&self, condition: &RuleCondition) -> String {
let field_header = match condition.field.as_str() { let field_header = match condition.field.as_str() {
"from" => "From", "from" => "From",
@ -1044,17 +925,34 @@ impl StalwartClient {
let value = condition.value.replace('"', "\\\""); let value = condition.value.replace('"', "\\\"");
match condition.operator.as_str() { match condition.operator.as_str() {
"contains" => format!("header :contains{} \"{}\" \"{}\"", comparator, field_header, value), "contains" => format!(
"equals" => format!("header :is{} \"{}\" \"{}\"", comparator, field_header, value), "header :contains{} \"{}\" \"{}\"",
"startsWith" => format!("header :matches{} \"{}\" \"{}*\"", comparator, field_header, value), comparator, field_header, value
"endsWith" => format!("header :matches{} \"{}\" \"*{}\"", comparator, field_header, value), ),
"regex" => format!("header :regex{} \"{}\" \"{}\"", comparator, field_header, value), "equals" => format!(
"notContains" => format!("not header :contains{} \"{}\" \"{}\"", comparator, field_header, value), "header :is{} \"{}\" \"{}\"",
comparator, field_header, value
),
"startsWith" => format!(
"header :matches{} \"{}\" \"{}*\"",
comparator, field_header, value
),
"endsWith" => format!(
"header :matches{} \"{}\" \"*{}\"",
comparator, field_header, value
),
"regex" => format!(
"header :regex{} \"{}\" \"{}\"",
comparator, field_header, value
),
"notContains" => format!(
"not header :contains{} \"{}\" \"{}\"",
comparator, field_header, value
),
_ => String::new(), _ => String::new(),
} }
} }
pub fn generate_action_sieve(&self, action: &RuleAction) -> String { pub fn generate_action_sieve(&self, action: &RuleAction) -> String {
match action.action_type.as_str() { match action.action_type.as_str() {
"move" => format!("fileinto \"{}\";", action.value.replace('"', "\\\"")), "move" => format!("fileinto \"{}\";", action.value.replace('"', "\\\"")),
@ -1068,16 +966,11 @@ impl StalwartClient {
} }
} }
pub async fn get_metrics(&self) -> Result<Metrics> { pub async fn get_metrics(&self) -> Result<Metrics> {
self.request(Method::GET, "/api/telemetry/metrics", None).await self.request(Method::GET, "/api/telemetry/metrics", None)
.await
} }
pub async fn get_logs(&self, page: u32, limit: u32) -> Result<LogList> { pub async fn get_logs(&self, page: u32, limit: u32) -> Result<LogList> {
self.request( self.request(
Method::GET, Method::GET,
@ -1087,7 +980,6 @@ impl StalwartClient {
.await .await
} }
pub async fn get_logs_by_level(&self, level: &str, page: u32, limit: u32) -> Result<LogList> { pub async fn get_logs_by_level(&self, level: &str, page: u32, limit: u32) -> Result<LogList> {
self.request( self.request(
Method::GET, Method::GET,
@ -1097,7 +989,6 @@ impl StalwartClient {
.await .await
} }
pub async fn get_traces(&self, trace_type: &str, page: u32) -> Result<TraceList> { pub async fn get_traces(&self, trace_type: &str, page: u32) -> Result<TraceList> {
self.request( self.request(
Method::GET, Method::GET,
@ -1110,7 +1001,6 @@ impl StalwartClient {
.await .await
} }
pub async fn get_recent_traces(&self, limit: u32) -> Result<TraceList> { pub async fn get_recent_traces(&self, limit: u32) -> Result<TraceList> {
self.request( self.request(
Method::GET, Method::GET,
@ -1120,7 +1010,6 @@ impl StalwartClient {
.await .await
} }
pub async fn get_trace(&self, trace_id: &str) -> Result<Vec<TraceEvent>> { pub async fn get_trace(&self, trace_id: &str) -> Result<Vec<TraceEvent>> {
self.request( self.request(
Method::GET, Method::GET,
@ -1130,7 +1019,6 @@ impl StalwartClient {
.await .await
} }
pub async fn get_dmarc_reports(&self, page: u32) -> Result<ReportList> { pub async fn get_dmarc_reports(&self, page: u32) -> Result<ReportList> {
self.request( self.request(
Method::GET, Method::GET,
@ -1140,7 +1028,6 @@ impl StalwartClient {
.await .await
} }
pub async fn get_tls_reports(&self, page: u32) -> Result<ReportList> { pub async fn get_tls_reports(&self, page: u32) -> Result<ReportList> {
self.request( self.request(
Method::GET, Method::GET,
@ -1150,7 +1037,6 @@ impl StalwartClient {
.await .await
} }
pub async fn get_arf_reports(&self, page: u32) -> Result<ReportList> { pub async fn get_arf_reports(&self, page: u32) -> Result<ReportList> {
self.request( self.request(
Method::GET, Method::GET,
@ -1160,23 +1046,16 @@ impl StalwartClient {
.await .await
} }
pub async fn get_live_metrics_token(&self) -> Result<String> { pub async fn get_live_metrics_token(&self) -> Result<String> {
self.request(Method::GET, "/api/telemetry/live/metrics-token", None) self.request(Method::GET, "/api/telemetry/live/metrics-token", None)
.await .await
} }
pub async fn get_live_tracing_token(&self) -> Result<String> { pub async fn get_live_tracing_token(&self) -> Result<String> {
self.request(Method::GET, "/api/telemetry/live/tracing-token", None) self.request(Method::GET, "/api/telemetry/live/tracing-token", None)
.await .await
} }
pub async fn train_spam(&self, raw_message: &str) -> Result<()> { pub async fn train_spam(&self, raw_message: &str) -> Result<()> {
self.request_raw( self.request_raw(
Method::POST, Method::POST,
@ -1189,7 +1068,6 @@ impl StalwartClient {
Ok(()) Ok(())
} }
pub async fn train_ham(&self, raw_message: &str) -> Result<()> { pub async fn train_ham(&self, raw_message: &str) -> Result<()> {
self.request_raw( self.request_raw(
Method::POST, Method::POST,
@ -1202,8 +1080,10 @@ impl StalwartClient {
Ok(()) Ok(())
} }
pub async fn classify_message(
pub async fn classify_message(&self, message: &SpamClassifyRequest) -> Result<SpamClassifyResult> { &self,
message: &SpamClassifyRequest,
) -> Result<SpamClassifyResult> {
self.request( self.request(
Method::POST, Method::POST,
"/api/spam-filter/classify", "/api/spam-filter/classify",
@ -1212,21 +1092,18 @@ impl StalwartClient {
.await .await
} }
pub async fn troubleshoot_delivery(&self, recipient: &str) -> Result<Value> { pub async fn troubleshoot_delivery(&self, recipient: &str) -> Result<Value> {
self.request( self.request(
Method::GET, Method::GET,
&format!("/api/troubleshoot/delivery/{}", urlencoding::encode(recipient)), &format!(
"/api/troubleshoot/delivery/{}",
urlencoding::encode(recipient)
),
None, None,
) )
.await .await
} }
pub async fn check_dmarc(&self, domain: &str, from_email: &str) -> Result<Value> { pub async fn check_dmarc(&self, domain: &str, from_email: &str) -> Result<Value> {
let body = json!({ let body = json!({
"domain": domain, "domain": domain,
@ -1236,17 +1113,11 @@ impl StalwartClient {
.await .await
} }
pub async fn get_dns_records(&self, domain: &str) -> Result<Value> { pub async fn get_dns_records(&self, domain: &str) -> Result<Value> {
self.request( self.request(Method::GET, &format!("/api/dns/records/{}", domain), None)
Method::GET, .await
&format!("/api/dns/records/{}", domain),
None,
)
.await
} }
pub async fn undelete_messages(&self, account_id: &str) -> Result<Value> { pub async fn undelete_messages(&self, account_id: &str) -> Result<Value> {
self.request( self.request(
Method::POST, Method::POST,
@ -1256,7 +1127,6 @@ impl StalwartClient {
.await .await
} }
pub async fn purge_account(&self, account_id: &str) -> Result<()> { pub async fn purge_account(&self, account_id: &str) -> Result<()> {
self.request::<Value>( self.request::<Value>(
Method::GET, Method::GET,
@ -1268,9 +1138,11 @@ impl StalwartClient {
Ok(()) Ok(())
} }
pub async fn health_check(&self) -> Result<bool> { pub async fn health_check(&self) -> Result<bool> {
match self.request::<Value>(Method::GET, "/api/queue/status", None).await { match self
.request::<Value>(Method::GET, "/api/queue/status", None)
.await
{
Ok(_) => Ok(true), Ok(_) => Ok(true),
Err(e) => { Err(e) => {
warn!("Stalwart health check failed: {}", e); warn!("Stalwart health check failed: {}", e);
@ -1279,5 +1151,3 @@ impl StalwartClient {
} }
} }
} }

View file

@ -4,7 +4,6 @@ use serde::{Deserialize, Serialize};
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
use tokio::fs; use tokio::fs;
use uuid::Uuid; use uuid::Uuid;
#[cfg(feature = "vectordb")] #[cfg(feature = "vectordb")]
@ -13,7 +12,6 @@ use qdrant_client::{
Qdrant, Qdrant,
}; };
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmailDocument { pub struct EmailDocument {
pub id: String, pub id: String,
@ -29,7 +27,6 @@ pub struct EmailDocument {
pub thread_id: Option<String>, pub thread_id: Option<String>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmailSearchQuery { pub struct EmailSearchQuery {
pub query_text: String, pub query_text: String,
@ -40,7 +37,6 @@ pub struct EmailSearchQuery {
pub limit: usize, pub limit: usize,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmailSearchResult { pub struct EmailSearchResult {
pub email: EmailDocument, pub email: EmailDocument,
@ -48,7 +44,6 @@ pub struct EmailSearchResult {
pub snippet: String, pub snippet: String,
} }
pub struct UserEmailVectorDB { pub struct UserEmailVectorDB {
user_id: Uuid, user_id: Uuid,
bot_id: Uuid, bot_id: Uuid,
@ -59,7 +54,6 @@ pub struct UserEmailVectorDB {
} }
impl UserEmailVectorDB { impl UserEmailVectorDB {
pub fn new(user_id: Uuid, bot_id: Uuid, db_path: PathBuf) -> Self { pub fn new(user_id: Uuid, bot_id: Uuid, db_path: PathBuf) -> Self {
let collection_name = format!("emails_{}_{}", bot_id, user_id); let collection_name = format!("emails_{}_{}", bot_id, user_id);
@ -73,12 +67,10 @@ impl UserEmailVectorDB {
} }
} }
#[cfg(feature = "vectordb")] #[cfg(feature = "vectordb")]
pub async fn initialize(&mut self, qdrant_url: &str) -> Result<()> { pub async fn initialize(&mut self, qdrant_url: &str) -> Result<()> {
let client = Qdrant::from_url(qdrant_url).build()?; let client = Qdrant::from_url(qdrant_url).build()?;
let collections = client.list_collections().await?; let collections = client.list_collections().await?;
let exists = collections let exists = collections
.collections .collections
@ -86,7 +78,6 @@ impl UserEmailVectorDB {
.any(|c| c.name == self.collection_name); .any(|c| c.name == self.collection_name);
if !exists { if !exists {
client client
.create_collection( .create_collection(
qdrant_client::qdrant::CreateCollectionBuilder::new(&self.collection_name) qdrant_client::qdrant::CreateCollectionBuilder::new(&self.collection_name)
@ -111,7 +102,6 @@ impl UserEmailVectorDB {
Ok(()) Ok(())
} }
#[cfg(feature = "vectordb")] #[cfg(feature = "vectordb")]
pub async fn index_email(&self, email: &EmailDocument, embedding: Vec<f32>) -> Result<()> { pub async fn index_email(&self, email: &EmailDocument, embedding: Vec<f32>) -> Result<()> {
let client = self let client = self
@ -131,9 +121,10 @@ impl UserEmailVectorDB {
let point = PointStruct::new(email.id.clone(), embedding, payload); let point = PointStruct::new(email.id.clone(), embedding, payload);
client client
.upsert_points( .upsert_points(qdrant_client::qdrant::UpsertPointsBuilder::new(
qdrant_client::qdrant::UpsertPointsBuilder::new(&self.collection_name, vec![point]), &self.collection_name,
) vec![point],
))
.await?; .await?;
log::debug!("Indexed email: {} - {}", email.id, email.subject); log::debug!("Indexed email: {} - {}", email.id, email.subject);
@ -142,14 +133,12 @@ impl UserEmailVectorDB {
#[cfg(not(feature = "vectordb"))] #[cfg(not(feature = "vectordb"))]
pub async fn index_email(&self, email: &EmailDocument, _embedding: Vec<f32>) -> Result<()> { pub async fn index_email(&self, email: &EmailDocument, _embedding: Vec<f32>) -> Result<()> {
let file_path = self.db_path.join(format!("{}.json", email.id)); let file_path = self.db_path.join(format!("{}.json", email.id));
let json = serde_json::to_string_pretty(email)?; let json = serde_json::to_string_pretty(email)?;
fs::write(file_path, json).await?; fs::write(file_path, json).await?;
Ok(()) Ok(())
} }
pub async fn index_emails_batch(&self, emails: &[(EmailDocument, Vec<f32>)]) -> Result<()> { pub async fn index_emails_batch(&self, emails: &[(EmailDocument, Vec<f32>)]) -> Result<()> {
for (email, embedding) in emails { for (email, embedding) in emails {
self.index_email(email, embedding.clone()).await?; self.index_email(email, embedding.clone()).await?;
@ -157,7 +146,6 @@ impl UserEmailVectorDB {
Ok(()) Ok(())
} }
#[cfg(feature = "vectordb")] #[cfg(feature = "vectordb")]
pub async fn search( pub async fn search(
&self, &self,
@ -169,7 +157,6 @@ impl UserEmailVectorDB {
.as_ref() .as_ref()
.ok_or_else(|| anyhow::anyhow!("Vector DB not initialized"))?; .ok_or_else(|| anyhow::anyhow!("Vector DB not initialized"))?;
let filter = if query.account_id.is_some() || query.folder.is_some() { let filter = if query.account_id.is_some() || query.folder.is_some() {
let mut conditions = vec![]; let mut conditions = vec![];
@ -210,7 +197,8 @@ impl UserEmailVectorDB {
let payload = &point.payload; let payload = &point.payload;
if !payload.is_empty() { if !payload.is_empty() {
let get_str = |key: &str| -> String { let get_str = |key: &str| -> String {
payload.get(key) payload
.get(key)
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.map(|s| s.to_string()) .map(|s| s.to_string())
.unwrap_or_default() .unwrap_or_default()
@ -227,10 +215,12 @@ impl UserEmailVectorDB {
date: chrono::Utc::now(), date: chrono::Utc::now(),
folder: get_str("folder"), folder: get_str("folder"),
has_attachments: false, has_attachments: false,
thread_id: payload.get("thread_id").and_then(|v| v.as_str()).map(|s| s.to_string()), thread_id: payload
.get("thread_id")
.and_then(|v| v.as_str())
.map(|s| s.to_string()),
}; };
let snippet = if email.body_text.len() > 200 { let snippet = if email.body_text.len() > 200 {
format!("{}...", &email.body_text[..200]) format!("{}...", &email.body_text[..200])
} else { } else {
@ -254,7 +244,6 @@ impl UserEmailVectorDB {
query: &EmailSearchQuery, query: &EmailSearchQuery,
_query_embedding: Vec<f32>, _query_embedding: Vec<f32>,
) -> Result<Vec<EmailSearchResult>> { ) -> Result<Vec<EmailSearchResult>> {
let mut results = Vec::new(); let mut results = Vec::new();
let mut entries = fs::read_dir(&self.db_path).await?; let mut entries = fs::read_dir(&self.db_path).await?;
@ -262,7 +251,6 @@ impl UserEmailVectorDB {
if entry.path().extension().and_then(|s| s.to_str()) == Some("json") { if entry.path().extension().and_then(|s| s.to_str()) == Some("json") {
let content = fs::read_to_string(entry.path()).await?; let content = fs::read_to_string(entry.path()).await?;
if let Ok(email) = serde_json::from_str::<EmailDocument>(&content) { if let Ok(email) = serde_json::from_str::<EmailDocument>(&content) {
let query_lower = query.query_text.to_lowercase(); let query_lower = query.query_text.to_lowercase();
if email.subject.to_lowercase().contains(&query_lower) if email.subject.to_lowercase().contains(&query_lower)
|| email.body_text.to_lowercase().contains(&query_lower) || email.body_text.to_lowercase().contains(&query_lower)
@ -291,7 +279,6 @@ impl UserEmailVectorDB {
Ok(results) Ok(results)
} }
#[cfg(feature = "vectordb")] #[cfg(feature = "vectordb")]
pub async fn delete_email(&self, email_id: &str) -> Result<()> { pub async fn delete_email(&self, email_id: &str) -> Result<()> {
let client = self let client = self
@ -301,8 +288,9 @@ impl UserEmailVectorDB {
client client
.delete_points( .delete_points(
qdrant_client::qdrant::DeletePointsBuilder::new(&self.collection_name) qdrant_client::qdrant::DeletePointsBuilder::new(&self.collection_name).points(
.points(vec![qdrant_client::qdrant::PointId::from(email_id.to_string())]), vec![qdrant_client::qdrant::PointId::from(email_id.to_string())],
),
) )
.await?; .await?;
@ -319,7 +307,6 @@ impl UserEmailVectorDB {
Ok(()) Ok(())
} }
#[cfg(feature = "vectordb")] #[cfg(feature = "vectordb")]
pub async fn get_count(&self) -> Result<u64> { pub async fn get_count(&self) -> Result<u64> {
let client = self let client = self
@ -346,7 +333,6 @@ impl UserEmailVectorDB {
Ok(count) Ok(count)
} }
#[cfg(feature = "vectordb")] #[cfg(feature = "vectordb")]
pub async fn clear(&self) -> Result<()> { pub async fn clear(&self) -> Result<()> {
let client = self let client = self
@ -354,10 +340,7 @@ impl UserEmailVectorDB {
.as_ref() .as_ref()
.ok_or_else(|| anyhow::anyhow!("Vector DB not initialized"))?; .ok_or_else(|| anyhow::anyhow!("Vector DB not initialized"))?;
client client.delete_collection(&self.collection_name).await?;
.delete_collection(&self.collection_name)
.await?;
client client
.create_collection( .create_collection(
@ -384,7 +367,6 @@ impl UserEmailVectorDB {
} }
} }
pub struct EmailEmbeddingGenerator { pub struct EmailEmbeddingGenerator {
pub llm_endpoint: String, pub llm_endpoint: String,
} }
@ -394,15 +376,12 @@ impl EmailEmbeddingGenerator {
Self { llm_endpoint } Self { llm_endpoint }
} }
pub async fn generate_embedding(&self, email: &EmailDocument) -> Result<Vec<f32>> { pub async fn generate_embedding(&self, email: &EmailDocument) -> Result<Vec<f32>> {
let text = format!( let text = format!(
"From: {} <{}>\nSubject: {}\n\n{}", "From: {} <{}>\nSubject: {}\n\n{}",
email.from_name, email.from_email, email.subject, email.body_text email.from_name, email.from_email, email.subject, email.body_text
); );
let text = if text.len() > 8000 { let text = if text.len() > 8000 {
&text[..8000] &text[..8000]
} else { } else {
@ -412,17 +391,11 @@ impl EmailEmbeddingGenerator {
self.generate_text_embedding(text).await self.generate_text_embedding(text).await
} }
pub async fn generate_text_embedding(&self, text: &str) -> Result<Vec<f32>> { pub async fn generate_text_embedding(&self, text: &str) -> Result<Vec<f32>> {
let embedding_url = "http://localhost:8082".to_string(); let embedding_url = "http://localhost:8082".to_string();
return self.generate_local_embedding(text, &embedding_url).await; self.generate_local_embedding(text, &embedding_url).await
self.generate_hash_embedding(text)
} }
async fn generate_openai_embedding(&self, text: &str, api_key: &str) -> Result<Vec<f32>> { async fn generate_openai_embedding(&self, text: &str, api_key: &str) -> Result<Vec<f32>> {
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE}; use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
use serde_json::json; use serde_json::json;
@ -462,7 +435,6 @@ impl EmailEmbeddingGenerator {
Ok(embedding) Ok(embedding)
} }
async fn generate_local_embedding(&self, text: &str, embedding_url: &str) -> Result<Vec<f32>> { async fn generate_local_embedding(&self, text: &str, embedding_url: &str) -> Result<Vec<f32>> {
use serde_json::json; use serde_json::json;
@ -492,7 +464,6 @@ impl EmailEmbeddingGenerator {
Ok(embedding) Ok(embedding)
} }
fn generate_hash_embedding(&self, text: &str) -> Result<Vec<f32>> { fn generate_hash_embedding(&self, text: &str) -> Result<Vec<f32>> {
use std::collections::hash_map::DefaultHasher; use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher}; use std::hash::{Hash, Hasher};
@ -500,7 +471,6 @@ impl EmailEmbeddingGenerator {
const EMBEDDING_DIM: usize = 1536; const EMBEDDING_DIM: usize = 1536;
let mut embedding = vec![0.0f32; EMBEDDING_DIM]; let mut embedding = vec![0.0f32; EMBEDDING_DIM];
let words: Vec<&str> = text.split_whitespace().collect(); let words: Vec<&str> = text.split_whitespace().collect();
for (i, chunk) in words.chunks(10).enumerate() { for (i, chunk) in words.chunks(10).enumerate() {
@ -508,7 +478,6 @@ impl EmailEmbeddingGenerator {
chunk.join(" ").hash(&mut hasher); chunk.join(" ").hash(&mut hasher);
let hash = hasher.finish(); let hash = hasher.finish();
for j in 0..64 { for j in 0..64 {
let idx = (i * 64 + j) % EMBEDDING_DIM; let idx = (i * 64 + j) % EMBEDDING_DIM;
let value = ((hash >> j) & 1) as f32; let value = ((hash >> j) & 1) as f32;
@ -516,7 +485,6 @@ impl EmailEmbeddingGenerator {
} }
} }
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt(); let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 { if norm > 0.0 {
for val in &mut embedding { for val in &mut embedding {

View file

@ -1,36 +1,3 @@
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use rhai::{Dynamic, Engine, Map}; use rhai::{Dynamic, Engine, Map};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -41,10 +8,8 @@ use tokio::sync::RwLock;
use tracing::info; use tracing::info;
use uuid::Uuid; use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LLMRequestMetrics { pub struct LLMRequestMetrics {
pub request_id: Uuid, pub request_id: Uuid,
pub session_id: Uuid, pub session_id: Uuid,
@ -78,7 +43,6 @@ pub struct LLMRequestMetrics {
pub metadata: HashMap<String, String>, pub metadata: HashMap<String, String>,
} }
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum RequestType { pub enum RequestType {
@ -113,10 +77,8 @@ impl std::fmt::Display for RequestType {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize, Default)] #[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct AggregatedMetrics { pub struct AggregatedMetrics {
pub period_start: DateTime<Utc>, pub period_start: DateTime<Utc>,
pub period_end: DateTime<Utc>, pub period_end: DateTime<Utc>,
@ -162,10 +124,8 @@ pub struct AggregatedMetrics {
pub errors_by_type: HashMap<String, u64>, pub errors_by_type: HashMap<String, u64>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelPricing { pub struct ModelPricing {
pub model: String, pub model: String,
pub input_cost_per_1k: f64, pub input_cost_per_1k: f64,
@ -189,10 +149,8 @@ impl Default for ModelPricing {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Budget { pub struct Budget {
pub daily_limit: f64, pub daily_limit: f64,
pub monthly_limit: f64, pub monthly_limit: f64,
@ -229,10 +187,8 @@ impl Default for Budget {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TraceEvent { pub struct TraceEvent {
pub id: Uuid, pub id: Uuid,
pub parent_id: Option<Uuid>, pub parent_id: Option<Uuid>,
@ -258,7 +214,6 @@ pub struct TraceEvent {
pub error: Option<String>, pub error: Option<String>,
} }
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum TraceEventType { pub enum TraceEventType {
@ -268,7 +223,6 @@ pub enum TraceEventType {
Metric, Metric,
} }
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum TraceStatus { pub enum TraceStatus {
@ -283,10 +237,8 @@ impl Default for TraceStatus {
} }
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ObservabilityConfig { pub struct ObservabilityConfig {
pub enabled: bool, pub enabled: bool,
pub metrics_interval: u64, pub metrics_interval: u64,
@ -322,11 +274,9 @@ impl Default for ObservabilityConfig {
} }
} }
fn default_model_pricing() -> HashMap<String, ModelPricing> { fn default_model_pricing() -> HashMap<String, ModelPricing> {
let mut pricing = HashMap::new(); let mut pricing = HashMap::new();
pricing.insert( pricing.insert(
"gpt-4".to_string(), "gpt-4".to_string(),
ModelPricing { ModelPricing {
@ -360,7 +310,6 @@ fn default_model_pricing() -> HashMap<String, ModelPricing> {
}, },
); );
pricing.insert( pricing.insert(
"claude-3-opus".to_string(), "claude-3-opus".to_string(),
ModelPricing { ModelPricing {
@ -383,7 +332,6 @@ fn default_model_pricing() -> HashMap<String, ModelPricing> {
}, },
); );
pricing.insert( pricing.insert(
"mixtral-8x7b-32768".to_string(), "mixtral-8x7b-32768".to_string(),
ModelPricing { ModelPricing {
@ -395,7 +343,6 @@ fn default_model_pricing() -> HashMap<String, ModelPricing> {
}, },
); );
pricing.insert( pricing.insert(
"local".to_string(), "local".to_string(),
ModelPricing { ModelPricing {
@ -410,7 +357,6 @@ fn default_model_pricing() -> HashMap<String, ModelPricing> {
pricing pricing
} }
#[derive(Debug)] #[derive(Debug)]
pub struct ObservabilityManager { pub struct ObservabilityManager {
config: ObservabilityConfig, config: ObservabilityConfig,
@ -431,7 +377,6 @@ pub struct ObservabilityManager {
} }
impl ObservabilityManager { impl ObservabilityManager {
pub fn new(config: ObservabilityConfig) -> Self { pub fn new(config: ObservabilityConfig) -> Self {
let budget = Budget { let budget = Budget {
daily_limit: config.budget_daily, daily_limit: config.budget_daily,
@ -454,7 +399,6 @@ impl ObservabilityManager {
} }
} }
pub fn from_config(config_map: &HashMap<String, String>) -> Self { pub fn from_config(config_map: &HashMap<String, String>) -> Self {
let config = ObservabilityConfig { let config = ObservabilityConfig {
enabled: config_map enabled: config_map
@ -494,13 +438,11 @@ impl ObservabilityManager {
ObservabilityManager::new(config) ObservabilityManager::new(config)
} }
pub async fn record_request(&self, metrics: LLMRequestMetrics) { pub async fn record_request(&self, metrics: LLMRequestMetrics) {
if !self.config.enabled { if !self.config.enabled {
return; return;
} }
self.request_count.fetch_add(1, Ordering::Relaxed); self.request_count.fetch_add(1, Ordering::Relaxed);
self.token_count self.token_count
.fetch_add(metrics.total_tokens, Ordering::Relaxed); .fetch_add(metrics.total_tokens, Ordering::Relaxed);
@ -515,24 +457,20 @@ impl ObservabilityManager {
self.error_count.fetch_add(1, Ordering::Relaxed); self.error_count.fetch_add(1, Ordering::Relaxed);
} }
if self.config.cost_tracking && metrics.estimated_cost > 0.0 { if self.config.cost_tracking && metrics.estimated_cost > 0.0 {
let mut budget = self.budget.write().await; let mut budget = self.budget.write().await;
budget.daily_spend += metrics.estimated_cost; budget.daily_spend += metrics.estimated_cost;
budget.monthly_spend += metrics.estimated_cost; budget.monthly_spend += metrics.estimated_cost;
} }
let mut buffer = self.metrics_buffer.write().await; let mut buffer = self.metrics_buffer.write().await;
buffer.push(metrics); buffer.push(metrics);
if buffer.len() > 10000 { if buffer.len() > 10000 {
buffer.drain(0..1000); buffer.drain(0..1000);
} }
} }
pub fn calculate_cost(&self, model: &str, input_tokens: u64, output_tokens: u64) -> f64 { pub fn calculate_cost(&self, model: &str, input_tokens: u64, output_tokens: u64) -> f64 {
if let Some(pricing) = self.config.model_pricing.get(model) { if let Some(pricing) = self.config.model_pricing.get(model) {
if pricing.is_local { if pricing.is_local {
@ -542,12 +480,10 @@ impl ObservabilityManager {
let output_cost = (output_tokens as f64 / 1000.0) * pricing.output_cost_per_1k; let output_cost = (output_tokens as f64 / 1000.0) * pricing.output_cost_per_1k;
input_cost + output_cost + pricing.cost_per_request input_cost + output_cost + pricing.cost_per_request
} else { } else {
0.0 0.0
} }
} }
pub async fn get_budget_status(&self) -> BudgetStatus { pub async fn get_budget_status(&self) -> BudgetStatus {
let budget = self.budget.read().await; let budget = self.budget.read().await;
BudgetStatus { BudgetStatus {
@ -567,7 +503,6 @@ impl ObservabilityManager {
} }
} }
pub async fn check_budget(&self, estimated_cost: f64) -> BudgetCheckResult { pub async fn check_budget(&self, estimated_cost: f64) -> BudgetCheckResult {
let budget = self.budget.read().await; let budget = self.budget.read().await;
@ -593,7 +528,6 @@ impl ObservabilityManager {
BudgetCheckResult::Ok BudgetCheckResult::Ok
} }
pub async fn reset_daily_budget(&self) { pub async fn reset_daily_budget(&self) {
let mut budget = self.budget.write().await; let mut budget = self.budget.write().await;
budget.daily_spend = 0.0; budget.daily_spend = 0.0;
@ -601,7 +535,6 @@ impl ObservabilityManager {
budget.daily_alert_sent = false; budget.daily_alert_sent = false;
} }
pub async fn reset_monthly_budget(&self) { pub async fn reset_monthly_budget(&self) {
let mut budget = self.budget.write().await; let mut budget = self.budget.write().await;
budget.monthly_spend = 0.0; budget.monthly_spend = 0.0;
@ -609,13 +542,11 @@ impl ObservabilityManager {
budget.monthly_alert_sent = false; budget.monthly_alert_sent = false;
} }
pub async fn record_trace(&self, event: TraceEvent) { pub async fn record_trace(&self, event: TraceEvent) {
if !self.config.enabled || !self.config.trace_enabled { if !self.config.enabled || !self.config.trace_enabled {
return; return;
} }
if self.config.trace_sample_rate < 1.0 { if self.config.trace_sample_rate < 1.0 {
let sample: f64 = rand::random(); let sample: f64 = rand::random();
if sample > self.config.trace_sample_rate { if sample > self.config.trace_sample_rate {
@ -626,13 +557,11 @@ impl ObservabilityManager {
let mut buffer = self.trace_buffer.write().await; let mut buffer = self.trace_buffer.write().await;
buffer.push(event); buffer.push(event);
if buffer.len() > 5000 { if buffer.len() > 5000 {
buffer.drain(0..500); buffer.drain(0..500);
} }
} }
pub fn start_span( pub fn start_span(
&self, &self,
trace_id: Uuid, trace_id: Uuid,
@ -656,7 +585,6 @@ impl ObservabilityManager {
} }
} }
pub fn end_span(&self, span: &mut TraceEvent, status: TraceStatus, error: Option<String>) { pub fn end_span(&self, span: &mut TraceEvent, status: TraceStatus, error: Option<String>) {
let end_time = Utc::now(); let end_time = Utc::now();
span.end_time = Some(end_time); span.end_time = Some(end_time);
@ -665,7 +593,6 @@ impl ObservabilityManager {
span.error = error; span.error = error;
} }
pub async fn get_aggregated_metrics( pub async fn get_aggregated_metrics(
&self, &self,
start: DateTime<Utc>, start: DateTime<Utc>,
@ -731,7 +658,6 @@ impl ObservabilityManager {
.or_insert(0) += 1; .or_insert(0) += 1;
} }
if !latencies.is_empty() { if !latencies.is_empty() {
latencies.sort(); latencies.sort();
let len = latencies.len(); let len = latencies.len();
@ -748,7 +674,6 @@ impl ObservabilityManager {
metrics metrics
} }
pub async fn get_current_metrics(&self) -> AggregatedMetrics { pub async fn get_current_metrics(&self) -> AggregatedMetrics {
self.current_metrics.read().await.clone() self.current_metrics.read().await.clone()
} }
@ -789,13 +714,11 @@ impl ObservabilityManager {
} }
} }
pub async fn get_recent_traces(&self, limit: usize) -> Vec<TraceEvent> { pub async fn get_recent_traces(&self, limit: usize) -> Vec<TraceEvent> {
let buffer = self.trace_buffer.read().await; let buffer = self.trace_buffer.read().await;
buffer.iter().rev().take(limit).cloned().collect() buffer.iter().rev().take(limit).cloned().collect()
} }
pub async fn get_trace(&self, trace_id: Uuid) -> Vec<TraceEvent> { pub async fn get_trace(&self, trace_id: Uuid) -> Vec<TraceEvent> {
let buffer = self.trace_buffer.read().await; let buffer = self.trace_buffer.read().await;
buffer buffer
@ -806,7 +729,6 @@ impl ObservabilityManager {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BudgetStatus { pub struct BudgetStatus {
pub daily_limit: f64, pub daily_limit: f64,
@ -823,7 +745,6 @@ pub struct BudgetStatus {
pub near_monthly_limit: bool, pub near_monthly_limit: bool,
} }
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub enum BudgetCheckResult { pub enum BudgetCheckResult {
Ok, Ok,
@ -833,7 +754,6 @@ pub enum BudgetCheckResult {
MonthlyExceeded, MonthlyExceeded,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuickStats { pub struct QuickStats {
pub total_requests: u64, pub total_requests: u64,
@ -845,7 +765,6 @@ pub struct QuickStats {
pub error_rate: f64, pub error_rate: f64,
} }
impl LLMRequestMetrics { impl LLMRequestMetrics {
pub fn to_dynamic(&self) -> Dynamic { pub fn to_dynamic(&self) -> Dynamic {
let mut map = Map::new(); let mut map = Map::new();
@ -935,10 +854,7 @@ impl BudgetStatus {
} }
} }
pub fn register_observability_keywords(engine: &mut Engine) { pub fn register_observability_keywords(engine: &mut Engine) {
engine.register_fn("metrics_total_requests", |metrics: Map| -> i64 { engine.register_fn("metrics_total_requests", |metrics: Map| -> i64 {
metrics metrics
.get("total_requests") .get("total_requests")
@ -1006,8 +922,7 @@ pub fn register_observability_keywords(engine: &mut Engine) {
info!("Observability keywords registered"); info!("Observability keywords registered");
} }
pub const OBSERVABILITY_SCHEMA: &str = r"
pub const OBSERVABILITY_SCHEMA: &str = r#"
-- LLM request metrics -- LLM request metrics
CREATE TABLE IF NOT EXISTS llm_metrics ( CREATE TABLE IF NOT EXISTS llm_metrics (
id UUID PRIMARY KEY, id UUID PRIMARY KEY,
@ -1102,4 +1017,4 @@ CREATE INDEX IF NOT EXISTS idx_llm_metrics_hourly_hour ON llm_metrics_hourly(hou
CREATE INDEX IF NOT EXISTS idx_llm_traces_trace_id ON llm_traces(trace_id); CREATE INDEX IF NOT EXISTS idx_llm_traces_trace_id ON llm_traces(trace_id);
CREATE INDEX IF NOT EXISTS idx_llm_traces_start_time ON llm_traces(start_time DESC); CREATE INDEX IF NOT EXISTS idx_llm_traces_start_time ON llm_traces(start_time DESC);
CREATE INDEX IF NOT EXISTS idx_llm_traces_component ON llm_traces(component); CREATE INDEX IF NOT EXISTS idx_llm_traces_component ON llm_traces(component);
"#; ";

View file

@ -50,7 +50,7 @@ pub fn configure() -> Router<Arc<AppState>> {
} }
async fn handle_incoming( async fn handle_incoming(
State(state): State<Arc<AppState>>, State(_state): State<Arc<AppState>>,
Json(activity): Json<TeamsActivity>, Json(activity): Json<TeamsActivity>,
) -> impl IntoResponse { ) -> impl IntoResponse {
match activity.activity_type.as_str() { match activity.activity_type.as_str() {

View file

@ -1,30 +1,3 @@
use anyhow::{anyhow, Context, Result}; use anyhow::{anyhow, Context, Result};
use base64::{engine::general_purpose::STANDARD as BASE64, Engine}; use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
use reqwest::{Certificate, Client, ClientBuilder}; use reqwest::{Certificate, Client, ClientBuilder};
@ -38,28 +11,20 @@ use std::time::Duration;
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
use x509_parser::prelude::*; use x509_parser::prelude::*;
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CertPinningConfig { pub struct CertPinningConfig {
pub enabled: bool, pub enabled: bool,
pub pins: HashMap<String, Vec<PinnedCert>>, pub pins: HashMap<String, Vec<PinnedCert>>,
pub require_pins: bool, pub require_pins: bool,
pub allow_backup_pins: bool, pub allow_backup_pins: bool,
pub report_only: bool, pub report_only: bool,
pub config_path: Option<PathBuf>, pub config_path: Option<PathBuf>,
pub cache_ttl_secs: u64, pub cache_ttl_secs: u64,
} }
@ -78,12 +43,10 @@ impl Default for CertPinningConfig {
} }
impl CertPinningConfig { impl CertPinningConfig {
pub fn new() -> Self { pub fn new() -> Self {
Self::default() Self::default()
} }
pub fn strict() -> Self { pub fn strict() -> Self {
Self { Self {
enabled: true, enabled: true,
@ -96,7 +59,6 @@ impl CertPinningConfig {
} }
} }
pub fn report_only() -> Self { pub fn report_only() -> Self {
Self { Self {
enabled: true, enabled: true,
@ -109,28 +71,23 @@ impl CertPinningConfig {
} }
} }
pub fn add_pin(&mut self, pin: PinnedCert) { pub fn add_pin(&mut self, pin: PinnedCert) {
let hostname = pin.hostname.clone(); let hostname = pin.hostname.clone();
self.pins.entry(hostname).or_default().push(pin); self.pins.entry(hostname).or_default().push(pin);
} }
pub fn add_pins(&mut self, hostname: &str, pins: Vec<PinnedCert>) { pub fn add_pins(&mut self, hostname: &str, pins: Vec<PinnedCert>) {
self.pins.insert(hostname.to_string(), pins); self.pins.insert(hostname.to_string(), pins);
} }
pub fn remove_pins(&mut self, hostname: &str) { pub fn remove_pins(&mut self, hostname: &str) {
self.pins.remove(hostname); self.pins.remove(hostname);
} }
pub fn get_pins(&self, hostname: &str) -> Option<&Vec<PinnedCert>> { pub fn get_pins(&self, hostname: &str) -> Option<&Vec<PinnedCert>> {
self.pins.get(hostname) self.pins.get(hostname)
} }
pub fn load_from_file(path: &Path) -> Result<Self> { pub fn load_from_file(path: &Path) -> Result<Self> {
let content = fs::read_to_string(path) let content = fs::read_to_string(path)
.with_context(|| format!("Failed to read pin config from {:?}", path))?; .with_context(|| format!("Failed to read pin config from {:?}", path))?;
@ -142,7 +99,6 @@ impl CertPinningConfig {
Ok(config) Ok(config)
} }
pub fn save_to_file(&self, path: &Path) -> Result<()> { pub fn save_to_file(&self, path: &Path) -> Result<()> {
let content = let content =
serde_json::to_string_pretty(self).context("Failed to serialize pin configuration")?; serde_json::to_string_pretty(self).context("Failed to serialize pin configuration")?;
@ -155,31 +111,22 @@ impl CertPinningConfig {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PinnedCert { pub struct PinnedCert {
pub hostname: String, pub hostname: String,
pub fingerprint: String, pub fingerprint: String,
pub description: Option<String>, pub description: Option<String>,
pub is_backup: bool, pub is_backup: bool,
pub expires_at: Option<chrono::DateTime<chrono::Utc>>, pub expires_at: Option<chrono::DateTime<chrono::Utc>>,
pub pin_type: PinType, pub pin_type: PinType,
} }
impl PinnedCert { impl PinnedCert {
pub fn new(hostname: &str, fingerprint: &str) -> Self { pub fn new(hostname: &str, fingerprint: &str) -> Self {
Self { Self {
hostname: hostname.to_string(), hostname: hostname.to_string(),
@ -191,7 +138,6 @@ impl PinnedCert {
} }
} }
pub fn backup(hostname: &str, fingerprint: &str) -> Self { pub fn backup(hostname: &str, fingerprint: &str) -> Self {
Self { Self {
hostname: hostname.to_string(), hostname: hostname.to_string(),
@ -203,25 +149,21 @@ impl PinnedCert {
} }
} }
pub fn with_type(mut self, pin_type: PinType) -> Self { pub fn with_type(mut self, pin_type: PinType) -> Self {
self.pin_type = pin_type; self.pin_type = pin_type;
self self
} }
pub fn with_description(mut self, desc: &str) -> Self { pub fn with_description(mut self, desc: &str) -> Self {
self.description = Some(desc.to_string()); self.description = Some(desc.to_string());
self self
} }
pub fn with_expiration(mut self, expires: chrono::DateTime<chrono::Utc>) -> Self { pub fn with_expiration(mut self, expires: chrono::DateTime<chrono::Utc>) -> Self {
self.expires_at = Some(expires); self.expires_at = Some(expires);
self self
} }
pub fn get_hash_bytes(&self) -> Result<Vec<u8>> { pub fn get_hash_bytes(&self) -> Result<Vec<u8>> {
let hash_str = self let hash_str = self
.fingerprint .fingerprint
@ -233,7 +175,6 @@ impl PinnedCert {
.context("Failed to decode base64 fingerprint") .context("Failed to decode base64 fingerprint")
} }
pub fn verify(&self, cert_der: &[u8]) -> Result<bool> { pub fn verify(&self, cert_der: &[u8]) -> Result<bool> {
let expected_hash = self.get_hash_bytes()?; let expected_hash = self.get_hash_bytes()?;
let actual_hash = compute_spki_fingerprint(cert_der)?; let actual_hash = compute_spki_fingerprint(cert_der)?;
@ -242,10 +183,8 @@ impl PinnedCert {
} }
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum PinType { pub enum PinType {
Leaf, Leaf,
Intermediate, Intermediate,
@ -259,30 +198,22 @@ impl Default for PinType {
} }
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct PinValidationResult { pub struct PinValidationResult {
pub valid: bool, pub valid: bool,
pub hostname: String, pub hostname: String,
pub matched_pin: Option<String>, pub matched_pin: Option<String>,
pub actual_fingerprint: String, pub actual_fingerprint: String,
pub error: Option<String>, pub error: Option<String>,
pub backup_match: bool, pub backup_match: bool,
} }
impl PinValidationResult { impl PinValidationResult {
pub fn success(hostname: &str, fingerprint: &str, backup: bool) -> Self { pub fn success(hostname: &str, fingerprint: &str, backup: bool) -> Self {
Self { Self {
valid: true, valid: true,
@ -294,7 +225,6 @@ impl PinValidationResult {
} }
} }
pub fn failure(hostname: &str, actual: &str, error: &str) -> Self { pub fn failure(hostname: &str, actual: &str, error: &str) -> Self {
Self { Self {
valid: false, valid: false,
@ -307,7 +237,6 @@ impl PinValidationResult {
} }
} }
#[derive(Debug)] #[derive(Debug)]
pub struct CertPinningManager { pub struct CertPinningManager {
config: Arc<RwLock<CertPinningConfig>>, config: Arc<RwLock<CertPinningConfig>>,
@ -315,7 +244,6 @@ pub struct CertPinningManager {
} }
impl CertPinningManager { impl CertPinningManager {
pub fn new(config: CertPinningConfig) -> Self { pub fn new(config: CertPinningConfig) -> Self {
Self { Self {
config: Arc::new(RwLock::new(config)), config: Arc::new(RwLock::new(config)),
@ -323,17 +251,14 @@ impl CertPinningManager {
} }
} }
pub fn default_manager() -> Self { pub fn default_manager() -> Self {
Self::new(CertPinningConfig::default()) Self::new(CertPinningConfig::default())
} }
pub fn is_enabled(&self) -> bool { pub fn is_enabled(&self) -> bool {
self.config.read().unwrap().enabled self.config.read().unwrap().enabled
} }
pub fn add_pin(&self, pin: PinnedCert) -> Result<()> { pub fn add_pin(&self, pin: PinnedCert) -> Result<()> {
let mut config = self let mut config = self
.config .config
@ -343,7 +268,6 @@ impl CertPinningManager {
Ok(()) Ok(())
} }
pub fn remove_pins(&self, hostname: &str) -> Result<()> { pub fn remove_pins(&self, hostname: &str) -> Result<()> {
let mut config = self let mut config = self
.config .config
@ -351,7 +275,6 @@ impl CertPinningManager {
.map_err(|_| anyhow!("Failed to acquire write lock"))?; .map_err(|_| anyhow!("Failed to acquire write lock"))?;
config.remove_pins(hostname); config.remove_pins(hostname);
let mut cache = self let mut cache = self
.validation_cache .validation_cache
.write() .write()
@ -361,7 +284,6 @@ impl CertPinningManager {
Ok(()) Ok(())
} }
pub fn validate_certificate( pub fn validate_certificate(
&self, &self,
hostname: &str, hostname: &str,
@ -376,7 +298,6 @@ impl CertPinningManager {
return Ok(PinValidationResult::success(hostname, "disabled", false)); return Ok(PinValidationResult::success(hostname, "disabled", false));
} }
if let Ok(cache) = self.validation_cache.read() { if let Ok(cache) = self.validation_cache.read() {
if let Some((result, timestamp)) = cache.get(hostname) { if let Some((result, timestamp)) = cache.get(hostname) {
if timestamp.elapsed().as_secs() < config.cache_ttl_secs { if timestamp.elapsed().as_secs() < config.cache_ttl_secs {
@ -385,11 +306,9 @@ impl CertPinningManager {
} }
} }
let actual_hash = compute_spki_fingerprint(cert_der)?; let actual_hash = compute_spki_fingerprint(cert_der)?;
let actual_fingerprint = format!("sha256//{}", BASE64.encode(&actual_hash)); let actual_fingerprint = format!("sha256//{}", BASE64.encode(&actual_hash));
let pins = match config.get_pins(hostname) { let pins = match config.get_pins(hostname) {
Some(pins) => pins, Some(pins) => pins,
None => { None => {
@ -411,7 +330,6 @@ impl CertPinningManager {
return Ok(result); return Ok(result);
} }
return Ok(PinValidationResult::success( return Ok(PinValidationResult::success(
hostname, hostname,
"no-pins-required", "no-pins-required",
@ -420,7 +338,6 @@ impl CertPinningManager {
} }
}; };
for pin in pins { for pin in pins {
match pin.verify(cert_der) { match pin.verify(cert_der) {
Ok(true) => { Ok(true) => {
@ -435,7 +352,6 @@ impl CertPinningManager {
); );
} }
if let Ok(mut cache) = self.validation_cache.write() { if let Ok(mut cache) = self.validation_cache.write() {
cache.insert( cache.insert(
hostname.to_string(), hostname.to_string(),
@ -445,15 +361,13 @@ impl CertPinningManager {
return Ok(result); return Ok(result);
} }
Ok(false) => continue, Ok(false) => {}
Err(e) => { Err(e) => {
debug!("Pin verification error for {}: {}", hostname, e); debug!("Pin verification error for {}: {}", hostname, e);
continue;
} }
} }
} }
let result = PinValidationResult::failure( let result = PinValidationResult::failure(
hostname, hostname,
&actual_fingerprint, &actual_fingerprint,
@ -479,12 +393,10 @@ impl CertPinningManager {
Ok(result) Ok(result)
} }
pub fn create_pinned_client(&self, hostname: &str) -> Result<Client> { pub fn create_pinned_client(&self, hostname: &str) -> Result<Client> {
self.create_pinned_client_with_options(hostname, None, Duration::from_secs(30)) self.create_pinned_client_with_options(hostname, None, Duration::from_secs(30))
} }
pub fn create_pinned_client_with_options( pub fn create_pinned_client_with_options(
&self, &self,
hostname: &str, hostname: &str,
@ -503,14 +415,10 @@ impl CertPinningManager {
.https_only(true) .https_only(true)
.tls_built_in_root_certs(true); .tls_built_in_root_certs(true);
if let Some(cert) = ca_cert { if let Some(cert) = ca_cert {
builder = builder.add_root_certificate(cert.clone()); builder = builder.add_root_certificate(cert.clone());
} }
if config.enabled && config.get_pins(hostname).is_some() { if config.enabled && config.get_pins(hostname).is_some() {
debug!( debug!(
"Creating pinned client for {} with {} pins", "Creating pinned client for {} with {} pins",
@ -522,7 +430,6 @@ impl CertPinningManager {
builder.build().context("Failed to build HTTP client") builder.build().context("Failed to build HTTP client")
} }
pub fn validate_pem_file( pub fn validate_pem_file(
&self, &self,
hostname: &str, hostname: &str,
@ -535,12 +442,10 @@ impl CertPinningManager {
self.validate_certificate(hostname, &der) self.validate_certificate(hostname, &der)
} }
pub fn generate_pin_from_file(hostname: &str, cert_path: &Path) -> Result<PinnedCert> { pub fn generate_pin_from_file(hostname: &str, cert_path: &Path) -> Result<PinnedCert> {
let cert_data = fs::read(cert_path) let cert_data = fs::read(cert_path)
.with_context(|| format!("Failed to read certificate: {:?}", cert_path))?; .with_context(|| format!("Failed to read certificate: {:?}", cert_path))?;
let der = if cert_data.starts_with(b"-----BEGIN") { let der = if cert_data.starts_with(b"-----BEGIN") {
pem_to_der(&cert_data)? pem_to_der(&cert_data)?
} else { } else {
@ -553,7 +458,6 @@ impl CertPinningManager {
Ok(PinnedCert::new(hostname, &fingerprint_str)) Ok(PinnedCert::new(hostname, &fingerprint_str))
} }
pub fn generate_pins_from_directory( pub fn generate_pins_from_directory(
hostname: &str, hostname: &str,
cert_dir: &Path, cert_dir: &Path,
@ -583,7 +487,6 @@ impl CertPinningManager {
Ok(pins) Ok(pins)
} }
pub fn export_pins(&self, path: &Path) -> Result<()> { pub fn export_pins(&self, path: &Path) -> Result<()> {
let config = self let config = self
.config .config
@ -593,7 +496,6 @@ impl CertPinningManager {
config.save_to_file(path) config.save_to_file(path)
} }
pub fn import_pins(&self, path: &Path) -> Result<()> { pub fn import_pins(&self, path: &Path) -> Result<()> {
let imported = CertPinningConfig::load_from_file(path)?; let imported = CertPinningConfig::load_from_file(path)?;
@ -606,7 +508,6 @@ impl CertPinningManager {
config.pins.insert(hostname, pins); config.pins.insert(hostname, pins);
} }
if let Ok(mut cache) = self.validation_cache.write() { if let Ok(mut cache) = self.validation_cache.write() {
cache.clear(); cache.clear();
} }
@ -614,7 +515,6 @@ impl CertPinningManager {
Ok(()) Ok(())
} }
pub fn get_stats(&self) -> Result<PinningStats> { pub fn get_stats(&self) -> Result<PinningStats> {
let config = self let config = self
.config .config
@ -653,7 +553,6 @@ impl CertPinningManager {
} }
} }
#[derive(Debug, Clone, Serialize)] #[derive(Debug, Clone, Serialize)]
pub struct PinningStats { pub struct PinningStats {
pub enabled: bool, pub enabled: bool,
@ -664,31 +563,25 @@ pub struct PinningStats {
pub report_only: bool, pub report_only: bool,
} }
pub fn compute_spki_fingerprint(cert_der: &[u8]) -> Result<Vec<u8>> { pub fn compute_spki_fingerprint(cert_der: &[u8]) -> Result<Vec<u8>> {
let (_, cert) = X509Certificate::from_der(cert_der) let (_, cert) = X509Certificate::from_der(cert_der)
.map_err(|e| anyhow!("Failed to parse X.509 certificate: {}", e))?; .map_err(|e| anyhow!("Failed to parse X.509 certificate: {}", e))?;
let spki = cert.public_key().raw; let spki = cert.public_key().raw;
let hash = digest(&SHA256, spki); let hash = digest(&SHA256, spki);
Ok(hash.as_ref().to_vec()) Ok(hash.as_ref().to_vec())
} }
pub fn compute_cert_fingerprint(cert_der: &[u8]) -> Vec<u8> { pub fn compute_cert_fingerprint(cert_der: &[u8]) -> Vec<u8> {
let hash = digest(&SHA256, cert_der); let hash = digest(&SHA256, cert_der);
hash.as_ref().to_vec() hash.as_ref().to_vec()
} }
pub fn pem_to_der(pem_data: &[u8]) -> Result<Vec<u8>> { pub fn pem_to_der(pem_data: &[u8]) -> Result<Vec<u8>> {
let pem_str = std::str::from_utf8(pem_data).context("Invalid UTF-8 in PEM data")?; let pem_str = std::str::from_utf8(pem_data).context("Invalid UTF-8 in PEM data")?;
let start_marker = "-----BEGIN CERTIFICATE-----"; let start_marker = "-----BEGIN CERTIFICATE-----";
let end_marker = "-----END CERTIFICATE-----"; let end_marker = "-----END CERTIFICATE-----";
@ -708,7 +601,6 @@ pub fn pem_to_der(pem_data: &[u8]) -> Result<Vec<u8>> {
.context("Failed to decode base64 certificate data") .context("Failed to decode base64 certificate data")
} }
pub fn format_fingerprint(hash: &[u8]) -> String { pub fn format_fingerprint(hash: &[u8]) -> String {
hash.iter() hash.iter()
.map(|b| format!("{:02X}", b)) .map(|b| format!("{:02X}", b))
@ -716,16 +608,13 @@ pub fn format_fingerprint(hash: &[u8]) -> String {
.join(":") .join(":")
} }
pub fn parse_fingerprint(formatted: &str) -> Result<Vec<u8>> { pub fn parse_fingerprint(formatted: &str) -> Result<Vec<u8>> {
if let Some(base64_part) = formatted.strip_prefix("sha256//") { if let Some(base64_part) = formatted.strip_prefix("sha256//") {
return BASE64 return BASE64
.decode(base64_part) .decode(base64_part)
.context("Failed to decode base64 fingerprint"); .context("Failed to decode base64 fingerprint");
} }
if formatted.contains(':') { if formatted.contains(':') {
let bytes: Result<Vec<u8>, _> = formatted let bytes: Result<Vec<u8>, _> = formatted
.split(':') .split(':')
@ -735,7 +624,6 @@ pub fn parse_fingerprint(formatted: &str) -> Result<Vec<u8>> {
return bytes.context("Failed to parse hex fingerprint"); return bytes.context("Failed to parse hex fingerprint");
} }
let bytes: Result<Vec<u8>, _> = (0..formatted.len()) let bytes: Result<Vec<u8>, _> = (0..formatted.len())
.step_by(2) .step_by(2)
.map(|i| u8::from_str_radix(&formatted[i..i + 2], 16)) .map(|i| u8::from_str_radix(&formatted[i..i + 2], 16))

View file

@ -83,7 +83,7 @@ type TaskHandler = Arc<
impl TaskScheduler { impl TaskScheduler {
pub fn new(state: Arc<AppState>) -> Self { pub fn new(state: Arc<AppState>) -> Self {
let scheduler = Self { let scheduler = Self {
state: state, state,
running_tasks: Arc::new(RwLock::new(HashMap::new())), running_tasks: Arc::new(RwLock::new(HashMap::new())),
task_registry: Arc::new(RwLock::new(HashMap::new())), task_registry: Arc::new(RwLock::new(HashMap::new())),
scheduled_tasks: Arc::new(RwLock::new(Vec::new())), scheduled_tasks: Arc::new(RwLock::new(Vec::new())),
@ -101,14 +101,10 @@ impl TaskScheduler {
tokio::spawn(async move { tokio::spawn(async move {
let mut handlers = registry.write().await; let mut handlers = registry.write().await;
handlers.insert( handlers.insert(
"database_cleanup".to_string(), "database_cleanup".to_string(),
Arc::new(move |_state: Arc<AppState>, _payload: serde_json::Value| { Arc::new(move |_state: Arc<AppState>, _payload: serde_json::Value| {
Box::pin(async move { Box::pin(async move {
info!("Database cleanup task executed"); info!("Database cleanup task executed");
Ok(serde_json::json!({ Ok(serde_json::json!({
@ -120,7 +116,6 @@ impl TaskScheduler {
}), }),
); );
handlers.insert( handlers.insert(
"cache_cleanup".to_string(), "cache_cleanup".to_string(),
Arc::new(move |state: Arc<AppState>, _payload: serde_json::Value| { Arc::new(move |state: Arc<AppState>, _payload: serde_json::Value| {
@ -139,7 +134,6 @@ impl TaskScheduler {
}), }),
); );
handlers.insert( handlers.insert(
"backup".to_string(), "backup".to_string(),
Arc::new(move |state: Arc<AppState>, payload: serde_json::Value| { Arc::new(move |state: Arc<AppState>, payload: serde_json::Value| {
@ -157,7 +151,6 @@ impl TaskScheduler {
.arg(&backup_file) .arg(&backup_file)
.output()?; .output()?;
if state.s3_client.is_some() { if state.s3_client.is_some() {
let s3 = state.s3_client.as_ref().unwrap(); let s3 = state.s3_client.as_ref().unwrap();
let body = tokio::fs::read(&backup_file).await?; let body = tokio::fs::read(&backup_file).await?;
@ -196,7 +189,6 @@ impl TaskScheduler {
}), }),
); );
handlers.insert( handlers.insert(
"generate_report".to_string(), "generate_report".to_string(),
Arc::new(move |_state: Arc<AppState>, payload: serde_json::Value| { Arc::new(move |_state: Arc<AppState>, payload: serde_json::Value| {
@ -229,7 +221,6 @@ impl TaskScheduler {
}), }),
); );
handlers.insert( handlers.insert(
"health_check".to_string(), "health_check".to_string(),
Arc::new(move |state: Arc<AppState>, _payload: serde_json::Value| { Arc::new(move |state: Arc<AppState>, _payload: serde_json::Value| {
@ -240,17 +231,14 @@ impl TaskScheduler {
"timestamp": Utc::now() "timestamp": Utc::now()
}); });
let db_ok = state.conn.get().is_ok(); let db_ok = state.conn.get().is_ok();
health["database"] = serde_json::json!(db_ok); health["database"] = serde_json::json!(db_ok);
if let Some(cache) = &state.cache { if let Some(cache) = &state.cache {
let cache_ok = cache.get_connection().is_ok(); let cache_ok = cache.get_connection().is_ok();
health["cache"] = serde_json::json!(cache_ok); health["cache"] = serde_json::json!(cache_ok);
} }
if let Some(s3) = &state.s3_client { if let Some(s3) = &state.s3_client {
let s3_ok = s3.list_buckets().send().await.is_ok(); let s3_ok = s3.list_buckets().send().await.is_ok();
health["storage"] = serde_json::json!(s3_ok); health["storage"] = serde_json::json!(s3_ok);
@ -351,7 +339,6 @@ impl TaskScheduler {
let execution_id = Uuid::new_v4(); let execution_id = Uuid::new_v4();
let started_at = Utc::now(); let started_at = Utc::now();
let _execution = TaskExecution { let _execution = TaskExecution {
id: execution_id, id: execution_id,
scheduled_task_id: task_id, scheduled_task_id: task_id,
@ -363,11 +350,6 @@ impl TaskScheduler {
duration_ms: None, duration_ms: None,
}; };
let result = { let result = {
let handlers = registry.read().await; let handlers = registry.read().await;
if let Some(handler) = handlers.get(&task.task_type) { if let Some(handler) = handlers.get(&task.task_type) {
@ -388,25 +370,20 @@ impl TaskScheduler {
let completed_at = Utc::now(); let completed_at = Utc::now();
let _duration_ms = (completed_at - started_at).num_milliseconds(); let _duration_ms = (completed_at - started_at).num_milliseconds();
match result { match result {
Ok(_result) => { Ok(_result) => {
let schedule = Schedule::from_str(&task.cron_expression).ok(); let schedule = Schedule::from_str(&task.cron_expression).ok();
let _next_run = schedule let _next_run = schedule
.and_then(|s| s.upcoming(chrono::Local).take(1).next()) .and_then(|s| s.upcoming(chrono::Local).take(1).next())
.map(|dt| dt.with_timezone(&Utc)) .map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|| Utc::now() + Duration::hours(1)); .unwrap_or_else(|| Utc::now() + Duration::hours(1));
info!("Task {} completed successfully", task.name); info!("Task {} completed successfully", task.name);
} }
Err(e) => { Err(e) => {
let error_msg = format!("Task failed: {}", e); let error_msg = format!("Task failed: {}", e);
error!("{}", error_msg); error!("{}", error_msg);
task.retry_count += 1; task.retry_count += 1;
if task.retry_count < task.max_retries { if task.retry_count < task.max_retries {
let _retry_delay = let _retry_delay =
@ -424,12 +401,10 @@ impl TaskScheduler {
} }
} }
let mut running = running_tasks.write().await; let mut running = running_tasks.write().await;
running.remove(&task_id); running.remove(&task_id);
}); });
let mut running = self.running_tasks.write().await; let mut running = self.running_tasks.write().await;
running.insert(task_id, handle); running.insert(task_id, handle);
} }
@ -445,7 +420,6 @@ impl TaskScheduler {
info!("Stopped task: {}", task_id); info!("Stopped task: {}", task_id);
} }
let mut tasks = self.scheduled_tasks.write().await; let mut tasks = self.scheduled_tasks.write().await;
if let Some(task) = tasks.iter_mut().find(|t| t.id == task_id) { if let Some(task) = tasks.iter_mut().find(|t| t.id == task_id) {
task.enabled = false; task.enabled = false;