Update attendance, keywords, calendar, compliance, console, core, drive, email, llm, msteams, security, and tasks modules

This commit is contained in:
Rodrigo Rodriguez (Pragmatismo) 2025-12-24 09:29:27 -03:00
parent 883c6d07e1
commit 7cbfe43319
33 changed files with 525 additions and 2412 deletions

View file

@ -2,7 +2,7 @@
use anyhow::{anyhow, Result};
use chrono::{DateTime, Duration, Local, NaiveTime, Utc};
use chrono::{DateTime, Duration, Local, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;

View file

@ -1,40 +1,3 @@
use crate::core::config::ConfigManager;
use crate::shared::models::UserSession;
use crate::shared::state::AppState;
@ -46,18 +9,14 @@ use axum::{
};
use chrono::Utc;
use diesel::prelude::*;
use log::{debug, error, info, warn};
use log::{error, info, warn};
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::sync::Arc;
use uuid::Uuid;
#[derive(Debug, Clone, Default)]
pub struct LlmAssistConfig {
pub tips_enabled: bool,
pub polish_enabled: bool,
@ -74,7 +33,6 @@ pub struct LlmAssistConfig {
}
impl LlmAssistConfig {
pub fn from_config(bot_id: Uuid, work_path: &str) -> Self {
let config_path = PathBuf::from(work_path)
.join(format!("{}.gbai", bot_id))
@ -143,7 +101,6 @@ impl LlmAssistConfig {
config
}
pub fn any_enabled(&self) -> bool {
self.tips_enabled
|| self.polish_enabled
@ -153,9 +110,6 @@ impl LlmAssistConfig {
}
}
#[derive(Debug, Deserialize)]
pub struct TipRequest {
pub session_id: Uuid,
@ -165,7 +119,6 @@ pub struct TipRequest {
pub history: Vec<ConversationMessage>,
}
#[derive(Debug, Deserialize)]
pub struct PolishRequest {
pub session_id: Uuid,
@ -179,7 +132,6 @@ fn default_tone() -> String {
"professional".to_string()
}
#[derive(Debug, Deserialize)]
pub struct SmartRepliesRequest {
pub session_id: Uuid,
@ -187,13 +139,11 @@ pub struct SmartRepliesRequest {
pub history: Vec<ConversationMessage>,
}
#[derive(Debug, Deserialize)]
pub struct SummaryRequest {
pub session_id: Uuid,
}
#[derive(Debug, Deserialize)]
pub struct SentimentRequest {
pub session_id: Uuid,
@ -202,7 +152,6 @@ pub struct SentimentRequest {
pub history: Vec<ConversationMessage>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConversationMessage {
pub role: String,
@ -210,7 +159,6 @@ pub struct ConversationMessage {
pub timestamp: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct TipResponse {
pub success: bool,
@ -219,7 +167,6 @@ pub struct TipResponse {
pub error: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct AttendantTip {
pub tip_type: TipType,
@ -228,11 +175,9 @@ pub struct AttendantTip {
pub priority: i32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TipType {
Intent,
Action,
@ -246,7 +191,6 @@ pub enum TipType {
General,
}
#[derive(Debug, Serialize)]
pub struct PolishResponse {
pub success: bool,
@ -257,7 +201,6 @@ pub struct PolishResponse {
pub error: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct SmartRepliesResponse {
pub success: bool,
@ -266,7 +209,6 @@ pub struct SmartRepliesResponse {
pub error: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct SmartReply {
pub text: String,
@ -275,7 +217,6 @@ pub struct SmartReply {
pub category: String,
}
#[derive(Debug, Serialize)]
pub struct SummaryResponse {
pub success: bool,
@ -284,7 +225,6 @@ pub struct SummaryResponse {
pub error: Option<String>,
}
#[derive(Debug, Clone, Serialize, Default)]
pub struct ConversationSummary {
pub brief: String,
@ -297,7 +237,6 @@ pub struct ConversationSummary {
pub duration_minutes: i32,
}
#[derive(Debug, Serialize)]
pub struct SentimentResponse {
pub success: bool,
@ -306,7 +245,6 @@ pub struct SentimentResponse {
pub error: Option<String>,
}
#[derive(Debug, Clone, Serialize, Default)]
pub struct SentimentAnalysis {
pub overall: String,
@ -317,16 +255,12 @@ pub struct SentimentAnalysis {
pub emoji: String,
}
#[derive(Debug, Clone, Serialize)]
pub struct Emotion {
pub name: String,
pub intensity: f32,
}
async fn execute_llm_with_context(
state: &Arc<AppState>,
bot_id: Uuid,
@ -351,7 +285,6 @@ async fn execute_llm_with_context(
.unwrap_or_default()
});
let messages = serde_json::json!([
{
"role": "system",
@ -368,29 +301,24 @@ async fn execute_llm_with_context(
.generate(user_prompt, &messages, &model, &key)
.await?;
let handler = crate::llm::llm_models::get_handler(&model);
let processed = handler.process_content(&response);
Ok(processed)
}
fn get_bot_system_prompt(bot_id: Uuid, work_path: &str) -> String {
let config = LlmAssistConfig::from_config(bot_id, work_path);
if let Some(prompt) = config.bot_system_prompt {
return prompt;
}
let start_bas_path = PathBuf::from(work_path)
.join(format!("{}.gbai", bot_id))
.join(format!("{}.gbdialog", bot_id))
.join("start.bas");
if let Ok(content) = std::fs::read_to_string(&start_bas_path) {
let mut description_lines = Vec::new();
for line in content.lines() {
let trimmed = line.trim();
@ -406,21 +334,15 @@ fn get_bot_system_prompt(bot_id: Uuid, work_path: &str) -> String {
}
}
"You are a professional customer service assistant. Be helpful, empathetic, and solution-oriented. Maintain a friendly but professional tone.".to_string()
}
pub async fn generate_tips(
State(state): State<Arc<AppState>>,
Json(request): Json<TipRequest>,
) -> (StatusCode, Json<TipResponse>) {
info!("Generating tips for session {}", request.session_id);
let session_result = get_session(&state, request.session_id).await;
let session = match session_result {
Ok(s) => s,
@ -436,7 +358,6 @@ pub async fn generate_tips(
}
};
let work_path = std::env::var("WORK_PATH").unwrap_or_else(|_| "./work".to_string());
let config = LlmAssistConfig::from_config(session.bot_id, &work_path);
@ -451,7 +372,6 @@ pub async fn generate_tips(
);
}
let history_context = request
.history
.iter()
@ -497,7 +417,6 @@ Provide tips for the attendant."#,
match execute_llm_with_context(&state, session.bot_id, &system_prompt, &user_prompt).await {
Ok(response) => {
let tips = parse_tips_response(&response);
(
StatusCode::OK,
@ -523,8 +442,6 @@ Provide tips for the attendant."#,
}
}
pub async fn polish_message(
State(state): State<Arc<AppState>>,
Json(request): Json<PolishRequest>,
@ -621,8 +538,6 @@ Respond in JSON format:
}
}
pub async fn generate_smart_replies(
State(state): State<Arc<AppState>>,
Json(request): Json<SmartRepliesRequest>,
@ -677,7 +592,7 @@ The service has this personality: {}
Generate exactly 3 reply suggestions that:
1. Are contextually appropriate
2. Sound natural and human (not robotic)
3. Vary in approach (one empathetic, one solution-focused, one follow-up)
3. Vary in approach (one empathetic, one solution-focused, one follow_up)
4. Are ready to send (no placeholders like [name])
Respond in JSON format:
@ -692,10 +607,10 @@ Respond in JSON format:
);
let user_prompt = format!(
r#"Conversation:
r"Conversation:
{}
Generate 3 reply options for the attendant."#,
Generate 3 reply options for the attendant.",
history_context
);
@ -725,8 +640,6 @@ Generate 3 reply options for the attendant."#,
}
}
pub async fn generate_summary(
State(state): State<Arc<AppState>>,
Path(session_id): Path<Uuid>,
@ -762,7 +675,6 @@ pub async fn generate_summary(
);
}
let history = load_conversation_history(&state, session_id).await;
if history.is_empty() {
@ -806,9 +718,9 @@ Respond in JSON format:
);
let user_prompt = format!(
r#"Summarize this conversation:
r"Summarize this conversation:
{}"#,
{}",
history_text
);
@ -817,7 +729,6 @@ Respond in JSON format:
let mut summary = parse_summary_response(&response);
summary.message_count = history.len() as i32;
if let (Some(first_ts), Some(last_ts)) = (
history.first().and_then(|m| m.timestamp.as_ref()),
history.last().and_then(|m| m.timestamp.as_ref()),
@ -857,8 +768,6 @@ Respond in JSON format:
}
}
pub async fn analyze_sentiment(
State(state): State<Arc<AppState>>,
Json(request): Json<SentimentRequest>,
@ -884,7 +793,6 @@ pub async fn analyze_sentiment(
let config = LlmAssistConfig::from_config(session.bot_id, &work_path);
if !config.sentiment_enabled {
let sentiment = analyze_sentiment_keywords(&request.message);
return (
StatusCode::OK,
@ -959,8 +867,6 @@ Analyze the customer's sentiment."#,
}
}
pub async fn get_llm_config(
State(_state): State<Arc<AppState>>,
Path(bot_id): Path<Uuid>,
@ -981,9 +887,6 @@ pub async fn get_llm_config(
)
}
pub async fn process_attendant_command(
state: &Arc<AppState>,
attendant_phone: &str,
@ -1020,7 +923,6 @@ pub async fn process_attendant_command(
}
async fn handle_queue_command(state: &Arc<AppState>) -> Result<String, String> {
let conn = state.conn.clone();
let result = tokio::task::spawn_blocking(move || {
let mut db_conn = conn.get().map_err(|e| e.to_string())?;
@ -1174,8 +1076,6 @@ async fn handle_status_command(
}
};
let conn = state.conn.clone();
let phone = attendant_phone.to_string();
let status_val = status_value.to_string();
@ -1185,8 +1085,6 @@ async fn handle_status_command(
use crate::shared::models::schema::user_sessions;
let sessions: Vec<UserSession> = user_sessions::table
.filter(
user_sessions::context_data
@ -1243,7 +1141,6 @@ async fn handle_transfer_command(
let target = args.join(" ");
let target_clean = target.trim_start_matches('@').to_string();
let conn = state.conn.clone();
let target_attendant = target_clean.clone();
@ -1252,13 +1149,11 @@ async fn handle_transfer_command(
use crate::shared::models::schema::user_sessions;
let session: UserSession = user_sessions::table
.find(session_id)
.first(&mut db_conn)
.map_err(|e| format!("Session not found: {}", e))?;
let mut ctx = session.context_data.clone();
let previous_attendant = ctx
.get("assigned_to_phone")
@ -1271,11 +1166,9 @@ async fn handle_transfer_command(
ctx["transferred_at"] = serde_json::json!(Utc::now().to_rfc3339());
ctx["status"] = serde_json::json!("pending_transfer");
ctx["assigned_to_phone"] = serde_json::Value::Null;
ctx["assigned_to"] = serde_json::Value::Null;
ctx["needs_human"] = serde_json::json!(true);
diesel::update(user_sessions::table.filter(user_sessions::id.eq(session_id)))
@ -1347,7 +1240,6 @@ async fn handle_tips_command(
) -> Result<String, String> {
let session_id = current_session.ok_or("No active conversation. Use /take first.")?;
let history = load_conversation_history(state, session_id).await;
if history.is_empty() {
@ -1369,7 +1261,6 @@ async fn handle_tips_command(
history,
};
let (_, Json(tip_response)) = generate_tips(State(state.clone()), Json(request)).await;
if tip_response.tips.is_empty() {
@ -1446,7 +1337,8 @@ async fn handle_replies_command(
history,
};
let (_, Json(replies_response)) = generate_smart_replies(State(state.clone()), Json(request)).await;
let (_, Json(replies_response)) =
generate_smart_replies(State(state.clone()), Json(request)).await;
if replies_response.replies.is_empty() {
return Ok(" No reply suggestions available.".to_string());
@ -1474,7 +1366,8 @@ async fn handle_summary_command(
) -> Result<String, String> {
let session_id = current_session.ok_or("No active conversation")?;
let (_, Json(summary_response)) = generate_summary(State(state.clone()), Path(session_id)).await;
let (_, Json(summary_response)) =
generate_summary(State(state.clone()), Path(session_id)).await;
if !summary_response.success {
return Err(summary_response
@ -1527,7 +1420,7 @@ async fn handle_summary_command(
}
fn get_help_text() -> String {
r#" *Attendant Commands*
r#"*Attendant Commands*
*Queue Management:*
`/queue` - View waiting conversations
@ -1549,9 +1442,6 @@ _Portuguese: /fila, /pegar, /transferir, /resolver, /dicas, /polir, /respostas,
.to_string()
}
async fn get_session(state: &Arc<AppState>, session_id: Uuid) -> Result<UserSession, String> {
let conn = state.conn.clone();
@ -1569,7 +1459,6 @@ async fn get_session(state: &Arc<AppState>, session_id: Uuid) -> Result<UserSess
.map_err(|e| format!("Task error: {}", e))?
}
async fn load_conversation_history(
state: &Arc<AppState>,
session_id: Uuid,
@ -1616,9 +1505,7 @@ async fn load_conversation_history(
result
}
fn parse_tips_response(response: &str) -> Vec<AttendantTip> {
let json_str = extract_json(response);
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&json_str) {
@ -1653,7 +1540,6 @@ fn parse_tips_response(response: &str) -> Vec<AttendantTip> {
}
}
if !response.trim().is_empty() {
vec![AttendantTip {
tip_type: TipType::General,
@ -1666,7 +1552,6 @@ fn parse_tips_response(response: &str) -> Vec<AttendantTip> {
}
}
fn parse_polish_response(response: &str, original: &str) -> (String, Vec<String>) {
let json_str = extract_json(response);
@ -1690,14 +1575,12 @@ fn parse_polish_response(response: &str, original: &str) -> (String, Vec<String>
return (polished, changes);
}
(
response.trim().to_string(),
vec!["Message improved".to_string()],
)
}
fn parse_smart_replies_response(response: &str) -> Vec<SmartReply> {
let json_str = extract_json(response);
@ -1731,7 +1614,6 @@ fn parse_smart_replies_response(response: &str) -> Vec<SmartReply> {
generate_fallback_replies()
}
fn parse_summary_response(response: &str) -> ConversationSummary {
let json_str = extract_json(response);
@ -1789,7 +1671,6 @@ fn parse_summary_response(response: &str) -> ConversationSummary {
}
}
fn parse_sentiment_response(response: &str) -> SentimentAnalysis {
let json_str = extract_json(response);
@ -1839,9 +1720,7 @@ fn parse_sentiment_response(response: &str) -> SentimentAnalysis {
SentimentAnalysis::default()
}
fn extract_json(response: &str) -> String {
if let Some(start) = response.find('{') {
if let Some(end) = response.rfind('}') {
if end > start {
@ -1850,7 +1729,6 @@ fn extract_json(response: &str) -> String {
}
}
if let Some(start) = response.find('[') {
if let Some(end) = response.rfind(']') {
if end > start {
@ -1862,12 +1740,10 @@ fn extract_json(response: &str) -> String {
response.to_string()
}
fn generate_fallback_tips(message: &str) -> Vec<AttendantTip> {
let msg_lower = message.to_lowercase();
let mut tips = Vec::new();
if msg_lower.contains("urgent")
|| msg_lower.contains("asap")
|| msg_lower.contains("immediately")
@ -1881,7 +1757,6 @@ fn generate_fallback_tips(message: &str) -> Vec<AttendantTip> {
});
}
if msg_lower.contains("frustrated")
|| msg_lower.contains("angry")
|| msg_lower.contains("ridiculous")
@ -1895,7 +1770,6 @@ fn generate_fallback_tips(message: &str) -> Vec<AttendantTip> {
});
}
if message.contains('?') {
tips.push(AttendantTip {
tip_type: TipType::Intent,
@ -1905,7 +1779,6 @@ fn generate_fallback_tips(message: &str) -> Vec<AttendantTip> {
});
}
if msg_lower.contains("problem")
|| msg_lower.contains("issue")
|| msg_lower.contains("not working")
@ -1919,7 +1792,6 @@ fn generate_fallback_tips(message: &str) -> Vec<AttendantTip> {
});
}
if msg_lower.contains("thank")
|| msg_lower.contains("great")
|| msg_lower.contains("perfect")
@ -1934,7 +1806,6 @@ fn generate_fallback_tips(message: &str) -> Vec<AttendantTip> {
});
}
if tips.is_empty() {
tips.push(AttendantTip {
tip_type: TipType::General,
@ -1947,7 +1818,6 @@ fn generate_fallback_tips(message: &str) -> Vec<AttendantTip> {
tips
}
fn generate_fallback_replies() -> Vec<SmartReply> {
vec![
SmartReply {
@ -1971,7 +1841,6 @@ fn generate_fallback_replies() -> Vec<SmartReply> {
]
}
fn analyze_sentiment_keywords(message: &str) -> SentimentAnalysis {
let msg_lower = message.to_lowercase();
@ -2097,5 +1966,3 @@ fn analyze_sentiment_keywords(message: &str) -> SentimentAnalysis {
emoji: emoji.to_string(),
}
}

View file

@ -1,40 +1,8 @@
pub mod drive;
pub mod keyword_services;
pub mod llm_assist;
pub mod queue;
pub use drive::{AttendanceDriveConfig, AttendanceDriveService, RecordMetadata, SyncResult};
pub use keyword_services::{
AttendanceCommand, AttendanceRecord, AttendanceResponse, AttendanceService, KeywordConfig,
@ -58,14 +26,13 @@ use crate::shared::state::{AppState, AttendantNotification};
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
Path, Query, State,
Query, State,
},
http::StatusCode,
response::IntoResponse,
routing::{get, post},
Json, Router,
};
use botlib::MessageType;
use chrono::Utc;
use diesel::prelude::*;
use futures::{SinkExt, StreamExt};
@ -76,10 +43,8 @@ use std::sync::Arc;
use tokio::sync::broadcast;
use uuid::Uuid;
pub fn configure_attendance_routes() -> Router<Arc<AppState>> {
Router::new()
.route("/api/attendance/queue", get(queue::list_queue))
.route("/api/attendance/attendants", get(queue::list_attendants))
.route("/api/attendance/assign", post(queue::assign_conversation))
@ -92,11 +57,8 @@ pub fn configure_attendance_routes() -> Router<Arc<AppState>> {
post(queue::resolve_conversation),
)
.route("/api/attendance/insights", get(queue::get_insights))
.route("/api/attendance/respond", post(attendant_respond))
.route("/ws/attendant", get(attendant_websocket_handler))
.route("/api/attendance/llm/tips", post(llm_assist::generate_tips))
.route(
"/api/attendance/llm/polish",
@ -120,7 +82,6 @@ pub fn configure_attendance_routes() -> Router<Arc<AppState>> {
)
}
#[derive(Debug, Deserialize)]
pub struct AttendantRespondRequest {
pub session_id: String,
@ -128,7 +89,6 @@ pub struct AttendantRespondRequest {
pub attendant_id: String,
}
#[derive(Debug, Serialize)]
pub struct AttendantRespondResponse {
pub success: bool,
@ -137,7 +97,6 @@ pub struct AttendantRespondResponse {
pub error: Option<String>,
}
pub async fn attendant_respond(
State(state): State<Arc<AppState>>,
Json(request): Json<AttendantRespondRequest>,
@ -161,7 +120,6 @@ pub async fn attendant_respond(
}
};
let conn = state.conn.clone();
let session_result = tokio::task::spawn_blocking(move || {
let mut db_conn = conn.get().ok()?;
@ -189,26 +147,22 @@ pub async fn attendant_respond(
}
};
let channel = session
.context_data
.get("channel")
.and_then(|v| v.as_str())
.unwrap_or("web");
let recipient = session
.context_data
.get("phone")
.and_then(|v| v.as_str())
.unwrap_or("");
if let Err(e) = save_message_to_history(&state, &session, &request.message, "attendant").await {
error!("Failed to save attendant message: {}", e);
}
match channel {
"whatsapp" => {
if recipient.is_empty() {
@ -240,7 +194,6 @@ pub async fn attendant_respond(
match adapter.send_message(response).await {
Ok(_) => {
broadcast_attendant_action(&state, &session, &request, "attendant_response")
.await;
@ -264,7 +217,6 @@ pub async fn attendant_respond(
}
}
"web" | _ => {
let sent = if let Some(tx) = state
.response_channels
.lock()
@ -282,15 +234,14 @@ pub async fn attendant_respond(
is_complete: true,
suggestions: vec![],
context_name: None,
context_length: 0,
context_max_length: 0,
context_length: 0,
context_max_length: 0,
};
tx.send(response).await.is_ok()
} else {
false
};
broadcast_attendant_action(&state, &session, &request, "attendant_response").await;
if sent {
@ -303,7 +254,6 @@ pub async fn attendant_respond(
}),
)
} else {
(
StatusCode::OK,
Json(AttendantRespondResponse {
@ -317,7 +267,6 @@ pub async fn attendant_respond(
}
}
async fn save_message_to_history(
state: &Arc<AppState>,
session: &UserSession,
@ -356,7 +305,6 @@ async fn save_message_to_history(
Ok(())
}
async fn broadcast_attendant_action(
state: &Arc<AppState>,
session: &UserSession,
@ -396,7 +344,6 @@ async fn broadcast_attendant_action(
}
}
pub async fn attendant_websocket_handler(
ws: WebSocketUpgrade,
State(state): State<Arc<AppState>>,
@ -422,13 +369,11 @@ pub async fn attendant_websocket_handler(
.into_response()
}
async fn handle_attendant_websocket(socket: WebSocket, state: Arc<AppState>, attendant_id: String) {
let (mut sender, mut receiver) = socket.split();
info!("Attendant WebSocket connected: {}", attendant_id);
let welcome = serde_json::json!({
"type": "connected",
"attendant_id": attendant_id,
@ -447,7 +392,6 @@ async fn handle_attendant_websocket(socket: WebSocket, state: Arc<AppState>, att
}
}
let mut broadcast_rx = if let Some(broadcast_tx) = state.attendant_broadcast.as_ref() {
broadcast_tx.subscribe()
} else {
@ -455,14 +399,11 @@ async fn handle_attendant_websocket(socket: WebSocket, state: Arc<AppState>, att
return;
};
let attendant_id_clone = attendant_id.clone();
let mut send_task = tokio::spawn(async move {
loop {
match broadcast_rx.recv().await {
Ok(notification) => {
let should_send = notification.assigned_to.is_none()
|| notification.assigned_to.as_ref() == Some(&attendant_id_clone);
@ -493,7 +434,6 @@ async fn handle_attendant_websocket(socket: WebSocket, state: Arc<AppState>, att
}
});
let state_clone = state.clone();
let attendant_id_for_recv = attendant_id.clone();
let mut recv_task = tokio::spawn(async move {
@ -505,7 +445,6 @@ async fn handle_attendant_websocket(socket: WebSocket, state: Arc<AppState>, att
attendant_id_for_recv, text
);
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&text) {
handle_attendant_message(&state_clone, &attendant_id_for_recv, parsed)
.await;
@ -513,7 +452,6 @@ async fn handle_attendant_websocket(socket: WebSocket, state: Arc<AppState>, att
}
Message::Ping(data) => {
debug!("Received ping from attendant {}", attendant_id_for_recv);
}
Message::Close(_) => {
info!(
@ -527,7 +465,6 @@ async fn handle_attendant_websocket(socket: WebSocket, state: Arc<AppState>, att
}
});
tokio::select! {
_ = (&mut send_task) => {
recv_task.abort();
@ -540,7 +477,6 @@ async fn handle_attendant_websocket(socket: WebSocket, state: Arc<AppState>, att
info!("Attendant WebSocket disconnected: {}", attendant_id);
}
async fn handle_attendant_message(
state: &Arc<AppState>,
attendant_id: &str,
@ -553,24 +489,19 @@ async fn handle_attendant_message(
match msg_type {
"status_update" => {
if let Some(status) = message.get("status").and_then(|v| v.as_str()) {
info!("Attendant {} status update: {}", attendant_id, status);
}
}
"typing" => {
if let Some(session_id) = message.get("session_id").and_then(|v| v.as_str()) {
debug!(
"Attendant {} typing in session {}",
attendant_id, session_id
);
}
}
"read" => {
if let Some(session_id) = message.get("session_id").and_then(|v| v.as_str()) {
debug!(
"Attendant {} marked session {} as read",
@ -579,7 +510,6 @@ async fn handle_attendant_message(
}
}
"respond" => {
if let (Some(session_id), Some(content)) = (
message.get("session_id").and_then(|v| v.as_str()),
message.get("content").and_then(|v| v.as_str()),
@ -589,14 +519,12 @@ async fn handle_attendant_message(
attendant_id, session_id
);
let request = AttendantRespondRequest {
session_id: session_id.to_string(),
message: content.to_string(),
attendant_id: attendant_id.to_string(),
};
if let Ok(uuid) = Uuid::parse_str(session_id) {
let conn = state.conn.clone();
if let Some(session) = tokio::task::spawn_blocking(move || {
@ -611,11 +539,9 @@ async fn handle_attendant_message(
.ok()
.flatten()
{
let _ =
save_message_to_history(state, &session, content, "attendant").await;
let channel = session
.context_data
.get("channel")
@ -639,14 +565,13 @@ async fn handle_attendant_message(
is_complete: true,
suggestions: vec![],
context_name: None,
context_length: 0,
context_max_length: 0,
context_length: 0,
context_max_length: 0,
};
let _ = adapter.send_message(response).await;
}
}
broadcast_attendant_action(state, &session, &request, "attendant_response")
.await;
}

View file

@ -1,22 +1,3 @@
use crate::shared::models::UserSession;
use crate::shared::state::AppState;
use diesel::prelude::*;
@ -27,10 +8,8 @@ use serde::{Deserialize, Serialize};
use std::sync::Arc;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ReflectionType {
ConversationQuality,
ResponseAccuracy,
@ -66,7 +45,6 @@ impl From<&str> for ReflectionType {
}
impl ReflectionType {
pub fn prompt_template(&self) -> String {
match self {
ReflectionType::ConversationQuality => {
@ -190,10 +168,8 @@ Provide your analysis in JSON format:
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReflectionConfig {
pub enabled: bool,
pub interval: u32,
@ -224,7 +200,6 @@ impl Default for ReflectionConfig {
}
impl ReflectionConfig {
pub fn from_bot_config(state: &AppState, bot_id: Uuid) -> Self {
let mut config = Self::default();
@ -278,10 +253,8 @@ impl ReflectionConfig {
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReflectionResult {
pub id: Uuid,
pub bot_id: Uuid,
@ -325,7 +298,6 @@ impl ReflectionResult {
}
}
pub fn from_llm_response(
bot_id: Uuid,
session_id: Uuid,
@ -337,16 +309,13 @@ impl ReflectionResult {
result.raw_response = response.to_string();
result.messages_analyzed = messages_analyzed;
if let Ok(json) = serde_json::from_str::<serde_json::Value>(response) {
result.score = json
.get("score")
.or_else(|| json.get("overall_score"))
.and_then(|v| v.as_f64())
.unwrap_or(5.0) as f32;
if let Some(insights) = json.get("key_insights").or_else(|| json.get("insights")) {
if let Some(arr) = insights.as_array() {
result.insights = arr
@ -356,7 +325,6 @@ impl ReflectionResult {
}
}
if let Some(improvements) = json
.get("improvements")
.or_else(|| json.get("critical_improvements"))
@ -370,7 +338,6 @@ impl ReflectionResult {
}
}
if let Some(patterns) = json
.get("positive_patterns")
.or_else(|| json.get("strengths"))
@ -384,7 +351,6 @@ impl ReflectionResult {
}
}
if let Some(concerns) = json
.get("weaknesses")
.or_else(|| json.get("concerns"))
@ -399,7 +365,6 @@ impl ReflectionResult {
}
}
} else {
warn!("Reflection response was not valid JSON, extracting from plain text");
result.insights = extract_insights_from_text(response);
result.score = 5.0;
@ -408,12 +373,10 @@ impl ReflectionResult {
result
}
pub fn needs_improvement(&self, threshold: f32) -> bool {
self.score < threshold
}
pub fn summary(&self) -> String {
format!(
"Reflection Score: {:.1}/10\n\
@ -430,11 +393,9 @@ impl ReflectionResult {
}
}
fn extract_insights_from_text(text: &str) -> Vec<String> {
pub fn extract_insights_from_text(text: &str) -> Vec<String> {
let mut insights = Vec::new();
for line in text.lines() {
let trimmed = line.trim();
if trimmed.starts_with(|c: char| c.is_ascii_digit())
@ -453,7 +414,6 @@ fn extract_insights_from_text(text: &str) -> Vec<String> {
}
}
if insights.is_empty() {
for sentence in text.split(|c| c == '.' || c == '!' || c == '?') {
let trimmed = sentence.trim();
@ -467,7 +427,6 @@ fn extract_insights_from_text(text: &str) -> Vec<String> {
insights
}
pub struct ReflectionEngine {
state: Arc<AppState>,
config: ReflectionConfig,
@ -501,7 +460,6 @@ impl ReflectionEngine {
}
}
pub async fn reflect(
&self,
session_id: Uuid,
@ -511,7 +469,6 @@ impl ReflectionEngine {
return Err("Reflection is not enabled for this bot".to_string());
}
let history = self.get_recent_history(session_id, 20).await?;
if history.is_empty() {
@ -520,13 +477,10 @@ impl ReflectionEngine {
let messages_count = history.len();
let prompt = self.build_reflection_prompt(&reflection_type, &history)?;
let response = self.call_llm_for_reflection(&prompt).await?;
let result = ReflectionResult::from_llm_response(
self.bot_id,
session_id,
@ -535,10 +489,8 @@ impl ReflectionEngine {
messages_count,
);
self.store_reflection(&result).await?;
if self.config.auto_apply && result.needs_improvement(self.config.improvement_threshold) {
self.apply_improvements(&result).await?;
}
@ -551,7 +503,6 @@ impl ReflectionEngine {
Ok(result)
}
async fn get_recent_history(
&self,
session_id: Uuid,
@ -595,7 +546,6 @@ impl ReflectionEngine {
Ok(history)
}
fn build_reflection_prompt(
&self,
reflection_type: &ReflectionType,
@ -607,7 +557,6 @@ impl ReflectionEngine {
reflection_type.prompt_template()
};
let conversation = history
.iter()
.map(|m| {
@ -629,9 +578,7 @@ impl ReflectionEngine {
Ok(prompt)
}
async fn call_llm_for_reflection(&self, prompt: &str) -> Result<String, String> {
let (llm_url, llm_model, llm_key) = self.get_llm_config().await?;
let client = reqwest::Client::new();
@ -680,7 +627,6 @@ impl ReflectionEngine {
Ok(content)
}
async fn get_llm_config(&self) -> Result<(String, String, String), String> {
let mut conn = self
.state
@ -720,7 +666,6 @@ impl ReflectionEngine {
Ok((llm_url, llm_model, llm_key))
}
async fn store_reflection(&self, result: &ReflectionResult) -> Result<(), String> {
let mut conn = self
.state
@ -759,9 +704,7 @@ impl ReflectionEngine {
Ok(())
}
async fn apply_improvements(&self, result: &ReflectionResult) -> Result<(), String> {
let mut conn = self
.state
.conn
@ -796,7 +739,6 @@ impl ReflectionEngine {
Ok(())
}
pub async fn get_insights(&self, limit: usize) -> Result<Vec<ReflectionResult>, String> {
let mut conn = self
.state
@ -867,13 +809,11 @@ impl ReflectionEngine {
Ok(results)
}
pub async fn should_reflect(&self, session_id: Uuid) -> bool {
if !self.config.enabled {
return false;
}
if let Ok(mut conn) = self.state.conn.get() {
#[derive(QueryableByName)]
struct CountRow {
@ -897,7 +837,6 @@ impl ReflectionEngine {
}
}
#[derive(Debug, Clone)]
struct ConversationMessage {
pub role: String,
@ -905,14 +844,12 @@ struct ConversationMessage {
pub timestamp: chrono::DateTime<chrono::Utc>,
}
pub fn register_reflection_keywords(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
set_bot_reflection_keyword(state.clone(), user.clone(), engine);
reflect_on_keyword(state.clone(), user.clone(), engine);
get_reflection_insights_keyword(state.clone(), user.clone(), engine);
}
pub fn set_bot_reflection_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
let state_clone = Arc::clone(&state);
let user_clone = user.clone();
@ -964,7 +901,6 @@ pub fn set_bot_reflection_keyword(state: Arc<AppState>, user: UserSession, engin
.expect("Failed to register SET BOT REFLECTION syntax");
}
pub fn reflect_on_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
let state_clone = Arc::clone(&state);
let user_clone = user.clone();
@ -1016,7 +952,6 @@ pub fn reflect_on_keyword(state: Arc<AppState>, user: UserSession, engine: &mut
.expect("Failed to register REFLECT ON syntax");
}
pub fn get_reflection_insights_keyword(
state: Arc<AppState>,
user: UserSession,
@ -1045,7 +980,6 @@ pub fn get_reflection_insights_keyword(
});
}
async fn set_reflection_enabled(
state: &AppState,
bot_id: Uuid,
@ -1077,5 +1011,3 @@ async fn set_reflection_enabled(
if enabled { "enabled" } else { "disabled" }
))
}

View file

@ -1,25 +1,3 @@
use crate::shared::models::UserSession;
use crate::shared::state::AppState;
use diesel::prelude::*;
@ -33,10 +11,8 @@ use std::time::Duration;
use tokio::time::timeout;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum SandboxRuntime {
LXC,
Docker,
@ -63,7 +39,6 @@ impl From<&str> for SandboxRuntime {
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum CodeLanguage {
Python,
@ -108,10 +83,8 @@ impl CodeLanguage {
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SandboxConfig {
pub enabled: bool,
pub runtime: SandboxRuntime,
@ -151,7 +124,6 @@ impl Default for SandboxConfig {
}
impl SandboxConfig {
pub fn from_bot_config(state: &AppState, bot_id: Uuid) -> Self {
let mut config = Self::default();
@ -218,10 +190,8 @@ impl SandboxConfig {
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutionResult {
pub stdout: String,
pub stderr: String,
@ -289,7 +259,6 @@ impl ExecutionResult {
}
}
pub struct CodeSandbox {
config: SandboxConfig,
session_id: Uuid,
@ -309,7 +278,6 @@ impl CodeSandbox {
Self { config, session_id }
}
pub async fn execute(&self, code: &str, language: CodeLanguage) -> ExecutionResult {
if !self.config.enabled {
return ExecutionResult::error("Sandbox execution is disabled");
@ -339,20 +307,16 @@ impl CodeSandbox {
}
}
async fn execute_lxc(
&self,
code: &str,
language: &CodeLanguage,
) -> Result<ExecutionResult, String> {
let container_name = format!("gb-sandbox-{}", Uuid::new_v4());
std::fs::create_dir_all(&self.config.work_dir)
.map_err(|e| format!("Failed to create work dir: {}", e))?;
let code_file = format!(
"{}/{}.{}",
self.config.work_dir,
@ -362,7 +326,6 @@ impl CodeSandbox {
std::fs::write(&code_file, code)
.map_err(|e| format!("Failed to write code file: {}", e))?;
let mut cmd = Command::new("lxc-execute");
cmd.arg("-n")
.arg(&container_name)
@ -372,7 +335,6 @@ impl CodeSandbox {
.arg(language.interpreter())
.arg(&code_file);
cmd.env(
"LXC_CGROUP_MEMORY_LIMIT",
format!("{}M", self.config.memory_limit_mb),
@ -382,7 +344,6 @@ impl CodeSandbox {
format!("{}", self.config.cpu_limit_percent * 1000),
);
let timeout_duration = Duration::from_secs(self.config.timeout_seconds);
let output = timeout(timeout_duration, async {
tokio::process::Command::new("lxc-execute")
@ -398,7 +359,6 @@ impl CodeSandbox {
})
.await;
let _ = std::fs::remove_file(&code_file);
match output {
@ -422,20 +382,17 @@ impl CodeSandbox {
}
}
async fn execute_docker(
&self,
code: &str,
language: &CodeLanguage,
) -> Result<ExecutionResult, String> {
let image = match language {
CodeLanguage::Python => "python:3.11-slim",
CodeLanguage::JavaScript => "node:20-slim",
CodeLanguage::Bash => "alpine:latest",
};
let args = vec![
"run".to_string(),
"--rm".to_string(),
@ -459,7 +416,6 @@ impl CodeSandbox {
code.to_string(),
];
let timeout_duration = Duration::from_secs(self.config.timeout_seconds);
let output = timeout(timeout_duration, async {
tokio::process::Command::new("docker")
@ -490,42 +446,34 @@ impl CodeSandbox {
}
}
async fn execute_firecracker(
&self,
code: &str,
language: &CodeLanguage,
) -> Result<ExecutionResult, String> {
warn!("Firecracker runtime not yet implemented, falling back to process isolation");
self.execute_process(code, language).await
}
async fn execute_process(
&self,
code: &str,
language: &CodeLanguage,
) -> Result<ExecutionResult, String> {
let temp_dir = format!("{}/{}", self.config.work_dir, Uuid::new_v4());
std::fs::create_dir_all(&temp_dir)
.map_err(|e| format!("Failed to create temp dir: {}", e))?;
let code_file = format!("{}/code.{}", temp_dir, language.file_extension());
std::fs::write(&code_file, code)
.map_err(|e| format!("Failed to write code file: {}", e))?;
let (cmd_name, cmd_args): (&str, Vec<&str>) = match language {
CodeLanguage::Python => ("python3", vec![&code_file]),
CodeLanguage::JavaScript => ("node", vec![&code_file]),
CodeLanguage::Bash => ("bash", vec![&code_file]),
};
let timeout_duration = Duration::from_secs(self.config.timeout_seconds);
let output = timeout(timeout_duration, async {
tokio::process::Command::new(cmd_name)
@ -540,7 +488,6 @@ impl CodeSandbox {
})
.await;
let _ = std::fs::remove_dir_all(&temp_dir);
match output {
@ -564,7 +511,6 @@ impl CodeSandbox {
}
}
pub async fn execute_file(&self, file_path: &str, language: CodeLanguage) -> ExecutionResult {
match std::fs::read_to_string(file_path) {
Ok(code) => self.execute(&code, language).await,
@ -573,7 +519,6 @@ impl CodeSandbox {
}
}
pub fn register_sandbox_keywords(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
run_python_keyword(state.clone(), user.clone(), engine);
run_javascript_keyword(state.clone(), user.clone(), engine);
@ -581,7 +526,6 @@ pub fn register_sandbox_keywords(state: Arc<AppState>, user: UserSession, engine
run_file_keyword(state.clone(), user.clone(), engine);
}
pub fn run_python_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
let state_clone = Arc::clone(&state);
let user_clone = user.clone();
@ -627,7 +571,6 @@ pub fn run_python_keyword(state: Arc<AppState>, user: UserSession, engine: &mut
.expect("Failed to register RUN PYTHON syntax");
}
pub fn run_javascript_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
let state_clone = Arc::clone(&state);
let user_clone = user.clone();
@ -672,7 +615,6 @@ pub fn run_javascript_keyword(state: Arc<AppState>, user: UserSession, engine: &
)
.expect("Failed to register RUN JAVASCRIPT syntax");
let state_clone2 = Arc::clone(&state);
let user_clone2 = user.clone();
@ -711,7 +653,6 @@ pub fn run_javascript_keyword(state: Arc<AppState>, user: UserSession, engine: &
.expect("Failed to register RUN JS syntax");
}
pub fn run_bash_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
let state_clone = Arc::clone(&state);
let user_clone = user.clone();
@ -753,7 +694,6 @@ pub fn run_bash_keyword(state: Arc<AppState>, user: UserSession, engine: &mut En
.expect("Failed to register RUN BASH syntax");
}
pub fn run_file_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
let state_clone = Arc::clone(&state);
let user_clone = user.clone();
@ -802,7 +742,6 @@ pub fn run_file_keyword(state: Arc<AppState>, user: UserSession, engine: &mut En
)
.expect("Failed to register RUN PYTHON WITH FILE syntax");
let state_clone2 = Arc::clone(&state);
let user_clone2 = user.clone();
@ -847,11 +786,8 @@ pub fn run_file_keyword(state: Arc<AppState>, user: UserSession, engine: &mut En
.expect("Failed to register RUN JAVASCRIPT WITH FILE syntax");
}
pub fn generate_python_lxc_config() -> String {
r#"
r"
# LXC configuration for Python sandbox
lxc.include = /usr/share/lxc/config/common.conf
lxc.arch = linux64
@ -879,13 +815,12 @@ lxc.mount.auto = proc:mixed sys:ro
lxc.mount.entry = /usr/bin/python3 usr/bin/python3 none ro,bind 0 0
lxc.mount.entry = /usr/lib/python3 usr/lib/python3 none ro,bind 0 0
lxc.mount.entry = tmpfs tmp tmpfs defaults 0 0
"#
"
.to_string()
}
pub fn generate_node_lxc_config() -> String {
r#"
r"
# LXC configuration for Node.js sandbox
lxc.include = /usr/share/lxc/config/common.conf
lxc.arch = linux64
@ -913,8 +848,6 @@ lxc.mount.auto = proc:mixed sys:ro
lxc.mount.entry = /usr/bin/node usr/bin/node none ro,bind 0 0
lxc.mount.entry = /usr/lib/node_modules usr/lib/node_modules none ro,bind 0 0
lxc.mount.entry = tmpfs tmp tmpfs defaults 0 0
"#
"
.to_string()
}

View file

@ -1,8 +1,3 @@
use crate::llm::LLMProvider;
use crate::shared::models::UserSession;
use crate::shared::state::AppState;
@ -16,7 +11,6 @@ use std::io::Read;
use std::path::PathBuf;
use std::sync::Arc;
pub fn create_site_keyword(state: &AppState, user: UserSession, engine: &mut Engine) {
let state_clone = state.clone();
let user_clone = user.clone();
@ -58,12 +52,6 @@ pub fn create_site_keyword(state: &AppState, user: UserSession, engine: &mut Eng
.unwrap();
}
async fn create_site(
config: crate::config::AppConfig,
s3: Option<std::sync::Arc<aws_sdk_s3::Client>>,
@ -83,20 +71,16 @@ async fn create_site(
alias_str, template_dir_str
);
let base_path = PathBuf::from(&config.site_path);
let template_path = base_path.join(&template_dir_str);
let combined_content = load_templates(&template_path)?;
let generated_html = generate_html_from_prompt(llm, &combined_content, &prompt_str).await?;
let drive_path = format!("apps/{}", alias_str);
store_to_drive(&s3, &bucket, &bot_id, &drive_path, &generated_html).await?;
let serve_path = base_path.join(&alias_str);
sync_to_serve_path(&serve_path, &generated_html, &template_path).await?;
@ -108,7 +92,6 @@ async fn create_site(
Ok(format!("/apps/{}", alias_str))
}
fn load_templates(template_path: &PathBuf) -> Result<String, Box<dyn Error + Send + Sync>> {
let mut combined_content = String::new();
@ -141,7 +124,6 @@ fn load_templates(template_path: &PathBuf) -> Result<String, Box<dyn Error + Sen
Ok(combined_content)
}
async fn generate_html_from_prompt(
llm: Option<Arc<dyn LLMProvider>>,
templates: &str,
@ -209,7 +191,6 @@ OUTPUT: Complete index.html file only, no explanations."#,
Ok(html)
}
fn extract_html_from_response(response: &str) -> String {
let trimmed = response.trim();
@ -234,7 +215,6 @@ fn extract_html_from_response(response: &str) -> String {
trimmed.to_string()
}
fn generate_placeholder_html(prompt: &str) -> String {
format!(
r##"<!DOCTYPE html>
@ -278,7 +258,6 @@ fn generate_placeholder_html(prompt: &str) -> String {
)
}
async fn store_to_drive(
s3: &Option<std::sync::Arc<aws_sdk_s3::Client>>,
bucket: &str,
@ -304,7 +283,6 @@ async fn store_to_drive(
.await
.map_err(|e| format!("Failed to store to drive: {}", e))?;
let schema_key = format!("{}.gbdrive/{}/schema.json", bot_id, drive_path);
let schema = r#"{"tables": {}, "version": 1}"#;
@ -321,23 +299,19 @@ async fn store_to_drive(
Ok(())
}
async fn sync_to_serve_path(
serve_path: &PathBuf,
html_content: &str,
template_path: &PathBuf,
) -> Result<(), Box<dyn Error + Send + Sync>> {
fs::create_dir_all(serve_path).map_err(|e| format!("Failed to create serve path: {}", e))?;
let index_path = serve_path.join("index.html");
fs::write(&index_path, html_content)
.map_err(|e| format!("Failed to write index.html: {}", e))?;
info!("Written: {:?}", index_path);
let template_assets = template_path.join("_assets");
let serve_assets = serve_path.join("_assets");
@ -345,26 +319,21 @@ async fn sync_to_serve_path(
copy_dir_recursive(&template_assets, &serve_assets)?;
info!("Copied assets to: {:?}", serve_assets);
} else {
fs::create_dir_all(&serve_assets)
.map_err(|e| format!("Failed to create assets dir: {}", e))?;
let htmx_path = serve_assets.join("htmx.min.js");
if !htmx_path.exists() {
fs::write(&htmx_path, "/* HTMX - include from CDN or bundle */")
.map_err(|e| format!("Failed to write htmx: {}", e))?;
}
let styles_path = serve_assets.join("styles.css");
if !styles_path.exists() {
fs::write(&styles_path, DEFAULT_STYLES)
.map_err(|e| format!("Failed to write styles: {}", e))?;
}
let app_js_path = serve_assets.join("app.js");
if !app_js_path.exists() {
fs::write(&app_js_path, DEFAULT_APP_JS)
@ -372,7 +341,6 @@ async fn sync_to_serve_path(
}
}
let schema_path = serve_path.join("schema.json");
fs::write(&schema_path, r#"{"tables": {}, "version": 1}"#)
.map_err(|e| format!("Failed to write schema.json: {}", e))?;
@ -380,7 +348,6 @@ async fn sync_to_serve_path(
Ok(())
}
fn copy_dir_recursive(src: &PathBuf, dst: &PathBuf) -> Result<(), Box<dyn Error + Send + Sync>> {
fs::create_dir_all(dst).map_err(|e| format!("Failed to create dir {:?}: {}", dst, e))?;
@ -400,8 +367,7 @@ fn copy_dir_recursive(src: &PathBuf, dst: &PathBuf) -> Result<(), Box<dyn Error
Ok(())
}
const DEFAULT_STYLES: &str = r#"
const DEFAULT_STYLES: &str = r"
:root {
--primary: #0ea5e9;
--success: #22c55e;
@ -494,10 +460,9 @@ th, td {
.htmx-request .htmx-indicator {
opacity: 1;
}
"#;
";
const DEFAULT_APP_JS: &str = r#"
const DEFAULT_APP_JS: &str = r"
function toast(message, type = 'info') {
const el = document.createElement('div');
@ -525,4 +490,4 @@ function openModal(id) {
function closeModal(id) {
document.getElementById(id)?.classList.remove('active');
}
"#;
";

View file

@ -1,52 +1,11 @@
use chrono::{DateTime, Duration, Utc};
use rhai::{Array, Dynamic, Engine, Map};
use serde::{Deserialize, Serialize};
use tracing::info;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Episode {
pub id: Uuid,
pub user_id: Uuid,
@ -79,10 +38,8 @@ pub struct Episode {
pub metadata: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ActionItem {
pub description: String,
pub assignee: Option<String>,
@ -94,7 +51,6 @@ pub struct ActionItem {
pub completed: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Priority {
@ -110,10 +66,8 @@ impl Default for Priority {
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Sentiment {
pub score: f64,
pub label: SentimentLabel,
@ -121,7 +75,6 @@ pub struct Sentiment {
pub confidence: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum SentimentLabel {
@ -148,7 +101,6 @@ impl Default for Sentiment {
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum ResolutionStatus {
@ -165,10 +117,8 @@ impl Default for ResolutionStatus {
}
}
#[derive(Debug, Clone)]
pub struct EpisodicMemoryConfig {
pub enabled: bool,
pub threshold: usize,
@ -198,7 +148,6 @@ impl Default for EpisodicMemoryConfig {
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConversationMessage {
pub id: Uuid,
@ -207,19 +156,16 @@ pub struct ConversationMessage {
pub timestamp: DateTime<Utc>,
}
#[derive(Debug)]
pub struct EpisodicMemoryManager {
config: EpisodicMemoryConfig,
}
impl EpisodicMemoryManager {
pub fn new(config: EpisodicMemoryConfig) -> Self {
EpisodicMemoryManager { config }
}
pub fn from_config(config_map: &std::collections::HashMap<String, String>) -> Self {
let config = EpisodicMemoryConfig {
enabled: config_map
@ -254,22 +200,18 @@ impl EpisodicMemoryManager {
EpisodicMemoryManager::new(config)
}
pub fn should_summarize(&self, message_count: usize) -> bool {
self.config.enabled && self.config.auto_summarize && message_count >= self.config.threshold
}
pub fn get_history_to_keep(&self) -> usize {
self.config.history
}
pub fn get_threshold(&self) -> usize {
self.config.threshold
}
pub fn generate_summary_prompt(&self, messages: &[ConversationMessage]) -> String {
let formatted_messages = messages
.iter()
@ -309,7 +251,6 @@ Respond with valid JSON only:
)
}
pub fn parse_summary_response(
&self,
response: &str,
@ -318,7 +259,6 @@ Respond with valid JSON only:
bot_id: Uuid,
session_id: Uuid,
) -> Result<Episode, String> {
let json_str = extract_json(response)?;
let parsed: serde_json::Value =
@ -418,22 +358,18 @@ Respond with valid JSON only:
})
}
pub fn get_retention_cutoff(&self) -> DateTime<Utc> {
Utc::now() - Duration::days(self.config.retention_days as i64)
}
}
fn extract_json(response: &str) -> Result<String, String> {
pub fn extract_json(response: &str) -> Result<String, String> {
if let Some(start) = response.find("```json") {
if let Some(end) = response[start + 7..].find("```") {
return Ok(response[start + 7..start + 7 + end].trim().to_string());
}
}
if let Some(start) = response.find("```") {
let after_start = start + 3;
@ -446,7 +382,6 @@ fn extract_json(response: &str) -> Result<String, String> {
}
}
if let Some(start) = response.find('{') {
if let Some(end) = response.rfind('}') {
if end > start {
@ -458,7 +393,6 @@ fn extract_json(response: &str) -> Result<String, String> {
Err("No JSON found in response".to_string())
}
impl Episode {
pub fn to_dynamic(&self) -> Dynamic {
let mut map = Map::new();
@ -531,12 +465,7 @@ impl Episode {
}
}
pub fn register_episodic_memory_keywords(engine: &mut Engine) {
engine.register_fn("episode_summary", |episode: Map| -> String {
episode
.get("summary")
@ -584,7 +513,6 @@ pub fn register_episodic_memory_keywords(engine: &mut Engine) {
info!("Episodic memory keywords registered");
}
pub const EPISODIC_MEMORY_SCHEMA: &str = r#"
-- Conversation episodes (summaries)
CREATE TABLE IF NOT EXISTS conversation_episodes (
@ -619,9 +547,8 @@ CREATE INDEX IF NOT EXISTS idx_episodes_summary_fts ON conversation_episodes
USING GIN(to_tsvector('english', summary));
"#;
pub mod sql {
pub const INSERT_EPISODE: &str = r#"
pub const INSERT_EPISODE: &str = r"
INSERT INTO conversation_episodes (
id, user_id, bot_id, session_id, summary, key_topics, decisions,
action_items, sentiment, resolution, message_count, message_ids,
@ -629,22 +556,22 @@ pub mod sql {
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16
)
"#;
";
pub const GET_EPISODES_BY_USER: &str = r#"
pub const GET_EPISODES_BY_USER: &str = r"
SELECT * FROM conversation_episodes
WHERE user_id = $1
ORDER BY created_at DESC
LIMIT $2
"#;
";
pub const GET_EPISODES_BY_SESSION: &str = r#"
pub const GET_EPISODES_BY_SESSION: &str = r"
SELECT * FROM conversation_episodes
WHERE session_id = $1
ORDER BY created_at DESC
"#;
";
pub const SEARCH_EPISODES: &str = r#"
pub const SEARCH_EPISODES: &str = r"
SELECT * FROM conversation_episodes
WHERE user_id = $1
AND (
@ -653,19 +580,19 @@ pub mod sql {
)
ORDER BY created_at DESC
LIMIT $4
"#;
";
pub const DELETE_OLD_EPISODES: &str = r#"
pub const DELETE_OLD_EPISODES: &str = r"
DELETE FROM conversation_episodes
WHERE created_at < $1
"#;
";
pub const COUNT_USER_EPISODES: &str = r#"
pub const COUNT_USER_EPISODES: &str = r"
SELECT COUNT(*) FROM conversation_episodes
WHERE user_id = $1
"#;
";
pub const DELETE_OLDEST_EPISODES: &str = r#"
pub const DELETE_OLDEST_EPISODES: &str = r"
DELETE FROM conversation_episodes
WHERE id IN (
SELECT id FROM conversation_episodes
@ -673,5 +600,5 @@ pub mod sql {
ORDER BY created_at ASC
LIMIT $2
)
"#;
";
}

View file

@ -3,7 +3,7 @@ use rhai::Engine;
pub fn first_keyword(engine: &mut Engine) {
engine
.register_custom_syntax(&["FIRST", "$expr$"], false, {
.register_custom_syntax(["FIRST", "$expr$"], false, {
move |context, inputs| {
let input_string = context.eval_expression_tree(&inputs[0])?;
let input_str = input_string.to_string();

View file

@ -1,29 +1,3 @@
use crate::shared::message_types::MessageType;
use crate::shared::models::{BotResponse, UserSession};
use crate::shared::state::AppState;
@ -34,8 +8,7 @@ use serde::{Deserialize, Serialize};
use std::sync::Arc;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum InputType {
Any,
Email,
@ -67,81 +40,74 @@ pub enum InputType {
}
impl InputType {
#[must_use]
pub fn error_message(&self) -> String {
match self {
InputType::Any => "".to_string(),
InputType::Email => {
Self::Any => String::new(),
Self::Email => {
"Please enter a valid email address (e.g., user@example.com)".to_string()
}
InputType::Date => {
"Please enter a valid date (e.g., 25/12/2024 or 2024-12-25)".to_string()
}
InputType::Name => "Please enter a valid name (letters and spaces only)".to_string(),
InputType::Integer => "Please enter a valid whole number".to_string(),
InputType::Float => "Please enter a valid number".to_string(),
InputType::Boolean => "Please answer yes or no".to_string(),
InputType::Hour => "Please enter a valid time (e.g., 14:30 or 2:30 PM)".to_string(),
InputType::Money => {
"Please enter a valid amount (e.g., 100.00 or R$ 100,00)".to_string()
}
InputType::Mobile => "Please enter a valid mobile number".to_string(),
InputType::Zipcode => "Please enter a valid ZIP/postal code".to_string(),
InputType::Language => {
"Please enter a valid language code (e.g., en, pt, es)".to_string()
}
InputType::Cpf => "Please enter a valid CPF (11 digits)".to_string(),
InputType::Cnpj => "Please enter a valid CNPJ (14 digits)".to_string(),
InputType::QrCode => "Please send an image containing a QR code".to_string(),
InputType::Login => "Please complete the authentication process".to_string(),
InputType::Menu(options) => format!("Please select one of: {}", options.join(", ")),
InputType::File => "Please upload a file".to_string(),
InputType::Image => "Please send an image".to_string(),
InputType::Audio => "Please send an audio file or voice message".to_string(),
InputType::Video => "Please send a video".to_string(),
InputType::Document => "Please send a document (PDF, Word, etc.)".to_string(),
InputType::Url => "Please enter a valid URL".to_string(),
InputType::Uuid => "Please enter a valid UUID".to_string(),
InputType::Color => "Please enter a valid color (e.g., #FF0000 or red)".to_string(),
InputType::CreditCard => "Please enter a valid credit card number".to_string(),
InputType::Password => "Please enter a password (minimum 8 characters)".to_string(),
Self::Date => "Please enter a valid date (e.g., 25/12/2024 or 2024-12-25)".to_string(),
Self::Name => "Please enter a valid name (letters and spaces only)".to_string(),
Self::Integer => "Please enter a valid whole number".to_string(),
Self::Float => "Please enter a valid number".to_string(),
Self::Boolean => "Please answer yes or no".to_string(),
Self::Hour => "Please enter a valid time (e.g., 14:30 or 2:30 PM)".to_string(),
Self::Money => "Please enter a valid amount (e.g., 100.00 or R$ 100,00)".to_string(),
Self::Mobile => "Please enter a valid mobile number".to_string(),
Self::Zipcode => "Please enter a valid ZIP/postal code".to_string(),
Self::Language => "Please enter a valid language code (e.g., en, pt, es)".to_string(),
Self::Cpf => "Please enter a valid CPF (11 digits)".to_string(),
Self::Cnpj => "Please enter a valid CNPJ (14 digits)".to_string(),
Self::QrCode => "Please send an image containing a QR code".to_string(),
Self::Login => "Please complete the authentication process".to_string(),
Self::Menu(options) => format!("Please select one of: {}", options.join(", ")),
Self::File => "Please upload a file".to_string(),
Self::Image => "Please send an image".to_string(),
Self::Audio => "Please send an audio file or voice message".to_string(),
Self::Video => "Please send a video".to_string(),
Self::Document => "Please send a document (PDF, Word, etc.)".to_string(),
Self::Url => "Please enter a valid URL".to_string(),
Self::Uuid => "Please enter a valid UUID".to_string(),
Self::Color => "Please enter a valid color (e.g., #FF0000 or red)".to_string(),
Self::CreditCard => "Please enter a valid credit card number".to_string(),
Self::Password => "Please enter a password (minimum 8 characters)".to_string(),
}
}
#[must_use]
pub fn from_str(s: &str) -> Self {
match s.to_uppercase().as_str() {
"EMAIL" => InputType::Email,
"DATE" => InputType::Date,
"NAME" => InputType::Name,
"INTEGER" | "INT" | "NUMBER" => InputType::Integer,
"FLOAT" | "DECIMAL" | "DOUBLE" => InputType::Float,
"BOOLEAN" | "BOOL" => InputType::Boolean,
"HOUR" | "TIME" => InputType::Hour,
"MONEY" | "CURRENCY" | "AMOUNT" => InputType::Money,
"MOBILE" | "PHONE" | "TELEPHONE" => InputType::Mobile,
"ZIPCODE" | "ZIP" | "CEP" | "POSTALCODE" => InputType::Zipcode,
"LANGUAGE" | "LANG" => InputType::Language,
"CPF" => InputType::Cpf,
"CNPJ" => InputType::Cnpj,
"QRCODE" | "QR" => InputType::QrCode,
"LOGIN" | "AUTH" => InputType::Login,
"FILE" => InputType::File,
"IMAGE" | "PHOTO" | "PICTURE" => InputType::Image,
"AUDIO" | "VOICE" | "SOUND" => InputType::Audio,
"VIDEO" => InputType::Video,
"DOCUMENT" | "DOC" | "PDF" => InputType::Document,
"URL" | "LINK" => InputType::Url,
"UUID" | "GUID" => InputType::Uuid,
"COLOR" | "COLOUR" => InputType::Color,
"CREDITCARD" | "CARD" => InputType::CreditCard,
"PASSWORD" | "PASS" | "SECRET" => InputType::Password,
_ => InputType::Any,
"EMAIL" => Self::Email,
"DATE" => Self::Date,
"NAME" => Self::Name,
"INTEGER" | "INT" | "NUMBER" => Self::Integer,
"FLOAT" | "DECIMAL" | "DOUBLE" => Self::Float,
"BOOLEAN" | "BOOL" => Self::Boolean,
"HOUR" | "TIME" => Self::Hour,
"MONEY" | "CURRENCY" | "AMOUNT" => Self::Money,
"MOBILE" | "PHONE" | "TELEPHONE" => Self::Mobile,
"ZIPCODE" | "ZIP" | "CEP" | "POSTALCODE" => Self::Zipcode,
"LANGUAGE" | "LANG" => Self::Language,
"CPF" => Self::Cpf,
"CNPJ" => Self::Cnpj,
"QRCODE" | "QR" => Self::QrCode,
"LOGIN" | "AUTH" => Self::Login,
"FILE" => Self::File,
"IMAGE" | "PHOTO" | "PICTURE" => Self::Image,
"AUDIO" | "VOICE" | "SOUND" => Self::Audio,
"VIDEO" => Self::Video,
"DOCUMENT" | "DOC" | "PDF" => Self::Document,
"URL" | "LINK" => Self::Url,
"UUID" | "GUID" => Self::Uuid,
"COLOR" | "COLOUR" => Self::Color,
"CREDITCARD" | "CARD" => Self::CreditCard,
"PASSWORD" | "PASS" | "SECRET" => Self::Password,
_ => Self::Any,
}
}
}
#[derive(Debug, Clone)]
pub struct ValidationResult {
pub is_valid: bool,
@ -151,6 +117,7 @@ pub struct ValidationResult {
}
impl ValidationResult {
#[must_use]
pub fn valid(value: String) -> Self {
Self {
is_valid: true,
@ -160,6 +127,7 @@ impl ValidationResult {
}
}
#[must_use]
pub fn valid_with_metadata(value: String, metadata: serde_json::Value) -> Self {
Self {
is_valid: true,
@ -169,6 +137,7 @@ impl ValidationResult {
}
}
#[must_use]
pub fn invalid(error: String) -> Self {
Self {
is_valid: false,
@ -179,26 +148,20 @@ impl ValidationResult {
}
}
pub fn hear_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
register_hear_basic(state.clone(), user.clone(), engine);
register_hear_as_type(state.clone(), user.clone(), engine);
register_hear_as_menu(state.clone(), user.clone(), engine);
}
fn register_hear_basic(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
let session_id = user.id;
let state_clone = Arc::clone(&state);
engine
.register_custom_syntax(&["HEAR", "$ident$"], true, move |_context, inputs| {
let variable_name = inputs[0]
.get_string_value()
.expect("Expected identifier as string")
@ -223,10 +186,9 @@ fn register_hear_basic(state: Arc<AppState>, user: UserSession, engine: &mut Eng
let mut session_manager = state_for_spawn.session_manager.lock().await;
session_manager.mark_waiting(session_id_clone);
if let Some(redis_client) = &state_for_spawn.cache {
if let Ok(mut conn) = redis_client.get_multiplexed_async_connection().await {
let key = format!("hear:{}:{}", session_id_clone, var_name_clone);
let key = format!("hear:{session_id_clone}:{var_name_clone}");
let wait_data = serde_json::json!({
"variable": var_name_clone,
"type": "any",
@ -252,7 +214,6 @@ fn register_hear_basic(state: Arc<AppState>, user: UserSession, engine: &mut Eng
.unwrap();
}
fn register_hear_as_type(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
let session_id = user.id;
let state_clone = Arc::clone(&state);
@ -262,7 +223,6 @@ fn register_hear_as_type(state: Arc<AppState>, user: UserSession, engine: &mut E
&["HEAR", "$ident$", "AS", "$ident$"],
true,
move |_context, inputs| {
let variable_name = inputs[0]
.get_string_value()
.expect("Expected identifier for variable")
@ -274,11 +234,7 @@ fn register_hear_as_type(state: Arc<AppState>, user: UserSession, engine: &mut E
let _input_type = InputType::from_str(&type_name);
trace!(
"HEAR {} AS {} - waiting for validated input",
variable_name,
type_name
);
trace!("HEAR {variable_name} AS {type_name} - waiting for validated input");
let state_for_spawn = Arc::clone(&state_clone);
let session_id_clone = session_id;
@ -289,11 +245,10 @@ fn register_hear_as_type(state: Arc<AppState>, user: UserSession, engine: &mut E
let mut session_manager = state_for_spawn.session_manager.lock().await;
session_manager.mark_waiting(session_id_clone);
if let Some(redis_client) = &state_for_spawn.cache {
if let Ok(mut conn) = redis_client.get_multiplexed_async_connection().await
{
let key = format!("hear:{}:{}", session_id_clone, var_name_clone);
let key = format!("hear:{session_id_clone}:{var_name_clone}");
let wait_data = serde_json::json!({
"variable": var_name_clone,
"type": type_clone.to_lowercase(),
@ -321,44 +276,34 @@ fn register_hear_as_type(state: Arc<AppState>, user: UserSession, engine: &mut E
.unwrap();
}
fn register_hear_as_menu(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
let session_id = user.id;
let state_clone = Arc::clone(&state);
engine
.register_custom_syntax(
&["HEAR", "$ident$", "AS", "$expr$"],
true,
move |context, inputs| {
let variable_name = inputs[0]
.get_string_value()
.expect("Expected identifier for variable")
.to_lowercase();
let options_expr = context.eval_expression_tree(&inputs[1])?;
let options_str = options_expr.to_string();
let input_type = InputType::from_str(&options_str);
if input_type != InputType::Any {
return Err(Box::new(EvalAltResult::ErrorRuntime(
"Use HEAR AS TYPE syntax".into(),
rhai::Position::NONE,
)));
}
let options: Vec<String> = if options_str.starts_with('[') {
serde_json::from_str(&options_str).unwrap_or_default()
} else {
options_str
.split(',')
.map(|s| s.trim().trim_matches('"').to_string())
@ -384,11 +329,10 @@ fn register_hear_as_menu(state: Arc<AppState>, user: UserSession, engine: &mut E
let mut session_manager = state_for_spawn.session_manager.lock().await;
session_manager.mark_waiting(session_id_clone);
if let Some(redis_client) = &state_for_spawn.cache {
if let Ok(mut conn) = redis_client.get_multiplexed_async_connection().await
{
let key = format!("hear:{}:{}", session_id_clone, var_name_clone);
let key = format!("hear:{session_id_clone}:{var_name_clone}");
let wait_data = serde_json::json!({
"variable": var_name_clone,
"type": "menu",
@ -404,9 +348,8 @@ fn register_hear_as_menu(state: Arc<AppState>, user: UserSession, engine: &mut E
.query_async(&mut conn)
.await;
let suggestions_key =
format!("suggestions:{}:{}", session_id_clone, session_id_clone);
format!("suggestions:{session_id_clone}:{session_id_clone}");
for opt in &options_clone {
let suggestion = serde_json::json!({
"text": opt,
@ -431,15 +374,19 @@ fn register_hear_as_menu(state: Arc<AppState>, user: UserSession, engine: &mut E
.unwrap();
}
#[must_use]
pub fn validate_input(input: &str, input_type: &InputType) -> ValidationResult {
let trimmed = input.trim();
match input_type {
InputType::Any => ValidationResult::valid(trimmed.to_string()),
InputType::Any
| InputType::QrCode
| InputType::File
| InputType::Image
| InputType::Audio
| InputType::Video
| InputType::Document
| InputType::Login => ValidationResult::valid(trimmed.to_string()),
InputType::Email => validate_email(trimmed),
InputType::Date => validate_date(trimmed),
InputType::Name => validate_name(trimmed),
@ -458,18 +405,7 @@ pub fn validate_input(input: &str, input_type: &InputType) -> ValidationResult {
InputType::Color => validate_color(trimmed),
InputType::CreditCard => validate_credit_card(trimmed),
InputType::Password => validate_password(trimmed),
InputType::Menu(options) => validate_menu(trimmed, options),
InputType::QrCode
| InputType::File
| InputType::Image
| InputType::Audio
| InputType::Video
| InputType::Document => ValidationResult::valid(trimmed.to_string()),
InputType::Login => ValidationResult::valid(trimmed.to_string()),
}
}
@ -486,32 +422,23 @@ fn validate_email(input: &str) -> ValidationResult {
}
fn validate_date(input: &str) -> ValidationResult {
let formats = [
"%d/%m/%Y",
"%d-%m-%Y",
"%Y-%m-%d",
"%Y/%m/%d",
"%d.%m.%Y",
"%m/%d/%Y",
"%d %b %Y",
"%d/%m/%Y", "%d-%m-%Y", "%Y-%m-%d", "%Y/%m/%d", "%d.%m.%Y", "%m/%d/%Y", "%d %b %Y",
"%d %B %Y",
];
for format in &formats {
if let Ok(date) = chrono::NaiveDate::parse_from_str(input, format) {
return ValidationResult::valid_with_metadata(
date.format("%Y-%m-%d").to_string(),
serde_json::json!({
"original": input,
"parsed_format": format
"parsed_format": *format
}),
);
}
}
let lower = input.to_lowercase();
let today = chrono::Local::now().date_naive();
@ -537,7 +464,6 @@ fn validate_date(input: &str) -> ValidationResult {
}
fn validate_name(input: &str) -> ValidationResult {
let name_regex = Regex::new(r"^[\p{L}\s\-']+$").unwrap();
if input.len() < 2 {
@ -549,7 +475,6 @@ fn validate_name(input: &str) -> ValidationResult {
}
if name_regex.is_match(input) {
let normalized = input
.split_whitespace()
.map(|word| {
@ -568,11 +493,10 @@ fn validate_name(input: &str) -> ValidationResult {
}
fn validate_integer(input: &str) -> ValidationResult {
let cleaned = input
.replace(",", "")
.replace(".", "")
.replace(" ", "")
.replace(',', "")
.replace('.', "")
.replace(' ', "")
.trim()
.to_string();
@ -586,7 +510,6 @@ fn validate_integer(input: &str) -> ValidationResult {
}
fn validate_float(input: &str) -> ValidationResult {
let cleaned = input.replace(" ", "").replace(",", ".").trim().to_string();
match cleaned.parse::<f64>() {
@ -644,7 +567,6 @@ fn validate_boolean(input: &str) -> ValidationResult {
}
fn validate_hour(input: &str) -> ValidationResult {
let time_24_regex = Regex::new(r"^([01]?\d|2[0-3]):([0-5]\d)$").unwrap();
if let Some(caps) = time_24_regex.captures(input) {
let hour: u32 = caps[1].parse().unwrap();
@ -655,7 +577,6 @@ fn validate_hour(input: &str) -> ValidationResult {
);
}
let time_12_regex =
Regex::new(r"^(1[0-2]|0?[1-9]):([0-5]\d)\s*(AM|PM|am|pm|a\.m\.|p\.m\.)$").unwrap();
if let Some(caps) = time_12_regex.captures(input) {
@ -679,7 +600,6 @@ fn validate_hour(input: &str) -> ValidationResult {
}
fn validate_money(input: &str) -> ValidationResult {
let cleaned = input
.replace("R$", "")
.replace("$", "")
@ -690,21 +610,16 @@ fn validate_money(input: &str) -> ValidationResult {
.trim()
.to_string();
let normalized = if cleaned.contains(',') && cleaned.contains('.') {
let last_comma = cleaned.rfind(',').unwrap_or(0);
let last_dot = cleaned.rfind('.').unwrap_or(0);
if last_comma > last_dot {
cleaned.replace(".", "").replace(",", ".")
cleaned.replace('.', "").replace(',', ".")
} else {
cleaned.replace(",", "")
}
} else if cleaned.contains(',') {
cleaned.replace(",", ".")
} else {
cleaned
@ -720,26 +635,19 @@ fn validate_money(input: &str) -> ValidationResult {
}
fn validate_mobile(input: &str) -> ValidationResult {
let digits: String = input.chars().filter(|c| c.is_ascii_digit()).collect();
if digits.len() < 10 || digits.len() > 15 {
return ValidationResult::invalid(InputType::Mobile.error_message());
}
let formatted = if digits.len() == 11 && digits.starts_with('9') {
format!("({}) {}-{}", &digits[0..2], &digits[2..7], &digits[7..11])
} else if digits.len() == 11 {
format!("({}) {}-{}", &digits[0..2], &digits[2..7], &digits[7..11])
} else if digits.len() == 10 {
format!("({}) {}-{}", &digits[0..3], &digits[3..6], &digits[6..10])
} else {
format!("+{}", digits)
};
@ -750,13 +658,11 @@ fn validate_mobile(input: &str) -> ValidationResult {
}
fn validate_zipcode(input: &str) -> ValidationResult {
let cleaned: String = input
.chars()
.filter(|c| c.is_ascii_alphanumeric())
.collect();
if cleaned.len() == 8 && cleaned.chars().all(|c| c.is_ascii_digit()) {
let formatted = format!("{}-{}", &cleaned[0..5], &cleaned[5..8]);
return ValidationResult::valid_with_metadata(
@ -765,10 +671,9 @@ fn validate_zipcode(input: &str) -> ValidationResult {
);
}
if (cleaned.len() == 5 || cleaned.len() == 9) && cleaned.chars().all(|c| c.is_ascii_digit()) {
let formatted = if cleaned.len() == 9 {
format!("{}-{}", &cleaned[0..5], &cleaned[5..9])
format!("{}-{}", &cleaned[0..5], &cleaned[5..8])
} else {
cleaned.clone()
};
@ -778,7 +683,6 @@ fn validate_zipcode(input: &str) -> ValidationResult {
);
}
let uk_regex = Regex::new(r"^[A-Z]{1,2}\d[A-Z\d]?\s?\d[A-Z]{2}$").unwrap();
if uk_regex.is_match(&cleaned.to_uppercase()) {
return ValidationResult::valid_with_metadata(
@ -793,7 +697,6 @@ fn validate_zipcode(input: &str) -> ValidationResult {
fn validate_language(input: &str) -> ValidationResult {
let lower = input.to_lowercase().trim().to_string();
let languages = [
("en", "english", "inglês", "ingles"),
("pt", "portuguese", "português", "portugues"),
@ -827,7 +730,6 @@ fn validate_language(input: &str) -> ValidationResult {
}
}
if lower.len() == 2 && lower.chars().all(|c| c.is_ascii_lowercase()) {
return ValidationResult::valid(lower);
}
@ -836,21 +738,17 @@ fn validate_language(input: &str) -> ValidationResult {
}
fn validate_cpf(input: &str) -> ValidationResult {
let digits: String = input.chars().filter(|c| c.is_ascii_digit()).collect();
if digits.len() != 11 {
return ValidationResult::invalid(InputType::Cpf.error_message());
}
if digits.chars().all(|c| c == digits.chars().next().unwrap()) {
return ValidationResult::invalid("Invalid CPF".to_string());
}
let digits_vec: Vec<u32> = digits.chars().map(|c| c.to_digit(10).unwrap()).collect();
let digits_vec: Vec<u32> = digits.chars().filter_map(|c| c.to_digit(10)).collect();
let sum1: u32 = digits_vec[0..9]
.iter()
@ -864,7 +762,6 @@ fn validate_cpf(input: &str) -> ValidationResult {
return ValidationResult::invalid("Invalid CPF".to_string());
}
let sum2: u32 = digits_vec[0..10]
.iter()
.enumerate()
@ -877,7 +774,6 @@ fn validate_cpf(input: &str) -> ValidationResult {
return ValidationResult::invalid("Invalid CPF".to_string());
}
let formatted = format!(
"{}.{}.{}-{}",
&digits[0..3],
@ -893,16 +789,13 @@ fn validate_cpf(input: &str) -> ValidationResult {
}
fn validate_cnpj(input: &str) -> ValidationResult {
let digits: String = input.chars().filter(|c| c.is_ascii_digit()).collect();
if digits.len() != 14 {
return ValidationResult::invalid(InputType::Cnpj.error_message());
}
let digits_vec: Vec<u32> = digits.chars().map(|c| c.to_digit(10).unwrap()).collect();
let digits_vec: Vec<u32> = digits.chars().filter_map(|c| c.to_digit(10)).collect();
let weights1 = [5, 4, 3, 2, 9, 8, 7, 6, 5, 4, 3, 2];
let sum1: u32 = digits_vec[0..12]
@ -917,7 +810,6 @@ fn validate_cnpj(input: &str) -> ValidationResult {
return ValidationResult::invalid("Invalid CNPJ".to_string());
}
let weights2 = [6, 5, 4, 3, 2, 9, 8, 7, 6, 5, 4, 3, 2];
let sum2: u32 = digits_vec[0..13]
.iter()
@ -931,7 +823,6 @@ fn validate_cnpj(input: &str) -> ValidationResult {
return ValidationResult::invalid("Invalid CNPJ".to_string());
}
let formatted = format!(
"{}.{}.{}/{}-{}",
&digits[0..2],
@ -948,9 +839,8 @@ fn validate_cnpj(input: &str) -> ValidationResult {
}
fn validate_url(input: &str) -> ValidationResult {
let url_str = if !input.starts_with("http://") && !input.starts_with("https://") {
format!("https://{}", input)
format!("https://{input}")
} else {
input.to_string()
};
@ -976,7 +866,6 @@ fn validate_uuid(input: &str) -> ValidationResult {
fn validate_color(input: &str) -> ValidationResult {
let lower = input.to_lowercase().trim().to_string();
let named_colors = [
("red", "#FF0000"),
("green", "#00FF00"),
@ -1003,7 +892,6 @@ fn validate_color(input: &str) -> ValidationResult {
}
}
let hex_regex = Regex::new(r"^#?([A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$").unwrap();
if let Some(caps) = hex_regex.captures(&lower) {
let hex = caps[1].to_uppercase();
@ -1017,7 +905,6 @@ fn validate_color(input: &str) -> ValidationResult {
return ValidationResult::valid(format!("#{}", full_hex));
}
let rgb_regex =
Regex::new(r"^rgb\s*\(\s*(\d{1,3})\s*,\s*(\d{1,3})\s*,\s*(\d{1,3})\s*\)$").unwrap();
if let Some(caps) = rgb_regex.captures(&lower) {
@ -1031,14 +918,12 @@ fn validate_color(input: &str) -> ValidationResult {
}
fn validate_credit_card(input: &str) -> ValidationResult {
let digits: String = input.chars().filter(|c| c.is_ascii_digit()).collect();
if digits.len() < 13 || digits.len() > 19 {
return ValidationResult::invalid(InputType::CreditCard.error_message());
}
let mut sum = 0;
let mut double = false;
@ -1058,7 +943,6 @@ fn validate_credit_card(input: &str) -> ValidationResult {
return ValidationResult::invalid("Invalid card number".to_string());
}
let card_type = if digits.starts_with('4') {
"Visa"
} else if digits.starts_with("51")
@ -1078,7 +962,6 @@ fn validate_credit_card(input: &str) -> ValidationResult {
"Unknown"
};
let masked = format!(
"{} **** **** {}",
&digits[0..4],
@ -1113,7 +996,6 @@ fn validate_password(input: &str) -> ValidationResult {
_ => "weak",
};
ValidationResult::valid_with_metadata(
"[PASSWORD SET]".to_string(),
serde_json::json!({
@ -1126,7 +1008,6 @@ fn validate_password(input: &str) -> ValidationResult {
fn validate_menu(input: &str, options: &[String]) -> ValidationResult {
let lower_input = input.to_lowercase().trim().to_string();
for (i, opt) in options.iter().enumerate() {
if opt.to_lowercase() == lower_input {
return ValidationResult::valid_with_metadata(
@ -1136,7 +1017,6 @@ fn validate_menu(input: &str, options: &[String]) -> ValidationResult {
}
}
if let Ok(num) = lower_input.parse::<usize>() {
if num >= 1 && num <= options.len() {
let selected = &options[num - 1];
@ -1147,7 +1027,6 @@ fn validate_menu(input: &str, options: &[String]) -> ValidationResult {
}
}
let matches: Vec<&String> = options
.iter()
.filter(|opt| opt.to_lowercase().contains(&lower_input))
@ -1161,11 +1040,10 @@ fn validate_menu(input: &str, options: &[String]) -> ValidationResult {
);
}
ValidationResult::invalid(format!("Please select one of: {}", options.join(", ")))
let opts = options.join(", ");
ValidationResult::invalid(format!("Please select one of: {opts}"))
}
pub async fn execute_talk(
state: Arc<AppState>,
user_session: UserSession,
@ -1173,7 +1051,6 @@ pub async fn execute_talk(
) -> Result<BotResponse, Box<dyn std::error::Error + Send + Sync>> {
let mut suggestions = Vec::new();
if let Some(redis_client) = &state.cache {
if let Ok(mut conn) = redis_client.get_multiplexed_async_connection().await {
let redis_key = format!("suggestions:{}:{}", user_session.user_id, user_session.id);
@ -1212,7 +1089,6 @@ pub async fn execute_talk(
let user_id = user_session.id.to_string();
let response_clone = response.clone();
let web_adapter = Arc::clone(&state.web_adapter);
tokio::spawn(async move {
if let Err(e) = web_adapter
@ -1249,9 +1125,6 @@ pub fn talk_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine
.unwrap();
}
pub async fn process_hear_input(
state: &AppState,
session_id: Uuid,
@ -1259,10 +1132,9 @@ pub async fn process_hear_input(
input: &str,
attachments: Option<Vec<crate::shared::models::Attachment>>,
) -> Result<(String, Option<serde_json::Value>), String> {
let wait_data = if let Some(redis_client) = &state.cache {
if let Ok(mut conn) = redis_client.get_multiplexed_async_connection().await {
let key = format!("hear:{}:{}", session_id, variable_name);
let key = format!("hear:{session_id}:{variable_name}");
let data: Result<String, _> = redis::cmd("GET").arg(&key).query_async(&mut conn).await;
@ -1293,14 +1165,12 @@ pub async fn process_hear_input(
.collect::<Vec<_>>()
});
let validation_type = if let Some(opts) = options {
InputType::Menu(opts)
} else {
InputType::from_str(input_type)
};
match validation_type {
InputType::Image | InputType::QrCode => {
if let Some(atts) = &attachments {
@ -1309,7 +1179,6 @@ pub async fn process_hear_input(
.find(|a| a.mime_type.as_deref().unwrap_or("").starts_with("image/"))
{
if validation_type == InputType::QrCode {
return process_qrcode(state, &img.url).await;
}
return Ok((
@ -1326,7 +1195,6 @@ pub async fn process_hear_input(
.iter()
.find(|a| a.mime_type.as_deref().unwrap_or("").starts_with("audio/"))
{
return process_audio_to_text(state, &audio.url).await;
}
}
@ -1338,7 +1206,6 @@ pub async fn process_hear_input(
.iter()
.find(|a| a.mime_type.as_deref().unwrap_or("").starts_with("video/"))
{
return process_video_description(state, &video.url).await;
}
}
@ -1358,14 +1225,12 @@ pub async fn process_hear_input(
_ => {}
}
let result = validate_input(input, &validation_type);
if result.is_valid {
if let Some(redis_client) = &state.cache {
if let Ok(mut conn) = redis_client.get_multiplexed_async_connection().await {
let key = format!("hear:{}:{}", session_id, variable_name);
let key = format!("hear:{session_id}:{variable_name}");
let _: Result<(), _> = redis::cmd("DEL").arg(&key).query_async(&mut conn).await;
}
}
@ -1378,12 +1243,10 @@ pub async fn process_hear_input(
}
}
async fn process_qrcode(
state: &AppState,
image_url: &str,
) -> Result<(String, Option<serde_json::Value>), String> {
let botmodels_url = {
let config_url = state.conn.get().ok().and_then(|mut conn| {
use crate::shared::models::schema::bot_memories::dsl::*;
@ -1401,7 +1264,6 @@ async fn process_qrcode(
let client = reqwest::Client::new();
let image_data = client
.get(image_url)
.send()
@ -1409,11 +1271,10 @@ async fn process_qrcode(
.map_err(|e| format!("Failed to download image: {}", e))?
.bytes()
.await
.map_err(|e| format!("Failed to read image: {}", e))?;
.map_err(|e| format!("Failed to fetch image: {e}"))?;
let response = client
.post(format!("{}/api/v1/vision/qrcode", botmodels_url))
.post(format!("{botmodels_url}/api/v1/vision/qrcode"))
.header("Content-Type", "application/octet-stream")
.body(image_data.to_vec())
.send()
@ -1424,7 +1285,7 @@ async fn process_qrcode(
let result: serde_json::Value = response
.json()
.await
.map_err(|e| format!("Failed to parse response: {}", e))?;
.map_err(|e| format!("Failed to read image: {e}"))?;
if let Some(qr_data) = result.get("data").and_then(|d| d.as_str()) {
Ok((
@ -1442,7 +1303,6 @@ async fn process_qrcode(
}
}
async fn process_audio_to_text(
_state: &AppState,
audio_url: &str,
@ -1452,7 +1312,6 @@ async fn process_audio_to_text(
let client = reqwest::Client::new();
let audio_data = client
.get(audio_url)
.send()
@ -1460,11 +1319,10 @@ async fn process_audio_to_text(
.map_err(|e| format!("Failed to download audio: {}", e))?
.bytes()
.await
.map_err(|e| format!("Failed to read audio: {}", e))?;
.map_err(|e| format!("Failed to read audio: {e}"))?;
let response = client
.post(format!("{}/api/v1/speech/to-text", botmodels_url))
.post(format!("{botmodels_url}/api/v1/speech/to-text"))
.header("Content-Type", "application/octet-stream")
.body(audio_data.to_vec())
.send()
@ -1494,7 +1352,6 @@ async fn process_audio_to_text(
}
}
async fn process_video_description(
_state: &AppState,
video_url: &str,
@ -1504,7 +1361,6 @@ async fn process_video_description(
let client = reqwest::Client::new();
let video_data = client
.get(video_url)
.send()
@ -1512,16 +1368,15 @@ async fn process_video_description(
.map_err(|e| format!("Failed to download video: {}", e))?
.bytes()
.await
.map_err(|e| format!("Failed to read video: {}", e))?;
.map_err(|e| format!("Failed to fetch video: {e}"))?;
let response = client
.post(format!("{}/api/v1/vision/describe-video", botmodels_url))
.post(format!("{botmodels_url}/api/v1/vision/describe-video"))
.header("Content-Type", "application/octet-stream")
.body(video_data.to_vec())
.send()
.await
.map_err(|e| format!("Failed to call botmodels: {}", e))?;
.map_err(|e| format!("Failed to read video: {e}"))?;
if response.status().is_success() {
let result: serde_json::Value = response

View file

@ -1,57 +1,3 @@
use chrono::{DateTime, Duration, Utc};
use rhai::{Array, Dynamic, Engine, Map};
use serde::{Deserialize, Serialize};
@ -59,10 +5,8 @@ use std::collections::HashMap;
use tracing::info;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApprovalRequest {
pub id: Uuid,
pub bot_id: Uuid,
@ -106,11 +50,9 @@ pub struct ApprovalRequest {
pub comments: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum ApprovalStatus {
Pending,
Approved,
@ -132,7 +74,6 @@ impl Default for ApprovalStatus {
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum ApprovalDecision {
@ -143,7 +84,6 @@ pub enum ApprovalDecision {
RequestInfo,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum ApprovalChannel {
@ -176,10 +116,8 @@ impl std::fmt::Display for ApprovalChannel {
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApprovalChain {
pub name: String,
pub bot_id: Uuid,
@ -195,10 +133,8 @@ pub struct ApprovalChain {
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApprovalLevel {
pub level: u32,
pub channel: ApprovalChannel,
@ -216,10 +152,8 @@ pub struct ApprovalLevel {
pub required_approvals: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApprovalAuditEntry {
pub id: Uuid,
pub request_id: Uuid,
@ -237,7 +171,6 @@ pub struct ApprovalAuditEntry {
pub user_agent: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum AuditAction {
@ -253,10 +186,8 @@ pub enum AuditAction {
ContextUpdated,
}
#[derive(Debug, Clone)]
pub struct ApprovalConfig {
pub enabled: bool,
pub default_timeout: u64,
@ -289,19 +220,16 @@ impl Default for ApprovalConfig {
}
}
#[derive(Debug)]
pub struct ApprovalManager {
config: ApprovalConfig,
}
impl ApprovalManager {
pub fn new(config: ApprovalConfig) -> Self {
ApprovalManager { config }
}
pub fn from_config(config_map: &HashMap<String, String>) -> Self {
let config = ApprovalConfig {
enabled: config_map
@ -331,7 +259,6 @@ impl ApprovalManager {
ApprovalManager::new(config)
}
pub fn create_request(
&self,
bot_id: Uuid,
@ -373,12 +300,10 @@ impl ApprovalManager {
}
}
pub fn is_expired(&self, request: &ApprovalRequest) -> bool {
Utc::now() > request.expires_at
}
pub fn should_send_reminder(&self, request: &ApprovalRequest) -> bool {
if request.status != ApprovalStatus::Pending {
return false;
@ -398,7 +323,6 @@ impl ApprovalManager {
since_last.num_seconds() >= self.config.reminder_interval as i64
}
pub fn generate_approval_url(&self, request_id: Uuid, action: &str, token: &str) -> String {
let base_url = self
.config
@ -412,7 +336,6 @@ impl ApprovalManager {
)
}
pub fn generate_email_content(&self, request: &ApprovalRequest, token: &str) -> EmailContent {
let approve_url = self.generate_approval_url(request.id, "approve", token);
let reject_url = self.generate_approval_url(request.id, "reject", token);
@ -423,7 +346,7 @@ impl ApprovalManager {
);
let body = format!(
r#"
r"
An approval is requested for:
Type: {}
@ -438,7 +361,7 @@ To approve, click: {}
To reject, click: {}
If you have questions, reply to this email.
"#,
",
request.approval_type,
request.message,
serde_json::to_string_pretty(&request.context).unwrap_or_default(),
@ -454,7 +377,6 @@ If you have questions, reply to this email.
}
}
pub fn process_decision(
&self,
request: &mut ApprovalRequest,
@ -475,7 +397,6 @@ If you have questions, reply to this email.
};
}
pub fn handle_timeout(&self, request: &mut ApprovalRequest) {
if let Some(default_action) = &request.default_action {
request.decision = Some(default_action.clone());
@ -491,14 +412,11 @@ If you have questions, reply to this email.
}
}
pub fn evaluate_condition(
&self,
condition: &str,
context: &serde_json::Value,
) -> Result<bool, String> {
let parts: Vec<&str> = condition.split_whitespace().collect();
if parts.len() != 3 {
return Err(format!("Invalid condition format: {}", condition));
@ -531,7 +449,6 @@ If you have questions, reply to this email.
}
}
#[derive(Debug, Clone)]
pub struct EmailContent {
pub subject: String,
@ -539,7 +456,6 @@ pub struct EmailContent {
pub html_body: Option<String>,
}
impl ApprovalRequest {
pub fn to_dynamic(&self) -> Dynamic {
let mut map = Map::new();
@ -589,7 +505,6 @@ impl ApprovalRequest {
}
}
fn json_to_dynamic(value: &serde_json::Value) -> Dynamic {
match value {
serde_json::Value::Null => Dynamic::UNIT,
@ -618,10 +533,7 @@ fn json_to_dynamic(value: &serde_json::Value) -> Dynamic {
}
}
pub fn register_approval_keywords(engine: &mut Engine) {
engine.register_fn("approval_is_approved", |request: Map| -> bool {
request
.get("status")
@ -678,7 +590,6 @@ pub fn register_approval_keywords(engine: &mut Engine) {
info!("Approval keywords registered");
}
pub const APPROVAL_SCHEMA: &str = r#"
-- Approval requests
CREATE TABLE IF NOT EXISTS approval_requests (
@ -758,9 +669,8 @@ CREATE INDEX IF NOT EXISTS idx_approval_tokens_token ON approval_tokens(token);
CREATE INDEX IF NOT EXISTS idx_approval_tokens_request_id ON approval_tokens(request_id);
"#;
pub mod sql {
pub const INSERT_REQUEST: &str = r#"
pub const INSERT_REQUEST: &str = r"
INSERT INTO approval_requests (
id, bot_id, session_id, initiated_by, approval_type, status,
channel, recipient, context, message, timeout_seconds,
@ -769,9 +679,9 @@ pub mod sql {
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17
)
"#;
";
pub const UPDATE_REQUEST: &str = r#"
pub const UPDATE_REQUEST: &str = r"
UPDATE approval_requests
SET status = $2,
decision = $3,
@ -779,71 +689,71 @@ pub mod sql {
decided_at = $5,
comments = $6
WHERE id = $1
"#;
";
pub const GET_REQUEST: &str = r#"
pub const GET_REQUEST: &str = r"
SELECT * FROM approval_requests WHERE id = $1
"#;
";
pub const GET_PENDING_REQUESTS: &str = r#"
pub const GET_PENDING_REQUESTS: &str = r"
SELECT * FROM approval_requests
WHERE status = 'pending'
AND expires_at > NOW()
ORDER BY created_at ASC
"#;
";
pub const GET_EXPIRED_REQUESTS: &str = r#"
pub const GET_EXPIRED_REQUESTS: &str = r"
SELECT * FROM approval_requests
WHERE status = 'pending'
AND expires_at <= NOW()
"#;
";
pub const GET_REQUESTS_BY_SESSION: &str = r#"
pub const GET_REQUESTS_BY_SESSION: &str = r"
SELECT * FROM approval_requests
WHERE session_id = $1
ORDER BY created_at DESC
"#;
";
pub const UPDATE_REMINDERS: &str = r#"
pub const UPDATE_REMINDERS: &str = r"
UPDATE approval_requests
SET reminders_sent = reminders_sent || $2::jsonb
WHERE id = $1
"#;
";
pub const INSERT_AUDIT: &str = r#"
pub const INSERT_AUDIT: &str = r"
INSERT INTO approval_audit_log (
id, request_id, action, actor, details, timestamp, ip_address, user_agent
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8
)
"#;
";
pub const GET_AUDIT_LOG: &str = r#"
pub const GET_AUDIT_LOG: &str = r"
SELECT * FROM approval_audit_log
WHERE request_id = $1
ORDER BY timestamp ASC
"#;
";
pub const INSERT_TOKEN: &str = r#"
pub const INSERT_TOKEN: &str = r"
INSERT INTO approval_tokens (
id, request_id, token, action, expires_at, created_at
) VALUES (
$1, $2, $3, $4, $5, $6
)
"#;
";
pub const GET_TOKEN: &str = r#"
pub const GET_TOKEN: &str = r"
SELECT * FROM approval_tokens
WHERE token = $1 AND used = false AND expires_at > NOW()
"#;
";
pub const USE_TOKEN: &str = r#"
pub const USE_TOKEN: &str = r"
UPDATE approval_tokens
SET used = true, used_at = NOW()
WHERE token = $1
"#;
";
pub const INSERT_CHAIN: &str = r#"
pub const INSERT_CHAIN: &str = r"
INSERT INTO approval_chains (
id, name, bot_id, levels, stop_on_reject, require_all, description, created_at
) VALUES (
@ -855,10 +765,10 @@ pub mod sql {
stop_on_reject = $5,
require_all = $6,
description = $7
"#;
";
pub const GET_CHAIN: &str = r#"
pub const GET_CHAIN: &str = r"
SELECT * FROM approval_chains
WHERE bot_id = $1 AND name = $2
"#;
";
}

View file

@ -1,50 +1,3 @@
use chrono::{DateTime, Utc};
use rhai::{Array, Dynamic, Engine, Map};
use serde::{Deserialize, Serialize};
@ -52,10 +5,8 @@ use std::collections::HashMap;
use tracing::info;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KgEntity {
pub id: Uuid,
pub bot_id: Uuid,
@ -77,7 +28,6 @@ pub struct KgEntity {
pub updated_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum EntitySource {
@ -93,10 +43,8 @@ impl Default for EntitySource {
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KgRelationship {
pub id: Uuid,
pub bot_id: Uuid,
@ -118,10 +66,8 @@ pub struct KgRelationship {
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtractedEntity {
pub name: String,
pub canonical_name: String,
@ -137,10 +83,8 @@ pub struct ExtractedEntity {
pub properties: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtractedRelationship {
pub from_entity: String,
pub to_entity: String,
@ -152,10 +96,8 @@ pub struct ExtractedRelationship {
pub evidence: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtractionResult {
pub entities: Vec<ExtractedEntity>,
pub relationships: Vec<ExtractedRelationship>,
@ -163,10 +105,8 @@ pub struct ExtractionResult {
pub metadata: ExtractionMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtractionMetadata {
pub model: String,
pub processing_time_ms: u64,
@ -176,10 +116,8 @@ pub struct ExtractionMetadata {
pub text_length: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphQueryResult {
pub entities: Vec<KgEntity>,
pub relationships: Vec<KgRelationship>,
@ -189,10 +127,8 @@ pub struct GraphQueryResult {
pub confidence: f64,
}
#[derive(Debug, Clone)]
pub struct KnowledgeGraphConfig {
pub enabled: bool,
pub backend: String,
@ -233,19 +169,16 @@ impl Default for KnowledgeGraphConfig {
}
}
#[derive(Debug)]
pub struct KnowledgeGraphManager {
config: KnowledgeGraphConfig,
}
impl KnowledgeGraphManager {
pub fn new(config: KnowledgeGraphConfig) -> Self {
KnowledgeGraphManager { config }
}
pub fn from_config(config_map: &HashMap<String, String>) -> Self {
let config = KnowledgeGraphConfig {
enabled: config_map
@ -284,7 +217,6 @@ impl KnowledgeGraphManager {
KnowledgeGraphManager::new(config)
}
pub fn generate_extraction_prompt(&self, text: &str) -> String {
let entity_types = self.config.entity_types.join(", ");
@ -320,7 +252,6 @@ Respond with valid JSON only:
)
}
pub fn generate_query_prompt(&self, query: &str, context: &str) -> String {
format!(
r#"Answer this question using the knowledge graph context.
@ -336,7 +267,6 @@ If the information is not available, say so clearly.
)
}
pub fn parse_extraction_response(
&self,
response: &str,
@ -401,12 +331,10 @@ If the information is not available, say so clearly.
})
}
pub fn should_extract(&self) -> bool {
self.config.enabled && self.config.extract_entities
}
pub fn is_valid_entity_type(&self, entity_type: &str) -> bool {
self.config
.entity_types
@ -415,16 +343,13 @@ If the information is not available, say so clearly.
}
}
fn extract_json(response: &str) -> Result<String, String> {
if let Some(start) = response.find("```json") {
if let Some(end) = response[start + 7..].find("```") {
return Ok(response[start + 7..start + 7 + end].trim().to_string());
}
}
if let Some(start) = response.find("```") {
let after_start = start + 3;
let json_start = response[after_start..]
@ -436,7 +361,6 @@ fn extract_json(response: &str) -> Result<String, String> {
}
}
if let Some(start) = response.find('{') {
if let Some(end) = response.rfind('}') {
if end > start {
@ -448,7 +372,6 @@ fn extract_json(response: &str) -> Result<String, String> {
Err("No JSON found in response".to_string())
}
impl KgEntity {
pub fn to_dynamic(&self) -> Dynamic {
let mut map = Map::new();
@ -478,7 +401,6 @@ impl KgEntity {
}
}
impl KgRelationship {
pub fn to_dynamic(&self) -> Dynamic {
let mut map = Map::new();
@ -507,7 +429,6 @@ impl KgRelationship {
}
}
fn json_to_dynamic(value: &serde_json::Value) -> Dynamic {
match value {
serde_json::Value::Null => Dynamic::UNIT,
@ -536,10 +457,7 @@ fn json_to_dynamic(value: &serde_json::Value) -> Dynamic {
}
}
pub fn register_knowledge_graph_keywords(engine: &mut Engine) {
engine.register_fn("entity_name", |entity: Map| -> String {
entity
.get("entity_name")
@ -576,7 +494,6 @@ pub fn register_knowledge_graph_keywords(engine: &mut Engine) {
info!("Knowledge graph keywords registered");
}
pub const KNOWLEDGE_GRAPH_SCHEMA: &str = r#"
-- Knowledge graph entities
CREATE TABLE IF NOT EXISTS kg_entities (
@ -627,9 +544,8 @@ CREATE INDEX IF NOT EXISTS idx_kg_entities_name_fts ON kg_entities
USING GIN(to_tsvector('english', entity_name));
"#;
pub mod sql {
pub const INSERT_ENTITY: &str = r#"
pub const INSERT_ENTITY: &str = r"
INSERT INTO kg_entities (
id, bot_id, entity_type, entity_name, aliases, properties,
confidence, source, created_at, updated_at
@ -643,9 +559,9 @@ pub mod sql {
confidence = GREATEST(kg_entities.confidence, $7),
updated_at = $10
RETURNING id
"#;
";
pub const INSERT_RELATIONSHIP: &str = r#"
pub const INSERT_RELATIONSHIP: &str = r"
INSERT INTO kg_relationships (
id, bot_id, from_entity_id, to_entity_id, relationship_type,
properties, confidence, bidirectional, source, created_at
@ -657,9 +573,9 @@ pub mod sql {
properties = kg_relationships.properties || $6,
confidence = GREATEST(kg_relationships.confidence, $7)
RETURNING id
"#;
";
pub const GET_ENTITY_BY_NAME: &str = r#"
pub const GET_ENTITY_BY_NAME: &str = r"
SELECT * FROM kg_entities
WHERE bot_id = $1
AND (
@ -667,13 +583,13 @@ pub mod sql {
OR aliases @> $3::jsonb
)
LIMIT 1
"#;
";
pub const GET_ENTITY_BY_ID: &str = r#"
pub const GET_ENTITY_BY_ID: &str = r"
SELECT * FROM kg_entities WHERE id = $1
"#;
";
pub const SEARCH_ENTITIES: &str = r#"
pub const SEARCH_ENTITIES: &str = r"
SELECT * FROM kg_entities
WHERE bot_id = $1
AND (
@ -682,16 +598,16 @@ pub mod sql {
)
ORDER BY confidence DESC
LIMIT $4
"#;
";
pub const GET_ENTITIES_BY_TYPE: &str = r#"
pub const GET_ENTITIES_BY_TYPE: &str = r"
SELECT * FROM kg_entities
WHERE bot_id = $1 AND entity_type = $2
ORDER BY entity_name
LIMIT $3
"#;
";
pub const GET_RELATED_ENTITIES: &str = r#"
pub const GET_RELATED_ENTITIES: &str = r"
SELECT e.*, r.relationship_type, r.confidence as rel_confidence
FROM kg_entities e
JOIN kg_relationships r ON (
@ -701,9 +617,9 @@ pub mod sql {
WHERE r.bot_id = $2
ORDER BY r.confidence DESC
LIMIT $3
"#;
";
pub const GET_RELATED_BY_TYPE: &str = r#"
pub const GET_RELATED_BY_TYPE: &str = r"
SELECT e.*, r.relationship_type, r.confidence as rel_confidence
FROM kg_entities e
JOIN kg_relationships r ON (
@ -713,17 +629,17 @@ pub mod sql {
WHERE r.bot_id = $2 AND r.relationship_type = $3
ORDER BY r.confidence DESC
LIMIT $4
"#;
";
pub const GET_RELATIONSHIP: &str = r#"
pub const GET_RELATIONSHIP: &str = r"
SELECT * FROM kg_relationships
WHERE bot_id = $1
AND from_entity_id = $2
AND to_entity_id = $3
AND relationship_type = $4
"#;
";
pub const GET_ALL_RELATIONSHIPS_FOR_ENTITY: &str = r#"
pub const GET_ALL_RELATIONSHIPS_FOR_ENTITY: &str = r"
SELECT r.*,
e1.entity_name as from_name, e1.entity_type as from_type,
e2.entity_name as to_name, e2.entity_type as to_type
@ -733,42 +649,41 @@ pub mod sql {
WHERE r.bot_id = $1
AND (r.from_entity_id = $2 OR r.to_entity_id = $2)
ORDER BY r.confidence DESC
"#;
";
pub const DELETE_ENTITY: &str = r#"
pub const DELETE_ENTITY: &str = r"
DELETE FROM kg_entities WHERE id = $1 AND bot_id = $2
"#;
";
pub const DELETE_RELATIONSHIP: &str = r#"
pub const DELETE_RELATIONSHIP: &str = r"
DELETE FROM kg_relationships WHERE id = $1 AND bot_id = $2
"#;
";
pub const COUNT_ENTITIES: &str = r#"
pub const COUNT_ENTITIES: &str = r"
SELECT COUNT(*) FROM kg_entities WHERE bot_id = $1
"#;
";
pub const COUNT_RELATIONSHIPS: &str = r#"
pub const COUNT_RELATIONSHIPS: &str = r"
SELECT COUNT(*) FROM kg_relationships WHERE bot_id = $1
"#;
";
pub const GET_ENTITY_TYPES: &str = r#"
pub const GET_ENTITY_TYPES: &str = r"
SELECT DISTINCT entity_type, COUNT(*) as count
FROM kg_entities
WHERE bot_id = $1
GROUP BY entity_type
ORDER BY count DESC
"#;
";
pub const GET_RELATIONSHIP_TYPES: &str = r#"
pub const GET_RELATIONSHIP_TYPES: &str = r"
SELECT DISTINCT relationship_type, COUNT(*) as count
FROM kg_relationships
WHERE bot_id = $1
GROUP BY relationship_type
ORDER BY count DESC
"#;
";
pub const FIND_PATH: &str = r#"
pub const FIND_PATH: &str = r"
WITH RECURSIVE path_finder AS (
-- Base case: start from source entity
SELECT
@ -799,10 +714,9 @@ pub mod sql {
WHERE to_entity_id = $3
ORDER BY depth
LIMIT 1
"#;
";
}
pub mod relationship_types {
pub const WORKS_ON: &str = "works_on";
pub const REPORTS_TO: &str = "reports_to";
@ -820,7 +734,6 @@ pub mod relationship_types {
pub const ALIAS_OF: &str = "alias_of";
}
pub mod entity_types {
pub const PERSON: &str = "person";
pub const ORGANIZATION: &str = "organization";

View file

@ -3,7 +3,7 @@ use rhai::Engine;
pub fn last_keyword(engine: &mut Engine) {
engine
.register_custom_syntax(&["LAST", "(", "$expr$", ")"], false, {
.register_custom_syntax(["LAST", "(", "$expr$", ")"], false, {
move |context, inputs| {
let input_string = context.eval_expression_tree(&inputs[0])?;
let input_str = input_string.to_string();

View file

@ -28,14 +28,6 @@
| |
\*****************************************************************************/
use crate::core::config::ConfigManager;
use crate::shared::models::UserSession;
use crate::shared::state::AppState;
@ -45,7 +37,6 @@ use std::sync::Arc;
use std::time::Duration;
use uuid::Uuid;
pub fn register_llm_macros(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
register_calculate_keyword(state.clone(), user.clone(), engine);
register_validate_keyword(state.clone(), user.clone(), engine);
@ -53,7 +44,6 @@ pub fn register_llm_macros(state: Arc<AppState>, user: UserSession, engine: &mut
register_summarize_keyword(state, user, engine);
}
async fn call_llm(
state: &AppState,
prompt: &str,
@ -75,7 +65,6 @@ async fn call_llm(
Ok(processed)
}
fn run_llm_with_timeout(
state: Arc<AppState>,
prompt: String,
@ -117,8 +106,6 @@ fn run_llm_with_timeout(
}
}
pub fn register_calculate_keyword(state: Arc<AppState>, _user: UserSession, engine: &mut Engine) {
let state_clone = Arc::clone(&state);
@ -158,7 +145,7 @@ fn build_calculate_prompt(formula: &str, variables: &Dynamic) -> String {
};
format!(
r#"You are a precise calculator. Evaluate the following expression.
r"You are a precise calculator. Evaluate the following expression.
Formula: {}
Variables: {}
@ -169,7 +156,7 @@ Instructions:
3. Return ONLY the final result (number, boolean, or text)
4. No explanations, just the result
Result:"#,
Result:",
formula, vars_str
)
}
@ -198,8 +185,6 @@ fn parse_calculate_result(result: &str) -> Result<Dynamic, Box<rhai::EvalAltResu
Ok(Dynamic::from(trimmed.to_string()))
}
pub fn register_validate_keyword(state: Arc<AppState>, _user: UserSession, engine: &mut Engine) {
let state_clone = Arc::clone(&state);
@ -306,8 +291,6 @@ fn parse_validate_result(result: &str) -> Result<Dynamic, Box<rhai::EvalAltResul
Ok(Dynamic::from(map))
}
pub fn register_translate_keyword(state: Arc<AppState>, _user: UserSession, engine: &mut Engine) {
let state_clone = Arc::clone(&state);
@ -336,20 +319,18 @@ pub fn register_translate_keyword(state: Arc<AppState>, _user: UserSession, engi
fn build_translate_prompt(text: &str, language: &str) -> String {
format!(
r#"Translate the following text to {}.
r"Translate the following text to {}.
Text:
{}
Return ONLY the translated text, no explanations:
Translation:"#,
Translation:",
language, text
)
}
pub fn register_summarize_keyword(state: Arc<AppState>, _user: UserSession, engine: &mut Engine) {
let state_clone = Arc::clone(&state);
@ -369,14 +350,14 @@ pub fn register_summarize_keyword(state: Arc<AppState>, _user: UserSession, engi
fn build_summarize_prompt(text: &str) -> String {
format!(
r#"Summarize the following text concisely.
r"Summarize the following text concisely.
Text:
{}
Return ONLY the summary:
Summary:"#,
Summary:",
text
)
}

View file

@ -126,7 +126,7 @@ pub fn parse_folder_path(path: &str) -> (FolderProvider, Option<String>, String)
(FolderProvider::Local, None, path.to_string())
}
fn detect_provider_from_email(email: &str) -> FolderProvider {
pub fn detect_provider_from_email(email: &str) -> FolderProvider {
let lower = email.to_lowercase();
if lower.ends_with("@gmail.com") || lower.contains("google") {
FolderProvider::GDrive
@ -285,7 +285,7 @@ fn register_on_change_with_events(state: &AppState, user: UserSession, engine: &
.unwrap();
}
fn sanitize_path_for_filename(path: &str) -> String {
pub fn sanitize_path_for_filename(path: &str) -> String {
path.replace('/', "_")
.replace('\\', "_")
.replace(':', "_")

View file

@ -1,16 +1,3 @@
use crate::shared::models::UserSession;
use crate::shared::state::AppState;
use log::{info, trace};
@ -21,7 +8,6 @@ use std::path::Path;
use std::sync::Arc;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ContentType {
Video,
@ -39,10 +25,8 @@ pub enum ContentType {
}
impl ContentType {
pub fn from_extension(ext: &str) -> Self {
match ext.to_lowercase().as_str() {
"mp4" | "webm" | "ogg" | "mov" | "avi" | "mkv" | "m4v" => ContentType::Video,
"mp3" | "wav" | "flac" | "aac" | "m4a" | "wma" => ContentType::Audio,
@ -68,7 +52,6 @@ impl ContentType {
}
}
pub fn from_mime(mime: &str) -> Self {
if mime.starts_with("video/") {
ContentType::Video
@ -97,7 +80,6 @@ impl ContentType {
}
}
pub fn player_component(&self) -> &'static str {
match self {
ContentType::Video => "video-player",
@ -116,7 +98,6 @@ impl ContentType {
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct PlayOptions {
pub autoplay: bool,
@ -137,7 +118,6 @@ pub struct PlayOptions {
}
impl PlayOptions {
pub fn from_string(options_str: &str) -> Self {
let mut opts = PlayOptions::default();
opts.controls = true;
@ -151,7 +131,6 @@ impl PlayOptions {
"nocontrols" => opts.controls = false,
"linenumbers" => opts.line_numbers = Some(true),
_ => {
if let Some((key, value)) = opt.split_once('=') {
match key {
"start" => opts.start_time = value.parse().ok(),
@ -177,7 +156,6 @@ impl PlayOptions {
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PlayContent {
pub id: Uuid,
@ -189,7 +167,6 @@ pub struct PlayContent {
pub created_at: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PlayResponse {
pub player_id: Uuid,
@ -201,7 +178,6 @@ pub struct PlayResponse {
pub metadata: HashMap<String, String>,
}
pub fn play_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
if let Err(e) = play_simple_keyword(state.clone(), user.clone(), engine) {
log::error!("Failed to register PLAY keyword: {}", e);
@ -220,7 +196,6 @@ pub fn play_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine
}
}
fn play_simple_keyword(
state: Arc<AppState>,
user: UserSession,
@ -269,7 +244,6 @@ fn play_simple_keyword(
Ok(())
}
fn play_with_options_keyword(
state: Arc<AppState>,
user: UserSession,
@ -334,7 +308,6 @@ fn play_with_options_keyword(
Ok(())
}
fn stop_keyword(
state: Arc<AppState>,
user: UserSession,
@ -373,7 +346,6 @@ fn stop_keyword(
Ok(())
}
fn pause_keyword(
state: Arc<AppState>,
user: UserSession,
@ -413,7 +385,6 @@ fn pause_keyword(
Ok(())
}
fn resume_keyword(
state: Arc<AppState>,
user: UserSession,
@ -453,34 +424,25 @@ fn resume_keyword(
Ok(())
}
async fn execute_play(
state: &AppState,
session_id: Uuid,
source: &str,
options: PlayOptions,
) -> Result<PlayResponse, String> {
let content_type = detect_content_type(source);
let source_url = resolve_source_url(state, session_id, source).await?;
let metadata = get_content_metadata(state, &source_url, &content_type).await?;
let player_id = Uuid::new_v4();
let title = metadata
.get("title")
.cloned()
.unwrap_or_else(|| extract_title_from_source(source));
let response = PlayResponse {
player_id,
content_type: content_type.clone(),
@ -491,7 +453,6 @@ async fn execute_play(
metadata,
};
send_play_to_client(state, session_id, &response).await?;
info!(
@ -502,11 +463,8 @@ async fn execute_play(
Ok(response)
}
fn detect_content_type(source: &str) -> ContentType {
pub fn detect_content_type(source: &str) -> ContentType {
if source.starts_with("http://") || source.starts_with("https://") {
if source.contains("youtube.com")
|| source.contains("youtu.be")
|| source.contains("vimeo.com")
@ -514,7 +472,6 @@ fn detect_content_type(source: &str) -> ContentType {
return ContentType::Video;
}
if source.contains("imgur.com")
|| source.contains("unsplash.com")
|| source.contains("pexels.com")
@ -522,18 +479,15 @@ fn detect_content_type(source: &str) -> ContentType {
return ContentType::Image;
}
if let Some(path) = source.split('?').next() {
if let Some(ext) = Path::new(path).extension() {
return ContentType::from_extension(&ext.to_string_lossy());
}
}
return ContentType::Iframe;
}
if let Some(ext) = Path::new(source).extension() {
return ContentType::from_extension(&ext.to_string_lossy());
}
@ -541,20 +495,16 @@ fn detect_content_type(source: &str) -> ContentType {
ContentType::Unknown
}
async fn resolve_source_url(
_state: &AppState,
session_id: Uuid,
source: &str,
) -> Result<String, String> {
if source.starts_with("http://") || source.starts_with("https://") {
return Ok(source.to_string());
}
if source.starts_with("/") || source.contains(".gbdrive") {
let file_url = format!(
"/api/drive/file/{}?session={}",
urlencoding::encode(source),
@ -563,7 +513,6 @@ async fn resolve_source_url(
return Ok(file_url);
}
let file_url = format!(
"/api/drive/file/{}?session={}",
urlencoding::encode(source),
@ -573,7 +522,6 @@ async fn resolve_source_url(
Ok(file_url)
}
async fn get_content_metadata(
_state: &AppState,
source_url: &str,
@ -584,7 +532,6 @@ async fn get_content_metadata(
metadata.insert("source".to_string(), source_url.to_string());
metadata.insert("type".to_string(), format!("{:?}", content_type));
match content_type {
ContentType::Video => {
metadata.insert("player".to_string(), "html5".to_string());
@ -613,9 +560,7 @@ async fn get_content_metadata(
Ok(metadata)
}
fn extract_title_from_source(source: &str) -> String {
pub fn extract_title_from_source(source: &str) -> String {
let path = source.split('?').next().unwrap_or(source);
Path::new(path)
@ -624,7 +569,6 @@ fn extract_title_from_source(source: &str) -> String {
.unwrap_or_else(|| "Untitled".to_string())
}
async fn send_play_to_client(
state: &AppState,
session_id: Uuid,
@ -638,7 +582,6 @@ async fn send_play_to_client(
let message_str =
serde_json::to_string(&message).map_err(|e| format!("Failed to serialize: {}", e))?;
let bot_response = crate::shared::models::BotResponse {
bot_id: String::new(),
user_id: String::new(),
@ -663,7 +606,6 @@ async fn send_play_to_client(
Ok(())
}
async fn send_player_command(
state: &AppState,
session_id: Uuid,
@ -677,7 +619,6 @@ async fn send_player_command(
let message_str =
serde_json::to_string(&message).map_err(|e| format!("Failed to serialize: {}", e))?;
let _ = state
.web_adapter
.send_message_to_session(

View file

@ -28,18 +28,6 @@
| |
\*****************************************************************************/
use crate::core::config::ConfigManager;
use crate::shared::models::UserSession;
use crate::shared::state::AppState;
@ -49,7 +37,6 @@ use serde::{Deserialize, Serialize};
use std::sync::Arc;
use uuid::Uuid;
#[derive(Debug, Clone, PartialEq)]
pub enum SmsProvider {
Twilio,
@ -59,7 +46,6 @@ pub enum SmsProvider {
Custom(String),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum SmsPriority {
Low,
@ -108,7 +94,6 @@ impl From<&str> for SmsProvider {
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SmsSendResult {
pub success: bool,
@ -119,15 +104,12 @@ pub struct SmsSendResult {
pub error: Option<String>,
}
pub fn register_sms_keywords(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
register_send_sms_keyword(state.clone(), user.clone(), engine);
register_send_sms_with_third_arg_keyword(state.clone(), user.clone(), engine);
register_send_sms_full_keyword(state, user, engine);
}
pub fn register_send_sms_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
let state_clone = Arc::clone(&state);
let user_clone = user.clone();
@ -211,9 +193,6 @@ pub fn register_send_sms_keyword(state: Arc<AppState>, user: UserSession, engine
.unwrap();
}
pub fn register_send_sms_with_third_arg_keyword(
state: Arc<AppState>,
user: UserSession,
@ -231,7 +210,6 @@ pub fn register_send_sms_with_third_arg_keyword(
let message = context.eval_expression_tree(&inputs[1])?.to_string();
let third_arg = context.eval_expression_tree(&inputs[2])?.to_string();
let is_priority = matches!(
third_arg.to_lowercase().as_str(),
"low" | "normal" | "high" | "urgent" | "critical"
@ -319,8 +297,6 @@ pub fn register_send_sms_with_third_arg_keyword(
.unwrap();
}
pub fn register_send_sms_full_keyword(
state: Arc<AppState>,
user: UserSession,
@ -414,7 +390,6 @@ pub fn register_send_sms_full_keyword(
)
.unwrap();
let state_clone2 = Arc::clone(&state);
let user_clone2 = user.clone();
@ -517,7 +492,6 @@ async fn execute_send_sms(
let config_manager = ConfigManager::new(state.conn.clone());
let bot_id = user.bot_id;
let provider_name = match provider_override {
Some(p) => p.to_string(),
None => config_manager
@ -527,7 +501,6 @@ async fn execute_send_sms(
let provider = SmsProvider::from(provider_name.as_str());
let priority = match priority_override {
Some(p) => SmsPriority::from(p),
None => {
@ -538,10 +511,8 @@ async fn execute_send_sms(
}
};
let normalized_phone = normalize_phone_number(phone);
if matches!(priority, SmsPriority::High | SmsPriority::Urgent) {
info!(
"High priority SMS to {}: priority={}",
@ -549,7 +520,6 @@ async fn execute_send_sms(
);
}
let result = match provider {
SmsProvider::Twilio => {
send_via_twilio(state, &bot_id, &normalized_phone, message, &priority).await
@ -602,20 +572,16 @@ async fn execute_send_sms(
}
fn normalize_phone_number(phone: &str) -> String {
let has_plus = phone.starts_with('+');
let digits: String = phone.chars().filter(|c| c.is_ascii_digit()).collect();
if has_plus {
format!("+{}", digits)
} else if digits.len() == 10 {
format!("+1{}", digits)
} else if digits.len() == 11 && digits.starts_with('1') {
format!("+{}", digits)
} else {
format!("+{}", digits)
}
}
@ -647,8 +613,6 @@ async fn send_via_twilio(
account_sid
);
let final_message = match priority {
SmsPriority::Urgent => format!("[URGENT] {}", message),
SmsPriority::High => format!("[HIGH] {}", message),
@ -699,22 +663,17 @@ async fn send_via_aws_sns(
.get_config(bot_id, "aws-region", Some("us-east-1"))
.unwrap_or_else(|_| "us-east-1".to_string());
let client = reqwest::Client::new();
let url = format!("https://sns.{}.amazonaws.com/", region);
let timestamp = chrono::Utc::now().format("%Y%m%dT%H%M%SZ").to_string();
let _date = &timestamp[..8];
let sms_type = match priority {
SmsPriority::High | SmsPriority::Urgent => "Transactional",
_ => "Promotional",
};
let params = [
("Action", "Publish"),
("PhoneNumber", phone),
@ -725,8 +684,6 @@ async fn send_via_aws_sns(
("MessageAttributes.entry.1.Value.StringValue", sms_type),
];
let response = client
.post(&url)
.form(&params)
@ -774,7 +731,6 @@ async fn send_via_vonage(
let client = reqwest::Client::new();
let message_class = match priority {
SmsPriority::Urgent => Some("0"),
_ => None,
@ -798,25 +754,24 @@ async fn send_via_vonage(
.send()
.await?;
if response.status().is_success() {
let json: serde_json::Value = response.json().await?;
let messages = json["messages"].as_array();
if let Some(msgs) = messages {
if let Some(first) = msgs.first() {
if first["status"].as_str() == Some("0") {
return Ok(first["message-id"].as_str().map(|s| s.to_string()));
} else {
let error_text = first["error-text"].as_str().unwrap_or("Unknown error");
return Err(format!("Vonage error: {}", error_text).into());
}
}
}
Err("Invalid Vonage response".into())
} else {
if !response.status().is_success() {
let error_text = response.text().await?;
Err(format!("Vonage API error: {}", error_text).into())
return Err(format!("Vonage API error: {error_text}").into());
}
let json: serde_json::Value = response.json().await?;
let messages = json["messages"].as_array();
if let Some(msgs) = messages {
if let Some(first) = msgs.first() {
if first["status"].as_str() == Some("0") {
return Ok(first["message-id"].as_str().map(|s| s.to_string()));
}
let error_text = first["error-text"].as_str().unwrap_or("Unknown error");
return Err(format!("Vonage error: {error_text}").into());
}
}
Err("Invalid Vonage response".into())
}
async fn send_via_messagebird(
@ -840,7 +795,6 @@ async fn send_via_messagebird(
let client = reqwest::Client::new();
let type_details = match priority {
SmsPriority::Urgent => Some(serde_json::json!({"class": 0})),
SmsPriority::High => Some(serde_json::json!({"class": 1})),

View file

@ -1,42 +1,3 @@
use crate::shared::models::UserSession;
use crate::shared::state::AppState;
use chrono::Utc;
@ -48,10 +9,8 @@ use std::path::PathBuf;
use std::sync::Arc;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransferToHumanRequest {
pub name: Option<String>,
pub department: Option<String>,
@ -63,7 +22,6 @@ pub struct TransferToHumanRequest {
pub context: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransferResult {
pub success: bool,
@ -75,11 +33,9 @@ pub struct TransferResult {
pub message: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TransferStatus {
Queued,
Assigned,
@ -95,7 +51,6 @@ pub enum TransferStatus {
Error,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Attendant {
pub id: String,
@ -107,7 +62,6 @@ pub struct Attendant {
pub status: AttendantStatus,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum AttendantStatus {
@ -123,14 +77,12 @@ impl Default for AttendantStatus {
}
}
pub fn is_crm_enabled(bot_id: Uuid, work_path: &str) -> bool {
let config_path = PathBuf::from(work_path)
.join(format!("{}.gbai", bot_id))
.join("config.csv");
if !config_path.exists() {
let alt_path = PathBuf::from(work_path).join("config.csv");
if alt_path.exists() {
return check_config_for_crm(&alt_path);
@ -167,14 +119,12 @@ fn check_config_for_crm(config_path: &PathBuf) -> bool {
}
}
pub fn read_attendants(bot_id: Uuid, work_path: &str) -> Vec<Attendant> {
let attendant_path = PathBuf::from(work_path)
.join(format!("{}.gbai", bot_id))
.join("attendant.csv");
if !attendant_path.exists() {
let alt_path = PathBuf::from(work_path).join("attendant.csv");
if alt_path.exists() {
return parse_attendants_csv(&alt_path);
@ -192,11 +142,9 @@ fn parse_attendants_csv(path: &PathBuf) -> Vec<Attendant> {
let mut attendants = Vec::new();
let mut lines = content.lines();
let header = lines.next().unwrap_or("");
let headers: Vec<String> = header.split(',').map(|s| s.trim().to_lowercase()).collect();
let id_idx = headers.iter().position(|h| h == "id").unwrap_or(0);
let name_idx = headers.iter().position(|h| h == "name").unwrap_or(1);
let channel_idx = headers.iter().position(|h| h == "channel").unwrap_or(2);
@ -239,7 +187,6 @@ fn parse_attendants_csv(path: &PathBuf) -> Vec<Attendant> {
}
}
pub fn find_attendant<'a>(
attendants: &'a [Attendant],
name: Option<&str>,
@ -248,7 +195,6 @@ pub fn find_attendant<'a>(
if let Some(search_name) = name {
let search_lower = search_name.to_lowercase();
if let Some(att) = attendants
.iter()
.find(|a| a.name.to_lowercase() == search_lower)
@ -256,7 +202,6 @@ pub fn find_attendant<'a>(
return Some(att);
}
if let Some(att) = attendants
.iter()
.find(|a| a.name.to_lowercase().contains(&search_lower))
@ -264,7 +209,6 @@ pub fn find_attendant<'a>(
return Some(att);
}
if let Some(att) = attendants
.iter()
.find(|a| a.aliases.contains(&search_lower))
@ -272,7 +216,6 @@ pub fn find_attendant<'a>(
return Some(att);
}
if let Some(att) = attendants
.iter()
.find(|a| a.id.to_lowercase() == search_lower)
@ -284,7 +227,6 @@ pub fn find_attendant<'a>(
if let Some(dept) = department {
let dept_lower = dept.to_lowercase();
if let Some(att) = attendants.iter().find(|a| {
a.department
.as_ref()
@ -295,7 +237,6 @@ pub fn find_attendant<'a>(
return Some(att);
}
if let Some(att) = attendants.iter().find(|a| {
a.preferences.to_lowercase().contains(&dept_lower)
&& a.status == AttendantStatus::Online
@ -304,13 +245,11 @@ pub fn find_attendant<'a>(
}
}
attendants
.iter()
.find(|a| a.status == AttendantStatus::Online)
}
fn priority_to_int(priority: Option<&str>) -> i32 {
match priority.map(|p| p.to_lowercase()).as_deref() {
Some("urgent") => 3,
@ -321,7 +260,6 @@ fn priority_to_int(priority: Option<&str>) -> i32 {
}
}
pub async fn execute_transfer(
state: Arc<AppState>,
session: &UserSession,
@ -330,7 +268,6 @@ pub async fn execute_transfer(
let work_path = "./work";
let bot_id = session.bot_id;
if !is_crm_enabled(bot_id, work_path) {
return TransferResult {
success: false,
@ -344,7 +281,6 @@ pub async fn execute_transfer(
};
}
let attendants = read_attendants(bot_id, work_path);
if attendants.is_empty() {
return TransferResult {
@ -359,14 +295,12 @@ pub async fn execute_transfer(
};
}
let attendant = find_attendant(
&attendants,
request.name.as_deref(),
request.department.as_deref(),
);
if request.name.is_some() && attendant.is_none() {
return TransferResult {
success: false,
@ -387,7 +321,6 @@ pub async fn execute_transfer(
};
}
let priority = priority_to_int(request.priority.as_deref());
let transfer_context = serde_json::json!({
"transfer_requested_at": Utc::now().to_rfc3339(),
@ -402,7 +335,6 @@ pub async fn execute_transfer(
"status": if attendant.is_some() { "assigned" } else { "queued" },
});
let session_id = session.id;
let conn = state.conn.clone();
let ctx_data = transfer_context.clone();
@ -491,7 +423,6 @@ pub async fn execute_transfer(
}
}
impl TransferResult {
pub fn to_dynamic(&self) -> Dynamic {
let mut map = Map::new();
@ -532,11 +463,11 @@ async fn calculate_queue_position(state: &Arc<AppState>, current_session_id: Uui
};
let query = diesel::sql_query(
r#"SELECT COUNT(*) as position FROM user_sessions
r"SELECT COUNT(*) as position FROM user_sessions
WHERE context_data->>'needs_human' = 'true'
AND context_data->>'status' = 'queued'
AND created_at <= (SELECT created_at FROM user_sessions WHERE id = $1)
AND id != $2"#,
AND id != $2",
)
.bind::<diesel::sql_types::Uuid, _>(current_session_id)
.bind::<diesel::sql_types::Uuid, _>(current_session_id);
@ -563,13 +494,11 @@ async fn calculate_queue_position(state: &Arc<AppState>, current_session_id: Uui
}
}
pub fn register_transfer_to_human_keyword(
state: Arc<AppState>,
user: UserSession,
engine: &mut Engine,
) {
let state_clone = state.clone();
let user_clone = user.clone();
engine.register_fn("transfer_to_human", move || -> Dynamic {
@ -595,7 +524,6 @@ pub fn register_transfer_to_human_keyword(
result.to_dynamic()
});
let state_clone = state.clone();
let user_clone = user.clone();
engine.register_fn("transfer_to_human", move |name: &str| -> Dynamic {
@ -622,7 +550,6 @@ pub fn register_transfer_to_human_keyword(
result.to_dynamic()
});
let state_clone = state.clone();
let user_clone = user.clone();
engine.register_fn(
@ -653,7 +580,6 @@ pub fn register_transfer_to_human_keyword(
},
);
let state_clone = state.clone();
let user_clone = user.clone();
engine.register_fn(
@ -685,7 +611,6 @@ pub fn register_transfer_to_human_keyword(
},
);
let state_clone = state.clone();
let user_clone = user.clone();
engine.register_fn("transfer_to_human_ex", move |params: Map| -> Dynamic {
@ -730,7 +655,6 @@ pub fn register_transfer_to_human_keyword(
debug!("Registered TRANSFER TO HUMAN keywords");
}
pub const TRANSFER_TO_HUMAN_TOOL_SCHEMA: &str = r#"{
"name": "transfer_to_human",
"description": "Transfer the conversation to a human attendant. Use this when the customer explicitly asks to speak with a person, when the issue is too complex for automated handling, or when emotional support is needed.",
@ -761,7 +685,6 @@ pub const TRANSFER_TO_HUMAN_TOOL_SCHEMA: &str = r#"{
}
}"#;
pub fn get_tool_definition() -> serde_json::Value {
serde_json::json!({
"type": "function",

View file

@ -6,7 +6,7 @@ use std::time::Duration;
pub fn wait_keyword(_state: &AppState, _user: UserSession, engine: &mut Engine) {
engine
.register_custom_syntax(&["WAIT", "$expr$"], false, move |context, inputs| {
.register_custom_syntax(["WAIT", "$expr$"], false, move |context, inputs| {
let seconds = context.eval_expression_tree(&inputs[0])?;
let duration_secs = if seconds.is::<i64>() {
#[allow(clippy::cast_precision_loss)]

View file

@ -345,9 +345,7 @@ async fn scrape_table(
.chain(row.select(&td_sel))
.map(|cell| cell.text().collect::<Vec<_>>().join(" ").trim().to_string())
.collect();
if headers.is_empty() {
continue;
}
if headers.is_empty() {}
} else {
let mut row_map = Map::new();
for (j, cell) in row.select(&td_sel).enumerate() {

View file

@ -1,8 +1,3 @@
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
@ -14,11 +9,7 @@ use std::sync::Arc;
use super::CalendarEngine;
use crate::shared::state::AppState;
pub fn create_caldav_router(_engine: Arc<CalendarEngine>) -> Router<Arc<AppState>> {
Router::new()
.route("/caldav", get(caldav_root))
.route("/caldav/principals", get(caldav_principals))
@ -30,7 +21,6 @@ pub fn create_caldav_router(_engine: Arc<CalendarEngine>) -> Router<Arc<AppState
)
}
async fn caldav_root() -> impl IntoResponse {
Response::builder()
.status(StatusCode::OK)
@ -57,7 +47,6 @@ async fn caldav_root() -> impl IntoResponse {
.unwrap()
}
async fn caldav_principals() -> impl IntoResponse {
Response::builder()
.status(StatusCode::OK)
@ -86,7 +75,6 @@ async fn caldav_principals() -> impl IntoResponse {
.unwrap()
}
async fn caldav_calendars() -> impl IntoResponse {
Response::builder()
.status(StatusCode::OK)
@ -129,7 +117,6 @@ async fn caldav_calendars() -> impl IntoResponse {
.unwrap()
}
async fn caldav_calendar() -> impl IntoResponse {
Response::builder()
.status(StatusCode::OK)
@ -156,14 +143,12 @@ async fn caldav_calendar() -> impl IntoResponse {
.unwrap()
}
async fn caldav_event() -> impl IntoResponse {
Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "text/calendar; charset=utf-8")
.body(
r#"BEGIN:VCALENDAR
r"BEGIN:VCALENDAR
VERSION:2.0
PRODID:-
BEGIN:VEVENT
@ -173,15 +158,13 @@ DTSTART:20240101T090000Z
DTEND:20240101T100000Z
SUMMARY:Placeholder Event
END:VEVENT
END:VCALENDAR"#
END:VCALENDAR"
.to_string(),
)
.unwrap()
}
async fn caldav_put_event() -> impl IntoResponse {
Response::builder()
.status(StatusCode::CREATED)
.header("ETag", "\"placeholder-etag\"")

View file

@ -1,16 +1,9 @@
use chrono::{DateTime, Utc};
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use walkdir::WalkDir;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
#[serde(rename_all = "lowercase")]
pub enum IssueSeverity {
@ -33,7 +26,6 @@ impl std::fmt::Display for IssueSeverity {
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum IssueType {
@ -64,7 +56,6 @@ impl std::fmt::Display for IssueType {
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CodeIssue {
pub id: String,
@ -80,7 +71,6 @@ pub struct CodeIssue {
pub detected_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BotScanResult {
pub bot_id: String,
@ -91,7 +81,6 @@ pub struct BotScanResult {
pub stats: ScanStats,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ScanStats {
pub critical: usize,
@ -124,7 +113,6 @@ impl ScanStats {
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComplianceScanResult {
pub scanned_at: DateTime<Utc>,
@ -135,7 +123,6 @@ pub struct ComplianceScanResult {
pub bot_results: Vec<BotScanResult>,
}
struct ScanPattern {
regex: Regex,
issue_type: IssueType,
@ -146,14 +133,12 @@ struct ScanPattern {
category: String,
}
pub struct CodeScanner {
patterns: Vec<ScanPattern>,
base_path: PathBuf,
}
impl CodeScanner {
pub fn new(base_path: impl AsRef<Path>) -> Self {
let patterns = Self::build_patterns();
Self {
@ -162,11 +147,9 @@ impl CodeScanner {
}
}
fn build_patterns() -> Vec<ScanPattern> {
let mut patterns = Vec::new();
patterns.push(ScanPattern {
regex: Regex::new(r#"(?i)password\s*=\s*["'][^"']+["']"#).unwrap(),
issue_type: IssueType::PasswordInConfig,
@ -197,9 +180,8 @@ impl CodeScanner {
category: "Security".to_string(),
});
patterns.push(ScanPattern {
regex: Regex::new(r#"(?i)IF\s+.*\binput\b"#).unwrap(),
regex: Regex::new(r"(?i)IF\s+.*\binput\b").unwrap(),
issue_type: IssueType::DeprecatedIfInput,
severity: IssueSeverity::Medium,
title: "Deprecated IF...input Pattern".to_string(),
@ -211,9 +193,8 @@ impl CodeScanner {
category: "Code Quality".to_string(),
});
patterns.push(ScanPattern {
regex: Regex::new(r#"(?i)\b(GET_BOT_MEMORY|SET_BOT_MEMORY|GET_USER_MEMORY|SET_USER_MEMORY|USE_KB|USE_TOOL|SEND_MAIL|CREATE_TASK)\b"#).unwrap(),
regex: Regex::new(r"(?i)\b(GET_BOT_MEMORY|SET_BOT_MEMORY|GET_USER_MEMORY|SET_USER_MEMORY|USE_KB|USE_TOOL|SEND_MAIL|CREATE_TASK)\b").unwrap(),
issue_type: IssueType::UnderscoreInKeyword,
severity: IssueSeverity::Low,
title: "Underscore in Keyword".to_string(),
@ -222,9 +203,8 @@ impl CodeScanner {
category: "Naming Convention".to_string(),
});
patterns.push(ScanPattern {
regex: Regex::new(r#"(?i)POST\s+TO\s+INSTAGRAM\s+\w+\s*,\s*\w+"#).unwrap(),
regex: Regex::new(r"(?i)POST\s+TO\s+INSTAGRAM\s+\w+\s*,\s*\w+").unwrap(),
issue_type: IssueType::InsecurePattern,
severity: IssueSeverity::High,
title: "Instagram Credentials in Code".to_string(),
@ -235,10 +215,8 @@ impl CodeScanner {
category: "Security".to_string(),
});
patterns.push(ScanPattern {
regex: Regex::new(r#"(?i)(SELECT|INSERT|UPDATE|DELETE)\s+.*(FROM|INTO|SET)\s+"#)
.unwrap(),
regex: Regex::new(r"(?i)(SELECT|INSERT|UPDATE|DELETE)\s+.*(FROM|INTO|SET)\s+").unwrap(),
issue_type: IssueType::FragileCode,
severity: IssueSeverity::Medium,
title: "Raw SQL Query".to_string(),
@ -250,9 +228,8 @@ impl CodeScanner {
category: "Security".to_string(),
});
patterns.push(ScanPattern {
regex: Regex::new(r#"(?i)\bEVAL\s*\("#).unwrap(),
regex: Regex::new(r"(?i)\bEVAL\s*\(").unwrap(),
issue_type: IssueType::FragileCode,
severity: IssueSeverity::High,
title: "Dynamic Code Execution".to_string(),
@ -261,7 +238,6 @@ impl CodeScanner {
category: "Security".to_string(),
});
patterns.push(ScanPattern {
regex: Regex::new(
r#"(?i)(password|secret|key|token)\s*=\s*["'][A-Za-z0-9+/=]{40,}["']"#,
@ -276,9 +252,8 @@ impl CodeScanner {
category: "Security".to_string(),
});
patterns.push(ScanPattern {
regex: Regex::new(r#"(?i)(AKIA[0-9A-Z]{16})"#).unwrap(),
regex: Regex::new(r"(?i)(AKIA[0-9A-Z]{16})").unwrap(),
issue_type: IssueType::HardcodedSecret,
severity: IssueSeverity::Critical,
title: "AWS Access Key".to_string(),
@ -288,9 +263,8 @@ impl CodeScanner {
category: "Security".to_string(),
});
patterns.push(ScanPattern {
regex: Regex::new(r#"-----BEGIN\s+(RSA\s+)?PRIVATE\s+KEY-----"#).unwrap(),
regex: Regex::new(r"-----BEGIN\s+(RSA\s+)?PRIVATE\s+KEY-----").unwrap(),
issue_type: IssueType::HardcodedSecret,
severity: IssueSeverity::Critical,
title: "Private Key in Code".to_string(),
@ -300,9 +274,8 @@ impl CodeScanner {
category: "Security".to_string(),
});
patterns.push(ScanPattern {
regex: Regex::new(r#"(?i)(postgres|mysql|mongodb|redis)://[^:]+:[^@]+@"#).unwrap(),
regex: Regex::new(r"(?i)(postgres|mysql|mongodb|redis)://[^:]+:[^@]+@").unwrap(),
issue_type: IssueType::HardcodedSecret,
severity: IssueSeverity::Critical,
title: "Database Credentials in Connection String".to_string(),
@ -314,7 +287,6 @@ impl CodeScanner {
patterns
}
pub async fn scan_all(
&self,
) -> Result<ComplianceScanResult, Box<dyn std::error::Error + Send + Sync>> {
@ -323,13 +295,11 @@ impl CodeScanner {
let mut total_stats = ScanStats::default();
let mut total_files = 0;
let templates_path = self.base_path.join("templates");
let work_path = self.base_path.join("work");
let mut bot_paths = Vec::new();
if templates_path.exists() {
for entry in WalkDir::new(&templates_path).max_depth(3) {
if let Ok(entry) = entry {
@ -344,7 +314,6 @@ impl CodeScanner {
}
}
if work_path.exists() {
for entry in WalkDir::new(&work_path).max_depth(3) {
if let Ok(entry) = entry {
@ -359,7 +328,6 @@ impl CodeScanner {
}
}
for bot_path in &bot_paths {
let result = self.scan_bot(bot_path).await?;
total_files += result.files_scanned;
@ -379,7 +347,6 @@ impl CodeScanner {
})
}
pub async fn scan_bot(
&self,
bot_path: &Path,
@ -397,7 +364,6 @@ impl CodeScanner {
let mut stats = ScanStats::default();
let mut files_scanned = 0;
for entry in WalkDir::new(bot_path) {
if let Ok(entry) = entry {
let path = entry.path();
@ -415,7 +381,6 @@ impl CodeScanner {
}
}
let config_path = bot_path.join("config.csv");
if config_path.exists() {
let vault_configured = self.check_vault_config(&config_path).await?;
@ -438,7 +403,6 @@ impl CodeScanner {
}
}
issues.sort_by(|a, b| b.severity.cmp(&a.severity));
Ok(BotScanResult {
@ -451,7 +415,6 @@ impl CodeScanner {
})
}
async fn scan_file(
&self,
file_path: &Path,
@ -468,7 +431,6 @@ impl CodeScanner {
for (line_number, line) in content.lines().enumerate() {
let line_num = line_number + 1;
let trimmed = line.trim();
if trimmed.starts_with("REM") || trimmed.starts_with("'") || trimmed.starts_with("//") {
continue;
@ -476,7 +438,6 @@ impl CodeScanner {
for pattern in &self.patterns {
if pattern.regex.is_match(line) {
let snippet = self.redact_sensitive(line);
let issue = CodeIssue {
@ -500,18 +461,15 @@ impl CodeScanner {
Ok(issues)
}
fn redact_sensitive(&self, line: &str) -> String {
let mut result = line.to_string();
let secret_pattern = Regex::new(r#"(["'])[^"']{8,}(["'])"#).unwrap();
result = secret_pattern
.replace_all(&result, "$1***REDACTED***$2")
.to_string();
let aws_pattern = Regex::new(r#"AKIA[0-9A-Z]{16}"#).unwrap();
let aws_pattern = Regex::new(r"AKIA[0-9A-Z]{16}").unwrap();
result = aws_pattern
.replace_all(&result, "AKIA***REDACTED***")
.to_string();
@ -519,14 +477,12 @@ impl CodeScanner {
result
}
async fn check_vault_config(
&self,
config_path: &Path,
) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
let content = tokio::fs::read_to_string(config_path).await?;
let has_vault = content.to_lowercase().contains("vault_addr")
|| content.to_lowercase().contains("vault_token")
|| content.to_lowercase().contains("vault-");
@ -534,7 +490,6 @@ impl CodeScanner {
Ok(has_vault)
}
pub async fn scan_bots(
&self,
bot_ids: &[String],
@ -543,14 +498,11 @@ impl CodeScanner {
return self.scan_all().await;
}
let mut full_result = self.scan_all().await?;
full_result
.bot_results
.retain(|r| bot_ids.contains(&r.bot_id) || bot_ids.contains(&r.bot_name));
let mut new_stats = ScanStats::default();
for bot in &full_result.bot_results {
new_stats.merge(&bot.stats);
@ -562,11 +514,9 @@ impl CodeScanner {
}
}
pub struct ComplianceReporter;
impl ComplianceReporter {
pub fn to_html(result: &ComplianceScanResult) -> String {
let mut html = String::new();
@ -617,12 +567,10 @@ impl ComplianceReporter {
html
}
pub fn to_json(result: &ComplianceScanResult) -> Result<String, serde_json::Error> {
serde_json::to_string_pretty(result)
}
pub fn to_csv(result: &ComplianceScanResult) -> String {
let mut csv = String::new();
csv.push_str("Severity,Type,Category,File,Line,Title,Description,Remediation\n");
@ -650,7 +598,6 @@ impl ComplianceReporter {
}
}
fn escape_csv(s: &str) -> String {
if s.contains(',') || s.contains('"') || s.contains('\n') {
format!("\"{}\"", s.replace('"', "\"\""))

View file

@ -1,13 +1,3 @@
use crate::shared::platform_name;
use crate::shared::BOTSERVER_VERSION;
use crossterm::{
@ -21,38 +11,27 @@ use serde::{Deserialize, Serialize};
use std::io::{self, Write};
use std::path::PathBuf;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WizardConfig {
pub llm_provider: LlmProvider,
pub llm_api_key: Option<String>,
pub local_model_path: Option<String>,
pub components: Vec<ComponentChoice>,
pub admin: AdminConfig,
pub organization: OrgConfig,
pub template: Option<String>,
pub install_mode: InstallMode,
pub data_dir: PathBuf,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum LlmProvider {
Claude,
@ -74,7 +53,6 @@ impl std::fmt::Display for LlmProvider {
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ComponentChoice {
Drive,
@ -104,7 +82,6 @@ impl std::fmt::Display for ComponentChoice {
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct AdminConfig {
pub username: String,
@ -113,7 +90,6 @@ pub struct AdminConfig {
pub display_name: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct OrgConfig {
pub name: String,
@ -121,7 +97,6 @@ pub struct OrgConfig {
pub domain: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum InstallMode {
Development,
@ -149,7 +124,6 @@ impl Default for WizardConfig {
}
}
#[derive(Debug)]
pub struct StartupWizard {
config: WizardConfig,
@ -166,12 +140,10 @@ impl StartupWizard {
}
}
pub fn run(&mut self) -> io::Result<WizardConfig> {
terminal::enable_raw_mode()?;
let mut stdout = io::stdout();
execute!(
stdout,
terminal::Clear(ClearType::All),
@ -181,31 +153,24 @@ impl StartupWizard {
self.show_welcome(&mut stdout)?;
self.wait_for_enter()?;
self.current_step = 1;
self.step_install_mode(&mut stdout)?;
self.current_step = 2;
self.step_llm_provider(&mut stdout)?;
self.current_step = 3;
self.step_components(&mut stdout)?;
self.current_step = 4;
self.step_organization(&mut stdout)?;
self.current_step = 5;
self.step_admin_user(&mut stdout)?;
self.current_step = 6;
self.step_template(&mut stdout)?;
self.current_step = 7;
self.step_summary(&mut stdout)?;
@ -220,7 +185,7 @@ impl StartupWizard {
cursor::MoveTo(0, 0)
)?;
let banner = r#"
let banner = r"
@ -237,7 +202,7 @@ impl StartupWizard {
"#;
";
execute!(
stdout,
@ -288,7 +253,6 @@ impl StartupWizard {
cursor::MoveTo(0, 0)
)?;
let progress = format!("Step {}/{}: {}", self.current_step, self.total_steps, title);
let bar_width = 50;
let filled = (self.current_step * bar_width) / self.total_steps;
@ -396,7 +360,6 @@ impl StartupWizard {
let selected = self.select_option(stdout, &options, 0)?;
self.config.llm_provider = options[selected].2.clone();
if self.config.llm_provider != LlmProvider::Local
&& self.config.llm_provider != LlmProvider::None
{
@ -485,7 +448,6 @@ impl StartupWizard {
io::stdin().read_line(&mut org_name)?;
self.config.organization.name = org_name.trim().to_string();
self.config.organization.slug = self
.config
.organization
@ -557,7 +519,6 @@ impl StartupWizard {
execute!(stdout, cursor::MoveTo(2, 13), Print("Admin password: "))?;
stdout.flush()?;
let mut password = String::new();
io::stdin().read_line(&mut password)?;
self.config.admin.password = password.trim().to_string();
@ -824,7 +785,6 @@ impl StartupWizard {
}
KeyCode::Char(' ') => {
if options[cursor].2 {
selected[cursor] = !selected[cursor];
}
}
@ -857,7 +817,6 @@ impl StartupWizard {
}
}
pub fn save_wizard_config(config: &WizardConfig, path: &str) -> io::Result<()> {
let content = toml::to_string_pretty(config)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
@ -865,7 +824,6 @@ pub fn save_wizard_config(config: &WizardConfig, path: &str) -> io::Result<()> {
Ok(())
}
pub fn load_wizard_config(path: &str) -> io::Result<WizardConfig> {
let content = std::fs::read_to_string(path)?;
let config: WizardConfig =
@ -873,32 +831,26 @@ pub fn load_wizard_config(path: &str) -> io::Result<WizardConfig> {
Ok(config)
}
pub fn should_run_wizard() -> bool {
!std::path::Path::new("./botserver-stack").exists()
&& !std::path::Path::new("/opt/gbo").exists()
}
pub fn apply_wizard_config(config: &WizardConfig) -> io::Result<()> {
use std::fs;
fs::create_dir_all(&config.data_dir)?;
let subdirs = ["bots", "logs", "cache", "uploads", "config"];
for subdir in &subdirs {
fs::create_dir_all(config.data_dir.join(subdir))?;
}
save_wizard_config(
config,
&config.data_dir.join("config/wizard.toml").to_string_lossy(),
)?;
let mut env_content = String::new();
env_content.push_str(&format!(
"# Generated by {} Setup Wizard\n\n",

File diff suppressed because it is too large Load diff

View file

@ -71,7 +71,6 @@ impl PackageManager {
.await?;
}
if !component.data_download_list.is_empty() {
let cache_base = self.base_path.parent().unwrap_or(&self.base_path);
let cache = DownloadCache::new(cache_base).ok();
@ -83,18 +82,15 @@ impl PackageManager {
.join(&component.name)
.join(&filename);
if output_path.exists() {
info!("Data file already exists: {:?}", output_path);
continue;
}
if let Some(parent) = output_path.parent() {
std::fs::create_dir_all(parent)?;
}
if let Some(ref c) = cache {
if let Some(cached_path) = c.get_cached_path(&filename) {
info!("Using cached data file: {:?}", cached_path);
@ -103,7 +99,6 @@ impl PackageManager {
}
}
let download_target = if let Some(ref c) = cache {
c.get_cache_path(&filename)
} else {
@ -114,7 +109,6 @@ impl PackageManager {
println!("Downloading {}", url);
utils::download_file(url, download_target.to_str().unwrap()).await?;
if cache.is_some() && download_target != output_path {
std::fs::copy(&download_target, &output_path)?;
info!("Copied cached file to: {:?}", output_path);
@ -127,10 +121,8 @@ impl PackageManager {
pub fn install_container(&self, component: &ComponentConfig) -> Result<InstallResult> {
let container_name = format!("{}-{}", self.tenant, component.name);
let _ = Command::new("lxd").args(&["init", "--auto"]).output();
let images = [
"ubuntu:24.04",
"ubuntu:22.04",
@ -157,14 +149,13 @@ impl PackageManager {
info!("Successfully created container with image: {}", image);
success = true;
break;
} else {
last_error = String::from_utf8_lossy(&output.stderr).to_string();
warn!("Failed to create container with {}: {}", image, last_error);
let _ = Command::new("lxc")
.args(&["delete", &container_name, "--force"])
.output();
}
last_error = String::from_utf8_lossy(&output.stderr).to_string();
warn!("Failed to create container with {}: {}", image, last_error);
let _ = Command::new("lxc")
.args(&["delete", &container_name, "--force"])
.output();
}
if !success {
@ -176,7 +167,6 @@ impl PackageManager {
std::thread::sleep(std::time::Duration::from_secs(15));
self.exec_in_container(&container_name, "mkdir -p /opt/gbo/{bin,data,conf,logs}")?;
self.exec_in_container(
&container_name,
"echo 'nameserver 8.8.8.8' > /etc/resolv.conf",
@ -186,7 +176,6 @@ impl PackageManager {
"echo 'nameserver 8.8.4.4' >> /etc/resolv.conf",
)?;
self.exec_in_container(&container_name, "apt-get update -qq")?;
self.exec_in_container(
&container_name,
@ -247,15 +236,12 @@ impl PackageManager {
}
self.setup_port_forwarding(&container_name, &component.ports)?;
let container_ip = self.get_container_ip(&container_name)?;
if component.name == "vault" {
self.initialize_vault(&container_name, &container_ip)?;
}
let (connection_info, env_vars) =
self.generate_connection_info(&component.name, &container_ip, &component.ports);
@ -275,12 +261,9 @@ impl PackageManager {
})
}
fn get_container_ip(&self, container_name: &str) -> Result<String> {
std::thread::sleep(std::time::Duration::from_secs(2));
let output = Command::new("lxc")
.args(&["list", container_name, "-c", "4", "--format", "csv"])
.output()?;
@ -289,7 +272,6 @@ impl PackageManager {
let ip_output = String::from_utf8_lossy(&output.stdout).trim().to_string();
if !ip_output.is_empty() {
let ip = ip_output
.split(|c| c == ' ' || c == '(')
.next()
@ -301,7 +283,6 @@ impl PackageManager {
}
}
let output = Command::new("lxc")
.args(&["exec", container_name, "--", "hostname", "-I"])
.output()?;
@ -318,15 +299,11 @@ impl PackageManager {
Ok("unknown".to_string())
}
fn initialize_vault(&self, container_name: &str, ip: &str) -> Result<()> {
info!("Initializing Vault...");
std::thread::sleep(std::time::Duration::from_secs(5));
let output = Command::new("lxc")
.args(&[
"exec",
@ -350,7 +327,6 @@ impl PackageManager {
let init_output = String::from_utf8_lossy(&output.stdout);
let init_json: serde_json::Value =
serde_json::from_str(&init_output).context("Failed to parse Vault init output")?;
@ -361,12 +337,10 @@ impl PackageManager {
.as_str()
.context("No root token in output")?;
let unseal_keys_file = PathBuf::from("vault-unseal-keys");
let mut unseal_content = String::new();
for (i, key) in unseal_keys.iter().enumerate() {
if i < 3 {
unseal_content.push_str(&format!(
"VAULT_UNSEAL_KEY_{}={}\n",
i + 1,
@ -376,7 +350,6 @@ impl PackageManager {
}
std::fs::write(&unseal_keys_file, &unseal_content)?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
@ -385,7 +358,6 @@ impl PackageManager {
info!("Created {}", unseal_keys_file.display());
let env_file = PathBuf::from(".env");
let env_content = format!(
"\n# Vault Configuration (auto-generated)\nVAULT_ADDR=http://{}:8200\nVAULT_TOKEN={}\nVAULT_UNSEAL_KEYS_FILE=vault-unseal-keys\n",
@ -393,11 +365,9 @@ impl PackageManager {
);
if env_file.exists() {
let existing = std::fs::read_to_string(&env_file)?;
if !existing.contains("VAULT_ADDR=") {
let mut file = std::fs::OpenOptions::new().append(true).open(&env_file)?;
use std::io::Write;
file.write_all(env_content.as_bytes())?;
@ -406,13 +376,10 @@ impl PackageManager {
warn!(".env already contains VAULT_ADDR, not overwriting");
}
} else {
std::fs::write(&env_file, env_content.trim_start())?;
info!("Created .env with Vault config");
}
for i in 0..3 {
if let Some(key) = unseal_keys.get(i) {
let key_str = key.as_str().unwrap_or("");
@ -434,8 +401,6 @@ impl PackageManager {
Ok(())
}
fn generate_connection_info(
&self,
component: &str,
@ -445,9 +410,8 @@ impl PackageManager {
let env_vars = HashMap::new();
let connection_info = match component {
"vault" => {
format!(
r#"Vault Server:
r"Vault Server:
URL: http://{}:8200
UI: http://{}:8200/ui
@ -466,141 +430,141 @@ Or manually:
lxc exec {}-vault -- /opt/gbo/bin/vault operator unseal <key>
For other auto-unseal options (TPM, HSM, Transit), see:
https://generalbots.github.io/botbook/chapter-08/secrets-management.html"#,
https://generalbots.github.io/botbook/chapter-08/secrets-management.html",
ip, ip, self.tenant
)
}
"vector_db" => {
format!(
r#"Qdrant Vector Database:
r"Qdrant Vector Database:
REST API: http://{}:6333
gRPC: {}:6334
Dashboard: http://{}:6333/dashboard
Store credentials in Vault:
botserver vault put gbo/vectordb host={} port=6333"#,
botserver vault put gbo/vectordb host={} port=6333",
ip, ip, ip, ip
)
}
"tables" => {
format!(
r#"PostgreSQL Database:
r"PostgreSQL Database:
Host: {}
Port: 5432
Database: botserver
User: gbuser
Store credentials in Vault:
botserver vault put gbo/tables host={} port=5432 database=botserver username=gbuser password=<your-password>"#,
botserver vault put gbo/tables host={} port=5432 database=botserver username=gbuser password=<your-password>",
ip, ip
)
}
"drive" => {
format!(
r#"MinIO Object Storage:
r"MinIO Object Storage:
API: http://{}:9000
Console: http://{}:9001
Store credentials in Vault:
botserver vault put gbo/drive server={} port=9000 accesskey=minioadmin secret=<your-secret>"#,
botserver vault put gbo/drive server={} port=9000 accesskey=minioadmin secret=<your-secret>",
ip, ip, ip
)
}
"cache" => {
format!(
r#"Redis/Valkey Cache:
r"Redis/Valkey Cache:
Host: {}
Port: 6379
Store credentials in Vault:
botserver vault put gbo/cache host={} port=6379 password=<your-password>"#,
botserver vault put gbo/cache host={} port=6379 password=<your-password>",
ip, ip
)
}
"email" => {
format!(
r#"Email Server (Stalwart):
r"Email Server (Stalwart):
SMTP: {}:25
IMAP: {}:143
Web: http://{}:8080
Store credentials in Vault:
botserver vault put gbo/email server={} port=25 username=admin password=<your-password>"#,
botserver vault put gbo/email server={} port=25 username=admin password=<your-password>",
ip, ip, ip, ip
)
}
"directory" => {
format!(
r#"Zitadel Identity Provider:
r"Zitadel Identity Provider:
URL: http://{}:8080
Console: http://{}:8080/ui/console
Store credentials in Vault:
botserver vault put gbo/directory url=http://{}:8080 client_id=<client-id> client_secret=<client-secret>"#,
botserver vault put gbo/directory url=http://{}:8080 client_id=<client-id> client_secret=<client-secret>",
ip, ip, ip
)
}
"llm" => {
format!(
r#"LLM Server (llama.cpp):
r"LLM Server (llama.cpp):
API: http://{}:8081
Test:
curl http://{}:8081/v1/models
Store credentials in Vault:
botserver vault put gbo/llm url=http://{}:8081 local=true"#,
botserver vault put gbo/llm url=http://{}:8081 local=true",
ip, ip, ip
)
}
"meeting" => {
format!(
r#"LiveKit Meeting Server:
r"LiveKit Meeting Server:
WebSocket: ws://{}:7880
API: http://{}:7880
Store credentials in Vault:
botserver vault put gbo/meet url=ws://{}:7880 api_key=<api-key> api_secret=<api-secret>"#,
botserver vault put gbo/meet url=ws://{}:7880 api_key=<api-key> api_secret=<api-secret>",
ip, ip, ip
)
}
"proxy" => {
format!(
r#"Caddy Reverse Proxy:
r"Caddy Reverse Proxy:
HTTP: http://{}:80
HTTPS: https://{}:443
Admin: http://{}:2019"#,
Admin: http://{}:2019",
ip, ip, ip
)
}
"timeseries_db" => {
format!(
r#"InfluxDB Time Series Database:
r"InfluxDB Time Series Database:
API: http://{}:8086
Store credentials in Vault:
botserver vault put gbo/observability url=http://{}:8086 token=<influx-token> org=pragmatismo bucket=metrics"#,
botserver vault put gbo/observability url=http://{}:8086 token=<influx-token> org=pragmatismo bucket=metrics",
ip, ip
)
}
"observability" => {
format!(
r#"Vector Log Aggregation:
r"Vector Log Aggregation:
API: http://{}:8686
Store credentials in Vault:
botserver vault put gbo/observability vector_url=http://{}:8686"#,
botserver vault put gbo/observability vector_url=http://{}:8686",
ip, ip
)
}
"alm" => {
format!(
r#"Forgejo Git Server:
r"Forgejo Git Server:
Web: http://{}:3000
SSH: {}:22
Store credentials in Vault:
botserver vault put gbo/alm url=http://{}:3000 token=<api-token>"#,
botserver vault put gbo/alm url=http://{}:3000 token=<api-token>",
ip, ip, ip
)
}
@ -611,11 +575,11 @@ Store credentials in Vault:
.collect::<Vec<_>>()
.join("\n");
format!(
r#"Component: {}
r"Component: {}
Container: {}-{}
IP: {}
Ports:
{}"#,
{}",
component, self.tenant, component, ip, ports_str
)
}
@ -740,7 +704,6 @@ Store credentials in Vault:
let bin_path = self.base_path.join("bin").join(component);
std::fs::create_dir_all(&bin_path)?;
let cache_base = self.base_path.parent().unwrap_or(&self.base_path);
let cache = DownloadCache::new(cache_base).unwrap_or_else(|e| {
warn!("Failed to initialize download cache: {}", e);
@ -748,7 +711,6 @@ Store credentials in Vault:
DownloadCache::new(&self.base_path).expect("Failed to create fallback cache")
});
let cache_result = cache.resolve_component_url(component, url);
let source_file = match cache_result {
@ -763,7 +725,6 @@ Store credentials in Vault:
info!("Downloading {} from {}", component, download_url);
println!("Downloading {}", download_url);
self.download_with_reqwest(&download_url, &cache_path, component)
.await?;
@ -772,7 +733,6 @@ Store credentials in Vault:
}
};
self.handle_downloaded_file(&source_file, &bin_path, binary_name)?;
Ok(())
}
@ -785,7 +745,6 @@ Store credentials in Vault:
const MAX_RETRIES: u32 = 3;
const RETRY_DELAY: std::time::Duration = std::time::Duration::from_secs(2);
if let Some(parent) = target_file.parent() {
std::fs::create_dir_all(parent)?;
}
@ -870,7 +829,6 @@ Store credentials in Vault:
} else {
let final_path = bin_path.join(temp_file.file_name().unwrap());
if temp_file.to_string_lossy().contains("botserver-installers") {
std::fs::copy(temp_file, &final_path)?;
} else {
@ -894,7 +852,6 @@ Store credentials in Vault:
));
}
if !temp_file.to_string_lossy().contains("botserver-installers") {
std::fs::remove_file(temp_file)?;
}
@ -912,7 +869,6 @@ Store credentials in Vault:
));
}
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
@ -935,8 +891,6 @@ Store credentials in Vault:
}
}
if !temp_file.to_string_lossy().contains("botserver-installers") {
std::fs::remove_file(temp_file)?;
}
@ -950,7 +904,6 @@ Store credentials in Vault:
) -> Result<()> {
let final_path = bin_path.join(name);
if temp_file.to_string_lossy().contains("botserver-installers") {
std::fs::copy(temp_file, &final_path)?;
} else {
@ -981,7 +934,6 @@ Store credentials in Vault:
PathBuf::from("/opt/gbo/data")
};
let conf_path = if target == "local" {
self.base_path.join("conf")
} else {
@ -993,15 +945,12 @@ Store credentials in Vault:
PathBuf::from("/opt/gbo/logs")
};
let db_password = match get_database_url_sync() {
Ok(url) => {
let (_, password, _, _, _) = parse_database_url(&url);
password
}
Err(_) => {
trace!("Vault not available for DB_PASSWORD, using empty string");
String::new()
}
@ -1124,8 +1073,6 @@ Store credentials in Vault:
exec_cmd: &str,
env_vars: &HashMap<String, String>,
) -> Result<()> {
let db_password = match get_database_url_sync() {
Ok(url) => {
let (_, password, _, _, _) = parse_database_url(&url);

View file

@ -5,7 +5,6 @@ use std::time::Duration;
use tokio::fs;
use tokio::time::sleep;
#[derive(Debug)]
pub struct EmailSetup {
base_url: String,
@ -34,7 +33,6 @@ pub struct EmailDomain {
impl EmailSetup {
pub fn new(base_url: String, config_path: PathBuf) -> Self {
let admin_user = format!(
"admin_{}@botserver.local",
uuid::Uuid::new_v4()
@ -53,7 +51,6 @@ impl EmailSetup {
}
}
fn generate_secure_password() -> String {
use rand::distr::Alphanumeric;
use rand::Rng;
@ -66,12 +63,10 @@ impl EmailSetup {
.collect()
}
pub async fn wait_for_ready(&self, max_attempts: u32) -> Result<()> {
log::info!("Waiting for Email service to be ready...");
for attempt in 1..=max_attempts {
if let Ok(_) = tokio::net::TcpStream::connect("127.0.0.1:25").await {
log::info!("Email service is ready!");
return Ok(());
@ -88,27 +83,22 @@ impl EmailSetup {
anyhow::bail!("Email service did not become ready in time")
}
pub async fn initialize(
&mut self,
directory_config_path: Option<PathBuf>,
) -> Result<EmailConfig> {
log::info!(" Initializing Email (Stalwart) server...");
if let Ok(existing_config) = self.load_existing_config().await {
log::info!("Email already initialized, using existing config");
return Ok(existing_config);
}
self.wait_for_ready(30).await?;
self.create_default_domain().await?;
log::info!(" Created default email domain: localhost");
let directory_integration = if let Some(dir_config_path) = directory_config_path {
match self.setup_directory_integration(&dir_config_path).await {
Ok(_) => {
@ -124,7 +114,6 @@ impl EmailSetup {
false
};
self.create_admin_account().await?;
log::info!(" Created admin email account: {}", self.admin_user);
@ -139,7 +128,6 @@ impl EmailSetup {
directory_integration,
};
self.save_config(&config).await?;
log::info!(" Saved Email configuration");
@ -151,14 +139,10 @@ impl EmailSetup {
Ok(config)
}
async fn create_default_domain(&self) -> Result<()> {
Ok(())
}
async fn create_admin_account(&self) -> Result<()> {
log::info!("Creating admin email account via Stalwart API...");
@ -166,14 +150,13 @@ impl EmailSetup {
.timeout(Duration::from_secs(30))
.build()?;
let api_url = format!("{}/api/account", self.base_url);
let account_data = serde_json::json!({
"name": self.admin_user,
"secret": self.admin_pass,
"description": "BotServer Admin Account",
"quota": 1073741824,
"quota": 1_073_741_824,
"type": "individual",
"emails": [self.admin_user.clone()],
"memberOf": ["administrators"],
@ -196,7 +179,6 @@ impl EmailSetup {
);
Ok(())
} else if resp.status().as_u16() == 409 {
log::info!("Admin email account already exists: {}", self.admin_user);
Ok(())
} else {
@ -222,7 +204,6 @@ impl EmailSetup {
}
}
async fn setup_directory_integration(&self, directory_config_path: &PathBuf) -> Result<()> {
let content = fs::read_to_string(directory_config_path).await?;
let dir_config: serde_json::Value = serde_json::from_str(&content)?;
@ -234,31 +215,25 @@ impl EmailSetup {
log::info!("Setting up OIDC authentication with Directory...");
log::info!("Issuer URL: {}", issuer_url);
Ok(())
}
async fn save_config(&self, config: &EmailConfig) -> Result<()> {
let json = serde_json::to_string_pretty(config)?;
fs::write(&self.config_path, json).await?;
Ok(())
}
async fn load_existing_config(&self) -> Result<EmailConfig> {
let content = fs::read_to_string(&self.config_path).await?;
let config: EmailConfig = serde_json::from_str(&content)?;
Ok(config)
}
pub async fn get_config(&self) -> Result<EmailConfig> {
self.load_existing_config().await
}
pub async fn create_user_mailbox(
&self,
_username: &str,
@ -267,20 +242,15 @@ impl EmailSetup {
) -> Result<()> {
log::info!("Creating mailbox for user: {}", email);
Ok(())
}
pub async fn sync_users_from_directory(&self, directory_config_path: &PathBuf) -> Result<()> {
log::info!("Syncing users from Directory to Email...");
let content = fs::read_to_string(directory_config_path).await?;
let dir_config: serde_json::Value = serde_json::from_str(&content)?;
if let Some(default_user) = dir_config.get("default_user") {
let email = default_user["email"].as_str().unwrap_or("");
let password = default_user["password"].as_str().unwrap_or("");
@ -296,7 +266,6 @@ impl EmailSetup {
}
}
pub async fn generate_email_config(
config_path: PathBuf,
data_path: PathBuf,
@ -352,7 +321,6 @@ store = "sqlite"
data_path.display()
);
if directory_integration {
config.push_str(
r#"

View file

@ -10,10 +10,7 @@ use uuid::Uuid;
#[cfg(feature = "vectordb")]
use qdrant_client::{
qdrant::{
vectors_config::Config, CreateCollection, Distance, PointStruct, VectorParams,
VectorsConfig,
},
qdrant::{Distance, PointStruct, VectorParams},
Qdrant,
};

File diff suppressed because it is too large Load diff

View file

@ -1,23 +1,3 @@
use anyhow::{anyhow, Context, Result};
use chrono::{DateTime, NaiveDate, Utc};
use reqwest::{Client, Method, StatusCode};
@ -26,23 +6,14 @@ use serde_json::{json, Value};
use std::time::Duration;
use tracing::{debug, error, info, warn};
const DEFAULT_TIMEOUT_SECS: u64 = 30;
pub const DEFAULT_QUEUE_POLL_INTERVAL_SECS: u64 = 30;
pub const DEFAULT_METRICS_POLL_INTERVAL_SECS: u64 = 60;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueueStatus {
pub is_running: bool,
pub total_queued: u64,
@ -50,10 +21,8 @@ pub struct QueueStatus {
pub messages: Vec<QueuedMessage>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueuedMessage {
pub id: String,
pub from: String,
@ -81,7 +50,6 @@ pub struct QueuedMessage {
pub queued_at: Option<DateTime<Utc>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum DeliveryStatus {
@ -94,7 +62,6 @@ pub enum DeliveryStatus {
Unknown,
}
#[derive(Debug, Clone, Deserialize)]
struct QueueListResponse {
#[serde(default)]
@ -103,9 +70,6 @@ struct QueueListResponse {
items: Vec<QueuedMessage>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum PrincipalType {
@ -119,10 +83,8 @@ pub enum PrincipalType {
Other,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Principal {
pub id: Option<u64>,
#[serde(rename = "type")]
@ -149,10 +111,8 @@ pub struct Principal {
pub disabled: bool,
}
#[derive(Debug, Clone, Serialize)]
pub struct AccountUpdate {
pub action: String,
pub field: String,
@ -161,7 +121,6 @@ pub struct AccountUpdate {
}
impl AccountUpdate {
pub fn set(field: &str, value: impl Into<Value>) -> Self {
Self {
action: "set".to_string(),
@ -170,7 +129,6 @@ impl AccountUpdate {
}
}
pub fn add_item(field: &str, value: impl Into<Value>) -> Self {
Self {
action: "addItem".to_string(),
@ -179,7 +137,6 @@ impl AccountUpdate {
}
}
pub fn remove_item(field: &str, value: impl Into<Value>) -> Self {
Self {
action: "removeItem".to_string(),
@ -188,7 +145,6 @@ impl AccountUpdate {
}
}
pub fn clear(field: &str) -> Self {
Self {
action: "clear".to_string(),
@ -198,12 +154,8 @@ impl AccountUpdate {
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AutoResponderConfig {
pub enabled: bool,
pub subject: String,
@ -235,7 +187,8 @@ impl Default for AutoResponderConfig {
Self {
enabled: false,
subject: "Out of Office".to_string(),
body_plain: "I am currently out of the office and will respond upon my return.".to_string(),
body_plain: "I am currently out of the office and will respond upon my return."
.to_string(),
body_html: None,
start_date: None,
end_date: None,
@ -245,10 +198,8 @@ impl Default for AutoResponderConfig {
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmailRule {
pub id: String,
pub name: String,
@ -270,10 +221,8 @@ fn default_stop_processing() -> bool {
true
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RuleCondition {
pub field: String,
pub operator: String,
@ -287,22 +236,16 @@ pub struct RuleCondition {
pub case_sensitive: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RuleAction {
pub action_type: String,
#[serde(default)]
pub value: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Metrics {
#[serde(default)]
pub messages_received: u64,
@ -334,10 +277,8 @@ pub struct Metrics {
pub extra: std::collections::HashMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LogEntry {
pub timestamp: DateTime<Utc>,
pub level: String,
@ -351,7 +292,6 @@ pub struct LogEntry {
pub context: Option<Value>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct LogList {
#[serde(default)]
@ -360,10 +300,8 @@ pub struct LogList {
pub items: Vec<LogEntry>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TraceEvent {
pub timestamp: DateTime<Utc>,
pub event_type: String,
@ -390,7 +328,6 @@ pub struct TraceEvent {
pub duration_ms: Option<u64>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct TraceList {
#[serde(default)]
@ -399,12 +336,8 @@ pub struct TraceList {
pub items: Vec<TraceEvent>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Report {
pub id: String,
pub report_type: String,
@ -423,7 +356,6 @@ pub struct Report {
pub data: Value,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ReportList {
#[serde(default)]
@ -432,12 +364,8 @@ pub struct ReportList {
pub items: Vec<Report>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpamClassifyRequest {
pub from: String,
pub to: Vec<String>,
@ -455,10 +383,8 @@ pub struct SpamClassifyRequest {
pub body: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpamClassifyResult {
pub score: f64,
pub classification: String,
@ -470,10 +396,8 @@ pub struct SpamClassifyResult {
pub action: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpamTest {
pub name: String,
pub score: f64,
@ -482,9 +406,6 @@ pub struct SpamTest {
pub description: Option<String>,
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum ApiResponse<T> {
@ -493,9 +414,6 @@ enum ApiResponse<T> {
Error { error: String },
}
#[derive(Debug, Clone)]
pub struct StalwartClient {
base_url: String,
@ -504,18 +422,6 @@ pub struct StalwartClient {
}
impl StalwartClient {
pub fn new(base_url: &str, token: &str) -> Self {
let http_client = Client::builder()
.timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
@ -529,7 +435,6 @@ impl StalwartClient {
}
}
pub fn with_timeout(base_url: &str, token: &str, timeout_secs: u64) -> Self {
let http_client = Client::builder()
.timeout(Duration::from_secs(timeout_secs))
@ -543,7 +448,6 @@ impl StalwartClient {
}
}
async fn request<T: DeserializeOwned>(
&self,
method: Method,
@ -563,20 +467,24 @@ impl StalwartClient {
req = req.header("Content-Type", "application/json").json(b);
}
let resp = req.send().await.context("Failed to send request to Stalwart")?;
let resp = req
.send()
.await
.context("Failed to send request to Stalwart")?;
let status = resp.status();
if !status.is_success() {
let error_text = resp.text().await.unwrap_or_else(|_| "Unknown error".to_string());
let error_text = resp
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
error!("Stalwart API error: {} - {}", status, error_text);
return Err(anyhow!("Stalwart API error ({}): {}", status, error_text));
}
let text = resp.text().await.context("Failed to read response body")?;
if text.is_empty() || text == "null" {
return serde_json::from_str("null")
.or_else(|_| serde_json::from_str("{}"))
.or_else(|_| serde_json::from_str("true"))
@ -586,8 +494,13 @@ impl StalwartClient {
serde_json::from_str(&text).context("Failed to parse Stalwart API response")
}
async fn request_raw(&self, method: Method, path: &str, body: &str, content_type: &str) -> Result<()> {
async fn request_raw(
&self,
method: Method,
path: &str,
body: &str,
content_type: &str,
) -> Result<()> {
let url = format!("{}{}", self.base_url, path);
debug!("Stalwart API raw request: {} {}", method, url);
@ -603,18 +516,16 @@ impl StalwartClient {
let status = resp.status();
if !status.is_success() {
let error_text = resp.text().await.unwrap_or_else(|_| "Unknown error".to_string());
let error_text = resp
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow!("Stalwart API error ({}): {}", status, error_text));
}
Ok(())
}
pub async fn get_queue_status(&self) -> Result<QueueStatus> {
let status: bool = self
.request(Method::GET, "/api/queue/status", None)
@ -636,7 +547,6 @@ impl StalwartClient {
})
}
pub async fn get_queued_message(&self, message_id: &str) -> Result<QueuedMessage> {
self.request(
Method::GET,
@ -646,7 +556,6 @@ impl StalwartClient {
.await
}
pub async fn list_queued_messages(
&self,
limit: u32,
@ -660,7 +569,6 @@ impl StalwartClient {
self.request(Method::GET, &path, None).await
}
pub async fn retry_delivery(&self, message_id: &str) -> Result<bool> {
self.request(
Method::PATCH,
@ -670,7 +578,6 @@ impl StalwartClient {
.await
}
pub async fn cancel_delivery(&self, message_id: &str) -> Result<bool> {
self.request(
Method::DELETE,
@ -680,17 +587,16 @@ impl StalwartClient {
.await
}
pub async fn stop_queue(&self) -> Result<bool> {
self.request(Method::PATCH, "/api/queue/status/stop", None).await
self.request(Method::PATCH, "/api/queue/status/stop", None)
.await
}
pub async fn start_queue(&self) -> Result<bool> {
self.request(Method::PATCH, "/api/queue/status/start", None).await
self.request(Method::PATCH, "/api/queue/status/start", None)
.await
}
pub async fn get_failed_delivery_count(&self) -> Result<u64> {
let resp: QueueListResponse = self
.request(
@ -702,11 +608,6 @@ impl StalwartClient {
Ok(resp.total)
}
pub async fn create_account(
&self,
email: &str,
@ -725,19 +626,19 @@ impl StalwartClient {
"roles": ["user"]
});
self.request(Method::POST, "/api/principal", Some(body)).await
self.request(Method::POST, "/api/principal", Some(body))
.await
}
pub async fn create_account_full(&self, principal: &Principal, password: &str) -> Result<u64> {
let mut body = serde_json::to_value(principal)?;
if let Some(obj) = body.as_object_mut() {
obj.insert("secrets".to_string(), json!([password]));
}
self.request(Method::POST, "/api/principal", Some(body)).await
self.request(Method::POST, "/api/principal", Some(body))
.await
}
pub async fn create_distribution_list(
&self,
name: &str,
@ -752,10 +653,10 @@ impl StalwartClient {
"description": format!("Distribution list: {}", name)
});
self.request(Method::POST, "/api/principal", Some(body)).await
self.request(Method::POST, "/api/principal", Some(body))
.await
}
pub async fn create_shared_mailbox(
&self,
name: &str,
@ -770,20 +671,15 @@ impl StalwartClient {
"description": format!("Shared mailbox: {}", name)
});
self.request(Method::POST, "/api/principal", Some(body)).await
self.request(Method::POST, "/api/principal", Some(body))
.await
}
pub async fn get_account(&self, account_id: &str) -> Result<Principal> {
self.request(
Method::GET,
&format!("/api/principal/{}", account_id),
None,
)
.await
self.request(Method::GET, &format!("/api/principal/{}", account_id), None)
.await
}
pub async fn get_account_by_email(&self, email: &str) -> Result<Principal> {
self.request(
Method::GET,
@ -793,8 +689,11 @@ impl StalwartClient {
.await
}
pub async fn update_account(&self, account_id: &str, updates: Vec<AccountUpdate>) -> Result<()> {
pub async fn update_account(
&self,
account_id: &str,
updates: Vec<AccountUpdate>,
) -> Result<()> {
let body: Vec<Value> = updates
.iter()
.map(|u| {
@ -815,7 +714,6 @@ impl StalwartClient {
Ok(())
}
pub async fn delete_account(&self, account_id: &str) -> Result<()> {
self.request::<Value>(
Method::DELETE,
@ -826,8 +724,10 @@ impl StalwartClient {
Ok(())
}
pub async fn list_principals(&self, principal_type: Option<PrincipalType>) -> Result<Vec<Principal>> {
pub async fn list_principals(
&self,
principal_type: Option<PrincipalType>,
) -> Result<Vec<Principal>> {
let path = match principal_type {
Some(t) => format!("/api/principal?type={:?}", t).to_lowercase(),
None => "/api/principal".to_string(),
@ -835,7 +735,6 @@ impl StalwartClient {
self.request(Method::GET, &path, None).await
}
pub async fn add_members(&self, account_id: &str, members: Vec<String>) -> Result<()> {
let updates: Vec<AccountUpdate> = members
.into_iter()
@ -844,7 +743,6 @@ impl StalwartClient {
self.update_account(account_id, updates).await
}
pub async fn remove_members(&self, account_id: &str, members: Vec<String>) -> Result<()> {
let updates: Vec<AccountUpdate> = members
.into_iter()
@ -853,11 +751,6 @@ impl StalwartClient {
self.update_account(account_id, updates).await
}
pub async fn set_auto_responder(
&self,
account_id: &str,
@ -879,7 +772,6 @@ impl StalwartClient {
Ok(script_id)
}
pub async fn disable_auto_responder(&self, account_id: &str) -> Result<()> {
let script_id = format!("{}_vacation", account_id);
@ -895,10 +787,9 @@ impl StalwartClient {
Ok(())
}
pub fn generate_vacation_sieve(&self, config: &AutoResponderConfig) -> String {
let mut script = String::from("require [\"vacation\", \"variables\", \"date\", \"relational\"];\n\n");
let mut script =
String::from("require [\"vacation\", \"variables\", \"date\", \"relational\"];\n\n");
if config.start_date.is_some() || config.end_date.is_some() {
script.push_str("# Date-based activation\n");
@ -920,7 +811,6 @@ impl StalwartClient {
script.push('\n');
}
let subject = config.subject.replace('"', "\\\"").replace('\n', " ");
let body = config.body_plain.replace('"', "\\\"").replace('\n', "\\n");
@ -932,7 +822,6 @@ impl StalwartClient {
script
}
pub async fn set_filter_rule(&self, account_id: &str, rule: &EmailRule) -> Result<String> {
let sieve_script = self.generate_filter_sieve(rule);
let script_id = format!("{}_filter_{}", account_id, rule.id);
@ -950,7 +839,6 @@ impl StalwartClient {
Ok(script_id)
}
pub async fn delete_filter_rule(&self, account_id: &str, rule_id: &str) -> Result<()> {
let script_id = format!("{}_filter_{}", account_id, rule_id);
@ -966,10 +854,10 @@ impl StalwartClient {
Ok(())
}
pub fn generate_filter_sieve(&self, rule: &EmailRule) -> String {
let mut script =
String::from("require [\"fileinto\", \"reject\", \"vacation\", \"imap4flags\", \"copy\"];\n\n");
let mut script = String::from(
"require [\"fileinto\", \"reject\", \"vacation\", \"imap4flags\", \"copy\"];\n\n",
);
script.push_str(&format!("# Rule: {}\n", rule.name));
@ -978,7 +866,6 @@ impl StalwartClient {
return script;
}
let mut conditions = Vec::new();
for condition in &rule.conditions {
let cond_str = self.generate_condition_sieve(condition);
@ -988,14 +875,11 @@ impl StalwartClient {
}
if conditions.is_empty() {
script.push_str("# Always applies\n");
} else {
script.push_str(&format!("if allof ({}) {{\n", conditions.join(", ")));
}
for action in &rule.actions {
let action_str = self.generate_action_sieve(action);
if !action_str.is_empty() {
@ -1007,7 +891,6 @@ impl StalwartClient {
}
}
if rule.stop_processing {
if conditions.is_empty() {
script.push_str("stop;\n");
@ -1016,7 +899,6 @@ impl StalwartClient {
}
}
if !conditions.is_empty() {
script.push_str("}\n");
}
@ -1024,7 +906,6 @@ impl StalwartClient {
script
}
pub fn generate_condition_sieve(&self, condition: &RuleCondition) -> String {
let field_header = match condition.field.as_str() {
"from" => "From",
@ -1044,17 +925,34 @@ impl StalwartClient {
let value = condition.value.replace('"', "\\\"");
match condition.operator.as_str() {
"contains" => format!("header :contains{} \"{}\" \"{}\"", comparator, field_header, value),
"equals" => format!("header :is{} \"{}\" \"{}\"", comparator, field_header, value),
"startsWith" => format!("header :matches{} \"{}\" \"{}*\"", comparator, field_header, value),
"endsWith" => format!("header :matches{} \"{}\" \"*{}\"", comparator, field_header, value),
"regex" => format!("header :regex{} \"{}\" \"{}\"", comparator, field_header, value),
"notContains" => format!("not header :contains{} \"{}\" \"{}\"", comparator, field_header, value),
"contains" => format!(
"header :contains{} \"{}\" \"{}\"",
comparator, field_header, value
),
"equals" => format!(
"header :is{} \"{}\" \"{}\"",
comparator, field_header, value
),
"startsWith" => format!(
"header :matches{} \"{}\" \"{}*\"",
comparator, field_header, value
),
"endsWith" => format!(
"header :matches{} \"{}\" \"*{}\"",
comparator, field_header, value
),
"regex" => format!(
"header :regex{} \"{}\" \"{}\"",
comparator, field_header, value
),
"notContains" => format!(
"not header :contains{} \"{}\" \"{}\"",
comparator, field_header, value
),
_ => String::new(),
}
}
pub fn generate_action_sieve(&self, action: &RuleAction) -> String {
match action.action_type.as_str() {
"move" => format!("fileinto \"{}\";", action.value.replace('"', "\\\"")),
@ -1068,16 +966,11 @@ impl StalwartClient {
}
}
pub async fn get_metrics(&self) -> Result<Metrics> {
self.request(Method::GET, "/api/telemetry/metrics", None).await
self.request(Method::GET, "/api/telemetry/metrics", None)
.await
}
pub async fn get_logs(&self, page: u32, limit: u32) -> Result<LogList> {
self.request(
Method::GET,
@ -1087,7 +980,6 @@ impl StalwartClient {
.await
}
pub async fn get_logs_by_level(&self, level: &str, page: u32, limit: u32) -> Result<LogList> {
self.request(
Method::GET,
@ -1097,7 +989,6 @@ impl StalwartClient {
.await
}
pub async fn get_traces(&self, trace_type: &str, page: u32) -> Result<TraceList> {
self.request(
Method::GET,
@ -1110,7 +1001,6 @@ impl StalwartClient {
.await
}
pub async fn get_recent_traces(&self, limit: u32) -> Result<TraceList> {
self.request(
Method::GET,
@ -1120,7 +1010,6 @@ impl StalwartClient {
.await
}
pub async fn get_trace(&self, trace_id: &str) -> Result<Vec<TraceEvent>> {
self.request(
Method::GET,
@ -1130,7 +1019,6 @@ impl StalwartClient {
.await
}
pub async fn get_dmarc_reports(&self, page: u32) -> Result<ReportList> {
self.request(
Method::GET,
@ -1140,7 +1028,6 @@ impl StalwartClient {
.await
}
pub async fn get_tls_reports(&self, page: u32) -> Result<ReportList> {
self.request(
Method::GET,
@ -1150,7 +1037,6 @@ impl StalwartClient {
.await
}
pub async fn get_arf_reports(&self, page: u32) -> Result<ReportList> {
self.request(
Method::GET,
@ -1160,23 +1046,16 @@ impl StalwartClient {
.await
}
pub async fn get_live_metrics_token(&self) -> Result<String> {
self.request(Method::GET, "/api/telemetry/live/metrics-token", None)
.await
}
pub async fn get_live_tracing_token(&self) -> Result<String> {
self.request(Method::GET, "/api/telemetry/live/tracing-token", None)
.await
}
pub async fn train_spam(&self, raw_message: &str) -> Result<()> {
self.request_raw(
Method::POST,
@ -1189,7 +1068,6 @@ impl StalwartClient {
Ok(())
}
pub async fn train_ham(&self, raw_message: &str) -> Result<()> {
self.request_raw(
Method::POST,
@ -1202,8 +1080,10 @@ impl StalwartClient {
Ok(())
}
pub async fn classify_message(&self, message: &SpamClassifyRequest) -> Result<SpamClassifyResult> {
pub async fn classify_message(
&self,
message: &SpamClassifyRequest,
) -> Result<SpamClassifyResult> {
self.request(
Method::POST,
"/api/spam-filter/classify",
@ -1212,21 +1092,18 @@ impl StalwartClient {
.await
}
pub async fn troubleshoot_delivery(&self, recipient: &str) -> Result<Value> {
self.request(
Method::GET,
&format!("/api/troubleshoot/delivery/{}", urlencoding::encode(recipient)),
&format!(
"/api/troubleshoot/delivery/{}",
urlencoding::encode(recipient)
),
None,
)
.await
}
pub async fn check_dmarc(&self, domain: &str, from_email: &str) -> Result<Value> {
let body = json!({
"domain": domain,
@ -1236,17 +1113,11 @@ impl StalwartClient {
.await
}
pub async fn get_dns_records(&self, domain: &str) -> Result<Value> {
self.request(
Method::GET,
&format!("/api/dns/records/{}", domain),
None,
)
.await
self.request(Method::GET, &format!("/api/dns/records/{}", domain), None)
.await
}
pub async fn undelete_messages(&self, account_id: &str) -> Result<Value> {
self.request(
Method::POST,
@ -1256,7 +1127,6 @@ impl StalwartClient {
.await
}
pub async fn purge_account(&self, account_id: &str) -> Result<()> {
self.request::<Value>(
Method::GET,
@ -1268,9 +1138,11 @@ impl StalwartClient {
Ok(())
}
pub async fn health_check(&self) -> Result<bool> {
match self.request::<Value>(Method::GET, "/api/queue/status", None).await {
match self
.request::<Value>(Method::GET, "/api/queue/status", None)
.await
{
Ok(_) => Ok(true),
Err(e) => {
warn!("Stalwart health check failed: {}", e);
@ -1279,5 +1151,3 @@ impl StalwartClient {
}
}
}

View file

@ -4,7 +4,6 @@ use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::sync::Arc;
use tokio::fs;
use uuid::Uuid;
#[cfg(feature = "vectordb")]
@ -13,7 +12,6 @@ use qdrant_client::{
Qdrant,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmailDocument {
pub id: String,
@ -29,7 +27,6 @@ pub struct EmailDocument {
pub thread_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmailSearchQuery {
pub query_text: String,
@ -40,7 +37,6 @@ pub struct EmailSearchQuery {
pub limit: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmailSearchResult {
pub email: EmailDocument,
@ -48,7 +44,6 @@ pub struct EmailSearchResult {
pub snippet: String,
}
pub struct UserEmailVectorDB {
user_id: Uuid,
bot_id: Uuid,
@ -59,7 +54,6 @@ pub struct UserEmailVectorDB {
}
impl UserEmailVectorDB {
pub fn new(user_id: Uuid, bot_id: Uuid, db_path: PathBuf) -> Self {
let collection_name = format!("emails_{}_{}", bot_id, user_id);
@ -73,12 +67,10 @@ impl UserEmailVectorDB {
}
}
#[cfg(feature = "vectordb")]
pub async fn initialize(&mut self, qdrant_url: &str) -> Result<()> {
let client = Qdrant::from_url(qdrant_url).build()?;
let collections = client.list_collections().await?;
let exists = collections
.collections
@ -86,7 +78,6 @@ impl UserEmailVectorDB {
.any(|c| c.name == self.collection_name);
if !exists {
client
.create_collection(
qdrant_client::qdrant::CreateCollectionBuilder::new(&self.collection_name)
@ -111,7 +102,6 @@ impl UserEmailVectorDB {
Ok(())
}
#[cfg(feature = "vectordb")]
pub async fn index_email(&self, email: &EmailDocument, embedding: Vec<f32>) -> Result<()> {
let client = self
@ -131,9 +121,10 @@ impl UserEmailVectorDB {
let point = PointStruct::new(email.id.clone(), embedding, payload);
client
.upsert_points(
qdrant_client::qdrant::UpsertPointsBuilder::new(&self.collection_name, vec![point]),
)
.upsert_points(qdrant_client::qdrant::UpsertPointsBuilder::new(
&self.collection_name,
vec![point],
))
.await?;
log::debug!("Indexed email: {} - {}", email.id, email.subject);
@ -142,14 +133,12 @@ impl UserEmailVectorDB {
#[cfg(not(feature = "vectordb"))]
pub async fn index_email(&self, email: &EmailDocument, _embedding: Vec<f32>) -> Result<()> {
let file_path = self.db_path.join(format!("{}.json", email.id));
let json = serde_json::to_string_pretty(email)?;
fs::write(file_path, json).await?;
Ok(())
}
pub async fn index_emails_batch(&self, emails: &[(EmailDocument, Vec<f32>)]) -> Result<()> {
for (email, embedding) in emails {
self.index_email(email, embedding.clone()).await?;
@ -157,7 +146,6 @@ impl UserEmailVectorDB {
Ok(())
}
#[cfg(feature = "vectordb")]
pub async fn search(
&self,
@ -169,7 +157,6 @@ impl UserEmailVectorDB {
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Vector DB not initialized"))?;
let filter = if query.account_id.is_some() || query.folder.is_some() {
let mut conditions = vec![];
@ -210,7 +197,8 @@ impl UserEmailVectorDB {
let payload = &point.payload;
if !payload.is_empty() {
let get_str = |key: &str| -> String {
payload.get(key)
payload
.get(key)
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.unwrap_or_default()
@ -227,10 +215,12 @@ impl UserEmailVectorDB {
date: chrono::Utc::now(),
folder: get_str("folder"),
has_attachments: false,
thread_id: payload.get("thread_id").and_then(|v| v.as_str()).map(|s| s.to_string()),
thread_id: payload
.get("thread_id")
.and_then(|v| v.as_str())
.map(|s| s.to_string()),
};
let snippet = if email.body_text.len() > 200 {
format!("{}...", &email.body_text[..200])
} else {
@ -254,7 +244,6 @@ impl UserEmailVectorDB {
query: &EmailSearchQuery,
_query_embedding: Vec<f32>,
) -> Result<Vec<EmailSearchResult>> {
let mut results = Vec::new();
let mut entries = fs::read_dir(&self.db_path).await?;
@ -262,7 +251,6 @@ impl UserEmailVectorDB {
if entry.path().extension().and_then(|s| s.to_str()) == Some("json") {
let content = fs::read_to_string(entry.path()).await?;
if let Ok(email) = serde_json::from_str::<EmailDocument>(&content) {
let query_lower = query.query_text.to_lowercase();
if email.subject.to_lowercase().contains(&query_lower)
|| email.body_text.to_lowercase().contains(&query_lower)
@ -291,7 +279,6 @@ impl UserEmailVectorDB {
Ok(results)
}
#[cfg(feature = "vectordb")]
pub async fn delete_email(&self, email_id: &str) -> Result<()> {
let client = self
@ -301,8 +288,9 @@ impl UserEmailVectorDB {
client
.delete_points(
qdrant_client::qdrant::DeletePointsBuilder::new(&self.collection_name)
.points(vec![qdrant_client::qdrant::PointId::from(email_id.to_string())]),
qdrant_client::qdrant::DeletePointsBuilder::new(&self.collection_name).points(
vec![qdrant_client::qdrant::PointId::from(email_id.to_string())],
),
)
.await?;
@ -319,7 +307,6 @@ impl UserEmailVectorDB {
Ok(())
}
#[cfg(feature = "vectordb")]
pub async fn get_count(&self) -> Result<u64> {
let client = self
@ -346,7 +333,6 @@ impl UserEmailVectorDB {
Ok(count)
}
#[cfg(feature = "vectordb")]
pub async fn clear(&self) -> Result<()> {
let client = self
@ -354,10 +340,7 @@ impl UserEmailVectorDB {
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Vector DB not initialized"))?;
client
.delete_collection(&self.collection_name)
.await?;
client.delete_collection(&self.collection_name).await?;
client
.create_collection(
@ -384,7 +367,6 @@ impl UserEmailVectorDB {
}
}
pub struct EmailEmbeddingGenerator {
pub llm_endpoint: String,
}
@ -394,15 +376,12 @@ impl EmailEmbeddingGenerator {
Self { llm_endpoint }
}
pub async fn generate_embedding(&self, email: &EmailDocument) -> Result<Vec<f32>> {
let text = format!(
"From: {} <{}>\nSubject: {}\n\n{}",
email.from_name, email.from_email, email.subject, email.body_text
);
let text = if text.len() > 8000 {
&text[..8000]
} else {
@ -412,17 +391,11 @@ impl EmailEmbeddingGenerator {
self.generate_text_embedding(text).await
}
pub async fn generate_text_embedding(&self, text: &str) -> Result<Vec<f32>> {
let embedding_url = "http://localhost:8082".to_string();
return self.generate_local_embedding(text, &embedding_url).await;
self.generate_hash_embedding(text)
self.generate_local_embedding(text, &embedding_url).await
}
async fn generate_openai_embedding(&self, text: &str, api_key: &str) -> Result<Vec<f32>> {
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
use serde_json::json;
@ -462,7 +435,6 @@ impl EmailEmbeddingGenerator {
Ok(embedding)
}
async fn generate_local_embedding(&self, text: &str, embedding_url: &str) -> Result<Vec<f32>> {
use serde_json::json;
@ -492,7 +464,6 @@ impl EmailEmbeddingGenerator {
Ok(embedding)
}
fn generate_hash_embedding(&self, text: &str) -> Result<Vec<f32>> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
@ -500,7 +471,6 @@ impl EmailEmbeddingGenerator {
const EMBEDDING_DIM: usize = 1536;
let mut embedding = vec![0.0f32; EMBEDDING_DIM];
let words: Vec<&str> = text.split_whitespace().collect();
for (i, chunk) in words.chunks(10).enumerate() {
@ -508,7 +478,6 @@ impl EmailEmbeddingGenerator {
chunk.join(" ").hash(&mut hasher);
let hash = hasher.finish();
for j in 0..64 {
let idx = (i * 64 + j) % EMBEDDING_DIM;
let value = ((hash >> j) & 1) as f32;
@ -516,7 +485,6 @@ impl EmailEmbeddingGenerator {
}
}
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for val in &mut embedding {

View file

@ -1,36 +1,3 @@
use chrono::{DateTime, Utc};
use rhai::{Dynamic, Engine, Map};
use serde::{Deserialize, Serialize};
@ -41,10 +8,8 @@ use tokio::sync::RwLock;
use tracing::info;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LLMRequestMetrics {
pub request_id: Uuid,
pub session_id: Uuid,
@ -78,7 +43,6 @@ pub struct LLMRequestMetrics {
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "lowercase")]
pub enum RequestType {
@ -113,10 +77,8 @@ impl std::fmt::Display for RequestType {
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct AggregatedMetrics {
pub period_start: DateTime<Utc>,
pub period_end: DateTime<Utc>,
@ -162,10 +124,8 @@ pub struct AggregatedMetrics {
pub errors_by_type: HashMap<String, u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelPricing {
pub model: String,
pub input_cost_per_1k: f64,
@ -189,10 +149,8 @@ impl Default for ModelPricing {
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Budget {
pub daily_limit: f64,
pub monthly_limit: f64,
@ -229,10 +187,8 @@ impl Default for Budget {
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TraceEvent {
pub id: Uuid,
pub parent_id: Option<Uuid>,
@ -258,7 +214,6 @@ pub struct TraceEvent {
pub error: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum TraceEventType {
@ -268,7 +223,6 @@ pub enum TraceEventType {
Metric,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum TraceStatus {
@ -283,10 +237,8 @@ impl Default for TraceStatus {
}
}
#[derive(Debug, Clone)]
pub struct ObservabilityConfig {
pub enabled: bool,
pub metrics_interval: u64,
@ -322,11 +274,9 @@ impl Default for ObservabilityConfig {
}
}
fn default_model_pricing() -> HashMap<String, ModelPricing> {
let mut pricing = HashMap::new();
pricing.insert(
"gpt-4".to_string(),
ModelPricing {
@ -360,7 +310,6 @@ fn default_model_pricing() -> HashMap<String, ModelPricing> {
},
);
pricing.insert(
"claude-3-opus".to_string(),
ModelPricing {
@ -383,7 +332,6 @@ fn default_model_pricing() -> HashMap<String, ModelPricing> {
},
);
pricing.insert(
"mixtral-8x7b-32768".to_string(),
ModelPricing {
@ -395,7 +343,6 @@ fn default_model_pricing() -> HashMap<String, ModelPricing> {
},
);
pricing.insert(
"local".to_string(),
ModelPricing {
@ -410,7 +357,6 @@ fn default_model_pricing() -> HashMap<String, ModelPricing> {
pricing
}
#[derive(Debug)]
pub struct ObservabilityManager {
config: ObservabilityConfig,
@ -431,7 +377,6 @@ pub struct ObservabilityManager {
}
impl ObservabilityManager {
pub fn new(config: ObservabilityConfig) -> Self {
let budget = Budget {
daily_limit: config.budget_daily,
@ -454,7 +399,6 @@ impl ObservabilityManager {
}
}
pub fn from_config(config_map: &HashMap<String, String>) -> Self {
let config = ObservabilityConfig {
enabled: config_map
@ -494,13 +438,11 @@ impl ObservabilityManager {
ObservabilityManager::new(config)
}
pub async fn record_request(&self, metrics: LLMRequestMetrics) {
if !self.config.enabled {
return;
}
self.request_count.fetch_add(1, Ordering::Relaxed);
self.token_count
.fetch_add(metrics.total_tokens, Ordering::Relaxed);
@ -515,24 +457,20 @@ impl ObservabilityManager {
self.error_count.fetch_add(1, Ordering::Relaxed);
}
if self.config.cost_tracking && metrics.estimated_cost > 0.0 {
let mut budget = self.budget.write().await;
budget.daily_spend += metrics.estimated_cost;
budget.monthly_spend += metrics.estimated_cost;
}
let mut buffer = self.metrics_buffer.write().await;
buffer.push(metrics);
if buffer.len() > 10000 {
buffer.drain(0..1000);
}
}
pub fn calculate_cost(&self, model: &str, input_tokens: u64, output_tokens: u64) -> f64 {
if let Some(pricing) = self.config.model_pricing.get(model) {
if pricing.is_local {
@ -542,12 +480,10 @@ impl ObservabilityManager {
let output_cost = (output_tokens as f64 / 1000.0) * pricing.output_cost_per_1k;
input_cost + output_cost + pricing.cost_per_request
} else {
0.0
}
}
pub async fn get_budget_status(&self) -> BudgetStatus {
let budget = self.budget.read().await;
BudgetStatus {
@ -567,7 +503,6 @@ impl ObservabilityManager {
}
}
pub async fn check_budget(&self, estimated_cost: f64) -> BudgetCheckResult {
let budget = self.budget.read().await;
@ -593,7 +528,6 @@ impl ObservabilityManager {
BudgetCheckResult::Ok
}
pub async fn reset_daily_budget(&self) {
let mut budget = self.budget.write().await;
budget.daily_spend = 0.0;
@ -601,7 +535,6 @@ impl ObservabilityManager {
budget.daily_alert_sent = false;
}
pub async fn reset_monthly_budget(&self) {
let mut budget = self.budget.write().await;
budget.monthly_spend = 0.0;
@ -609,13 +542,11 @@ impl ObservabilityManager {
budget.monthly_alert_sent = false;
}
pub async fn record_trace(&self, event: TraceEvent) {
if !self.config.enabled || !self.config.trace_enabled {
return;
}
if self.config.trace_sample_rate < 1.0 {
let sample: f64 = rand::random();
if sample > self.config.trace_sample_rate {
@ -626,13 +557,11 @@ impl ObservabilityManager {
let mut buffer = self.trace_buffer.write().await;
buffer.push(event);
if buffer.len() > 5000 {
buffer.drain(0..500);
}
}
pub fn start_span(
&self,
trace_id: Uuid,
@ -656,7 +585,6 @@ impl ObservabilityManager {
}
}
pub fn end_span(&self, span: &mut TraceEvent, status: TraceStatus, error: Option<String>) {
let end_time = Utc::now();
span.end_time = Some(end_time);
@ -665,7 +593,6 @@ impl ObservabilityManager {
span.error = error;
}
pub async fn get_aggregated_metrics(
&self,
start: DateTime<Utc>,
@ -731,7 +658,6 @@ impl ObservabilityManager {
.or_insert(0) += 1;
}
if !latencies.is_empty() {
latencies.sort();
let len = latencies.len();
@ -748,7 +674,6 @@ impl ObservabilityManager {
metrics
}
pub async fn get_current_metrics(&self) -> AggregatedMetrics {
self.current_metrics.read().await.clone()
}
@ -789,13 +714,11 @@ impl ObservabilityManager {
}
}
pub async fn get_recent_traces(&self, limit: usize) -> Vec<TraceEvent> {
let buffer = self.trace_buffer.read().await;
buffer.iter().rev().take(limit).cloned().collect()
}
pub async fn get_trace(&self, trace_id: Uuid) -> Vec<TraceEvent> {
let buffer = self.trace_buffer.read().await;
buffer
@ -806,7 +729,6 @@ impl ObservabilityManager {
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BudgetStatus {
pub daily_limit: f64,
@ -823,7 +745,6 @@ pub struct BudgetStatus {
pub near_monthly_limit: bool,
}
#[derive(Debug, Clone, PartialEq)]
pub enum BudgetCheckResult {
Ok,
@ -833,7 +754,6 @@ pub enum BudgetCheckResult {
MonthlyExceeded,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuickStats {
pub total_requests: u64,
@ -845,7 +765,6 @@ pub struct QuickStats {
pub error_rate: f64,
}
impl LLMRequestMetrics {
pub fn to_dynamic(&self) -> Dynamic {
let mut map = Map::new();
@ -935,10 +854,7 @@ impl BudgetStatus {
}
}
pub fn register_observability_keywords(engine: &mut Engine) {
engine.register_fn("metrics_total_requests", |metrics: Map| -> i64 {
metrics
.get("total_requests")
@ -1006,8 +922,7 @@ pub fn register_observability_keywords(engine: &mut Engine) {
info!("Observability keywords registered");
}
pub const OBSERVABILITY_SCHEMA: &str = r#"
pub const OBSERVABILITY_SCHEMA: &str = r"
-- LLM request metrics
CREATE TABLE IF NOT EXISTS llm_metrics (
id UUID PRIMARY KEY,
@ -1102,4 +1017,4 @@ CREATE INDEX IF NOT EXISTS idx_llm_metrics_hourly_hour ON llm_metrics_hourly(hou
CREATE INDEX IF NOT EXISTS idx_llm_traces_trace_id ON llm_traces(trace_id);
CREATE INDEX IF NOT EXISTS idx_llm_traces_start_time ON llm_traces(start_time DESC);
CREATE INDEX IF NOT EXISTS idx_llm_traces_component ON llm_traces(component);
"#;
";

View file

@ -50,7 +50,7 @@ pub fn configure() -> Router<Arc<AppState>> {
}
async fn handle_incoming(
State(state): State<Arc<AppState>>,
State(_state): State<Arc<AppState>>,
Json(activity): Json<TeamsActivity>,
) -> impl IntoResponse {
match activity.activity_type.as_str() {

View file

@ -1,30 +1,3 @@
use anyhow::{anyhow, Context, Result};
use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
use reqwest::{Certificate, Client, ClientBuilder};
@ -38,28 +11,20 @@ use std::time::Duration;
use tracing::{debug, error, info, warn};
use x509_parser::prelude::*;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CertPinningConfig {
pub enabled: bool,
pub pins: HashMap<String, Vec<PinnedCert>>,
pub require_pins: bool,
pub allow_backup_pins: bool,
pub report_only: bool,
pub config_path: Option<PathBuf>,
pub cache_ttl_secs: u64,
}
@ -78,12 +43,10 @@ impl Default for CertPinningConfig {
}
impl CertPinningConfig {
pub fn new() -> Self {
Self::default()
}
pub fn strict() -> Self {
Self {
enabled: true,
@ -96,7 +59,6 @@ impl CertPinningConfig {
}
}
pub fn report_only() -> Self {
Self {
enabled: true,
@ -109,28 +71,23 @@ impl CertPinningConfig {
}
}
pub fn add_pin(&mut self, pin: PinnedCert) {
let hostname = pin.hostname.clone();
self.pins.entry(hostname).or_default().push(pin);
}
pub fn add_pins(&mut self, hostname: &str, pins: Vec<PinnedCert>) {
self.pins.insert(hostname.to_string(), pins);
}
pub fn remove_pins(&mut self, hostname: &str) {
self.pins.remove(hostname);
}
pub fn get_pins(&self, hostname: &str) -> Option<&Vec<PinnedCert>> {
self.pins.get(hostname)
}
pub fn load_from_file(path: &Path) -> Result<Self> {
let content = fs::read_to_string(path)
.with_context(|| format!("Failed to read pin config from {:?}", path))?;
@ -142,7 +99,6 @@ impl CertPinningConfig {
Ok(config)
}
pub fn save_to_file(&self, path: &Path) -> Result<()> {
let content =
serde_json::to_string_pretty(self).context("Failed to serialize pin configuration")?;
@ -155,31 +111,22 @@ impl CertPinningConfig {
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PinnedCert {
pub hostname: String,
pub fingerprint: String,
pub description: Option<String>,
pub is_backup: bool,
pub expires_at: Option<chrono::DateTime<chrono::Utc>>,
pub pin_type: PinType,
}
impl PinnedCert {
pub fn new(hostname: &str, fingerprint: &str) -> Self {
Self {
hostname: hostname.to_string(),
@ -191,7 +138,6 @@ impl PinnedCert {
}
}
pub fn backup(hostname: &str, fingerprint: &str) -> Self {
Self {
hostname: hostname.to_string(),
@ -203,25 +149,21 @@ impl PinnedCert {
}
}
pub fn with_type(mut self, pin_type: PinType) -> Self {
self.pin_type = pin_type;
self
}
pub fn with_description(mut self, desc: &str) -> Self {
self.description = Some(desc.to_string());
self
}
pub fn with_expiration(mut self, expires: chrono::DateTime<chrono::Utc>) -> Self {
self.expires_at = Some(expires);
self
}
pub fn get_hash_bytes(&self) -> Result<Vec<u8>> {
let hash_str = self
.fingerprint
@ -233,7 +175,6 @@ impl PinnedCert {
.context("Failed to decode base64 fingerprint")
}
pub fn verify(&self, cert_der: &[u8]) -> Result<bool> {
let expected_hash = self.get_hash_bytes()?;
let actual_hash = compute_spki_fingerprint(cert_der)?;
@ -242,10 +183,8 @@ impl PinnedCert {
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum PinType {
Leaf,
Intermediate,
@ -259,30 +198,22 @@ impl Default for PinType {
}
}
#[derive(Debug, Clone)]
pub struct PinValidationResult {
pub valid: bool,
pub hostname: String,
pub matched_pin: Option<String>,
pub actual_fingerprint: String,
pub error: Option<String>,
pub backup_match: bool,
}
impl PinValidationResult {
pub fn success(hostname: &str, fingerprint: &str, backup: bool) -> Self {
Self {
valid: true,
@ -294,7 +225,6 @@ impl PinValidationResult {
}
}
pub fn failure(hostname: &str, actual: &str, error: &str) -> Self {
Self {
valid: false,
@ -307,7 +237,6 @@ impl PinValidationResult {
}
}
#[derive(Debug)]
pub struct CertPinningManager {
config: Arc<RwLock<CertPinningConfig>>,
@ -315,7 +244,6 @@ pub struct CertPinningManager {
}
impl CertPinningManager {
pub fn new(config: CertPinningConfig) -> Self {
Self {
config: Arc::new(RwLock::new(config)),
@ -323,17 +251,14 @@ impl CertPinningManager {
}
}
pub fn default_manager() -> Self {
Self::new(CertPinningConfig::default())
}
pub fn is_enabled(&self) -> bool {
self.config.read().unwrap().enabled
}
pub fn add_pin(&self, pin: PinnedCert) -> Result<()> {
let mut config = self
.config
@ -343,7 +268,6 @@ impl CertPinningManager {
Ok(())
}
pub fn remove_pins(&self, hostname: &str) -> Result<()> {
let mut config = self
.config
@ -351,7 +275,6 @@ impl CertPinningManager {
.map_err(|_| anyhow!("Failed to acquire write lock"))?;
config.remove_pins(hostname);
let mut cache = self
.validation_cache
.write()
@ -361,7 +284,6 @@ impl CertPinningManager {
Ok(())
}
pub fn validate_certificate(
&self,
hostname: &str,
@ -376,7 +298,6 @@ impl CertPinningManager {
return Ok(PinValidationResult::success(hostname, "disabled", false));
}
if let Ok(cache) = self.validation_cache.read() {
if let Some((result, timestamp)) = cache.get(hostname) {
if timestamp.elapsed().as_secs() < config.cache_ttl_secs {
@ -385,11 +306,9 @@ impl CertPinningManager {
}
}
let actual_hash = compute_spki_fingerprint(cert_der)?;
let actual_fingerprint = format!("sha256//{}", BASE64.encode(&actual_hash));
let pins = match config.get_pins(hostname) {
Some(pins) => pins,
None => {
@ -411,7 +330,6 @@ impl CertPinningManager {
return Ok(result);
}
return Ok(PinValidationResult::success(
hostname,
"no-pins-required",
@ -420,7 +338,6 @@ impl CertPinningManager {
}
};
for pin in pins {
match pin.verify(cert_der) {
Ok(true) => {
@ -435,7 +352,6 @@ impl CertPinningManager {
);
}
if let Ok(mut cache) = self.validation_cache.write() {
cache.insert(
hostname.to_string(),
@ -445,15 +361,13 @@ impl CertPinningManager {
return Ok(result);
}
Ok(false) => continue,
Ok(false) => {}
Err(e) => {
debug!("Pin verification error for {}: {}", hostname, e);
continue;
}
}
}
let result = PinValidationResult::failure(
hostname,
&actual_fingerprint,
@ -479,12 +393,10 @@ impl CertPinningManager {
Ok(result)
}
pub fn create_pinned_client(&self, hostname: &str) -> Result<Client> {
self.create_pinned_client_with_options(hostname, None, Duration::from_secs(30))
}
pub fn create_pinned_client_with_options(
&self,
hostname: &str,
@ -503,14 +415,10 @@ impl CertPinningManager {
.https_only(true)
.tls_built_in_root_certs(true);
if let Some(cert) = ca_cert {
builder = builder.add_root_certificate(cert.clone());
}
if config.enabled && config.get_pins(hostname).is_some() {
debug!(
"Creating pinned client for {} with {} pins",
@ -522,7 +430,6 @@ impl CertPinningManager {
builder.build().context("Failed to build HTTP client")
}
pub fn validate_pem_file(
&self,
hostname: &str,
@ -535,12 +442,10 @@ impl CertPinningManager {
self.validate_certificate(hostname, &der)
}
pub fn generate_pin_from_file(hostname: &str, cert_path: &Path) -> Result<PinnedCert> {
let cert_data = fs::read(cert_path)
.with_context(|| format!("Failed to read certificate: {:?}", cert_path))?;
let der = if cert_data.starts_with(b"-----BEGIN") {
pem_to_der(&cert_data)?
} else {
@ -553,7 +458,6 @@ impl CertPinningManager {
Ok(PinnedCert::new(hostname, &fingerprint_str))
}
pub fn generate_pins_from_directory(
hostname: &str,
cert_dir: &Path,
@ -583,7 +487,6 @@ impl CertPinningManager {
Ok(pins)
}
pub fn export_pins(&self, path: &Path) -> Result<()> {
let config = self
.config
@ -593,7 +496,6 @@ impl CertPinningManager {
config.save_to_file(path)
}
pub fn import_pins(&self, path: &Path) -> Result<()> {
let imported = CertPinningConfig::load_from_file(path)?;
@ -606,7 +508,6 @@ impl CertPinningManager {
config.pins.insert(hostname, pins);
}
if let Ok(mut cache) = self.validation_cache.write() {
cache.clear();
}
@ -614,7 +515,6 @@ impl CertPinningManager {
Ok(())
}
pub fn get_stats(&self) -> Result<PinningStats> {
let config = self
.config
@ -653,7 +553,6 @@ impl CertPinningManager {
}
}
#[derive(Debug, Clone, Serialize)]
pub struct PinningStats {
pub enabled: bool,
@ -664,31 +563,25 @@ pub struct PinningStats {
pub report_only: bool,
}
pub fn compute_spki_fingerprint(cert_der: &[u8]) -> Result<Vec<u8>> {
let (_, cert) = X509Certificate::from_der(cert_der)
.map_err(|e| anyhow!("Failed to parse X.509 certificate: {}", e))?;
let spki = cert.public_key().raw;
let hash = digest(&SHA256, spki);
Ok(hash.as_ref().to_vec())
}
pub fn compute_cert_fingerprint(cert_der: &[u8]) -> Vec<u8> {
let hash = digest(&SHA256, cert_der);
hash.as_ref().to_vec()
}
pub fn pem_to_der(pem_data: &[u8]) -> Result<Vec<u8>> {
let pem_str = std::str::from_utf8(pem_data).context("Invalid UTF-8 in PEM data")?;
let start_marker = "-----BEGIN CERTIFICATE-----";
let end_marker = "-----END CERTIFICATE-----";
@ -708,7 +601,6 @@ pub fn pem_to_der(pem_data: &[u8]) -> Result<Vec<u8>> {
.context("Failed to decode base64 certificate data")
}
pub fn format_fingerprint(hash: &[u8]) -> String {
hash.iter()
.map(|b| format!("{:02X}", b))
@ -716,16 +608,13 @@ pub fn format_fingerprint(hash: &[u8]) -> String {
.join(":")
}
pub fn parse_fingerprint(formatted: &str) -> Result<Vec<u8>> {
if let Some(base64_part) = formatted.strip_prefix("sha256//") {
return BASE64
.decode(base64_part)
.context("Failed to decode base64 fingerprint");
}
if formatted.contains(':') {
let bytes: Result<Vec<u8>, _> = formatted
.split(':')
@ -735,7 +624,6 @@ pub fn parse_fingerprint(formatted: &str) -> Result<Vec<u8>> {
return bytes.context("Failed to parse hex fingerprint");
}
let bytes: Result<Vec<u8>, _> = (0..formatted.len())
.step_by(2)
.map(|i| u8::from_str_radix(&formatted[i..i + 2], 16))

View file

@ -83,7 +83,7 @@ type TaskHandler = Arc<
impl TaskScheduler {
pub fn new(state: Arc<AppState>) -> Self {
let scheduler = Self {
state: state,
state,
running_tasks: Arc::new(RwLock::new(HashMap::new())),
task_registry: Arc::new(RwLock::new(HashMap::new())),
scheduled_tasks: Arc::new(RwLock::new(Vec::new())),
@ -101,14 +101,10 @@ impl TaskScheduler {
tokio::spawn(async move {
let mut handlers = registry.write().await;
handlers.insert(
"database_cleanup".to_string(),
Arc::new(move |_state: Arc<AppState>, _payload: serde_json::Value| {
Box::pin(async move {
info!("Database cleanup task executed");
Ok(serde_json::json!({
@ -120,7 +116,6 @@ impl TaskScheduler {
}),
);
handlers.insert(
"cache_cleanup".to_string(),
Arc::new(move |state: Arc<AppState>, _payload: serde_json::Value| {
@ -139,7 +134,6 @@ impl TaskScheduler {
}),
);
handlers.insert(
"backup".to_string(),
Arc::new(move |state: Arc<AppState>, payload: serde_json::Value| {
@ -157,7 +151,6 @@ impl TaskScheduler {
.arg(&backup_file)
.output()?;
if state.s3_client.is_some() {
let s3 = state.s3_client.as_ref().unwrap();
let body = tokio::fs::read(&backup_file).await?;
@ -196,7 +189,6 @@ impl TaskScheduler {
}),
);
handlers.insert(
"generate_report".to_string(),
Arc::new(move |_state: Arc<AppState>, payload: serde_json::Value| {
@ -229,7 +221,6 @@ impl TaskScheduler {
}),
);
handlers.insert(
"health_check".to_string(),
Arc::new(move |state: Arc<AppState>, _payload: serde_json::Value| {
@ -240,17 +231,14 @@ impl TaskScheduler {
"timestamp": Utc::now()
});
let db_ok = state.conn.get().is_ok();
health["database"] = serde_json::json!(db_ok);
if let Some(cache) = &state.cache {
let cache_ok = cache.get_connection().is_ok();
health["cache"] = serde_json::json!(cache_ok);
}
if let Some(s3) = &state.s3_client {
let s3_ok = s3.list_buckets().send().await.is_ok();
health["storage"] = serde_json::json!(s3_ok);
@ -351,7 +339,6 @@ impl TaskScheduler {
let execution_id = Uuid::new_v4();
let started_at = Utc::now();
let _execution = TaskExecution {
id: execution_id,
scheduled_task_id: task_id,
@ -363,11 +350,6 @@ impl TaskScheduler {
duration_ms: None,
};
let result = {
let handlers = registry.read().await;
if let Some(handler) = handlers.get(&task.task_type) {
@ -388,25 +370,20 @@ impl TaskScheduler {
let completed_at = Utc::now();
let _duration_ms = (completed_at - started_at).num_milliseconds();
match result {
Ok(_result) => {
let schedule = Schedule::from_str(&task.cron_expression).ok();
let _next_run = schedule
.and_then(|s| s.upcoming(chrono::Local).take(1).next())
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|| Utc::now() + Duration::hours(1));
info!("Task {} completed successfully", task.name);
}
Err(e) => {
let error_msg = format!("Task failed: {}", e);
error!("{}", error_msg);
task.retry_count += 1;
if task.retry_count < task.max_retries {
let _retry_delay =
@ -424,12 +401,10 @@ impl TaskScheduler {
}
}
let mut running = running_tasks.write().await;
running.remove(&task_id);
});
let mut running = self.running_tasks.write().await;
running.insert(task_id, handle);
}
@ -445,7 +420,6 @@ impl TaskScheduler {
info!("Stopped task: {}", task_id);
}
let mut tasks = self.scheduled_tasks.write().await;
if let Some(task) = tasks.iter_mut().find(|t| t.id == task_id) {
task.enabled = false;