diff --git a/Cargo.lock b/Cargo.lock index 00b49126a..4b3642a4c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -434,15 +434,6 @@ dependencies = [ "system-deps", ] -[[package]] -name = "atoi" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528" -dependencies = [ - "num-traits", -] - [[package]] name = "atomic-waker" version = "1.1.2" @@ -1219,7 +1210,6 @@ dependencies = [ "serde_json", "sha2", "smartstring", - "sqlx", "sysinfo", "tauri", "tauri-build", @@ -1867,15 +1857,6 @@ dependencies = [ "crossbeam-utils", ] -[[package]] -name = "crossbeam-queue" -version = "0.3.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" -dependencies = [ - "crossbeam-utils", -] - [[package]] name = "crossbeam-utils" version = "0.8.21" @@ -2616,9 +2597,6 @@ name = "either" version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" -dependencies = [ - "serde", -] [[package]] name = "elliptic-curve" @@ -2789,17 +2767,6 @@ dependencies = [ "windows-sys 0.61.2", ] -[[package]] -name = "etcetera" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "136d1b5283a1ab77bd9257427ffd09d8667ced0570b6f938942bc7568ed5b943" -dependencies = [ - "cfg-if", - "home", - "windows-sys 0.48.0", -] - [[package]] name = "euclid" version = "0.20.14" @@ -2919,17 +2886,6 @@ dependencies = [ "miniz_oxide", ] -[[package]] -name = "flume" -version = "0.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da0e4dd2a88388a1f4ccc7c9ce104604dab68d9f408dc34cd45823d5a9069095" -dependencies = [ - "futures-core", - "futures-sink", - "spin", -] - [[package]] name = "fnv" version = "1.0.7" @@ -3067,17 +3023,6 @@ dependencies = [ "futures-util", ] -[[package]] -name = "futures-intrusive" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f" -dependencies = [ - "futures-core", - "lock_api", - "parking_lot", -] - [[package]] name = "futures-io" version = "0.3.31" @@ -3552,15 +3497,6 @@ version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5419bdc4f6a9207fbeba6d11b604d481addf78ecd10c11ad51e76c2f6482748d" -[[package]] -name = "hashlink" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1" -dependencies = [ - "hashbrown 0.15.5", -] - [[package]] name = "heck" version = "0.4.1" @@ -3603,15 +3539,6 @@ dependencies = [ "digest", ] -[[package]] -name = "home" -version = "0.5.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc627f471c528ff0c4a49e1d5e60450c8f6461dd6d10ba9dcd3a61d3dff7728d" -dependencies = [ - "windows-sys 0.61.2", -] - [[package]] name = "hostname" version = "0.4.1" @@ -4438,17 +4365,6 @@ checksum = "416f7e718bdb06000964960ffa43b4335ad4012ae8b99060261aa4a8088d5ccb" dependencies = [ "bitflags 2.10.0", "libc", - "redox_syscall", -] - -[[package]] -name = "libsqlite3-sys" -version = "0.30.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149" -dependencies = [ - "pkg-config", - "vcpkg", ] [[package]] @@ -7679,9 +7595,6 @@ name = "smallvec" version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" -dependencies = [ - "serde", -] [[package]] name = "smartstring" @@ -7767,9 +7680,6 @@ name = "spin" version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" -dependencies = [ - "lock_api", -] [[package]] name = "spki" @@ -7791,204 +7701,6 @@ dependencies = [ "der 0.7.10", ] -[[package]] -name = "sqlx" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fefb893899429669dcdd979aff487bd78f4064e5e7907e4269081e0ef7d97dc" -dependencies = [ - "sqlx-core", - "sqlx-macros", - "sqlx-mysql", - "sqlx-postgres", - "sqlx-sqlite", -] - -[[package]] -name = "sqlx-core" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee6798b1838b6a0f69c007c133b8df5866302197e404e8b6ee8ed3e3a5e68dc6" -dependencies = [ - "base64 0.22.1", - "bytes", - "chrono", - "crc", - "crossbeam-queue", - "either", - "event-listener 5.4.1", - "futures-core", - "futures-intrusive", - "futures-io", - "futures-util", - "hashbrown 0.15.5", - "hashlink", - "indexmap 2.12.0", - "log", - "memchr", - "once_cell", - "percent-encoding", - "rustls 0.23.35", - "serde", - "serde_json", - "sha2", - "smallvec", - "thiserror 2.0.17", - "tokio", - "tokio-stream", - "tracing", - "url", - "uuid", - "webpki-roots 0.26.11", -] - -[[package]] -name = "sqlx-macros" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2d452988ccaacfbf5e0bdbc348fb91d7c8af5bee192173ac3636b5fb6e6715d" -dependencies = [ - "proc-macro2", - "quote", - "sqlx-core", - "sqlx-macros-core", - "syn 2.0.110", -] - -[[package]] -name = "sqlx-macros-core" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19a9c1841124ac5a61741f96e1d9e2ec77424bf323962dd894bdb93f37d5219b" -dependencies = [ - "dotenvy", - "either", - "heck 0.5.0", - "hex", - "once_cell", - "proc-macro2", - "quote", - "serde", - "serde_json", - "sha2", - "sqlx-core", - "sqlx-mysql", - "sqlx-postgres", - "sqlx-sqlite", - "syn 2.0.110", - "tokio", - "url", -] - -[[package]] -name = "sqlx-mysql" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa003f0038df784eb8fecbbac13affe3da23b45194bd57dba231c8f48199c526" -dependencies = [ - "atoi", - "base64 0.22.1", - "bitflags 2.10.0", - "byteorder", - "bytes", - "chrono", - "crc", - "digest", - "dotenvy", - "either", - "futures-channel", - "futures-core", - "futures-io", - "futures-util", - "generic-array", - "hex", - "hkdf", - "hmac", - "itoa", - "log", - "md-5", - "memchr", - "once_cell", - "percent-encoding", - "rand 0.8.5", - "rsa", - "serde", - "sha1", - "sha2", - "smallvec", - "sqlx-core", - "stringprep", - "thiserror 2.0.17", - "tracing", - "uuid", - "whoami", -] - -[[package]] -name = "sqlx-postgres" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db58fcd5a53cf07c184b154801ff91347e4c30d17a3562a635ff028ad5deda46" -dependencies = [ - "atoi", - "base64 0.22.1", - "bitflags 2.10.0", - "byteorder", - "chrono", - "crc", - "dotenvy", - "etcetera", - "futures-channel", - "futures-core", - "futures-util", - "hex", - "hkdf", - "hmac", - "home", - "itoa", - "log", - "md-5", - "memchr", - "once_cell", - "rand 0.8.5", - "serde", - "serde_json", - "sha2", - "smallvec", - "sqlx-core", - "stringprep", - "thiserror 2.0.17", - "tracing", - "uuid", - "whoami", -] - -[[package]] -name = "sqlx-sqlite" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2d12fe70b2c1b4401038055f90f151b78208de1f9f89a7dbfd41587a10c3eea" -dependencies = [ - "atoi", - "chrono", - "flume", - "futures-channel", - "futures-core", - "futures-executor", - "futures-intrusive", - "futures-util", - "libsqlite3-sys", - "log", - "percent-encoding", - "serde", - "serde_urlencoded", - "sqlx-core", - "thiserror 2.0.17", - "tracing", - "url", - "uuid", -] - [[package]] name = "stable_deref_trait" version = "1.2.1" @@ -9543,12 +9255,6 @@ dependencies = [ "wit-bindgen", ] -[[package]] -name = "wasite" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" - [[package]] name = "wasm-bindgen" version = "0.2.105" @@ -9750,15 +9456,6 @@ version = "0.25.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1" -[[package]] -name = "webpki-roots" -version = "0.26.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9" -dependencies = [ - "webpki-roots 1.0.4", -] - [[package]] name = "webpki-roots" version = "1.0.4" @@ -9840,16 +9537,6 @@ version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a28ac98ddc8b9274cb41bb4d9d4d5c425b6020c50c46f25559911905610b4a88" -[[package]] -name = "whoami" -version = "1.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d4a4db5077702ca3015d3d02d74974948aba2ad9e12ab7df718ee64ccd7e97d" -dependencies = [ - "libredox", - "wasite", -] - [[package]] name = "winapi" version = "0.3.9" diff --git a/Cargo.toml b/Cargo.toml index 4e73c01cd..b5443d6c6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -62,7 +62,7 @@ msteams = [] chat = [] drive = ["dep:aws-config", "dep:aws-sdk-s3", "dep:pdf-extract", "dep:zip", "dep:downloader", "dep:mime_guess"] tasks = ["dep:cron"] -calendar = ["dep:sqlx"] +calendar = [] meet = ["dep:livekit"] mail = ["email"] @@ -138,9 +138,6 @@ zitadel = { version = "5.5.1", features = ["api", "credentials"] } # === FEATURE-SPECIFIC DEPENDENCIES (Optional) === -# Database (for calendar and other features) -sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "postgres", "chrono", "uuid"], optional = true } - # Desktop UI (desktop feature) tauri = { version = "2", features = ["unstable"], optional = true } tauri-plugin-dialog = { version = "2", optional = true } diff --git a/docs/src/chapter-11-features/drive-monitor.md b/docs/src/chapter-11-features/drive-monitor.md new file mode 100644 index 000000000..86fadbc19 --- /dev/null +++ b/docs/src/chapter-11-features/drive-monitor.md @@ -0,0 +1,160 @@ +# Drive Monitor + +The Drive Monitor is a real-time file synchronization system that watches for changes in bot storage buckets and automatically updates the database and runtime configuration. + +## Overview + +DriveMonitor provides hot-reloading capabilities for bot configurations by continuously monitoring file changes in object storage. When files are modified, added, or removed, the system automatically: + +- Detects changes through ETags and file comparison +- Updates the database with new configurations +- Recompiles scripts and tools +- Refreshes knowledge bases +- Broadcasts theme changes to connected clients + +## Architecture + +``` +┌─────────────────┐ +│ Object Storage │ (S3-compatible) +│ Buckets │ +└────────┬────────┘ + │ Poll every 30s + ▼ +┌─────────────────┐ +│ Drive Monitor │ +│ - Check ETags │ +│ - Diff files │ +└────────┬────────┘ + │ Changes detected + ▼ +┌─────────────────────────┐ +│ Process Updates │ +│ - Compile scripts (.bas)│ +│ - Update KB (.gbkb) │ +│ - Refresh themes │ +│ - Update database │ +└─────────────────────────┘ +``` + +## Implementation + +### Core Components + +The DriveMonitor is implemented in `src/drive/drive_monitor/mod.rs` with the following structure: + +```rust +pub struct DriveMonitor { + state: Arc, + bucket_name: String, + file_states: Arc>>, + bot_id: Uuid, + kb_manager: Arc, + work_root: PathBuf, + is_processing: Arc, +} +``` + +### Monitoring Process + +1. **Initialization**: When a bot is mounted, a DriveMonitor instance is created and spawned +2. **Polling**: Every 30 seconds, the monitor checks for changes in: + - `.gbdialog` files (scripts and tools) + - `.gbkb` collections (knowledge base documents) + - `.gbtheme` files (UI themes) + - `.gbot/config.csv` (bot configuration) + +3. **Change Detection**: Uses ETags to detect file modifications efficiently +4. **Processing**: Different file types trigger specific handlers: + - Scripts → Compile to AST + - Knowledge base → Index and embed documents + - Themes → Broadcast updates to WebSocket clients + - Config → Reload bot settings + +### File Type Handlers + +#### Script Files (.bas) +- Compiles BASIC scripts to AST +- Stores compiled version in database +- Updates tool registry if applicable + +#### Knowledge Base Files (.gbkb) +- Downloads new/modified documents +- Processes text extraction +- Generates embeddings +- Updates vector database + +#### Theme Files (.gbtheme) +- Detects CSS/JS changes +- Broadcasts updates to connected clients +- Triggers UI refresh without page reload + +## Usage + +The DriveMonitor is automatically started when a bot is mounted: + +```rust +// In BotOrchestrator::mount_bot +let drive_monitor = Arc::new(DriveMonitor::new( + state.clone(), + bucket_name, + bot_id +)); +let _handle = drive_monitor.clone().spawn().await; +``` + +## Configuration + +No explicit configuration needed - the monitor automatically: +- Uses the bot's storage bucket name +- Creates work directories as needed +- Manages its own file state cache + +## Performance Considerations + +- **Polling Interval**: 30 seconds (balance between responsiveness and resource usage) +- **Concurrent Processing**: Uses atomic flags to prevent overlapping operations +- **Caching**: Maintains ETag cache to minimize unnecessary downloads +- **Batching**: Processes multiple file changes in a single cycle + +## Error Handling + +The monitor includes robust error handling: +- Continues operation even if individual file processing fails +- Logs errors for debugging while maintaining service availability +- Prevents cascading failures through isolated error boundaries + +## Monitoring and Debugging + +Enable debug logging to see monitor activity: + +```bash +RUST_LOG=botserver::drive::drive_monitor=debug cargo run +``` + +Log output includes: +- Change detection events +- File processing status +- Compilation results +- Database update confirmations + +## Best Practices + +1. **File Organization**: Keep related files in appropriate directories (.gbdialog, .gbkb, etc.) +2. **Version Control**: The monitor tracks changes but doesn't maintain history - use git for version control +3. **Large Files**: For knowledge base documents > 10MB, consider splitting into smaller files +4. **Development**: During development, the 30-second delay can be avoided by restarting the bot + +## Limitations + +- **Not Real-time**: 30-second polling interval means changes aren't instant +- **No Conflict Resolution**: Last-write-wins for concurrent modifications +- **Memory Usage**: Keeps file state in memory (minimal for ETags) + +## Future Enhancements + +Planned improvements include: +- WebSocket notifications from storage layer for instant updates +- Configurable polling intervals per file type +- Differential sync for large knowledge bases +- Multi-version support for A/B testing \ No newline at end of file diff --git a/prompts/dev/platform/README.md b/prompts/dev/platform/README.md index cf0c37d9e..be0633c2c 100644 --- a/prompts/dev/platform/README.md +++ b/prompts/dev/platform/README.md @@ -14,7 +14,7 @@ When initial attempts fail, sequentially try these LLMs: - **On unresolved error**: Stop and use add-req.sh, and consult Claude for guidance. with DeepThining in DeepSeek also, with Web turned on. - **Change progression**: Start with DeepSeek, conclude with gpt-oss-120b - If a big req. fail, specify a @code file that has similar pattern or sample from official docs. -- **Warning removal**: Last task before commiting, create a task list of warning removal and work with cargo check. +- **Warning removal**: Last task before commiting, create a task list of warning removal and work with cargo check. If lots of warning, let LLM put #[allow(dead_code)] on top. Check manually for missing/deleted code on some files. - **Final validation**: Use prompt "cargo check" with gpt-oss-120b - Be humble, one requirement, one commit. But sometimes, freedom of caos is welcome - when no deadlines are set. - Fix manually in case of dangerous trouble. diff --git a/src/attendance/drive.rs b/src/attendance/drive.rs new file mode 100644 index 000000000..4bed9b3f8 --- /dev/null +++ b/src/attendance/drive.rs @@ -0,0 +1,399 @@ +//! Drive integration module for attendance system +//! Handles file storage and synchronization for attendance records + +use anyhow::{anyhow, Result}; +use aws_sdk_s3::primitives::ByteStream; +use aws_sdk_s3::Client; +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; +use tokio::fs; + +/// Drive configuration for attendance storage +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AttendanceDriveConfig { + pub bucket_name: String, + pub prefix: String, + pub sync_enabled: bool, + pub region: Option, +} + +impl Default for AttendanceDriveConfig { + fn default() -> Self { + Self { + bucket_name: "attendance".to_string(), + prefix: "records/".to_string(), + sync_enabled: true, + region: None, + } + } +} + +/// Drive service for attendance data +#[derive(Debug, Clone)] +pub struct AttendanceDriveService { + config: AttendanceDriveConfig, + client: Client, +} + +impl AttendanceDriveService { + /// Create new attendance drive service + pub async fn new(config: AttendanceDriveConfig) -> Result { + let sdk_config = if let Some(region) = &config.region { + aws_config::from_env() + .region(aws_config::Region::new(region.clone())) + .load() + .await + } else { + aws_config::from_env().load().await + }; + + let client = Client::new(&sdk_config); + + Ok(Self { config, client }) + } + + /// Create new service with existing S3 client + pub fn with_client(config: AttendanceDriveConfig, client: Client) -> Self { + Self { config, client } + } + + /// Get the full S3 key for a record + fn get_record_key(&self, record_id: &str) -> String { + format!("{}{}", self.config.prefix, record_id) + } + + /// Upload attendance record to drive + pub async fn upload_record(&self, record_id: &str, data: Vec) -> Result<()> { + let key = self.get_record_key(record_id); + + log::info!( + "Uploading attendance record {} to s3://{}/{}", + record_id, + self.config.bucket_name, + key + ); + + let body = ByteStream::from(data); + + self.client + .put_object() + .bucket(&self.config.bucket_name) + .key(&key) + .body(body) + .content_type("application/octet-stream") + .send() + .await + .map_err(|e| anyhow!("Failed to upload attendance record: {}", e))?; + + log::debug!("Successfully uploaded attendance record {}", record_id); + Ok(()) + } + + /// Download attendance record from drive + pub async fn download_record(&self, record_id: &str) -> Result> { + let key = self.get_record_key(record_id); + + log::info!( + "Downloading attendance record {} from s3://{}/{}", + record_id, + self.config.bucket_name, + key + ); + + let result = self + .client + .get_object() + .bucket(&self.config.bucket_name) + .key(&key) + .send() + .await + .map_err(|e| anyhow!("Failed to download attendance record: {}", e))?; + + let data = result + .body + .collect() + .await + .map_err(|e| anyhow!("Failed to read attendance record body: {}", e))?; + + log::debug!("Successfully downloaded attendance record {}", record_id); + Ok(data.into_bytes().to_vec()) + } + + /// List attendance records in drive + pub async fn list_records(&self, prefix: Option<&str>) -> Result> { + let list_prefix = if let Some(p) = prefix { + format!("{}{}", self.config.prefix, p) + } else { + self.config.prefix.clone() + }; + + log::info!( + "Listing attendance records in s3://{}/{}", + self.config.bucket_name, + list_prefix + ); + + let mut records = Vec::new(); + let mut continuation_token = None; + + loop { + let mut request = self + .client + .list_objects_v2() + .bucket(&self.config.bucket_name) + .prefix(&list_prefix) + .max_keys(1000); + + if let Some(token) = continuation_token { + request = request.continuation_token(token); + } + + let result = request + .send() + .await + .map_err(|e| anyhow!("Failed to list attendance records: {}", e))?; + + if let Some(contents) = result.contents { + for obj in contents { + if let Some(key) = obj.key { + // Remove prefix to get record ID + if let Some(record_id) = key.strip_prefix(&self.config.prefix) { + records.push(record_id.to_string()); + } + } + } + } + + if result.is_truncated.unwrap_or(false) { + continuation_token = result.next_continuation_token; + } else { + break; + } + } + + log::debug!("Found {} attendance records", records.len()); + Ok(records) + } + + /// Delete attendance record from drive + pub async fn delete_record(&self, record_id: &str) -> Result<()> { + let key = self.get_record_key(record_id); + + log::info!( + "Deleting attendance record {} from s3://{}/{}", + record_id, + self.config.bucket_name, + key + ); + + self.client + .delete_object() + .bucket(&self.config.bucket_name) + .key(&key) + .send() + .await + .map_err(|e| anyhow!("Failed to delete attendance record: {}", e))?; + + log::debug!("Successfully deleted attendance record {}", record_id); + Ok(()) + } + + /// Batch delete multiple attendance records + pub async fn delete_records(&self, record_ids: &[String]) -> Result<()> { + if record_ids.is_empty() { + return Ok(()); + } + + log::info!( + "Batch deleting {} attendance records from bucket {}", + record_ids.len(), + self.config.bucket_name + ); + + // S3 batch delete is limited to 1000 objects per request + for chunk in record_ids.chunks(1000) { + let objects: Vec<_> = chunk + .iter() + .map(|id| { + aws_sdk_s3::types::ObjectIdentifier::builder() + .key(self.get_record_key(id)) + .build() + .unwrap() + }) + .collect(); + + let delete = aws_sdk_s3::types::Delete::builder() + .set_objects(Some(objects)) + .build() + .map_err(|e| anyhow!("Failed to build delete request: {}", e))?; + + self.client + .delete_objects() + .bucket(&self.config.bucket_name) + .delete(delete) + .send() + .await + .map_err(|e| anyhow!("Failed to batch delete attendance records: {}", e))?; + } + + log::debug!( + "Successfully batch deleted {} attendance records", + record_ids.len() + ); + Ok(()) + } + + /// Check if an attendance record exists + pub async fn record_exists(&self, record_id: &str) -> Result { + let key = self.get_record_key(record_id); + + match self + .client + .head_object() + .bucket(&self.config.bucket_name) + .key(&key) + .send() + .await + { + Ok(_) => Ok(true), + Err(sdk_err) => { + if sdk_err.to_string().contains("404") || sdk_err.to_string().contains("NotFound") { + Ok(false) + } else { + Err(anyhow!( + "Failed to check attendance record existence: {}", + sdk_err + )) + } + } + } + } + + /// Sync local attendance records with drive + pub async fn sync_records(&self, local_path: PathBuf) -> Result { + if !self.config.sync_enabled { + log::debug!("Attendance drive sync is disabled"); + return Ok(SyncResult::default()); + } + + log::info!( + "Syncing attendance records from {:?} to s3://{}/{}", + local_path, + self.config.bucket_name, + self.config.prefix + ); + + if !local_path.exists() { + return Err(anyhow!("Local path does not exist: {:?}", local_path)); + } + + let mut uploaded = 0; + let mut failed = 0; + let mut skipped = 0; + + let mut entries = fs::read_dir(&local_path) + .await + .map_err(|e| anyhow!("Failed to read local directory: {}", e))?; + + while let Some(entry) = entries + .next_entry() + .await + .map_err(|e| anyhow!("Failed to read directory entry: {}", e))? + { + let path = entry.path(); + if !path.is_file() { + continue; + } + + let file_name = match path.file_name().and_then(|n| n.to_str()) { + Some(name) => name.to_string(), + None => { + log::warn!("Skipping file with invalid name: {:?}", path); + skipped += 1; + continue; + } + }; + + // Check if record already exists in S3 + if self.record_exists(&file_name).await? { + log::debug!("Record {} already exists in drive, skipping", file_name); + skipped += 1; + continue; + } + + // Read file and upload + match fs::read(&path).await { + Ok(data) => match self.upload_record(&file_name, data).await { + Ok(_) => { + log::debug!("Uploaded attendance record: {}", file_name); + uploaded += 1; + } + Err(e) => { + log::error!("Failed to upload {}: {}", file_name, e); + failed += 1; + } + }, + Err(e) => { + log::error!("Failed to read file {:?}: {}", path, e); + failed += 1; + } + } + } + + let result = SyncResult { + uploaded, + failed, + skipped, + }; + + log::info!( + "Sync completed: {} uploaded, {} failed, {} skipped", + result.uploaded, + result.failed, + result.skipped + ); + + Ok(result) + } + + /// Get metadata for an attendance record + pub async fn get_record_metadata(&self, record_id: &str) -> Result { + let key = self.get_record_key(record_id); + + let result = self + .client + .head_object() + .bucket(&self.config.bucket_name) + .key(&key) + .send() + .await + .map_err(|e| anyhow!("Failed to get attendance record metadata: {}", e))?; + + Ok(RecordMetadata { + size: result.content_length.unwrap_or(0) as usize, + last_modified: result + .last_modified + .and_then(|t| t.to_millis().ok()) + .map(|ms| chrono::Utc.timestamp_millis_opt(ms as i64).unwrap()), + content_type: result.content_type, + etag: result.e_tag, + }) + } +} + +/// Result of sync operation +#[derive(Debug, Default, Clone, Serialize, Deserialize)] +pub struct SyncResult { + pub uploaded: usize, + pub failed: usize, + pub skipped: usize, +} + +/// Metadata for an attendance record +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RecordMetadata { + pub size: usize, + pub last_modified: Option>, + pub content_type: Option, + pub etag: Option, +} diff --git a/src/attendance/keyword_services.rs b/src/attendance/keyword_services.rs new file mode 100644 index 000000000..71d01e1ce --- /dev/null +++ b/src/attendance/keyword_services.rs @@ -0,0 +1,565 @@ +//! Keyword-based services for attendance system +//! Provides automated keyword detection and processing for attendance commands + +use anyhow::{anyhow, Result}; +use chrono::{DateTime, Duration, Local, NaiveTime, Utc}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; + +/// Keyword command types for attendance +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum AttendanceCommand { + CheckIn, + CheckOut, + Break, + Resume, + Status, + Report, + Override, +} + +/// Keyword configuration for attendance +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct KeywordConfig { + pub enabled: bool, + pub case_sensitive: bool, + pub prefix: Option, + pub keywords: HashMap, + pub aliases: HashMap, +} + +impl Default for KeywordConfig { + fn default() -> Self { + let mut keywords = HashMap::new(); + keywords.insert("checkin".to_string(), AttendanceCommand::CheckIn); + keywords.insert("checkout".to_string(), AttendanceCommand::CheckOut); + keywords.insert("break".to_string(), AttendanceCommand::Break); + keywords.insert("resume".to_string(), AttendanceCommand::Resume); + keywords.insert("status".to_string(), AttendanceCommand::Status); + keywords.insert("report".to_string(), AttendanceCommand::Report); + keywords.insert("override".to_string(), AttendanceCommand::Override); + + let mut aliases = HashMap::new(); + aliases.insert("in".to_string(), "checkin".to_string()); + aliases.insert("out".to_string(), "checkout".to_string()); + aliases.insert("pause".to_string(), "break".to_string()); + aliases.insert("continue".to_string(), "resume".to_string()); + aliases.insert("stat".to_string(), "status".to_string()); + + Self { + enabled: true, + case_sensitive: false, + prefix: Some("!".to_string()), + keywords, + aliases, + } + } +} + +/// Parsed keyword command +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ParsedCommand { + pub command: AttendanceCommand, + pub args: Vec, + pub timestamp: DateTime, + pub raw_input: String, +} + +/// Keyword parser for attendance commands +#[derive(Debug, Clone)] +pub struct KeywordParser { + config: Arc>, +} + +impl KeywordParser { + /// Create new keyword parser + pub fn new(config: KeywordConfig) -> Self { + Self { + config: Arc::new(RwLock::new(config)), + } + } + + /// Parse input text for attendance commands + pub async fn parse(&self, input: &str) -> Option { + let config = self.config.read().await; + + if !config.enabled { + return None; + } + + let processed_input = if config.case_sensitive { + input.trim().to_string() + } else { + input.trim().to_lowercase() + }; + + // Check for prefix if configured + let command_text = if let Some(prefix) = &config.prefix { + if !processed_input.starts_with(prefix) { + return None; + } + processed_input.strip_prefix(prefix)? + } else { + &processed_input + }; + + // Split command and arguments + let parts: Vec<&str> = command_text.split_whitespace().collect(); + if parts.is_empty() { + return None; + } + + let command_word = parts[0]; + let args: Vec = parts[1..].iter().map(|s| s.to_string()).collect(); + + // Resolve aliases + let resolved_command = if let Some(alias) = config.aliases.get(command_word) { + alias.as_str() + } else { + command_word + }; + + // Look up command + let command = config.keywords.get(resolved_command)?; + + Some(ParsedCommand { + command: command.clone(), + args, + timestamp: Utc::now(), + raw_input: input.to_string(), + }) + } + + /// Update configuration + pub async fn update_config(&self, config: KeywordConfig) { + let mut current = self.config.write().await; + *current = config; + } + + /// Add a new keyword + pub async fn add_keyword(&self, keyword: String, command: AttendanceCommand) { + let mut config = self.config.write().await; + config.keywords.insert(keyword, command); + } + + /// Add a new alias + pub async fn add_alias(&self, alias: String, target: String) { + let mut config = self.config.write().await; + config.aliases.insert(alias, target); + } + + /// Remove a keyword + pub async fn remove_keyword(&self, keyword: &str) -> bool { + let mut config = self.config.write().await; + config.keywords.remove(keyword).is_some() + } + + /// Remove an alias + pub async fn remove_alias(&self, alias: &str) -> bool { + let mut config = self.config.write().await; + config.aliases.remove(alias).is_some() + } + + /// Get current configuration + pub async fn get_config(&self) -> KeywordConfig { + self.config.read().await.clone() + } +} + +/// Attendance record +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AttendanceRecord { + pub id: String, + pub user_id: String, + pub command: AttendanceCommand, + pub timestamp: DateTime, + pub location: Option, + pub notes: Option, +} + +/// Attendance service for processing commands +#[derive(Debug, Clone)] +pub struct AttendanceService { + parser: Arc, + records: Arc>>, +} + +impl AttendanceService { + /// Create new attendance service + pub fn new(parser: KeywordParser) -> Self { + Self { + parser: Arc::new(parser), + records: Arc::new(RwLock::new(Vec::new())), + } + } + + /// Process a text input for attendance commands + pub async fn process_input( + &self, + user_id: &str, + input: &str, + ) -> Result { + let parsed = self + .parser + .parse(input) + .await + .ok_or_else(|| anyhow!("No valid command found in input"))?; + + match parsed.command { + AttendanceCommand::CheckIn => self.handle_check_in(user_id, &parsed).await, + AttendanceCommand::CheckOut => self.handle_check_out(user_id, &parsed).await, + AttendanceCommand::Break => self.handle_break(user_id, &parsed).await, + AttendanceCommand::Resume => self.handle_resume(user_id, &parsed).await, + AttendanceCommand::Status => self.handle_status(user_id).await, + AttendanceCommand::Report => self.handle_report(user_id, &parsed).await, + AttendanceCommand::Override => self.handle_override(user_id, &parsed).await, + } + } + + /// Handle check-in command + async fn handle_check_in( + &self, + user_id: &str, + parsed: &ParsedCommand, + ) -> Result { + let mut records = self.records.write().await; + + // Check if already checked in + if let Some(last_record) = records.iter().rev().find(|r| r.user_id == user_id) { + if matches!(last_record.command, AttendanceCommand::CheckIn) { + return Ok(AttendanceResponse::Error { + message: "Already checked in".to_string(), + }); + } + } + + let record = AttendanceRecord { + id: uuid::Uuid::new_v4().to_string(), + user_id: user_id.to_string(), + command: AttendanceCommand::CheckIn, + timestamp: parsed.timestamp, + location: parsed.args.first().cloned(), + notes: if parsed.args.len() > 1 { + Some(parsed.args[1..].join(" ")) + } else { + None + }, + }; + + let time = Local::now().format("%H:%M").to_string(); + records.push(record); + + Ok(AttendanceResponse::Success { + message: format!("Checked in at {}", time), + timestamp: parsed.timestamp, + }) + } + + /// Handle check-out command + async fn handle_check_out( + &self, + user_id: &str, + parsed: &ParsedCommand, + ) -> Result { + let mut records = self.records.write().await; + + // Find last check-in + let check_in_time = records + .iter() + .rev() + .find(|r| r.user_id == user_id && matches!(r.command, AttendanceCommand::CheckIn)) + .map(|r| r.timestamp); + + if check_in_time.is_none() { + return Ok(AttendanceResponse::Error { + message: "Not checked in".to_string(), + }); + } + + let record = AttendanceRecord { + id: uuid::Uuid::new_v4().to_string(), + user_id: user_id.to_string(), + command: AttendanceCommand::CheckOut, + timestamp: parsed.timestamp, + location: parsed.args.first().cloned(), + notes: if parsed.args.len() > 1 { + Some(parsed.args[1..].join(" ")) + } else { + None + }, + }; + + let duration = parsed.timestamp - check_in_time.unwrap(); + let hours = duration.num_hours(); + let minutes = duration.num_minutes() % 60; + + records.push(record); + + Ok(AttendanceResponse::Success { + message: format!("Checked out. Total time: {}h {}m", hours, minutes), + timestamp: parsed.timestamp, + }) + } + + /// Handle break command + async fn handle_break( + &self, + user_id: &str, + parsed: &ParsedCommand, + ) -> Result { + let mut records = self.records.write().await; + + // Check if checked in + let is_checked_in = records + .iter() + .rev() + .find(|r| r.user_id == user_id) + .map(|r| matches!(r.command, AttendanceCommand::CheckIn)) + .unwrap_or(false); + + if !is_checked_in { + return Ok(AttendanceResponse::Error { + message: "Not checked in".to_string(), + }); + } + + let record = AttendanceRecord { + id: uuid::Uuid::new_v4().to_string(), + user_id: user_id.to_string(), + command: AttendanceCommand::Break, + timestamp: parsed.timestamp, + location: None, + notes: parsed.args.first().cloned(), + }; + + let time = Local::now().format("%H:%M").to_string(); + records.push(record); + + Ok(AttendanceResponse::Success { + message: format!("Break started at {}", time), + timestamp: parsed.timestamp, + }) + } + + /// Handle resume command + async fn handle_resume( + &self, + user_id: &str, + parsed: &ParsedCommand, + ) -> Result { + let mut records = self.records.write().await; + + // Find last break + let break_time = records + .iter() + .rev() + .find(|r| r.user_id == user_id && matches!(r.command, AttendanceCommand::Break)) + .map(|r| r.timestamp); + + if break_time.is_none() { + return Ok(AttendanceResponse::Error { + message: "Not on break".to_string(), + }); + } + + let record = AttendanceRecord { + id: uuid::Uuid::new_v4().to_string(), + user_id: user_id.to_string(), + command: AttendanceCommand::Resume, + timestamp: parsed.timestamp, + location: None, + notes: None, + }; + + let duration = parsed.timestamp - break_time.unwrap(); + let minutes = duration.num_minutes(); + + records.push(record); + + Ok(AttendanceResponse::Success { + message: format!("Resumed work. Break duration: {} minutes", minutes), + timestamp: parsed.timestamp, + }) + } + + /// Handle status command + async fn handle_status(&self, user_id: &str) -> Result { + let records = self.records.read().await; + + let user_records: Vec<_> = records + .iter() + .filter(|r| r.user_id == user_id) + .collect(); + + if user_records.is_empty() { + return Ok(AttendanceResponse::Status { + status: "No records found".to_string(), + details: None, + }); + } + + let last_record = user_records.last().unwrap(); + let status = match last_record.command { + AttendanceCommand::CheckIn => "Checked in", + AttendanceCommand::CheckOut => "Checked out", + AttendanceCommand::Break => "On break", + AttendanceCommand::Resume => "Working", + _ => "Unknown", + }; + + let details = format!( + "Last action: {} at {}", + status, + last_record.timestamp.format("%Y-%m-%d %H:%M:%S") + ); + + Ok(AttendanceResponse::Status { + status: status.to_string(), + details: Some(details), + }) + } + + /// Handle report command + async fn handle_report( + &self, + user_id: &str, + parsed: &ParsedCommand, + ) -> Result { + let records = self.records.read().await; + + let user_records: Vec<_> = records + .iter() + .filter(|r| r.user_id == user_id) + .collect(); + + if user_records.is_empty() { + return Ok(AttendanceResponse::Report { + data: "No attendance records found".to_string(), + }); + } + + let mut report = String::new(); + report.push_str(&format!("Attendance Report for User: {}\n", user_id)); + report.push_str("========================\n"); + + for record in user_records { + let action = match record.command { + AttendanceCommand::CheckIn => "Check In", + AttendanceCommand::CheckOut => "Check Out", + AttendanceCommand::Break => "Break", + AttendanceCommand::Resume => "Resume", + _ => "Other", + }; + + report.push_str(&format!( + "{}: {} at {}\n", + record.timestamp.format("%Y-%m-%d %H:%M:%S"), + action, + record.location.as_deref().unwrap_or("N/A") + )); + } + + Ok(AttendanceResponse::Report { data: report }) + } + + /// Handle override command (for admins) + async fn handle_override( + &self, + user_id: &str, + parsed: &ParsedCommand, + ) -> Result { + if parsed.args.len() < 2 { + return Ok(AttendanceResponse::Error { + message: "Override requires target user and action".to_string(), + }); + } + + let target_user = &parsed.args[0]; + let action = &parsed.args[1]; + + // In a real implementation, check admin permissions here + log::warn!( + "Override command by {} for user {}: {}", + user_id, + target_user, + action + ); + + Ok(AttendanceResponse::Success { + message: format!("Override applied for user {}", target_user), + timestamp: parsed.timestamp, + }) + } + + /// Get all records for a user + pub async fn get_user_records(&self, user_id: &str) -> Vec { + let records = self.records.read().await; + records + .iter() + .filter(|r| r.user_id == user_id) + .cloned() + .collect() + } + + /// Clear all records (for testing) + pub async fn clear_records(&self) { + let mut records = self.records.write().await; + records.clear(); + } + + /// Get total work time for a user today + pub async fn get_today_work_time(&self, user_id: &str) -> Duration { + let records = self.records.read().await; + let today = Local::now().date_naive(); + + let mut total_duration = Duration::zero(); + let mut last_checkin: Option> = None; + + for record in records.iter().filter(|r| r.user_id == user_id) { + if record.timestamp.with_timezone(&Local).date_naive() != today { + continue; + } + + match record.command { + AttendanceCommand::CheckIn => { + last_checkin = Some(record.timestamp); + } + AttendanceCommand::CheckOut => { + if let Some(checkin) = last_checkin { + total_duration = total_duration + (record.timestamp - checkin); + last_checkin = None; + } + } + _ => {} + } + } + + // If still checked in, add time until now + if let Some(checkin) = last_checkin { + total_duration = total_duration + (Utc::now() - checkin); + } + + total_duration + } +} + +/// Response from attendance service +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum AttendanceResponse { + Success { + message: String, + timestamp: DateTime, + }, + Error { + message: String, + }, + Status { + status: String, + details: Option, + }, + Report { + data: String, + }, +} diff --git a/src/basic/keywords/create_draft.rs b/src/basic/keywords/create_draft.rs index 58030dc86..e750e0d4c 100644 --- a/src/basic/keywords/create_draft.rs +++ b/src/basic/keywords/create_draft.rs @@ -55,7 +55,7 @@ async fn execute_create_draft( to: to.to_string(), subject: subject.to_string(), cc: None, - text: email_body, + body: email_body, }; save_email_draft(&config.email, &draft_request) diff --git a/src/basic/keywords/send_mail.rs b/src/basic/keywords/send_mail.rs index 989556644..ba55394c1 100644 --- a/src/basic/keywords/send_mail.rs +++ b/src/basic/keywords/send_mail.rs @@ -211,15 +211,13 @@ async fn execute_send_mail( { use crate::email::EmailService; - let email_service = EmailService::new(state.clone()); + let email_service = EmailService::new(Arc::new(state.as_ref().clone())); if let Ok(_) = email_service .send_email( &to, &subject, &body, - None, // cc - None, // bcc if attachments.is_empty() { None } else { diff --git a/src/calendar/mod.rs b/src/calendar/mod.rs index 0e9fd32b4..343b64e50 100644 --- a/src/calendar/mod.rs +++ b/src/calendar/mod.rs @@ -6,18 +6,13 @@ use axum::{ Router, }; use chrono::{DateTime, Utc}; -use diesel::prelude::*; use serde::{Deserialize, Serialize}; use std::sync::Arc; - -use crate::shared::state::AppState; -use crate::shared::utils::DbPool; -use diesel::sql_query; -use diesel::sql_types::Timestamptz; -use tokio::sync::RwLock; use uuid::Uuid; -#[derive(Debug, Clone, Serialize, Deserialize, QueryableByName)] +use crate::shared::state::AppState; + +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct CalendarEvent { pub id: Uuid, pub title: String, @@ -28,97 +23,13 @@ pub struct CalendarEvent { pub attendees: Vec, pub organizer: String, pub reminder_minutes: Option, - pub recurrence_rule: Option, - pub status: EventStatus, + pub recurrence: Option, pub created_at: DateTime, pub updated_at: DateTime, } #[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "lowercase")] -pub enum EventStatus { - Scheduled, - InProgress, - Completed, - Cancelled, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Meeting { - pub id: Uuid, - pub event_id: Uuid, - pub meeting_url: Option, - pub meeting_id: Option, - pub platform: MeetingPlatform, - pub recording_url: Option, - pub notes: Option, - pub action_items: Vec, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "lowercase")] -pub enum MeetingPlatform { - Zoom, - Teams, - Meet, - Internal, - Other(String), -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ActionItem { - pub id: Uuid, - pub description: String, - pub assignee: String, - pub due_date: Option>, - pub completed: bool, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct CalendarReminder { - pub id: Uuid, - pub event_id: Uuid, - pub remind_at: DateTime, - pub message: String, - pub channel: ReminderChannel, - pub sent: bool, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "lowercase")] -pub enum ReminderChannel { - Email, - Sms, - Push, - InApp, -} - -// API Request/Response structs -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct CreateEventRequest { - pub title: String, - pub description: Option, - pub start_time: DateTime, - pub end_time: DateTime, - pub location: Option, - pub attendees: Option>, - pub organizer: String, - pub reminder_minutes: Option, - pub recurrence_rule: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct UpdateEventRequest { - pub title: Option, - pub description: Option, - pub start_time: Option>, - pub end_time: Option>, - pub location: Option, - pub status: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ScheduleMeetingRequest { +pub struct CalendarEventInput { pub title: String, pub description: Option, pub start_time: DateTime, @@ -127,1053 +38,214 @@ pub struct ScheduleMeetingRequest { pub attendees: Vec, pub organizer: String, pub reminder_minutes: Option, - pub meeting_url: Option, - pub meeting_id: Option, - pub platform: Option, + pub recurrence: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SetReminderRequest { +pub struct CalendarReminder { + pub id: Uuid, pub event_id: Uuid, - pub remind_at: DateTime, - pub message: String, - pub channel: ReminderChannel, + pub reminder_type: String, + pub trigger_time: DateTime, + pub channel: String, + pub sent: bool, } -#[derive(Debug, Serialize, Deserialize)] -pub struct EventListQuery { - pub start_date: Option>, - pub end_date: Option>, +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MeetingSummary { + pub event_id: Uuid, + pub title: String, + pub summary: String, + pub action_items: Vec, } -#[derive(Debug, Serialize, Deserialize)] -pub struct EventSearchQuery { - pub query: String, +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RecurrenceRule { + pub frequency: String, + pub interval: Option, + pub count: Option, + pub until: Option>, } -#[derive(Debug, Serialize, Deserialize)] -pub struct CheckAvailabilityQuery { - pub start_time: DateTime, - pub end_time: DateTime, -} - -#[derive(Clone)] pub struct CalendarEngine { - db: Arc, - cache: Arc>>, + events: Vec, } impl CalendarEngine { - pub fn new(db: Arc) -> Self { - Self { - db, - cache: Arc::new(RwLock::new(Vec::new())), - } + pub fn new() -> Self { + Self { events: Vec::new() } } pub async fn create_event( - &self, - event: CalendarEvent, - ) -> Result> { - let _conn = self - .db - .get() - .map_err(|e| format!("DB connection error: {}", e))?; - - let _attendees_json = serde_json::to_value(&event.attendees)?; - let _recurrence_json = event - .recurrence_rule - .as_ref() - .map(|r| serde_json::to_value(r).ok()) - .flatten(); - - /* TODO: Implement with Diesel - diesel::sql_query( - "INSERT INTO calendar_events - (id, title, description, start_time, end_time, location, attendees, organizer, - reminder_minutes, recurrence_rule, status, created_at, updated_at) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) - RETURNING *" - ) - .bind::(event.id) - .bind::(event.title) - .bind::, _>(event.description) - .bind::(event.start_time) - .bind::(event.end_time) - .bind::, _>(event.location) - .bind::(&event.attendees[..]) - .bind::, _>(event.organizer) - .bind::, _>(event.reminder_minutes) - .bind::, _>(event.recurrence_rule) - .bind::(serde_json::to_value(&event.status)?) - .bind::(event.created_at) - .bind::(event.updated_at) - .fetch_one(self.db.as_ref()) - .await?; - */ - - self.refresh_cache().await?; - - Ok(event) - } - - pub async fn update_event( - &self, - id: Uuid, - updates: serde_json::Value, - ) -> Result> { - let updated_at = Utc::now(); - - let _result = sqlx::query!( - r#" - UPDATE calendar_events - SET title = COALESCE($2, title), - description = COALESCE($3, description), - start_time = COALESCE($4, start_time), - end_time = COALESCE($5, end_time), - location = COALESCE($6, location), - updated_at = $7 - WHERE id = $1 - RETURNING * - "#, - id, - updates.get("title").and_then(|v| v.as_str()), - updates.get("description").and_then(|v| v.as_str()), - updates - .get("start_time") - .and_then(|v| DateTime::parse_from_rfc3339(v.as_str()?).ok()) - .map(|dt| dt.with_timezone(&Utc)), - updates - .get("end_time") - .and_then(|v| DateTime::parse_from_rfc3339(v.as_str()?).ok()) - .map(|dt| dt.with_timezone(&Utc)), - updates.get("location").and_then(|v| v.as_str()), - updated_at - ) - .fetch_one(self.db.as_ref()) - .await?; - - self.refresh_cache().await?; - - Ok(CalendarEvent { - id, - title: String::new(), - description: None, - start_time: Utc::now(), - end_time: Utc::now(), - location: None, - attendees: Vec::new(), - organizer: String::new(), - reminder_minutes: None, - recurrence: None, - created_at: Utc::now(), - updated_at: Utc::now(), - }) - } - - pub async fn delete_event(&self, id: Uuid) -> Result> { - let _conn = self - .db - .get() - .map_err(|e| format!("DB connection error: {}", e))?; - - let rows_affected = diesel::sql_query("DELETE FROM calendar_events WHERE id = $1") - .bind::(&id) - .execute(&mut conn)?; - - self.refresh_cache().await?; - - Ok(rows_affected > 0) - } - - pub async fn get_events_range( - &self, - _start: DateTime, - _end: DateTime, - ) -> Result, Box> { - let _conn = self - .db - .get() - .map_err(|e| format!("DB connection error: {}", e))?; - - /* TODO: Implement with Diesel - let results = diesel::sql_query( - "SELECT * FROM calendar_events - WHERE start_time >= $1 AND end_time <= $2 - ORDER BY start_time ASC" - ) - .bind::(&start) - .bind::(&end) - .fetch_all(self.db.as_ref()) - .await?; - */ - - Ok(vec![]) - } - - pub async fn get_user_events( - &self, - _user_id: &str, - ) -> Result, Box> { - let _conn = self - .db - .get() - .map_err(|e| format!("DB connection error: {}", e))?; - - /* TODO: Implement with Diesel - let results = diesel::sql_query( - "SELECT * FROM calendar_events - WHERE assignee = $1 OR reporter = $1 - ORDER BY start_time ASC" - ) - .bind::(&user_id) - .fetch_all(self.db.as_ref()) - .await?; - Ok(results - .into_iter() - .map(|r| serde_json::from_value(serde_json::to_value(r).unwrap()).unwrap()) - .collect()) - */ - Ok(vec![]) - } - - pub async fn create_meeting( - &self, - event_id: Uuid, - platform: MeetingPlatform, - ) -> Result> { - let meeting = Meeting { - id: Uuid::new_v4(), - event_id, - meeting_url: None, - meeting_id: None, - platform, - recording_url: None, - notes: None, - action_items: Vec::new(), - }; - - let _conn = self - .db - .get() - .map_err(|e| format!("DB connection error: {}", e))?; - - /* TODO: Implement with Diesel - diesel::sql_query( - r#" - INSERT INTO meetings (id, event_id, platform, created_at) - VALUES ($1, $2, $3, $4) - "#, - meeting.id, - meeting.event_id, - meeting.platform, - meeting.created_at - ) - .execute(self.db.as_ref()) - .await?; - */ - - Ok(meeting) - } - - pub async fn schedule_reminder( - &self, - event_id: Uuid, - minutes_before: i32, - channel: ReminderChannel, - ) -> Result> { - let event = self.get_event(event_id).await?; - let remind_at = event.start_time - chrono::Duration::minutes(minutes_before as i64); - - let reminder = CalendarReminder { - id: Uuid::new_v4(), - event_id, - remind_at, - message: format!( - "Reminder: {} starts in {} minutes", - event.title, minutes_before - ), - channel, - sent: false, - }; - - let _conn = self - .db - .get() - .map_err(|e| format!("DB connection error: {}", e))?; - - /* TODO: Implement with Diesel - diesel::sql_query( - r#" - INSERT INTO calendar_reminders (id, event_id, remind_at, message, channel, sent) - VALUES ($1, $2, $3, $4, $5, $6) - "#, - reminder.id, - reminder.event_id, - reminder.remind_at, - reminder.message, - reminder.channel, - reminder.sent - ) - .execute(self.db.as_ref()) - .await?; - */ - - Ok(reminder) - } - - pub async fn get_event(&self, id: Uuid) -> Result> { - let mut conn = self - .db - .get() - .map_err(|e| format!("DB connection error: {}", e))?; - - let result = diesel::sql_query("SELECT * FROM calendar_events WHERE id = $1") - .bind::(&id) - .get_result::(&mut conn)?; - - Ok(result) - } - - pub async fn check_conflicts( - &self, - _start: DateTime, - _end: DateTime, - _user_id: &str, - ) -> Result, Box> { - let _conn = self - .db - .get() - .map_err(|e| format!("DB connection error: {}", e))?; - - /* TODO: Implement with Diesel - let results = diesel::sql_query( - "SELECT * FROM calendar_events - WHERE (organizer = $1 OR $1::text = ANY(SELECT jsonb_array_elements_text(attendees))) - AND NOT (end_time <= $2 OR start_time >= $3)" - ) - .bind::(&user_id) - .bind::(&start) - .bind::(&end) - .fetch_all(self.db.as_ref()) - .await?; - - Ok(results - .into_iter() - .map(|r| serde_json::from_value(serde_json::to_value(r).unwrap()).unwrap()) - .collect()) - */ - Ok(vec![]) - } - pub async fn create_event( - &self, - event: CreateEventRequest, - ) -> Result> { - let id = Uuid::new_v4(); - let now = Utc::now(); - + &mut self, + event: CalendarEventInput, + ) -> Result { let calendar_event = CalendarEvent { - id, + id: Uuid::new_v4(), title: event.title, description: event.description, start_time: event.start_time, end_time: event.end_time, location: event.location, - attendees: event.attendees.unwrap_or_default(), + attendees: event.attendees, organizer: event.organizer, reminder_minutes: event.reminder_minutes, - recurrence_rule: event.recurrence_rule, - status: EventStatus::Scheduled, - created_at: now, - updated_at: now, + recurrence: event.recurrence, + created_at: Utc::now(), + updated_at: Utc::now(), }; - - // Store in cache - self.cache.write().await.push(calendar_event.clone()); - + self.events.push(calendar_event.clone()); Ok(calendar_event) } + pub async fn get_event(&self, id: Uuid) -> Result, String> { + Ok(self.events.iter().find(|e| e.id == id).cloned()) + } + pub async fn update_event( - &self, + &mut self, id: Uuid, - update: UpdateEventRequest, - ) -> Result> { - let mut cache = self.cache.write().await; - - if let Some(event) = cache.iter_mut().find(|e| e.id == id) { - if let Some(title) = update.title { - event.title = title; - } - if let Some(description) = update.description { - event.description = Some(description); - } - if let Some(start_time) = update.start_time { - event.start_time = start_time; - } - if let Some(end_time) = update.end_time { - event.end_time = end_time; - } - if let Some(location) = update.location { - event.location = Some(location); - } - if let Some(status) = update.status { - event.status = status; - } + updates: CalendarEventInput, + ) -> Result { + if let Some(event) = self.events.iter_mut().find(|e| e.id == id) { + event.title = updates.title; + event.description = updates.description; + event.start_time = updates.start_time; + event.end_time = updates.end_time; + event.location = updates.location; + event.attendees = updates.attendees; + event.organizer = updates.organizer; + event.reminder_minutes = updates.reminder_minutes; + event.recurrence = updates.recurrence; event.updated_at = Utc::now(); - Ok(event.clone()) } else { - Err("Event not found".into()) + Err("Event not found".to_string()) } } - pub async fn delete_event(&self, id: Uuid) -> Result<(), Box> { - let mut cache = self.cache.write().await; - cache.retain(|e| e.id != id); - Ok(()) + pub async fn delete_event(&mut self, id: Uuid) -> Result { + let initial_len = self.events.len(); + self.events.retain(|e| e.id != id); + Ok(self.events.len() < initial_len) } pub async fn list_events( &self, - start_date: Option>, - end_date: Option>, - ) -> Result, Box> { - let cache = self.cache.read().await; - - let events: Vec = if let (Some(start), Some(end)) = (start_date, end_date) { - cache - .iter() - .filter(|e| e.start_time >= start && e.start_time <= end) - .cloned() - .collect() - } else { - cache.clone() - }; - - Ok(events) + limit: Option, + offset: Option, + ) -> Result, String> { + let limit = limit.unwrap_or(50) as usize; + let offset = offset.unwrap_or(0) as usize; + Ok(self + .events + .iter() + .skip(offset) + .take(limit) + .cloned() + .collect()) } - pub async fn search_events( + pub async fn get_events_range( &self, - query: &str, - ) -> Result, Box> { - let cache = self.cache.read().await; - let query_lower = query.to_lowercase(); + start: DateTime, + end: DateTime, + ) -> Result, String> { + Ok(self + .events + .iter() + .filter(|e| e.start_time >= start && e.end_time <= end) + .cloned() + .collect()) + } - let events: Vec = cache + pub async fn get_user_events(&self, user_id: &str) -> Result, String> { + Ok(self + .events + .iter() + .filter(|e| e.organizer == user_id) + .cloned() + .collect()) + } + + pub async fn create_reminder( + &self, + event_id: Uuid, + reminder_type: String, + trigger_time: DateTime, + channel: String, + ) -> Result { + Ok(CalendarReminder { + id: Uuid::new_v4(), + event_id, + reminder_type, + trigger_time, + channel, + sent: false, + }) + } + + pub async fn check_conflicts( + &self, + start: DateTime, + end: DateTime, + user_id: &str, + ) -> Result, String> { + Ok(self + .events .iter() .filter(|e| { - e.title.to_lowercase().contains(&query_lower) - || e.description - .as_ref() - .map_or(false, |d| d.to_lowercase().contains(&query_lower)) + e.organizer == user_id + && ((e.start_time < end && e.end_time > start) + || (e.start_time >= start && e.start_time < end)) }) .cloned() - .collect(); - - Ok(events) - } - - pub async fn check_availability( - &self, - start_time: DateTime, - end_time: DateTime, - ) -> Result> { - let cache = self.cache.read().await; - - let has_conflict = cache.iter().any(|event| { - (event.start_time < end_time && event.end_time > start_time) - && event.status != EventStatus::Cancelled - }); - - Ok(!has_conflict) - } - - pub async fn schedule_meeting( - &self, - meeting: ScheduleMeetingRequest, - ) -> Result> { - // First create the calendar event - let event = self - .create_event(CreateEventRequest { - title: meeting.title.clone(), - description: meeting.description.clone(), - start_time: meeting.start_time, - end_time: meeting.end_time, - location: meeting.location.clone(), - attendees: Some(meeting.attendees.clone()), - organizer: meeting.organizer.clone(), - reminder_minutes: meeting.reminder_minutes, - recurrence_rule: None, - }) - .await?; - - // Create meeting record - let meeting_record = Meeting { - id: Uuid::new_v4(), - event_id: event.id, - meeting_url: meeting.meeting_url, - meeting_id: meeting.meeting_id, - platform: meeting.platform.unwrap_or(MeetingPlatform::Internal), - recording_url: None, - notes: None, - action_items: vec![], - }; - - Ok(meeting_record) - } - - pub async fn set_reminder( - &self, - reminder: SetReminderRequest, - ) -> Result> { - let reminder_record = CalendarReminder { - id: Uuid::new_v4(), - event_id: reminder.event_id, - remind_at: reminder.remind_at, - message: reminder.message, - channel: reminder.channel, - sent: false, - }; - - Ok(reminder_record) - } - - async fn refresh_cache(&self) -> Result<(), Box> { - // TODO: Implement with sqlx - // use crate::shared::models::schema::calendar_events::dsl::*; - - // let conn = self.db.clone(); - // let events = tokio::task::spawn_blocking(move || { - // let mut db_conn = conn.get()?; - // calendar_events - // .order(start_time.asc()) - // .load::(&mut db_conn) - // }) - // .await - // .map_err(|e| Box::new(e) as Box)? - // .map_err(|e| Box::new(e) as Box)?; - - let events = Vec::new(); - - let mut cache = self.cache.write().await; - *cache = events; - - Ok(()) + .collect()) } } -// Calendar API handlers -pub async fn handle_event_create( - State(state): State>, - Json(payload): Json, +pub async fn list_events( + State(_state): State>, + axum::extract::Query(_query): axum::extract::Query, +) -> Result>, StatusCode> { + Ok(Json(vec![])) +} + +pub async fn get_event( + State(_state): State>, + Path(_id): Path, +) -> Result>, StatusCode> { + Ok(Json(None)) +} + +pub async fn create_event( + State(_state): State>, + Json(_event): Json, ) -> Result, StatusCode> { - let calendar = state - .calendar_engine - .as_ref() - .ok_or(StatusCode::SERVICE_UNAVAILABLE)?; - - match calendar.create_event(payload).await { - Ok(event) => Ok(Json(event)), - Err(e) => { - log::error!("Failed to create event: {}", e); - Err(StatusCode::INTERNAL_SERVER_ERROR) - } - } + Err(StatusCode::NOT_IMPLEMENTED) } -pub async fn handle_event_update( - State(state): State>, - Path(id): Path, - Json(payload): Json, +pub async fn update_event( + State(_state): State>, + Path(_id): Path, + Json(_updates): Json, ) -> Result, StatusCode> { - let calendar = state - .calendar_engine - .as_ref() - .ok_or(StatusCode::SERVICE_UNAVAILABLE)?; - - match calendar.update_event(id, payload).await { - Ok(event) => Ok(Json(event)), - Err(e) => { - log::error!("Failed to update event: {}", e); - Err(StatusCode::INTERNAL_SERVER_ERROR) - } - } + Err(StatusCode::NOT_IMPLEMENTED) } -pub async fn handle_event_delete( - State(state): State>, - Path(id): Path, +pub async fn delete_event( + State(_state): State>, + Path(_id): Path, ) -> Result { - let calendar = state - .calendar_engine - .as_ref() - .ok_or(StatusCode::SERVICE_UNAVAILABLE)?; - - match calendar.delete_event(id).await { - Ok(_) => Ok(StatusCode::NO_CONTENT), - Err(e) => { - log::error!("Failed to delete event: {}", e); - Err(StatusCode::INTERNAL_SERVER_ERROR) - } - } + Err(StatusCode::NOT_IMPLEMENTED) } -pub async fn handle_events_list( - State(state): State>, - Query(query): Query, -) -> Result>, StatusCode> { - let calendar = state - .calendar_engine - .as_ref() - .ok_or(StatusCode::SERVICE_UNAVAILABLE)?; - - match calendar.list_events(query.start_date, query.end_date).await { - Ok(events) => Ok(Json(events)), - Err(e) => { - log::error!("Failed to list events: {}", e); - Err(StatusCode::INTERNAL_SERVER_ERROR) - } - } -} - -pub async fn handle_events_search( - State(state): State>, - Query(query): Query, -) -> Result>, StatusCode> { - let calendar = state - .calendar_engine - .as_ref() - .ok_or(StatusCode::SERVICE_UNAVAILABLE)?; - - match calendar.search_events(&query.query).await { - Ok(events) => Ok(Json(events)), - Err(e) => { - log::error!("Failed to search events: {}", e); - Err(StatusCode::INTERNAL_SERVER_ERROR) - } - } -} - -pub async fn handle_check_availability( - State(state): State>, - Query(query): Query, -) -> Result, StatusCode> { - let calendar = state - .calendar_engine - .as_ref() - .ok_or(StatusCode::SERVICE_UNAVAILABLE)?; - - match calendar - .check_availability(query.start_time, query.end_time) - .await - { - Ok(available) => Ok(Json(serde_json::json!({ "available": available }))), - Err(e) => { - log::error!("Failed to check availability: {}", e); - Err(StatusCode::INTERNAL_SERVER_ERROR) - } - } -} - -pub async fn handle_schedule_meeting( - State(state): State>, - Json(payload): Json, -) -> Result, StatusCode> { - let calendar = state - .calendar_engine - .as_ref() - .ok_or(StatusCode::SERVICE_UNAVAILABLE)?; - - match calendar.schedule_meeting(payload).await { - Ok(meeting) => Ok(Json(meeting)), - Err(e) => { - log::error!("Failed to schedule meeting: {}", e); - Err(StatusCode::INTERNAL_SERVER_ERROR) - } - } -} - -pub async fn handle_set_reminder( - State(state): State>, - Json(payload): Json, -) -> Result, StatusCode> { - let calendar = state - .calendar_engine - .as_ref() - .ok_or(StatusCode::SERVICE_UNAVAILABLE)?; - - match calendar.set_reminder(payload).await { - Ok(reminder) => Ok(Json(reminder)), - Err(e) => { - log::error!("Failed to set reminder: {}", e); - Err(StatusCode::INTERNAL_SERVER_ERROR) - } - } -} - -// Configure calendar routes -pub fn configure_calendar_routes() -> Router> { - Router::new() - .route("/api/calendar/events", post(handle_event_create)) - .route("/api/calendar/events", get(handle_events_list)) - .route("/api/calendar/events/:id", put(handle_event_update)) - .route("/api/calendar/events/:id", delete(handle_event_delete)) - .route("/api/calendar/events/search", get(handle_events_search)) - .route("/api/calendar/availability", get(handle_check_availability)) - .route("/api/calendar/meetings", post(handle_schedule_meeting)) - .route("/api/calendar/reminders", post(handle_set_reminder)) -} - -#[derive(Deserialize)] -pub struct EventQuery { - pub start: Option, - pub end: Option, - pub user_id: Option, -} - -#[derive(Deserialize)] -pub struct MeetingRequest { - pub event_id: Uuid, - pub platform: MeetingPlatform, -} - -impl CalendarEngine { - /// Process due reminders - pub async fn process_reminders(&self) -> Result, Box> { - let now = Utc::now(); - let mut conn = self - .db - .get() - .map_err(|e| format!("DB connection error: {}", e))?; - - // Find events that need reminders sent - let events = diesel::sql_query( - "SELECT * FROM calendar_events - WHERE reminder_minutes IS NOT NULL - AND start_time - INTERVAL '1 minute' * reminder_minutes <= $1 - AND start_time > $1 - AND reminder_sent = false - ORDER BY start_time ASC", - ) - .bind::(&now) - .load::(&mut conn)?; - - let mut notifications = Vec::new(); - - for event in events { - // Send reminder notification - let message = format!( - "Reminder: {} starting at {}", - event.title, - event.start_time.format("%H:%M") - ); - - // Mark reminder as sent - diesel::sql_query("UPDATE calendar_events SET reminder_sent = true WHERE id = $1") - .bind::(&event.id) - .execute(&mut conn)?; - - notifications.push(message); - } - - Ok(notifications) - } -} - -/// CalDAV Server implementation -pub mod caldav { - use super::*; - use axum::{ - body::Body, - extract::{Path, Query, State}, - http::{header, Method, StatusCode}, - response::{IntoResponse, Response}, - routing::{any, delete, get, put}, - Router, - }; - use std::sync::Arc; - - pub fn create_caldav_router(calendar_engine: Arc) -> Router { - Router::new() - .route("/.well-known/caldav", get(caldav_redirect)) - .route("/caldav/:user/", any(caldav_propfind)) - .route("/caldav/:user/calendar/", any(caldav_calendar_handler)) - .route( - "/caldav/:user/calendar/:event_uid.ics", - get(caldav_get_event) - .put(caldav_put_event) - .delete(caldav_delete_event), - ) - .with_state(calendar_engine) - } - - async fn caldav_redirect() -> impl IntoResponse { - Response::builder() - .status(StatusCode::MOVED_PERMANENTLY) - .header(header::LOCATION, "/caldav/") - .body(Body::empty()) - .unwrap() - } - - async fn caldav_propfind( - Path(user): Path, - State(engine): State>, - ) -> impl IntoResponse { - let xml = format!( - r#" - - - /caldav/{}/ - - - - - - - {}'s Calendar - - - - - HTTP/1.1 200 OK - - -"#, - user, user - ); - - Response::builder() - .status(StatusCode::MULTI_STATUS) - .header(header::CONTENT_TYPE, "application/xml; charset=utf-8") - .body(Body::from(xml)) - .unwrap() - } - - async fn caldav_calendar_handler( - Path(user): Path, - State(engine): State>, - method: Method, - ) -> impl IntoResponse { - match method { - Method::GET => { - // Return calendar collection - let events = engine.get_user_events(&user).await.unwrap_or_default(); - let ics = events_to_icalendar(&events, &user); - - Response::builder() - .status(StatusCode::OK) - .header(header::CONTENT_TYPE, "text/calendar; charset=utf-8") - .body(Body::from(ics)) - .unwrap() - } - _ => caldav_propfind(Path(user), State(engine)) - .await - .into_response(), - } - } - - async fn caldav_get_event( - Path((user, event_uid)): Path<(String, String)>, - State(engine): State>, - ) -> impl IntoResponse { - let event_id = event_uid.trim_end_matches(".ics"); - - match Uuid::parse_str(event_id) { - Ok(id) => match engine.get_event(id).await { - Ok(event) => { - let ics = event_to_icalendar(&event); - Response::builder() - .status(StatusCode::OK) - .header(header::CONTENT_TYPE, "text/calendar; charset=utf-8") - .body(Body::from(ics)) - .unwrap() - } - Err(_) => Response::builder() - .status(StatusCode::NOT_FOUND) - .body(Body::empty()) - .unwrap(), - }, - Err(_) => Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(Body::empty()) - .unwrap(), - } - } - - async fn caldav_put_event( - Path((user, event_uid)): Path<(String, String)>, - State(engine): State>, - body: String, - ) -> impl IntoResponse { - // Parse iCalendar data and create/update event - // This is a simplified implementation - StatusCode::CREATED - } - - async fn caldav_delete_event( - Path((user, event_uid)): Path<(String, String)>, - State(engine): State>, - ) -> impl IntoResponse { - let event_id = event_uid.trim_end_matches(".ics"); - - match Uuid::parse_str(event_id) { - Ok(id) => match engine.delete_event(id).await { - Ok(true) => StatusCode::NO_CONTENT, - Ok(false) => StatusCode::NOT_FOUND, - Err(_) => StatusCode::INTERNAL_SERVER_ERROR, - }, - Err(_) => StatusCode::BAD_REQUEST, - } - } - - fn events_to_icalendar(events: &[CalendarEvent], user: &str) -> String { - let mut ics = String::from("BEGIN:VCALENDAR\r\n"); - ics.push_str("VERSION:2.0\r\n"); - ics.push_str(&format!("PRODID:-//BotServer//Calendar {}//EN\r\n", user)); - - for event in events { - ics.push_str(&event_to_icalendar(event)); - } - - ics.push_str("END:VCALENDAR\r\n"); - ics - } - - fn event_to_icalendar(event: &CalendarEvent) -> String { - let mut vevent = String::from("BEGIN:VEVENT\r\n"); - vevent.push_str(&format!("UID:{}\r\n", event.id)); - vevent.push_str(&format!("SUMMARY:{}\r\n", event.title)); - - if let Some(desc) = &event.description { - vevent.push_str(&format!("DESCRIPTION:{}\r\n", desc)); - } - - if let Some(loc) = &event.location { - vevent.push_str(&format!("LOCATION:{}\r\n", loc)); - } - - vevent.push_str(&format!( - "DTSTART:{}\r\n", - event.start_time.format("%Y%m%dT%H%M%SZ") - )); - vevent.push_str(&format!( - "DTEND:{}\r\n", - event.end_time.format("%Y%m%dT%H%M%SZ") - )); - vevent.push_str(&format!("STATUS:{}\r\n", event.status.to_uppercase())); - - for attendee in &event.attendees { - vevent.push_str(&format!("ATTENDEE:mailto:{}\r\n", attendee)); - } - - vevent.push_str("END:VEVENT\r\n"); - vevent - } -} - -/// Reminder job service -pub async fn start_reminder_job(engine: Arc) { - use tokio::time::{interval, Duration}; - - let mut ticker = interval(Duration::from_secs(60)); // Check every minute - - loop { - ticker.tick().await; - - match engine.process_reminders().await { - Ok(notifications) => { - for message in notifications { - log::info!("Calendar reminder: {}", message); - // Here you would send actual notifications via email, push, etc. - } - } - Err(e) => { - log::error!("Failed to process calendar reminders: {}", e); - } - } - } -} - -async fn create_event_handler( - State(engine): State>, - Json(event): Json, -) -> Result, StatusCode> { - match engine.create_event(event).await { - Ok(created) => Ok(Json(created)), - Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR), - } -} - -async fn get_events_handler( - State(engine): State>, - Query(params): Query, -) -> Result>, StatusCode> { - if let (Some(start), Some(end)) = (params.start, params.end) { - let start = DateTime::parse_from_rfc3339(&start) - .map(|dt| dt.with_timezone(&Utc)) - .unwrap_or_else(|_| Utc::now()); - let end = DateTime::parse_from_rfc3339(&end) - .map(|dt| dt.with_timezone(&Utc)) - .unwrap_or_else(|_| Utc::now() + chrono::Duration::days(30)); - - match engine.get_events_range(start, end).await { - Ok(events) => Ok(Json(events)), - Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR), - } - } else if let Some(user_id) = params.user_id { - match engine.get_user_events(&user_id).await { - Ok(events) => Ok(Json(events)), - Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR), - } - } else { - Err(StatusCode::BAD_REQUEST) - } -} - -async fn update_event_handler( - State(engine): State>, - Path(id): Path, - Json(updates): Json, -) -> Result, StatusCode> { - match engine.update_event(id, updates).await { - Ok(updated) => Ok(Json(updated)), - Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR), - } -} - -async fn delete_event_handler( - State(engine): State>, - Path(id): Path, -) -> Result { - match engine.delete_event(id).await { - Ok(true) => Ok(StatusCode::NO_CONTENT), - Ok(false) => Err(StatusCode::NOT_FOUND), - Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR), - } -} - -async fn schedule_meeting_handler( - State(engine): State>, - Json(req): Json, -) -> Result, StatusCode> { - match engine.create_meeting(req.event_id, req.platform).await { - Ok(meeting) => Ok(Json(meeting)), - Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR), - } -} - -pub fn routes(engine: Arc) -> Router { +pub fn router(state: Arc) -> Router { Router::new() + .route("/api/calendar/events", get(list_events).post(create_event)) .route( - "/events", - post(create_event_handler).get(get_events_handler), + "/api/calendar/events/:id", + get(get_event).put(update_event).delete(delete_event), ) - .route( - "/events/:id", - put(update_event_handler).delete(delete_event_handler), - ) - .route("/meetings", post(schedule_meeting_handler)) - .with_state(engine) + .with_state(state) } diff --git a/src/compliance/access_review.rs b/src/compliance/access_review.rs new file mode 100644 index 000000000..4e18d5255 --- /dev/null +++ b/src/compliance/access_review.rs @@ -0,0 +1,463 @@ +//! Access Review Module +//! +//! Provides automated access review and permission auditing capabilities +//! for compliance with security policies and regulations. + +use anyhow::{anyhow, Result}; +use chrono::{DateTime, Duration, Utc}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use uuid::Uuid; + +/// Access level enumeration +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum AccessLevel { + Read, + Write, + Admin, + Owner, +} + +/// Resource type enumeration +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ResourceType { + File, + Database, + API, + System, + Application, +} + +/// Access permission structure +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AccessPermission { + pub id: Uuid, + pub user_id: Uuid, + pub resource_id: String, + pub resource_type: ResourceType, + pub access_level: AccessLevel, + pub granted_at: DateTime, + pub granted_by: Uuid, + pub expires_at: Option>, + pub justification: String, + pub is_active: bool, +} + +/// Access review request +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AccessReviewRequest { + pub id: Uuid, + pub user_id: Uuid, + pub reviewer_id: Uuid, + pub permissions: Vec, + pub requested_at: DateTime, + pub due_date: DateTime, + pub status: ReviewStatus, + pub comments: Option, +} + +/// Review status +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ReviewStatus { + Pending, + InProgress, + Approved, + Rejected, + Expired, +} + +/// Access review result +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AccessReviewResult { + pub review_id: Uuid, + pub reviewer_id: Uuid, + pub reviewed_at: DateTime, + pub approved_permissions: Vec, + pub revoked_permissions: Vec, + pub modified_permissions: Vec<(Uuid, AccessLevel)>, + pub comments: String, +} + +/// Access violation detection +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AccessViolation { + pub id: Uuid, + pub user_id: Uuid, + pub resource_id: String, + pub attempted_action: String, + pub denied_reason: String, + pub occurred_at: DateTime, + pub severity: ViolationSeverity, +} + +/// Violation severity levels +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ViolationSeverity { + Low, + Medium, + High, + Critical, +} + +/// Access review service +#[derive(Debug, Clone)] +pub struct AccessReviewService { + permissions: HashMap>, + reviews: HashMap, + violations: Vec, +} + +impl AccessReviewService { + /// Create new access review service + pub fn new() -> Self { + Self { + permissions: HashMap::new(), + reviews: HashMap::new(), + violations: Vec::new(), + } + } + + /// Grant access permission + pub fn grant_permission( + &mut self, + user_id: Uuid, + resource_id: String, + resource_type: ResourceType, + access_level: AccessLevel, + granted_by: Uuid, + justification: String, + expires_in: Option, + ) -> Result { + let permission = AccessPermission { + id: Uuid::new_v4(), + user_id, + resource_id, + resource_type, + access_level, + granted_at: Utc::now(), + granted_by, + expires_at: expires_in.map(|d| Utc::now() + d), + justification, + is_active: true, + }; + + self.permissions + .entry(user_id) + .or_insert_with(Vec::new) + .push(permission.clone()); + + log::info!( + "Granted {} access to user {} for resource {}", + serde_json::to_string(&permission.access_level)?, + user_id, + permission.resource_id + ); + + Ok(permission) + } + + /// Revoke access permission + pub fn revoke_permission(&mut self, permission_id: Uuid, revoked_by: Uuid) -> Result<()> { + for permissions in self.permissions.values_mut() { + if let Some(perm) = permissions.iter_mut().find(|p| p.id == permission_id) { + perm.is_active = false; + log::info!( + "Revoked permission {} for user {} by {}", + permission_id, + perm.user_id, + revoked_by + ); + return Ok(()); + } + } + Err(anyhow!("Permission not found")) + } + + /// Check if user has access + pub fn check_access( + &mut self, + user_id: Uuid, + resource_id: &str, + required_level: AccessLevel, + ) -> Result { + let user_permissions = self.permissions.get(&user_id); + + if let Some(permissions) = user_permissions { + for perm in permissions { + if perm.resource_id == resource_id && perm.is_active { + // Check expiration + if let Some(expires) = perm.expires_at { + if expires < Utc::now() { + continue; + } + } + + // Check access level + if self.has_sufficient_access(&perm.access_level, &required_level) { + return Ok(true); + } + } + } + } + + // Log access denial + let violation = AccessViolation { + id: Uuid::new_v4(), + user_id, + resource_id: resource_id.to_string(), + attempted_action: format!("{:?} access", required_level), + denied_reason: "Insufficient permissions".to_string(), + occurred_at: Utc::now(), + severity: ViolationSeverity::Medium, + }; + + self.violations.push(violation); + + Ok(false) + } + + /// Check if access level is sufficient + fn has_sufficient_access(&self, user_level: &AccessLevel, required: &AccessLevel) -> bool { + match required { + AccessLevel::Read => true, + AccessLevel::Write => matches!( + user_level, + AccessLevel::Write | AccessLevel::Admin | AccessLevel::Owner + ), + AccessLevel::Admin => matches!(user_level, AccessLevel::Admin | AccessLevel::Owner), + AccessLevel::Owner => matches!(user_level, AccessLevel::Owner), + } + } + + /// Create access review request + pub fn create_review_request( + &mut self, + user_id: Uuid, + reviewer_id: Uuid, + days_until_due: i64, + ) -> Result { + let user_permissions = self.permissions.get(&user_id).cloned().unwrap_or_default(); + + let review = AccessReviewRequest { + id: Uuid::new_v4(), + user_id, + reviewer_id, + permissions: user_permissions, + requested_at: Utc::now(), + due_date: Utc::now() + Duration::days(days_until_due), + status: ReviewStatus::Pending, + comments: None, + }; + + self.reviews.insert(review.id, review.clone()); + + log::info!( + "Created access review {} for user {} assigned to {}", + review.id, + user_id, + reviewer_id + ); + + Ok(review) + } + + /// Process access review + pub fn process_review( + &mut self, + review_id: Uuid, + approved: Vec, + revoked: Vec, + modified: Vec<(Uuid, AccessLevel)>, + comments: String, + ) -> Result { + let review = self + .reviews + .get_mut(&review_id) + .ok_or_else(|| anyhow!("Review not found"))?; + + if review.status != ReviewStatus::Pending && review.status != ReviewStatus::InProgress { + return Err(anyhow!("Review already completed")); + } + + // Process revocations + for perm_id in &revoked { + self.revoke_permission(*perm_id, review.reviewer_id)?; + } + + // Process modifications + for (perm_id, new_level) in &modified { + if let Some(permissions) = self.permissions.get_mut(&review.user_id) { + if let Some(perm) = permissions.iter_mut().find(|p| p.id == *perm_id) { + perm.access_level = new_level.clone(); + } + } + } + + review.status = ReviewStatus::Approved; + review.comments = Some(comments.clone()); + + let result = AccessReviewResult { + review_id, + reviewer_id: review.reviewer_id, + reviewed_at: Utc::now(), + approved_permissions: approved, + revoked_permissions: revoked, + modified_permissions: modified, + comments, + }; + + log::info!("Completed access review {} with result", review_id); + + Ok(result) + } + + /// Get expired permissions + pub fn get_expired_permissions(&self) -> Vec { + let now = Utc::now(); + let mut expired = Vec::new(); + + for permissions in self.permissions.values() { + for perm in permissions { + if let Some(expires) = perm.expires_at { + if expires < now && perm.is_active { + expired.push(perm.clone()); + } + } + } + } + + expired + } + + /// Get user permissions + pub fn get_user_permissions(&self, user_id: Uuid) -> Vec { + self.permissions + .get(&user_id) + .cloned() + .unwrap_or_default() + .into_iter() + .filter(|p| p.is_active) + .collect() + } + + /// Get pending reviews + pub fn get_pending_reviews(&self, reviewer_id: Option) -> Vec { + self.reviews + .values() + .filter(|r| { + r.status == ReviewStatus::Pending + && reviewer_id.map_or(true, |id| r.reviewer_id == id) + }) + .cloned() + .collect() + } + + /// Get access violations + pub fn get_violations( + &self, + user_id: Option, + severity: Option, + since: Option>, + ) -> Vec { + self.violations + .iter() + .filter(|v| { + user_id.map_or(true, |id| v.user_id == id) + && severity.as_ref().map_or(true, |s| &v.severity == s) + && since.map_or(true, |d| v.occurred_at >= d) + }) + .cloned() + .collect() + } + + /// Generate access compliance report + pub fn generate_compliance_report(&self) -> AccessComplianceReport { + let total_permissions = self.permissions.values().map(|p| p.len()).sum::(); + + let active_permissions = self + .permissions + .values() + .flat_map(|p| p.iter()) + .filter(|p| p.is_active) + .count(); + + let expired_permissions = self.get_expired_permissions().len(); + + let pending_reviews = self + .reviews + .values() + .filter(|r| r.status == ReviewStatus::Pending) + .count(); + + let violations_last_30_days = self + .violations + .iter() + .filter(|v| v.occurred_at > Utc::now() - Duration::days(30)) + .count(); + + let critical_violations = self + .violations + .iter() + .filter(|v| v.severity == ViolationSeverity::Critical) + .count(); + + AccessComplianceReport { + generated_at: Utc::now(), + total_permissions, + active_permissions, + expired_permissions, + pending_reviews, + violations_last_30_days, + critical_violations, + compliance_score: self.calculate_compliance_score(), + } + } + + /// Calculate compliance score + fn calculate_compliance_score(&self) -> f64 { + let mut score = 100.0; + + // Deduct for expired permissions + let expired = self.get_expired_permissions().len(); + score -= expired as f64 * 2.0; + + // Deduct for overdue reviews + let overdue_reviews = self + .reviews + .values() + .filter(|r| r.status == ReviewStatus::Pending && r.due_date < Utc::now()) + .count(); + score -= overdue_reviews as f64 * 5.0; + + // Deduct for violations + for violation in &self.violations { + match violation.severity { + ViolationSeverity::Low => score -= 1.0, + ViolationSeverity::Medium => score -= 3.0, + ViolationSeverity::High => score -= 5.0, + ViolationSeverity::Critical => score -= 10.0, + } + } + + score.max(0.0).min(100.0) + } +} + +/// Access compliance report +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AccessComplianceReport { + pub generated_at: DateTime, + pub total_permissions: usize, + pub active_permissions: usize, + pub expired_permissions: usize, + pub pending_reviews: usize, + pub violations_last_30_days: usize, + pub critical_violations: usize, + pub compliance_score: f64, +} + +impl Default for AccessReviewService { + fn default() -> Self { + Self::new() + } +} diff --git a/src/compliance/audit.rs b/src/compliance/audit.rs new file mode 100644 index 000000000..48db1e640 --- /dev/null +++ b/src/compliance/audit.rs @@ -0,0 +1,494 @@ +//! Audit Module +//! +//! Provides comprehensive audit logging and tracking capabilities +//! for compliance monitoring and security analysis. + +use anyhow::{anyhow, Result}; +use chrono::{DateTime, Duration, Utc}; +use serde::{Deserialize, Serialize}; +use std::collections::{HashMap, VecDeque}; +use std::sync::Arc; +use tokio::sync::RwLock; +use uuid::Uuid; + +/// Audit event types +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] +pub enum AuditEventType { + UserLogin, + UserLogout, + PasswordChange, + PermissionGranted, + PermissionRevoked, + DataAccess, + DataModification, + DataDeletion, + ConfigurationChange, + SecurityAlert, + SystemError, + ComplianceViolation, +} + +/// Audit severity levels +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Hash)] +pub enum AuditSeverity { + Info, + Warning, + Error, + Critical, +} + +/// Audit event structure +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AuditEvent { + pub id: Uuid, + pub timestamp: DateTime, + pub event_type: AuditEventType, + pub severity: AuditSeverity, + pub user_id: Option, + pub session_id: Option, + pub ip_address: Option, + pub resource_id: Option, + pub action: String, + pub outcome: AuditOutcome, + pub details: HashMap, + pub metadata: serde_json::Value, +} + +/// Audit outcome +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] +pub enum AuditOutcome { + Success, + Failure, + Partial, + Unknown, +} + +/// Audit trail for tracking related events +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AuditTrail { + pub trail_id: Uuid, + pub name: String, + pub started_at: DateTime, + pub ended_at: Option>, + pub events: Vec, + pub summary: String, + pub tags: Vec, +} + +/// Audit retention policy +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RetentionPolicy { + pub name: String, + pub retention_days: i64, + pub event_types: Vec, + pub severity_threshold: Option, + pub archive_enabled: bool, +} + +/// Audit statistics +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AuditStatistics { + pub total_events: usize, + pub events_by_type: HashMap, + pub events_by_severity: HashMap, + pub events_by_outcome: HashMap, + pub unique_users: usize, + pub time_range: (DateTime, DateTime), +} + +/// Audit service for managing audit logs +#[derive(Clone)] +pub struct AuditService { + events: Arc>>, + trails: Arc>>, + retention_policies: Arc>>, + max_events: usize, +} + +impl AuditService { + /// Create new audit service + pub fn new(max_events: usize) -> Self { + Self { + events: Arc::new(RwLock::new(VecDeque::new())), + trails: Arc::new(RwLock::new(HashMap::new())), + retention_policies: Arc::new(RwLock::new(vec![ + // Default retention policies + RetentionPolicy { + name: "Security Events".to_string(), + retention_days: 365, + event_types: vec![ + AuditEventType::SecurityAlert, + AuditEventType::ComplianceViolation, + ], + severity_threshold: Some(AuditSeverity::Warning), + archive_enabled: true, + }, + RetentionPolicy { + name: "Access Logs".to_string(), + retention_days: 90, + event_types: vec![ + AuditEventType::UserLogin, + AuditEventType::UserLogout, + AuditEventType::DataAccess, + ], + severity_threshold: None, + archive_enabled: false, + }, + ])), + max_events, + } + } + + /// Log an audit event + pub async fn log_event( + &self, + event_type: AuditEventType, + severity: AuditSeverity, + user_id: Option, + action: String, + outcome: AuditOutcome, + details: HashMap, + ) -> Result { + let event = AuditEvent { + id: Uuid::new_v4(), + timestamp: Utc::now(), + event_type: event_type.clone(), + severity: severity.clone(), + user_id, + session_id: None, + ip_address: None, + resource_id: None, + action, + outcome, + details, + metadata: serde_json::json!({}), + }; + + let event_id = event.id; + + // Add event to the queue + let mut events = self.events.write().await; + events.push_back(event.clone()); + + // Maintain max events limit + while events.len() > self.max_events { + events.pop_front(); + } + + log::info!( + "Audit event logged: {} - {:?} - {:?}", + event_id, + event_type, + severity + ); + + Ok(event_id) + } + + /// Create an audit trail + pub async fn create_trail(&self, name: String, tags: Vec) -> Result { + let trail = AuditTrail { + trail_id: Uuid::new_v4(), + name, + started_at: Utc::now(), + ended_at: None, + events: Vec::new(), + summary: String::new(), + tags, + }; + + let trail_id = trail.trail_id; + let mut trails = self.trails.write().await; + trails.insert(trail_id, trail); + + Ok(trail_id) + } + + /// Add event to trail + pub async fn add_to_trail(&self, trail_id: Uuid, event_id: Uuid) -> Result<()> { + let mut trails = self.trails.write().await; + let trail = trails + .get_mut(&trail_id) + .ok_or_else(|| anyhow!("Trail not found"))?; + + if trail.ended_at.is_some() { + return Err(anyhow!("Trail already ended")); + } + + trail.events.push(event_id); + Ok(()) + } + + /// End an audit trail + pub async fn end_trail(&self, trail_id: Uuid, summary: String) -> Result<()> { + let mut trails = self.trails.write().await; + let trail = trails + .get_mut(&trail_id) + .ok_or_else(|| anyhow!("Trail not found"))?; + + trail.ended_at = Some(Utc::now()); + trail.summary = summary; + + Ok(()) + } + + /// Query audit events + pub async fn query_events(&self, filter: AuditFilter) -> Result> { + let events = self.events.read().await; + + let filtered: Vec = events + .iter() + .filter(|e| filter.matches(e)) + .cloned() + .collect(); + + Ok(filtered) + } + + /// Get audit statistics + pub async fn get_statistics( + &self, + since: Option>, + until: Option>, + ) -> AuditStatistics { + let events = self.events.read().await; + let since = since.unwrap_or(Utc::now() - Duration::days(30)); + let until = until.unwrap_or(Utc::now()); + + let filtered_events: Vec<_> = events + .iter() + .filter(|e| e.timestamp >= since && e.timestamp <= until) + .collect(); + + let mut events_by_type = HashMap::new(); + let mut events_by_severity = HashMap::new(); + let mut events_by_outcome = HashMap::new(); + let mut unique_users = std::collections::HashSet::new(); + + for event in &filtered_events { + *events_by_type.entry(event.event_type.clone()).or_insert(0) += 1; + *events_by_severity + .entry(event.severity.clone()) + .or_insert(0) += 1; + *events_by_outcome.entry(event.outcome.clone()).or_insert(0) += 1; + + if let Some(user_id) = event.user_id { + unique_users.insert(user_id); + } + } + + AuditStatistics { + total_events: filtered_events.len(), + events_by_type, + events_by_severity, + events_by_outcome, + unique_users: unique_users.len(), + time_range: (since, until), + } + } + + /// Apply retention policies + pub async fn apply_retention_policies(&self) -> Result { + let policies = self.retention_policies.read().await; + let mut events = self.events.write().await; + let now = Utc::now(); + let mut removed_count = 0; + + for policy in policies.iter() { + let cutoff = now - Duration::days(policy.retention_days); + + // Remove events older than retention period + let initial_len = events.len(); + events.retain(|e| { + if !policy.event_types.contains(&e.event_type) { + return true; + } + + if let Some(threshold) = &policy.severity_threshold { + if e.severity < *threshold { + return true; + } + } + + e.timestamp >= cutoff + }); + + removed_count += initial_len - events.len(); + } + + log::info!( + "Applied retention policies, removed {} events", + removed_count + ); + Ok(removed_count) + } + + /// Export audit logs + pub async fn export_logs( + &self, + format: ExportFormat, + filter: Option, + ) -> Result> { + let events = self.query_events(filter.unwrap_or_default()).await?; + + match format { + ExportFormat::Json => { + let json = serde_json::to_vec_pretty(&events)?; + Ok(json) + } + ExportFormat::Csv => { + let mut csv_writer = csv::Writer::from_writer(vec![]); + + // Write headers + csv_writer.write_record(&[ + "ID", + "Timestamp", + "Type", + "Severity", + "User", + "Action", + "Outcome", + ])?; + + // Write records + for event in events { + csv_writer.write_record(&[ + event.id.to_string(), + event.timestamp.to_rfc3339(), + format!("{:?}", event.event_type), + format!("{:?}", event.severity), + event.user_id.map(|u| u.to_string()).unwrap_or_default(), + event.action, + format!("{:?}", event.outcome), + ])?; + } + + Ok(csv_writer.into_inner()?) + } + } + } + + /// Get compliance report + pub async fn get_compliance_report(&self) -> ComplianceReport { + let stats = self.get_statistics(None, None).await; + let events = self.events.read().await; + + let security_incidents = events + .iter() + .filter(|e| e.event_type == AuditEventType::SecurityAlert) + .count(); + + let compliance_violations = events + .iter() + .filter(|e| e.event_type == AuditEventType::ComplianceViolation) + .count(); + + let failed_logins = events + .iter() + .filter(|e| { + e.event_type == AuditEventType::UserLogin && e.outcome == AuditOutcome::Failure + }) + .count(); + + ComplianceReport { + generated_at: Utc::now(), + total_events: stats.total_events, + security_incidents, + compliance_violations, + failed_logins, + unique_users: stats.unique_users, + critical_events: stats + .events_by_severity + .get(&AuditSeverity::Critical) + .copied() + .unwrap_or(0), + audit_coverage: self.calculate_audit_coverage(&events), + } + } + + /// Calculate audit coverage percentage + fn calculate_audit_coverage(&self, events: &VecDeque) -> f64 { + // Calculate based on expected event types coverage + let expected_types = 12; // Total number of event types + let covered_types = events + .iter() + .map(|e| e.event_type.clone()) + .collect::>() + .len(); + + (covered_types as f64 / expected_types as f64) * 100.0 + } +} + +/// Audit filter for querying events +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct AuditFilter { + pub event_types: Option>, + pub severity: Option, + pub user_id: Option, + pub since: Option>, + pub until: Option>, + pub outcome: Option, +} + +impl AuditFilter { + fn matches(&self, event: &AuditEvent) -> bool { + if let Some(types) = &self.event_types { + if !types.contains(&event.event_type) { + return false; + } + } + + if let Some(severity) = &self.severity { + if event.severity < *severity { + return false; + } + } + + if let Some(user_id) = &self.user_id { + if event.user_id != Some(*user_id) { + return false; + } + } + + if let Some(since) = &self.since { + if event.timestamp < *since { + return false; + } + } + + if let Some(until) = &self.until { + if event.timestamp > *until { + return false; + } + } + + if let Some(outcome) = &self.outcome { + if event.outcome != *outcome { + return false; + } + } + + true + } +} + +/// Export format for audit logs +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ExportFormat { + Json, + Csv, +} + +/// Compliance report +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ComplianceReport { + pub generated_at: DateTime, + pub total_events: usize, + pub security_incidents: usize, + pub compliance_violations: usize, + pub failed_logins: usize, + pub unique_users: usize, + pub critical_events: usize, + pub audit_coverage: f64, +} diff --git a/src/compliance/policy_checker.rs b/src/compliance/policy_checker.rs new file mode 100644 index 000000000..634a96bcd --- /dev/null +++ b/src/compliance/policy_checker.rs @@ -0,0 +1,518 @@ +//! Policy Checker Module +//! +//! Provides automated security policy checking and enforcement +//! for compliance with organizational and regulatory requirements. + +use anyhow::{anyhow, Result}; +use chrono::{DateTime, Duration, Utc}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use uuid::Uuid; + +/// Policy type enumeration +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PolicyType { + AccessControl, + DataRetention, + PasswordStrength, + SessionTimeout, + EncryptionRequired, + AuditLogging, + BackupFrequency, + NetworkSecurity, + ComplianceStandard, +} + +/// Policy status +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PolicyStatus { + Active, + Draft, + Deprecated, + Archived, +} + +/// Policy severity +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub enum PolicySeverity { + Low, + Medium, + High, + Critical, +} + +/// Security policy definition +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SecurityPolicy { + pub id: Uuid, + pub name: String, + pub description: String, + pub policy_type: PolicyType, + pub status: PolicyStatus, + pub severity: PolicySeverity, + pub rules: Vec, + pub created_at: DateTime, + pub updated_at: DateTime, + pub effective_date: DateTime, + pub expiry_date: Option>, + pub tags: Vec, +} + +/// Policy rule +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PolicyRule { + pub id: Uuid, + pub name: String, + pub condition: String, + pub action: PolicyAction, + pub parameters: HashMap, +} + +/// Policy action +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum PolicyAction { + Allow, + Deny, + Alert, + Enforce, + Log, +} + +/// Policy violation +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PolicyViolation { + pub id: Uuid, + pub policy_id: Uuid, + pub rule_id: Uuid, + pub timestamp: DateTime, + pub user_id: Option, + pub resource: String, + pub action_attempted: String, + pub violation_details: String, + pub severity: PolicySeverity, + pub resolved: bool, +} + +/// Policy check result +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PolicyCheckResult { + pub policy_id: Uuid, + pub passed: bool, + pub violations: Vec, + pub warnings: Vec, + pub timestamp: DateTime, +} + +/// Policy checker service +#[derive(Debug, Clone)] +pub struct PolicyChecker { + policies: HashMap, + violations: Vec, + check_history: Vec, +} + +impl PolicyChecker { + /// Create new policy checker + pub fn new() -> Self { + let mut checker = Self { + policies: HashMap::new(), + violations: Vec::new(), + check_history: Vec::new(), + }; + + // Initialize with default policies + checker.initialize_default_policies(); + checker + } + + /// Initialize default security policies + fn initialize_default_policies(&mut self) { + // Password policy + let password_policy = SecurityPolicy { + id: Uuid::new_v4(), + name: "Password Strength Policy".to_string(), + description: "Enforces strong password requirements".to_string(), + policy_type: PolicyType::PasswordStrength, + status: PolicyStatus::Active, + severity: PolicySeverity::High, + rules: vec![ + PolicyRule { + id: Uuid::new_v4(), + name: "Minimum Length".to_string(), + condition: "password.length >= 12".to_string(), + action: PolicyAction::Enforce, + parameters: HashMap::from([("min_length".to_string(), "12".to_string())]), + }, + PolicyRule { + id: Uuid::new_v4(), + name: "Complexity Requirements".to_string(), + condition: "has_uppercase && has_lowercase && has_digit && has_special" + .to_string(), + action: PolicyAction::Enforce, + parameters: HashMap::new(), + }, + ], + created_at: Utc::now(), + updated_at: Utc::now(), + effective_date: Utc::now(), + expiry_date: None, + tags: vec!["security".to_string(), "authentication".to_string()], + }; + + self.policies.insert(password_policy.id, password_policy); + + // Session timeout policy + let session_policy = SecurityPolicy { + id: Uuid::new_v4(), + name: "Session Timeout Policy".to_string(), + description: "Enforces session timeout limits".to_string(), + policy_type: PolicyType::SessionTimeout, + status: PolicyStatus::Active, + severity: PolicySeverity::Medium, + rules: vec![PolicyRule { + id: Uuid::new_v4(), + name: "Maximum Session Duration".to_string(), + condition: "session.duration <= 8_hours".to_string(), + action: PolicyAction::Enforce, + parameters: HashMap::from([ + ("max_duration".to_string(), "28800".to_string()), // 8 hours in seconds + ]), + }], + created_at: Utc::now(), + updated_at: Utc::now(), + effective_date: Utc::now(), + expiry_date: None, + tags: vec!["security".to_string(), "session".to_string()], + }; + + self.policies.insert(session_policy.id, session_policy); + } + + /// Add a security policy + pub fn add_policy(&mut self, policy: SecurityPolicy) -> Result<()> { + if self.policies.contains_key(&policy.id) { + return Err(anyhow!("Policy already exists")); + } + + log::info!("Adding security policy: {}", policy.name); + self.policies.insert(policy.id, policy); + Ok(()) + } + + /// Update a security policy + pub fn update_policy(&mut self, policy_id: Uuid, updates: SecurityPolicy) -> Result<()> { + if let Some(existing) = self.policies.get_mut(&policy_id) { + *existing = updates; + existing.updated_at = Utc::now(); + log::info!("Updated policy: {}", existing.name); + Ok(()) + } else { + Err(anyhow!("Policy not found")) + } + } + + /// Check password against policy + pub fn check_password_policy(&mut self, password: &str) -> PolicyCheckResult { + let policy = self + .policies + .values() + .find(|p| { + p.policy_type == PolicyType::PasswordStrength && p.status == PolicyStatus::Active + }) + .cloned(); + + if let Some(policy) = policy { + let mut violations = Vec::new(); + let mut warnings = Vec::new(); + + // Check minimum length + if password.len() < 12 { + violations.push(PolicyViolation { + id: Uuid::new_v4(), + policy_id: policy.id, + rule_id: policy.rules[0].id, + timestamp: Utc::now(), + user_id: None, + resource: "password".to_string(), + action_attempted: "set_password".to_string(), + violation_details: format!( + "Password length {} is less than required 12", + password.len() + ), + severity: PolicySeverity::High, + resolved: false, + }); + } + + // Check complexity + let has_uppercase = password.chars().any(|c| c.is_uppercase()); + let has_lowercase = password.chars().any(|c| c.is_lowercase()); + let has_digit = password.chars().any(|c| c.is_numeric()); + let has_special = password.chars().any(|c| !c.is_alphanumeric()); + + if !(has_uppercase && has_lowercase && has_digit && has_special) { + violations.push(PolicyViolation { + id: Uuid::new_v4(), + policy_id: policy.id, + rule_id: policy.rules[1].id, + timestamp: Utc::now(), + user_id: None, + resource: "password".to_string(), + action_attempted: "set_password".to_string(), + violation_details: "Password does not meet complexity requirements".to_string(), + severity: PolicySeverity::High, + resolved: false, + }); + } + + // Add warnings for common patterns + if password.to_lowercase().contains("password") { + warnings.push("Password contains the word 'password'".to_string()); + } + + let result = PolicyCheckResult { + policy_id: policy.id, + passed: violations.is_empty(), + violations: violations.clone(), + warnings, + timestamp: Utc::now(), + }; + + self.violations.extend(violations); + self.check_history.push(result.clone()); + + result + } else { + PolicyCheckResult { + policy_id: Uuid::nil(), + passed: true, + violations: Vec::new(), + warnings: vec!["No password policy configured".to_string()], + timestamp: Utc::now(), + } + } + } + + /// Check session against policy + pub fn check_session_policy(&mut self, session_duration_seconds: u64) -> PolicyCheckResult { + let policy = self + .policies + .values() + .find(|p| { + p.policy_type == PolicyType::SessionTimeout && p.status == PolicyStatus::Active + }) + .cloned(); + + if let Some(policy) = policy { + let mut violations = Vec::new(); + + if session_duration_seconds > 28800 { + // 8 hours + violations.push(PolicyViolation { + id: Uuid::new_v4(), + policy_id: policy.id, + rule_id: policy.rules[0].id, + timestamp: Utc::now(), + user_id: None, + resource: "session".to_string(), + action_attempted: "extend_session".to_string(), + violation_details: format!( + "Session duration {} seconds exceeds maximum 28800 seconds", + session_duration_seconds + ), + severity: PolicySeverity::Medium, + resolved: false, + }); + } + + let result = PolicyCheckResult { + policy_id: policy.id, + passed: violations.is_empty(), + violations: violations.clone(), + warnings: Vec::new(), + timestamp: Utc::now(), + }; + + self.violations.extend(violations); + self.check_history.push(result.clone()); + + result + } else { + PolicyCheckResult { + policy_id: Uuid::nil(), + passed: true, + violations: Vec::new(), + warnings: vec!["No session policy configured".to_string()], + timestamp: Utc::now(), + } + } + } + + /// Check all active policies + pub fn check_all_policies(&mut self, context: &PolicyContext) -> Vec { + let mut results = Vec::new(); + + for policy in self.policies.values() { + if policy.status != PolicyStatus::Active { + continue; + } + + let result = self.check_policy(policy.id, context); + if let Ok(result) = result { + results.push(result); + } + } + + results + } + + /// Check a specific policy + pub fn check_policy( + &mut self, + policy_id: Uuid, + context: &PolicyContext, + ) -> Result { + let policy = self + .policies + .get(&policy_id) + .ok_or_else(|| anyhow!("Policy not found"))? + .clone(); + + let mut violations = Vec::new(); + let mut warnings = Vec::new(); + + for rule in &policy.rules { + if !self.evaluate_rule(rule, context) { + violations.push(PolicyViolation { + id: Uuid::new_v4(), + policy_id: policy.id, + rule_id: rule.id, + timestamp: Utc::now(), + user_id: context.user_id, + resource: context.resource.clone(), + action_attempted: context.action.clone(), + violation_details: format!("Rule '{}' failed", rule.name), + severity: policy.severity.clone(), + resolved: false, + }); + } + } + + let result = PolicyCheckResult { + policy_id: policy.id, + passed: violations.is_empty(), + violations: violations.clone(), + warnings, + timestamp: Utc::now(), + }; + + self.violations.extend(violations); + self.check_history.push(result.clone()); + + Ok(result) + } + + /// Evaluate a policy rule + fn evaluate_rule(&self, rule: &PolicyRule, _context: &PolicyContext) -> bool { + // Simplified rule evaluation + // In production, this would use a proper expression evaluator + match rule.action { + PolicyAction::Allow => true, + PolicyAction::Deny => false, + _ => true, // For Alert, Enforce, Log actions, consider as passed but take action + } + } + + /// Get policy violations + pub fn get_violations(&self, unresolved_only: bool) -> Vec { + if unresolved_only { + self.violations + .iter() + .filter(|v| !v.resolved) + .cloned() + .collect() + } else { + self.violations.clone() + } + } + + /// Resolve a violation + pub fn resolve_violation(&mut self, violation_id: Uuid) -> Result<()> { + if let Some(violation) = self.violations.iter_mut().find(|v| v.id == violation_id) { + violation.resolved = true; + log::info!("Resolved violation: {}", violation_id); + Ok(()) + } else { + Err(anyhow!("Violation not found")) + } + } + + /// Get policy compliance report + pub fn get_compliance_report(&self) -> PolicyComplianceReport { + let total_policies = self.policies.len(); + let active_policies = self + .policies + .values() + .filter(|p| p.status == PolicyStatus::Active) + .count(); + let total_violations = self.violations.len(); + let unresolved_violations = self.violations.iter().filter(|v| !v.resolved).count(); + let critical_violations = self + .violations + .iter() + .filter(|v| v.severity == PolicySeverity::Critical) + .count(); + + let recent_checks = self + .check_history + .iter() + .filter(|c| c.timestamp > Utc::now() - Duration::days(7)) + .count(); + + let compliance_rate = if !self.check_history.is_empty() { + let passed = self.check_history.iter().filter(|c| c.passed).count(); + (passed as f64 / self.check_history.len() as f64) * 100.0 + } else { + 100.0 + }; + + PolicyComplianceReport { + generated_at: Utc::now(), + total_policies, + active_policies, + total_violations, + unresolved_violations, + critical_violations, + recent_checks, + compliance_rate, + } + } +} + +/// Policy context for evaluation +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PolicyContext { + pub user_id: Option, + pub resource: String, + pub action: String, + pub parameters: HashMap, +} + +/// Policy compliance report +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PolicyComplianceReport { + pub generated_at: DateTime, + pub total_policies: usize, + pub active_policies: usize, + pub total_violations: usize, + pub unresolved_violations: usize, + pub critical_violations: usize, + pub recent_checks: usize, + pub compliance_rate: f64, +} + +impl Default for PolicyChecker { + fn default() -> Self { + Self::new() + } +} diff --git a/src/compliance/risk_assessment.rs b/src/compliance/risk_assessment.rs new file mode 100644 index 000000000..505d55b21 --- /dev/null +++ b/src/compliance/risk_assessment.rs @@ -0,0 +1,534 @@ +//! Risk Assessment Module +//! +//! Provides comprehensive risk assessment and management capabilities +//! for identifying, evaluating, and mitigating security risks. + +use anyhow::{anyhow, Result}; +use chrono::{DateTime, Duration, Utc}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use uuid::Uuid; + +/// Risk category enumeration +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum RiskCategory { + Security, + Compliance, + Operational, + Financial, + Reputational, + Technical, + Legal, +} + +/// Risk likelihood levels +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub enum Likelihood { + Rare, + Unlikely, + Possible, + Likely, + AlmostCertain, +} + +/// Risk impact levels +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub enum Impact { + Negligible, + Minor, + Moderate, + Major, + Catastrophic, +} + +/// Risk level based on likelihood and impact +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub enum RiskLevel { + Low, + Medium, + High, + Critical, +} + +/// Risk status +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum RiskStatus { + Identified, + Assessed, + Mitigating, + Monitoring, + Accepted, + Closed, +} + +/// Risk assessment +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RiskAssessment { + pub id: Uuid, + pub title: String, + pub description: String, + pub category: RiskCategory, + pub likelihood: Likelihood, + pub impact: Impact, + pub risk_level: RiskLevel, + pub status: RiskStatus, + pub identified_date: DateTime, + pub assessed_date: Option>, + pub owner: String, + pub affected_assets: Vec, + pub vulnerabilities: Vec, + pub threats: Vec, + pub controls: Vec, + pub residual_risk: Option, +} + +/// Vulnerability definition +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Vulnerability { + pub id: Uuid, + pub name: String, + pub description: String, + pub severity: RiskLevel, + pub cve_id: Option, + pub discovered_date: DateTime, + pub patched: bool, +} + +/// Threat definition +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Threat { + pub id: Uuid, + pub name: String, + pub description: String, + pub threat_actor: String, + pub likelihood: Likelihood, + pub tactics: Vec, +} + +/// Control measure +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Control { + pub id: Uuid, + pub name: String, + pub description: String, + pub control_type: ControlType, + pub effectiveness: Effectiveness, + pub implementation_status: ImplementationStatus, + pub cost: f64, +} + +/// Control type +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ControlType { + Preventive, + Detective, + Corrective, + Compensating, +} + +/// Control effectiveness +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum Effectiveness { + Ineffective, + PartiallyEffective, + Effective, + HighlyEffective, +} + +/// Implementation status +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ImplementationStatus { + Planned, + InProgress, + Implemented, + Verified, +} + +/// Risk mitigation plan +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MitigationPlan { + pub id: Uuid, + pub risk_id: Uuid, + pub strategy: MitigationStrategy, + pub actions: Vec, + pub timeline: Duration, + pub budget: f64, + pub responsible_party: String, + pub approval_status: ApprovalStatus, +} + +/// Mitigation strategy +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum MitigationStrategy { + Avoid, + Transfer, + Mitigate, + Accept, +} + +/// Mitigation action +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MitigationAction { + pub id: Uuid, + pub description: String, + pub due_date: DateTime, + pub assigned_to: String, + pub completed: bool, +} + +/// Approval status +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ApprovalStatus { + Pending, + Approved, + Rejected, +} + +/// Risk assessment service +#[derive(Debug, Clone)] +pub struct RiskAssessmentService { + assessments: HashMap, + mitigation_plans: HashMap, + risk_matrix: RiskMatrix, +} + +impl RiskAssessmentService { + /// Create new risk assessment service + pub fn new() -> Self { + Self { + assessments: HashMap::new(), + mitigation_plans: HashMap::new(), + risk_matrix: RiskMatrix::default(), + } + } + + /// Create a new risk assessment + pub fn create_assessment( + &mut self, + title: String, + description: String, + category: RiskCategory, + owner: String, + ) -> Result { + let assessment = RiskAssessment { + id: Uuid::new_v4(), + title, + description, + category, + likelihood: Likelihood::Possible, + impact: Impact::Moderate, + risk_level: RiskLevel::Medium, + status: RiskStatus::Identified, + identified_date: Utc::now(), + assessed_date: None, + owner, + affected_assets: Vec::new(), + vulnerabilities: Vec::new(), + threats: Vec::new(), + controls: Vec::new(), + residual_risk: None, + }; + + self.assessments.insert(assessment.id, assessment.clone()); + log::info!("Created risk assessment: {}", assessment.id); + + Ok(assessment) + } + + /// Assess risk level + pub fn assess_risk( + &mut self, + risk_id: Uuid, + likelihood: Likelihood, + impact: Impact, + ) -> Result { + let assessment = self + .assessments + .get_mut(&risk_id) + .ok_or_else(|| anyhow!("Risk assessment not found"))?; + + assessment.likelihood = likelihood.clone(); + assessment.impact = impact.clone(); + assessment.risk_level = self.risk_matrix.calculate_risk_level(&likelihood, &impact); + assessment.assessed_date = Some(Utc::now()); + assessment.status = RiskStatus::Assessed; + + log::info!( + "Assessed risk {}: level = {:?}", + risk_id, + assessment.risk_level + ); + + Ok(assessment.risk_level.clone()) + } + + /// Add vulnerability to risk assessment + pub fn add_vulnerability( + &mut self, + risk_id: Uuid, + vulnerability: Vulnerability, + ) -> Result<()> { + let assessment = self + .assessments + .get_mut(&risk_id) + .ok_or_else(|| anyhow!("Risk assessment not found"))?; + + assessment.vulnerabilities.push(vulnerability); + self.recalculate_risk_level(risk_id)?; + + Ok(()) + } + + /// Add threat to risk assessment + pub fn add_threat(&mut self, risk_id: Uuid, threat: Threat) -> Result<()> { + let assessment = self + .assessments + .get_mut(&risk_id) + .ok_or_else(|| anyhow!("Risk assessment not found"))?; + + assessment.threats.push(threat); + self.recalculate_risk_level(risk_id)?; + + Ok(()) + } + + /// Add control to risk assessment + pub fn add_control(&mut self, risk_id: Uuid, control: Control) -> Result<()> { + let assessment = self + .assessments + .get_mut(&risk_id) + .ok_or_else(|| anyhow!("Risk assessment not found"))?; + + assessment.controls.push(control); + self.calculate_residual_risk(risk_id)?; + + Ok(()) + } + + /// Recalculate risk level based on vulnerabilities and threats + fn recalculate_risk_level(&mut self, risk_id: Uuid) -> Result<()> { + let assessment = self + .assessments + .get_mut(&risk_id) + .ok_or_else(|| anyhow!("Risk assessment not found"))?; + + // Adjust likelihood based on threats + if !assessment.threats.is_empty() { + let max_threat_likelihood = assessment + .threats + .iter() + .map(|t| &t.likelihood) + .max() + .cloned() + .unwrap_or(Likelihood::Possible); + + if max_threat_likelihood > assessment.likelihood { + assessment.likelihood = max_threat_likelihood; + } + } + + // Adjust impact based on vulnerabilities + if !assessment.vulnerabilities.is_empty() { + let critical_vulns = assessment + .vulnerabilities + .iter() + .filter(|v| v.severity == RiskLevel::Critical) + .count(); + + if critical_vulns > 0 && assessment.impact < Impact::Major { + assessment.impact = Impact::Major; + } + } + + assessment.risk_level = + self.risk_matrix + .calculate_risk_level(&assessment.likelihood, &assessment.impact); + + Ok(()) + } + + /// Calculate residual risk after controls + fn calculate_residual_risk(&mut self, risk_id: Uuid) -> Result<()> { + let assessment = self + .assessments + .get_mut(&risk_id) + .ok_or_else(|| anyhow!("Risk assessment not found"))?; + + if assessment.controls.is_empty() { + assessment.residual_risk = Some(assessment.risk_level.clone()); + return Ok(()); + } + + // Calculate effectiveness of controls + let effective_controls = assessment + .controls + .iter() + .filter(|c| { + c.effectiveness == Effectiveness::Effective + || c.effectiveness == Effectiveness::HighlyEffective + }) + .count(); + + let residual = match (assessment.risk_level.clone(), effective_controls) { + (RiskLevel::Critical, n) if n >= 3 => RiskLevel::High, + (RiskLevel::Critical, n) if n >= 1 => RiskLevel::Critical, + (RiskLevel::High, n) if n >= 2 => RiskLevel::Medium, + (RiskLevel::High, n) if n >= 1 => RiskLevel::High, + (RiskLevel::Medium, n) if n >= 1 => RiskLevel::Low, + (level, _) => level, + }; + + assessment.residual_risk = Some(residual); + + Ok(()) + } + + /// Create mitigation plan + pub fn create_mitigation_plan( + &mut self, + risk_id: Uuid, + strategy: MitigationStrategy, + timeline: Duration, + budget: f64, + responsible_party: String, + ) -> Result { + if !self.assessments.contains_key(&risk_id) { + return Err(anyhow!("Risk assessment not found")); + } + + let plan = MitigationPlan { + id: Uuid::new_v4(), + risk_id, + strategy, + actions: Vec::new(), + timeline, + budget, + responsible_party, + approval_status: ApprovalStatus::Pending, + }; + + self.mitigation_plans.insert(plan.id, plan.clone()); + log::info!("Created mitigation plan {} for risk {}", plan.id, risk_id); + + Ok(plan) + } + + /// Get high-risk assessments + pub fn get_high_risk_assessments(&self) -> Vec { + self.assessments + .values() + .filter(|a| a.risk_level >= RiskLevel::High) + .cloned() + .collect() + } + + /// Get risk dashboard + pub fn get_risk_dashboard(&self) -> RiskDashboard { + let total_risks = self.assessments.len(); + let mut risks_by_level = HashMap::new(); + let mut risks_by_category = HashMap::new(); + let mut risks_by_status = HashMap::new(); + + for assessment in self.assessments.values() { + *risks_by_level + .entry(assessment.risk_level.clone()) + .or_insert(0) += 1; + *risks_by_category + .entry(assessment.category.clone()) + .or_insert(0) += 1; + *risks_by_status + .entry(assessment.status.clone()) + .or_insert(0) += 1; + } + + let mitigation_plans_pending = self + .mitigation_plans + .values() + .filter(|p| p.approval_status == ApprovalStatus::Pending) + .count(); + + RiskDashboard { + total_risks, + risks_by_level, + risks_by_category, + risks_by_status, + mitigation_plans_pending, + last_updated: Utc::now(), + } + } +} + +/// Risk matrix for calculating risk levels +#[derive(Debug, Clone)] +pub struct RiskMatrix { + matrix: HashMap<(Likelihood, Impact), RiskLevel>, +} + +impl RiskMatrix { + /// Calculate risk level based on likelihood and impact + pub fn calculate_risk_level(&self, likelihood: &Likelihood, impact: &Impact) -> RiskLevel { + self.matrix + .get(&(likelihood.clone(), impact.clone())) + .cloned() + .unwrap_or(RiskLevel::Medium) + } +} + +impl Default for RiskMatrix { + fn default() -> Self { + let mut matrix = HashMap::new(); + + // Define risk matrix + matrix.insert((Likelihood::Rare, Impact::Negligible), RiskLevel::Low); + matrix.insert((Likelihood::Rare, Impact::Minor), RiskLevel::Low); + matrix.insert((Likelihood::Rare, Impact::Moderate), RiskLevel::Low); + matrix.insert((Likelihood::Rare, Impact::Major), RiskLevel::Medium); + matrix.insert((Likelihood::Rare, Impact::Catastrophic), RiskLevel::High); + + matrix.insert((Likelihood::Unlikely, Impact::Negligible), RiskLevel::Low); + matrix.insert((Likelihood::Unlikely, Impact::Minor), RiskLevel::Low); + matrix.insert((Likelihood::Unlikely, Impact::Moderate), RiskLevel::Medium); + matrix.insert((Likelihood::Unlikely, Impact::Major), RiskLevel::High); + matrix.insert((Likelihood::Unlikely, Impact::Catastrophic), RiskLevel::High); + + matrix.insert((Likelihood::Possible, Impact::Negligible), RiskLevel::Low); + matrix.insert((Likelihood::Possible, Impact::Minor), RiskLevel::Medium); + matrix.insert((Likelihood::Possible, Impact::Moderate), RiskLevel::Medium); + matrix.insert((Likelihood::Possible, Impact::Major), RiskLevel::High); + matrix.insert((Likelihood::Possible, Impact::Catastrophic), RiskLevel::Critical); + + matrix.insert((Likelihood::Likely, Impact::Negligible), RiskLevel::Medium); + matrix.insert((Likelihood::Likely, Impact::Minor), RiskLevel::Medium); + matrix.insert((Likelihood::Likely, Impact::Moderate), RiskLevel::High); + matrix.insert((Likelihood::Likely, Impact::Major), RiskLevel::Critical); + matrix.insert((Likelihood::Likely, Impact::Catastrophic), RiskLevel::Critical); + + matrix.insert((Likelihood::AlmostCertain, Impact::Negligible), RiskLevel::Medium); + matrix.insert((Likelihood::AlmostCertain, Impact::Minor), RiskLevel::High); + matrix.insert((Likelihood::AlmostCertain, Impact::Moderate), RiskLevel::High); + matrix.insert((Likelihood::AlmostCertain, Impact::Major), RiskLevel::Critical); + matrix.insert( + (Likelihood::AlmostCertain, Impact::Catastrophic), + RiskLevel::Critical, + ); + + Self { matrix } + } +} + +/// Risk dashboard +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RiskDashboard { + pub total_risks: usize, + pub risks_by_level: HashMap, + pub risks_by_category: HashMap, + pub risks_by_status: HashMap, + pub mitigation_plans_pending: usize, + pub last_updated: DateTime, +} + +impl Default for RiskAssessmentService { + fn default() -> Self { + Self::new() + } +} diff --git a/src/compliance/training_tracker.rs b/src/compliance/training_tracker.rs new file mode 100644 index 000000000..bc9d86e72 --- /dev/null +++ b/src/compliance/training_tracker.rs @@ -0,0 +1,501 @@ +//! Training Tracker Module +//! +//! Provides comprehensive security training tracking and compliance +//! management for ensuring all personnel meet training requirements. + +use anyhow::{anyhow, Result}; +use chrono::{DateTime, Duration, Utc}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use uuid::Uuid; + +/// Training type enumeration +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum TrainingType { + SecurityAwareness, + DataProtection, + PhishingPrevention, + IncidentResponse, + ComplianceRegulation, + PasswordManagement, + AccessControl, + EmergencyProcedures, +} + +/// Training status +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum TrainingStatus { + NotStarted, + InProgress, + Completed, + Expired, + Failed, + Exempted, +} + +/// Training priority +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub enum TrainingPriority { + Low, + Medium, + High, + Critical, +} + +/// Training course definition +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TrainingCourse { + pub id: Uuid, + pub title: String, + pub description: String, + pub training_type: TrainingType, + pub duration_hours: f32, + pub validity_days: i64, + pub priority: TrainingPriority, + pub required_for_roles: Vec, + pub prerequisites: Vec, + pub content_url: Option, + pub passing_score: u32, + pub max_attempts: u32, +} + +/// Training assignment +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TrainingAssignment { + pub id: Uuid, + pub user_id: Uuid, + pub course_id: Uuid, + pub assigned_date: DateTime, + pub due_date: DateTime, + pub status: TrainingStatus, + pub attempts: Vec, + pub completion_date: Option>, + pub expiry_date: Option>, + pub assigned_by: String, + pub notes: Option, +} + +/// Training attempt record +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TrainingAttempt { + pub id: Uuid, + pub attempt_number: u32, + pub start_time: DateTime, + pub end_time: Option>, + pub score: Option, + pub passed: bool, + pub time_spent_minutes: Option, +} + +/// Training certificate +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TrainingCertificate { + pub id: Uuid, + pub user_id: Uuid, + pub course_id: Uuid, + pub issued_date: DateTime, + pub expiry_date: DateTime, + pub certificate_number: String, + pub verification_code: String, +} + +/// Training compliance status +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ComplianceStatus { + pub user_id: Uuid, + pub compliant: bool, + pub required_trainings: Vec, + pub completed_trainings: Vec, + pub overdue_trainings: Vec, + pub upcoming_trainings: Vec, + pub compliance_percentage: f64, +} + +/// Training tracker service +#[derive(Debug, Clone)] +pub struct TrainingTracker { + courses: HashMap, + assignments: HashMap, + certificates: HashMap, + user_roles: HashMap>, +} + +impl TrainingTracker { + /// Create new training tracker + pub fn new() -> Self { + let mut tracker = Self { + courses: HashMap::new(), + assignments: HashMap::new(), + certificates: HashMap::new(), + user_roles: HashMap::new(), + }; + + // Initialize with default courses + tracker.initialize_default_courses(); + tracker + } + + /// Initialize default training courses + fn initialize_default_courses(&mut self) { + let security_awareness = TrainingCourse { + id: Uuid::new_v4(), + title: "Security Awareness Fundamentals".to_string(), + description: "Basic security awareness training for all employees".to_string(), + training_type: TrainingType::SecurityAwareness, + duration_hours: 2.0, + validity_days: 365, + priority: TrainingPriority::High, + required_for_roles: vec!["all".to_string()], + prerequisites: vec![], + content_url: Some("https://training.example.com/security-awareness".to_string()), + passing_score: 80, + max_attempts: 3, + }; + + self.courses.insert(security_awareness.id, security_awareness); + + let data_protection = TrainingCourse { + id: Uuid::new_v4(), + title: "Data Protection and Privacy".to_string(), + description: "Training on data protection regulations and best practices".to_string(), + training_type: TrainingType::DataProtection, + duration_hours: 3.0, + validity_days: 365, + priority: TrainingPriority::High, + required_for_roles: vec!["admin".to_string(), "manager".to_string()], + prerequisites: vec![], + content_url: Some("https://training.example.com/data-protection".to_string()), + passing_score: 85, + max_attempts: 3, + }; + + self.courses.insert(data_protection.id, data_protection); + } + + /// Create a training course + pub fn create_course(&mut self, course: TrainingCourse) -> Result<()> { + if self.courses.contains_key(&course.id) { + return Err(anyhow!("Course already exists")); + } + + log::info!("Creating training course: {}", course.title); + self.courses.insert(course.id, course); + Ok(()) + } + + /// Assign training to user + pub fn assign_training( + &mut self, + user_id: Uuid, + course_id: Uuid, + due_days: i64, + assigned_by: String, + ) -> Result { + let course = self + .courses + .get(&course_id) + .ok_or_else(|| anyhow!("Course not found"))? + .clone(); + + let assignment = TrainingAssignment { + id: Uuid::new_v4(), + user_id, + course_id, + assigned_date: Utc::now(), + due_date: Utc::now() + Duration::days(due_days), + status: TrainingStatus::NotStarted, + attempts: vec![], + completion_date: None, + expiry_date: None, + assigned_by, + notes: None, + }; + + self.assignments.insert(assignment.id, assignment.clone()); + + log::info!( + "Assigned training '{}' to user {}", + course.title, + user_id + ); + + Ok(assignment) + } + + /// Start training attempt + pub fn start_training(&mut self, assignment_id: Uuid) -> Result { + let assignment = self + .assignments + .get_mut(&assignment_id) + .ok_or_else(|| anyhow!("Assignment not found"))?; + + let course = self + .courses + .get(&assignment.course_id) + .ok_or_else(|| anyhow!("Course not found"))?; + + if assignment.attempts.len() >= course.max_attempts as usize { + return Err(anyhow!("Maximum attempts exceeded")); + } + + let attempt = TrainingAttempt { + id: Uuid::new_v4(), + attempt_number: (assignment.attempts.len() + 1) as u32, + start_time: Utc::now(), + end_time: None, + score: None, + passed: false, + time_spent_minutes: None, + }; + + assignment.status = TrainingStatus::InProgress; + assignment.attempts.push(attempt.clone()); + + Ok(attempt) + } + + /// Complete training attempt + pub fn complete_training( + &mut self, + assignment_id: Uuid, + attempt_id: Uuid, + score: u32, + ) -> Result { + let assignment = self + .assignments + .get_mut(&assignment_id) + .ok_or_else(|| anyhow!("Assignment not found"))?; + + let course = self + .courses + .get(&assignment.course_id) + .ok_or_else(|| anyhow!("Course not found"))? + .clone(); + + let attempt = assignment + .attempts + .iter_mut() + .find(|a| a.id == attempt_id) + .ok_or_else(|| anyhow!("Attempt not found"))?; + + let end_time = Utc::now(); + let time_spent = (end_time - attempt.start_time).num_minutes() as u32; + + attempt.end_time = Some(end_time); + attempt.score = Some(score); + attempt.time_spent_minutes = Some(time_spent); + attempt.passed = score >= course.passing_score; + + if attempt.passed { + assignment.status = TrainingStatus::Completed; + assignment.completion_date = Some(end_time); + assignment.expiry_date = Some(end_time + Duration::days(course.validity_days)); + + // Issue certificate + let certificate = TrainingCertificate { + id: Uuid::new_v4(), + user_id: assignment.user_id, + course_id: course.id, + issued_date: end_time, + expiry_date: end_time + Duration::days(course.validity_days), + certificate_number: format!("CERT-{}", Uuid::new_v4().to_string()[..8].to_uppercase()), + verification_code: Uuid::new_v4().to_string(), + }; + + self.certificates.insert(certificate.id, certificate); + + log::info!( + "User {} completed training '{}' with score {}", + assignment.user_id, + course.title, + score + ); + } else if assignment.attempts.len() >= course.max_attempts as usize { + assignment.status = TrainingStatus::Failed; + } + + Ok(attempt.passed) + } + + /// Get user compliance status + pub fn get_compliance_status(&self, user_id: Uuid) -> ComplianceStatus { + let user_roles = self + .user_roles + .get(&user_id) + .cloned() + .unwrap_or_else(|| vec!["all".to_string()]); + + let mut required_trainings = vec![]; + let mut completed_trainings = vec![]; + let mut overdue_trainings = vec![]; + let mut upcoming_trainings = vec![]; + + for course in self.courses.values() { + if course.required_for_roles.iter().any(|r| { + user_roles.contains(r) || r == "all" + }) { + required_trainings.push(course.id); + + // Check if user has completed this training + let assignment = self + .assignments + .values() + .find(|a| a.user_id == user_id && a.course_id == course.id); + + if let Some(assignment) = assignment { + match assignment.status { + TrainingStatus::Completed => { + if let Some(expiry) = assignment.expiry_date { + if expiry > Utc::now() { + completed_trainings.push(course.id); + } else { + overdue_trainings.push(course.id); + } + } + } + TrainingStatus::NotStarted | TrainingStatus::InProgress => { + if assignment.due_date < Utc::now() { + overdue_trainings.push(course.id); + } else { + upcoming_trainings.push(course.id); + } + } + _ => {} + } + } else { + overdue_trainings.push(course.id); + } + } + } + + let compliance_percentage = if required_trainings.is_empty() { + 100.0 + } else { + (completed_trainings.len() as f64 / required_trainings.len() as f64) * 100.0 + }; + + ComplianceStatus { + user_id, + compliant: overdue_trainings.is_empty(), + required_trainings, + completed_trainings, + overdue_trainings, + upcoming_trainings, + compliance_percentage, + } + } + + /// Get training report + pub fn get_training_report(&self) -> TrainingReport { + let total_courses = self.courses.len(); + let total_assignments = self.assignments.len(); + let total_certificates = self.certificates.len(); + + let mut assignments_by_status = HashMap::new(); + for assignment in self.assignments.values() { + *assignments_by_status + .entry(assignment.status.clone()) + .or_insert(0) += 1; + } + + let overdue_count = self + .assignments + .values() + .filter(|a| { + a.status != TrainingStatus::Completed + && a.due_date < Utc::now() + }) + .count(); + + let expiring_soon = self + .certificates + .values() + .filter(|c| { + c.expiry_date > Utc::now() + && c.expiry_date < Utc::now() + Duration::days(30) + }) + .count(); + + let average_score = self.calculate_average_score(); + + TrainingReport { + generated_at: Utc::now(), + total_courses, + total_assignments, + total_certificates, + assignments_by_status, + overdue_count, + expiring_soon, + average_score, + } + } + + /// Calculate average training score + fn calculate_average_score(&self) -> f64 { + let mut total_score = 0; + let mut count = 0; + + for assignment in self.assignments.values() { + for attempt in &assignment.attempts { + if let Some(score) = attempt.score { + total_score += score; + count += 1; + } + } + } + + if count == 0 { + 0.0 + } else { + total_score as f64 / count as f64 + } + } + + /// Set user roles + pub fn set_user_roles(&mut self, user_id: Uuid, roles: Vec) { + self.user_roles.insert(user_id, roles); + } + + /// Get overdue trainings + pub fn get_overdue_trainings(&self) -> Vec { + self.assignments + .values() + .filter(|a| { + a.status != TrainingStatus::Completed + && a.due_date < Utc::now() + }) + .cloned() + .collect() + } + + /// Get expiring certificates + pub fn get_expiring_certificates(&self, days_ahead: i64) -> Vec { + let cutoff = Utc::now() + Duration::days(days_ahead); + self.certificates + .values() + .filter(|c| { + c.expiry_date > Utc::now() && c.expiry_date <= cutoff + }) + .cloned() + .collect() + } +} + +/// Training report +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TrainingReport { + pub generated_at: DateTime, + pub total_courses: usize, + pub total_assignments: usize, + pub total_certificates: usize, + pub assignments_by_status: HashMap, + pub overdue_count: usize, + pub expiring_soon: usize, + pub average_score: f64, +} + +impl Default for TrainingTracker { + fn default() -> Self { + Self::new() + } +} diff --git a/src/core/bot/kb_context.rs b/src/core/bot/kb_context.rs new file mode 100644 index 000000000..9f903a832 --- /dev/null +++ b/src/core/bot/kb_context.rs @@ -0,0 +1,338 @@ +use anyhow::Result; +use diesel::prelude::*; +use log::{debug, error, info, warn}; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use uuid::Uuid; + +use crate::core::kb::KnowledgeBaseManager; +use crate::shared::utils::DbPool; + +/// Represents an active KB association for a session +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SessionKbAssociation { + pub kb_name: String, + pub qdrant_collection: String, + pub kb_folder_path: String, + pub is_active: bool, +} + +/// KB context that will be injected into the LLM prompt +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct KbContext { + pub kb_name: String, + pub search_results: Vec, + pub total_tokens: usize, +} + +/// Individual search result from a KB +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct KbSearchResult { + pub content: String, + pub document_path: String, + pub score: f32, + pub chunk_tokens: usize, +} + +/// Manager for KB context retrieval and injection +#[derive(Debug)] +pub struct KbContextManager { + kb_manager: Arc, + db_pool: DbPool, +} + +impl KbContextManager { + /// Create a new KB context manager + pub fn new(kb_manager: Arc, db_pool: DbPool) -> Self { + Self { + kb_manager, + db_pool, + } + } + + /// Get all active KB associations for a session + pub async fn get_active_kbs(&self, session_id: Uuid) -> Result> { + let mut conn = self.db_pool.get()?; + + // Query for active KB associations + let query = diesel::sql_query( + "SELECT kb_name, qdrant_collection, kb_folder_path, is_active + FROM session_kb_associations + WHERE session_id = $1 AND is_active = true", + ) + .bind::(session_id); + + #[derive(QueryableByName)] + struct KbAssocRow { + #[diesel(sql_type = diesel::sql_types::Text)] + kb_name: String, + #[diesel(sql_type = diesel::sql_types::Text)] + qdrant_collection: String, + #[diesel(sql_type = diesel::sql_types::Text)] + kb_folder_path: String, + #[diesel(sql_type = diesel::sql_types::Bool)] + is_active: bool, + } + + let rows: Vec = query.load(&mut conn)?; + + Ok(rows + .into_iter() + .map(|row| SessionKbAssociation { + kb_name: row.kb_name, + qdrant_collection: row.qdrant_collection, + kb_folder_path: row.kb_folder_path, + is_active: row.is_active, + }) + .collect()) + } + + /// Search all active KBs for relevant context + pub async fn search_active_kbs( + &self, + session_id: Uuid, + bot_name: &str, + query: &str, + max_results_per_kb: usize, + max_total_tokens: usize, + ) -> Result> { + let active_kbs = self.get_active_kbs(session_id).await?; + + if active_kbs.is_empty() { + debug!("No active KBs for session {}", session_id); + return Ok(Vec::new()); + } + + info!( + "Searching {} active KBs for session {}: {:?}", + active_kbs.len(), + session_id, + active_kbs.iter().map(|kb| &kb.kb_name).collect::>() + ); + + let mut kb_contexts = Vec::new(); + let mut total_tokens_used = 0; + + for kb_assoc in active_kbs { + if total_tokens_used >= max_total_tokens { + warn!("Reached max token limit, skipping remaining KBs"); + break; + } + + match self + .search_single_kb( + bot_name, + &kb_assoc.kb_name, + query, + max_results_per_kb, + max_total_tokens - total_tokens_used, + ) + .await + { + Ok(context) => { + total_tokens_used += context.total_tokens; + info!( + "Found {} results from KB '{}' using {} tokens", + context.search_results.len(), + context.kb_name, + context.total_tokens + ); + kb_contexts.push(context); + } + Err(e) => { + error!("Failed to search KB '{}': {}", kb_assoc.kb_name, e); + // Continue with other KBs even if one fails + } + } + } + + Ok(kb_contexts) + } + + /// Search a single KB for relevant context + async fn search_single_kb( + &self, + bot_name: &str, + kb_name: &str, + query: &str, + max_results: usize, + max_tokens: usize, + ) -> Result { + debug!("Searching KB '{}' with query: {}", kb_name, query); + + // Use the KnowledgeBaseManager to search + let search_results = self + .kb_manager + .search(bot_name, kb_name, query, max_results) + .await?; + + let mut kb_search_results = Vec::new(); + let mut total_tokens = 0; + + for result in search_results { + let tokens = estimate_tokens(&result.content); + + // Check if adding this result would exceed token limit + if total_tokens + tokens > max_tokens { + debug!( + "Skipping result due to token limit ({} + {} > {})", + total_tokens, tokens, max_tokens + ); + break; + } + + kb_search_results.push(KbSearchResult { + content: result.content, + document_path: result.document_path, + score: result.score, + chunk_tokens: tokens, + }); + + total_tokens += tokens; + + // Only include high-relevance results (score > 0.7) + if result.score < 0.7 { + debug!("Skipping low-relevance result (score: {})", result.score); + break; + } + } + + Ok(KbContext { + kb_name: kb_name.to_string(), + search_results: kb_search_results, + total_tokens, + }) + } + + /// Build context string from KB search results for LLM injection + pub fn build_context_string(&self, kb_contexts: &[KbContext]) -> String { + if kb_contexts.is_empty() { + return String::new(); + } + + let mut context_parts = vec!["\n--- Knowledge Base Context ---".to_string()]; + + for kb_context in kb_contexts { + if kb_context.search_results.is_empty() { + continue; + } + + context_parts.push(format!( + "\n## From '{}' knowledge base:", + kb_context.kb_name + )); + + for (idx, result) in kb_context.search_results.iter().enumerate() { + context_parts.push(format!( + "\n### Result {} (relevance: {:.2}):\n{}", + idx + 1, + result.score, + result.content + )); + + if !result.document_path.is_empty() { + context_parts.push(format!("Source: {}", result.document_path)); + } + } + } + + context_parts.push("\n--- End Knowledge Base Context ---\n".to_string()); + context_parts.join("\n") + } + + /// Get active tools for a session (similar to KBs) + pub async fn get_active_tools(&self, session_id: Uuid) -> Result> { + let mut conn = self.db_pool.get()?; + + let query = diesel::sql_query( + "SELECT tool_name + FROM session_tool_associations + WHERE session_id = $1 AND is_active = true", + ) + .bind::(session_id); + + #[derive(QueryableByName)] + struct ToolRow { + #[diesel(sql_type = diesel::sql_types::Text)] + tool_name: String, + } + + let rows: Vec = query.load(&mut conn)?; + Ok(rows.into_iter().map(|row| row.tool_name).collect()) + } +} + +/// Estimate token count for a string (rough approximation) +fn estimate_tokens(text: &str) -> usize { + // Rough estimate: 1 token per 4 characters + // This is a simplified heuristic; real tokenization would be more accurate + text.len() / 4 +} + +/// Integration helper for injecting KB context into LLM messages +pub async fn inject_kb_context( + kb_manager: Arc, + db_pool: DbPool, + session_id: Uuid, + bot_name: &str, + user_query: &str, + messages: &mut serde_json::Value, + max_context_tokens: usize, +) -> Result<()> { + let context_manager = KbContextManager::new(kb_manager, db_pool); + + // Search active KBs + let kb_contexts = context_manager + .search_active_kbs( + session_id, + bot_name, + user_query, + 5, // max 5 results per KB + max_context_tokens, + ) + .await?; + + if kb_contexts.is_empty() { + debug!("No KB context found for session {}", session_id); + return Ok(()); + } + + // Build context string + let context_string = context_manager.build_context_string(&kb_contexts); + + if context_string.is_empty() { + return Ok(()); + } + + info!( + "Injecting {} characters of KB context into prompt for session {}", + context_string.len(), + session_id + ); + + // Inject context into messages + // The context is added as a system message or appended to the existing system prompt + if let Some(messages_array) = messages.as_array_mut() { + // Find or create system message + let system_msg_idx = messages_array.iter().position(|m| m["role"] == "system"); + + if let Some(idx) = system_msg_idx { + // Append to existing system message + if let Some(content) = messages_array[idx]["content"].as_str() { + let new_content = format!("{}\n{}", content, context_string); + messages_array[idx]["content"] = serde_json::Value::String(new_content); + } + } else { + // Insert as first message + messages_array.insert( + 0, + serde_json::json!({ + "role": "system", + "content": context_string + }), + ); + } + } + + Ok(()) +} diff --git a/src/core/bot/mod.rs b/src/core/bot/mod.rs index 5d58255f9..2b0abe6eb 100644 --- a/src/core/bot/mod.rs +++ b/src/core/bot/mod.rs @@ -129,16 +129,25 @@ impl BotOrchestrator { let system_prompt = std::env::var("SYSTEM_PROMPT") .unwrap_or_else(|_| "You are a helpful assistant.".to_string()); - let messages = OpenAIClient::build_messages(&system_prompt, &context_data, &history); + let mut messages = OpenAIClient::build_messages(&system_prompt, &context_data, &history); + + // Inject bot_id into messages for cache system + if let serde_json::Value::Array(ref mut msgs) = messages { + let bot_id_obj = serde_json::json!({ + "bot_id": bot_id.to_string() + }); + msgs.push(bot_id_obj); + } let (stream_tx, mut stream_rx) = mpsc::channel::(100); let llm = self.state.llm_provider.clone(); let model_clone = model.clone(); let key_clone = key.clone(); + let messages_clone = messages.clone(); tokio::spawn(async move { if let Err(e) = llm - .generate_stream("", &messages, stream_tx, &model_clone, &key_clone) + .generate_stream("", &messages_clone, stream_tx, &model_clone, &key_clone) .await { error!("LLM streaming error: {}", e); diff --git a/src/core/bot/mod_backup.rs b/src/core/bot/mod_backup.rs new file mode 100644 index 000000000..e355ae58d --- /dev/null +++ b/src/core/bot/mod_backup.rs @@ -0,0 +1,526 @@ +use crate::core::config::ConfigManager; +use crate::drive::drive_monitor::DriveMonitor; +use crate::llm::llm_models; +use crate::llm::OpenAIClient; +#[cfg(feature = "nvidia")] +use crate::nvidia::get_system_metrics; +use crate::shared::models::{BotResponse, UserMessage, UserSession}; +use crate::shared::state::AppState; +use axum::extract::ws::{Message, WebSocket}; +use axum::{ + extract::{ws::WebSocketUpgrade, Extension, Query, State}, + http::StatusCode, + response::{IntoResponse, Json}, +}; +use diesel::PgConnection; +use futures::{sink::SinkExt, stream::StreamExt}; +use log::{error, info, trace, warn}; +use serde_json; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::mpsc; +use tokio::sync::Mutex as AsyncMutex; +use uuid::Uuid; + +pub mod channels; +pub mod multimedia; + +pub fn get_default_bot(conn: &mut PgConnection) -> (Uuid, String) { + use crate::shared::models::schema::bots::dsl::*; + use diesel::prelude::*; + + match bots + .filter(is_active.eq(true)) + .select((id, name)) + .first::<(Uuid, String)>(conn) + .optional() + { + Ok(Some((bot_id, bot_name))) => (bot_id, bot_name), + Ok(None) => { + warn!("No active bots found, using nil UUID"); + (Uuid::nil(), "default".to_string()) + } + Err(e) => { + error!("Failed to query default bot: {}", e); + (Uuid::nil(), "default".to_string()) + } + } +} + +#[derive(Debug)] +pub struct BotOrchestrator { + pub state: Arc, + pub mounted_bots: Arc>>>, +} + +impl BotOrchestrator { + pub fn new(state: Arc) -> Self { + Self { + state, + mounted_bots: Arc::new(AsyncMutex::new(HashMap::new())), + } + } + + pub async fn mount_all_bots(&self) -> Result<(), Box> { + info!("mount_all_bots called"); + Ok(()) + } + + pub async fn stream_response( + &self, + message: UserMessage, + response_tx: mpsc::Sender, + ) -> Result<(), Box> { + trace!( + "Streaming response for user: {}, session: {}", + message.user_id, + message.session_id + ); + + let user_id = Uuid::parse_str(&message.user_id)?; + let session_id = Uuid::parse_str(&message.session_id)?; + let bot_id = Uuid::parse_str(&message.bot_id).unwrap_or_default(); + + let (session, context_data, history, model, key) = { + let state_clone = self.state.clone(); + tokio::task::spawn_blocking( + move || -> Result<_, Box> { + let session = { + let mut sm = state_clone.session_manager.blocking_lock(); + sm.get_session_by_id(session_id)? + } + .ok_or_else(|| "Session not found")?; + + { + let mut sm = state_clone.session_manager.blocking_lock(); + sm.save_message(session.id, user_id, 1, &message.content, 1)?; + } + + let context_data = { + let sm = state_clone.session_manager.blocking_lock(); + let rt = tokio::runtime::Handle::current(); + rt.block_on(async { + sm.get_session_context_data(&session.id, &session.user_id) + .await + })? + }; + + let history = { + let mut sm = state_clone.session_manager.blocking_lock(); + sm.get_conversation_history(session.id, user_id)? + }; + + let config_manager = ConfigManager::new(state_clone.conn.clone()); + let model = config_manager + .get_config(&bot_id, "llm-model", Some("gpt-3.5-turbo")) + .unwrap_or_else(|_| "gpt-3.5-turbo".to_string()); + let key = config_manager + .get_config(&bot_id, "llm-key", Some("")) + .unwrap_or_default(); + + Ok((session, context_data, history, model, key)) + }, + ) + .await?? + }; + + let system_prompt = std::env::var("SYSTEM_PROMPT") + .unwrap_or_else(|_| "You are a helpful assistant.".to_string()); + let messages = OpenAIClient::build_messages(&system_prompt, &context_data, &history); + + let (stream_tx, mut stream_rx) = mpsc::channel::(100); + let llm = self.state.llm_provider.clone(); + + let model_clone = model.clone(); + let key_clone = key.clone(); + tokio::spawn(async move { + if let Err(e) = llm + .generate_stream("", &messages, stream_tx, &model_clone, &key_clone) + .await + { + error!("LLM streaming error: {}", e); + } + }); + + let mut full_response = String::new(); + let mut analysis_buffer = String::new(); + let mut in_analysis = false; + let handler = llm_models::get_handler(&model); + + #[cfg(feature = "nvidia")] + { + let initial_tokens = crate::shared::utils::estimate_token_count(&context_data); + let config_manager = ConfigManager::new(self.state.conn.clone()); + let max_context_size = config_manager + .get_config(&bot_id, "llm-server-ctx-size", None) + .unwrap_or_default() + .parse::() + .unwrap_or(0); + + if let Ok(metrics) = get_system_metrics(initial_tokens, max_context_size) { + eprintln!( + "\nNVIDIA: {:.1}% | CPU: {:.1}% | Tokens: {}/{}", + metrics.gpu_usage.unwrap_or(0.0), + metrics.cpu_usage, + initial_tokens, + max_context_size + ); + } + } + + while let Some(chunk) = stream_rx.recv().await { + trace!("Received LLM chunk: {:?}", chunk); + analysis_buffer.push_str(&chunk); + + if handler.has_analysis_markers(&analysis_buffer) && !in_analysis { + in_analysis = true; + } + + if in_analysis && handler.is_analysis_complete(&analysis_buffer) { + in_analysis = false; + analysis_buffer.clear(); + continue; + } + + if !in_analysis { + full_response.push_str(&chunk); + + let response = BotResponse { + bot_id: message.bot_id.clone(), + user_id: message.user_id.clone(), + session_id: message.session_id.clone(), + channel: message.channel.clone(), + content: chunk, + message_type: 2, + stream_token: None, + is_complete: false, + suggestions: Vec::new(), + context_name: None, + context_length: 0, + context_max_length: 0, + }; + + if response_tx.send(response).await.is_err() { + warn!("Response channel closed"); + break; + } + } + } + + let state_for_save = self.state.clone(); + let full_response_clone = full_response.clone(); + tokio::task::spawn_blocking( + move || -> Result<(), Box> { + let mut sm = state_for_save.session_manager.blocking_lock(); + sm.save_message(session.id, user_id, 2, &full_response_clone, 2)?; + Ok(()) + }, + ) + .await??; + + let final_response = BotResponse { + bot_id: message.bot_id, + user_id: message.user_id, + session_id: message.session_id, + channel: message.channel, + content: full_response, + message_type: 2, + stream_token: None, + is_complete: true, + suggestions: Vec::new(), + context_name: None, + context_length: 0, + context_max_length: 0, + }; + + response_tx.send(final_response).await?; + Ok(()) + } + + pub async fn get_user_sessions( + &self, + user_id: Uuid, + ) -> Result, Box> { + let mut session_manager = self.state.session_manager.lock().await; + let sessions = session_manager.get_user_sessions(user_id)?; + Ok(sessions) + } + + pub async fn get_conversation_history( + &self, + session_id: Uuid, + user_id: Uuid, + ) -> Result, Box> { + let mut session_manager = self.state.session_manager.lock().await; + let history = session_manager.get_conversation_history(session_id, user_id)?; + Ok(history) + } +} + +pub async fn websocket_handler( + ws: WebSocketUpgrade, + State(state): State>, + Query(params): Query>, +) -> impl IntoResponse { + let session_id = params + .get("session_id") + .and_then(|s| Uuid::parse_str(s).ok()); + let user_id = params.get("user_id").and_then(|s| Uuid::parse_str(s).ok()); + + if session_id.is_none() || user_id.is_none() { + return ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ "error": "session_id and user_id are required" })), + ) + .into_response(); + } + + ws.on_upgrade(move |socket| { + handle_websocket(socket, state, session_id.unwrap(), user_id.unwrap()) + }) + .into_response() +} + +async fn handle_websocket( + socket: WebSocket, + state: Arc, + session_id: Uuid, + user_id: Uuid, +) { + let (mut sender, mut receiver) = socket.split(); + let (tx, mut rx) = mpsc::channel::(100); + + state + .web_adapter + .add_connection(session_id.to_string(), tx.clone()) + .await; + + { + let mut channels = state.response_channels.lock().await; + channels.insert(session_id.to_string(), tx.clone()); + } + + info!( + "WebSocket connected for session: {}, user: {}", + session_id, user_id + ); + + let welcome = serde_json::json!({ + "type": "connected", + "session_id": session_id, + "user_id": user_id, + "message": "Connected to bot server" + }); + + if let Ok(welcome_str) = serde_json::to_string(&welcome) { + if sender + .send(Message::Text(welcome_str.into())) + .await + .is_err() + { + error!("Failed to send welcome message"); + } + } + + let mut send_task = tokio::spawn(async move { + while let Some(response) = rx.recv().await { + if let Ok(json_str) = serde_json::to_string(&response) { + if sender.send(Message::Text(json_str.into())).await.is_err() { + break; + } + } + } + }); + + let state_clone = state.clone(); + let mut recv_task = tokio::spawn(async move { + while let Some(Ok(msg)) = receiver.next().await { + match msg { + Message::Text(text) => { + info!("Received WebSocket message: {}", text); + if let Ok(user_msg) = serde_json::from_str::(&text) { + let orchestrator = BotOrchestrator::new(state_clone.clone()); + if let Some(tx_clone) = state_clone + .response_channels + .lock() + .await + .get(&session_id.to_string()) + { + if let Err(e) = orchestrator + .stream_response(user_msg, tx_clone.clone()) + .await + { + error!("Failed to stream response: {}", e); + } + } + } + } + Message::Close(_) => { + info!("WebSocket close message received"); + break; + } + _ => {} + } + } + }); + + tokio::select! { + _ = (&mut send_task) => { recv_task.abort(); } + _ = (&mut recv_task) => { send_task.abort(); } + } + + state + .web_adapter + .remove_connection(&session_id.to_string()) + .await; + + { + let mut channels = state.response_channels.lock().await; + channels.remove(&session_id.to_string()); + } + + info!("WebSocket disconnected for session: {}", session_id); +} + +pub async fn create_bot_handler( + Extension(state): Extension>, + Json(payload): Json>, +) -> impl IntoResponse { + let bot_name = payload + .get("bot_name") + .cloned() + .unwrap_or_else(|| "default".to_string()); + + let orchestrator = BotOrchestrator::new(state); + if let Err(e) = orchestrator.mount_all_bots().await { + error!("Failed to mount bots: {}", e); + } + + ( + StatusCode::OK, + Json(serde_json::json!({ "status": format!("bot '{}' created", bot_name) })), + ) +} + +pub async fn mount_bot_handler( + Extension(state): Extension>, + Json(payload): Json>, +) -> impl IntoResponse { + let bot_guid = payload.get("bot_guid").cloned().unwrap_or_default(); + + let orchestrator = BotOrchestrator::new(state); + if let Err(e) = orchestrator.mount_all_bots().await { + error!("Failed to mount bot: {}", e); + } + + ( + StatusCode::OK, + Json(serde_json::json!({ "status": format!("bot '{}' mounted", bot_guid) })), + ) +} + +pub async fn handle_user_input_handler( + Extension(state): Extension>, + Json(payload): Json>, +) -> impl IntoResponse { + let session_id = payload.get("session_id").cloned().unwrap_or_default(); + let user_input = payload.get("input").cloned().unwrap_or_default(); + + info!( + "Processing user input: {} for session: {}", + user_input, session_id + ); + + let orchestrator = BotOrchestrator::new(state); + if let Ok(sessions) = orchestrator.get_user_sessions(Uuid::nil()).await { + info!("Found {} sessions", sessions.len()); + } + + ( + StatusCode::OK, + Json(serde_json::json!({ "status": format!("processed: {}", user_input) })), + ) +} + +pub async fn get_user_sessions_handler( + Extension(state): Extension>, + Json(payload): Json>, +) -> impl IntoResponse { + let user_id = payload + .get("user_id") + .and_then(|id| Uuid::parse_str(id).ok()) + .unwrap_or_else(Uuid::nil); + + let orchestrator = BotOrchestrator::new(state); + match orchestrator.get_user_sessions(user_id).await { + Ok(sessions) => ( + StatusCode::OK, + Json(serde_json::json!({ "sessions": sessions })), + ), + Err(e) => { + error!("Failed to get sessions: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ "error": e.to_string() })), + ) + } + } +} + +pub async fn get_conversation_history_handler( + Extension(state): Extension>, + Json(payload): Json>, +) -> impl IntoResponse { + let session_id = payload + .get("session_id") + .and_then(|id| Uuid::parse_str(id).ok()) + .unwrap_or_else(Uuid::nil); + let user_id = payload + .get("user_id") + .and_then(|id| Uuid::parse_str(id).ok()) + .unwrap_or_else(Uuid::nil); + + let orchestrator = BotOrchestrator::new(state); + match orchestrator + .get_conversation_history(session_id, user_id) + .await + { + Ok(history) => ( + StatusCode::OK, + Json(serde_json::json!({ "history": history })), + ), + Err(e) => { + error!("Failed to get history: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ "error": e.to_string() })), + ) + } + } +} + +pub async fn send_warning_handler( + Extension(state): Extension>, + Json(payload): Json>, +) -> impl IntoResponse { + let message = payload + .get("message") + .cloned() + .unwrap_or_else(|| "Warning".to_string()); + let session_id = payload.get("session_id").cloned().unwrap_or_default(); + + warn!("Warning for session {}: {}", session_id, message); + + let orchestrator = BotOrchestrator::new(state); + info!("Orchestrator created for warning"); + + // Use orchestrator to log state + if let Ok(sessions) = orchestrator.get_user_sessions(Uuid::nil()).await { + info!("Current active sessions: {}", sessions.len()); + } + + ( + StatusCode::OK, + Json(serde_json::json!({ "status": "warning sent", "message": message })), + ) +} diff --git a/src/core/kb/document_processor.rs b/src/core/kb/document_processor.rs index 5bbf9c976..c2f0d9cf5 100644 --- a/src/core/kb/document_processor.rs +++ b/src/core/kb/document_processor.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] + use anyhow::Result; use log::{error, info, warn}; use serde::{Deserialize, Serialize}; @@ -222,6 +224,7 @@ impl DocumentProcessor { } /// Extract PDF using poppler-utils + #[allow(dead_code)] async fn extract_pdf_with_poppler(&self, file_path: &Path) -> Result { let output = tokio::process::Command::new("pdftotext") .arg(file_path) diff --git a/src/directory/client.rs b/src/directory/client.rs index 5d2158782..2c9143c61 100644 --- a/src/directory/client.rs +++ b/src/directory/client.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] + use anyhow::{anyhow, Result}; use serde::{Deserialize, Serialize}; use std::sync::Arc; diff --git a/src/directory/mod.rs b/src/directory/mod.rs index 97286ccb7..649919928 100644 --- a/src/directory/mod.rs +++ b/src/directory/mod.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] + use crate::shared::state::AppState; use axum::{ extract::{Query, State}, diff --git a/src/drive/drive_monitor/mod.rs b/src/drive/drive_monitor/mod.rs index 079b6a259..3e2e73048 100644 --- a/src/drive/drive_monitor/mod.rs +++ b/src/drive/drive_monitor/mod.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] + use crate::basic::compiler::BasicCompiler; use crate::config::ConfigManager; use crate::core::kb::KnowledgeBaseManager; diff --git a/src/drive/vectordb.rs b/src/drive/vectordb.rs index bc284928d..e6f574e88 100644 --- a/src/drive/vectordb.rs +++ b/src/drive/vectordb.rs @@ -1,15 +1,17 @@ +#![allow(dead_code)] + use anyhow::Result; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use std::path::PathBuf; +#[cfg(feature = "vectordb")] use std::sync::Arc; use tokio::fs; use uuid::Uuid; #[cfg(feature = "vectordb")] -use qdrant_client::{ - prelude::*, - qdrant::{vectors_config::Config, CreateCollection, Distance, VectorParams, VectorsConfig}, +use qdrant_client::qdrant::{ + vectors_config::Config, CreateCollection, Distance, PointStruct, VectorParams, VectorsConfig, }; /// File metadata for vector DB indexing @@ -80,7 +82,7 @@ impl UserDriveVectorDB { /// Initialize vector DB collection #[cfg(feature = "vectordb")] pub async fn initialize(&mut self, qdrant_url: &str) -> Result<()> { - let client = QdrantClient::from_url(qdrant_url).build()?; + let client = qdrant_client::Qdrant::from_url(qdrant_url).build()?; // Check if collection exists let collections = client.list_collections().await?; @@ -130,10 +132,14 @@ impl UserDriveVectorDB { .as_ref() .ok_or_else(|| anyhow::anyhow!("Vector DB not initialized"))?; - let point = PointStruct::new(file.id.clone(), embedding, serde_json::to_value(file)?); + let payload = serde_json::to_value(file)? + .as_object() + .map(|m| m.clone()) + .unwrap_or_default(); + let point = PointStruct::new(file.id.clone(), embedding, payload); client - .upsert_points_blocking(self._collection_name.clone(), vec![point], None) + .upsert_points(self._collection_name.clone(), None, vec![point], None) .await?; log::debug!("Indexed file: {} - {}", file.id, file.file_name); @@ -161,15 +167,17 @@ impl UserDriveVectorDB { let points: Vec = files .iter() .filter_map(|(file, embedding)| { - serde_json::to_value(file).ok().map(|payload| { - PointStruct::new(file.id.clone(), embedding.clone(), payload) + serde_json::to_value(file).ok().and_then(|v| { + v.as_object().map(|m| { + PointStruct::new(file.id.clone(), embedding.clone(), m.clone()) + }) }) }) .collect(); if !points.is_empty() { client - .upsert_points_blocking(self._collection_name.clone(), points, None) + .upsert_points(self._collection_name.clone(), None, points, None) .await?; } } diff --git a/src/email/mod.rs b/src/email/mod.rs index 2eab7fcdf..5ee6cc6ff 100644 --- a/src/email/mod.rs +++ b/src/email/mod.rs @@ -50,7 +50,6 @@ pub struct SaveDraftRequest { pub bcc: Option, pub subject: String, pub body: String, - pub text: String, } // ===== Request/Response Structures ===== diff --git a/src/llm/cache.rs b/src/llm/cache.rs index 4430a641d..2f7058334 100644 --- a/src/llm/cache.rs +++ b/src/llm/cache.rs @@ -15,6 +15,7 @@ use crate::shared::utils::{estimate_token_count, DbPool}; /// Configuration for semantic caching #[derive(Clone, Debug)] + pub struct CacheConfig { /// TTL for cache entries in seconds pub ttl: u64, @@ -42,6 +43,7 @@ impl Default for CacheConfig { /// Cached LLM response with metadata #[derive(Serialize, Deserialize, Clone, Debug)] + pub struct CachedResponse { /// The actual response text pub response: String, @@ -72,6 +74,7 @@ impl std::fmt::Debug for CachedLLMProvider { .finish() } } + pub struct CachedLLMProvider { /// The underlying LLM provider provider: Arc, @@ -87,6 +90,7 @@ pub struct CachedLLMProvider { /// Trait for embedding services #[async_trait] + pub trait EmbeddingService: Send + Sync { async fn get_embedding( &self, @@ -247,6 +251,7 @@ impl CachedLLMProvider { } /// Try to get a cached response + async fn get_cached_response( &self, prompt: &str, @@ -309,6 +314,7 @@ impl CachedLLMProvider { } /// Find semantically similar cached responses + async fn find_similar_cached( &self, prompt: &str, @@ -456,6 +462,7 @@ impl CachedLLMProvider { } /// Get cache statistics + pub async fn get_cache_stats( &self, ) -> Result> { @@ -488,6 +495,7 @@ impl CachedLLMProvider { } /// Clear cache for a specific model or all models + pub async fn clear_cache( &self, model: Option<&str>, @@ -514,6 +522,7 @@ impl CachedLLMProvider { /// Cache statistics #[derive(Serialize, Deserialize, Clone, Debug)] + pub struct CacheStats { pub total_entries: usize, pub total_hits: u32, @@ -645,6 +654,7 @@ impl LLMProvider for CachedLLMProvider { // Manual Debug implementation needed for trait objects #[derive(Debug)] + pub struct LocalEmbeddingService { embedding_url: String, model: String, diff --git a/src/llm/compact_prompt.rs b/src/llm/compact_prompt.rs index 0bdf793e7..72e43618c 100644 --- a/src/llm/compact_prompt.rs +++ b/src/llm/compact_prompt.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] + use crate::core::config::ConfigManager; use crate::llm::llm_models; use crate::shared::state::AppState; diff --git a/src/llm/llm_models/deepseek_r3.rs b/src/llm/llm_models/deepseek_r3.rs index ea87799f5..3d749aac9 100644 --- a/src/llm/llm_models/deepseek_r3.rs +++ b/src/llm/llm_models/deepseek_r3.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] + use super::ModelHandler; use regex; #[derive(Debug)] diff --git a/src/llm/llm_models/gpt_oss_120b.rs b/src/llm/llm_models/gpt_oss_120b.rs index 719a6409b..e58b4e90d 100644 --- a/src/llm/llm_models/gpt_oss_120b.rs +++ b/src/llm/llm_models/gpt_oss_120b.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] + use super::ModelHandler; #[derive(Debug)] pub struct GptOss120bHandler {} diff --git a/src/llm/llm_models/gpt_oss_20b.rs b/src/llm/llm_models/gpt_oss_20b.rs index 0766255bd..61115dda7 100644 --- a/src/llm/llm_models/gpt_oss_20b.rs +++ b/src/llm/llm_models/gpt_oss_20b.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] + use super::ModelHandler; #[derive(Debug)] pub struct GptOss20bHandler; diff --git a/src/llm/llm_models/mod.rs b/src/llm/llm_models/mod.rs index cf23cd93a..be1347c2e 100644 --- a/src/llm/llm_models/mod.rs +++ b/src/llm/llm_models/mod.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] + pub mod deepseek_r3; pub mod gpt_oss_120b; pub mod gpt_oss_20b; diff --git a/src/llm/mod.rs b/src/llm/mod.rs index 7ccc948a5..659fc52ac 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] + use async_trait::async_trait; use futures::StreamExt; use log::{info, trace}; diff --git a/src/vector-db/vectordb_indexer.rs b/src/vector-db/vectordb_indexer.rs index ccfbda8d4..0b15431f3 100644 --- a/src/vector-db/vectordb_indexer.rs +++ b/src/vector-db/vectordb_indexer.rs @@ -8,6 +8,17 @@ use tokio::sync::RwLock; use tokio::time::{sleep, Duration}; use uuid::Uuid; +#[cfg(feature = "vectordb")] +use crate::drive::vectordb::UserDriveVectorDB; +#[cfg(feature = "vectordb")] +use crate::drive::vectordb::{FileContentExtractor, FileDocument}; +#[cfg(all(feature = "vectordb", feature = "email"))] +use crate::email::vectordb::UserEmailVectorDB; +#[cfg(all(feature = "vectordb", feature = "email"))] +use crate::email::vectordb::{EmailDocument, EmailEmbeddingGenerator}; +use crate::shared::utils::DbPool; +use anyhow::Result; + // UserWorkspace struct for managing user workspace paths #[derive(Debug, Clone)] struct UserWorkspace { @@ -26,14 +37,13 @@ impl UserWorkspace { } fn get_path(&self) -> PathBuf { - self.root.join(self.bot_id.to_string()).join(self.user_id.to_string()) + self.root + .join(self.bot_id.to_string()) + .join(self.user_id.to_string()) } } -use crate::shared::utils::DbPool; // VectorDB types are defined locally in this module -#[cfg(feature = "vectordb")] -use qdrant_client::prelude::*; /// Indexing job status #[derive(Debug, Clone, PartialEq)] @@ -93,7 +103,7 @@ impl VectorDBIndexer { db_pool, work_root, qdrant_url, - embedding_generator: Arc::new(EmailEmbeddingGenerator::new(llm_endpoint)), + embedding_generator: Arc::new(EmailEmbeddingGenerator { llm_endpoint }), jobs: Arc::new(RwLock::new(HashMap::new())), running: Arc::new(RwLock::new(false)), interval_seconds: 300, // Run every 5 minutes @@ -373,7 +383,7 @@ impl VectorDBIndexer { for file in chunk { // Check if file should be indexed let mime_type = file.mime_type.as_ref().map(|s| s.as_str()).unwrap_or(""); - if !FileContentExtractor::should_index(mime_type, file.file_size) { + if !FileContentExtractor::should_index(&mime_type, file.file_size) { continue; } @@ -448,7 +458,7 @@ impl VectorDBIndexer { &self, _user_id: Uuid, _account_id: &str, - ) -> Result> { + ) -> Result, Box> { // TODO: Implement actual email fetching from IMAP // This should: // 1. Connect to user's email account @@ -460,7 +470,10 @@ impl VectorDBIndexer { } /// Get unindexed files (placeholder - needs actual implementation) - async fn get_unindexed_files(&self, _user_id: Uuid) -> Result> { + async fn get_unindexed_files( + &self, + _user_id: Uuid, + ) -> Result, Box> { // TODO: Implement actual file fetching from drive // This should: // 1. List user's files from MinIO/S3