Remove unused sqlx dependency and related code
The sqlx database library has been removed from the project along with associated database-specific code that was no longer being used. This includes removal of various sqlx-related dependencies from Cargo.lock and cleanup of database connection pool references.
This commit is contained in:
parent
a42915f7fd
commit
12de4abf13
31 changed files with 4733 additions and 1421 deletions
313
Cargo.lock
generated
313
Cargo.lock
generated
|
|
@ -434,15 +434,6 @@ dependencies = [
|
|||
"system-deps",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atoi"
|
||||
version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atomic-waker"
|
||||
version = "1.1.2"
|
||||
|
|
@ -1219,7 +1210,6 @@ dependencies = [
|
|||
"serde_json",
|
||||
"sha2",
|
||||
"smartstring",
|
||||
"sqlx",
|
||||
"sysinfo",
|
||||
"tauri",
|
||||
"tauri-build",
|
||||
|
|
@ -1867,15 +1857,6 @@ dependencies = [
|
|||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-queue"
|
||||
version = "0.3.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115"
|
||||
dependencies = [
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-utils"
|
||||
version = "0.8.21"
|
||||
|
|
@ -2616,9 +2597,6 @@ name = "either"
|
|||
version = "1.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "elliptic-curve"
|
||||
|
|
@ -2789,17 +2767,6 @@ dependencies = [
|
|||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "etcetera"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "136d1b5283a1ab77bd9257427ffd09d8667ced0570b6f938942bc7568ed5b943"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"home",
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "euclid"
|
||||
version = "0.20.14"
|
||||
|
|
@ -2919,17 +2886,6 @@ dependencies = [
|
|||
"miniz_oxide",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "flume"
|
||||
version = "0.11.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "da0e4dd2a88388a1f4ccc7c9ce104604dab68d9f408dc34cd45823d5a9069095"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-sink",
|
||||
"spin",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fnv"
|
||||
version = "1.0.7"
|
||||
|
|
@ -3067,17 +3023,6 @@ dependencies = [
|
|||
"futures-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-intrusive"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"lock_api",
|
||||
"parking_lot",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-io"
|
||||
version = "0.3.31"
|
||||
|
|
@ -3552,15 +3497,6 @@ version = "0.16.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5419bdc4f6a9207fbeba6d11b604d481addf78ecd10c11ad51e76c2f6482748d"
|
||||
|
||||
[[package]]
|
||||
name = "hashlink"
|
||||
version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1"
|
||||
dependencies = [
|
||||
"hashbrown 0.15.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "heck"
|
||||
version = "0.4.1"
|
||||
|
|
@ -3603,15 +3539,6 @@ dependencies = [
|
|||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "home"
|
||||
version = "0.5.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cc627f471c528ff0c4a49e1d5e60450c8f6461dd6d10ba9dcd3a61d3dff7728d"
|
||||
dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hostname"
|
||||
version = "0.4.1"
|
||||
|
|
@ -4438,17 +4365,6 @@ checksum = "416f7e718bdb06000964960ffa43b4335ad4012ae8b99060261aa4a8088d5ccb"
|
|||
dependencies = [
|
||||
"bitflags 2.10.0",
|
||||
"libc",
|
||||
"redox_syscall",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "libsqlite3-sys"
|
||||
version = "0.30.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149"
|
||||
dependencies = [
|
||||
"pkg-config",
|
||||
"vcpkg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -7679,9 +7595,6 @@ name = "smallvec"
|
|||
version = "1.15.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "smartstring"
|
||||
|
|
@ -7767,9 +7680,6 @@ name = "spin"
|
|||
version = "0.9.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
|
||||
dependencies = [
|
||||
"lock_api",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "spki"
|
||||
|
|
@ -7791,204 +7701,6 @@ dependencies = [
|
|||
"der 0.7.10",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sqlx"
|
||||
version = "0.8.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1fefb893899429669dcdd979aff487bd78f4064e5e7907e4269081e0ef7d97dc"
|
||||
dependencies = [
|
||||
"sqlx-core",
|
||||
"sqlx-macros",
|
||||
"sqlx-mysql",
|
||||
"sqlx-postgres",
|
||||
"sqlx-sqlite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sqlx-core"
|
||||
version = "0.8.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ee6798b1838b6a0f69c007c133b8df5866302197e404e8b6ee8ed3e3a5e68dc6"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"crc",
|
||||
"crossbeam-queue",
|
||||
"either",
|
||||
"event-listener 5.4.1",
|
||||
"futures-core",
|
||||
"futures-intrusive",
|
||||
"futures-io",
|
||||
"futures-util",
|
||||
"hashbrown 0.15.5",
|
||||
"hashlink",
|
||||
"indexmap 2.12.0",
|
||||
"log",
|
||||
"memchr",
|
||||
"once_cell",
|
||||
"percent-encoding",
|
||||
"rustls 0.23.35",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"smallvec",
|
||||
"thiserror 2.0.17",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tracing",
|
||||
"url",
|
||||
"uuid",
|
||||
"webpki-roots 0.26.11",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sqlx-macros"
|
||||
version = "0.8.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a2d452988ccaacfbf5e0bdbc348fb91d7c8af5bee192173ac3636b5fb6e6715d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"sqlx-core",
|
||||
"sqlx-macros-core",
|
||||
"syn 2.0.110",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sqlx-macros-core"
|
||||
version = "0.8.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "19a9c1841124ac5a61741f96e1d9e2ec77424bf323962dd894bdb93f37d5219b"
|
||||
dependencies = [
|
||||
"dotenvy",
|
||||
"either",
|
||||
"heck 0.5.0",
|
||||
"hex",
|
||||
"once_cell",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"sqlx-core",
|
||||
"sqlx-mysql",
|
||||
"sqlx-postgres",
|
||||
"sqlx-sqlite",
|
||||
"syn 2.0.110",
|
||||
"tokio",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sqlx-mysql"
|
||||
version = "0.8.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "aa003f0038df784eb8fecbbac13affe3da23b45194bd57dba231c8f48199c526"
|
||||
dependencies = [
|
||||
"atoi",
|
||||
"base64 0.22.1",
|
||||
"bitflags 2.10.0",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"crc",
|
||||
"digest",
|
||||
"dotenvy",
|
||||
"either",
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-io",
|
||||
"futures-util",
|
||||
"generic-array",
|
||||
"hex",
|
||||
"hkdf",
|
||||
"hmac",
|
||||
"itoa",
|
||||
"log",
|
||||
"md-5",
|
||||
"memchr",
|
||||
"once_cell",
|
||||
"percent-encoding",
|
||||
"rand 0.8.5",
|
||||
"rsa",
|
||||
"serde",
|
||||
"sha1",
|
||||
"sha2",
|
||||
"smallvec",
|
||||
"sqlx-core",
|
||||
"stringprep",
|
||||
"thiserror 2.0.17",
|
||||
"tracing",
|
||||
"uuid",
|
||||
"whoami",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sqlx-postgres"
|
||||
version = "0.8.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "db58fcd5a53cf07c184b154801ff91347e4c30d17a3562a635ff028ad5deda46"
|
||||
dependencies = [
|
||||
"atoi",
|
||||
"base64 0.22.1",
|
||||
"bitflags 2.10.0",
|
||||
"byteorder",
|
||||
"chrono",
|
||||
"crc",
|
||||
"dotenvy",
|
||||
"etcetera",
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"hex",
|
||||
"hkdf",
|
||||
"hmac",
|
||||
"home",
|
||||
"itoa",
|
||||
"log",
|
||||
"md-5",
|
||||
"memchr",
|
||||
"once_cell",
|
||||
"rand 0.8.5",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"smallvec",
|
||||
"sqlx-core",
|
||||
"stringprep",
|
||||
"thiserror 2.0.17",
|
||||
"tracing",
|
||||
"uuid",
|
||||
"whoami",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sqlx-sqlite"
|
||||
version = "0.8.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c2d12fe70b2c1b4401038055f90f151b78208de1f9f89a7dbfd41587a10c3eea"
|
||||
dependencies = [
|
||||
"atoi",
|
||||
"chrono",
|
||||
"flume",
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-executor",
|
||||
"futures-intrusive",
|
||||
"futures-util",
|
||||
"libsqlite3-sys",
|
||||
"log",
|
||||
"percent-encoding",
|
||||
"serde",
|
||||
"serde_urlencoded",
|
||||
"sqlx-core",
|
||||
"thiserror 2.0.17",
|
||||
"tracing",
|
||||
"url",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "stable_deref_trait"
|
||||
version = "1.2.1"
|
||||
|
|
@ -9543,12 +9255,6 @@ dependencies = [
|
|||
"wit-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasite"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b"
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen"
|
||||
version = "0.2.105"
|
||||
|
|
@ -9750,15 +9456,6 @@ version = "0.25.4"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1"
|
||||
|
||||
[[package]]
|
||||
name = "webpki-roots"
|
||||
version = "0.26.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9"
|
||||
dependencies = [
|
||||
"webpki-roots 1.0.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webpki-roots"
|
||||
version = "1.0.4"
|
||||
|
|
@ -9840,16 +9537,6 @@ version = "0.1.12"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a28ac98ddc8b9274cb41bb4d9d4d5c425b6020c50c46f25559911905610b4a88"
|
||||
|
||||
[[package]]
|
||||
name = "whoami"
|
||||
version = "1.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5d4a4db5077702ca3015d3d02d74974948aba2ad9e12ab7df718ee64ccd7e97d"
|
||||
dependencies = [
|
||||
"libredox",
|
||||
"wasite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi"
|
||||
version = "0.3.9"
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ msteams = []
|
|||
chat = []
|
||||
drive = ["dep:aws-config", "dep:aws-sdk-s3", "dep:pdf-extract", "dep:zip", "dep:downloader", "dep:mime_guess"]
|
||||
tasks = ["dep:cron"]
|
||||
calendar = ["dep:sqlx"]
|
||||
calendar = []
|
||||
meet = ["dep:livekit"]
|
||||
mail = ["email"]
|
||||
|
||||
|
|
@ -138,9 +138,6 @@ zitadel = { version = "5.5.1", features = ["api", "credentials"] }
|
|||
|
||||
# === FEATURE-SPECIFIC DEPENDENCIES (Optional) ===
|
||||
|
||||
# Database (for calendar and other features)
|
||||
sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "postgres", "chrono", "uuid"], optional = true }
|
||||
|
||||
# Desktop UI (desktop feature)
|
||||
tauri = { version = "2", features = ["unstable"], optional = true }
|
||||
tauri-plugin-dialog = { version = "2", optional = true }
|
||||
|
|
|
|||
160
docs/src/chapter-11-features/drive-monitor.md
Normal file
160
docs/src/chapter-11-features/drive-monitor.md
Normal file
|
|
@ -0,0 +1,160 @@
|
|||
# Drive Monitor
|
||||
|
||||
The Drive Monitor is a real-time file synchronization system that watches for changes in bot storage buckets and automatically updates the database and runtime configuration.
|
||||
|
||||
## Overview
|
||||
|
||||
DriveMonitor provides hot-reloading capabilities for bot configurations by continuously monitoring file changes in object storage. When files are modified, added, or removed, the system automatically:
|
||||
|
||||
- Detects changes through ETags and file comparison
|
||||
- Updates the database with new configurations
|
||||
- Recompiles scripts and tools
|
||||
- Refreshes knowledge bases
|
||||
- Broadcasts theme changes to connected clients
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌─────────────────┐
|
||||
│ Object Storage │ (S3-compatible)
|
||||
│ Buckets │
|
||||
└────────┬────────┘
|
||||
│ Poll every 30s
|
||||
▼
|
||||
┌─────────────────┐
|
||||
│ Drive Monitor │
|
||||
│ - Check ETags │
|
||||
│ - Diff files │
|
||||
└────────┬────────┘
|
||||
│ Changes detected
|
||||
▼
|
||||
┌─────────────────────────┐
|
||||
│ Process Updates │
|
||||
│ - Compile scripts (.bas)│
|
||||
│ - Update KB (.gbkb) │
|
||||
│ - Refresh themes │
|
||||
│ - Update database │
|
||||
└─────────────────────────┘
|
||||
```
|
||||
|
||||
## Implementation
|
||||
|
||||
### Core Components
|
||||
|
||||
The DriveMonitor is implemented in `src/drive/drive_monitor/mod.rs` with the following structure:
|
||||
|
||||
```rust
|
||||
pub struct DriveMonitor {
|
||||
state: Arc<AppState>,
|
||||
bucket_name: String,
|
||||
file_states: Arc<RwLock<HashMap<String, FileState>>>,
|
||||
bot_id: Uuid,
|
||||
kb_manager: Arc<KnowledgeBaseManager>,
|
||||
work_root: PathBuf,
|
||||
is_processing: Arc<AtomicBool>,
|
||||
}
|
||||
```
|
||||
|
||||
### Monitoring Process
|
||||
|
||||
1. **Initialization**: When a bot is mounted, a DriveMonitor instance is created and spawned
|
||||
2. **Polling**: Every 30 seconds, the monitor checks for changes in:
|
||||
- `.gbdialog` files (scripts and tools)
|
||||
- `.gbkb` collections (knowledge base documents)
|
||||
- `.gbtheme` files (UI themes)
|
||||
- `.gbot/config.csv` (bot configuration)
|
||||
|
||||
3. **Change Detection**: Uses ETags to detect file modifications efficiently
|
||||
4. **Processing**: Different file types trigger specific handlers:
|
||||
- Scripts → Compile to AST
|
||||
- Knowledge base → Index and embed documents
|
||||
- Themes → Broadcast updates to WebSocket clients
|
||||
- Config → Reload bot settings
|
||||
|
||||
### File Type Handlers
|
||||
|
||||
#### Script Files (.bas)
|
||||
- Compiles BASIC scripts to AST
|
||||
- Stores compiled version in database
|
||||
- Updates tool registry if applicable
|
||||
|
||||
#### Knowledge Base Files (.gbkb)
|
||||
- Downloads new/modified documents
|
||||
- Processes text extraction
|
||||
- Generates embeddings
|
||||
- Updates vector database
|
||||
|
||||
#### Theme Files (.gbtheme)
|
||||
- Detects CSS/JS changes
|
||||
- Broadcasts updates to connected clients
|
||||
- Triggers UI refresh without page reload
|
||||
|
||||
## Usage
|
||||
|
||||
The DriveMonitor is automatically started when a bot is mounted:
|
||||
|
||||
```rust
|
||||
// In BotOrchestrator::mount_bot
|
||||
let drive_monitor = Arc::new(DriveMonitor::new(
|
||||
state.clone(),
|
||||
bucket_name,
|
||||
bot_id
|
||||
));
|
||||
let _handle = drive_monitor.clone().spawn().await;
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
No explicit configuration needed - the monitor automatically:
|
||||
- Uses the bot's storage bucket name
|
||||
- Creates work directories as needed
|
||||
- Manages its own file state cache
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
- **Polling Interval**: 30 seconds (balance between responsiveness and resource usage)
|
||||
- **Concurrent Processing**: Uses atomic flags to prevent overlapping operations
|
||||
- **Caching**: Maintains ETag cache to minimize unnecessary downloads
|
||||
- **Batching**: Processes multiple file changes in a single cycle
|
||||
|
||||
## Error Handling
|
||||
|
||||
The monitor includes robust error handling:
|
||||
- Continues operation even if individual file processing fails
|
||||
- Logs errors for debugging while maintaining service availability
|
||||
- Prevents cascading failures through isolated error boundaries
|
||||
|
||||
## Monitoring and Debugging
|
||||
|
||||
Enable debug logging to see monitor activity:
|
||||
|
||||
```bash
|
||||
RUST_LOG=botserver::drive::drive_monitor=debug cargo run
|
||||
```
|
||||
|
||||
Log output includes:
|
||||
- Change detection events
|
||||
- File processing status
|
||||
- Compilation results
|
||||
- Database update confirmations
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **File Organization**: Keep related files in appropriate directories (.gbdialog, .gbkb, etc.)
|
||||
2. **Version Control**: The monitor tracks changes but doesn't maintain history - use git for version control
|
||||
3. **Large Files**: For knowledge base documents > 10MB, consider splitting into smaller files
|
||||
4. **Development**: During development, the 30-second delay can be avoided by restarting the bot
|
||||
|
||||
## Limitations
|
||||
|
||||
- **Not Real-time**: 30-second polling interval means changes aren't instant
|
||||
- **No Conflict Resolution**: Last-write-wins for concurrent modifications
|
||||
- **Memory Usage**: Keeps file state in memory (minimal for ETags)
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
Planned improvements include:
|
||||
- WebSocket notifications from storage layer for instant updates
|
||||
- Configurable polling intervals per file type
|
||||
- Differential sync for large knowledge bases
|
||||
- Multi-version support for A/B testing
|
||||
|
|
@ -14,7 +14,7 @@ When initial attempts fail, sequentially try these LLMs:
|
|||
- **On unresolved error**: Stop and use add-req.sh, and consult Claude for guidance. with DeepThining in DeepSeek also, with Web turned on.
|
||||
- **Change progression**: Start with DeepSeek, conclude with gpt-oss-120b
|
||||
- If a big req. fail, specify a @code file that has similar pattern or sample from official docs.
|
||||
- **Warning removal**: Last task before commiting, create a task list of warning removal and work with cargo check.
|
||||
- **Warning removal**: Last task before commiting, create a task list of warning removal and work with cargo check. If lots of warning, let LLM put #[allow(dead_code)] on top. Check manually for missing/deleted code on some files.
|
||||
- **Final validation**: Use prompt "cargo check" with gpt-oss-120b
|
||||
- Be humble, one requirement, one commit. But sometimes, freedom of caos is welcome - when no deadlines are set.
|
||||
- Fix manually in case of dangerous trouble.
|
||||
|
|
|
|||
399
src/attendance/drive.rs
Normal file
399
src/attendance/drive.rs
Normal file
|
|
@ -0,0 +1,399 @@
|
|||
//! Drive integration module for attendance system
|
||||
//! Handles file storage and synchronization for attendance records
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use aws_sdk_s3::primitives::ByteStream;
|
||||
use aws_sdk_s3::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
use tokio::fs;
|
||||
|
||||
/// Drive configuration for attendance storage
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AttendanceDriveConfig {
|
||||
pub bucket_name: String,
|
||||
pub prefix: String,
|
||||
pub sync_enabled: bool,
|
||||
pub region: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for AttendanceDriveConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
bucket_name: "attendance".to_string(),
|
||||
prefix: "records/".to_string(),
|
||||
sync_enabled: true,
|
||||
region: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Drive service for attendance data
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AttendanceDriveService {
|
||||
config: AttendanceDriveConfig,
|
||||
client: Client,
|
||||
}
|
||||
|
||||
impl AttendanceDriveService {
|
||||
/// Create new attendance drive service
|
||||
pub async fn new(config: AttendanceDriveConfig) -> Result<Self> {
|
||||
let sdk_config = if let Some(region) = &config.region {
|
||||
aws_config::from_env()
|
||||
.region(aws_config::Region::new(region.clone()))
|
||||
.load()
|
||||
.await
|
||||
} else {
|
||||
aws_config::from_env().load().await
|
||||
};
|
||||
|
||||
let client = Client::new(&sdk_config);
|
||||
|
||||
Ok(Self { config, client })
|
||||
}
|
||||
|
||||
/// Create new service with existing S3 client
|
||||
pub fn with_client(config: AttendanceDriveConfig, client: Client) -> Self {
|
||||
Self { config, client }
|
||||
}
|
||||
|
||||
/// Get the full S3 key for a record
|
||||
fn get_record_key(&self, record_id: &str) -> String {
|
||||
format!("{}{}", self.config.prefix, record_id)
|
||||
}
|
||||
|
||||
/// Upload attendance record to drive
|
||||
pub async fn upload_record(&self, record_id: &str, data: Vec<u8>) -> Result<()> {
|
||||
let key = self.get_record_key(record_id);
|
||||
|
||||
log::info!(
|
||||
"Uploading attendance record {} to s3://{}/{}",
|
||||
record_id,
|
||||
self.config.bucket_name,
|
||||
key
|
||||
);
|
||||
|
||||
let body = ByteStream::from(data);
|
||||
|
||||
self.client
|
||||
.put_object()
|
||||
.bucket(&self.config.bucket_name)
|
||||
.key(&key)
|
||||
.body(body)
|
||||
.content_type("application/octet-stream")
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to upload attendance record: {}", e))?;
|
||||
|
||||
log::debug!("Successfully uploaded attendance record {}", record_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Download attendance record from drive
|
||||
pub async fn download_record(&self, record_id: &str) -> Result<Vec<u8>> {
|
||||
let key = self.get_record_key(record_id);
|
||||
|
||||
log::info!(
|
||||
"Downloading attendance record {} from s3://{}/{}",
|
||||
record_id,
|
||||
self.config.bucket_name,
|
||||
key
|
||||
);
|
||||
|
||||
let result = self
|
||||
.client
|
||||
.get_object()
|
||||
.bucket(&self.config.bucket_name)
|
||||
.key(&key)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to download attendance record: {}", e))?;
|
||||
|
||||
let data = result
|
||||
.body
|
||||
.collect()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to read attendance record body: {}", e))?;
|
||||
|
||||
log::debug!("Successfully downloaded attendance record {}", record_id);
|
||||
Ok(data.into_bytes().to_vec())
|
||||
}
|
||||
|
||||
/// List attendance records in drive
|
||||
pub async fn list_records(&self, prefix: Option<&str>) -> Result<Vec<String>> {
|
||||
let list_prefix = if let Some(p) = prefix {
|
||||
format!("{}{}", self.config.prefix, p)
|
||||
} else {
|
||||
self.config.prefix.clone()
|
||||
};
|
||||
|
||||
log::info!(
|
||||
"Listing attendance records in s3://{}/{}",
|
||||
self.config.bucket_name,
|
||||
list_prefix
|
||||
);
|
||||
|
||||
let mut records = Vec::new();
|
||||
let mut continuation_token = None;
|
||||
|
||||
loop {
|
||||
let mut request = self
|
||||
.client
|
||||
.list_objects_v2()
|
||||
.bucket(&self.config.bucket_name)
|
||||
.prefix(&list_prefix)
|
||||
.max_keys(1000);
|
||||
|
||||
if let Some(token) = continuation_token {
|
||||
request = request.continuation_token(token);
|
||||
}
|
||||
|
||||
let result = request
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to list attendance records: {}", e))?;
|
||||
|
||||
if let Some(contents) = result.contents {
|
||||
for obj in contents {
|
||||
if let Some(key) = obj.key {
|
||||
// Remove prefix to get record ID
|
||||
if let Some(record_id) = key.strip_prefix(&self.config.prefix) {
|
||||
records.push(record_id.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if result.is_truncated.unwrap_or(false) {
|
||||
continuation_token = result.next_continuation_token;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
log::debug!("Found {} attendance records", records.len());
|
||||
Ok(records)
|
||||
}
|
||||
|
||||
/// Delete attendance record from drive
|
||||
pub async fn delete_record(&self, record_id: &str) -> Result<()> {
|
||||
let key = self.get_record_key(record_id);
|
||||
|
||||
log::info!(
|
||||
"Deleting attendance record {} from s3://{}/{}",
|
||||
record_id,
|
||||
self.config.bucket_name,
|
||||
key
|
||||
);
|
||||
|
||||
self.client
|
||||
.delete_object()
|
||||
.bucket(&self.config.bucket_name)
|
||||
.key(&key)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to delete attendance record: {}", e))?;
|
||||
|
||||
log::debug!("Successfully deleted attendance record {}", record_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Batch delete multiple attendance records
|
||||
pub async fn delete_records(&self, record_ids: &[String]) -> Result<()> {
|
||||
if record_ids.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
log::info!(
|
||||
"Batch deleting {} attendance records from bucket {}",
|
||||
record_ids.len(),
|
||||
self.config.bucket_name
|
||||
);
|
||||
|
||||
// S3 batch delete is limited to 1000 objects per request
|
||||
for chunk in record_ids.chunks(1000) {
|
||||
let objects: Vec<_> = chunk
|
||||
.iter()
|
||||
.map(|id| {
|
||||
aws_sdk_s3::types::ObjectIdentifier::builder()
|
||||
.key(self.get_record_key(id))
|
||||
.build()
|
||||
.unwrap()
|
||||
})
|
||||
.collect();
|
||||
|
||||
let delete = aws_sdk_s3::types::Delete::builder()
|
||||
.set_objects(Some(objects))
|
||||
.build()
|
||||
.map_err(|e| anyhow!("Failed to build delete request: {}", e))?;
|
||||
|
||||
self.client
|
||||
.delete_objects()
|
||||
.bucket(&self.config.bucket_name)
|
||||
.delete(delete)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to batch delete attendance records: {}", e))?;
|
||||
}
|
||||
|
||||
log::debug!(
|
||||
"Successfully batch deleted {} attendance records",
|
||||
record_ids.len()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if an attendance record exists
|
||||
pub async fn record_exists(&self, record_id: &str) -> Result<bool> {
|
||||
let key = self.get_record_key(record_id);
|
||||
|
||||
match self
|
||||
.client
|
||||
.head_object()
|
||||
.bucket(&self.config.bucket_name)
|
||||
.key(&key)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(_) => Ok(true),
|
||||
Err(sdk_err) => {
|
||||
if sdk_err.to_string().contains("404") || sdk_err.to_string().contains("NotFound") {
|
||||
Ok(false)
|
||||
} else {
|
||||
Err(anyhow!(
|
||||
"Failed to check attendance record existence: {}",
|
||||
sdk_err
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Sync local attendance records with drive
|
||||
pub async fn sync_records(&self, local_path: PathBuf) -> Result<SyncResult> {
|
||||
if !self.config.sync_enabled {
|
||||
log::debug!("Attendance drive sync is disabled");
|
||||
return Ok(SyncResult::default());
|
||||
}
|
||||
|
||||
log::info!(
|
||||
"Syncing attendance records from {:?} to s3://{}/{}",
|
||||
local_path,
|
||||
self.config.bucket_name,
|
||||
self.config.prefix
|
||||
);
|
||||
|
||||
if !local_path.exists() {
|
||||
return Err(anyhow!("Local path does not exist: {:?}", local_path));
|
||||
}
|
||||
|
||||
let mut uploaded = 0;
|
||||
let mut failed = 0;
|
||||
let mut skipped = 0;
|
||||
|
||||
let mut entries = fs::read_dir(&local_path)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to read local directory: {}", e))?;
|
||||
|
||||
while let Some(entry) = entries
|
||||
.next_entry()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to read directory entry: {}", e))?
|
||||
{
|
||||
let path = entry.path();
|
||||
if !path.is_file() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let file_name = match path.file_name().and_then(|n| n.to_str()) {
|
||||
Some(name) => name.to_string(),
|
||||
None => {
|
||||
log::warn!("Skipping file with invalid name: {:?}", path);
|
||||
skipped += 1;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Check if record already exists in S3
|
||||
if self.record_exists(&file_name).await? {
|
||||
log::debug!("Record {} already exists in drive, skipping", file_name);
|
||||
skipped += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Read file and upload
|
||||
match fs::read(&path).await {
|
||||
Ok(data) => match self.upload_record(&file_name, data).await {
|
||||
Ok(_) => {
|
||||
log::debug!("Uploaded attendance record: {}", file_name);
|
||||
uploaded += 1;
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("Failed to upload {}: {}", file_name, e);
|
||||
failed += 1;
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
log::error!("Failed to read file {:?}: {}", path, e);
|
||||
failed += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let result = SyncResult {
|
||||
uploaded,
|
||||
failed,
|
||||
skipped,
|
||||
};
|
||||
|
||||
log::info!(
|
||||
"Sync completed: {} uploaded, {} failed, {} skipped",
|
||||
result.uploaded,
|
||||
result.failed,
|
||||
result.skipped
|
||||
);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Get metadata for an attendance record
|
||||
pub async fn get_record_metadata(&self, record_id: &str) -> Result<RecordMetadata> {
|
||||
let key = self.get_record_key(record_id);
|
||||
|
||||
let result = self
|
||||
.client
|
||||
.head_object()
|
||||
.bucket(&self.config.bucket_name)
|
||||
.key(&key)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get attendance record metadata: {}", e))?;
|
||||
|
||||
Ok(RecordMetadata {
|
||||
size: result.content_length.unwrap_or(0) as usize,
|
||||
last_modified: result
|
||||
.last_modified
|
||||
.and_then(|t| t.to_millis().ok())
|
||||
.map(|ms| chrono::Utc.timestamp_millis_opt(ms as i64).unwrap()),
|
||||
content_type: result.content_type,
|
||||
etag: result.e_tag,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of sync operation
|
||||
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
|
||||
pub struct SyncResult {
|
||||
pub uploaded: usize,
|
||||
pub failed: usize,
|
||||
pub skipped: usize,
|
||||
}
|
||||
|
||||
/// Metadata for an attendance record
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RecordMetadata {
|
||||
pub size: usize,
|
||||
pub last_modified: Option<chrono::DateTime<chrono::Utc>>,
|
||||
pub content_type: Option<String>,
|
||||
pub etag: Option<String>,
|
||||
}
|
||||
565
src/attendance/keyword_services.rs
Normal file
565
src/attendance/keyword_services.rs
Normal file
|
|
@ -0,0 +1,565 @@
|
|||
//! Keyword-based services for attendance system
|
||||
//! Provides automated keyword detection and processing for attendance commands
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use chrono::{DateTime, Duration, Local, NaiveTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Keyword command types for attendance
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum AttendanceCommand {
|
||||
CheckIn,
|
||||
CheckOut,
|
||||
Break,
|
||||
Resume,
|
||||
Status,
|
||||
Report,
|
||||
Override,
|
||||
}
|
||||
|
||||
/// Keyword configuration for attendance
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct KeywordConfig {
|
||||
pub enabled: bool,
|
||||
pub case_sensitive: bool,
|
||||
pub prefix: Option<String>,
|
||||
pub keywords: HashMap<String, AttendanceCommand>,
|
||||
pub aliases: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl Default for KeywordConfig {
|
||||
fn default() -> Self {
|
||||
let mut keywords = HashMap::new();
|
||||
keywords.insert("checkin".to_string(), AttendanceCommand::CheckIn);
|
||||
keywords.insert("checkout".to_string(), AttendanceCommand::CheckOut);
|
||||
keywords.insert("break".to_string(), AttendanceCommand::Break);
|
||||
keywords.insert("resume".to_string(), AttendanceCommand::Resume);
|
||||
keywords.insert("status".to_string(), AttendanceCommand::Status);
|
||||
keywords.insert("report".to_string(), AttendanceCommand::Report);
|
||||
keywords.insert("override".to_string(), AttendanceCommand::Override);
|
||||
|
||||
let mut aliases = HashMap::new();
|
||||
aliases.insert("in".to_string(), "checkin".to_string());
|
||||
aliases.insert("out".to_string(), "checkout".to_string());
|
||||
aliases.insert("pause".to_string(), "break".to_string());
|
||||
aliases.insert("continue".to_string(), "resume".to_string());
|
||||
aliases.insert("stat".to_string(), "status".to_string());
|
||||
|
||||
Self {
|
||||
enabled: true,
|
||||
case_sensitive: false,
|
||||
prefix: Some("!".to_string()),
|
||||
keywords,
|
||||
aliases,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parsed keyword command
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ParsedCommand {
|
||||
pub command: AttendanceCommand,
|
||||
pub args: Vec<String>,
|
||||
pub timestamp: DateTime<Utc>,
|
||||
pub raw_input: String,
|
||||
}
|
||||
|
||||
/// Keyword parser for attendance commands
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KeywordParser {
|
||||
config: Arc<RwLock<KeywordConfig>>,
|
||||
}
|
||||
|
||||
impl KeywordParser {
|
||||
/// Create new keyword parser
|
||||
pub fn new(config: KeywordConfig) -> Self {
|
||||
Self {
|
||||
config: Arc::new(RwLock::new(config)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse input text for attendance commands
|
||||
pub async fn parse(&self, input: &str) -> Option<ParsedCommand> {
|
||||
let config = self.config.read().await;
|
||||
|
||||
if !config.enabled {
|
||||
return None;
|
||||
}
|
||||
|
||||
let processed_input = if config.case_sensitive {
|
||||
input.trim().to_string()
|
||||
} else {
|
||||
input.trim().to_lowercase()
|
||||
};
|
||||
|
||||
// Check for prefix if configured
|
||||
let command_text = if let Some(prefix) = &config.prefix {
|
||||
if !processed_input.starts_with(prefix) {
|
||||
return None;
|
||||
}
|
||||
processed_input.strip_prefix(prefix)?
|
||||
} else {
|
||||
&processed_input
|
||||
};
|
||||
|
||||
// Split command and arguments
|
||||
let parts: Vec<&str> = command_text.split_whitespace().collect();
|
||||
if parts.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let command_word = parts[0];
|
||||
let args: Vec<String> = parts[1..].iter().map(|s| s.to_string()).collect();
|
||||
|
||||
// Resolve aliases
|
||||
let resolved_command = if let Some(alias) = config.aliases.get(command_word) {
|
||||
alias.as_str()
|
||||
} else {
|
||||
command_word
|
||||
};
|
||||
|
||||
// Look up command
|
||||
let command = config.keywords.get(resolved_command)?;
|
||||
|
||||
Some(ParsedCommand {
|
||||
command: command.clone(),
|
||||
args,
|
||||
timestamp: Utc::now(),
|
||||
raw_input: input.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Update configuration
|
||||
pub async fn update_config(&self, config: KeywordConfig) {
|
||||
let mut current = self.config.write().await;
|
||||
*current = config;
|
||||
}
|
||||
|
||||
/// Add a new keyword
|
||||
pub async fn add_keyword(&self, keyword: String, command: AttendanceCommand) {
|
||||
let mut config = self.config.write().await;
|
||||
config.keywords.insert(keyword, command);
|
||||
}
|
||||
|
||||
/// Add a new alias
|
||||
pub async fn add_alias(&self, alias: String, target: String) {
|
||||
let mut config = self.config.write().await;
|
||||
config.aliases.insert(alias, target);
|
||||
}
|
||||
|
||||
/// Remove a keyword
|
||||
pub async fn remove_keyword(&self, keyword: &str) -> bool {
|
||||
let mut config = self.config.write().await;
|
||||
config.keywords.remove(keyword).is_some()
|
||||
}
|
||||
|
||||
/// Remove an alias
|
||||
pub async fn remove_alias(&self, alias: &str) -> bool {
|
||||
let mut config = self.config.write().await;
|
||||
config.aliases.remove(alias).is_some()
|
||||
}
|
||||
|
||||
/// Get current configuration
|
||||
pub async fn get_config(&self) -> KeywordConfig {
|
||||
self.config.read().await.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Attendance record
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AttendanceRecord {
|
||||
pub id: String,
|
||||
pub user_id: String,
|
||||
pub command: AttendanceCommand,
|
||||
pub timestamp: DateTime<Utc>,
|
||||
pub location: Option<String>,
|
||||
pub notes: Option<String>,
|
||||
}
|
||||
|
||||
/// Attendance service for processing commands
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AttendanceService {
|
||||
parser: Arc<KeywordParser>,
|
||||
records: Arc<RwLock<Vec<AttendanceRecord>>>,
|
||||
}
|
||||
|
||||
impl AttendanceService {
|
||||
/// Create new attendance service
|
||||
pub fn new(parser: KeywordParser) -> Self {
|
||||
Self {
|
||||
parser: Arc::new(parser),
|
||||
records: Arc::new(RwLock::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a text input for attendance commands
|
||||
pub async fn process_input(
|
||||
&self,
|
||||
user_id: &str,
|
||||
input: &str,
|
||||
) -> Result<AttendanceResponse> {
|
||||
let parsed = self
|
||||
.parser
|
||||
.parse(input)
|
||||
.await
|
||||
.ok_or_else(|| anyhow!("No valid command found in input"))?;
|
||||
|
||||
match parsed.command {
|
||||
AttendanceCommand::CheckIn => self.handle_check_in(user_id, &parsed).await,
|
||||
AttendanceCommand::CheckOut => self.handle_check_out(user_id, &parsed).await,
|
||||
AttendanceCommand::Break => self.handle_break(user_id, &parsed).await,
|
||||
AttendanceCommand::Resume => self.handle_resume(user_id, &parsed).await,
|
||||
AttendanceCommand::Status => self.handle_status(user_id).await,
|
||||
AttendanceCommand::Report => self.handle_report(user_id, &parsed).await,
|
||||
AttendanceCommand::Override => self.handle_override(user_id, &parsed).await,
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle check-in command
|
||||
async fn handle_check_in(
|
||||
&self,
|
||||
user_id: &str,
|
||||
parsed: &ParsedCommand,
|
||||
) -> Result<AttendanceResponse> {
|
||||
let mut records = self.records.write().await;
|
||||
|
||||
// Check if already checked in
|
||||
if let Some(last_record) = records.iter().rev().find(|r| r.user_id == user_id) {
|
||||
if matches!(last_record.command, AttendanceCommand::CheckIn) {
|
||||
return Ok(AttendanceResponse::Error {
|
||||
message: "Already checked in".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let record = AttendanceRecord {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
user_id: user_id.to_string(),
|
||||
command: AttendanceCommand::CheckIn,
|
||||
timestamp: parsed.timestamp,
|
||||
location: parsed.args.first().cloned(),
|
||||
notes: if parsed.args.len() > 1 {
|
||||
Some(parsed.args[1..].join(" "))
|
||||
} else {
|
||||
None
|
||||
},
|
||||
};
|
||||
|
||||
let time = Local::now().format("%H:%M").to_string();
|
||||
records.push(record);
|
||||
|
||||
Ok(AttendanceResponse::Success {
|
||||
message: format!("Checked in at {}", time),
|
||||
timestamp: parsed.timestamp,
|
||||
})
|
||||
}
|
||||
|
||||
/// Handle check-out command
|
||||
async fn handle_check_out(
|
||||
&self,
|
||||
user_id: &str,
|
||||
parsed: &ParsedCommand,
|
||||
) -> Result<AttendanceResponse> {
|
||||
let mut records = self.records.write().await;
|
||||
|
||||
// Find last check-in
|
||||
let check_in_time = records
|
||||
.iter()
|
||||
.rev()
|
||||
.find(|r| r.user_id == user_id && matches!(r.command, AttendanceCommand::CheckIn))
|
||||
.map(|r| r.timestamp);
|
||||
|
||||
if check_in_time.is_none() {
|
||||
return Ok(AttendanceResponse::Error {
|
||||
message: "Not checked in".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
let record = AttendanceRecord {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
user_id: user_id.to_string(),
|
||||
command: AttendanceCommand::CheckOut,
|
||||
timestamp: parsed.timestamp,
|
||||
location: parsed.args.first().cloned(),
|
||||
notes: if parsed.args.len() > 1 {
|
||||
Some(parsed.args[1..].join(" "))
|
||||
} else {
|
||||
None
|
||||
},
|
||||
};
|
||||
|
||||
let duration = parsed.timestamp - check_in_time.unwrap();
|
||||
let hours = duration.num_hours();
|
||||
let minutes = duration.num_minutes() % 60;
|
||||
|
||||
records.push(record);
|
||||
|
||||
Ok(AttendanceResponse::Success {
|
||||
message: format!("Checked out. Total time: {}h {}m", hours, minutes),
|
||||
timestamp: parsed.timestamp,
|
||||
})
|
||||
}
|
||||
|
||||
/// Handle break command
|
||||
async fn handle_break(
|
||||
&self,
|
||||
user_id: &str,
|
||||
parsed: &ParsedCommand,
|
||||
) -> Result<AttendanceResponse> {
|
||||
let mut records = self.records.write().await;
|
||||
|
||||
// Check if checked in
|
||||
let is_checked_in = records
|
||||
.iter()
|
||||
.rev()
|
||||
.find(|r| r.user_id == user_id)
|
||||
.map(|r| matches!(r.command, AttendanceCommand::CheckIn))
|
||||
.unwrap_or(false);
|
||||
|
||||
if !is_checked_in {
|
||||
return Ok(AttendanceResponse::Error {
|
||||
message: "Not checked in".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
let record = AttendanceRecord {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
user_id: user_id.to_string(),
|
||||
command: AttendanceCommand::Break,
|
||||
timestamp: parsed.timestamp,
|
||||
location: None,
|
||||
notes: parsed.args.first().cloned(),
|
||||
};
|
||||
|
||||
let time = Local::now().format("%H:%M").to_string();
|
||||
records.push(record);
|
||||
|
||||
Ok(AttendanceResponse::Success {
|
||||
message: format!("Break started at {}", time),
|
||||
timestamp: parsed.timestamp,
|
||||
})
|
||||
}
|
||||
|
||||
/// Handle resume command
|
||||
async fn handle_resume(
|
||||
&self,
|
||||
user_id: &str,
|
||||
parsed: &ParsedCommand,
|
||||
) -> Result<AttendanceResponse> {
|
||||
let mut records = self.records.write().await;
|
||||
|
||||
// Find last break
|
||||
let break_time = records
|
||||
.iter()
|
||||
.rev()
|
||||
.find(|r| r.user_id == user_id && matches!(r.command, AttendanceCommand::Break))
|
||||
.map(|r| r.timestamp);
|
||||
|
||||
if break_time.is_none() {
|
||||
return Ok(AttendanceResponse::Error {
|
||||
message: "Not on break".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
let record = AttendanceRecord {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
user_id: user_id.to_string(),
|
||||
command: AttendanceCommand::Resume,
|
||||
timestamp: parsed.timestamp,
|
||||
location: None,
|
||||
notes: None,
|
||||
};
|
||||
|
||||
let duration = parsed.timestamp - break_time.unwrap();
|
||||
let minutes = duration.num_minutes();
|
||||
|
||||
records.push(record);
|
||||
|
||||
Ok(AttendanceResponse::Success {
|
||||
message: format!("Resumed work. Break duration: {} minutes", minutes),
|
||||
timestamp: parsed.timestamp,
|
||||
})
|
||||
}
|
||||
|
||||
/// Handle status command
|
||||
async fn handle_status(&self, user_id: &str) -> Result<AttendanceResponse> {
|
||||
let records = self.records.read().await;
|
||||
|
||||
let user_records: Vec<_> = records
|
||||
.iter()
|
||||
.filter(|r| r.user_id == user_id)
|
||||
.collect();
|
||||
|
||||
if user_records.is_empty() {
|
||||
return Ok(AttendanceResponse::Status {
|
||||
status: "No records found".to_string(),
|
||||
details: None,
|
||||
});
|
||||
}
|
||||
|
||||
let last_record = user_records.last().unwrap();
|
||||
let status = match last_record.command {
|
||||
AttendanceCommand::CheckIn => "Checked in",
|
||||
AttendanceCommand::CheckOut => "Checked out",
|
||||
AttendanceCommand::Break => "On break",
|
||||
AttendanceCommand::Resume => "Working",
|
||||
_ => "Unknown",
|
||||
};
|
||||
|
||||
let details = format!(
|
||||
"Last action: {} at {}",
|
||||
status,
|
||||
last_record.timestamp.format("%Y-%m-%d %H:%M:%S")
|
||||
);
|
||||
|
||||
Ok(AttendanceResponse::Status {
|
||||
status: status.to_string(),
|
||||
details: Some(details),
|
||||
})
|
||||
}
|
||||
|
||||
/// Handle report command
|
||||
async fn handle_report(
|
||||
&self,
|
||||
user_id: &str,
|
||||
parsed: &ParsedCommand,
|
||||
) -> Result<AttendanceResponse> {
|
||||
let records = self.records.read().await;
|
||||
|
||||
let user_records: Vec<_> = records
|
||||
.iter()
|
||||
.filter(|r| r.user_id == user_id)
|
||||
.collect();
|
||||
|
||||
if user_records.is_empty() {
|
||||
return Ok(AttendanceResponse::Report {
|
||||
data: "No attendance records found".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
let mut report = String::new();
|
||||
report.push_str(&format!("Attendance Report for User: {}\n", user_id));
|
||||
report.push_str("========================\n");
|
||||
|
||||
for record in user_records {
|
||||
let action = match record.command {
|
||||
AttendanceCommand::CheckIn => "Check In",
|
||||
AttendanceCommand::CheckOut => "Check Out",
|
||||
AttendanceCommand::Break => "Break",
|
||||
AttendanceCommand::Resume => "Resume",
|
||||
_ => "Other",
|
||||
};
|
||||
|
||||
report.push_str(&format!(
|
||||
"{}: {} at {}\n",
|
||||
record.timestamp.format("%Y-%m-%d %H:%M:%S"),
|
||||
action,
|
||||
record.location.as_deref().unwrap_or("N/A")
|
||||
));
|
||||
}
|
||||
|
||||
Ok(AttendanceResponse::Report { data: report })
|
||||
}
|
||||
|
||||
/// Handle override command (for admins)
|
||||
async fn handle_override(
|
||||
&self,
|
||||
user_id: &str,
|
||||
parsed: &ParsedCommand,
|
||||
) -> Result<AttendanceResponse> {
|
||||
if parsed.args.len() < 2 {
|
||||
return Ok(AttendanceResponse::Error {
|
||||
message: "Override requires target user and action".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
let target_user = &parsed.args[0];
|
||||
let action = &parsed.args[1];
|
||||
|
||||
// In a real implementation, check admin permissions here
|
||||
log::warn!(
|
||||
"Override command by {} for user {}: {}",
|
||||
user_id,
|
||||
target_user,
|
||||
action
|
||||
);
|
||||
|
||||
Ok(AttendanceResponse::Success {
|
||||
message: format!("Override applied for user {}", target_user),
|
||||
timestamp: parsed.timestamp,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get all records for a user
|
||||
pub async fn get_user_records(&self, user_id: &str) -> Vec<AttendanceRecord> {
|
||||
let records = self.records.read().await;
|
||||
records
|
||||
.iter()
|
||||
.filter(|r| r.user_id == user_id)
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Clear all records (for testing)
|
||||
pub async fn clear_records(&self) {
|
||||
let mut records = self.records.write().await;
|
||||
records.clear();
|
||||
}
|
||||
|
||||
/// Get total work time for a user today
|
||||
pub async fn get_today_work_time(&self, user_id: &str) -> Duration {
|
||||
let records = self.records.read().await;
|
||||
let today = Local::now().date_naive();
|
||||
|
||||
let mut total_duration = Duration::zero();
|
||||
let mut last_checkin: Option<DateTime<Utc>> = None;
|
||||
|
||||
for record in records.iter().filter(|r| r.user_id == user_id) {
|
||||
if record.timestamp.with_timezone(&Local).date_naive() != today {
|
||||
continue;
|
||||
}
|
||||
|
||||
match record.command {
|
||||
AttendanceCommand::CheckIn => {
|
||||
last_checkin = Some(record.timestamp);
|
||||
}
|
||||
AttendanceCommand::CheckOut => {
|
||||
if let Some(checkin) = last_checkin {
|
||||
total_duration = total_duration + (record.timestamp - checkin);
|
||||
last_checkin = None;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// If still checked in, add time until now
|
||||
if let Some(checkin) = last_checkin {
|
||||
total_duration = total_duration + (Utc::now() - checkin);
|
||||
}
|
||||
|
||||
total_duration
|
||||
}
|
||||
}
|
||||
|
||||
/// Response from attendance service
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum AttendanceResponse {
|
||||
Success {
|
||||
message: String,
|
||||
timestamp: DateTime<Utc>,
|
||||
},
|
||||
Error {
|
||||
message: String,
|
||||
},
|
||||
Status {
|
||||
status: String,
|
||||
details: Option<String>,
|
||||
},
|
||||
Report {
|
||||
data: String,
|
||||
},
|
||||
}
|
||||
|
|
@ -55,7 +55,7 @@ async fn execute_create_draft(
|
|||
to: to.to_string(),
|
||||
subject: subject.to_string(),
|
||||
cc: None,
|
||||
text: email_body,
|
||||
body: email_body,
|
||||
};
|
||||
|
||||
save_email_draft(&config.email, &draft_request)
|
||||
|
|
|
|||
|
|
@ -211,15 +211,13 @@ async fn execute_send_mail(
|
|||
{
|
||||
use crate::email::EmailService;
|
||||
|
||||
let email_service = EmailService::new(state.clone());
|
||||
let email_service = EmailService::new(Arc::new(state.as_ref().clone()));
|
||||
|
||||
if let Ok(_) = email_service
|
||||
.send_email(
|
||||
&to,
|
||||
&subject,
|
||||
&body,
|
||||
None, // cc
|
||||
None, // bcc
|
||||
if attachments.is_empty() {
|
||||
None
|
||||
} else {
|
||||
|
|
|
|||
1226
src/calendar/mod.rs
1226
src/calendar/mod.rs
File diff suppressed because it is too large
Load diff
463
src/compliance/access_review.rs
Normal file
463
src/compliance/access_review.rs
Normal file
|
|
@ -0,0 +1,463 @@
|
|||
//! Access Review Module
|
||||
//!
|
||||
//! Provides automated access review and permission auditing capabilities
|
||||
//! for compliance with security policies and regulations.
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Access level enumeration
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum AccessLevel {
|
||||
Read,
|
||||
Write,
|
||||
Admin,
|
||||
Owner,
|
||||
}
|
||||
|
||||
/// Resource type enumeration
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum ResourceType {
|
||||
File,
|
||||
Database,
|
||||
API,
|
||||
System,
|
||||
Application,
|
||||
}
|
||||
|
||||
/// Access permission structure
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AccessPermission {
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub resource_id: String,
|
||||
pub resource_type: ResourceType,
|
||||
pub access_level: AccessLevel,
|
||||
pub granted_at: DateTime<Utc>,
|
||||
pub granted_by: Uuid,
|
||||
pub expires_at: Option<DateTime<Utc>>,
|
||||
pub justification: String,
|
||||
pub is_active: bool,
|
||||
}
|
||||
|
||||
/// Access review request
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AccessReviewRequest {
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub reviewer_id: Uuid,
|
||||
pub permissions: Vec<AccessPermission>,
|
||||
pub requested_at: DateTime<Utc>,
|
||||
pub due_date: DateTime<Utc>,
|
||||
pub status: ReviewStatus,
|
||||
pub comments: Option<String>,
|
||||
}
|
||||
|
||||
/// Review status
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum ReviewStatus {
|
||||
Pending,
|
||||
InProgress,
|
||||
Approved,
|
||||
Rejected,
|
||||
Expired,
|
||||
}
|
||||
|
||||
/// Access review result
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AccessReviewResult {
|
||||
pub review_id: Uuid,
|
||||
pub reviewer_id: Uuid,
|
||||
pub reviewed_at: DateTime<Utc>,
|
||||
pub approved_permissions: Vec<Uuid>,
|
||||
pub revoked_permissions: Vec<Uuid>,
|
||||
pub modified_permissions: Vec<(Uuid, AccessLevel)>,
|
||||
pub comments: String,
|
||||
}
|
||||
|
||||
/// Access violation detection
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AccessViolation {
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub resource_id: String,
|
||||
pub attempted_action: String,
|
||||
pub denied_reason: String,
|
||||
pub occurred_at: DateTime<Utc>,
|
||||
pub severity: ViolationSeverity,
|
||||
}
|
||||
|
||||
/// Violation severity levels
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum ViolationSeverity {
|
||||
Low,
|
||||
Medium,
|
||||
High,
|
||||
Critical,
|
||||
}
|
||||
|
||||
/// Access review service
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AccessReviewService {
|
||||
permissions: HashMap<Uuid, Vec<AccessPermission>>,
|
||||
reviews: HashMap<Uuid, AccessReviewRequest>,
|
||||
violations: Vec<AccessViolation>,
|
||||
}
|
||||
|
||||
impl AccessReviewService {
|
||||
/// Create new access review service
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
permissions: HashMap::new(),
|
||||
reviews: HashMap::new(),
|
||||
violations: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Grant access permission
|
||||
pub fn grant_permission(
|
||||
&mut self,
|
||||
user_id: Uuid,
|
||||
resource_id: String,
|
||||
resource_type: ResourceType,
|
||||
access_level: AccessLevel,
|
||||
granted_by: Uuid,
|
||||
justification: String,
|
||||
expires_in: Option<Duration>,
|
||||
) -> Result<AccessPermission> {
|
||||
let permission = AccessPermission {
|
||||
id: Uuid::new_v4(),
|
||||
user_id,
|
||||
resource_id,
|
||||
resource_type,
|
||||
access_level,
|
||||
granted_at: Utc::now(),
|
||||
granted_by,
|
||||
expires_at: expires_in.map(|d| Utc::now() + d),
|
||||
justification,
|
||||
is_active: true,
|
||||
};
|
||||
|
||||
self.permissions
|
||||
.entry(user_id)
|
||||
.or_insert_with(Vec::new)
|
||||
.push(permission.clone());
|
||||
|
||||
log::info!(
|
||||
"Granted {} access to user {} for resource {}",
|
||||
serde_json::to_string(&permission.access_level)?,
|
||||
user_id,
|
||||
permission.resource_id
|
||||
);
|
||||
|
||||
Ok(permission)
|
||||
}
|
||||
|
||||
/// Revoke access permission
|
||||
pub fn revoke_permission(&mut self, permission_id: Uuid, revoked_by: Uuid) -> Result<()> {
|
||||
for permissions in self.permissions.values_mut() {
|
||||
if let Some(perm) = permissions.iter_mut().find(|p| p.id == permission_id) {
|
||||
perm.is_active = false;
|
||||
log::info!(
|
||||
"Revoked permission {} for user {} by {}",
|
||||
permission_id,
|
||||
perm.user_id,
|
||||
revoked_by
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
Err(anyhow!("Permission not found"))
|
||||
}
|
||||
|
||||
/// Check if user has access
|
||||
pub fn check_access(
|
||||
&mut self,
|
||||
user_id: Uuid,
|
||||
resource_id: &str,
|
||||
required_level: AccessLevel,
|
||||
) -> Result<bool> {
|
||||
let user_permissions = self.permissions.get(&user_id);
|
||||
|
||||
if let Some(permissions) = user_permissions {
|
||||
for perm in permissions {
|
||||
if perm.resource_id == resource_id && perm.is_active {
|
||||
// Check expiration
|
||||
if let Some(expires) = perm.expires_at {
|
||||
if expires < Utc::now() {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Check access level
|
||||
if self.has_sufficient_access(&perm.access_level, &required_level) {
|
||||
return Ok(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Log access denial
|
||||
let violation = AccessViolation {
|
||||
id: Uuid::new_v4(),
|
||||
user_id,
|
||||
resource_id: resource_id.to_string(),
|
||||
attempted_action: format!("{:?} access", required_level),
|
||||
denied_reason: "Insufficient permissions".to_string(),
|
||||
occurred_at: Utc::now(),
|
||||
severity: ViolationSeverity::Medium,
|
||||
};
|
||||
|
||||
self.violations.push(violation);
|
||||
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
/// Check if access level is sufficient
|
||||
fn has_sufficient_access(&self, user_level: &AccessLevel, required: &AccessLevel) -> bool {
|
||||
match required {
|
||||
AccessLevel::Read => true,
|
||||
AccessLevel::Write => matches!(
|
||||
user_level,
|
||||
AccessLevel::Write | AccessLevel::Admin | AccessLevel::Owner
|
||||
),
|
||||
AccessLevel::Admin => matches!(user_level, AccessLevel::Admin | AccessLevel::Owner),
|
||||
AccessLevel::Owner => matches!(user_level, AccessLevel::Owner),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create access review request
|
||||
pub fn create_review_request(
|
||||
&mut self,
|
||||
user_id: Uuid,
|
||||
reviewer_id: Uuid,
|
||||
days_until_due: i64,
|
||||
) -> Result<AccessReviewRequest> {
|
||||
let user_permissions = self.permissions.get(&user_id).cloned().unwrap_or_default();
|
||||
|
||||
let review = AccessReviewRequest {
|
||||
id: Uuid::new_v4(),
|
||||
user_id,
|
||||
reviewer_id,
|
||||
permissions: user_permissions,
|
||||
requested_at: Utc::now(),
|
||||
due_date: Utc::now() + Duration::days(days_until_due),
|
||||
status: ReviewStatus::Pending,
|
||||
comments: None,
|
||||
};
|
||||
|
||||
self.reviews.insert(review.id, review.clone());
|
||||
|
||||
log::info!(
|
||||
"Created access review {} for user {} assigned to {}",
|
||||
review.id,
|
||||
user_id,
|
||||
reviewer_id
|
||||
);
|
||||
|
||||
Ok(review)
|
||||
}
|
||||
|
||||
/// Process access review
|
||||
pub fn process_review(
|
||||
&mut self,
|
||||
review_id: Uuid,
|
||||
approved: Vec<Uuid>,
|
||||
revoked: Vec<Uuid>,
|
||||
modified: Vec<(Uuid, AccessLevel)>,
|
||||
comments: String,
|
||||
) -> Result<AccessReviewResult> {
|
||||
let review = self
|
||||
.reviews
|
||||
.get_mut(&review_id)
|
||||
.ok_or_else(|| anyhow!("Review not found"))?;
|
||||
|
||||
if review.status != ReviewStatus::Pending && review.status != ReviewStatus::InProgress {
|
||||
return Err(anyhow!("Review already completed"));
|
||||
}
|
||||
|
||||
// Process revocations
|
||||
for perm_id in &revoked {
|
||||
self.revoke_permission(*perm_id, review.reviewer_id)?;
|
||||
}
|
||||
|
||||
// Process modifications
|
||||
for (perm_id, new_level) in &modified {
|
||||
if let Some(permissions) = self.permissions.get_mut(&review.user_id) {
|
||||
if let Some(perm) = permissions.iter_mut().find(|p| p.id == *perm_id) {
|
||||
perm.access_level = new_level.clone();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
review.status = ReviewStatus::Approved;
|
||||
review.comments = Some(comments.clone());
|
||||
|
||||
let result = AccessReviewResult {
|
||||
review_id,
|
||||
reviewer_id: review.reviewer_id,
|
||||
reviewed_at: Utc::now(),
|
||||
approved_permissions: approved,
|
||||
revoked_permissions: revoked,
|
||||
modified_permissions: modified,
|
||||
comments,
|
||||
};
|
||||
|
||||
log::info!("Completed access review {} with result", review_id);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Get expired permissions
|
||||
pub fn get_expired_permissions(&self) -> Vec<AccessPermission> {
|
||||
let now = Utc::now();
|
||||
let mut expired = Vec::new();
|
||||
|
||||
for permissions in self.permissions.values() {
|
||||
for perm in permissions {
|
||||
if let Some(expires) = perm.expires_at {
|
||||
if expires < now && perm.is_active {
|
||||
expired.push(perm.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
expired
|
||||
}
|
||||
|
||||
/// Get user permissions
|
||||
pub fn get_user_permissions(&self, user_id: Uuid) -> Vec<AccessPermission> {
|
||||
self.permissions
|
||||
.get(&user_id)
|
||||
.cloned()
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.filter(|p| p.is_active)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get pending reviews
|
||||
pub fn get_pending_reviews(&self, reviewer_id: Option<Uuid>) -> Vec<AccessReviewRequest> {
|
||||
self.reviews
|
||||
.values()
|
||||
.filter(|r| {
|
||||
r.status == ReviewStatus::Pending
|
||||
&& reviewer_id.map_or(true, |id| r.reviewer_id == id)
|
||||
})
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get access violations
|
||||
pub fn get_violations(
|
||||
&self,
|
||||
user_id: Option<Uuid>,
|
||||
severity: Option<ViolationSeverity>,
|
||||
since: Option<DateTime<Utc>>,
|
||||
) -> Vec<AccessViolation> {
|
||||
self.violations
|
||||
.iter()
|
||||
.filter(|v| {
|
||||
user_id.map_or(true, |id| v.user_id == id)
|
||||
&& severity.as_ref().map_or(true, |s| &v.severity == s)
|
||||
&& since.map_or(true, |d| v.occurred_at >= d)
|
||||
})
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Generate access compliance report
|
||||
pub fn generate_compliance_report(&self) -> AccessComplianceReport {
|
||||
let total_permissions = self.permissions.values().map(|p| p.len()).sum::<usize>();
|
||||
|
||||
let active_permissions = self
|
||||
.permissions
|
||||
.values()
|
||||
.flat_map(|p| p.iter())
|
||||
.filter(|p| p.is_active)
|
||||
.count();
|
||||
|
||||
let expired_permissions = self.get_expired_permissions().len();
|
||||
|
||||
let pending_reviews = self
|
||||
.reviews
|
||||
.values()
|
||||
.filter(|r| r.status == ReviewStatus::Pending)
|
||||
.count();
|
||||
|
||||
let violations_last_30_days = self
|
||||
.violations
|
||||
.iter()
|
||||
.filter(|v| v.occurred_at > Utc::now() - Duration::days(30))
|
||||
.count();
|
||||
|
||||
let critical_violations = self
|
||||
.violations
|
||||
.iter()
|
||||
.filter(|v| v.severity == ViolationSeverity::Critical)
|
||||
.count();
|
||||
|
||||
AccessComplianceReport {
|
||||
generated_at: Utc::now(),
|
||||
total_permissions,
|
||||
active_permissions,
|
||||
expired_permissions,
|
||||
pending_reviews,
|
||||
violations_last_30_days,
|
||||
critical_violations,
|
||||
compliance_score: self.calculate_compliance_score(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate compliance score
|
||||
fn calculate_compliance_score(&self) -> f64 {
|
||||
let mut score = 100.0;
|
||||
|
||||
// Deduct for expired permissions
|
||||
let expired = self.get_expired_permissions().len();
|
||||
score -= expired as f64 * 2.0;
|
||||
|
||||
// Deduct for overdue reviews
|
||||
let overdue_reviews = self
|
||||
.reviews
|
||||
.values()
|
||||
.filter(|r| r.status == ReviewStatus::Pending && r.due_date < Utc::now())
|
||||
.count();
|
||||
score -= overdue_reviews as f64 * 5.0;
|
||||
|
||||
// Deduct for violations
|
||||
for violation in &self.violations {
|
||||
match violation.severity {
|
||||
ViolationSeverity::Low => score -= 1.0,
|
||||
ViolationSeverity::Medium => score -= 3.0,
|
||||
ViolationSeverity::High => score -= 5.0,
|
||||
ViolationSeverity::Critical => score -= 10.0,
|
||||
}
|
||||
}
|
||||
|
||||
score.max(0.0).min(100.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// Access compliance report
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AccessComplianceReport {
|
||||
pub generated_at: DateTime<Utc>,
|
||||
pub total_permissions: usize,
|
||||
pub active_permissions: usize,
|
||||
pub expired_permissions: usize,
|
||||
pub pending_reviews: usize,
|
||||
pub violations_last_30_days: usize,
|
||||
pub critical_violations: usize,
|
||||
pub compliance_score: f64,
|
||||
}
|
||||
|
||||
impl Default for AccessReviewService {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
494
src/compliance/audit.rs
Normal file
494
src/compliance/audit.rs
Normal file
|
|
@ -0,0 +1,494 @@
|
|||
//! Audit Module
|
||||
//!
|
||||
//! Provides comprehensive audit logging and tracking capabilities
|
||||
//! for compliance monitoring and security analysis.
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Audit event types
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
|
||||
pub enum AuditEventType {
|
||||
UserLogin,
|
||||
UserLogout,
|
||||
PasswordChange,
|
||||
PermissionGranted,
|
||||
PermissionRevoked,
|
||||
DataAccess,
|
||||
DataModification,
|
||||
DataDeletion,
|
||||
ConfigurationChange,
|
||||
SecurityAlert,
|
||||
SystemError,
|
||||
ComplianceViolation,
|
||||
}
|
||||
|
||||
/// Audit severity levels
|
||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Hash)]
|
||||
pub enum AuditSeverity {
|
||||
Info,
|
||||
Warning,
|
||||
Error,
|
||||
Critical,
|
||||
}
|
||||
|
||||
/// Audit event structure
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AuditEvent {
|
||||
pub id: Uuid,
|
||||
pub timestamp: DateTime<Utc>,
|
||||
pub event_type: AuditEventType,
|
||||
pub severity: AuditSeverity,
|
||||
pub user_id: Option<Uuid>,
|
||||
pub session_id: Option<String>,
|
||||
pub ip_address: Option<String>,
|
||||
pub resource_id: Option<String>,
|
||||
pub action: String,
|
||||
pub outcome: AuditOutcome,
|
||||
pub details: HashMap<String, String>,
|
||||
pub metadata: serde_json::Value,
|
||||
}
|
||||
|
||||
/// Audit outcome
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
|
||||
pub enum AuditOutcome {
|
||||
Success,
|
||||
Failure,
|
||||
Partial,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
/// Audit trail for tracking related events
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AuditTrail {
|
||||
pub trail_id: Uuid,
|
||||
pub name: String,
|
||||
pub started_at: DateTime<Utc>,
|
||||
pub ended_at: Option<DateTime<Utc>>,
|
||||
pub events: Vec<Uuid>,
|
||||
pub summary: String,
|
||||
pub tags: Vec<String>,
|
||||
}
|
||||
|
||||
/// Audit retention policy
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RetentionPolicy {
|
||||
pub name: String,
|
||||
pub retention_days: i64,
|
||||
pub event_types: Vec<AuditEventType>,
|
||||
pub severity_threshold: Option<AuditSeverity>,
|
||||
pub archive_enabled: bool,
|
||||
}
|
||||
|
||||
/// Audit statistics
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AuditStatistics {
|
||||
pub total_events: usize,
|
||||
pub events_by_type: HashMap<AuditEventType, usize>,
|
||||
pub events_by_severity: HashMap<AuditSeverity, usize>,
|
||||
pub events_by_outcome: HashMap<AuditOutcome, usize>,
|
||||
pub unique_users: usize,
|
||||
pub time_range: (DateTime<Utc>, DateTime<Utc>),
|
||||
}
|
||||
|
||||
/// Audit service for managing audit logs
|
||||
#[derive(Clone)]
|
||||
pub struct AuditService {
|
||||
events: Arc<RwLock<VecDeque<AuditEvent>>>,
|
||||
trails: Arc<RwLock<HashMap<Uuid, AuditTrail>>>,
|
||||
retention_policies: Arc<RwLock<Vec<RetentionPolicy>>>,
|
||||
max_events: usize,
|
||||
}
|
||||
|
||||
impl AuditService {
|
||||
/// Create new audit service
|
||||
pub fn new(max_events: usize) -> Self {
|
||||
Self {
|
||||
events: Arc::new(RwLock::new(VecDeque::new())),
|
||||
trails: Arc::new(RwLock::new(HashMap::new())),
|
||||
retention_policies: Arc::new(RwLock::new(vec![
|
||||
// Default retention policies
|
||||
RetentionPolicy {
|
||||
name: "Security Events".to_string(),
|
||||
retention_days: 365,
|
||||
event_types: vec![
|
||||
AuditEventType::SecurityAlert,
|
||||
AuditEventType::ComplianceViolation,
|
||||
],
|
||||
severity_threshold: Some(AuditSeverity::Warning),
|
||||
archive_enabled: true,
|
||||
},
|
||||
RetentionPolicy {
|
||||
name: "Access Logs".to_string(),
|
||||
retention_days: 90,
|
||||
event_types: vec![
|
||||
AuditEventType::UserLogin,
|
||||
AuditEventType::UserLogout,
|
||||
AuditEventType::DataAccess,
|
||||
],
|
||||
severity_threshold: None,
|
||||
archive_enabled: false,
|
||||
},
|
||||
])),
|
||||
max_events,
|
||||
}
|
||||
}
|
||||
|
||||
/// Log an audit event
|
||||
pub async fn log_event(
|
||||
&self,
|
||||
event_type: AuditEventType,
|
||||
severity: AuditSeverity,
|
||||
user_id: Option<Uuid>,
|
||||
action: String,
|
||||
outcome: AuditOutcome,
|
||||
details: HashMap<String, String>,
|
||||
) -> Result<Uuid> {
|
||||
let event = AuditEvent {
|
||||
id: Uuid::new_v4(),
|
||||
timestamp: Utc::now(),
|
||||
event_type: event_type.clone(),
|
||||
severity: severity.clone(),
|
||||
user_id,
|
||||
session_id: None,
|
||||
ip_address: None,
|
||||
resource_id: None,
|
||||
action,
|
||||
outcome,
|
||||
details,
|
||||
metadata: serde_json::json!({}),
|
||||
};
|
||||
|
||||
let event_id = event.id;
|
||||
|
||||
// Add event to the queue
|
||||
let mut events = self.events.write().await;
|
||||
events.push_back(event.clone());
|
||||
|
||||
// Maintain max events limit
|
||||
while events.len() > self.max_events {
|
||||
events.pop_front();
|
||||
}
|
||||
|
||||
log::info!(
|
||||
"Audit event logged: {} - {:?} - {:?}",
|
||||
event_id,
|
||||
event_type,
|
||||
severity
|
||||
);
|
||||
|
||||
Ok(event_id)
|
||||
}
|
||||
|
||||
/// Create an audit trail
|
||||
pub async fn create_trail(&self, name: String, tags: Vec<String>) -> Result<Uuid> {
|
||||
let trail = AuditTrail {
|
||||
trail_id: Uuid::new_v4(),
|
||||
name,
|
||||
started_at: Utc::now(),
|
||||
ended_at: None,
|
||||
events: Vec::new(),
|
||||
summary: String::new(),
|
||||
tags,
|
||||
};
|
||||
|
||||
let trail_id = trail.trail_id;
|
||||
let mut trails = self.trails.write().await;
|
||||
trails.insert(trail_id, trail);
|
||||
|
||||
Ok(trail_id)
|
||||
}
|
||||
|
||||
/// Add event to trail
|
||||
pub async fn add_to_trail(&self, trail_id: Uuid, event_id: Uuid) -> Result<()> {
|
||||
let mut trails = self.trails.write().await;
|
||||
let trail = trails
|
||||
.get_mut(&trail_id)
|
||||
.ok_or_else(|| anyhow!("Trail not found"))?;
|
||||
|
||||
if trail.ended_at.is_some() {
|
||||
return Err(anyhow!("Trail already ended"));
|
||||
}
|
||||
|
||||
trail.events.push(event_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// End an audit trail
|
||||
pub async fn end_trail(&self, trail_id: Uuid, summary: String) -> Result<()> {
|
||||
let mut trails = self.trails.write().await;
|
||||
let trail = trails
|
||||
.get_mut(&trail_id)
|
||||
.ok_or_else(|| anyhow!("Trail not found"))?;
|
||||
|
||||
trail.ended_at = Some(Utc::now());
|
||||
trail.summary = summary;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Query audit events
|
||||
pub async fn query_events(&self, filter: AuditFilter) -> Result<Vec<AuditEvent>> {
|
||||
let events = self.events.read().await;
|
||||
|
||||
let filtered: Vec<AuditEvent> = events
|
||||
.iter()
|
||||
.filter(|e| filter.matches(e))
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
Ok(filtered)
|
||||
}
|
||||
|
||||
/// Get audit statistics
|
||||
pub async fn get_statistics(
|
||||
&self,
|
||||
since: Option<DateTime<Utc>>,
|
||||
until: Option<DateTime<Utc>>,
|
||||
) -> AuditStatistics {
|
||||
let events = self.events.read().await;
|
||||
let since = since.unwrap_or(Utc::now() - Duration::days(30));
|
||||
let until = until.unwrap_or(Utc::now());
|
||||
|
||||
let filtered_events: Vec<_> = events
|
||||
.iter()
|
||||
.filter(|e| e.timestamp >= since && e.timestamp <= until)
|
||||
.collect();
|
||||
|
||||
let mut events_by_type = HashMap::new();
|
||||
let mut events_by_severity = HashMap::new();
|
||||
let mut events_by_outcome = HashMap::new();
|
||||
let mut unique_users = std::collections::HashSet::new();
|
||||
|
||||
for event in &filtered_events {
|
||||
*events_by_type.entry(event.event_type.clone()).or_insert(0) += 1;
|
||||
*events_by_severity
|
||||
.entry(event.severity.clone())
|
||||
.or_insert(0) += 1;
|
||||
*events_by_outcome.entry(event.outcome.clone()).or_insert(0) += 1;
|
||||
|
||||
if let Some(user_id) = event.user_id {
|
||||
unique_users.insert(user_id);
|
||||
}
|
||||
}
|
||||
|
||||
AuditStatistics {
|
||||
total_events: filtered_events.len(),
|
||||
events_by_type,
|
||||
events_by_severity,
|
||||
events_by_outcome,
|
||||
unique_users: unique_users.len(),
|
||||
time_range: (since, until),
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply retention policies
|
||||
pub async fn apply_retention_policies(&self) -> Result<usize> {
|
||||
let policies = self.retention_policies.read().await;
|
||||
let mut events = self.events.write().await;
|
||||
let now = Utc::now();
|
||||
let mut removed_count = 0;
|
||||
|
||||
for policy in policies.iter() {
|
||||
let cutoff = now - Duration::days(policy.retention_days);
|
||||
|
||||
// Remove events older than retention period
|
||||
let initial_len = events.len();
|
||||
events.retain(|e| {
|
||||
if !policy.event_types.contains(&e.event_type) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if let Some(threshold) = &policy.severity_threshold {
|
||||
if e.severity < *threshold {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
e.timestamp >= cutoff
|
||||
});
|
||||
|
||||
removed_count += initial_len - events.len();
|
||||
}
|
||||
|
||||
log::info!(
|
||||
"Applied retention policies, removed {} events",
|
||||
removed_count
|
||||
);
|
||||
Ok(removed_count)
|
||||
}
|
||||
|
||||
/// Export audit logs
|
||||
pub async fn export_logs(
|
||||
&self,
|
||||
format: ExportFormat,
|
||||
filter: Option<AuditFilter>,
|
||||
) -> Result<Vec<u8>> {
|
||||
let events = self.query_events(filter.unwrap_or_default()).await?;
|
||||
|
||||
match format {
|
||||
ExportFormat::Json => {
|
||||
let json = serde_json::to_vec_pretty(&events)?;
|
||||
Ok(json)
|
||||
}
|
||||
ExportFormat::Csv => {
|
||||
let mut csv_writer = csv::Writer::from_writer(vec![]);
|
||||
|
||||
// Write headers
|
||||
csv_writer.write_record(&[
|
||||
"ID",
|
||||
"Timestamp",
|
||||
"Type",
|
||||
"Severity",
|
||||
"User",
|
||||
"Action",
|
||||
"Outcome",
|
||||
])?;
|
||||
|
||||
// Write records
|
||||
for event in events {
|
||||
csv_writer.write_record(&[
|
||||
event.id.to_string(),
|
||||
event.timestamp.to_rfc3339(),
|
||||
format!("{:?}", event.event_type),
|
||||
format!("{:?}", event.severity),
|
||||
event.user_id.map(|u| u.to_string()).unwrap_or_default(),
|
||||
event.action,
|
||||
format!("{:?}", event.outcome),
|
||||
])?;
|
||||
}
|
||||
|
||||
Ok(csv_writer.into_inner()?)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get compliance report
|
||||
pub async fn get_compliance_report(&self) -> ComplianceReport {
|
||||
let stats = self.get_statistics(None, None).await;
|
||||
let events = self.events.read().await;
|
||||
|
||||
let security_incidents = events
|
||||
.iter()
|
||||
.filter(|e| e.event_type == AuditEventType::SecurityAlert)
|
||||
.count();
|
||||
|
||||
let compliance_violations = events
|
||||
.iter()
|
||||
.filter(|e| e.event_type == AuditEventType::ComplianceViolation)
|
||||
.count();
|
||||
|
||||
let failed_logins = events
|
||||
.iter()
|
||||
.filter(|e| {
|
||||
e.event_type == AuditEventType::UserLogin && e.outcome == AuditOutcome::Failure
|
||||
})
|
||||
.count();
|
||||
|
||||
ComplianceReport {
|
||||
generated_at: Utc::now(),
|
||||
total_events: stats.total_events,
|
||||
security_incidents,
|
||||
compliance_violations,
|
||||
failed_logins,
|
||||
unique_users: stats.unique_users,
|
||||
critical_events: stats
|
||||
.events_by_severity
|
||||
.get(&AuditSeverity::Critical)
|
||||
.copied()
|
||||
.unwrap_or(0),
|
||||
audit_coverage: self.calculate_audit_coverage(&events),
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate audit coverage percentage
|
||||
fn calculate_audit_coverage(&self, events: &VecDeque<AuditEvent>) -> f64 {
|
||||
// Calculate based on expected event types coverage
|
||||
let expected_types = 12; // Total number of event types
|
||||
let covered_types = events
|
||||
.iter()
|
||||
.map(|e| e.event_type.clone())
|
||||
.collect::<std::collections::HashSet<_>>()
|
||||
.len();
|
||||
|
||||
(covered_types as f64 / expected_types as f64) * 100.0
|
||||
}
|
||||
}
|
||||
|
||||
/// Audit filter for querying events
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct AuditFilter {
|
||||
pub event_types: Option<Vec<AuditEventType>>,
|
||||
pub severity: Option<AuditSeverity>,
|
||||
pub user_id: Option<Uuid>,
|
||||
pub since: Option<DateTime<Utc>>,
|
||||
pub until: Option<DateTime<Utc>>,
|
||||
pub outcome: Option<AuditOutcome>,
|
||||
}
|
||||
|
||||
impl AuditFilter {
|
||||
fn matches(&self, event: &AuditEvent) -> bool {
|
||||
if let Some(types) = &self.event_types {
|
||||
if !types.contains(&event.event_type) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(severity) = &self.severity {
|
||||
if event.severity < *severity {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(user_id) = &self.user_id {
|
||||
if event.user_id != Some(*user_id) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(since) = &self.since {
|
||||
if event.timestamp < *since {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(until) = &self.until {
|
||||
if event.timestamp > *until {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(outcome) = &self.outcome {
|
||||
if event.outcome != *outcome {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
/// Export format for audit logs
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum ExportFormat {
|
||||
Json,
|
||||
Csv,
|
||||
}
|
||||
|
||||
/// Compliance report
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ComplianceReport {
|
||||
pub generated_at: DateTime<Utc>,
|
||||
pub total_events: usize,
|
||||
pub security_incidents: usize,
|
||||
pub compliance_violations: usize,
|
||||
pub failed_logins: usize,
|
||||
pub unique_users: usize,
|
||||
pub critical_events: usize,
|
||||
pub audit_coverage: f64,
|
||||
}
|
||||
518
src/compliance/policy_checker.rs
Normal file
518
src/compliance/policy_checker.rs
Normal file
|
|
@ -0,0 +1,518 @@
|
|||
//! Policy Checker Module
|
||||
//!
|
||||
//! Provides automated security policy checking and enforcement
|
||||
//! for compliance with organizational and regulatory requirements.
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Policy type enumeration
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum PolicyType {
|
||||
AccessControl,
|
||||
DataRetention,
|
||||
PasswordStrength,
|
||||
SessionTimeout,
|
||||
EncryptionRequired,
|
||||
AuditLogging,
|
||||
BackupFrequency,
|
||||
NetworkSecurity,
|
||||
ComplianceStandard,
|
||||
}
|
||||
|
||||
/// Policy status
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum PolicyStatus {
|
||||
Active,
|
||||
Draft,
|
||||
Deprecated,
|
||||
Archived,
|
||||
}
|
||||
|
||||
/// Policy severity
|
||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub enum PolicySeverity {
|
||||
Low,
|
||||
Medium,
|
||||
High,
|
||||
Critical,
|
||||
}
|
||||
|
||||
/// Security policy definition
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SecurityPolicy {
|
||||
pub id: Uuid,
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub policy_type: PolicyType,
|
||||
pub status: PolicyStatus,
|
||||
pub severity: PolicySeverity,
|
||||
pub rules: Vec<PolicyRule>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
pub effective_date: DateTime<Utc>,
|
||||
pub expiry_date: Option<DateTime<Utc>>,
|
||||
pub tags: Vec<String>,
|
||||
}
|
||||
|
||||
/// Policy rule
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PolicyRule {
|
||||
pub id: Uuid,
|
||||
pub name: String,
|
||||
pub condition: String,
|
||||
pub action: PolicyAction,
|
||||
pub parameters: HashMap<String, String>,
|
||||
}
|
||||
|
||||
/// Policy action
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum PolicyAction {
|
||||
Allow,
|
||||
Deny,
|
||||
Alert,
|
||||
Enforce,
|
||||
Log,
|
||||
}
|
||||
|
||||
/// Policy violation
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PolicyViolation {
|
||||
pub id: Uuid,
|
||||
pub policy_id: Uuid,
|
||||
pub rule_id: Uuid,
|
||||
pub timestamp: DateTime<Utc>,
|
||||
pub user_id: Option<Uuid>,
|
||||
pub resource: String,
|
||||
pub action_attempted: String,
|
||||
pub violation_details: String,
|
||||
pub severity: PolicySeverity,
|
||||
pub resolved: bool,
|
||||
}
|
||||
|
||||
/// Policy check result
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PolicyCheckResult {
|
||||
pub policy_id: Uuid,
|
||||
pub passed: bool,
|
||||
pub violations: Vec<PolicyViolation>,
|
||||
pub warnings: Vec<String>,
|
||||
pub timestamp: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Policy checker service
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PolicyChecker {
|
||||
policies: HashMap<Uuid, SecurityPolicy>,
|
||||
violations: Vec<PolicyViolation>,
|
||||
check_history: Vec<PolicyCheckResult>,
|
||||
}
|
||||
|
||||
impl PolicyChecker {
|
||||
/// Create new policy checker
|
||||
pub fn new() -> Self {
|
||||
let mut checker = Self {
|
||||
policies: HashMap::new(),
|
||||
violations: Vec::new(),
|
||||
check_history: Vec::new(),
|
||||
};
|
||||
|
||||
// Initialize with default policies
|
||||
checker.initialize_default_policies();
|
||||
checker
|
||||
}
|
||||
|
||||
/// Initialize default security policies
|
||||
fn initialize_default_policies(&mut self) {
|
||||
// Password policy
|
||||
let password_policy = SecurityPolicy {
|
||||
id: Uuid::new_v4(),
|
||||
name: "Password Strength Policy".to_string(),
|
||||
description: "Enforces strong password requirements".to_string(),
|
||||
policy_type: PolicyType::PasswordStrength,
|
||||
status: PolicyStatus::Active,
|
||||
severity: PolicySeverity::High,
|
||||
rules: vec![
|
||||
PolicyRule {
|
||||
id: Uuid::new_v4(),
|
||||
name: "Minimum Length".to_string(),
|
||||
condition: "password.length >= 12".to_string(),
|
||||
action: PolicyAction::Enforce,
|
||||
parameters: HashMap::from([("min_length".to_string(), "12".to_string())]),
|
||||
},
|
||||
PolicyRule {
|
||||
id: Uuid::new_v4(),
|
||||
name: "Complexity Requirements".to_string(),
|
||||
condition: "has_uppercase && has_lowercase && has_digit && has_special"
|
||||
.to_string(),
|
||||
action: PolicyAction::Enforce,
|
||||
parameters: HashMap::new(),
|
||||
},
|
||||
],
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
effective_date: Utc::now(),
|
||||
expiry_date: None,
|
||||
tags: vec!["security".to_string(), "authentication".to_string()],
|
||||
};
|
||||
|
||||
self.policies.insert(password_policy.id, password_policy);
|
||||
|
||||
// Session timeout policy
|
||||
let session_policy = SecurityPolicy {
|
||||
id: Uuid::new_v4(),
|
||||
name: "Session Timeout Policy".to_string(),
|
||||
description: "Enforces session timeout limits".to_string(),
|
||||
policy_type: PolicyType::SessionTimeout,
|
||||
status: PolicyStatus::Active,
|
||||
severity: PolicySeverity::Medium,
|
||||
rules: vec![PolicyRule {
|
||||
id: Uuid::new_v4(),
|
||||
name: "Maximum Session Duration".to_string(),
|
||||
condition: "session.duration <= 8_hours".to_string(),
|
||||
action: PolicyAction::Enforce,
|
||||
parameters: HashMap::from([
|
||||
("max_duration".to_string(), "28800".to_string()), // 8 hours in seconds
|
||||
]),
|
||||
}],
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
effective_date: Utc::now(),
|
||||
expiry_date: None,
|
||||
tags: vec!["security".to_string(), "session".to_string()],
|
||||
};
|
||||
|
||||
self.policies.insert(session_policy.id, session_policy);
|
||||
}
|
||||
|
||||
/// Add a security policy
|
||||
pub fn add_policy(&mut self, policy: SecurityPolicy) -> Result<()> {
|
||||
if self.policies.contains_key(&policy.id) {
|
||||
return Err(anyhow!("Policy already exists"));
|
||||
}
|
||||
|
||||
log::info!("Adding security policy: {}", policy.name);
|
||||
self.policies.insert(policy.id, policy);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update a security policy
|
||||
pub fn update_policy(&mut self, policy_id: Uuid, updates: SecurityPolicy) -> Result<()> {
|
||||
if let Some(existing) = self.policies.get_mut(&policy_id) {
|
||||
*existing = updates;
|
||||
existing.updated_at = Utc::now();
|
||||
log::info!("Updated policy: {}", existing.name);
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow!("Policy not found"))
|
||||
}
|
||||
}
|
||||
|
||||
/// Check password against policy
|
||||
pub fn check_password_policy(&mut self, password: &str) -> PolicyCheckResult {
|
||||
let policy = self
|
||||
.policies
|
||||
.values()
|
||||
.find(|p| {
|
||||
p.policy_type == PolicyType::PasswordStrength && p.status == PolicyStatus::Active
|
||||
})
|
||||
.cloned();
|
||||
|
||||
if let Some(policy) = policy {
|
||||
let mut violations = Vec::new();
|
||||
let mut warnings = Vec::new();
|
||||
|
||||
// Check minimum length
|
||||
if password.len() < 12 {
|
||||
violations.push(PolicyViolation {
|
||||
id: Uuid::new_v4(),
|
||||
policy_id: policy.id,
|
||||
rule_id: policy.rules[0].id,
|
||||
timestamp: Utc::now(),
|
||||
user_id: None,
|
||||
resource: "password".to_string(),
|
||||
action_attempted: "set_password".to_string(),
|
||||
violation_details: format!(
|
||||
"Password length {} is less than required 12",
|
||||
password.len()
|
||||
),
|
||||
severity: PolicySeverity::High,
|
||||
resolved: false,
|
||||
});
|
||||
}
|
||||
|
||||
// Check complexity
|
||||
let has_uppercase = password.chars().any(|c| c.is_uppercase());
|
||||
let has_lowercase = password.chars().any(|c| c.is_lowercase());
|
||||
let has_digit = password.chars().any(|c| c.is_numeric());
|
||||
let has_special = password.chars().any(|c| !c.is_alphanumeric());
|
||||
|
||||
if !(has_uppercase && has_lowercase && has_digit && has_special) {
|
||||
violations.push(PolicyViolation {
|
||||
id: Uuid::new_v4(),
|
||||
policy_id: policy.id,
|
||||
rule_id: policy.rules[1].id,
|
||||
timestamp: Utc::now(),
|
||||
user_id: None,
|
||||
resource: "password".to_string(),
|
||||
action_attempted: "set_password".to_string(),
|
||||
violation_details: "Password does not meet complexity requirements".to_string(),
|
||||
severity: PolicySeverity::High,
|
||||
resolved: false,
|
||||
});
|
||||
}
|
||||
|
||||
// Add warnings for common patterns
|
||||
if password.to_lowercase().contains("password") {
|
||||
warnings.push("Password contains the word 'password'".to_string());
|
||||
}
|
||||
|
||||
let result = PolicyCheckResult {
|
||||
policy_id: policy.id,
|
||||
passed: violations.is_empty(),
|
||||
violations: violations.clone(),
|
||||
warnings,
|
||||
timestamp: Utc::now(),
|
||||
};
|
||||
|
||||
self.violations.extend(violations);
|
||||
self.check_history.push(result.clone());
|
||||
|
||||
result
|
||||
} else {
|
||||
PolicyCheckResult {
|
||||
policy_id: Uuid::nil(),
|
||||
passed: true,
|
||||
violations: Vec::new(),
|
||||
warnings: vec!["No password policy configured".to_string()],
|
||||
timestamp: Utc::now(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Check session against policy
|
||||
pub fn check_session_policy(&mut self, session_duration_seconds: u64) -> PolicyCheckResult {
|
||||
let policy = self
|
||||
.policies
|
||||
.values()
|
||||
.find(|p| {
|
||||
p.policy_type == PolicyType::SessionTimeout && p.status == PolicyStatus::Active
|
||||
})
|
||||
.cloned();
|
||||
|
||||
if let Some(policy) = policy {
|
||||
let mut violations = Vec::new();
|
||||
|
||||
if session_duration_seconds > 28800 {
|
||||
// 8 hours
|
||||
violations.push(PolicyViolation {
|
||||
id: Uuid::new_v4(),
|
||||
policy_id: policy.id,
|
||||
rule_id: policy.rules[0].id,
|
||||
timestamp: Utc::now(),
|
||||
user_id: None,
|
||||
resource: "session".to_string(),
|
||||
action_attempted: "extend_session".to_string(),
|
||||
violation_details: format!(
|
||||
"Session duration {} seconds exceeds maximum 28800 seconds",
|
||||
session_duration_seconds
|
||||
),
|
||||
severity: PolicySeverity::Medium,
|
||||
resolved: false,
|
||||
});
|
||||
}
|
||||
|
||||
let result = PolicyCheckResult {
|
||||
policy_id: policy.id,
|
||||
passed: violations.is_empty(),
|
||||
violations: violations.clone(),
|
||||
warnings: Vec::new(),
|
||||
timestamp: Utc::now(),
|
||||
};
|
||||
|
||||
self.violations.extend(violations);
|
||||
self.check_history.push(result.clone());
|
||||
|
||||
result
|
||||
} else {
|
||||
PolicyCheckResult {
|
||||
policy_id: Uuid::nil(),
|
||||
passed: true,
|
||||
violations: Vec::new(),
|
||||
warnings: vec!["No session policy configured".to_string()],
|
||||
timestamp: Utc::now(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Check all active policies
|
||||
pub fn check_all_policies(&mut self, context: &PolicyContext) -> Vec<PolicyCheckResult> {
|
||||
let mut results = Vec::new();
|
||||
|
||||
for policy in self.policies.values() {
|
||||
if policy.status != PolicyStatus::Active {
|
||||
continue;
|
||||
}
|
||||
|
||||
let result = self.check_policy(policy.id, context);
|
||||
if let Ok(result) = result {
|
||||
results.push(result);
|
||||
}
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
/// Check a specific policy
|
||||
pub fn check_policy(
|
||||
&mut self,
|
||||
policy_id: Uuid,
|
||||
context: &PolicyContext,
|
||||
) -> Result<PolicyCheckResult> {
|
||||
let policy = self
|
||||
.policies
|
||||
.get(&policy_id)
|
||||
.ok_or_else(|| anyhow!("Policy not found"))?
|
||||
.clone();
|
||||
|
||||
let mut violations = Vec::new();
|
||||
let mut warnings = Vec::new();
|
||||
|
||||
for rule in &policy.rules {
|
||||
if !self.evaluate_rule(rule, context) {
|
||||
violations.push(PolicyViolation {
|
||||
id: Uuid::new_v4(),
|
||||
policy_id: policy.id,
|
||||
rule_id: rule.id,
|
||||
timestamp: Utc::now(),
|
||||
user_id: context.user_id,
|
||||
resource: context.resource.clone(),
|
||||
action_attempted: context.action.clone(),
|
||||
violation_details: format!("Rule '{}' failed", rule.name),
|
||||
severity: policy.severity.clone(),
|
||||
resolved: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let result = PolicyCheckResult {
|
||||
policy_id: policy.id,
|
||||
passed: violations.is_empty(),
|
||||
violations: violations.clone(),
|
||||
warnings,
|
||||
timestamp: Utc::now(),
|
||||
};
|
||||
|
||||
self.violations.extend(violations);
|
||||
self.check_history.push(result.clone());
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Evaluate a policy rule
|
||||
fn evaluate_rule(&self, rule: &PolicyRule, _context: &PolicyContext) -> bool {
|
||||
// Simplified rule evaluation
|
||||
// In production, this would use a proper expression evaluator
|
||||
match rule.action {
|
||||
PolicyAction::Allow => true,
|
||||
PolicyAction::Deny => false,
|
||||
_ => true, // For Alert, Enforce, Log actions, consider as passed but take action
|
||||
}
|
||||
}
|
||||
|
||||
/// Get policy violations
|
||||
pub fn get_violations(&self, unresolved_only: bool) -> Vec<PolicyViolation> {
|
||||
if unresolved_only {
|
||||
self.violations
|
||||
.iter()
|
||||
.filter(|v| !v.resolved)
|
||||
.cloned()
|
||||
.collect()
|
||||
} else {
|
||||
self.violations.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve a violation
|
||||
pub fn resolve_violation(&mut self, violation_id: Uuid) -> Result<()> {
|
||||
if let Some(violation) = self.violations.iter_mut().find(|v| v.id == violation_id) {
|
||||
violation.resolved = true;
|
||||
log::info!("Resolved violation: {}", violation_id);
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow!("Violation not found"))
|
||||
}
|
||||
}
|
||||
|
||||
/// Get policy compliance report
|
||||
pub fn get_compliance_report(&self) -> PolicyComplianceReport {
|
||||
let total_policies = self.policies.len();
|
||||
let active_policies = self
|
||||
.policies
|
||||
.values()
|
||||
.filter(|p| p.status == PolicyStatus::Active)
|
||||
.count();
|
||||
let total_violations = self.violations.len();
|
||||
let unresolved_violations = self.violations.iter().filter(|v| !v.resolved).count();
|
||||
let critical_violations = self
|
||||
.violations
|
||||
.iter()
|
||||
.filter(|v| v.severity == PolicySeverity::Critical)
|
||||
.count();
|
||||
|
||||
let recent_checks = self
|
||||
.check_history
|
||||
.iter()
|
||||
.filter(|c| c.timestamp > Utc::now() - Duration::days(7))
|
||||
.count();
|
||||
|
||||
let compliance_rate = if !self.check_history.is_empty() {
|
||||
let passed = self.check_history.iter().filter(|c| c.passed).count();
|
||||
(passed as f64 / self.check_history.len() as f64) * 100.0
|
||||
} else {
|
||||
100.0
|
||||
};
|
||||
|
||||
PolicyComplianceReport {
|
||||
generated_at: Utc::now(),
|
||||
total_policies,
|
||||
active_policies,
|
||||
total_violations,
|
||||
unresolved_violations,
|
||||
critical_violations,
|
||||
recent_checks,
|
||||
compliance_rate,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Policy context for evaluation
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PolicyContext {
|
||||
pub user_id: Option<Uuid>,
|
||||
pub resource: String,
|
||||
pub action: String,
|
||||
pub parameters: HashMap<String, String>,
|
||||
}
|
||||
|
||||
/// Policy compliance report
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PolicyComplianceReport {
|
||||
pub generated_at: DateTime<Utc>,
|
||||
pub total_policies: usize,
|
||||
pub active_policies: usize,
|
||||
pub total_violations: usize,
|
||||
pub unresolved_violations: usize,
|
||||
pub critical_violations: usize,
|
||||
pub recent_checks: usize,
|
||||
pub compliance_rate: f64,
|
||||
}
|
||||
|
||||
impl Default for PolicyChecker {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
534
src/compliance/risk_assessment.rs
Normal file
534
src/compliance/risk_assessment.rs
Normal file
|
|
@ -0,0 +1,534 @@
|
|||
//! Risk Assessment Module
|
||||
//!
|
||||
//! Provides comprehensive risk assessment and management capabilities
|
||||
//! for identifying, evaluating, and mitigating security risks.
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Risk category enumeration
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum RiskCategory {
|
||||
Security,
|
||||
Compliance,
|
||||
Operational,
|
||||
Financial,
|
||||
Reputational,
|
||||
Technical,
|
||||
Legal,
|
||||
}
|
||||
|
||||
/// Risk likelihood levels
|
||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub enum Likelihood {
|
||||
Rare,
|
||||
Unlikely,
|
||||
Possible,
|
||||
Likely,
|
||||
AlmostCertain,
|
||||
}
|
||||
|
||||
/// Risk impact levels
|
||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub enum Impact {
|
||||
Negligible,
|
||||
Minor,
|
||||
Moderate,
|
||||
Major,
|
||||
Catastrophic,
|
||||
}
|
||||
|
||||
/// Risk level based on likelihood and impact
|
||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub enum RiskLevel {
|
||||
Low,
|
||||
Medium,
|
||||
High,
|
||||
Critical,
|
||||
}
|
||||
|
||||
/// Risk status
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum RiskStatus {
|
||||
Identified,
|
||||
Assessed,
|
||||
Mitigating,
|
||||
Monitoring,
|
||||
Accepted,
|
||||
Closed,
|
||||
}
|
||||
|
||||
/// Risk assessment
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RiskAssessment {
|
||||
pub id: Uuid,
|
||||
pub title: String,
|
||||
pub description: String,
|
||||
pub category: RiskCategory,
|
||||
pub likelihood: Likelihood,
|
||||
pub impact: Impact,
|
||||
pub risk_level: RiskLevel,
|
||||
pub status: RiskStatus,
|
||||
pub identified_date: DateTime<Utc>,
|
||||
pub assessed_date: Option<DateTime<Utc>>,
|
||||
pub owner: String,
|
||||
pub affected_assets: Vec<String>,
|
||||
pub vulnerabilities: Vec<Vulnerability>,
|
||||
pub threats: Vec<Threat>,
|
||||
pub controls: Vec<Control>,
|
||||
pub residual_risk: Option<RiskLevel>,
|
||||
}
|
||||
|
||||
/// Vulnerability definition
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Vulnerability {
|
||||
pub id: Uuid,
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub severity: RiskLevel,
|
||||
pub cve_id: Option<String>,
|
||||
pub discovered_date: DateTime<Utc>,
|
||||
pub patched: bool,
|
||||
}
|
||||
|
||||
/// Threat definition
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Threat {
|
||||
pub id: Uuid,
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub threat_actor: String,
|
||||
pub likelihood: Likelihood,
|
||||
pub tactics: Vec<String>,
|
||||
}
|
||||
|
||||
/// Control measure
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Control {
|
||||
pub id: Uuid,
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub control_type: ControlType,
|
||||
pub effectiveness: Effectiveness,
|
||||
pub implementation_status: ImplementationStatus,
|
||||
pub cost: f64,
|
||||
}
|
||||
|
||||
/// Control type
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum ControlType {
|
||||
Preventive,
|
||||
Detective,
|
||||
Corrective,
|
||||
Compensating,
|
||||
}
|
||||
|
||||
/// Control effectiveness
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum Effectiveness {
|
||||
Ineffective,
|
||||
PartiallyEffective,
|
||||
Effective,
|
||||
HighlyEffective,
|
||||
}
|
||||
|
||||
/// Implementation status
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum ImplementationStatus {
|
||||
Planned,
|
||||
InProgress,
|
||||
Implemented,
|
||||
Verified,
|
||||
}
|
||||
|
||||
/// Risk mitigation plan
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MitigationPlan {
|
||||
pub id: Uuid,
|
||||
pub risk_id: Uuid,
|
||||
pub strategy: MitigationStrategy,
|
||||
pub actions: Vec<MitigationAction>,
|
||||
pub timeline: Duration,
|
||||
pub budget: f64,
|
||||
pub responsible_party: String,
|
||||
pub approval_status: ApprovalStatus,
|
||||
}
|
||||
|
||||
/// Mitigation strategy
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum MitigationStrategy {
|
||||
Avoid,
|
||||
Transfer,
|
||||
Mitigate,
|
||||
Accept,
|
||||
}
|
||||
|
||||
/// Mitigation action
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MitigationAction {
|
||||
pub id: Uuid,
|
||||
pub description: String,
|
||||
pub due_date: DateTime<Utc>,
|
||||
pub assigned_to: String,
|
||||
pub completed: bool,
|
||||
}
|
||||
|
||||
/// Approval status
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum ApprovalStatus {
|
||||
Pending,
|
||||
Approved,
|
||||
Rejected,
|
||||
}
|
||||
|
||||
/// Risk assessment service
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RiskAssessmentService {
|
||||
assessments: HashMap<Uuid, RiskAssessment>,
|
||||
mitigation_plans: HashMap<Uuid, MitigationPlan>,
|
||||
risk_matrix: RiskMatrix,
|
||||
}
|
||||
|
||||
impl RiskAssessmentService {
|
||||
/// Create new risk assessment service
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
assessments: HashMap::new(),
|
||||
mitigation_plans: HashMap::new(),
|
||||
risk_matrix: RiskMatrix::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new risk assessment
|
||||
pub fn create_assessment(
|
||||
&mut self,
|
||||
title: String,
|
||||
description: String,
|
||||
category: RiskCategory,
|
||||
owner: String,
|
||||
) -> Result<RiskAssessment> {
|
||||
let assessment = RiskAssessment {
|
||||
id: Uuid::new_v4(),
|
||||
title,
|
||||
description,
|
||||
category,
|
||||
likelihood: Likelihood::Possible,
|
||||
impact: Impact::Moderate,
|
||||
risk_level: RiskLevel::Medium,
|
||||
status: RiskStatus::Identified,
|
||||
identified_date: Utc::now(),
|
||||
assessed_date: None,
|
||||
owner,
|
||||
affected_assets: Vec::new(),
|
||||
vulnerabilities: Vec::new(),
|
||||
threats: Vec::new(),
|
||||
controls: Vec::new(),
|
||||
residual_risk: None,
|
||||
};
|
||||
|
||||
self.assessments.insert(assessment.id, assessment.clone());
|
||||
log::info!("Created risk assessment: {}", assessment.id);
|
||||
|
||||
Ok(assessment)
|
||||
}
|
||||
|
||||
/// Assess risk level
|
||||
pub fn assess_risk(
|
||||
&mut self,
|
||||
risk_id: Uuid,
|
||||
likelihood: Likelihood,
|
||||
impact: Impact,
|
||||
) -> Result<RiskLevel> {
|
||||
let assessment = self
|
||||
.assessments
|
||||
.get_mut(&risk_id)
|
||||
.ok_or_else(|| anyhow!("Risk assessment not found"))?;
|
||||
|
||||
assessment.likelihood = likelihood.clone();
|
||||
assessment.impact = impact.clone();
|
||||
assessment.risk_level = self.risk_matrix.calculate_risk_level(&likelihood, &impact);
|
||||
assessment.assessed_date = Some(Utc::now());
|
||||
assessment.status = RiskStatus::Assessed;
|
||||
|
||||
log::info!(
|
||||
"Assessed risk {}: level = {:?}",
|
||||
risk_id,
|
||||
assessment.risk_level
|
||||
);
|
||||
|
||||
Ok(assessment.risk_level.clone())
|
||||
}
|
||||
|
||||
/// Add vulnerability to risk assessment
|
||||
pub fn add_vulnerability(
|
||||
&mut self,
|
||||
risk_id: Uuid,
|
||||
vulnerability: Vulnerability,
|
||||
) -> Result<()> {
|
||||
let assessment = self
|
||||
.assessments
|
||||
.get_mut(&risk_id)
|
||||
.ok_or_else(|| anyhow!("Risk assessment not found"))?;
|
||||
|
||||
assessment.vulnerabilities.push(vulnerability);
|
||||
self.recalculate_risk_level(risk_id)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Add threat to risk assessment
|
||||
pub fn add_threat(&mut self, risk_id: Uuid, threat: Threat) -> Result<()> {
|
||||
let assessment = self
|
||||
.assessments
|
||||
.get_mut(&risk_id)
|
||||
.ok_or_else(|| anyhow!("Risk assessment not found"))?;
|
||||
|
||||
assessment.threats.push(threat);
|
||||
self.recalculate_risk_level(risk_id)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Add control to risk assessment
|
||||
pub fn add_control(&mut self, risk_id: Uuid, control: Control) -> Result<()> {
|
||||
let assessment = self
|
||||
.assessments
|
||||
.get_mut(&risk_id)
|
||||
.ok_or_else(|| anyhow!("Risk assessment not found"))?;
|
||||
|
||||
assessment.controls.push(control);
|
||||
self.calculate_residual_risk(risk_id)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Recalculate risk level based on vulnerabilities and threats
|
||||
fn recalculate_risk_level(&mut self, risk_id: Uuid) -> Result<()> {
|
||||
let assessment = self
|
||||
.assessments
|
||||
.get_mut(&risk_id)
|
||||
.ok_or_else(|| anyhow!("Risk assessment not found"))?;
|
||||
|
||||
// Adjust likelihood based on threats
|
||||
if !assessment.threats.is_empty() {
|
||||
let max_threat_likelihood = assessment
|
||||
.threats
|
||||
.iter()
|
||||
.map(|t| &t.likelihood)
|
||||
.max()
|
||||
.cloned()
|
||||
.unwrap_or(Likelihood::Possible);
|
||||
|
||||
if max_threat_likelihood > assessment.likelihood {
|
||||
assessment.likelihood = max_threat_likelihood;
|
||||
}
|
||||
}
|
||||
|
||||
// Adjust impact based on vulnerabilities
|
||||
if !assessment.vulnerabilities.is_empty() {
|
||||
let critical_vulns = assessment
|
||||
.vulnerabilities
|
||||
.iter()
|
||||
.filter(|v| v.severity == RiskLevel::Critical)
|
||||
.count();
|
||||
|
||||
if critical_vulns > 0 && assessment.impact < Impact::Major {
|
||||
assessment.impact = Impact::Major;
|
||||
}
|
||||
}
|
||||
|
||||
assessment.risk_level =
|
||||
self.risk_matrix
|
||||
.calculate_risk_level(&assessment.likelihood, &assessment.impact);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Calculate residual risk after controls
|
||||
fn calculate_residual_risk(&mut self, risk_id: Uuid) -> Result<()> {
|
||||
let assessment = self
|
||||
.assessments
|
||||
.get_mut(&risk_id)
|
||||
.ok_or_else(|| anyhow!("Risk assessment not found"))?;
|
||||
|
||||
if assessment.controls.is_empty() {
|
||||
assessment.residual_risk = Some(assessment.risk_level.clone());
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Calculate effectiveness of controls
|
||||
let effective_controls = assessment
|
||||
.controls
|
||||
.iter()
|
||||
.filter(|c| {
|
||||
c.effectiveness == Effectiveness::Effective
|
||||
|| c.effectiveness == Effectiveness::HighlyEffective
|
||||
})
|
||||
.count();
|
||||
|
||||
let residual = match (assessment.risk_level.clone(), effective_controls) {
|
||||
(RiskLevel::Critical, n) if n >= 3 => RiskLevel::High,
|
||||
(RiskLevel::Critical, n) if n >= 1 => RiskLevel::Critical,
|
||||
(RiskLevel::High, n) if n >= 2 => RiskLevel::Medium,
|
||||
(RiskLevel::High, n) if n >= 1 => RiskLevel::High,
|
||||
(RiskLevel::Medium, n) if n >= 1 => RiskLevel::Low,
|
||||
(level, _) => level,
|
||||
};
|
||||
|
||||
assessment.residual_risk = Some(residual);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create mitigation plan
|
||||
pub fn create_mitigation_plan(
|
||||
&mut self,
|
||||
risk_id: Uuid,
|
||||
strategy: MitigationStrategy,
|
||||
timeline: Duration,
|
||||
budget: f64,
|
||||
responsible_party: String,
|
||||
) -> Result<MitigationPlan> {
|
||||
if !self.assessments.contains_key(&risk_id) {
|
||||
return Err(anyhow!("Risk assessment not found"));
|
||||
}
|
||||
|
||||
let plan = MitigationPlan {
|
||||
id: Uuid::new_v4(),
|
||||
risk_id,
|
||||
strategy,
|
||||
actions: Vec::new(),
|
||||
timeline,
|
||||
budget,
|
||||
responsible_party,
|
||||
approval_status: ApprovalStatus::Pending,
|
||||
};
|
||||
|
||||
self.mitigation_plans.insert(plan.id, plan.clone());
|
||||
log::info!("Created mitigation plan {} for risk {}", plan.id, risk_id);
|
||||
|
||||
Ok(plan)
|
||||
}
|
||||
|
||||
/// Get high-risk assessments
|
||||
pub fn get_high_risk_assessments(&self) -> Vec<RiskAssessment> {
|
||||
self.assessments
|
||||
.values()
|
||||
.filter(|a| a.risk_level >= RiskLevel::High)
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get risk dashboard
|
||||
pub fn get_risk_dashboard(&self) -> RiskDashboard {
|
||||
let total_risks = self.assessments.len();
|
||||
let mut risks_by_level = HashMap::new();
|
||||
let mut risks_by_category = HashMap::new();
|
||||
let mut risks_by_status = HashMap::new();
|
||||
|
||||
for assessment in self.assessments.values() {
|
||||
*risks_by_level
|
||||
.entry(assessment.risk_level.clone())
|
||||
.or_insert(0) += 1;
|
||||
*risks_by_category
|
||||
.entry(assessment.category.clone())
|
||||
.or_insert(0) += 1;
|
||||
*risks_by_status
|
||||
.entry(assessment.status.clone())
|
||||
.or_insert(0) += 1;
|
||||
}
|
||||
|
||||
let mitigation_plans_pending = self
|
||||
.mitigation_plans
|
||||
.values()
|
||||
.filter(|p| p.approval_status == ApprovalStatus::Pending)
|
||||
.count();
|
||||
|
||||
RiskDashboard {
|
||||
total_risks,
|
||||
risks_by_level,
|
||||
risks_by_category,
|
||||
risks_by_status,
|
||||
mitigation_plans_pending,
|
||||
last_updated: Utc::now(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Risk matrix for calculating risk levels
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RiskMatrix {
|
||||
matrix: HashMap<(Likelihood, Impact), RiskLevel>,
|
||||
}
|
||||
|
||||
impl RiskMatrix {
|
||||
/// Calculate risk level based on likelihood and impact
|
||||
pub fn calculate_risk_level(&self, likelihood: &Likelihood, impact: &Impact) -> RiskLevel {
|
||||
self.matrix
|
||||
.get(&(likelihood.clone(), impact.clone()))
|
||||
.cloned()
|
||||
.unwrap_or(RiskLevel::Medium)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RiskMatrix {
|
||||
fn default() -> Self {
|
||||
let mut matrix = HashMap::new();
|
||||
|
||||
// Define risk matrix
|
||||
matrix.insert((Likelihood::Rare, Impact::Negligible), RiskLevel::Low);
|
||||
matrix.insert((Likelihood::Rare, Impact::Minor), RiskLevel::Low);
|
||||
matrix.insert((Likelihood::Rare, Impact::Moderate), RiskLevel::Low);
|
||||
matrix.insert((Likelihood::Rare, Impact::Major), RiskLevel::Medium);
|
||||
matrix.insert((Likelihood::Rare, Impact::Catastrophic), RiskLevel::High);
|
||||
|
||||
matrix.insert((Likelihood::Unlikely, Impact::Negligible), RiskLevel::Low);
|
||||
matrix.insert((Likelihood::Unlikely, Impact::Minor), RiskLevel::Low);
|
||||
matrix.insert((Likelihood::Unlikely, Impact::Moderate), RiskLevel::Medium);
|
||||
matrix.insert((Likelihood::Unlikely, Impact::Major), RiskLevel::High);
|
||||
matrix.insert((Likelihood::Unlikely, Impact::Catastrophic), RiskLevel::High);
|
||||
|
||||
matrix.insert((Likelihood::Possible, Impact::Negligible), RiskLevel::Low);
|
||||
matrix.insert((Likelihood::Possible, Impact::Minor), RiskLevel::Medium);
|
||||
matrix.insert((Likelihood::Possible, Impact::Moderate), RiskLevel::Medium);
|
||||
matrix.insert((Likelihood::Possible, Impact::Major), RiskLevel::High);
|
||||
matrix.insert((Likelihood::Possible, Impact::Catastrophic), RiskLevel::Critical);
|
||||
|
||||
matrix.insert((Likelihood::Likely, Impact::Negligible), RiskLevel::Medium);
|
||||
matrix.insert((Likelihood::Likely, Impact::Minor), RiskLevel::Medium);
|
||||
matrix.insert((Likelihood::Likely, Impact::Moderate), RiskLevel::High);
|
||||
matrix.insert((Likelihood::Likely, Impact::Major), RiskLevel::Critical);
|
||||
matrix.insert((Likelihood::Likely, Impact::Catastrophic), RiskLevel::Critical);
|
||||
|
||||
matrix.insert((Likelihood::AlmostCertain, Impact::Negligible), RiskLevel::Medium);
|
||||
matrix.insert((Likelihood::AlmostCertain, Impact::Minor), RiskLevel::High);
|
||||
matrix.insert((Likelihood::AlmostCertain, Impact::Moderate), RiskLevel::High);
|
||||
matrix.insert((Likelihood::AlmostCertain, Impact::Major), RiskLevel::Critical);
|
||||
matrix.insert(
|
||||
(Likelihood::AlmostCertain, Impact::Catastrophic),
|
||||
RiskLevel::Critical,
|
||||
);
|
||||
|
||||
Self { matrix }
|
||||
}
|
||||
}
|
||||
|
||||
/// Risk dashboard
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RiskDashboard {
|
||||
pub total_risks: usize,
|
||||
pub risks_by_level: HashMap<RiskLevel, usize>,
|
||||
pub risks_by_category: HashMap<RiskCategory, usize>,
|
||||
pub risks_by_status: HashMap<RiskStatus, usize>,
|
||||
pub mitigation_plans_pending: usize,
|
||||
pub last_updated: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl Default for RiskAssessmentService {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
501
src/compliance/training_tracker.rs
Normal file
501
src/compliance/training_tracker.rs
Normal file
|
|
@ -0,0 +1,501 @@
|
|||
//! Training Tracker Module
|
||||
//!
|
||||
//! Provides comprehensive security training tracking and compliance
|
||||
//! management for ensuring all personnel meet training requirements.
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Training type enumeration
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum TrainingType {
|
||||
SecurityAwareness,
|
||||
DataProtection,
|
||||
PhishingPrevention,
|
||||
IncidentResponse,
|
||||
ComplianceRegulation,
|
||||
PasswordManagement,
|
||||
AccessControl,
|
||||
EmergencyProcedures,
|
||||
}
|
||||
|
||||
/// Training status
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum TrainingStatus {
|
||||
NotStarted,
|
||||
InProgress,
|
||||
Completed,
|
||||
Expired,
|
||||
Failed,
|
||||
Exempted,
|
||||
}
|
||||
|
||||
/// Training priority
|
||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub enum TrainingPriority {
|
||||
Low,
|
||||
Medium,
|
||||
High,
|
||||
Critical,
|
||||
}
|
||||
|
||||
/// Training course definition
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TrainingCourse {
|
||||
pub id: Uuid,
|
||||
pub title: String,
|
||||
pub description: String,
|
||||
pub training_type: TrainingType,
|
||||
pub duration_hours: f32,
|
||||
pub validity_days: i64,
|
||||
pub priority: TrainingPriority,
|
||||
pub required_for_roles: Vec<String>,
|
||||
pub prerequisites: Vec<Uuid>,
|
||||
pub content_url: Option<String>,
|
||||
pub passing_score: u32,
|
||||
pub max_attempts: u32,
|
||||
}
|
||||
|
||||
/// Training assignment
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TrainingAssignment {
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub course_id: Uuid,
|
||||
pub assigned_date: DateTime<Utc>,
|
||||
pub due_date: DateTime<Utc>,
|
||||
pub status: TrainingStatus,
|
||||
pub attempts: Vec<TrainingAttempt>,
|
||||
pub completion_date: Option<DateTime<Utc>>,
|
||||
pub expiry_date: Option<DateTime<Utc>>,
|
||||
pub assigned_by: String,
|
||||
pub notes: Option<String>,
|
||||
}
|
||||
|
||||
/// Training attempt record
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TrainingAttempt {
|
||||
pub id: Uuid,
|
||||
pub attempt_number: u32,
|
||||
pub start_time: DateTime<Utc>,
|
||||
pub end_time: Option<DateTime<Utc>>,
|
||||
pub score: Option<u32>,
|
||||
pub passed: bool,
|
||||
pub time_spent_minutes: Option<u32>,
|
||||
}
|
||||
|
||||
/// Training certificate
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TrainingCertificate {
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub course_id: Uuid,
|
||||
pub issued_date: DateTime<Utc>,
|
||||
pub expiry_date: DateTime<Utc>,
|
||||
pub certificate_number: String,
|
||||
pub verification_code: String,
|
||||
}
|
||||
|
||||
/// Training compliance status
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ComplianceStatus {
|
||||
pub user_id: Uuid,
|
||||
pub compliant: bool,
|
||||
pub required_trainings: Vec<Uuid>,
|
||||
pub completed_trainings: Vec<Uuid>,
|
||||
pub overdue_trainings: Vec<Uuid>,
|
||||
pub upcoming_trainings: Vec<Uuid>,
|
||||
pub compliance_percentage: f64,
|
||||
}
|
||||
|
||||
/// Training tracker service
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TrainingTracker {
|
||||
courses: HashMap<Uuid, TrainingCourse>,
|
||||
assignments: HashMap<Uuid, TrainingAssignment>,
|
||||
certificates: HashMap<Uuid, TrainingCertificate>,
|
||||
user_roles: HashMap<Uuid, Vec<String>>,
|
||||
}
|
||||
|
||||
impl TrainingTracker {
|
||||
/// Create new training tracker
|
||||
pub fn new() -> Self {
|
||||
let mut tracker = Self {
|
||||
courses: HashMap::new(),
|
||||
assignments: HashMap::new(),
|
||||
certificates: HashMap::new(),
|
||||
user_roles: HashMap::new(),
|
||||
};
|
||||
|
||||
// Initialize with default courses
|
||||
tracker.initialize_default_courses();
|
||||
tracker
|
||||
}
|
||||
|
||||
/// Initialize default training courses
|
||||
fn initialize_default_courses(&mut self) {
|
||||
let security_awareness = TrainingCourse {
|
||||
id: Uuid::new_v4(),
|
||||
title: "Security Awareness Fundamentals".to_string(),
|
||||
description: "Basic security awareness training for all employees".to_string(),
|
||||
training_type: TrainingType::SecurityAwareness,
|
||||
duration_hours: 2.0,
|
||||
validity_days: 365,
|
||||
priority: TrainingPriority::High,
|
||||
required_for_roles: vec!["all".to_string()],
|
||||
prerequisites: vec![],
|
||||
content_url: Some("https://training.example.com/security-awareness".to_string()),
|
||||
passing_score: 80,
|
||||
max_attempts: 3,
|
||||
};
|
||||
|
||||
self.courses.insert(security_awareness.id, security_awareness);
|
||||
|
||||
let data_protection = TrainingCourse {
|
||||
id: Uuid::new_v4(),
|
||||
title: "Data Protection and Privacy".to_string(),
|
||||
description: "Training on data protection regulations and best practices".to_string(),
|
||||
training_type: TrainingType::DataProtection,
|
||||
duration_hours: 3.0,
|
||||
validity_days: 365,
|
||||
priority: TrainingPriority::High,
|
||||
required_for_roles: vec!["admin".to_string(), "manager".to_string()],
|
||||
prerequisites: vec![],
|
||||
content_url: Some("https://training.example.com/data-protection".to_string()),
|
||||
passing_score: 85,
|
||||
max_attempts: 3,
|
||||
};
|
||||
|
||||
self.courses.insert(data_protection.id, data_protection);
|
||||
}
|
||||
|
||||
/// Create a training course
|
||||
pub fn create_course(&mut self, course: TrainingCourse) -> Result<()> {
|
||||
if self.courses.contains_key(&course.id) {
|
||||
return Err(anyhow!("Course already exists"));
|
||||
}
|
||||
|
||||
log::info!("Creating training course: {}", course.title);
|
||||
self.courses.insert(course.id, course);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Assign training to user
|
||||
pub fn assign_training(
|
||||
&mut self,
|
||||
user_id: Uuid,
|
||||
course_id: Uuid,
|
||||
due_days: i64,
|
||||
assigned_by: String,
|
||||
) -> Result<TrainingAssignment> {
|
||||
let course = self
|
||||
.courses
|
||||
.get(&course_id)
|
||||
.ok_or_else(|| anyhow!("Course not found"))?
|
||||
.clone();
|
||||
|
||||
let assignment = TrainingAssignment {
|
||||
id: Uuid::new_v4(),
|
||||
user_id,
|
||||
course_id,
|
||||
assigned_date: Utc::now(),
|
||||
due_date: Utc::now() + Duration::days(due_days),
|
||||
status: TrainingStatus::NotStarted,
|
||||
attempts: vec![],
|
||||
completion_date: None,
|
||||
expiry_date: None,
|
||||
assigned_by,
|
||||
notes: None,
|
||||
};
|
||||
|
||||
self.assignments.insert(assignment.id, assignment.clone());
|
||||
|
||||
log::info!(
|
||||
"Assigned training '{}' to user {}",
|
||||
course.title,
|
||||
user_id
|
||||
);
|
||||
|
||||
Ok(assignment)
|
||||
}
|
||||
|
||||
/// Start training attempt
|
||||
pub fn start_training(&mut self, assignment_id: Uuid) -> Result<TrainingAttempt> {
|
||||
let assignment = self
|
||||
.assignments
|
||||
.get_mut(&assignment_id)
|
||||
.ok_or_else(|| anyhow!("Assignment not found"))?;
|
||||
|
||||
let course = self
|
||||
.courses
|
||||
.get(&assignment.course_id)
|
||||
.ok_or_else(|| anyhow!("Course not found"))?;
|
||||
|
||||
if assignment.attempts.len() >= course.max_attempts as usize {
|
||||
return Err(anyhow!("Maximum attempts exceeded"));
|
||||
}
|
||||
|
||||
let attempt = TrainingAttempt {
|
||||
id: Uuid::new_v4(),
|
||||
attempt_number: (assignment.attempts.len() + 1) as u32,
|
||||
start_time: Utc::now(),
|
||||
end_time: None,
|
||||
score: None,
|
||||
passed: false,
|
||||
time_spent_minutes: None,
|
||||
};
|
||||
|
||||
assignment.status = TrainingStatus::InProgress;
|
||||
assignment.attempts.push(attempt.clone());
|
||||
|
||||
Ok(attempt)
|
||||
}
|
||||
|
||||
/// Complete training attempt
|
||||
pub fn complete_training(
|
||||
&mut self,
|
||||
assignment_id: Uuid,
|
||||
attempt_id: Uuid,
|
||||
score: u32,
|
||||
) -> Result<bool> {
|
||||
let assignment = self
|
||||
.assignments
|
||||
.get_mut(&assignment_id)
|
||||
.ok_or_else(|| anyhow!("Assignment not found"))?;
|
||||
|
||||
let course = self
|
||||
.courses
|
||||
.get(&assignment.course_id)
|
||||
.ok_or_else(|| anyhow!("Course not found"))?
|
||||
.clone();
|
||||
|
||||
let attempt = assignment
|
||||
.attempts
|
||||
.iter_mut()
|
||||
.find(|a| a.id == attempt_id)
|
||||
.ok_or_else(|| anyhow!("Attempt not found"))?;
|
||||
|
||||
let end_time = Utc::now();
|
||||
let time_spent = (end_time - attempt.start_time).num_minutes() as u32;
|
||||
|
||||
attempt.end_time = Some(end_time);
|
||||
attempt.score = Some(score);
|
||||
attempt.time_spent_minutes = Some(time_spent);
|
||||
attempt.passed = score >= course.passing_score;
|
||||
|
||||
if attempt.passed {
|
||||
assignment.status = TrainingStatus::Completed;
|
||||
assignment.completion_date = Some(end_time);
|
||||
assignment.expiry_date = Some(end_time + Duration::days(course.validity_days));
|
||||
|
||||
// Issue certificate
|
||||
let certificate = TrainingCertificate {
|
||||
id: Uuid::new_v4(),
|
||||
user_id: assignment.user_id,
|
||||
course_id: course.id,
|
||||
issued_date: end_time,
|
||||
expiry_date: end_time + Duration::days(course.validity_days),
|
||||
certificate_number: format!("CERT-{}", Uuid::new_v4().to_string()[..8].to_uppercase()),
|
||||
verification_code: Uuid::new_v4().to_string(),
|
||||
};
|
||||
|
||||
self.certificates.insert(certificate.id, certificate);
|
||||
|
||||
log::info!(
|
||||
"User {} completed training '{}' with score {}",
|
||||
assignment.user_id,
|
||||
course.title,
|
||||
score
|
||||
);
|
||||
} else if assignment.attempts.len() >= course.max_attempts as usize {
|
||||
assignment.status = TrainingStatus::Failed;
|
||||
}
|
||||
|
||||
Ok(attempt.passed)
|
||||
}
|
||||
|
||||
/// Get user compliance status
|
||||
pub fn get_compliance_status(&self, user_id: Uuid) -> ComplianceStatus {
|
||||
let user_roles = self
|
||||
.user_roles
|
||||
.get(&user_id)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| vec!["all".to_string()]);
|
||||
|
||||
let mut required_trainings = vec![];
|
||||
let mut completed_trainings = vec![];
|
||||
let mut overdue_trainings = vec![];
|
||||
let mut upcoming_trainings = vec![];
|
||||
|
||||
for course in self.courses.values() {
|
||||
if course.required_for_roles.iter().any(|r| {
|
||||
user_roles.contains(r) || r == "all"
|
||||
}) {
|
||||
required_trainings.push(course.id);
|
||||
|
||||
// Check if user has completed this training
|
||||
let assignment = self
|
||||
.assignments
|
||||
.values()
|
||||
.find(|a| a.user_id == user_id && a.course_id == course.id);
|
||||
|
||||
if let Some(assignment) = assignment {
|
||||
match assignment.status {
|
||||
TrainingStatus::Completed => {
|
||||
if let Some(expiry) = assignment.expiry_date {
|
||||
if expiry > Utc::now() {
|
||||
completed_trainings.push(course.id);
|
||||
} else {
|
||||
overdue_trainings.push(course.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
TrainingStatus::NotStarted | TrainingStatus::InProgress => {
|
||||
if assignment.due_date < Utc::now() {
|
||||
overdue_trainings.push(course.id);
|
||||
} else {
|
||||
upcoming_trainings.push(course.id);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
} else {
|
||||
overdue_trainings.push(course.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let compliance_percentage = if required_trainings.is_empty() {
|
||||
100.0
|
||||
} else {
|
||||
(completed_trainings.len() as f64 / required_trainings.len() as f64) * 100.0
|
||||
};
|
||||
|
||||
ComplianceStatus {
|
||||
user_id,
|
||||
compliant: overdue_trainings.is_empty(),
|
||||
required_trainings,
|
||||
completed_trainings,
|
||||
overdue_trainings,
|
||||
upcoming_trainings,
|
||||
compliance_percentage,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get training report
|
||||
pub fn get_training_report(&self) -> TrainingReport {
|
||||
let total_courses = self.courses.len();
|
||||
let total_assignments = self.assignments.len();
|
||||
let total_certificates = self.certificates.len();
|
||||
|
||||
let mut assignments_by_status = HashMap::new();
|
||||
for assignment in self.assignments.values() {
|
||||
*assignments_by_status
|
||||
.entry(assignment.status.clone())
|
||||
.or_insert(0) += 1;
|
||||
}
|
||||
|
||||
let overdue_count = self
|
||||
.assignments
|
||||
.values()
|
||||
.filter(|a| {
|
||||
a.status != TrainingStatus::Completed
|
||||
&& a.due_date < Utc::now()
|
||||
})
|
||||
.count();
|
||||
|
||||
let expiring_soon = self
|
||||
.certificates
|
||||
.values()
|
||||
.filter(|c| {
|
||||
c.expiry_date > Utc::now()
|
||||
&& c.expiry_date < Utc::now() + Duration::days(30)
|
||||
})
|
||||
.count();
|
||||
|
||||
let average_score = self.calculate_average_score();
|
||||
|
||||
TrainingReport {
|
||||
generated_at: Utc::now(),
|
||||
total_courses,
|
||||
total_assignments,
|
||||
total_certificates,
|
||||
assignments_by_status,
|
||||
overdue_count,
|
||||
expiring_soon,
|
||||
average_score,
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate average training score
|
||||
fn calculate_average_score(&self) -> f64 {
|
||||
let mut total_score = 0;
|
||||
let mut count = 0;
|
||||
|
||||
for assignment in self.assignments.values() {
|
||||
for attempt in &assignment.attempts {
|
||||
if let Some(score) = attempt.score {
|
||||
total_score += score;
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
0.0
|
||||
} else {
|
||||
total_score as f64 / count as f64
|
||||
}
|
||||
}
|
||||
|
||||
/// Set user roles
|
||||
pub fn set_user_roles(&mut self, user_id: Uuid, roles: Vec<String>) {
|
||||
self.user_roles.insert(user_id, roles);
|
||||
}
|
||||
|
||||
/// Get overdue trainings
|
||||
pub fn get_overdue_trainings(&self) -> Vec<TrainingAssignment> {
|
||||
self.assignments
|
||||
.values()
|
||||
.filter(|a| {
|
||||
a.status != TrainingStatus::Completed
|
||||
&& a.due_date < Utc::now()
|
||||
})
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get expiring certificates
|
||||
pub fn get_expiring_certificates(&self, days_ahead: i64) -> Vec<TrainingCertificate> {
|
||||
let cutoff = Utc::now() + Duration::days(days_ahead);
|
||||
self.certificates
|
||||
.values()
|
||||
.filter(|c| {
|
||||
c.expiry_date > Utc::now() && c.expiry_date <= cutoff
|
||||
})
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Training report
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TrainingReport {
|
||||
pub generated_at: DateTime<Utc>,
|
||||
pub total_courses: usize,
|
||||
pub total_assignments: usize,
|
||||
pub total_certificates: usize,
|
||||
pub assignments_by_status: HashMap<TrainingStatus, usize>,
|
||||
pub overdue_count: usize,
|
||||
pub expiring_soon: usize,
|
||||
pub average_score: f64,
|
||||
}
|
||||
|
||||
impl Default for TrainingTracker {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
338
src/core/bot/kb_context.rs
Normal file
338
src/core/bot/kb_context.rs
Normal file
|
|
@ -0,0 +1,338 @@
|
|||
use anyhow::Result;
|
||||
use diesel::prelude::*;
|
||||
use log::{debug, error, info, warn};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::core::kb::KnowledgeBaseManager;
|
||||
use crate::shared::utils::DbPool;
|
||||
|
||||
/// Represents an active KB association for a session
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SessionKbAssociation {
|
||||
pub kb_name: String,
|
||||
pub qdrant_collection: String,
|
||||
pub kb_folder_path: String,
|
||||
pub is_active: bool,
|
||||
}
|
||||
|
||||
/// KB context that will be injected into the LLM prompt
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct KbContext {
|
||||
pub kb_name: String,
|
||||
pub search_results: Vec<KbSearchResult>,
|
||||
pub total_tokens: usize,
|
||||
}
|
||||
|
||||
/// Individual search result from a KB
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct KbSearchResult {
|
||||
pub content: String,
|
||||
pub document_path: String,
|
||||
pub score: f32,
|
||||
pub chunk_tokens: usize,
|
||||
}
|
||||
|
||||
/// Manager for KB context retrieval and injection
|
||||
#[derive(Debug)]
|
||||
pub struct KbContextManager {
|
||||
kb_manager: Arc<KnowledgeBaseManager>,
|
||||
db_pool: DbPool,
|
||||
}
|
||||
|
||||
impl KbContextManager {
|
||||
/// Create a new KB context manager
|
||||
pub fn new(kb_manager: Arc<KnowledgeBaseManager>, db_pool: DbPool) -> Self {
|
||||
Self {
|
||||
kb_manager,
|
||||
db_pool,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all active KB associations for a session
|
||||
pub async fn get_active_kbs(&self, session_id: Uuid) -> Result<Vec<SessionKbAssociation>> {
|
||||
let mut conn = self.db_pool.get()?;
|
||||
|
||||
// Query for active KB associations
|
||||
let query = diesel::sql_query(
|
||||
"SELECT kb_name, qdrant_collection, kb_folder_path, is_active
|
||||
FROM session_kb_associations
|
||||
WHERE session_id = $1 AND is_active = true",
|
||||
)
|
||||
.bind::<diesel::sql_types::Uuid, _>(session_id);
|
||||
|
||||
#[derive(QueryableByName)]
|
||||
struct KbAssocRow {
|
||||
#[diesel(sql_type = diesel::sql_types::Text)]
|
||||
kb_name: String,
|
||||
#[diesel(sql_type = diesel::sql_types::Text)]
|
||||
qdrant_collection: String,
|
||||
#[diesel(sql_type = diesel::sql_types::Text)]
|
||||
kb_folder_path: String,
|
||||
#[diesel(sql_type = diesel::sql_types::Bool)]
|
||||
is_active: bool,
|
||||
}
|
||||
|
||||
let rows: Vec<KbAssocRow> = query.load(&mut conn)?;
|
||||
|
||||
Ok(rows
|
||||
.into_iter()
|
||||
.map(|row| SessionKbAssociation {
|
||||
kb_name: row.kb_name,
|
||||
qdrant_collection: row.qdrant_collection,
|
||||
kb_folder_path: row.kb_folder_path,
|
||||
is_active: row.is_active,
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Search all active KBs for relevant context
|
||||
pub async fn search_active_kbs(
|
||||
&self,
|
||||
session_id: Uuid,
|
||||
bot_name: &str,
|
||||
query: &str,
|
||||
max_results_per_kb: usize,
|
||||
max_total_tokens: usize,
|
||||
) -> Result<Vec<KbContext>> {
|
||||
let active_kbs = self.get_active_kbs(session_id).await?;
|
||||
|
||||
if active_kbs.is_empty() {
|
||||
debug!("No active KBs for session {}", session_id);
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
info!(
|
||||
"Searching {} active KBs for session {}: {:?}",
|
||||
active_kbs.len(),
|
||||
session_id,
|
||||
active_kbs.iter().map(|kb| &kb.kb_name).collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
let mut kb_contexts = Vec::new();
|
||||
let mut total_tokens_used = 0;
|
||||
|
||||
for kb_assoc in active_kbs {
|
||||
if total_tokens_used >= max_total_tokens {
|
||||
warn!("Reached max token limit, skipping remaining KBs");
|
||||
break;
|
||||
}
|
||||
|
||||
match self
|
||||
.search_single_kb(
|
||||
bot_name,
|
||||
&kb_assoc.kb_name,
|
||||
query,
|
||||
max_results_per_kb,
|
||||
max_total_tokens - total_tokens_used,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(context) => {
|
||||
total_tokens_used += context.total_tokens;
|
||||
info!(
|
||||
"Found {} results from KB '{}' using {} tokens",
|
||||
context.search_results.len(),
|
||||
context.kb_name,
|
||||
context.total_tokens
|
||||
);
|
||||
kb_contexts.push(context);
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to search KB '{}': {}", kb_assoc.kb_name, e);
|
||||
// Continue with other KBs even if one fails
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(kb_contexts)
|
||||
}
|
||||
|
||||
/// Search a single KB for relevant context
|
||||
async fn search_single_kb(
|
||||
&self,
|
||||
bot_name: &str,
|
||||
kb_name: &str,
|
||||
query: &str,
|
||||
max_results: usize,
|
||||
max_tokens: usize,
|
||||
) -> Result<KbContext> {
|
||||
debug!("Searching KB '{}' with query: {}", kb_name, query);
|
||||
|
||||
// Use the KnowledgeBaseManager to search
|
||||
let search_results = self
|
||||
.kb_manager
|
||||
.search(bot_name, kb_name, query, max_results)
|
||||
.await?;
|
||||
|
||||
let mut kb_search_results = Vec::new();
|
||||
let mut total_tokens = 0;
|
||||
|
||||
for result in search_results {
|
||||
let tokens = estimate_tokens(&result.content);
|
||||
|
||||
// Check if adding this result would exceed token limit
|
||||
if total_tokens + tokens > max_tokens {
|
||||
debug!(
|
||||
"Skipping result due to token limit ({} + {} > {})",
|
||||
total_tokens, tokens, max_tokens
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
kb_search_results.push(KbSearchResult {
|
||||
content: result.content,
|
||||
document_path: result.document_path,
|
||||
score: result.score,
|
||||
chunk_tokens: tokens,
|
||||
});
|
||||
|
||||
total_tokens += tokens;
|
||||
|
||||
// Only include high-relevance results (score > 0.7)
|
||||
if result.score < 0.7 {
|
||||
debug!("Skipping low-relevance result (score: {})", result.score);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(KbContext {
|
||||
kb_name: kb_name.to_string(),
|
||||
search_results: kb_search_results,
|
||||
total_tokens,
|
||||
})
|
||||
}
|
||||
|
||||
/// Build context string from KB search results for LLM injection
|
||||
pub fn build_context_string(&self, kb_contexts: &[KbContext]) -> String {
|
||||
if kb_contexts.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
let mut context_parts = vec!["\n--- Knowledge Base Context ---".to_string()];
|
||||
|
||||
for kb_context in kb_contexts {
|
||||
if kb_context.search_results.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
context_parts.push(format!(
|
||||
"\n## From '{}' knowledge base:",
|
||||
kb_context.kb_name
|
||||
));
|
||||
|
||||
for (idx, result) in kb_context.search_results.iter().enumerate() {
|
||||
context_parts.push(format!(
|
||||
"\n### Result {} (relevance: {:.2}):\n{}",
|
||||
idx + 1,
|
||||
result.score,
|
||||
result.content
|
||||
));
|
||||
|
||||
if !result.document_path.is_empty() {
|
||||
context_parts.push(format!("Source: {}", result.document_path));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
context_parts.push("\n--- End Knowledge Base Context ---\n".to_string());
|
||||
context_parts.join("\n")
|
||||
}
|
||||
|
||||
/// Get active tools for a session (similar to KBs)
|
||||
pub async fn get_active_tools(&self, session_id: Uuid) -> Result<Vec<String>> {
|
||||
let mut conn = self.db_pool.get()?;
|
||||
|
||||
let query = diesel::sql_query(
|
||||
"SELECT tool_name
|
||||
FROM session_tool_associations
|
||||
WHERE session_id = $1 AND is_active = true",
|
||||
)
|
||||
.bind::<diesel::sql_types::Uuid, _>(session_id);
|
||||
|
||||
#[derive(QueryableByName)]
|
||||
struct ToolRow {
|
||||
#[diesel(sql_type = diesel::sql_types::Text)]
|
||||
tool_name: String,
|
||||
}
|
||||
|
||||
let rows: Vec<ToolRow> = query.load(&mut conn)?;
|
||||
Ok(rows.into_iter().map(|row| row.tool_name).collect())
|
||||
}
|
||||
}
|
||||
|
||||
/// Estimate token count for a string (rough approximation)
|
||||
fn estimate_tokens(text: &str) -> usize {
|
||||
// Rough estimate: 1 token per 4 characters
|
||||
// This is a simplified heuristic; real tokenization would be more accurate
|
||||
text.len() / 4
|
||||
}
|
||||
|
||||
/// Integration helper for injecting KB context into LLM messages
|
||||
pub async fn inject_kb_context(
|
||||
kb_manager: Arc<KnowledgeBaseManager>,
|
||||
db_pool: DbPool,
|
||||
session_id: Uuid,
|
||||
bot_name: &str,
|
||||
user_query: &str,
|
||||
messages: &mut serde_json::Value,
|
||||
max_context_tokens: usize,
|
||||
) -> Result<()> {
|
||||
let context_manager = KbContextManager::new(kb_manager, db_pool);
|
||||
|
||||
// Search active KBs
|
||||
let kb_contexts = context_manager
|
||||
.search_active_kbs(
|
||||
session_id,
|
||||
bot_name,
|
||||
user_query,
|
||||
5, // max 5 results per KB
|
||||
max_context_tokens,
|
||||
)
|
||||
.await?;
|
||||
|
||||
if kb_contexts.is_empty() {
|
||||
debug!("No KB context found for session {}", session_id);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Build context string
|
||||
let context_string = context_manager.build_context_string(&kb_contexts);
|
||||
|
||||
if context_string.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
info!(
|
||||
"Injecting {} characters of KB context into prompt for session {}",
|
||||
context_string.len(),
|
||||
session_id
|
||||
);
|
||||
|
||||
// Inject context into messages
|
||||
// The context is added as a system message or appended to the existing system prompt
|
||||
if let Some(messages_array) = messages.as_array_mut() {
|
||||
// Find or create system message
|
||||
let system_msg_idx = messages_array.iter().position(|m| m["role"] == "system");
|
||||
|
||||
if let Some(idx) = system_msg_idx {
|
||||
// Append to existing system message
|
||||
if let Some(content) = messages_array[idx]["content"].as_str() {
|
||||
let new_content = format!("{}\n{}", content, context_string);
|
||||
messages_array[idx]["content"] = serde_json::Value::String(new_content);
|
||||
}
|
||||
} else {
|
||||
// Insert as first message
|
||||
messages_array.insert(
|
||||
0,
|
||||
serde_json::json!({
|
||||
"role": "system",
|
||||
"content": context_string
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -129,16 +129,25 @@ impl BotOrchestrator {
|
|||
|
||||
let system_prompt = std::env::var("SYSTEM_PROMPT")
|
||||
.unwrap_or_else(|_| "You are a helpful assistant.".to_string());
|
||||
let messages = OpenAIClient::build_messages(&system_prompt, &context_data, &history);
|
||||
let mut messages = OpenAIClient::build_messages(&system_prompt, &context_data, &history);
|
||||
|
||||
// Inject bot_id into messages for cache system
|
||||
if let serde_json::Value::Array(ref mut msgs) = messages {
|
||||
let bot_id_obj = serde_json::json!({
|
||||
"bot_id": bot_id.to_string()
|
||||
});
|
||||
msgs.push(bot_id_obj);
|
||||
}
|
||||
|
||||
let (stream_tx, mut stream_rx) = mpsc::channel::<String>(100);
|
||||
let llm = self.state.llm_provider.clone();
|
||||
|
||||
let model_clone = model.clone();
|
||||
let key_clone = key.clone();
|
||||
let messages_clone = messages.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = llm
|
||||
.generate_stream("", &messages, stream_tx, &model_clone, &key_clone)
|
||||
.generate_stream("", &messages_clone, stream_tx, &model_clone, &key_clone)
|
||||
.await
|
||||
{
|
||||
error!("LLM streaming error: {}", e);
|
||||
|
|
|
|||
526
src/core/bot/mod_backup.rs
Normal file
526
src/core/bot/mod_backup.rs
Normal file
|
|
@ -0,0 +1,526 @@
|
|||
use crate::core::config::ConfigManager;
|
||||
use crate::drive::drive_monitor::DriveMonitor;
|
||||
use crate::llm::llm_models;
|
||||
use crate::llm::OpenAIClient;
|
||||
#[cfg(feature = "nvidia")]
|
||||
use crate::nvidia::get_system_metrics;
|
||||
use crate::shared::models::{BotResponse, UserMessage, UserSession};
|
||||
use crate::shared::state::AppState;
|
||||
use axum::extract::ws::{Message, WebSocket};
|
||||
use axum::{
|
||||
extract::{ws::WebSocketUpgrade, Extension, Query, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Json},
|
||||
};
|
||||
use diesel::PgConnection;
|
||||
use futures::{sink::SinkExt, stream::StreamExt};
|
||||
use log::{error, info, trace, warn};
|
||||
use serde_json;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::Mutex as AsyncMutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub mod channels;
|
||||
pub mod multimedia;
|
||||
|
||||
pub fn get_default_bot(conn: &mut PgConnection) -> (Uuid, String) {
|
||||
use crate::shared::models::schema::bots::dsl::*;
|
||||
use diesel::prelude::*;
|
||||
|
||||
match bots
|
||||
.filter(is_active.eq(true))
|
||||
.select((id, name))
|
||||
.first::<(Uuid, String)>(conn)
|
||||
.optional()
|
||||
{
|
||||
Ok(Some((bot_id, bot_name))) => (bot_id, bot_name),
|
||||
Ok(None) => {
|
||||
warn!("No active bots found, using nil UUID");
|
||||
(Uuid::nil(), "default".to_string())
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to query default bot: {}", e);
|
||||
(Uuid::nil(), "default".to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct BotOrchestrator {
|
||||
pub state: Arc<AppState>,
|
||||
pub mounted_bots: Arc<AsyncMutex<HashMap<String, Arc<DriveMonitor>>>>,
|
||||
}
|
||||
|
||||
impl BotOrchestrator {
|
||||
pub fn new(state: Arc<AppState>) -> Self {
|
||||
Self {
|
||||
state,
|
||||
mounted_bots: Arc::new(AsyncMutex::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn mount_all_bots(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
info!("mount_all_bots called");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn stream_response(
|
||||
&self,
|
||||
message: UserMessage,
|
||||
response_tx: mpsc::Sender<BotResponse>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
trace!(
|
||||
"Streaming response for user: {}, session: {}",
|
||||
message.user_id,
|
||||
message.session_id
|
||||
);
|
||||
|
||||
let user_id = Uuid::parse_str(&message.user_id)?;
|
||||
let session_id = Uuid::parse_str(&message.session_id)?;
|
||||
let bot_id = Uuid::parse_str(&message.bot_id).unwrap_or_default();
|
||||
|
||||
let (session, context_data, history, model, key) = {
|
||||
let state_clone = self.state.clone();
|
||||
tokio::task::spawn_blocking(
|
||||
move || -> Result<_, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let session = {
|
||||
let mut sm = state_clone.session_manager.blocking_lock();
|
||||
sm.get_session_by_id(session_id)?
|
||||
}
|
||||
.ok_or_else(|| "Session not found")?;
|
||||
|
||||
{
|
||||
let mut sm = state_clone.session_manager.blocking_lock();
|
||||
sm.save_message(session.id, user_id, 1, &message.content, 1)?;
|
||||
}
|
||||
|
||||
let context_data = {
|
||||
let sm = state_clone.session_manager.blocking_lock();
|
||||
let rt = tokio::runtime::Handle::current();
|
||||
rt.block_on(async {
|
||||
sm.get_session_context_data(&session.id, &session.user_id)
|
||||
.await
|
||||
})?
|
||||
};
|
||||
|
||||
let history = {
|
||||
let mut sm = state_clone.session_manager.blocking_lock();
|
||||
sm.get_conversation_history(session.id, user_id)?
|
||||
};
|
||||
|
||||
let config_manager = ConfigManager::new(state_clone.conn.clone());
|
||||
let model = config_manager
|
||||
.get_config(&bot_id, "llm-model", Some("gpt-3.5-turbo"))
|
||||
.unwrap_or_else(|_| "gpt-3.5-turbo".to_string());
|
||||
let key = config_manager
|
||||
.get_config(&bot_id, "llm-key", Some(""))
|
||||
.unwrap_or_default();
|
||||
|
||||
Ok((session, context_data, history, model, key))
|
||||
},
|
||||
)
|
||||
.await??
|
||||
};
|
||||
|
||||
let system_prompt = std::env::var("SYSTEM_PROMPT")
|
||||
.unwrap_or_else(|_| "You are a helpful assistant.".to_string());
|
||||
let messages = OpenAIClient::build_messages(&system_prompt, &context_data, &history);
|
||||
|
||||
let (stream_tx, mut stream_rx) = mpsc::channel::<String>(100);
|
||||
let llm = self.state.llm_provider.clone();
|
||||
|
||||
let model_clone = model.clone();
|
||||
let key_clone = key.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = llm
|
||||
.generate_stream("", &messages, stream_tx, &model_clone, &key_clone)
|
||||
.await
|
||||
{
|
||||
error!("LLM streaming error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
let mut full_response = String::new();
|
||||
let mut analysis_buffer = String::new();
|
||||
let mut in_analysis = false;
|
||||
let handler = llm_models::get_handler(&model);
|
||||
|
||||
#[cfg(feature = "nvidia")]
|
||||
{
|
||||
let initial_tokens = crate::shared::utils::estimate_token_count(&context_data);
|
||||
let config_manager = ConfigManager::new(self.state.conn.clone());
|
||||
let max_context_size = config_manager
|
||||
.get_config(&bot_id, "llm-server-ctx-size", None)
|
||||
.unwrap_or_default()
|
||||
.parse::<usize>()
|
||||
.unwrap_or(0);
|
||||
|
||||
if let Ok(metrics) = get_system_metrics(initial_tokens, max_context_size) {
|
||||
eprintln!(
|
||||
"\nNVIDIA: {:.1}% | CPU: {:.1}% | Tokens: {}/{}",
|
||||
metrics.gpu_usage.unwrap_or(0.0),
|
||||
metrics.cpu_usage,
|
||||
initial_tokens,
|
||||
max_context_size
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
while let Some(chunk) = stream_rx.recv().await {
|
||||
trace!("Received LLM chunk: {:?}", chunk);
|
||||
analysis_buffer.push_str(&chunk);
|
||||
|
||||
if handler.has_analysis_markers(&analysis_buffer) && !in_analysis {
|
||||
in_analysis = true;
|
||||
}
|
||||
|
||||
if in_analysis && handler.is_analysis_complete(&analysis_buffer) {
|
||||
in_analysis = false;
|
||||
analysis_buffer.clear();
|
||||
continue;
|
||||
}
|
||||
|
||||
if !in_analysis {
|
||||
full_response.push_str(&chunk);
|
||||
|
||||
let response = BotResponse {
|
||||
bot_id: message.bot_id.clone(),
|
||||
user_id: message.user_id.clone(),
|
||||
session_id: message.session_id.clone(),
|
||||
channel: message.channel.clone(),
|
||||
content: chunk,
|
||||
message_type: 2,
|
||||
stream_token: None,
|
||||
is_complete: false,
|
||||
suggestions: Vec::new(),
|
||||
context_name: None,
|
||||
context_length: 0,
|
||||
context_max_length: 0,
|
||||
};
|
||||
|
||||
if response_tx.send(response).await.is_err() {
|
||||
warn!("Response channel closed");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let state_for_save = self.state.clone();
|
||||
let full_response_clone = full_response.clone();
|
||||
tokio::task::spawn_blocking(
|
||||
move || -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let mut sm = state_for_save.session_manager.blocking_lock();
|
||||
sm.save_message(session.id, user_id, 2, &full_response_clone, 2)?;
|
||||
Ok(())
|
||||
},
|
||||
)
|
||||
.await??;
|
||||
|
||||
let final_response = BotResponse {
|
||||
bot_id: message.bot_id,
|
||||
user_id: message.user_id,
|
||||
session_id: message.session_id,
|
||||
channel: message.channel,
|
||||
content: full_response,
|
||||
message_type: 2,
|
||||
stream_token: None,
|
||||
is_complete: true,
|
||||
suggestions: Vec::new(),
|
||||
context_name: None,
|
||||
context_length: 0,
|
||||
context_max_length: 0,
|
||||
};
|
||||
|
||||
response_tx.send(final_response).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_user_sessions(
|
||||
&self,
|
||||
user_id: Uuid,
|
||||
) -> Result<Vec<UserSession>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let mut session_manager = self.state.session_manager.lock().await;
|
||||
let sessions = session_manager.get_user_sessions(user_id)?;
|
||||
Ok(sessions)
|
||||
}
|
||||
|
||||
pub async fn get_conversation_history(
|
||||
&self,
|
||||
session_id: Uuid,
|
||||
user_id: Uuid,
|
||||
) -> Result<Vec<(String, String)>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let mut session_manager = self.state.session_manager.lock().await;
|
||||
let history = session_manager.get_conversation_history(session_id, user_id)?;
|
||||
Ok(history)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn websocket_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Query(params): Query<HashMap<String, String>>,
|
||||
) -> impl IntoResponse {
|
||||
let session_id = params
|
||||
.get("session_id")
|
||||
.and_then(|s| Uuid::parse_str(s).ok());
|
||||
let user_id = params.get("user_id").and_then(|s| Uuid::parse_str(s).ok());
|
||||
|
||||
if session_id.is_none() || user_id.is_none() {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({ "error": "session_id and user_id are required" })),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
ws.on_upgrade(move |socket| {
|
||||
handle_websocket(socket, state, session_id.unwrap(), user_id.unwrap())
|
||||
})
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn handle_websocket(
|
||||
socket: WebSocket,
|
||||
state: Arc<AppState>,
|
||||
session_id: Uuid,
|
||||
user_id: Uuid,
|
||||
) {
|
||||
let (mut sender, mut receiver) = socket.split();
|
||||
let (tx, mut rx) = mpsc::channel::<BotResponse>(100);
|
||||
|
||||
state
|
||||
.web_adapter
|
||||
.add_connection(session_id.to_string(), tx.clone())
|
||||
.await;
|
||||
|
||||
{
|
||||
let mut channels = state.response_channels.lock().await;
|
||||
channels.insert(session_id.to_string(), tx.clone());
|
||||
}
|
||||
|
||||
info!(
|
||||
"WebSocket connected for session: {}, user: {}",
|
||||
session_id, user_id
|
||||
);
|
||||
|
||||
let welcome = serde_json::json!({
|
||||
"type": "connected",
|
||||
"session_id": session_id,
|
||||
"user_id": user_id,
|
||||
"message": "Connected to bot server"
|
||||
});
|
||||
|
||||
if let Ok(welcome_str) = serde_json::to_string(&welcome) {
|
||||
if sender
|
||||
.send(Message::Text(welcome_str.into()))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
error!("Failed to send welcome message");
|
||||
}
|
||||
}
|
||||
|
||||
let mut send_task = tokio::spawn(async move {
|
||||
while let Some(response) = rx.recv().await {
|
||||
if let Ok(json_str) = serde_json::to_string(&response) {
|
||||
if sender.send(Message::Text(json_str.into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let state_clone = state.clone();
|
||||
let mut recv_task = tokio::spawn(async move {
|
||||
while let Some(Ok(msg)) = receiver.next().await {
|
||||
match msg {
|
||||
Message::Text(text) => {
|
||||
info!("Received WebSocket message: {}", text);
|
||||
if let Ok(user_msg) = serde_json::from_str::<UserMessage>(&text) {
|
||||
let orchestrator = BotOrchestrator::new(state_clone.clone());
|
||||
if let Some(tx_clone) = state_clone
|
||||
.response_channels
|
||||
.lock()
|
||||
.await
|
||||
.get(&session_id.to_string())
|
||||
{
|
||||
if let Err(e) = orchestrator
|
||||
.stream_response(user_msg, tx_clone.clone())
|
||||
.await
|
||||
{
|
||||
error!("Failed to stream response: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Message::Close(_) => {
|
||||
info!("WebSocket close message received");
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
tokio::select! {
|
||||
_ = (&mut send_task) => { recv_task.abort(); }
|
||||
_ = (&mut recv_task) => { send_task.abort(); }
|
||||
}
|
||||
|
||||
state
|
||||
.web_adapter
|
||||
.remove_connection(&session_id.to_string())
|
||||
.await;
|
||||
|
||||
{
|
||||
let mut channels = state.response_channels.lock().await;
|
||||
channels.remove(&session_id.to_string());
|
||||
}
|
||||
|
||||
info!("WebSocket disconnected for session: {}", session_id);
|
||||
}
|
||||
|
||||
pub async fn create_bot_handler(
|
||||
Extension(state): Extension<Arc<AppState>>,
|
||||
Json(payload): Json<HashMap<String, String>>,
|
||||
) -> impl IntoResponse {
|
||||
let bot_name = payload
|
||||
.get("bot_name")
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "default".to_string());
|
||||
|
||||
let orchestrator = BotOrchestrator::new(state);
|
||||
if let Err(e) = orchestrator.mount_all_bots().await {
|
||||
error!("Failed to mount bots: {}", e);
|
||||
}
|
||||
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(serde_json::json!({ "status": format!("bot '{}' created", bot_name) })),
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn mount_bot_handler(
|
||||
Extension(state): Extension<Arc<AppState>>,
|
||||
Json(payload): Json<HashMap<String, String>>,
|
||||
) -> impl IntoResponse {
|
||||
let bot_guid = payload.get("bot_guid").cloned().unwrap_or_default();
|
||||
|
||||
let orchestrator = BotOrchestrator::new(state);
|
||||
if let Err(e) = orchestrator.mount_all_bots().await {
|
||||
error!("Failed to mount bot: {}", e);
|
||||
}
|
||||
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(serde_json::json!({ "status": format!("bot '{}' mounted", bot_guid) })),
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn handle_user_input_handler(
|
||||
Extension(state): Extension<Arc<AppState>>,
|
||||
Json(payload): Json<HashMap<String, String>>,
|
||||
) -> impl IntoResponse {
|
||||
let session_id = payload.get("session_id").cloned().unwrap_or_default();
|
||||
let user_input = payload.get("input").cloned().unwrap_or_default();
|
||||
|
||||
info!(
|
||||
"Processing user input: {} for session: {}",
|
||||
user_input, session_id
|
||||
);
|
||||
|
||||
let orchestrator = BotOrchestrator::new(state);
|
||||
if let Ok(sessions) = orchestrator.get_user_sessions(Uuid::nil()).await {
|
||||
info!("Found {} sessions", sessions.len());
|
||||
}
|
||||
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(serde_json::json!({ "status": format!("processed: {}", user_input) })),
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn get_user_sessions_handler(
|
||||
Extension(state): Extension<Arc<AppState>>,
|
||||
Json(payload): Json<HashMap<String, String>>,
|
||||
) -> impl IntoResponse {
|
||||
let user_id = payload
|
||||
.get("user_id")
|
||||
.and_then(|id| Uuid::parse_str(id).ok())
|
||||
.unwrap_or_else(Uuid::nil);
|
||||
|
||||
let orchestrator = BotOrchestrator::new(state);
|
||||
match orchestrator.get_user_sessions(user_id).await {
|
||||
Ok(sessions) => (
|
||||
StatusCode::OK,
|
||||
Json(serde_json::json!({ "sessions": sessions })),
|
||||
),
|
||||
Err(e) => {
|
||||
error!("Failed to get sessions: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_conversation_history_handler(
|
||||
Extension(state): Extension<Arc<AppState>>,
|
||||
Json(payload): Json<HashMap<String, String>>,
|
||||
) -> impl IntoResponse {
|
||||
let session_id = payload
|
||||
.get("session_id")
|
||||
.and_then(|id| Uuid::parse_str(id).ok())
|
||||
.unwrap_or_else(Uuid::nil);
|
||||
let user_id = payload
|
||||
.get("user_id")
|
||||
.and_then(|id| Uuid::parse_str(id).ok())
|
||||
.unwrap_or_else(Uuid::nil);
|
||||
|
||||
let orchestrator = BotOrchestrator::new(state);
|
||||
match orchestrator
|
||||
.get_conversation_history(session_id, user_id)
|
||||
.await
|
||||
{
|
||||
Ok(history) => (
|
||||
StatusCode::OK,
|
||||
Json(serde_json::json!({ "history": history })),
|
||||
),
|
||||
Err(e) => {
|
||||
error!("Failed to get history: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_warning_handler(
|
||||
Extension(state): Extension<Arc<AppState>>,
|
||||
Json(payload): Json<HashMap<String, String>>,
|
||||
) -> impl IntoResponse {
|
||||
let message = payload
|
||||
.get("message")
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "Warning".to_string());
|
||||
let session_id = payload.get("session_id").cloned().unwrap_or_default();
|
||||
|
||||
warn!("Warning for session {}: {}", session_id, message);
|
||||
|
||||
let orchestrator = BotOrchestrator::new(state);
|
||||
info!("Orchestrator created for warning");
|
||||
|
||||
// Use orchestrator to log state
|
||||
if let Ok(sessions) = orchestrator.get_user_sessions(Uuid::nil()).await {
|
||||
info!("Current active sessions: {}", sessions.len());
|
||||
}
|
||||
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(serde_json::json!({ "status": "warning sent", "message": message })),
|
||||
)
|
||||
}
|
||||
|
|
@ -1,3 +1,5 @@
|
|||
#![allow(dead_code)]
|
||||
|
||||
use anyhow::Result;
|
||||
use log::{error, info, warn};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
|
@ -222,6 +224,7 @@ impl DocumentProcessor {
|
|||
}
|
||||
|
||||
/// Extract PDF using poppler-utils
|
||||
#[allow(dead_code)]
|
||||
async fn extract_pdf_with_poppler(&self, file_path: &Path) -> Result<String> {
|
||||
let output = tokio::process::Command::new("pdftotext")
|
||||
.arg(file_path)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
#![allow(dead_code)]
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
#![allow(dead_code)]
|
||||
|
||||
use crate::shared::state::AppState;
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
#![allow(dead_code)]
|
||||
|
||||
use crate::basic::compiler::BasicCompiler;
|
||||
use crate::config::ConfigManager;
|
||||
use crate::core::kb::KnowledgeBaseManager;
|
||||
|
|
|
|||
|
|
@ -1,15 +1,17 @@
|
|||
#![allow(dead_code)]
|
||||
|
||||
use anyhow::Result;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
#[cfg(feature = "vectordb")]
|
||||
use std::sync::Arc;
|
||||
use tokio::fs;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[cfg(feature = "vectordb")]
|
||||
use qdrant_client::{
|
||||
prelude::*,
|
||||
qdrant::{vectors_config::Config, CreateCollection, Distance, VectorParams, VectorsConfig},
|
||||
use qdrant_client::qdrant::{
|
||||
vectors_config::Config, CreateCollection, Distance, PointStruct, VectorParams, VectorsConfig,
|
||||
};
|
||||
|
||||
/// File metadata for vector DB indexing
|
||||
|
|
@ -80,7 +82,7 @@ impl UserDriveVectorDB {
|
|||
/// Initialize vector DB collection
|
||||
#[cfg(feature = "vectordb")]
|
||||
pub async fn initialize(&mut self, qdrant_url: &str) -> Result<()> {
|
||||
let client = QdrantClient::from_url(qdrant_url).build()?;
|
||||
let client = qdrant_client::Qdrant::from_url(qdrant_url).build()?;
|
||||
|
||||
// Check if collection exists
|
||||
let collections = client.list_collections().await?;
|
||||
|
|
@ -130,10 +132,14 @@ impl UserDriveVectorDB {
|
|||
.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Vector DB not initialized"))?;
|
||||
|
||||
let point = PointStruct::new(file.id.clone(), embedding, serde_json::to_value(file)?);
|
||||
let payload = serde_json::to_value(file)?
|
||||
.as_object()
|
||||
.map(|m| m.clone())
|
||||
.unwrap_or_default();
|
||||
let point = PointStruct::new(file.id.clone(), embedding, payload);
|
||||
|
||||
client
|
||||
.upsert_points_blocking(self._collection_name.clone(), vec![point], None)
|
||||
.upsert_points(self._collection_name.clone(), None, vec![point], None)
|
||||
.await?;
|
||||
|
||||
log::debug!("Indexed file: {} - {}", file.id, file.file_name);
|
||||
|
|
@ -161,15 +167,17 @@ impl UserDriveVectorDB {
|
|||
let points: Vec<PointStruct> = files
|
||||
.iter()
|
||||
.filter_map(|(file, embedding)| {
|
||||
serde_json::to_value(file).ok().map(|payload| {
|
||||
PointStruct::new(file.id.clone(), embedding.clone(), payload)
|
||||
serde_json::to_value(file).ok().and_then(|v| {
|
||||
v.as_object().map(|m| {
|
||||
PointStruct::new(file.id.clone(), embedding.clone(), m.clone())
|
||||
})
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
if !points.is_empty() {
|
||||
client
|
||||
.upsert_points_blocking(self._collection_name.clone(), points, None)
|
||||
.upsert_points(self._collection_name.clone(), None, points, None)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -50,7 +50,6 @@ pub struct SaveDraftRequest {
|
|||
pub bcc: Option<String>,
|
||||
pub subject: String,
|
||||
pub body: String,
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
// ===== Request/Response Structures =====
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ use crate::shared::utils::{estimate_token_count, DbPool};
|
|||
|
||||
/// Configuration for semantic caching
|
||||
#[derive(Clone, Debug)]
|
||||
|
||||
pub struct CacheConfig {
|
||||
/// TTL for cache entries in seconds
|
||||
pub ttl: u64,
|
||||
|
|
@ -42,6 +43,7 @@ impl Default for CacheConfig {
|
|||
|
||||
/// Cached LLM response with metadata
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
|
||||
pub struct CachedResponse {
|
||||
/// The actual response text
|
||||
pub response: String,
|
||||
|
|
@ -72,6 +74,7 @@ impl std::fmt::Debug for CachedLLMProvider {
|
|||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CachedLLMProvider {
|
||||
/// The underlying LLM provider
|
||||
provider: Arc<dyn LLMProvider>,
|
||||
|
|
@ -87,6 +90,7 @@ pub struct CachedLLMProvider {
|
|||
|
||||
/// Trait for embedding services
|
||||
#[async_trait]
|
||||
|
||||
pub trait EmbeddingService: Send + Sync {
|
||||
async fn get_embedding(
|
||||
&self,
|
||||
|
|
@ -247,6 +251,7 @@ impl CachedLLMProvider {
|
|||
}
|
||||
|
||||
/// Try to get a cached response
|
||||
|
||||
async fn get_cached_response(
|
||||
&self,
|
||||
prompt: &str,
|
||||
|
|
@ -309,6 +314,7 @@ impl CachedLLMProvider {
|
|||
}
|
||||
|
||||
/// Find semantically similar cached responses
|
||||
|
||||
async fn find_similar_cached(
|
||||
&self,
|
||||
prompt: &str,
|
||||
|
|
@ -456,6 +462,7 @@ impl CachedLLMProvider {
|
|||
}
|
||||
|
||||
/// Get cache statistics
|
||||
|
||||
pub async fn get_cache_stats(
|
||||
&self,
|
||||
) -> Result<CacheStats, Box<dyn std::error::Error + Send + Sync>> {
|
||||
|
|
@ -488,6 +495,7 @@ impl CachedLLMProvider {
|
|||
}
|
||||
|
||||
/// Clear cache for a specific model or all models
|
||||
|
||||
pub async fn clear_cache(
|
||||
&self,
|
||||
model: Option<&str>,
|
||||
|
|
@ -514,6 +522,7 @@ impl CachedLLMProvider {
|
|||
|
||||
/// Cache statistics
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
|
||||
pub struct CacheStats {
|
||||
pub total_entries: usize,
|
||||
pub total_hits: u32,
|
||||
|
|
@ -645,6 +654,7 @@ impl LLMProvider for CachedLLMProvider {
|
|||
// Manual Debug implementation needed for trait objects
|
||||
|
||||
#[derive(Debug)]
|
||||
|
||||
pub struct LocalEmbeddingService {
|
||||
embedding_url: String,
|
||||
model: String,
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
#![allow(dead_code)]
|
||||
|
||||
use crate::core::config::ConfigManager;
|
||||
use crate::llm::llm_models;
|
||||
use crate::shared::state::AppState;
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
#![allow(dead_code)]
|
||||
|
||||
use super::ModelHandler;
|
||||
use regex;
|
||||
#[derive(Debug)]
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
#![allow(dead_code)]
|
||||
|
||||
use super::ModelHandler;
|
||||
#[derive(Debug)]
|
||||
pub struct GptOss120bHandler {}
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
#![allow(dead_code)]
|
||||
|
||||
use super::ModelHandler;
|
||||
#[derive(Debug)]
|
||||
pub struct GptOss20bHandler;
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
#![allow(dead_code)]
|
||||
|
||||
pub mod deepseek_r3;
|
||||
pub mod gpt_oss_120b;
|
||||
pub mod gpt_oss_20b;
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
#![allow(dead_code)]
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::StreamExt;
|
||||
use log::{info, trace};
|
||||
|
|
|
|||
|
|
@ -8,6 +8,17 @@ use tokio::sync::RwLock;
|
|||
use tokio::time::{sleep, Duration};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[cfg(feature = "vectordb")]
|
||||
use crate::drive::vectordb::UserDriveVectorDB;
|
||||
#[cfg(feature = "vectordb")]
|
||||
use crate::drive::vectordb::{FileContentExtractor, FileDocument};
|
||||
#[cfg(all(feature = "vectordb", feature = "email"))]
|
||||
use crate::email::vectordb::UserEmailVectorDB;
|
||||
#[cfg(all(feature = "vectordb", feature = "email"))]
|
||||
use crate::email::vectordb::{EmailDocument, EmailEmbeddingGenerator};
|
||||
use crate::shared::utils::DbPool;
|
||||
use anyhow::Result;
|
||||
|
||||
// UserWorkspace struct for managing user workspace paths
|
||||
#[derive(Debug, Clone)]
|
||||
struct UserWorkspace {
|
||||
|
|
@ -26,14 +37,13 @@ impl UserWorkspace {
|
|||
}
|
||||
|
||||
fn get_path(&self) -> PathBuf {
|
||||
self.root.join(self.bot_id.to_string()).join(self.user_id.to_string())
|
||||
self.root
|
||||
.join(self.bot_id.to_string())
|
||||
.join(self.user_id.to_string())
|
||||
}
|
||||
}
|
||||
use crate::shared::utils::DbPool;
|
||||
|
||||
// VectorDB types are defined locally in this module
|
||||
#[cfg(feature = "vectordb")]
|
||||
use qdrant_client::prelude::*;
|
||||
|
||||
/// Indexing job status
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
|
|
@ -93,7 +103,7 @@ impl VectorDBIndexer {
|
|||
db_pool,
|
||||
work_root,
|
||||
qdrant_url,
|
||||
embedding_generator: Arc::new(EmailEmbeddingGenerator::new(llm_endpoint)),
|
||||
embedding_generator: Arc::new(EmailEmbeddingGenerator { llm_endpoint }),
|
||||
jobs: Arc::new(RwLock::new(HashMap::new())),
|
||||
running: Arc::new(RwLock::new(false)),
|
||||
interval_seconds: 300, // Run every 5 minutes
|
||||
|
|
@ -373,7 +383,7 @@ impl VectorDBIndexer {
|
|||
for file in chunk {
|
||||
// Check if file should be indexed
|
||||
let mime_type = file.mime_type.as_ref().map(|s| s.as_str()).unwrap_or("");
|
||||
if !FileContentExtractor::should_index(mime_type, file.file_size) {
|
||||
if !FileContentExtractor::should_index(&mime_type, file.file_size) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
@ -448,7 +458,7 @@ impl VectorDBIndexer {
|
|||
&self,
|
||||
_user_id: Uuid,
|
||||
_account_id: &str,
|
||||
) -> Result<Vec<EmailDocument>> {
|
||||
) -> Result<Vec<EmailDocument>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
// TODO: Implement actual email fetching from IMAP
|
||||
// This should:
|
||||
// 1. Connect to user's email account
|
||||
|
|
@ -460,7 +470,10 @@ impl VectorDBIndexer {
|
|||
}
|
||||
|
||||
/// Get unindexed files (placeholder - needs actual implementation)
|
||||
async fn get_unindexed_files(&self, _user_id: Uuid) -> Result<Vec<FileDocument>> {
|
||||
async fn get_unindexed_files(
|
||||
&self,
|
||||
_user_id: Uuid,
|
||||
) -> Result<Vec<FileDocument>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
// TODO: Implement actual file fetching from drive
|
||||
// This should:
|
||||
// 1. List user's files from MinIO/S3
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue