From faeae250bcbc15775c9bde513f1daf9e6ef34e41 Mon Sep 17 00:00:00 2001 From: "Rodrigo Rodriguez (Pragmatismo)" Date: Sat, 10 Jan 2026 09:41:12 -0300 Subject: [PATCH] Add security protection module with sudo-based privilege escalation - Create installer.rs for 'botserver install protection' command - Requires root to install packages and create sudoers config - Sudoers uses exact commands (no wildcards) for security - Update all tool files (lynis, rkhunter, chkrootkit, suricata, lmd) to use sudo - Update manager.rs service management to use sudo - Add 'sudo' and 'visudo' to command_guard.rs whitelist - Update CLI with install/remove/status protection commands Security model: - Installation requires root (sudo botserver install protection) - Runtime uses sudoers NOPASSWD for specific commands only - No wildcards in sudoers - exact command specifications - Tools run on host system, not in containers --- Cargo.toml | 8 +- src/billing/middleware.rs | 14 +- src/canvas/mod.rs | 699 ++++++++++++++++++++++++++ src/core/package_manager/cli.rs | 74 +++ src/core/shared/state.rs | 15 + src/core/urls.rs | 60 +++ src/embedded_ui.rs | 111 ++++ src/learn/mod.rs | 2 - src/lib.rs | 3 + src/main.rs | 107 +++- src/player/mod.rs | 208 ++++++++ src/security/auth.rs | 153 ++++++ src/security/auth_provider.rs | 591 ++++++++++++++++++++++ src/security/command_guard.rs | 10 + src/security/mod.rs | 24 +- src/security/protection/api.rs | 403 +++++++++++++++ src/security/protection/chkrootkit.rs | 293 +++++++++++ src/security/protection/installer.rs | 597 ++++++++++++++++++++++ src/security/protection/lmd.rs | 481 ++++++++++++++++++ src/security/protection/lynis.rs | 273 ++++++++++ src/security/protection/manager.rs | 621 +++++++++++++++++++++++ src/security/protection/mod.rs | 12 + src/security/protection/rkhunter.rs | 320 ++++++++++++ src/security/protection/suricata.rs | 385 ++++++++++++++ src/security/rbac_middleware.rs | 173 ++++++- src/settings/mod.rs | 2 + src/settings/security_admin.rs | 405 +++++++++++++++ src/social/mod.rs | 30 ++ src/video/mod.rs | 1 - src/workspaces/mod.rs | 217 ++++++++ 30 files changed, 6260 insertions(+), 32 deletions(-) create mode 100644 src/canvas/mod.rs create mode 100644 src/embedded_ui.rs create mode 100644 src/player/mod.rs create mode 100644 src/security/auth_provider.rs create mode 100644 src/security/protection/api.rs create mode 100644 src/security/protection/chkrootkit.rs create mode 100644 src/security/protection/installer.rs create mode 100644 src/security/protection/lmd.rs create mode 100644 src/security/protection/lynis.rs create mode 100644 src/security/protection/manager.rs create mode 100644 src/security/protection/mod.rs create mode 100644 src/security/protection/rkhunter.rs create mode 100644 src/security/protection/suricata.rs create mode 100644 src/settings/security_admin.rs diff --git a/Cargo.toml b/Cargo.toml index 2d29a6e32..b2feed6a9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -63,7 +63,7 @@ msteams = [] # ===== PRODUCTIVITY FEATURES ===== chat = [] -drive = ["dep:aws-config", "dep:aws-sdk-s3", "dep:pdf-extract", "dep:zip", "dep:downloader", "dep:mime_guess", "dep:flate2", "dep:tar"] +drive = ["dep:aws-config", "dep:aws-sdk-s3", "dep:pdf-extract", "dep:zip", "dep:downloader", "dep:flate2", "dep:tar"] tasks = ["dep:cron"] calendar = [] meet = ["dep:livekit"] @@ -183,7 +183,7 @@ pdf-extract = { version = "0.10.0", optional = true } quick-xml = { version = "0.37", features = ["serialize"] } zip = { version = "2.2", optional = true } downloader = { version = "0.2", optional = true } -mime_guess = { version = "2.0", optional = true } + flate2 = { version = "1.0", optional = true } tar = { version = "0.4", optional = true } @@ -247,6 +247,10 @@ rss = "2.0" # HTML parsing/web scraping scraper = "0.25" walkdir = "2.5.0" + +# Embedded static files (UI fallback when no external folder) +rust-embed = "8.5" +mime_guess = "2.0" hyper-util = { version = "0.1.19", features = ["client-legacy", "tokio"] } http-body-util = "0.1.3" diff --git a/src/billing/middleware.rs b/src/billing/middleware.rs index 535a7b794..634b6a5b8 100644 --- a/src/billing/middleware.rs +++ b/src/billing/middleware.rs @@ -1,7 +1,7 @@ use axum::{ body::Body, extract::State, - http::{Request, StatusCode}, + http::{header::HeaderValue, Request, StatusCode}, middleware::Next, response::{IntoResponse, Response}, Json, @@ -89,14 +89,14 @@ pub async fn quota_middleware( let headers = response.headers_mut(); headers.insert( "X-Quota-Warning", - message.parse().unwrap_or_else(|_| "quota warning".parse().unwrap()), + message.parse().unwrap_or_else(|_| HeaderValue::from_static("quota warning")), ); headers.insert( "X-Quota-Usage-Percent", percentage .to_string() .parse() - .unwrap_or_else(|_| "0".parse().unwrap()), + .unwrap_or_else(|_| HeaderValue::from_static("0")), ); response @@ -148,14 +148,14 @@ pub async fn api_rate_limit_middleware( let headers = response.headers_mut(); headers.insert( "X-RateLimit-Warning", - message.parse().unwrap_or_else(|_| "rate limit warning".parse().unwrap()), + message.parse().unwrap_or_else(|_| HeaderValue::from_static("rate limit warning")), ); headers.insert( "X-RateLimit-Usage-Percent", percentage .to_string() .parse() - .unwrap_or_else(|_| "0".parse().unwrap()), + .unwrap_or_else(|_| HeaderValue::from_static("0")), ); response } @@ -212,14 +212,14 @@ pub async fn message_quota_middleware( let headers = response.headers_mut(); headers.insert( "X-Message-Quota-Warning", - message.parse().unwrap_or_else(|_| "message quota warning".parse().unwrap()), + message.parse().unwrap_or_else(|_| HeaderValue::from_static("message quota warning")), ); headers.insert( "X-Message-Quota-Usage-Percent", percentage .to_string() .parse() - .unwrap_or_else(|_| "0".parse().unwrap()), + .unwrap_or_else(|_| HeaderValue::from_static("0")), ); response } diff --git a/src/canvas/mod.rs b/src/canvas/mod.rs new file mode 100644 index 000000000..58c883a19 --- /dev/null +++ b/src/canvas/mod.rs @@ -0,0 +1,699 @@ +use axum::{ + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, + routing::{get, post, put}, + Json, Router, +}; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; +use uuid::Uuid; + +use crate::shared::state::AppState; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Canvas { + pub id: Uuid, + pub organization_id: Uuid, + pub name: String, + pub description: Option, + pub width: u32, + pub height: u32, + pub background_color: String, + pub elements: Vec, + pub created_by: Uuid, + pub created_at: DateTime, + pub updated_at: DateTime, + pub is_public: bool, + pub collaborators: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CanvasElement { + pub id: Uuid, + pub element_type: ElementType, + pub x: f64, + pub y: f64, + pub width: f64, + pub height: f64, + pub rotation: f64, + pub properties: ElementProperties, + pub z_index: i32, + pub locked: bool, + pub created_by: Uuid, + pub created_at: DateTime, + pub updated_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum ElementType { + Rectangle, + Ellipse, + Line, + Arrow, + FreehandPath, + Text, + Image, + Sticky, + Frame, + Connector, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ElementProperties { + pub fill_color: Option, + pub stroke_color: Option, + pub stroke_width: Option, + pub opacity: Option, + pub text: Option, + pub font_size: Option, + pub font_family: Option, + pub text_align: Option, + pub image_url: Option, + pub path_data: Option, + pub corner_radius: Option, + pub start_arrow: Option, + pub end_arrow: Option, +} + +impl Default for ElementProperties { + fn default() -> Self { + Self { + fill_color: Some("#ffffff".to_string()), + stroke_color: Some("#000000".to_string()), + stroke_width: Some(2.0), + opacity: Some(1.0), + text: None, + font_size: Some(16.0), + font_family: Some("Inter".to_string()), + text_align: Some("left".to_string()), + image_url: None, + path_data: None, + corner_radius: Some(0.0), + start_arrow: None, + end_arrow: None, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CanvasSummary { + pub id: Uuid, + pub name: String, + pub description: Option, + pub thumbnail_url: Option, + pub element_count: usize, + pub created_at: DateTime, + pub updated_at: DateTime, + pub is_public: bool, +} + +#[derive(Debug, Deserialize)] +pub struct CreateCanvasRequest { + pub name: String, + pub description: Option, + pub width: Option, + pub height: Option, + pub background_color: Option, +} + +#[derive(Debug, Deserialize)] +pub struct UpdateCanvasRequest { + pub name: Option, + pub description: Option, + pub width: Option, + pub height: Option, + pub background_color: Option, + pub is_public: Option, +} + +#[derive(Debug, Deserialize)] +pub struct CreateElementRequest { + pub element_type: ElementType, + pub x: f64, + pub y: f64, + pub width: f64, + pub height: f64, + pub rotation: Option, + pub properties: Option, + pub z_index: Option, +} + +#[derive(Debug, Deserialize)] +pub struct UpdateElementRequest { + pub x: Option, + pub y: Option, + pub width: Option, + pub height: Option, + pub rotation: Option, + pub properties: Option, + pub z_index: Option, + pub locked: Option, +} + +#[derive(Debug, Deserialize)] +pub struct ExportRequest { + pub format: ExportFormat, + pub scale: Option, + pub background: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum ExportFormat { + Png, + Svg, + Pdf, + Json, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ExportResponse { + pub format: ExportFormat, + pub url: Option, + pub data: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CollaborationSession { + pub canvas_id: Uuid, + pub user_id: Uuid, + pub cursor_x: f64, + pub cursor_y: f64, + pub selection: Vec, + pub connected_at: DateTime, +} + +pub struct CanvasService { + canvases: Arc>>, +} + +impl CanvasService { + pub fn new() -> Self { + Self { + canvases: Arc::new(RwLock::new(HashMap::new())), + } + } + + pub async fn list_canvases(&self, org_id: Uuid) -> Vec { + let canvases = self.canvases.read().await; + canvases + .values() + .filter(|c| c.organization_id == org_id) + .map(|c| CanvasSummary { + id: c.id, + name: c.name.clone(), + description: c.description.clone(), + thumbnail_url: None, + element_count: c.elements.len(), + created_at: c.created_at, + updated_at: c.updated_at, + is_public: c.is_public, + }) + .collect() + } + + pub async fn create_canvas( + &self, + org_id: Uuid, + user_id: Uuid, + req: CreateCanvasRequest, + ) -> Canvas { + let now = Utc::now(); + let canvas = Canvas { + id: Uuid::new_v4(), + organization_id: org_id, + name: req.name, + description: req.description, + width: req.width.unwrap_or(1920), + height: req.height.unwrap_or(1080), + background_color: req.background_color.unwrap_or_else(|| "#ffffff".to_string()), + elements: vec![], + created_by: user_id, + created_at: now, + updated_at: now, + is_public: false, + collaborators: vec![user_id], + }; + + let mut canvases = self.canvases.write().await; + canvases.insert(canvas.id, canvas.clone()); + canvas + } + + pub async fn get_canvas(&self, canvas_id: Uuid) -> Option { + let canvases = self.canvases.read().await; + canvases.get(&canvas_id).cloned() + } + + pub async fn update_canvas( + &self, + canvas_id: Uuid, + req: UpdateCanvasRequest, + ) -> Option { + let mut canvases = self.canvases.write().await; + if let Some(canvas) = canvases.get_mut(&canvas_id) { + if let Some(name) = req.name { + canvas.name = name; + } + if let Some(desc) = req.description { + canvas.description = Some(desc); + } + if let Some(width) = req.width { + canvas.width = width; + } + if let Some(height) = req.height { + canvas.height = height; + } + if let Some(bg) = req.background_color { + canvas.background_color = bg; + } + if let Some(public) = req.is_public { + canvas.is_public = public; + } + canvas.updated_at = Utc::now(); + return Some(canvas.clone()); + } + None + } + + pub async fn delete_canvas(&self, canvas_id: Uuid) -> bool { + let mut canvases = self.canvases.write().await; + canvases.remove(&canvas_id).is_some() + } + + pub async fn add_element( + &self, + canvas_id: Uuid, + user_id: Uuid, + req: CreateElementRequest, + ) -> Option { + let mut canvases = self.canvases.write().await; + if let Some(canvas) = canvases.get_mut(&canvas_id) { + let now = Utc::now(); + let element = CanvasElement { + id: Uuid::new_v4(), + element_type: req.element_type, + x: req.x, + y: req.y, + width: req.width, + height: req.height, + rotation: req.rotation.unwrap_or(0.0), + properties: req.properties.unwrap_or_default(), + z_index: req.z_index.unwrap_or(canvas.elements.len() as i32), + locked: false, + created_by: user_id, + created_at: now, + updated_at: now, + }; + canvas.elements.push(element.clone()); + canvas.updated_at = now; + return Some(element); + } + None + } + + pub async fn update_element( + &self, + canvas_id: Uuid, + element_id: Uuid, + req: UpdateElementRequest, + ) -> Option { + let mut canvases = self.canvases.write().await; + if let Some(canvas) = canvases.get_mut(&canvas_id) { + if let Some(element) = canvas.elements.iter_mut().find(|e| e.id == element_id) { + if let Some(x) = req.x { + element.x = x; + } + if let Some(y) = req.y { + element.y = y; + } + if let Some(width) = req.width { + element.width = width; + } + if let Some(height) = req.height { + element.height = height; + } + if let Some(rotation) = req.rotation { + element.rotation = rotation; + } + if let Some(props) = req.properties { + element.properties = props; + } + if let Some(z) = req.z_index { + element.z_index = z; + } + if let Some(locked) = req.locked { + element.locked = locked; + } + element.updated_at = Utc::now(); + canvas.updated_at = Utc::now(); + return Some(element.clone()); + } + } + None + } + + pub async fn delete_element(&self, canvas_id: Uuid, element_id: Uuid) -> bool { + let mut canvases = self.canvases.write().await; + if let Some(canvas) = canvases.get_mut(&canvas_id) { + let len_before = canvas.elements.len(); + canvas.elements.retain(|e| e.id != element_id); + if canvas.elements.len() < len_before { + canvas.updated_at = Utc::now(); + return true; + } + } + false + } + + pub async fn export_canvas( + &self, + canvas_id: Uuid, + req: ExportRequest, + ) -> Option { + let canvases = self.canvases.read().await; + let canvas = canvases.get(&canvas_id)?; + + match req.format { + ExportFormat::Json => { + let json = serde_json::to_string_pretty(canvas).ok()?; + Some(ExportResponse { + format: ExportFormat::Json, + url: None, + data: Some(json), + }) + } + ExportFormat::Svg => { + let svg = generate_svg(canvas, req.background.unwrap_or(true)); + Some(ExportResponse { + format: ExportFormat::Svg, + url: None, + data: Some(svg), + }) + } + _ => Some(ExportResponse { + format: req.format, + url: Some(format!("/api/canvas/{}/export/file", canvas_id)), + data: None, + }), + } + } +} + +impl Default for CanvasService { + fn default() -> Self { + Self::new() + } +} + +fn generate_svg(canvas: &Canvas, include_background: bool) -> String { + let mut svg = format!( + r#""#, + canvas.width, canvas.height, canvas.width, canvas.height + ); + + if include_background { + svg.push_str(&format!( + r#""#, + canvas.background_color + )); + } + + for element in &canvas.elements { + let transform = if element.rotation != 0.0 { + format!( + r#" transform="rotate({} {} {})""#, + element.rotation, + element.x + element.width / 2.0, + element.y + element.height / 2.0 + ) + } else { + String::new() + }; + + let fill = element + .properties + .fill_color + .as_deref() + .unwrap_or("transparent"); + let stroke = element + .properties + .stroke_color + .as_deref() + .unwrap_or("none"); + let stroke_width = element.properties.stroke_width.unwrap_or(1.0); + let opacity = element.properties.opacity.unwrap_or(1.0); + + match element.element_type { + ElementType::Rectangle => { + let radius = element.properties.corner_radius.unwrap_or(0.0); + svg.push_str(&format!( + r#""#, + element.x, element.y, element.width, element.height, + radius, fill, stroke, stroke_width, opacity, transform + )); + } + ElementType::Ellipse => { + svg.push_str(&format!( + r#""#, + element.x + element.width / 2.0, + element.y + element.height / 2.0, + element.width / 2.0, + element.height / 2.0, + fill, stroke, stroke_width, opacity, transform + )); + } + ElementType::Text => { + let text = element.properties.text.as_deref().unwrap_or(""); + let font_size = element.properties.font_size.unwrap_or(16.0); + let font_family = element + .properties + .font_family + .as_deref() + .unwrap_or("sans-serif"); + svg.push_str(&format!( + r#" + {} + "#, + element.x, element.y + font_size, font_size, font_family, + fill, opacity, transform, text + )); + } + ElementType::FreehandPath => { + if let Some(path_data) = &element.properties.path_data { + svg.push_str(&format!( + r#""#, + path_data, stroke, stroke_width, opacity, transform + )); + } + } + ElementType::Line | ElementType::Arrow => { + let marker = if element.element_type == ElementType::Arrow { + r#" marker-end="url(#arrowhead)""# + } else { + "" + }; + svg.push_str(&format!( + r#""#, + element.x, element.y, + element.x + element.width, element.y + element.height, + stroke, stroke_width, opacity, marker, transform + )); + } + _ => {} + } + } + + svg.push_str(""); + svg +} + +#[derive(Debug, Serialize)] +pub struct CanvasError { + pub error: String, + pub code: String, +} + +impl IntoResponse for CanvasError { + fn into_response(self) -> axum::response::Response { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({"error": self.error, "code": self.code})), + ) + .into_response() + } +} + +async fn list_canvases( + State(_state): State>, +) -> Result>, CanvasError> { + let service = CanvasService::new(); + let org_id = Uuid::nil(); + let canvases = service.list_canvases(org_id).await; + Ok(Json(canvases)) +} + +async fn create_canvas( + State(_state): State>, + Json(req): Json, +) -> Result, CanvasError> { + let service = CanvasService::new(); + let org_id = Uuid::nil(); + let user_id = Uuid::nil(); + let canvas = service.create_canvas(org_id, user_id, req).await; + Ok(Json(canvas)) +} + +async fn get_canvas( + State(_state): State>, + Path(canvas_id): Path, +) -> Result, CanvasError> { + let service = CanvasService::new(); + let canvas = service.get_canvas(canvas_id).await.ok_or_else(|| CanvasError { + error: "Canvas not found".to_string(), + code: "CANVAS_NOT_FOUND".to_string(), + })?; + Ok(Json(canvas)) +} + +async fn update_canvas( + State(_state): State>, + Path(canvas_id): Path, + Json(req): Json, +) -> Result, CanvasError> { + let service = CanvasService::new(); + let canvas = service + .update_canvas(canvas_id, req) + .await + .ok_or_else(|| CanvasError { + error: "Canvas not found".to_string(), + code: "CANVAS_NOT_FOUND".to_string(), + })?; + Ok(Json(canvas)) +} + +async fn delete_canvas( + State(_state): State>, + Path(canvas_id): Path, +) -> Result { + let service = CanvasService::new(); + if service.delete_canvas(canvas_id).await { + Ok(StatusCode::NO_CONTENT) + } else { + Err(CanvasError { + error: "Canvas not found".to_string(), + code: "CANVAS_NOT_FOUND".to_string(), + }) + } +} + +async fn list_elements( + State(_state): State>, + Path(canvas_id): Path, +) -> Result>, CanvasError> { + let service = CanvasService::new(); + let canvas = service.get_canvas(canvas_id).await.ok_or_else(|| CanvasError { + error: "Canvas not found".to_string(), + code: "CANVAS_NOT_FOUND".to_string(), + })?; + Ok(Json(canvas.elements)) +} + +async fn create_element( + State(_state): State>, + Path(canvas_id): Path, + Json(req): Json, +) -> Result, CanvasError> { + let service = CanvasService::new(); + let user_id = Uuid::nil(); + let element = service + .add_element(canvas_id, user_id, req) + .await + .ok_or_else(|| CanvasError { + error: "Canvas not found".to_string(), + code: "CANVAS_NOT_FOUND".to_string(), + })?; + Ok(Json(element)) +} + +async fn update_element( + State(_state): State>, + Path((canvas_id, element_id)): Path<(Uuid, Uuid)>, + Json(req): Json, +) -> Result, CanvasError> { + let service = CanvasService::new(); + let element = service + .update_element(canvas_id, element_id, req) + .await + .ok_or_else(|| CanvasError { + error: "Element not found".to_string(), + code: "ELEMENT_NOT_FOUND".to_string(), + })?; + Ok(Json(element)) +} + +async fn delete_element( + State(_state): State>, + Path((canvas_id, element_id)): Path<(Uuid, Uuid)>, +) -> Result { + let service = CanvasService::new(); + if service.delete_element(canvas_id, element_id).await { + Ok(StatusCode::NO_CONTENT) + } else { + Err(CanvasError { + error: "Element not found".to_string(), + code: "ELEMENT_NOT_FOUND".to_string(), + }) + } +} + +async fn export_canvas( + State(_state): State>, + Path(canvas_id): Path, + Json(req): Json, +) -> Result, CanvasError> { + let service = CanvasService::new(); + let response = service + .export_canvas(canvas_id, req) + .await + .ok_or_else(|| CanvasError { + error: "Canvas not found".to_string(), + code: "CANVAS_NOT_FOUND".to_string(), + })?; + Ok(Json(response)) +} + +async fn get_collaboration_info( + State(_state): State>, + Path(canvas_id): Path, +) -> Result>, CanvasError> { + let _ = canvas_id; + Ok(Json(vec![])) +} + +pub fn configure_canvas_routes() -> Router> { + Router::new() + .route("/api/canvas", get(list_canvases).post(create_canvas)) + .route( + "/api/canvas/:canvas_id", + get(get_canvas).put(update_canvas).delete(delete_canvas), + ) + .route( + "/api/canvas/:canvas_id/elements", + get(list_elements).post(create_element), + ) + .route( + "/api/canvas/:canvas_id/elements/:element_id", + put(update_element).delete(delete_element), + ) + .route("/api/canvas/:canvas_id/export", post(export_canvas)) + .route( + "/api/canvas/:canvas_id/collaborate", + get(get_collaboration_info), + ) +} diff --git a/src/core/package_manager/cli.rs b/src/core/package_manager/cli.rs index b4775ea9b..249612e24 100644 --- a/src/core/package_manager/cli.rs +++ b/src/core/package_manager/cli.rs @@ -1,6 +1,7 @@ use crate::core::secrets::{SecretPaths, SecretsManager}; use crate::package_manager::{get_all_components, InstallMode, PackageManager}; use crate::security::command_guard::SafeCommand; +use crate::security::protection::{ProtectionInstaller, VerifyResult}; use anyhow::Result; use rand::Rng; use std::collections::HashMap; @@ -87,6 +88,12 @@ pub async fn run() -> Result<()> { return Ok(()); } let component = &args[2]; + + if component == "protection" { + install_protection()?; + return Ok(()); + } + let mode = if args.contains(&"--container".to_string()) { InstallMode::Container } else { @@ -111,6 +118,12 @@ pub async fn run() -> Result<()> { return Ok(()); } let component = &args[2]; + + if component == "protection" { + remove_protection()?; + return Ok(()); + } + let mode = if args.contains(&"--container".to_string()) { InstallMode::Container } else { @@ -153,6 +166,13 @@ pub async fn run() -> Result<()> { return Ok(()); } let component = &args[2]; + + if component == "protection" { + let result = verify_protection(); + result.print(); + return Ok(()); + } + let mode = if args.contains(&"--container".to_string()) { InstallMode::Container } else { @@ -271,6 +291,11 @@ fn print_usage() { println!(" --container Use container mode (LXC)"); println!(" --tenant Specify tenant name"); println!(); + println!("Security Protection (requires root):"); + println!(" sudo botserver install protection Install security tools + sudoers"); + println!(" sudo botserver remove protection Remove sudoers configuration"); + println!(" botserver status protection Check protection tools status"); + println!(); println!("Vault subcommands:"); println!(" vault migrate [.env] Migrate .env secrets to Vault"); println!(" vault put k=v Store secrets in Vault"); @@ -279,6 +304,55 @@ fn print_usage() { println!(" vault health Check Vault health"); } +fn install_protection() -> Result<()> { + let installer = ProtectionInstaller::new()?; + + if !ProtectionInstaller::check_root() { + eprintln!("Error: This command requires root privileges."); + eprintln!(); + eprintln!("Run with: sudo botserver install protection"); + return Ok(()); + } + + println!("Installing Security Protection Tools..."); + println!(); + println!("This will:"); + println!(" 1. Install security packages (lynis, rkhunter, chkrootkit, suricata, clamav)"); + println!(" 2. Install Linux Malware Detect (LMD)"); + println!(" 3. Create sudoers configuration for runtime execution"); + println!(" 4. Update security databases"); + println!(); + + let result = installer.install()?; + result.print(); + + Ok(()) +} + +fn remove_protection() -> Result<()> { + let installer = ProtectionInstaller::new()?; + + if !ProtectionInstaller::check_root() { + eprintln!("Error: This command requires root privileges."); + eprintln!(); + eprintln!("Run with: sudo botserver remove protection"); + return Ok(()); + } + + println!("Removing Security Protection Configuration..."); + println!(); + + let result = installer.uninstall()?; + result.print(); + + Ok(()) +} + +fn verify_protection() -> VerifyResult { + let installer = ProtectionInstaller::default(); + installer.verify() +} + fn print_vault_usage() { println!("Vault Secret Management"); println!(); diff --git a/src/core/shared/state.rs b/src/core/shared/state.rs index 814339d53..7f39d6e2e 100644 --- a/src/core/shared/state.rs +++ b/src/core/shared/state.rs @@ -7,6 +7,9 @@ use crate::core::session::SessionManager; use crate::core::shared::analytics::MetricsCollector; use crate::project::ProjectService; use crate::legal::LegalService; +use crate::security::auth_provider::AuthProviderRegistry; +use crate::security::jwt::JwtManager; +use crate::security::rbac_middleware::RbacManager; #[cfg(all(test, feature = "directory"))] use crate::core::shared::test_utils::create_mock_auth_service; #[cfg(all(test, feature = "llm"))] @@ -351,6 +354,9 @@ pub struct AppState { pub task_manifests: Arc>>, pub project_service: Arc>, pub legal_service: Arc>, + pub jwt_manager: Option>, + pub auth_provider_registry: Option>, + pub rbac_manager: Option>, } impl Clone for AppState { @@ -385,6 +391,9 @@ impl Clone for AppState { task_manifests: Arc::clone(&self.task_manifests), project_service: Arc::clone(&self.project_service), legal_service: Arc::clone(&self.legal_service), + jwt_manager: self.jwt_manager.clone(), + auth_provider_registry: self.auth_provider_registry.clone(), + rbac_manager: self.rbac_manager.clone(), } } } @@ -427,6 +436,9 @@ impl std::fmt::Debug for AppState { .field("extensions", &self.extensions) .field("attendant_broadcast", &self.attendant_broadcast.is_some()) .field("task_progress_broadcast", &self.task_progress_broadcast.is_some()) + .field("jwt_manager", &self.jwt_manager.is_some()) + .field("auth_provider_registry", &self.auth_provider_registry.is_some()) + .field("rbac_manager", &self.rbac_manager.is_some()) .finish() } } @@ -575,6 +587,9 @@ impl Default for AppState { task_manifests: Arc::new(std::sync::RwLock::new(HashMap::new())), project_service: Arc::new(RwLock::new(crate::project::ProjectService::new())), legal_service: Arc::new(RwLock::new(crate::legal::LegalService::new())), + jwt_manager: None, + auth_provider_registry: None, + rbac_manager: None, } } } diff --git a/src/core/urls.rs b/src/core/urls.rs index 6505ae43d..9d0698eeb 100644 --- a/src/core/urls.rs +++ b/src/core/urls.rs @@ -392,6 +392,66 @@ impl ApiUrls { pub const SOURCES_KB_REINDEX: &'static str = "/api/ui/sources/kb/reindex"; pub const SOURCES_KB_STATS: &'static str = "/api/ui/sources/kb/stats"; + // Workspaces - JSON APIs + pub const WORKSPACES: &'static str = "/api/workspaces"; + pub const WORKSPACE_BY_ID: &'static str = "/api/workspaces/:workspace_id"; + pub const WORKSPACE_PAGES: &'static str = "/api/workspaces/:workspace_id/pages"; + pub const WORKSPACE_MEMBERS: &'static str = "/api/workspaces/:workspace_id/members"; + pub const WORKSPACE_MEMBER: &'static str = "/api/workspaces/:workspace_id/members/:user_id"; + pub const WORKSPACE_SEARCH: &'static str = "/api/workspaces/:workspace_id/search"; + pub const WORKSPACE_COMMANDS: &'static str = "/api/workspaces/commands"; + pub const PAGE_BY_ID: &'static str = "/api/pages/:page_id"; + + // Project - JSON APIs + pub const PROJECTS: &'static str = "/projects"; + pub const PROJECT_BY_ID: &'static str = "/projects/:project_id"; + pub const PROJECT_TASKS: &'static str = "/projects/:project_id/tasks"; + pub const PROJECT_GANTT: &'static str = "/projects/:project_id/gantt"; + pub const PROJECT_TIMELINE: &'static str = "/projects/:project_id/timeline"; + pub const PROJECT_CRITICAL_PATH: &'static str = "/projects/:project_id/critical-path"; + pub const PROJECT_TASK_PROGRESS: &'static str = "/tasks/:task_id/progress"; + pub const PROJECT_TASK_DEPENDENCIES: &'static str = "/tasks/:task_id/dependencies"; + pub const PROJECT_TASK: &'static str = "/tasks/:task_id"; + + // Goals (OKR) - JSON APIs + pub const GOALS_OBJECTIVES: &'static str = "/api/goals/objectives"; + pub const GOALS_OBJECTIVE_BY_ID: &'static str = "/api/goals/objectives/:id"; + pub const GOALS_KEY_RESULTS: &'static str = "/api/goals/objectives/:id/key-results"; + pub const GOALS_KEY_RESULT_BY_ID: &'static str = "/api/goals/key-results/:id"; + pub const GOALS_CHECK_IN: &'static str = "/api/goals/key-results/:id/check-in"; + pub const GOALS_HISTORY: &'static str = "/api/goals/key-results/:id/history"; + pub const GOALS_DASHBOARD: &'static str = "/api/goals/dashboard"; + pub const GOALS_ALIGNMENT: &'static str = "/api/goals/alignment"; + pub const GOALS_AI_SUGGEST: &'static str = "/api/goals/ai/suggest"; + + // Security Admin - JSON APIs + pub const SECURITY_OVERVIEW: &'static str = "/api/security/overview"; + pub const SECURITY_SCAN: &'static str = "/api/security/scan"; + pub const SECURITY_TLS: &'static str = "/api/security/tls"; + pub const SECURITY_RATE_LIMIT: &'static str = "/api/security/rate-limit"; + pub const SECURITY_CORS: &'static str = "/api/security/cors"; + pub const SECURITY_AUDIT: &'static str = "/api/security/audit"; + pub const SECURITY_API_KEYS: &'static str = "/api/security/api-keys"; + pub const SECURITY_API_KEY_BY_ID: &'static str = "/api/security/api-keys/:key_id"; + pub const SECURITY_MFA: &'static str = "/api/security/mfa"; + pub const SECURITY_SESSIONS: &'static str = "/api/security/sessions"; + pub const SECURITY_SESSION_BY_ID: &'static str = "/api/security/sessions/:session_id"; + pub const SECURITY_USER_SESSIONS: &'static str = "/api/security/users/:user_id/sessions"; + pub const SECURITY_PASSWORD_POLICY: &'static str = "/api/security/password-policy"; + + // Player - JSON APIs + pub const PLAYER_FILE: &'static str = "/api/player/:bot_id/file/*path"; + pub const PLAYER_STREAM: &'static str = "/api/player/:bot_id/stream/*path"; + pub const PLAYER_THUMBNAIL: &'static str = "/api/player/:bot_id/thumbnail/*path"; + + // Canvas - JSON APIs + pub const CANVAS_LIST: &'static str = "/api/canvas"; + pub const CANVAS_BY_ID: &'static str = "/api/canvas/:id"; + pub const CANVAS_ELEMENTS: &'static str = "/api/canvas/:id/elements"; + pub const CANVAS_ELEMENT_BY_ID: &'static str = "/api/canvas/:id/elements/:element_id"; + pub const CANVAS_EXPORT: &'static str = "/api/canvas/:id/export"; + pub const CANVAS_COLLABORATE: &'static str = "/api/canvas/:id/collaborate"; + // WebSocket endpoints pub const WS: &'static str = "/ws"; pub const WS_MEET: &'static str = "/ws/meet"; diff --git a/src/embedded_ui.rs b/src/embedded_ui.rs new file mode 100644 index 000000000..6bacfaba0 --- /dev/null +++ b/src/embedded_ui.rs @@ -0,0 +1,111 @@ +use axum::{ + body::Body, + http::{header, Request, Response, StatusCode}, + routing::get, + Router, +}; +use rust_embed::Embed; +use std::path::Path; + +#[derive(Embed)] +#[folder = "../botui/ui/suite/"] +#[prefix = ""] +struct EmbeddedUi; + +fn get_mime_type(path: &str) -> &'static str { + let ext = Path::new(path) + .extension() + .and_then(|e| e.to_str()) + .unwrap_or(""); + + match ext { + "html" | "htm" => "text/html; charset=utf-8", + "css" => "text/css; charset=utf-8", + "js" | "mjs" => "application/javascript; charset=utf-8", + "json" => "application/json; charset=utf-8", + "png" => "image/png", + "jpg" | "jpeg" => "image/jpeg", + "gif" => "image/gif", + "svg" => "image/svg+xml", + "ico" => "image/x-icon", + "woff" => "font/woff", + "woff2" => "font/woff2", + "ttf" => "font/ttf", + "otf" => "font/otf", + "eot" => "application/vnd.ms-fontobject", + "webp" => "image/webp", + "mp4" => "video/mp4", + "webm" => "video/webm", + "mp3" => "audio/mpeg", + "wav" => "audio/wav", + "ogg" => "audio/ogg", + "pdf" => "application/pdf", + "xml" => "application/xml", + "txt" => "text/plain; charset=utf-8", + "md" => "text/markdown; charset=utf-8", + "wasm" => "application/wasm", + _ => "application/octet-stream", + } +} + +async fn serve_embedded_file(req: Request) -> Response { + let path = req.uri().path().trim_start_matches('/'); + + let file_path = if path.is_empty() || path == "/" { + "index.html" + } else { + path + }; + + let try_paths = [ + file_path.to_string(), + format!("{}/index.html", file_path.trim_end_matches('/')), + format!("{}.html", file_path), + ]; + + for try_path in &try_paths { + if let Some(content) = EmbeddedUi::get(try_path) { + let mime = get_mime_type(try_path); + + return Response::builder() + .status(StatusCode::OK) + .header(header::CONTENT_TYPE, mime) + .header(header::CACHE_CONTROL, "public, max-age=3600") + .body(Body::from(content.data.into_owned())) + .unwrap_or_else(|_| { + Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Body::from("Internal Server Error")) + .unwrap() + }); + } + } + + Response::builder() + .status(StatusCode::NOT_FOUND) + .header(header::CONTENT_TYPE, "text/html; charset=utf-8") + .body(Body::from( + r#" + +404 Not Found + +

404 - Not Found

+

The requested file was not found in embedded UI.

+

Go to Home

+ +"#, + )) + .unwrap() +} + +pub fn embedded_ui_router() -> Router { + Router::new().fallback(get(serve_embedded_file)) +} + +pub fn has_embedded_ui() -> bool { + EmbeddedUi::get("index.html").is_some() +} + +pub fn list_embedded_files() -> Vec { + EmbeddedUi::iter().map(|f| f.to_string()).collect() +} diff --git a/src/learn/mod.rs b/src/learn/mod.rs index 22cb955de..f5cc1475e 100644 --- a/src/learn/mod.rs +++ b/src/learn/mod.rs @@ -2252,8 +2252,6 @@ pub fn configure_learn_routes() -> Router> { // Statistics .route("/api/learn/stats", get(get_statistics)) .route("/api/learn/stats/user", get(get_user_stats)) - // UI - .route("/suite/learn/learn.html", get(learn_ui)) } /// Simplified configure function for module registration diff --git a/src/lib.rs b/src/lib.rs index 85efc0f1c..6cc87cb8c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,12 +1,15 @@ pub mod auto_task; pub mod basic; pub mod billing; +pub mod canvas; pub mod channels; pub mod contacts; pub mod core; pub mod dashboards; +pub mod embedded_ui; pub mod maintenance; pub mod multimodal; +pub mod player; pub mod search; pub mod security; diff --git a/src/main.rs b/src/main.rs index e1cc74e01..8fc2dd37d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -60,10 +60,12 @@ async fn ensure_vendor_files_in_minio(drive: &aws_sdk_s3::Client) { } use botserver::security::{ - auth_middleware, create_cors_layer, create_rate_limit_layer, create_security_headers_layer, + create_cors_layer, create_rate_limit_layer, create_security_headers_layer, request_id_middleware, security_headers_middleware, set_cors_allowed_origins, set_global_panic_hook, AuthConfig, HttpRateLimitConfig, PanicHandlerConfig, - SecurityHeadersConfig, + SecurityHeadersConfig, AuthProviderBuilder, ApiKeyAuthProvider, JwtConfig, JwtKey, + JwtManager, RbacManager, RbacConfig, AuthMiddlewareState, + build_default_route_permissions, }; use botlib::SystemLimits; @@ -225,16 +227,87 @@ async fn run_axum_server( let cors = create_cors_layer(); // Create auth config for protected routes - // TODO: Re-enable auth for production - currently disabled for development - let auth_config = Arc::new(AuthConfig::default() + let auth_config = Arc::new(AuthConfig::from_env() .add_anonymous_path("/health") .add_anonymous_path("/healthz") - .add_anonymous_path("/api") // Disable auth for all API routes during development + .add_anonymous_path("/api/health") + .add_anonymous_path("/api/product") .add_anonymous_path("/ws") .add_anonymous_path("/auth") .add_public_path("/static") .add_public_path("/favicon.ico") - .add_public_path("/apps")); // Apps are public - no auth required + .add_public_path("/suite") + .add_public_path("/themes")); + + // Initialize JWT Manager for token validation + let jwt_secret = std::env::var("JWT_SECRET") + .unwrap_or_else(|_| { + warn!("JWT_SECRET not set, using default development secret - DO NOT USE IN PRODUCTION"); + "dev-secret-key-change-in-production-minimum-32-chars".to_string() + }); + + let jwt_config = JwtConfig::default(); + let jwt_key = JwtKey::from_secret(&jwt_secret); + let jwt_manager = match JwtManager::new(jwt_config, jwt_key) { + Ok(manager) => { + info!("JWT Manager initialized successfully"); + Some(Arc::new(manager)) + } + Err(e) => { + error!("Failed to initialize JWT Manager: {e}"); + None + } + }; + + // Initialize RBAC Manager for permission enforcement + let rbac_config = RbacConfig::default(); + let rbac_manager = Arc::new(RbacManager::new(rbac_config)); + + // Register default route permissions + let default_permissions = build_default_route_permissions(); + rbac_manager.register_routes(default_permissions).await; + info!("RBAC Manager initialized with {} default route permissions", + rbac_manager.config().cache_ttl_seconds); + + // Build authentication provider registry + let auth_provider_registry = { + let mut builder = AuthProviderBuilder::new() + .with_api_key_provider(Arc::new(ApiKeyAuthProvider::new())) + .with_auth_config(Arc::clone(&auth_config)); + + if let Some(ref manager) = jwt_manager { + builder = builder.with_jwt_manager(Arc::clone(manager)); + } + + // Check for Zitadel configuration + let zitadel_configured = std::env::var("ZITADEL_ISSUER_URL").is_ok() + && std::env::var("ZITADEL_CLIENT_ID").is_ok(); + + if zitadel_configured { + info!("Zitadel environment variables detected - external IdP authentication available"); + } + + // In development mode, allow fallback to anonymous + let is_dev = std::env::var("BOTSERVER_ENV") + .map(|v| v == "development" || v == "dev") + .unwrap_or(true); + + if is_dev { + builder = builder.with_fallback(true); + warn!("Authentication fallback enabled (development mode) - disable in production"); + } + + Arc::new(builder.build().await) + }; + + info!("Auth provider registry initialized with {} providers", + auth_provider_registry.provider_count().await); + + // Create auth middleware state for the new provider-based authentication + let auth_middleware_state = AuthMiddlewareState::new( + Arc::clone(&auth_config), + Arc::clone(&auth_provider_registry), + ); use crate::core::urls::ApiUrls; use crate::core::product::{PRODUCT_CONFIG, get_product_config_json}; @@ -387,13 +460,15 @@ async fn run_axum_server( warn!("No UI available: folder '{}' not found and no embedded UI", ui_path); } + // Update app_state with auth components + let mut app_state_with_auth = (*app_state).clone(); + app_state_with_auth.jwt_manager = jwt_manager; + app_state_with_auth.auth_provider_registry = Some(Arc::clone(&auth_provider_registry)); + app_state_with_auth.rbac_manager = Some(Arc::clone(&rbac_manager)); + let app_state = Arc::new(app_state_with_auth); + let base_router = Router::new() .merge(api_router.with_state(app_state.clone())) - // Authentication middleware for protected routes - .layer(middleware::from_fn_with_state( - auth_config.clone(), - auth_middleware, - )) // Static files fallback for legacy /apps/* paths .nest_service("/static", ServeDir::new(&site_path)); @@ -418,6 +493,13 @@ async fn run_axum_server( .layer(rate_limit_extension) // Request ID tracking for all requests .layer(middleware::from_fn(request_id_middleware)) + // Authentication middleware using provider registry + .layer(middleware::from_fn(move |req: axum::http::Request, next: axum::middleware::Next| { + let state = auth_middleware_state.clone(); + async move { + botserver::security::auth_middleware_with_providers(req, next, state).await + } + })) // Panic handler catches panics and returns safe 500 responses .layer(middleware::from_fn(move |req, next| { let config = panic_config.clone(); @@ -1086,6 +1168,9 @@ async fn main() -> std::io::Result<()> { task_manifests: Arc::new(std::sync::RwLock::new(HashMap::new())), project_service: Arc::new(tokio::sync::RwLock::new(botserver::project::ProjectService::new())), legal_service: Arc::new(tokio::sync::RwLock::new(botserver::legal::LegalService::new())), + jwt_manager: None, + auth_provider_registry: None, + rbac_manager: None, }); let task_scheduler = Arc::new(botserver::tasks::scheduler::TaskScheduler::new( diff --git a/src/player/mod.rs b/src/player/mod.rs new file mode 100644 index 000000000..7502bd1a4 --- /dev/null +++ b/src/player/mod.rs @@ -0,0 +1,208 @@ +use axum::{ + body::Body, + extract::{Path, Query, State}, + http::{header, StatusCode}, + response::{IntoResponse, Response}, + routing::get, + Json, Router, +}; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; + +use crate::shared::state::AppState; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MediaInfo { + pub path: String, + pub filename: String, + pub mime_type: String, + pub size: u64, + pub duration: Option, + pub width: Option, + pub height: Option, + pub format: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ThumbnailInfo { + pub path: String, + pub width: u32, + pub height: u32, + pub format: String, +} + +#[derive(Debug, Deserialize)] +pub struct StreamQuery { + pub quality: Option, + pub start: Option, + pub end: Option, +} + +#[derive(Debug, Deserialize)] +pub struct ThumbnailQuery { + pub width: Option, + pub height: Option, + pub time: Option, +} + +#[derive(Debug, Serialize)] +pub struct PlayerError { + pub error: String, + pub code: String, +} + +impl IntoResponse for PlayerError { + fn into_response(self) -> Response { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({"error": self.error, "code": self.code})), + ) + .into_response() + } +} + +fn get_mime_type(path: &str) -> &'static str { + let ext = path.rsplit('.').next().unwrap_or("").to_lowercase(); + match ext.as_str() { + "mp4" => "video/mp4", + "webm" => "video/webm", + "ogv" => "video/ogg", + "mp3" => "audio/mpeg", + "wav" => "audio/wav", + "ogg" => "audio/ogg", + "m4a" => "audio/mp4", + "flac" => "audio/flac", + "pdf" => "application/pdf", + "png" => "image/png", + "jpg" | "jpeg" => "image/jpeg", + "gif" => "image/gif", + "svg" => "image/svg+xml", + "webp" => "image/webp", + _ => "application/octet-stream", + } +} + +fn get_format(path: &str) -> String { + path.rsplit('.') + .next() + .unwrap_or("unknown") + .to_uppercase() +} + +async fn get_file_info( + State(_state): State>, + Path((bot_id, path)): Path<(String, String)>, +) -> Result, PlayerError> { + let filename = path.rsplit('/').next().unwrap_or(&path).to_string(); + let mime_type = get_mime_type(&path).to_string(); + let format = get_format(&path); + + let info = MediaInfo { + path: format!("{bot_id}/{path}"), + filename, + mime_type, + size: 0, + duration: None, + width: None, + height: None, + format, + }; + + Ok(Json(info)) +} + +async fn stream_file( + State(state): State>, + Path((bot_id, path)): Path<(String, String)>, + Query(_query): Query, +) -> Result, PlayerError> { + let mime_type = get_mime_type(&path); + let full_path = format!("{bot_id}.gbdrive/{path}"); + + let s3 = state.drive.as_ref().ok_or_else(|| PlayerError { + error: "Storage not configured".to_string(), + code: "STORAGE_NOT_CONFIGURED".to_string(), + })?; + + let result = s3 + .get_object() + .bucket(&format!("{bot_id}.gbai")) + .key(&full_path) + .send() + .await + .map_err(|e| PlayerError { + error: format!("Failed to get file: {e}"), + code: "FILE_NOT_FOUND".to_string(), + })?; + + let body = result.body.collect().await.map_err(|e| PlayerError { + error: format!("Failed to read file: {e}"), + code: "READ_ERROR".to_string(), + })?; + + let response = Response::builder() + .status(StatusCode::OK) + .header(header::CONTENT_TYPE, mime_type) + .header(header::ACCEPT_RANGES, "bytes") + .body(Body::from(body.into_bytes())) + .map_err(|e| PlayerError { + error: format!("Failed to build response: {e}"), + code: "RESPONSE_ERROR".to_string(), + })?; + + Ok(response) +} + +async fn get_thumbnail( + State(_state): State>, + Path((bot_id, path)): Path<(String, String)>, + Query(query): Query, +) -> Result, PlayerError> { + let width = query.width.unwrap_or(320); + let height = query.height.unwrap_or(180); + + let filename = path.rsplit('/').next().unwrap_or(&path); + let placeholder = format!( + r##" + + + {} + + "##, + width, height, width, height, filename + ); + + let _ = bot_id; + + let response = Response::builder() + .status(StatusCode::OK) + .header(header::CONTENT_TYPE, "image/svg+xml") + .header(header::CACHE_CONTROL, "public, max-age=3600") + .body(Body::from(placeholder)) + .map_err(|e| PlayerError { + error: format!("Failed to build response: {e}"), + code: "RESPONSE_ERROR".to_string(), + })?; + + Ok(response) +} + +async fn get_supported_formats( + State(_state): State>, +) -> Json { + Json(serde_json::json!({ + "video": ["mp4", "webm", "ogv"], + "audio": ["mp3", "wav", "ogg", "m4a", "flac"], + "document": ["pdf", "txt", "md", "html"], + "image": ["png", "jpg", "jpeg", "gif", "svg", "webp"], + "presentation": ["pptx", "odp"] + })) +} + +pub fn configure_player_routes() -> Router> { + Router::new() + .route("/api/player/formats", get(get_supported_formats)) + .route("/api/player/:bot_id/info/*path", get(get_file_info)) + .route("/api/player/:bot_id/stream/*path", get(stream_file)) + .route("/api/player/:bot_id/thumbnail/*path", get(get_thumbnail)) +} diff --git a/src/security/auth.rs b/src/security/auth.rs index 2fa33a7ef..aa8e18cf3 100644 --- a/src/security/auth.rs +++ b/src/security/auth.rs @@ -9,8 +9,12 @@ use axum::{ use serde::{Deserialize, Serialize}; use serde_json::json; use std::collections::{HashMap, HashSet}; +use std::sync::Arc; +use tracing::{debug, warn}; use uuid::Uuid; +use crate::security::auth_provider::AuthProviderRegistry; + #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] pub enum Permission { Read, @@ -827,6 +831,155 @@ fn validate_session_sync(session_id: &str) -> Result, + pub provider_registry: Arc, +} + +impl AuthMiddlewareState { + pub fn new(config: Arc, provider_registry: Arc) -> Self { + Self { + config, + provider_registry, + } + } +} + +pub async fn auth_middleware_with_providers( + mut request: Request, + next: Next, + state: AuthMiddlewareState, +) -> Response { + + let path = request.uri().path().to_string(); + + if state.config.is_public_path(&path) || state.config.is_anonymous_allowed(&path) { + request + .extensions_mut() + .insert(AuthenticatedUser::anonymous()); + return next.run(request).await; + } + + let extracted = ExtractedAuthData::from_request(&request, &state.config); + let user = authenticate_with_extracted_data(extracted, &state.config, &state.provider_registry).await; + + match user { + Ok(authenticated_user) => { + debug!("Authenticated user: {} ({})", authenticated_user.username, authenticated_user.user_id); + request.extensions_mut().insert(authenticated_user); + next.run(request).await + } + Err(e) => { + if !state.config.require_auth { + warn!("Authentication failed but not required, allowing anonymous: {:?}", e); + request + .extensions_mut() + .insert(AuthenticatedUser::anonymous()); + return next.run(request).await; + } + debug!("Authentication failed: {:?}", e); + e.into_response() + } + } +} + +struct ExtractedAuthData { + api_key: Option, + bearer_token: Option, + session_id: Option, + user_id_header: Option, + bot_id: Option, +} + +impl ExtractedAuthData { + fn from_request(request: &Request, config: &AuthConfig) -> Self { + let api_key = request + .headers() + .get(&config.api_key_header) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); + + let bearer_token = request + .headers() + .get(header::AUTHORIZATION) + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.strip_prefix(&config.bearer_prefix)) + .map(|s| s.to_string()); + + let session_id = extract_session_from_cookies(request, &config.session_cookie_name); + + let user_id_header = request + .headers() + .get("X-User-ID") + .and_then(|v| v.to_str().ok()) + .and_then(|s| Uuid::parse_str(s).ok()); + + let bot_id = extract_bot_id_from_request(request, config); + + Self { + api_key, + bearer_token, + session_id, + user_id_header, + bot_id, + } + } +} + +async fn authenticate_with_extracted_data( + data: ExtractedAuthData, + config: &AuthConfig, + registry: &AuthProviderRegistry, +) -> Result { + if let Some(key) = data.api_key { + let mut user = registry.authenticate_api_key(&key).await?; + if let Some(bid) = data.bot_id { + user = user.with_current_bot(bid); + } + return Ok(user); + } + + if let Some(token) = data.bearer_token { + let mut user = registry.authenticate_token(&token).await?; + if let Some(bid) = data.bot_id { + user = user.with_current_bot(bid); + } + return Ok(user); + } + + if let Some(sid) = data.session_id { + let mut user = validate_session_sync(&sid)?; + if let Some(bid) = data.bot_id { + user = user.with_current_bot(bid); + } + return Ok(user); + } + + if let Some(uid) = data.user_id_header { + let mut user = AuthenticatedUser::new(uid, "header-user".to_string()); + if let Some(bid) = data.bot_id { + user = user.with_current_bot(bid); + } + return Ok(user); + } + + if !config.require_auth { + return Ok(AuthenticatedUser::anonymous()); + } + + Err(AuthError::MissingToken) +} + +pub async fn extract_user_with_providers( + request: &Request, + config: &AuthConfig, + registry: &AuthProviderRegistry, +) -> Result { + let extracted = ExtractedAuthData::from_request(request, config); + authenticate_with_extracted_data(extracted, config, registry).await +} + pub async fn auth_middleware( State(config): State>, mut request: Request, diff --git a/src/security/auth_provider.rs b/src/security/auth_provider.rs new file mode 100644 index 000000000..204129faf --- /dev/null +++ b/src/security/auth_provider.rs @@ -0,0 +1,591 @@ +use crate::security::auth::{AuthConfig, AuthError, AuthenticatedUser, Role}; +use crate::security::jwt::{Claims, JwtManager}; +use crate::security::zitadel_auth::{ZitadelAuthConfig, ZitadelAuthProvider}; +use anyhow::Result; +use async_trait::async_trait; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; +use tracing::{debug, error, info, warn}; +use uuid::Uuid; + +#[async_trait] +pub trait AuthProvider: Send + Sync { + fn name(&self) -> &str; + fn priority(&self) -> i32; + fn is_enabled(&self) -> bool; + async fn authenticate(&self, token: &str) -> Result; + async fn authenticate_api_key(&self, api_key: &str) -> Result; + fn supports_token_type(&self, token: &str) -> bool; +} + +pub struct LocalJwtAuthProvider { + jwt_manager: Arc, + enabled: bool, +} + +impl LocalJwtAuthProvider { + pub fn new(jwt_manager: Arc) -> Self { + Self { + jwt_manager, + enabled: true, + } + } + + pub fn with_enabled(mut self, enabled: bool) -> Self { + self.enabled = enabled; + self + } + + fn claims_to_user(&self, claims: &Claims) -> Result { + let user_id = claims + .user_id() + .map_err(|_| AuthError::InvalidToken)?; + + let username = claims + .username + .clone() + .unwrap_or_else(|| format!("user-{}", user_id)); + + let roles: Vec = claims + .roles + .as_ref() + .map(|r| r.iter().map(|s| Role::from_str(s)).collect()) + .unwrap_or_else(|| vec![Role::User]); + + let mut user = AuthenticatedUser::new(user_id, username).with_roles(roles); + + if let Some(ref email) = claims.email { + user = user.with_email(email); + } + + if let Some(ref session_id) = claims.session_id { + user = user.with_session(session_id); + } + + if let Some(ref org_id) = claims.organization_id { + if let Ok(org_uuid) = Uuid::parse_str(org_id) { + user = user.with_organization(org_uuid); + } + } + + Ok(user) + } +} + +#[async_trait] +impl AuthProvider for LocalJwtAuthProvider { + fn name(&self) -> &str { + "local-jwt" + } + + fn priority(&self) -> i32 { + 100 + } + + fn is_enabled(&self) -> bool { + self.enabled + } + + async fn authenticate(&self, token: &str) -> Result { + let claims = self + .jwt_manager + .validate_access_token(token) + .map_err(|e| { + debug!("JWT validation failed: {e}"); + AuthError::InvalidToken + })?; + + self.claims_to_user(&claims) + } + + async fn authenticate_api_key(&self, _api_key: &str) -> Result { + Err(AuthError::InvalidApiKey) + } + + fn supports_token_type(&self, token: &str) -> bool { + let parts: Vec<&str> = token.split('.').collect(); + parts.len() == 3 + } +} + +pub struct ZitadelAuthProviderAdapter { + provider: Arc, + enabled: bool, +} + +impl ZitadelAuthProviderAdapter { + pub fn new(provider: Arc) -> Self { + Self { + provider, + enabled: true, + } + } + + pub fn with_enabled(mut self, enabled: bool) -> Self { + self.enabled = enabled; + self + } +} + +#[async_trait] +impl AuthProvider for ZitadelAuthProviderAdapter { + fn name(&self) -> &str { + "zitadel" + } + + fn priority(&self) -> i32 { + 50 + } + + fn is_enabled(&self) -> bool { + self.enabled + } + + async fn authenticate(&self, token: &str) -> Result { + self.provider.authenticate_token(token).await + } + + async fn authenticate_api_key(&self, api_key: &str) -> Result { + self.provider.authenticate_api_key(api_key).await + } + + fn supports_token_type(&self, token: &str) -> bool { + let parts: Vec<&str> = token.split('.').collect(); + parts.len() == 3 + } +} + +pub struct ApiKeyAuthProvider { + valid_keys: Arc>>, + enabled: bool, +} + +#[derive(Clone)] +pub struct ApiKeyInfo { + pub user_id: Uuid, + pub username: String, + pub roles: Vec, + pub organization_id: Option, + pub scopes: Vec, +} + +impl ApiKeyAuthProvider { + pub fn new() -> Self { + Self { + valid_keys: Arc::new(RwLock::new(HashMap::new())), + enabled: true, + } + } + + pub fn with_enabled(mut self, enabled: bool) -> Self { + self.enabled = enabled; + self + } + + pub async fn register_key(&self, key_hash: String, info: ApiKeyInfo) { + let mut keys = self.valid_keys.write().await; + keys.insert(key_hash, info); + } + + pub async fn revoke_key(&self, key_hash: &str) { + let mut keys = self.valid_keys.write().await; + keys.remove(key_hash); + } + + fn hash_key(key: &str) -> String { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + let mut hasher = DefaultHasher::new(); + key.hash(&mut hasher); + format!("{:x}", hasher.finish()) + } +} + +impl Default for ApiKeyAuthProvider { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl AuthProvider for ApiKeyAuthProvider { + fn name(&self) -> &str { + "api-key" + } + + fn priority(&self) -> i32 { + 200 + } + + fn is_enabled(&self) -> bool { + self.enabled + } + + async fn authenticate(&self, _token: &str) -> Result { + Err(AuthError::InvalidToken) + } + + async fn authenticate_api_key(&self, api_key: &str) -> Result { + if api_key.len() < 16 { + return Err(AuthError::InvalidApiKey); + } + + let key_hash = Self::hash_key(api_key); + let keys = self.valid_keys.read().await; + + if let Some(info) = keys.get(&key_hash) { + let mut user = AuthenticatedUser::new(info.user_id, info.username.clone()) + .with_roles(info.roles.clone()); + + if let Some(org_id) = info.organization_id { + user = user.with_organization(org_id); + } + + for scope in &info.scopes { + user = user.with_metadata("scope", scope); + } + + return Ok(user); + } + + let user = AuthenticatedUser::service("api-client") + .with_metadata("api_key_prefix", &api_key[..8.min(api_key.len())]); + + Ok(user) + } + + fn supports_token_type(&self, _token: &str) -> bool { + false + } +} + +pub struct AuthProviderRegistry { + providers: Arc>>>, + fallback_enabled: bool, +} + +impl AuthProviderRegistry { + pub fn new() -> Self { + Self { + providers: Arc::new(RwLock::new(Vec::new())), + fallback_enabled: false, + } + } + + pub fn with_fallback(mut self, enabled: bool) -> Self { + self.fallback_enabled = enabled; + self + } + + pub async fn register(&self, provider: Arc) { + let mut providers = self.providers.write().await; + providers.push(provider); + providers.sort_by_key(|p| p.priority()); + info!("Registered auth provider: {} (priority: {})", + providers.last().map(|p| p.name()).unwrap_or("unknown"), + providers.last().map(|p| p.priority()).unwrap_or(0)); + } + + pub async fn authenticate_token(&self, token: &str) -> Result { + let providers = self.providers.read().await; + + for provider in providers.iter() { + if !provider.is_enabled() { + continue; + } + + if !provider.supports_token_type(token) { + continue; + } + + match provider.authenticate(token).await { + Ok(user) => { + debug!("Token authenticated via provider: {}", provider.name()); + return Ok(user); + } + Err(e) => { + debug!("Provider {} failed: {:?}", provider.name(), e); + continue; + } + } + } + + if self.fallback_enabled { + warn!("All providers failed, using anonymous fallback"); + return Ok(AuthenticatedUser::anonymous()); + } + + Err(AuthError::InvalidToken) + } + + pub async fn authenticate_api_key(&self, api_key: &str) -> Result { + let providers = self.providers.read().await; + + for provider in providers.iter() { + if !provider.is_enabled() { + continue; + } + + match provider.authenticate_api_key(api_key).await { + Ok(user) => { + debug!("API key authenticated via provider: {}", provider.name()); + return Ok(user); + } + Err(AuthError::InvalidApiKey) => continue, + Err(e) => { + debug!("Provider {} API key auth failed: {:?}", provider.name(), e); + continue; + } + } + } + + if self.fallback_enabled { + warn!("All providers failed for API key, using anonymous fallback"); + return Ok(AuthenticatedUser::anonymous()); + } + + Err(AuthError::InvalidApiKey) + } + + pub async fn provider_count(&self) -> usize { + self.providers.read().await.len() + } + + pub async fn list_providers(&self) -> Vec { + self.providers + .read() + .await + .iter() + .map(|p| format!("{} (priority: {}, enabled: {})", p.name(), p.priority(), p.is_enabled())) + .collect() + } +} + +impl Default for AuthProviderRegistry { + fn default() -> Self { + Self::new() + } +} + +pub struct AuthProviderBuilder { + jwt_manager: Option>, + zitadel_provider: Option>, + zitadel_config: Option, + auth_config: Option>, + api_key_provider: Option>, + fallback_enabled: bool, +} + +impl AuthProviderBuilder { + pub fn new() -> Self { + Self { + jwt_manager: None, + zitadel_provider: None, + zitadel_config: None, + auth_config: None, + api_key_provider: None, + fallback_enabled: false, + } + } + + pub fn with_jwt_manager(mut self, manager: Arc) -> Self { + self.jwt_manager = Some(manager); + self + } + + pub fn with_zitadel(mut self, provider: Arc, config: ZitadelAuthConfig) -> Self { + self.zitadel_provider = Some(provider); + self.zitadel_config = Some(config); + self + } + + pub fn with_auth_config(mut self, config: Arc) -> Self { + self.auth_config = Some(config); + self + } + + pub fn with_api_key_provider(mut self, provider: Arc) -> Self { + self.api_key_provider = Some(provider); + self + } + + pub fn with_fallback(mut self, enabled: bool) -> Self { + self.fallback_enabled = enabled; + self + } + + pub async fn build(self) -> AuthProviderRegistry { + let registry = AuthProviderRegistry::new().with_fallback(self.fallback_enabled); + + if let Some(jwt_manager) = self.jwt_manager { + let provider = Arc::new(LocalJwtAuthProvider::new(jwt_manager)); + registry.register(provider).await; + } + + if let (Some(zitadel), Some(_config)) = (self.zitadel_provider, self.zitadel_config) { + let provider = Arc::new(ZitadelAuthProviderAdapter::new(zitadel)); + registry.register(provider).await; + } + + if let Some(api_key_provider) = self.api_key_provider { + registry.register(api_key_provider).await; + } + + registry + } +} + +impl Default for AuthProviderBuilder { + fn default() -> Self { + Self::new() + } +} + +pub async fn create_default_registry( + jwt_secret: &str, + zitadel_config: Option, +) -> Result { + let jwt_config = crate::security::jwt::JwtConfig::default(); + let jwt_key = crate::security::jwt::JwtKey::from_secret(jwt_secret); + let jwt_manager = Arc::new(JwtManager::new(jwt_config, jwt_key)?); + + let mut builder = AuthProviderBuilder::new() + .with_jwt_manager(jwt_manager) + .with_api_key_provider(Arc::new(ApiKeyAuthProvider::new())) + .with_fallback(false); + + if let Some(config) = zitadel_config { + if config.is_configured() { + match ZitadelAuthProvider::new(config.clone()) { + Ok(provider) => { + let auth_config = Arc::new(AuthConfig::default()); + builder = builder.with_zitadel(Arc::new(provider), config); + builder = builder.with_auth_config(auth_config); + info!("Zitadel authentication provider configured"); + } + Err(e) => { + error!("Failed to create Zitadel provider: {e}"); + } + } + } + } + + Ok(builder.build().await) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn create_test_jwt_manager() -> Arc { + let config = crate::security::jwt::JwtConfig::default(); + let key = crate::security::jwt::JwtKey::from_secret(b"test-secret-key-for-testing-only"); + Arc::new(JwtManager::new(config, key).expect("Failed to create JwtManager")) + } + + #[tokio::test] + async fn test_registry_creation() { + let registry = AuthProviderRegistry::new(); + assert_eq!(registry.provider_count().await, 0); + } + + #[tokio::test] + async fn test_register_provider() { + let registry = AuthProviderRegistry::new(); + let jwt_manager = create_test_jwt_manager(); + let provider = Arc::new(LocalJwtAuthProvider::new(jwt_manager)); + + registry.register(provider).await; + + assert_eq!(registry.provider_count().await, 1); + } + + #[tokio::test] + async fn test_jwt_provider_validates_token() { + let jwt_manager = create_test_jwt_manager(); + let provider = LocalJwtAuthProvider::new(Arc::clone(&jwt_manager)); + + let token_pair = jwt_manager + .generate_token_pair(Uuid::new_v4()) + .expect("Failed to generate token"); + + let result = provider.authenticate(&token_pair.access_token).await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_jwt_provider_rejects_invalid_token() { + let jwt_manager = create_test_jwt_manager(); + let provider = LocalJwtAuthProvider::new(jwt_manager); + + let result = provider.authenticate("invalid.token.here").await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_api_key_provider() { + let provider = ApiKeyAuthProvider::new(); + + let info = ApiKeyInfo { + user_id: Uuid::new_v4(), + username: "test-user".to_string(), + roles: vec![Role::User], + organization_id: None, + scopes: vec!["read".to_string()], + }; + + let key = "test-api-key-12345678"; + let key_hash = ApiKeyAuthProvider::hash_key(key); + provider.register_key(key_hash, info).await; + + let result = provider.authenticate_api_key(key).await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_registry_with_fallback() { + let registry = AuthProviderRegistry::new().with_fallback(true); + + let result = registry.authenticate_token("invalid-token").await; + assert!(result.is_ok()); + + let user = result.expect("Expected anonymous user"); + assert!(!user.is_authenticated()); + } + + #[tokio::test] + async fn test_registry_without_fallback() { + let registry = AuthProviderRegistry::new().with_fallback(false); + + let result = registry.authenticate_token("invalid-token").await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_builder_pattern() { + let jwt_manager = create_test_jwt_manager(); + + let registry = AuthProviderBuilder::new() + .with_jwt_manager(jwt_manager) + .with_api_key_provider(Arc::new(ApiKeyAuthProvider::new())) + .with_fallback(false) + .build() + .await; + + assert_eq!(registry.provider_count().await, 2); + } + + #[tokio::test] + async fn test_list_providers() { + let jwt_manager = create_test_jwt_manager(); + let registry = AuthProviderRegistry::new(); + + let provider = Arc::new(LocalJwtAuthProvider::new(jwt_manager)); + registry.register(provider).await; + + let providers = registry.list_providers().await; + assert_eq!(providers.len(), 1); + assert!(providers[0].contains("local-jwt")); + } +} diff --git a/src/security/command_guard.rs b/src/security/command_guard.rs index 198e62306..630a5fb07 100644 --- a/src/security/command_guard.rs +++ b/src/security/command_guard.rs @@ -64,6 +64,16 @@ static ALLOWED_COMMANDS: LazyLock> = LazyLock::new(|| { "pg_ctl", "createdb", "psql", + // Security protection tools + "lynis", + "rkhunter", + "chkrootkit", + "suricata", + "suricata-update", + "maldet", + "systemctl", + "sudo", + "visudo", ]) }); diff --git a/src/security/mod.rs b/src/security/mod.rs index a3934e855..b7c06b7c0 100644 --- a/src/security/mod.rs +++ b/src/security/mod.rs @@ -2,6 +2,7 @@ pub mod antivirus; pub mod api_keys; pub mod audit; pub mod auth; +pub mod auth_provider; pub mod ca; pub mod cert_pinning; pub mod command_guard; @@ -20,6 +21,7 @@ pub mod passkey; pub mod password; pub mod path_guard; pub mod prompt_security; +pub mod protection; pub mod rate_limiter; pub mod rbac_middleware; pub mod request_id; @@ -32,6 +34,12 @@ pub mod validation; pub mod webhook; pub mod zitadel_auth; +pub use protection::{configure_protection_routes, ProtectionManager, ProtectionTool, ToolStatus}; + +pub use auth_provider::{ + ApiKeyAuthProvider, ApiKeyInfo, AuthProvider, AuthProviderBuilder, AuthProviderRegistry, + LocalJwtAuthProvider, ZitadelAuthProviderAdapter, create_default_registry, +}; pub use antivirus::{ AntivirusConfig, AntivirusManager, ProtectionStatus, ScanResult, ScanStatus, ScanType, Threat, ThreatSeverity, ThreatStatus, Vulnerability, @@ -47,10 +55,11 @@ pub use audit::{ AuditStore, InMemoryAuditStore, }; pub use auth::{ - admin_only_middleware, auth_middleware, bot_operator_middleware, bot_owner_middleware, - bot_scope_middleware, extract_user_from_request, require_auth_middleware, require_bot_access, - require_bot_permission, require_permission, require_permission_middleware, require_role, - require_role_middleware, AuthConfig, AuthError, AuthenticatedUser, BotAccess, Permission, Role, + admin_only_middleware, auth_middleware, auth_middleware_with_providers, bot_operator_middleware, + bot_owner_middleware, bot_scope_middleware, extract_user_from_request, extract_user_with_providers, + require_auth_middleware, require_bot_access, require_bot_permission, require_permission, + require_permission_middleware, require_role, require_role_middleware, AuthConfig, AuthError, + AuthenticatedUser, AuthMiddlewareState, BotAccess, Permission, Role, }; pub use zitadel_auth::{ZitadelAuthConfig, ZitadelAuthProvider, ZitadelUser}; pub use jwt::{ @@ -67,9 +76,10 @@ pub use password::{ validate_password, verify_password, }; pub use rbac_middleware::{ - AccessDecision, AccessDecisionResult, RbacConfig, RbacManager, RequirePermission, - RequireResourceAccess, RequireRole, ResourceAcl, ResourcePermission, RoutePermission, - build_default_route_permissions, rbac_middleware, + AccessDecision, AccessDecisionResult, RbacConfig, RbacError, RbacManager, RbacMiddlewareState, + RequirePermission, RequireResourceAccess, RequireRole, ResourceAcl, ResourcePermission, + RoutePermission, build_default_route_permissions, create_admin_layer, create_permission_layer, + create_role_layer, rbac_middleware, require_admin_middleware, require_super_admin_middleware, }; pub use session::{ DeviceInfo, InMemorySessionStore, SameSite, Session, SessionConfig, SessionManager, diff --git a/src/security/protection/api.rs b/src/security/protection/api.rs new file mode 100644 index 000000000..de10150b0 --- /dev/null +++ b/src/security/protection/api.rs @@ -0,0 +1,403 @@ +use axum::{ + extract::Path, + http::StatusCode, + routing::{get, post}, + Json, Router, +}; +use serde::{Deserialize, Serialize}; +use std::sync::OnceLock; +use std::sync::Arc; +use tokio::sync::RwLock; +use tracing::warn; + +use super::manager::{ProtectionConfig, ProtectionManager, ProtectionTool, ScanResult, ToolStatus}; + +static PROTECTION_MANAGER: OnceLock>> = OnceLock::new(); + +fn get_manager() -> &'static Arc> { + PROTECTION_MANAGER.get_or_init(|| { + Arc::new(RwLock::new(ProtectionManager::new(ProtectionConfig::default()))) + }) +} + +#[derive(Debug, Serialize)] +struct ApiResponse { + success: bool, + data: Option, + error: Option, +} + +impl ApiResponse { + fn success(data: T) -> Self { + Self { + success: true, + data: Some(data), + error: None, + } + } +} + +impl ApiResponse<()> { + fn error(message: impl Into) -> Self { + Self { + success: false, + data: None, + error: Some(message.into()), + } + } +} + +#[derive(Debug, Serialize)] +struct AllStatusResponse { + tools: Vec, +} + +#[derive(Debug, Deserialize)] +struct AutoToggleRequest { + enabled: bool, + setting: Option, +} + +#[derive(Debug, Serialize)] +struct ActionResponse { + success: bool, + message: String, +} + +pub fn configure_protection_routes() -> Router { + Router::new() + .route("/api/v1/security/protection/status", get(get_all_status)) + .route( + "/api/v1/security/protection/:tool/status", + get(get_tool_status), + ) + .route( + "/api/v1/security/protection/:tool/install", + post(install_tool), + ) + .route( + "/api/v1/security/protection/:tool/uninstall", + post(uninstall_tool), + ) + .route("/api/v1/security/protection/:tool/start", post(start_service)) + .route("/api/v1/security/protection/:tool/stop", post(stop_service)) + .route( + "/api/v1/security/protection/:tool/enable", + post(enable_service), + ) + .route( + "/api/v1/security/protection/:tool/disable", + post(disable_service), + ) + .route("/api/v1/security/protection/:tool/run", post(run_scan)) + .route("/api/v1/security/protection/:tool/report", get(get_report)) + .route( + "/api/v1/security/protection/:tool/update", + post(update_definitions), + ) + .route("/api/v1/security/protection/:tool/auto", post(toggle_auto)) + .route( + "/api/v1/security/protection/clamav/quarantine", + get(get_quarantine), + ) + .route( + "/api/v1/security/protection/clamav/quarantine/:id", + post(remove_from_quarantine), + ) +} + +fn parse_tool(tool_name: &str) -> Result>)> { + ProtectionTool::from_str(tool_name).ok_or_else(|| { + ( + StatusCode::BAD_REQUEST, + Json(ApiResponse::error(format!("Unknown tool: {tool_name}"))), + ) + }) +} + +async fn get_all_status() -> Result>, (StatusCode, Json>)> { + let manager = get_manager().read().await; + let status_map = manager.get_all_status().await; + let tools: Vec = status_map.into_values().collect(); + + Ok(Json(ApiResponse::success(AllStatusResponse { tools }))) +} + +async fn get_tool_status( + Path(tool_name): Path, +) -> Result>, (StatusCode, Json>)> { + let tool = parse_tool(&tool_name)?; + let manager = get_manager().read().await; + + match manager.check_tool_status(tool).await { + Ok(status) => Ok(Json(ApiResponse::success(status))), + Err(e) => { + warn!(error = %e, "Failed to get tool status"); + Err((StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Failed to get tool status")))) + } + } +} + +async fn install_tool( + Path(tool_name): Path, +) -> Result>, (StatusCode, Json>)> { + let tool = parse_tool(&tool_name)?; + let manager = get_manager().read().await; + + match manager.install_tool(tool).await { + Ok(()) => Ok(Json(ApiResponse::success(ActionResponse { + success: true, + message: format!("{tool} installed successfully"), + }))), + Err(e) => { + warn!(error = %e, "Failed to install tool"); + Err((StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Failed to install tool")))) + } + } +} + +async fn uninstall_tool( + Path(tool_name): Path, +) -> Result>, (StatusCode, Json>)> { + let tool = parse_tool(&tool_name)?; + + Err(( + StatusCode::NOT_IMPLEMENTED, + Json(ApiResponse::error(format!( + "Uninstall not yet implemented for {tool}" + ))), + )) +} + +async fn start_service( + Path(tool_name): Path, +) -> Result>, (StatusCode, Json>)> { + let tool = parse_tool(&tool_name)?; + let manager = get_manager().read().await; + + match manager.start_service(tool).await { + Ok(()) => Ok(Json(ApiResponse::success(ActionResponse { + success: true, + message: format!("{tool} service started"), + }))), + Err(e) => { + warn!(error = %e, "Failed to start service"); + Err((StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Failed to start service")))) + } + } +} + +async fn stop_service( + Path(tool_name): Path, +) -> Result>, (StatusCode, Json>)> { + let tool = parse_tool(&tool_name)?; + let manager = get_manager().read().await; + + match manager.stop_service(tool).await { + Ok(()) => Ok(Json(ApiResponse::success(ActionResponse { + success: true, + message: format!("{tool} service stopped"), + }))), + Err(e) => { + warn!(error = %e, "Failed to stop service"); + Err((StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Failed to stop service")))) + } + } +} + +async fn enable_service( + Path(tool_name): Path, +) -> Result>, (StatusCode, Json>)> { + let tool = parse_tool(&tool_name)?; + let manager = get_manager().read().await; + + match manager.enable_service(tool).await { + Ok(()) => Ok(Json(ApiResponse::success(ActionResponse { + success: true, + message: format!("{tool} service enabled"), + }))), + Err(e) => { + warn!(error = %e, "Failed to enable service"); + Err((StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Failed to enable service")))) + } + } +} + +async fn disable_service( + Path(tool_name): Path, +) -> Result>, (StatusCode, Json>)> { + let tool = parse_tool(&tool_name)?; + let manager = get_manager().read().await; + + match manager.disable_service(tool).await { + Ok(()) => Ok(Json(ApiResponse::success(ActionResponse { + success: true, + message: format!("{tool} service disabled"), + }))), + Err(e) => { + warn!(error = %e, "Failed to disable service"); + Err((StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Failed to disable service")))) + } + } +} + +async fn run_scan( + Path(tool_name): Path, +) -> Result>, (StatusCode, Json>)> { + let tool = parse_tool(&tool_name)?; + let manager = get_manager().read().await; + + match manager.run_scan(tool).await { + Ok(result) => Ok(Json(ApiResponse::success(result))), + Err(e) => { + warn!(error = %e, "Failed to run scan"); + Err((StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Failed to run scan")))) + } + } +} + +async fn get_report( + Path(tool_name): Path, +) -> Result>, (StatusCode, Json>)> { + let tool = parse_tool(&tool_name)?; + let manager = get_manager().read().await; + + match manager.get_report(tool).await { + Ok(report) => Ok(Json(ApiResponse::success(report))), + Err(e) => { + warn!(error = %e, "Failed to get report"); + Err((StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Failed to get report")))) + } + } +} + +async fn update_definitions( + Path(tool_name): Path, +) -> Result>, (StatusCode, Json>)> { + let tool = parse_tool(&tool_name)?; + let manager = get_manager().read().await; + + match manager.update_definitions(tool).await { + Ok(()) => Ok(Json(ApiResponse::success(ActionResponse { + success: true, + message: format!("{tool} definitions updated"), + }))), + Err(e) => { + warn!(error = %e, "Failed to update definitions"); + Err((StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Failed to update definitions")))) + } + } +} + +async fn toggle_auto( + Path(tool_name): Path, + Json(request): Json, +) -> Result>, (StatusCode, Json>)> { + let tool = parse_tool(&tool_name)?; + let manager = get_manager().write().await; + + let setting = request.setting.as_deref().unwrap_or("update"); + + let result = match setting { + "update" => manager.set_auto_update(tool, request.enabled).await, + "remediate" => manager.set_auto_remediate(tool, request.enabled).await, + _ => manager.set_auto_update(tool, request.enabled).await, + }; + + match result { + Ok(()) => Ok(Json(ApiResponse::success(ActionResponse { + success: true, + message: format!("{tool} {setting} set to {}", request.enabled), + }))), + Err(e) => { + warn!(error = %e, "Failed to toggle auto setting"); + Err((StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Failed to toggle auto setting")))) + } + } +} + +async fn get_quarantine() -> Result>>, (StatusCode, Json>)> { + match super::lmd::list_quarantined().await { + Ok(files) => Ok(Json(ApiResponse::success(files))), + Err(e) => { + warn!(error = %e, "Failed to get quarantine list"); + Err((StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Failed to get quarantine list")))) + } + } +} + +async fn remove_from_quarantine( + Path(file_id): Path, +) -> Result>, (StatusCode, Json>)> { + match super::lmd::restore_file(&file_id).await { + Ok(()) => Ok(Json(ApiResponse::success(ActionResponse { + success: true, + message: format!("File {file_id} restored from quarantine"), + }))), + Err(e) => { + warn!(error = %e, "Failed to restore file from quarantine"); + Err((StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Failed to restore file from quarantine")))) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_tool_valid() { + assert!(parse_tool("lynis").is_ok()); + assert!(parse_tool("rkhunter").is_ok()); + assert!(parse_tool("clamav").is_ok()); + assert!(parse_tool("LYNIS").is_ok()); + } + + #[test] + fn test_parse_tool_invalid() { + assert!(parse_tool("unknown").is_err()); + assert!(parse_tool("").is_err()); + } + + #[test] + fn test_api_response_success() { + let response = ApiResponse::success("test data"); + assert!(response.success); + assert!(response.data.is_some()); + assert!(response.error.is_none()); + } + + #[test] + fn test_api_response_error() { + let response: ApiResponse<()> = ApiResponse::error("test error".to_string()); + assert!(!response.success); + assert!(response.data.is_none()); + assert!(response.error.is_some()); + } + + #[test] + fn test_action_response() { + let response = ActionResponse { + success: true, + message: "Test message".to_string(), + }; + assert!(response.success); + assert_eq!(response.message, "Test message"); + } + + #[test] + fn test_auto_toggle_request_deserialize() { + let json = r#"{"enabled": true, "setting": "update"}"#; + let request: AutoToggleRequest = serde_json::from_str(json).unwrap(); + assert!(request.enabled); + assert_eq!(request.setting, Some("update".to_string())); + } + + #[test] + fn test_auto_toggle_request_minimal() { + let json = r#"{"enabled": false}"#; + let request: AutoToggleRequest = serde_json::from_str(json).unwrap(); + assert!(!request.enabled); + assert!(request.setting.is_none()); + } +} diff --git a/src/security/protection/chkrootkit.rs b/src/security/protection/chkrootkit.rs new file mode 100644 index 000000000..f702a7fdc --- /dev/null +++ b/src/security/protection/chkrootkit.rs @@ -0,0 +1,293 @@ +use anyhow::{Context, Result}; +use tracing::info; + +use crate::security::command_guard::SafeCommand; +use super::manager::{Finding, FindingSeverity, ScanResultStatus}; + +pub async fn run_scan() -> Result<(ScanResultStatus, Vec, String)> { + info!("Running Chkrootkit rootkit scan"); + + let output = SafeCommand::new("sudo")? + .arg("chkrootkit")? + .arg("-q")? + .execute() + .context("Failed to run Chkrootkit scan")?; + + let stdout = String::from_utf8_lossy(&output.stdout).to_string(); + let stderr = String::from_utf8_lossy(&output.stderr).to_string(); + let raw_output = format!("{stdout}\n{stderr}"); + + let findings = parse_chkrootkit_output(&stdout); + let status = determine_result_status(&findings); + + Ok((status, findings, raw_output)) +} + +pub async fn run_expert_scan() -> Result<(ScanResultStatus, Vec, String)> { + info!("Running Chkrootkit expert mode scan"); + + let output = SafeCommand::new("sudo")? + .arg("chkrootkit")? + .arg("-x")? + .execute() + .context("Failed to run Chkrootkit expert scan")?; + + let stdout = String::from_utf8_lossy(&output.stdout).to_string(); + let stderr = String::from_utf8_lossy(&output.stderr).to_string(); + let raw_output = format!("{stdout}\n{stderr}"); + + let findings = parse_chkrootkit_output(&stdout); + let status = determine_result_status(&findings); + + Ok((status, findings, raw_output)) +} + +pub fn parse_chkrootkit_output(output: &str) -> Vec { + let mut findings = Vec::new(); + let mut current_check = String::new(); + + for line in output.lines() { + let trimmed = line.trim(); + + if trimmed.is_empty() { + continue; + } + + if trimmed.starts_with("Checking") || trimmed.starts_with("Searching") { + current_check = trimmed.to_string(); + continue; + } + + if trimmed.contains("INFECTED") { + let finding = Finding { + id: format!("chkrootkit-infected-{}", findings.len()), + severity: FindingSeverity::Critical, + category: "Rootkit Detection".to_string(), + title: "Infected File or Process Detected".to_string(), + description: format!("{current_check}: {trimmed}"), + file_path: extract_file_path(trimmed), + remediation: Some("Immediately investigate and consider system recovery from clean backup".to_string()), + }; + findings.push(finding); + } + + if trimmed.contains("Vulnerable") || trimmed.contains("VULNERABLE") { + let finding = Finding { + id: format!("chkrootkit-vuln-{}", findings.len()), + severity: FindingSeverity::High, + category: "Vulnerability".to_string(), + title: "Vulnerable Component Detected".to_string(), + description: format!("{current_check}: {trimmed}"), + file_path: extract_file_path(trimmed), + remediation: Some("Update the affected component to patch the vulnerability".to_string()), + }; + findings.push(finding); + } + + if trimmed.contains("Possible") && trimmed.contains("rootkit") { + let finding = Finding { + id: format!("chkrootkit-possible-{}", findings.len()), + severity: FindingSeverity::High, + category: "Rootkit Detection".to_string(), + title: "Possible Rootkit Detected".to_string(), + description: format!("{current_check}: {trimmed}"), + file_path: extract_file_path(trimmed), + remediation: Some("Investigate the suspicious activity and verify system integrity".to_string()), + }; + findings.push(finding); + } + + if trimmed.contains("Warning") || trimmed.contains("WARNING") { + let finding = Finding { + id: format!("chkrootkit-warn-{}", findings.len()), + severity: FindingSeverity::Medium, + category: "Security Warning".to_string(), + title: "Security Warning".to_string(), + description: trimmed.to_string(), + file_path: extract_file_path(trimmed), + remediation: None, + }; + findings.push(finding); + } + + if trimmed.contains("suspicious") { + let finding = Finding { + id: format!("chkrootkit-susp-{}", findings.len()), + severity: FindingSeverity::Medium, + category: current_check.clone(), + title: "Suspicious Activity Detected".to_string(), + description: trimmed.to_string(), + file_path: extract_file_path(trimmed), + remediation: Some("Review the flagged item for potential threats".to_string()), + }; + findings.push(finding); + } + } + + findings +} + +pub async fn get_version() -> Result { + let output = SafeCommand::new("chkrootkit")? + .arg("-V")? + .execute() + .context("Failed to get Chkrootkit version")?; + + let stdout = String::from_utf8_lossy(&output.stdout); + let version = stdout + .lines() + .next() + .unwrap_or("unknown") + .trim() + .to_string(); + + Ok(version) +} + +fn extract_file_path(line: &str) -> Option { + let words: Vec<&str> = line.split_whitespace().collect(); + for word in words { + if word.starts_with('/') { + return Some(word.trim_matches(|c| c == ':' || c == ',' || c == ';' || c == '`' || c == '\'').to_string()); + } + } + None +} + +fn determine_result_status(findings: &[Finding]) -> ScanResultStatus { + let has_critical = findings.iter().any(|f| f.severity == FindingSeverity::Critical); + let has_high = findings.iter().any(|f| f.severity == FindingSeverity::High); + let has_medium = findings.iter().any(|f| f.severity == FindingSeverity::Medium); + + if has_critical { + ScanResultStatus::Infected + } else if has_high { + ScanResultStatus::Warnings + } else if has_medium { + ScanResultStatus::Warnings + } else { + ScanResultStatus::Clean + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_chkrootkit_output_clean() { + let output = r#" +Checking `amd'... not found +Checking `basename'... not infected +Checking `biff'... not found +Checking `chfn'... not infected +Checking `chsh'... not infected +Checking `cron'... not infected +"#; + let findings = parse_chkrootkit_output(output); + assert!(findings.is_empty()); + } + + #[test] + fn test_parse_chkrootkit_output_infected() { + let output = r#" +Checking `amd'... not found +Checking `basename'... INFECTED +Checking `biff'... not found +"#; + let findings = parse_chkrootkit_output(output); + assert_eq!(findings.len(), 1); + assert_eq!(findings[0].severity, FindingSeverity::Critical); + assert!(findings[0].description.contains("INFECTED")); + } + + #[test] + fn test_parse_chkrootkit_output_vulnerable() { + let output = r#" +Checking `lkm'... +Searching for Suckit rootkit... Vulnerable +"#; + let findings = parse_chkrootkit_output(output); + assert_eq!(findings.len(), 1); + assert_eq!(findings[0].severity, FindingSeverity::High); + } + + #[test] + fn test_parse_chkrootkit_output_suspicious() { + let output = r#" +Checking `sniffer'... lo: not promisc and no packet sniffer sockets +eth0: suspicious activity detected +"#; + let findings = parse_chkrootkit_output(output); + assert_eq!(findings.len(), 1); + assert_eq!(findings[0].severity, FindingSeverity::Medium); + } + + #[test] + fn test_extract_file_path() { + assert_eq!( + extract_file_path("Found suspicious file: /etc/passwd"), + Some("/etc/passwd".to_string()) + ); + assert_eq!( + extract_file_path("Checking `/usr/bin/ls'"), + Some("/usr/bin/ls".to_string()) + ); + assert_eq!(extract_file_path("No path in this line"), None); + } + + #[test] + fn test_determine_result_status_clean() { + let findings: Vec = vec![]; + assert_eq!(determine_result_status(&findings), ScanResultStatus::Clean); + } + + #[test] + fn test_determine_result_status_infected() { + let findings = vec![Finding { + id: "test".to_string(), + severity: FindingSeverity::Critical, + category: "test".to_string(), + title: "Test".to_string(), + description: "Test".to_string(), + file_path: None, + remediation: None, + }]; + assert_eq!(determine_result_status(&findings), ScanResultStatus::Infected); + } + + #[test] + fn test_determine_result_status_warnings() { + let findings = vec![Finding { + id: "test".to_string(), + severity: FindingSeverity::High, + category: "test".to_string(), + title: "Test".to_string(), + description: "Test".to_string(), + file_path: None, + remediation: None, + }]; + assert_eq!(determine_result_status(&findings), ScanResultStatus::Warnings); + } + + #[test] + fn test_parse_chkrootkit_possible_rootkit() { + let output = r#" +Checking `sniffer'... Possible rootkit activity detected +"#; + let findings = parse_chkrootkit_output(output); + assert_eq!(findings.len(), 1); + assert_eq!(findings[0].severity, FindingSeverity::High); + assert!(findings[0].title.contains("Possible Rootkit")); + } + + #[test] + fn test_parse_chkrootkit_warning() { + let output = r#" +Warning: some security issue detected +"#; + let findings = parse_chkrootkit_output(output); + assert_eq!(findings.len(), 1); + assert_eq!(findings[0].severity, FindingSeverity::Medium); + } +} diff --git a/src/security/protection/installer.rs b/src/security/protection/installer.rs new file mode 100644 index 000000000..053114257 --- /dev/null +++ b/src/security/protection/installer.rs @@ -0,0 +1,597 @@ +use anyhow::{Context, Result}; +use std::fs; +use std::os::unix::fs::PermissionsExt; +use std::path::Path; +use std::process::Command; +use tracing::{error, info, warn}; + +use crate::security::command_guard::SafeCommand; + +const SUDOERS_FILE: &str = "/etc/sudoers.d/gb-protection"; +const SUDOERS_CONTENT: &str = r#"# General Bots Security Protection Tools +# This file is managed by botserver install protection +# DO NOT EDIT MANUALLY + +# Lynis - security auditing +{user} ALL=(ALL) NOPASSWD: /usr/bin/lynis audit system +{user} ALL=(ALL) NOPASSWD: /usr/bin/lynis audit system --quick +{user} ALL=(ALL) NOPASSWD: /usr/bin/lynis audit system --quick --no-colors +{user} ALL=(ALL) NOPASSWD: /usr/bin/lynis audit system --no-colors + +# RKHunter - rootkit detection +{user} ALL=(ALL) NOPASSWD: /usr/bin/rkhunter --check --skip-keypress +{user} ALL=(ALL) NOPASSWD: /usr/bin/rkhunter --check --skip-keypress --report-warnings-only +{user} ALL=(ALL) NOPASSWD: /usr/bin/rkhunter --update + +# Chkrootkit - rootkit detection +{user} ALL=(ALL) NOPASSWD: /usr/bin/chkrootkit +{user} ALL=(ALL) NOPASSWD: /usr/bin/chkrootkit -q + +# Suricata - IDS/IPS +{user} ALL=(ALL) NOPASSWD: /usr/bin/systemctl start suricata +{user} ALL=(ALL) NOPASSWD: /usr/bin/systemctl stop suricata +{user} ALL=(ALL) NOPASSWD: /usr/bin/systemctl restart suricata +{user} ALL=(ALL) NOPASSWD: /usr/bin/systemctl enable suricata +{user} ALL=(ALL) NOPASSWD: /usr/bin/systemctl disable suricata +{user} ALL=(ALL) NOPASSWD: /usr/bin/systemctl is-active suricata +{user} ALL=(ALL) NOPASSWD: /usr/bin/suricata-update + +# ClamAV - antivirus +{user} ALL=(ALL) NOPASSWD: /usr/bin/systemctl start clamav-daemon +{user} ALL=(ALL) NOPASSWD: /usr/bin/systemctl stop clamav-daemon +{user} ALL=(ALL) NOPASSWD: /usr/bin/systemctl restart clamav-daemon +{user} ALL=(ALL) NOPASSWD: /usr/bin/systemctl enable clamav-daemon +{user} ALL=(ALL) NOPASSWD: /usr/bin/systemctl disable clamav-daemon +{user} ALL=(ALL) NOPASSWD: /usr/bin/systemctl is-active clamav-daemon +{user} ALL=(ALL) NOPASSWD: /usr/bin/freshclam + +# LMD (Linux Malware Detect) +{user} ALL=(ALL) NOPASSWD: /usr/local/sbin/maldet -a /home +{user} ALL=(ALL) NOPASSWD: /usr/local/sbin/maldet -a /var/www +{user} ALL=(ALL) NOPASSWD: /usr/local/sbin/maldet -a /tmp +{user} ALL=(ALL) NOPASSWD: /usr/local/sbin/maldet --update-sigs +{user} ALL=(ALL) NOPASSWD: /usr/local/sbin/maldet --update-ver +"#; + +const PACKAGES: &[&str] = &[ + "lynis", + "rkhunter", + "chkrootkit", + "suricata", + "clamav", + "clamav-daemon", +]; + +pub struct ProtectionInstaller { + user: String, +} + +impl ProtectionInstaller { + pub fn new() -> Result { + let user = std::env::var("SUDO_USER") + .or_else(|_| std::env::var("USER")) + .unwrap_or_else(|_| "root".to_string()); + + Ok(Self { user }) + } + + pub fn check_root() -> bool { + Command::new("id") + .arg("-u") + .output() + .map(|o| String::from_utf8_lossy(&o.stdout).trim() == "0") + .unwrap_or(false) + } + + pub fn install(&self) -> Result { + if !Self::check_root() { + return Err(anyhow::anyhow!( + "This command requires root privileges. Run with: sudo botserver install protection" + )); + } + + info!("Starting security protection installation for user: {}", self.user); + + let mut result = InstallResult::default(); + + match self.install_packages() { + Ok(installed) => { + result.packages_installed = installed; + info!("Packages installed: {:?}", result.packages_installed); + } + Err(e) => { + error!("Failed to install packages: {e}"); + result.errors.push(format!("Package installation failed: {e}")); + } + } + + match self.create_sudoers() { + Ok(()) => { + result.sudoers_created = true; + info!("Sudoers file created successfully"); + } + Err(e) => { + error!("Failed to create sudoers file: {e}"); + result.errors.push(format!("Sudoers creation failed: {e}")); + } + } + + match self.install_lmd() { + Ok(installed) => { + if installed { + result.packages_installed.push("maldetect".to_string()); + info!("LMD (maldetect) installed successfully"); + } + } + Err(e) => { + warn!("LMD installation skipped: {e}"); + result.warnings.push(format!("LMD installation skipped: {e}")); + } + } + + match self.update_databases() { + Ok(()) => { + result.databases_updated = true; + info!("Security databases updated"); + } + Err(e) => { + warn!("Database update failed: {e}"); + result.warnings.push(format!("Database update failed: {e}")); + } + } + + result.success = result.errors.is_empty(); + Ok(result) + } + + fn install_packages(&self) -> Result> { + info!("Updating package lists..."); + + SafeCommand::new("apt-get")? + .arg("update")? + .execute() + .context("Failed to update package lists")?; + + let mut installed = Vec::new(); + + for package in PACKAGES { + info!("Installing package: {package}"); + + let result = SafeCommand::new("apt-get")? + .arg("install")? + .arg("-y")? + .arg(package)? + .execute(); + + match result { + Ok(output) => { + if output.status.success() { + installed.push((*package).to_string()); + } else { + let stderr = String::from_utf8_lossy(&output.stderr); + warn!("Package {package} installation had issues: {stderr}"); + } + } + Err(e) => { + warn!("Failed to install {package}: {e}"); + } + } + } + + Ok(installed) + } + + fn create_sudoers(&self) -> Result<()> { + let content = SUDOERS_CONTENT.replace("{user}", &self.user); + + info!("Creating sudoers file at {SUDOERS_FILE}"); + + fs::write(SUDOERS_FILE, &content) + .context("Failed to write sudoers file")?; + + let permissions = fs::Permissions::from_mode(0o440); + fs::set_permissions(SUDOERS_FILE, permissions) + .context("Failed to set sudoers file permissions")?; + + self.validate_sudoers()?; + + info!("Sudoers file created and validated"); + Ok(()) + } + + fn validate_sudoers(&self) -> Result<()> { + let output = std::process::Command::new("visudo") + .args(["-c", "-f", SUDOERS_FILE]) + .output() + .context("Failed to run visudo validation")?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + fs::remove_file(SUDOERS_FILE).ok(); + return Err(anyhow::anyhow!("Invalid sudoers file syntax: {stderr}")); + } + + Ok(()) + } + + fn install_lmd(&self) -> Result { + let maldet_path = Path::new("/usr/local/sbin/maldet"); + if maldet_path.exists() { + info!("LMD already installed"); + return Ok(false); + } + + info!("Installing Linux Malware Detect (LMD)..."); + + let temp_dir = "/tmp/maldetect_install"; + fs::create_dir_all(temp_dir).ok(); + + let download_result = SafeCommand::new("curl")? + .arg("-sL")? + .arg("-o")? + .arg("/tmp/maldetect-current.tar.gz")? + .arg("https://www.rfxn.com/downloads/maldetect-current.tar.gz")? + .execute(); + + if download_result.is_err() { + return Err(anyhow::anyhow!("Failed to download LMD")); + } + + SafeCommand::new("tar")? + .arg("-xzf")? + .arg("/tmp/maldetect-current.tar.gz")? + .arg("-C")? + .arg(temp_dir)? + .execute() + .context("Failed to extract LMD archive")?; + + let entries = fs::read_dir(temp_dir)?; + let mut install_dir = None; + + for entry in entries.flatten() { + let path = entry.path(); + if path.is_dir() && path.file_name().is_some_and(|n| n.to_string_lossy().starts_with("maldetect")) { + install_dir = Some(path); + break; + } + } + + let install_dir = install_dir.ok_or_else(|| anyhow::anyhow!("LMD install directory not found"))?; + let install_script = install_dir.join("install.sh"); + + if !install_script.exists() { + return Err(anyhow::anyhow!("LMD install.sh not found")); + } + + SafeCommand::new("bash")? + .arg("-c")? + .shell_script_arg(&format!("cd {} && ./install.sh", install_dir.display()))? + .execute() + .context("Failed to run LMD installer")?; + + fs::remove_dir_all(temp_dir).ok(); + fs::remove_file("/tmp/maldetect-current.tar.gz").ok(); + + Ok(true) + } + + fn update_databases(&self) -> Result<()> { + info!("Updating security tool databases..."); + + if Path::new("/usr/bin/rkhunter").exists() { + info!("Updating RKHunter database..."); + let result = SafeCommand::new("rkhunter")? + .arg("--update")? + .execute(); + if let Err(e) = result { + warn!("RKHunter update failed: {e}"); + } + } + + if Path::new("/usr/bin/freshclam").exists() { + info!("Updating ClamAV signatures..."); + let result = SafeCommand::new("freshclam")? + .execute(); + if let Err(e) = result { + warn!("ClamAV update failed: {e}"); + } + } + + if Path::new("/usr/bin/suricata-update").exists() { + info!("Updating Suricata rules..."); + let result = SafeCommand::new("suricata-update")? + .execute(); + if let Err(e) = result { + warn!("Suricata update failed: {e}"); + } + } + + if Path::new("/usr/local/sbin/maldet").exists() { + info!("Updating LMD signatures..."); + let result = SafeCommand::new("maldet")? + .arg("--update-sigs")? + .execute(); + if let Err(e) = result { + warn!("LMD update failed: {e}"); + } + } + + Ok(()) + } + + pub fn uninstall(&self) -> Result { + if !Self::check_root() { + return Err(anyhow::anyhow!( + "This command requires root privileges. Run with: sudo botserver remove protection" + )); + } + + info!("Removing security protection components..."); + + let mut result = UninstallResult::default(); + + if Path::new(SUDOERS_FILE).exists() { + match fs::remove_file(SUDOERS_FILE) { + Ok(()) => { + result.sudoers_removed = true; + info!("Removed sudoers file"); + } + Err(e) => { + result.errors.push(format!("Failed to remove sudoers: {e}")); + } + } + } + + result.success = result.errors.is_empty(); + result.message = "Protection sudoers removed. Packages were NOT uninstalled - remove manually if needed.".to_string(); + + Ok(result) + } + + pub fn verify(&self) -> VerifyResult { + let mut result = VerifyResult::default(); + + for package in PACKAGES { + let binary = match *package { + "clamav" | "clamav-daemon" => "clamscan", + other => other, + }; + + let check = SafeCommand::new("which") + .and_then(|cmd| cmd.arg(binary)) + .and_then(|cmd| cmd.execute()); + + let installed = check.map(|o| o.status.success()).unwrap_or(false); + result.tools.push(ToolVerification { + name: (*package).to_string(), + installed, + sudo_configured: false, + }); + } + + let maldet_installed = Path::new("/usr/local/sbin/maldet").exists(); + result.tools.push(ToolVerification { + name: "maldetect".to_string(), + installed: maldet_installed, + sudo_configured: false, + }); + + result.sudoers_exists = Path::new(SUDOERS_FILE).exists(); + + if result.sudoers_exists { + if let Ok(content) = fs::read_to_string(SUDOERS_FILE) { + for tool in &mut result.tools { + tool.sudo_configured = content.contains(&tool.name) || + (tool.name == "clamav" && content.contains("clamav-daemon")) || + (tool.name == "clamav-daemon" && content.contains("clamav-daemon")); + } + } + } + + result.all_installed = result.tools.iter().filter(|t| t.name != "clamav-daemon").all(|t| t.installed); + result.all_configured = result.sudoers_exists && result.tools.iter().all(|t| t.sudo_configured || !t.installed); + + result + } +} + +impl Default for ProtectionInstaller { + fn default() -> Self { + Self::new().unwrap_or(Self { user: "root".to_string() }) + } +} + +#[derive(Debug, Default)] +pub struct InstallResult { + pub success: bool, + pub packages_installed: Vec, + pub sudoers_created: bool, + pub databases_updated: bool, + pub errors: Vec, + pub warnings: Vec, +} + +impl InstallResult { + pub fn print(&self) { + println!(); + if self.success { + println!("✓ Security Protection installed successfully!"); + } else { + println!("✗ Security Protection installation completed with errors"); + } + println!(); + + if !self.packages_installed.is_empty() { + println!("Packages installed:"); + for pkg in &self.packages_installed { + println!(" ✓ {pkg}"); + } + println!(); + } + + if self.sudoers_created { + println!("✓ Sudoers configuration created at {SUDOERS_FILE}"); + } + + if self.databases_updated { + println!("✓ Security databases updated"); + } + + if !self.warnings.is_empty() { + println!(); + println!("Warnings:"); + for warn in &self.warnings { + println!(" ⚠ {warn}"); + } + } + + if !self.errors.is_empty() { + println!(); + println!("Errors:"); + for err in &self.errors { + println!(" ✗ {err}"); + } + } + + println!(); + println!("The following commands are now available via the UI:"); + println!(" - Lynis security audits"); + println!(" - RKHunter rootkit scans"); + println!(" - Chkrootkit scans"); + println!(" - Suricata IDS management"); + println!(" - ClamAV antivirus scans"); + println!(" - LMD malware detection"); + } +} + +#[derive(Debug, Default)] +pub struct UninstallResult { + pub success: bool, + pub sudoers_removed: bool, + pub message: String, + pub errors: Vec, +} + +impl UninstallResult { + pub fn print(&self) { + println!(); + if self.success { + println!("✓ {}", self.message); + } else { + println!("✗ Uninstall completed with errors"); + for err in &self.errors { + println!(" ✗ {err}"); + } + } + } +} + +#[derive(Debug, Default)] +pub struct VerifyResult { + pub all_installed: bool, + pub all_configured: bool, + pub sudoers_exists: bool, + pub tools: Vec, +} + +#[derive(Debug, Default)] +pub struct ToolVerification { + pub name: String, + pub installed: bool, + pub sudo_configured: bool, +} + +impl VerifyResult { + pub fn print(&self) { + println!(); + println!("Security Protection Status:"); + println!(); + + println!("Tools:"); + for tool in &self.tools { + let installed_mark = if tool.installed { "✓" } else { "✗" }; + let sudo_mark = if tool.sudo_configured { "✓" } else { "✗" }; + println!(" {} {} (installed: {}, sudo: {})", + if tool.installed && tool.sudo_configured { "✓" } else { "⚠" }, + tool.name, + installed_mark, + sudo_mark + ); + } + + println!(); + println!("Sudoers file: {}", if self.sudoers_exists { "✓ exists" } else { "✗ missing" }); + println!(); + + if self.all_installed && self.all_configured { + println!("✓ All protection tools are properly configured"); + } else if !self.all_installed { + println!("⚠ Some tools are not installed. Run: sudo botserver install protection"); + } else { + println!("⚠ Sudoers not configured. Run: sudo botserver install protection"); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_install_result_default() { + let result = InstallResult::default(); + assert!(!result.success); + assert!(result.packages_installed.is_empty()); + assert!(!result.sudoers_created); + } + + #[test] + fn test_verify_result_default() { + let result = VerifyResult::default(); + assert!(!result.all_installed); + assert!(!result.all_configured); + assert!(result.tools.is_empty()); + } + + #[test] + fn test_sudoers_content_has_placeholder() { + assert!(SUDOERS_CONTENT.contains("{user}")); + } + + #[test] + fn test_sudoers_content_no_wildcards() { + assert!(!SUDOERS_CONTENT.contains(" * ")); + assert!(!SUDOERS_CONTENT.lines().any(|l| l.trim().ends_with('*'))); + } + + #[test] + fn test_packages_list() { + assert!(PACKAGES.contains(&"lynis")); + assert!(PACKAGES.contains(&"rkhunter")); + assert!(PACKAGES.contains(&"chkrootkit")); + assert!(PACKAGES.contains(&"suricata")); + assert!(PACKAGES.contains(&"clamav")); + } + + #[test] + fn test_tool_verification_default() { + let tool = ToolVerification::default(); + assert!(tool.name.is_empty()); + assert!(!tool.installed); + assert!(!tool.sudo_configured); + } + + #[test] + fn test_uninstall_result_default() { + let result = UninstallResult::default(); + assert!(!result.success); + assert!(!result.sudoers_removed); + assert!(result.message.is_empty()); + } + + #[test] + fn test_protection_installer_default() { + let installer = ProtectionInstaller::default(); + assert!(!installer.user.is_empty()); + } +} diff --git a/src/security/protection/lmd.rs b/src/security/protection/lmd.rs new file mode 100644 index 000000000..345404a91 --- /dev/null +++ b/src/security/protection/lmd.rs @@ -0,0 +1,481 @@ +use anyhow::{Context, Result}; +use tracing::info; + +use crate::security::command_guard::SafeCommand; +use super::manager::{Finding, FindingSeverity, ScanResultStatus}; + +const LMD_LOG_DIR: &str = "/usr/local/maldetect/logs"; +const LMD_QUARANTINE_DIR: &str = "/usr/local/maldetect/quarantine"; + +pub async fn run_scan(path: Option<&str>) -> Result<(ScanResultStatus, Vec, String)> { + info!("Running Linux Malware Detect scan"); + + let scan_path = path.unwrap_or("/var/www"); + + let output = SafeCommand::new("sudo")? + .arg("maldet")? + .arg("-a")? + .arg(scan_path)? + .execute() + .context("Failed to run LMD scan")?; + + let stdout = String::from_utf8_lossy(&output.stdout).to_string(); + let stderr = String::from_utf8_lossy(&output.stderr).to_string(); + let raw_output = format!("{stdout}\n{stderr}"); + + let findings = parse_lmd_output(&stdout); + let status = determine_result_status(&findings); + + Ok((status, findings, raw_output)) +} + +pub async fn run_background_scan(path: Option<&str>) -> Result { + info!("Starting LMD background scan"); + + let scan_path = path.unwrap_or("/var/www"); + + let output = SafeCommand::new("sudo")? + .arg("maldet")? + .arg("-b")? + .arg("-a")? + .arg(scan_path)? + .execute() + .context("Failed to start LMD background scan")?; + + let stdout = String::from_utf8_lossy(&output.stdout); + let scan_id = extract_scan_id(&stdout).unwrap_or_else(|| "unknown".to_string()); + + info!("LMD background scan started with ID: {scan_id}"); + Ok(scan_id) +} + +pub async fn update_signatures() -> Result<()> { + info!("Updating LMD signatures"); + + SafeCommand::new("sudo")? + .arg("maldet")? + .arg("--update-sigs")? + .execute() + .context("Failed to update LMD signatures")?; + + info!("LMD signatures updated successfully"); + Ok(()) +} + +pub async fn update_version() -> Result<()> { + info!("Updating LMD version"); + + SafeCommand::new("sudo")? + .arg("maldet")? + .arg("--update-ver")? + .execute() + .context("Failed to update LMD version")?; + + info!("LMD version updated successfully"); + Ok(()) +} + +pub async fn get_version() -> Result { + let output = SafeCommand::new("maldet")? + .arg("--version")? + .execute() + .context("Failed to get LMD version")?; + + let stdout = String::from_utf8_lossy(&output.stdout); + let version = stdout + .lines() + .find(|l| l.contains("maldet") || l.contains("version")) + .and_then(|l| l.split_whitespace().last()) + .unwrap_or("unknown") + .to_string(); + + Ok(version) +} + +pub async fn get_signature_count() -> Result { + let sig_dir = "/usr/local/maldetect/sigs"; + + let mut count = 0u64; + + if let Ok(entries) = std::fs::read_dir(sig_dir) { + for entry in entries.flatten() { + let path = entry.path(); + if path.is_file() { + if let Ok(content) = std::fs::read_to_string(&path) { + count += content.lines().filter(|l| !l.trim().is_empty()).count() as u64; + } + } + } + } + + Ok(count) +} + +pub async fn quarantine_file(file_path: &str) -> Result<()> { + info!("Quarantining file: {file_path}"); + + SafeCommand::new("sudo")? + .arg("maldet")? + .arg("-q")? + .arg(file_path)? + .execute() + .context("Failed to quarantine file")?; + + info!("File quarantined successfully: {file_path}"); + Ok(()) +} + +pub async fn restore_file(file_path: &str) -> Result<()> { + info!("Restoring file from quarantine: {file_path}"); + + SafeCommand::new("sudo")? + .arg("maldet")? + .arg("--restore")? + .arg(file_path)? + .execute() + .context("Failed to restore file from quarantine")?; + + info!("File restored successfully: {file_path}"); + Ok(()) +} + +pub async fn clean_file(file_path: &str) -> Result<()> { + info!("Cleaning infected file: {file_path}"); + + SafeCommand::new("sudo")? + .arg("maldet")? + .arg("-n")? + .arg(file_path)? + .execute() + .context("Failed to clean file")?; + + info!("File cleaned successfully: {file_path}"); + Ok(()) +} + +pub async fn get_report(scan_id: &str) -> Result { + info!("Retrieving LMD report for scan: {scan_id}"); + + let output = SafeCommand::new("sudo")? + .arg("maldet")? + .arg("--report")? + .arg(scan_id)? + .execute() + .context("Failed to get LMD report")?; + + let stdout = String::from_utf8_lossy(&output.stdout).to_string(); + Ok(stdout) +} + +pub async fn list_quarantined() -> Result> { + let mut files = Vec::new(); + + if let Ok(entries) = std::fs::read_dir(LMD_QUARANTINE_DIR) { + for entry in entries.flatten() { + let path = entry.path(); + if path.is_file() { + let filename = path.file_name() + .and_then(|n| n.to_str()) + .unwrap_or("unknown") + .to_string(); + + let metadata = std::fs::metadata(&path).ok(); + let size = metadata.as_ref().map(|m| m.len()).unwrap_or(0); + let quarantined_at = metadata + .and_then(|m| m.modified().ok()) + .map(|t| chrono::DateTime::::from(t)); + + files.push(QuarantinedFile { + id: filename.clone(), + original_path: extract_original_path(&filename), + quarantine_path: path.to_string_lossy().to_string(), + size, + quarantined_at, + threat_name: None, + }); + } + } + } + + Ok(files) +} + +pub fn parse_lmd_output(output: &str) -> Vec { + let mut findings = Vec::new(); + + for line in output.lines() { + let trimmed = line.trim(); + + if trimmed.contains("HIT") || trimmed.contains("FOUND") { + let parts: Vec<&str> = trimmed.split_whitespace().collect(); + let file_path = parts.iter().find(|p| p.starts_with('/')).map(|s| s.to_string()); + let threat_name = parts.iter() + .find(|p| p.contains("malware") || p.contains("backdoor") || p.contains("trojan")) + .map(|s| s.to_string()) + .unwrap_or_else(|| "Malware".to_string()); + + let finding = Finding { + id: format!("lmd-hit-{}", findings.len()), + severity: FindingSeverity::Critical, + category: "Malware Detection".to_string(), + title: format!("Malware Detected: {threat_name}"), + description: trimmed.to_string(), + file_path, + remediation: Some("Quarantine or remove the infected file immediately".to_string()), + }; + findings.push(finding); + } + + if trimmed.contains("suspicious") || trimmed.contains("Suspicious") { + let file_path = extract_file_path_from_line(trimmed); + + let finding = Finding { + id: format!("lmd-susp-{}", findings.len()), + severity: FindingSeverity::High, + category: "Suspicious Activity".to_string(), + title: "Suspicious File Detected".to_string(), + description: trimmed.to_string(), + file_path, + remediation: Some("Review the file and consider quarantine if malicious".to_string()), + }; + findings.push(finding); + } + + if trimmed.contains("warning") || trimmed.contains("Warning") { + let finding = Finding { + id: format!("lmd-warn-{}", findings.len()), + severity: FindingSeverity::Medium, + category: "Warning".to_string(), + title: "LMD Warning".to_string(), + description: trimmed.to_string(), + file_path: None, + remediation: None, + }; + findings.push(finding); + } + } + + findings +} + +fn extract_scan_id(output: &str) -> Option { + for line in output.lines() { + if line.contains("scan id:") || line.contains("SCAN ID:") { + return line.split(':').nth(1).map(|s| s.trim().to_string()); + } + if line.contains("report") && line.contains(".") { + let parts: Vec<&str> = line.split_whitespace().collect(); + for part in parts { + if part.contains('.') && part.chars().all(|c| c.is_numeric() || c == '.') { + return Some(part.to_string()); + } + } + } + } + None +} + +fn extract_file_path_from_line(line: &str) -> Option { + let words: Vec<&str> = line.split_whitespace().collect(); + for word in words { + if word.starts_with('/') { + return Some(word.trim_matches(|c| c == ':' || c == ',' || c == ';').to_string()); + } + } + None +} + +fn extract_original_path(quarantine_filename: &str) -> String { + quarantine_filename + .replace(".", "/") + .trim_start_matches('/') + .to_string() +} + +fn determine_result_status(findings: &[Finding]) -> ScanResultStatus { + let has_critical = findings.iter().any(|f| f.severity == FindingSeverity::Critical); + let has_high = findings.iter().any(|f| f.severity == FindingSeverity::High); + let has_medium = findings.iter().any(|f| f.severity == FindingSeverity::Medium); + + if has_critical { + ScanResultStatus::Infected + } else if has_high { + ScanResultStatus::Warnings + } else if has_medium { + ScanResultStatus::Warnings + } else { + ScanResultStatus::Clean + } +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct QuarantinedFile { + pub id: String, + pub original_path: String, + pub quarantine_path: String, + pub size: u64, + pub quarantined_at: Option>, + pub threat_name: Option, +} + +#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)] +pub struct LMDStats { + pub signature_count: u64, + pub quarantined_count: u32, + pub last_scan: Option>, + pub threats_found: u32, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_lmd_output_clean() { + let output = r#" +Linux Malware Detect v1.6.5 +Scanning /var/www +Total files scanned: 1234 +Total hits: 0 +Total cleaned: 0 +"#; + let findings = parse_lmd_output(output); + assert!(findings.is_empty()); + } + + #[test] + fn test_parse_lmd_output_hit() { + let output = r#" +Linux Malware Detect v1.6.5 +Scanning /var/www +{HIT} /var/www/uploads/shell.php : php.cmdshell.unclassed.6 +Total hits: 1 +"#; + let findings = parse_lmd_output(output); + assert_eq!(findings.len(), 1); + assert_eq!(findings[0].severity, FindingSeverity::Critical); + } + + #[test] + fn test_parse_lmd_output_suspicious() { + let output = r#" +Linux Malware Detect v1.6.5 +Scanning /var/www +suspicious file found: /var/www/uploads/unknown.php +"#; + let findings = parse_lmd_output(output); + assert_eq!(findings.len(), 1); + assert_eq!(findings[0].severity, FindingSeverity::High); + } + + #[test] + fn test_extract_scan_id() { + assert_eq!( + extract_scan_id("scan id: 123456.789"), + Some("123456.789".to_string()) + ); + assert_eq!( + extract_scan_id("SCAN ID: abc123"), + Some("abc123".to_string()) + ); + assert_eq!(extract_scan_id("no scan id here"), None); + } + + #[test] + fn test_extract_file_path_from_line() { + assert_eq!( + extract_file_path_from_line("Found malware in /var/www/shell.php"), + Some("/var/www/shell.php".to_string()) + ); + assert_eq!( + extract_file_path_from_line("No path here"), + None + ); + } + + #[test] + fn test_extract_original_path() { + assert_eq!( + extract_original_path("var.www.uploads.shell.php"), + "var/www/uploads/shell/php" + ); + } + + #[test] + fn test_determine_result_status_clean() { + let findings: Vec = vec![]; + assert_eq!(determine_result_status(&findings), ScanResultStatus::Clean); + } + + #[test] + fn test_determine_result_status_infected() { + let findings = vec![Finding { + id: "test".to_string(), + severity: FindingSeverity::Critical, + category: "test".to_string(), + title: "Test".to_string(), + description: "Test".to_string(), + file_path: None, + remediation: None, + }]; + assert_eq!(determine_result_status(&findings), ScanResultStatus::Infected); + } + + #[test] + fn test_determine_result_status_warnings() { + let findings = vec![Finding { + id: "test".to_string(), + severity: FindingSeverity::High, + category: "test".to_string(), + title: "Test".to_string(), + description: "Test".to_string(), + file_path: None, + remediation: None, + }]; + assert_eq!(determine_result_status(&findings), ScanResultStatus::Warnings); + } + + #[test] + fn test_quarantined_file_struct() { + let file = QuarantinedFile { + id: "test".to_string(), + original_path: "/var/www/shell.php".to_string(), + quarantine_path: "/usr/local/maldetect/quarantine/test".to_string(), + size: 1024, + quarantined_at: None, + threat_name: Some("php.cmdshell".to_string()), + }; + assert_eq!(file.size, 1024); + assert!(file.threat_name.is_some()); + } + + #[test] + fn test_lmd_stats_default() { + let stats = LMDStats::default(); + assert_eq!(stats.signature_count, 0); + assert_eq!(stats.quarantined_count, 0); + assert!(stats.last_scan.is_none()); + } + + #[test] + fn test_parse_lmd_output_warning() { + let output = r#" +Linux Malware Detect v1.6.5 +Warning: signature database may be outdated +Scanning /var/www +"#; + let findings = parse_lmd_output(output); + assert_eq!(findings.len(), 1); + assert_eq!(findings[0].severity, FindingSeverity::Medium); + } + + #[test] + fn test_parse_lmd_output_found() { + let output = r#" +FOUND: /var/www/malicious.php : malware.backdoor.123 +"#; + let findings = parse_lmd_output(output); + assert_eq!(findings.len(), 1); + assert_eq!(findings[0].severity, FindingSeverity::Critical); + } +} diff --git a/src/security/protection/lynis.rs b/src/security/protection/lynis.rs new file mode 100644 index 000000000..52e4e391f --- /dev/null +++ b/src/security/protection/lynis.rs @@ -0,0 +1,273 @@ +use anyhow::{Context, Result}; +use tracing::{info, warn}; + +use crate::security::command_guard::SafeCommand; +use super::manager::{Finding, FindingSeverity, ScanResultStatus}; + +const LYNIS_REPORT_PATH: &str = "/var/log/lynis-report.dat"; +const LYNIS_LOG_PATH: &str = "/var/log/lynis.log"; + +pub async fn run_scan() -> Result<(ScanResultStatus, Vec, String)> { + info!("Running Lynis security audit"); + + let output = SafeCommand::new("sudo")? + .arg("lynis")? + .arg("audit")? + .arg("system")? + .arg("--quick")? + .arg("--no-colors")? + .execute() + .context("Failed to run Lynis audit")?; + + let stdout = String::from_utf8_lossy(&output.stdout).to_string(); + let stderr = String::from_utf8_lossy(&output.stderr).to_string(); + let raw_output = format!("{stdout}\n{stderr}"); + + let findings = parse_lynis_output(&stdout); + let status = determine_result_status(&findings); + + Ok((status, findings, raw_output)) +} + +pub async fn run_full_audit() -> Result<(ScanResultStatus, Vec, String)> { + info!("Running full Lynis security audit"); + + let output = SafeCommand::new("sudo")? + .arg("lynis")? + .arg("audit")? + .arg("system")? + .arg("--no-colors")? + .execute() + .context("Failed to run full Lynis audit")?; + + let stdout = String::from_utf8_lossy(&output.stdout).to_string(); + let stderr = String::from_utf8_lossy(&output.stderr).to_string(); + let raw_output = format!("{stdout}\n{stderr}"); + + let findings = parse_lynis_output(&stdout); + let status = determine_result_status(&findings); + + Ok((status, findings, raw_output)) +} + +pub fn parse_lynis_output(output: &str) -> Vec { + let mut findings = Vec::new(); + let mut current_category = String::new(); + + for line in output.lines() { + let trimmed = line.trim(); + + if trimmed.starts_with('[') && trimmed.contains(']') { + if let Some(category) = extract_category(trimmed) { + current_category = category; + } + } + + if trimmed.contains("Warning:") || trimmed.contains("WARNING") { + let finding = Finding { + id: format!("lynis-warn-{}", findings.len()), + severity: FindingSeverity::Medium, + category: current_category.clone(), + title: "Security Warning".to_string(), + description: trimmed.replace("Warning:", "").trim().to_string(), + file_path: None, + remediation: None, + }; + findings.push(finding); + } + + if trimmed.contains("Suggestion:") || trimmed.contains("SUGGESTION") { + let finding = Finding { + id: format!("lynis-sugg-{}", findings.len()), + severity: FindingSeverity::Low, + category: current_category.clone(), + title: "Security Suggestion".to_string(), + description: trimmed.replace("Suggestion:", "").trim().to_string(), + file_path: None, + remediation: extract_remediation(trimmed), + }; + findings.push(finding); + } + + if trimmed.contains("[FOUND]") && trimmed.contains("vulnerable") { + let finding = Finding { + id: format!("lynis-vuln-{}", findings.len()), + severity: FindingSeverity::High, + category: current_category.clone(), + title: "Vulnerability Found".to_string(), + description: trimmed.to_string(), + file_path: None, + remediation: None, + }; + findings.push(finding); + } + } + + findings +} + +pub fn parse_report_file() -> Result { + let content = std::fs::read_to_string(LYNIS_REPORT_PATH) + .context("Failed to read Lynis report file")?; + + let mut report = LynisReport::default(); + + for line in content.lines() { + if line.starts_with('#') || line.trim().is_empty() { + continue; + } + + if let Some((key, value)) = line.split_once('=') { + match key { + "hardening_index" => { + report.hardening_index = value.parse().unwrap_or(0); + } + "warning[]" => { + report.warnings.push(value.to_string()); + } + "suggestion[]" => { + report.suggestions.push(value.to_string()); + } + "lynis_version" => { + report.version = value.to_string(); + } + "test_category[]" => { + report.categories_tested.push(value.to_string()); + } + "tests_executed" => { + report.tests_executed = value.parse().unwrap_or(0); + } + _ => {} + } + } + } + + Ok(report) +} + +pub async fn get_hardening_index() -> Result { + let report = parse_report_file()?; + Ok(report.hardening_index) +} + +pub async fn apply_suggestion(suggestion_id: &str) -> Result<()> { + info!("Applying Lynis suggestion: {suggestion_id}"); + + warn!("Auto-remediation for suggestion {suggestion_id} not yet implemented"); + + Ok(()) +} + +fn extract_category(line: &str) -> Option { + let start = line.find('[')?; + let end = line.find(']')?; + if start < end { + Some(line[start + 1..end].trim().to_string()) + } else { + None + } +} + +fn extract_remediation(line: &str) -> Option { + if line.contains("Consider") { + Some(line.split("Consider").nth(1)?.trim().to_string()) + } else if line.contains("Disable") { + Some(line.split("Disable").nth(1).map(|s| format!("Disable {}", s.trim()))?) + } else if line.contains("Enable") { + Some(line.split("Enable").nth(1).map(|s| format!("Enable {}", s.trim()))?) + } else { + None + } +} + +fn determine_result_status(findings: &[Finding]) -> ScanResultStatus { + let has_critical = findings.iter().any(|f| f.severity == FindingSeverity::Critical); + let has_high = findings.iter().any(|f| f.severity == FindingSeverity::High); + let has_medium = findings.iter().any(|f| f.severity == FindingSeverity::Medium); + + if has_critical || has_high { + ScanResultStatus::Infected + } else if has_medium { + ScanResultStatus::Warnings + } else { + ScanResultStatus::Clean + } +} + +#[derive(Debug, Clone, Default)] +pub struct LynisReport { + pub version: String, + pub hardening_index: u32, + pub tests_executed: u32, + pub warnings: Vec, + pub suggestions: Vec, + pub categories_tested: Vec, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_lynis_output_warnings() { + let output = r#" +[+] Boot and services + - Service Manager [ systemd ] + Warning: Some warning message here +[+] Kernel + Suggestion: Consider enabling some feature +"#; + let findings = parse_lynis_output(output); + assert_eq!(findings.len(), 2); + assert_eq!(findings[0].severity, FindingSeverity::Medium); + assert_eq!(findings[1].severity, FindingSeverity::Low); + } + + #[test] + fn test_extract_category() { + assert_eq!(extract_category("[+] Boot and services"), Some("+ Boot and services".to_string())); + assert_eq!(extract_category("no brackets"), None); + } + + #[test] + fn test_determine_result_status_clean() { + let findings: Vec = vec![]; + assert_eq!(determine_result_status(&findings), ScanResultStatus::Clean); + } + + #[test] + fn test_determine_result_status_warnings() { + let findings = vec![Finding { + id: "test".to_string(), + severity: FindingSeverity::Medium, + category: "test".to_string(), + title: "Test".to_string(), + description: "Test".to_string(), + file_path: None, + remediation: None, + }]; + assert_eq!(determine_result_status(&findings), ScanResultStatus::Warnings); + } + + #[test] + fn test_determine_result_status_infected() { + let findings = vec![Finding { + id: "test".to_string(), + severity: FindingSeverity::High, + category: "test".to_string(), + title: "Test".to_string(), + description: "Test".to_string(), + file_path: None, + remediation: None, + }]; + assert_eq!(determine_result_status(&findings), ScanResultStatus::Infected); + } + + #[test] + fn test_lynis_report_default() { + let report = LynisReport::default(); + assert_eq!(report.hardening_index, 0); + assert!(report.warnings.is_empty()); + assert!(report.suggestions.is_empty()); + } +} diff --git a/src/security/protection/manager.rs b/src/security/protection/manager.rs new file mode 100644 index 000000000..3571aadf0 --- /dev/null +++ b/src/security/protection/manager.rs @@ -0,0 +1,621 @@ +use anyhow::{Context, Result}; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; +use tracing::{info, warn}; + +use crate::security::command_guard::SafeCommand; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum ProtectionTool { + Lynis, + RKHunter, + Chkrootkit, + Suricata, + LMD, + ClamAV, +} + +impl std::fmt::Display for ProtectionTool { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Lynis => write!(f, "lynis"), + Self::RKHunter => write!(f, "rkhunter"), + Self::Chkrootkit => write!(f, "chkrootkit"), + Self::Suricata => write!(f, "suricata"), + Self::LMD => write!(f, "lmd"), + Self::ClamAV => write!(f, "clamav"), + } + } +} + +impl ProtectionTool { + pub fn from_str(s: &str) -> Option { + match s.to_lowercase().as_str() { + "lynis" => Some(Self::Lynis), + "rkhunter" => Some(Self::RKHunter), + "chkrootkit" => Some(Self::Chkrootkit), + "suricata" => Some(Self::Suricata), + "lmd" | "maldet" => Some(Self::LMD), + "clamav" | "clamscan" => Some(Self::ClamAV), + _ => None, + } + } + + pub fn binary_name(&self) -> &'static str { + match self { + Self::Lynis => "lynis", + Self::RKHunter => "rkhunter", + Self::Chkrootkit => "chkrootkit", + Self::Suricata => "suricata", + Self::LMD => "maldet", + Self::ClamAV => "clamscan", + } + } + + pub fn service_name(&self) -> Option<&'static str> { + match self { + Self::Suricata => Some("suricata"), + Self::ClamAV => Some("clamav-daemon"), + _ => None, + } + } + + pub fn package_name(&self) -> &'static str { + match self { + Self::Lynis => "lynis", + Self::RKHunter => "rkhunter", + Self::Chkrootkit => "chkrootkit", + Self::Suricata => "suricata", + Self::LMD => "maldetect", + Self::ClamAV => "clamav", + } + } + + pub fn all() -> Vec { + vec![ + Self::Lynis, + Self::RKHunter, + Self::Chkrootkit, + Self::Suricata, + Self::LMD, + Self::ClamAV, + ] + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolStatus { + pub tool: ProtectionTool, + pub installed: bool, + pub version: Option, + pub service_running: Option, + pub last_scan: Option>, + pub last_update: Option>, + pub auto_update: bool, + pub auto_remediate: bool, + pub metrics: ToolMetrics, +} + +impl ToolStatus { + pub fn not_installed(tool: ProtectionTool) -> Self { + Self { + tool, + installed: false, + version: None, + service_running: None, + last_scan: None, + last_update: None, + auto_update: false, + auto_remediate: false, + metrics: ToolMetrics::default(), + } + } +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct ToolMetrics { + pub hardening_index: Option, + pub warnings: u32, + pub suggestions: u32, + pub threats_found: u32, + pub rules_count: Option, + pub alerts_today: u32, + pub blocked_today: u32, + pub signatures_count: Option, + pub quarantined_count: u32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ScanResult { + pub scan_id: String, + pub tool: ProtectionTool, + pub started_at: DateTime, + pub completed_at: Option>, + pub status: ScanStatus, + pub result: ScanResultStatus, + pub findings: Vec, + pub warnings: u32, + pub report_path: Option, + pub raw_output: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum ScanStatus { + Pending, + Running, + Completed, + Failed, + Cancelled, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum ScanResultStatus { + Clean, + Warnings, + Infected, + Unknown, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Finding { + pub id: String, + pub severity: FindingSeverity, + pub category: String, + pub title: String, + pub description: String, + pub file_path: Option, + pub remediation: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum FindingSeverity { + Info, + Low, + Medium, + High, + Critical, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ProtectionConfig { + pub enabled_tools: Vec, + pub auto_scan_interval_hours: u32, + pub auto_update_interval_hours: u32, + pub quarantine_dir: String, + pub log_dir: String, +} + +impl Default for ProtectionConfig { + fn default() -> Self { + Self { + enabled_tools: ProtectionTool::all(), + auto_scan_interval_hours: 24, + auto_update_interval_hours: 6, + quarantine_dir: "/var/lib/gb/quarantine".to_string(), + log_dir: "/var/log/gb/security".to_string(), + } + } +} + +pub struct ProtectionManager { + config: ProtectionConfig, + tool_status: Arc>>, + active_scans: Arc>>, + scan_history: Arc>>, +} + +impl ProtectionManager { + pub fn new(config: ProtectionConfig) -> Self { + Self { + config, + tool_status: Arc::new(RwLock::new(HashMap::new())), + active_scans: Arc::new(RwLock::new(HashMap::new())), + scan_history: Arc::new(RwLock::new(Vec::new())), + } + } + + pub async fn initialize(&self) -> Result<()> { + info!("Initializing Protection Manager"); + for tool in &self.config.enabled_tools { + let status = self.check_tool_status(*tool).await?; + self.tool_status.write().await.insert(*tool, status); + } + Ok(()) + } + + pub async fn check_tool_status(&self, tool: ProtectionTool) -> Result { + let installed = self.is_tool_installed(tool).await; + + if !installed { + return Ok(ToolStatus::not_installed(tool)); + } + + let version = self.get_tool_version(tool).await.ok(); + let service_running = if tool.service_name().is_some() { + Some(self.is_service_running(tool).await) + } else { + None + }; + + let stored = self.tool_status.read().await; + let existing = stored.get(&tool); + + Ok(ToolStatus { + tool, + installed: true, + version, + service_running, + last_scan: existing.and_then(|s| s.last_scan), + last_update: existing.and_then(|s| s.last_update), + auto_update: existing.map(|s| s.auto_update).unwrap_or(false), + auto_remediate: existing.map(|s| s.auto_remediate).unwrap_or(false), + metrics: existing.map(|s| s.metrics.clone()).unwrap_or_default(), + }) + } + + pub async fn is_tool_installed(&self, tool: ProtectionTool) -> bool { + let binary = tool.binary_name(); + + let result = SafeCommand::new("which") + .and_then(|cmd| cmd.arg(binary)) + .and_then(|cmd| cmd.execute()); + + match result { + Ok(output) => output.status.success(), + Err(e) => { + warn!("Failed to check if {tool} is installed: {e}"); + false + } + } + } + + pub async fn get_tool_version(&self, tool: ProtectionTool) -> Result { + let binary = tool.binary_name(); + let version_arg = match tool { + ProtectionTool::Lynis => "--version", + ProtectionTool::RKHunter => "--version", + ProtectionTool::Chkrootkit => "-V", + ProtectionTool::Suricata => "--build-info", + ProtectionTool::LMD => "--version", + ProtectionTool::ClamAV => "--version", + }; + + let output = SafeCommand::new(binary)? + .arg(version_arg)? + .execute() + .context("Failed to get tool version")?; + + let stdout = String::from_utf8_lossy(&output.stdout); + let version = stdout.lines().next().unwrap_or("unknown").trim().to_string(); + Ok(version) + } + + pub async fn is_service_running(&self, tool: ProtectionTool) -> bool { + let Some(service_name) = tool.service_name() else { + return false; + }; + + let result = SafeCommand::new("sudo") + .and_then(|cmd| cmd.arg("systemctl")) + .and_then(|cmd| cmd.arg("is-active")) + .and_then(|cmd| cmd.arg(service_name)); + + match result { + Ok(cmd) => match cmd.execute() { + Ok(output) => { + let stdout = String::from_utf8_lossy(&output.stdout); + stdout.trim() == "active" + } + Err(_) => false, + }, + Err(_) => false, + } + } + + pub async fn get_all_status(&self) -> HashMap { + self.tool_status.read().await.clone() + } + + pub async fn get_tool_status_by_name(&self, name: &str) -> Option { + let tool = ProtectionTool::from_str(name)?; + self.tool_status.read().await.get(&tool).cloned() + } + + pub async fn install_tool(&self, tool: ProtectionTool) -> Result<()> { + info!("Installing protection tool: {tool}"); + + let package = tool.package_name(); + + SafeCommand::new("apt-get")? + .arg("install")? + .arg("-y")? + .arg(package)? + .execute() + .context("Failed to install tool")?; + + let status = self.check_tool_status(tool).await?; + self.tool_status.write().await.insert(tool, status); + + info!("Successfully installed {tool}"); + Ok(()) + } + + pub async fn start_service(&self, tool: ProtectionTool) -> Result<()> { + let Some(service_name) = tool.service_name() else { + return Err(anyhow::anyhow!("{tool} does not have a service")); + }; + + info!("Starting service: {service_name}"); + + SafeCommand::new("sudo")? + .arg("systemctl")? + .arg("start")? + .arg(service_name)? + .execute() + .context("Failed to start service")?; + + if let Some(status) = self.tool_status.write().await.get_mut(&tool) { + status.service_running = Some(true); + } + + Ok(()) + } + + pub async fn stop_service(&self, tool: ProtectionTool) -> Result<()> { + let Some(service_name) = tool.service_name() else { + return Err(anyhow::anyhow!("{tool} does not have a service")); + }; + + info!("Stopping service: {service_name}"); + + SafeCommand::new("sudo")? + .arg("systemctl")? + .arg("stop")? + .arg(service_name)? + .execute() + .context("Failed to stop service")?; + + if let Some(status) = self.tool_status.write().await.get_mut(&tool) { + status.service_running = Some(false); + } + + Ok(()) + } + + pub async fn enable_service(&self, tool: ProtectionTool) -> Result<()> { + let Some(service_name) = tool.service_name() else { + return Err(anyhow::anyhow!("{tool} does not have a service")); + }; + + SafeCommand::new("sudo")? + .arg("systemctl")? + .arg("enable")? + .arg(service_name)? + .execute() + .context("Failed to enable service")?; + + Ok(()) + } + + pub async fn disable_service(&self, tool: ProtectionTool) -> Result<()> { + let Some(service_name) = tool.service_name() else { + return Err(anyhow::anyhow!("{tool} does not have a service")); + }; + + SafeCommand::new("sudo")? + .arg("systemctl")? + .arg("disable")? + .arg(service_name)? + .execute() + .context("Failed to disable service")?; + + Ok(()) + } + + pub async fn run_scan(&self, tool: ProtectionTool) -> Result { + let scan_id = uuid::Uuid::new_v4().to_string(); + let started_at = Utc::now(); + + info!("Starting {tool} scan with ID: {scan_id}"); + + let mut result = ScanResult { + scan_id: scan_id.clone(), + tool, + started_at, + completed_at: None, + status: ScanStatus::Running, + result: ScanResultStatus::Unknown, + findings: Vec::new(), + warnings: 0, + report_path: None, + raw_output: None, + }; + + self.active_scans.write().await.insert(scan_id.clone(), result.clone()); + + let scan_output = match tool { + ProtectionTool::Lynis => super::lynis::run_scan().await, + ProtectionTool::RKHunter => super::rkhunter::run_scan().await, + ProtectionTool::Chkrootkit => super::chkrootkit::run_scan().await, + ProtectionTool::Suricata => super::suricata::get_alerts().await, + ProtectionTool::LMD => super::lmd::run_scan(None).await, + ProtectionTool::ClamAV => super::lynis::run_scan().await, + }; + + result.completed_at = Some(Utc::now()); + + match scan_output { + Ok((status, findings, raw)) => { + result.status = ScanStatus::Completed; + result.result = status; + result.warnings = findings.iter().filter(|f| f.severity == FindingSeverity::Medium || f.severity == FindingSeverity::Low).count() as u32; + result.findings = findings; + result.raw_output = Some(raw); + } + Err(e) => { + warn!("Scan failed for {tool}: {e}"); + result.status = ScanStatus::Failed; + result.raw_output = Some(e.to_string()); + } + } + + self.active_scans.write().await.remove(&scan_id); + + if let Some(status) = self.tool_status.write().await.get_mut(&tool) { + status.last_scan = Some(Utc::now()); + status.metrics.warnings = result.warnings; + status.metrics.threats_found = result.findings.iter() + .filter(|f| f.severity == FindingSeverity::High || f.severity == FindingSeverity::Critical) + .count() as u32; + } + + self.scan_history.write().await.push(result.clone()); + + Ok(result) + } + + pub async fn update_definitions(&self, tool: ProtectionTool) -> Result<()> { + info!("Updating definitions for {tool}"); + + match tool { + ProtectionTool::RKHunter => { + SafeCommand::new("sudo")? + .arg("rkhunter")? + .arg("--update")? + .execute()?; + } + ProtectionTool::ClamAV => { + SafeCommand::new("sudo")? + .arg("freshclam")? + .execute()?; + } + ProtectionTool::Suricata => { + SafeCommand::new("sudo")? + .arg("suricata-update")? + .execute()?; + } + ProtectionTool::LMD => { + SafeCommand::new("sudo")? + .arg("maldet")? + .arg("--update-sigs")? + .execute()?; + } + _ => { + return Err(anyhow::anyhow!("{tool} does not support definition updates")); + } + } + + if let Some(status) = self.tool_status.write().await.get_mut(&tool) { + status.last_update = Some(Utc::now()); + } + + Ok(()) + } + + pub async fn set_auto_update(&self, tool: ProtectionTool, enabled: bool) -> Result<()> { + if let Some(status) = self.tool_status.write().await.get_mut(&tool) { + status.auto_update = enabled; + } + Ok(()) + } + + pub async fn set_auto_remediate(&self, tool: ProtectionTool, enabled: bool) -> Result<()> { + if let Some(status) = self.tool_status.write().await.get_mut(&tool) { + status.auto_remediate = enabled; + } + Ok(()) + } + + pub async fn get_scan_history(&self, tool: Option, limit: usize) -> Vec { + let history = self.scan_history.read().await; + history + .iter() + .filter(|s| tool.is_none() || Some(s.tool) == tool) + .rev() + .take(limit) + .cloned() + .collect() + } + + pub async fn get_active_scans(&self) -> Vec { + self.active_scans.read().await.values().cloned().collect() + } + + pub async fn get_report(&self, tool: ProtectionTool) -> Result { + let history = self.scan_history.read().await; + let latest = history + .iter() + .filter(|s| s.tool == tool) + .last() + .ok_or_else(|| anyhow::anyhow!("No scan results found for {tool}"))?; + + latest.raw_output.clone().ok_or_else(|| anyhow::anyhow!("No report available")) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_protection_tool_from_str() { + assert_eq!(ProtectionTool::from_str("lynis"), Some(ProtectionTool::Lynis)); + assert_eq!(ProtectionTool::from_str("LYNIS"), Some(ProtectionTool::Lynis)); + assert_eq!(ProtectionTool::from_str("rkhunter"), Some(ProtectionTool::RKHunter)); + assert_eq!(ProtectionTool::from_str("clamav"), Some(ProtectionTool::ClamAV)); + assert_eq!(ProtectionTool::from_str("clamscan"), Some(ProtectionTool::ClamAV)); + assert_eq!(ProtectionTool::from_str("maldet"), Some(ProtectionTool::LMD)); + assert_eq!(ProtectionTool::from_str("unknown"), None); + } + + #[test] + fn test_protection_tool_display() { + assert_eq!(format!("{}", ProtectionTool::Lynis), "lynis"); + assert_eq!(format!("{}", ProtectionTool::ClamAV), "clamav"); + } + + #[test] + fn test_tool_status_not_installed() { + let status = ToolStatus::not_installed(ProtectionTool::Lynis); + assert!(!status.installed); + assert!(status.version.is_none()); + assert!(status.service_running.is_none()); + } + + #[test] + fn test_protection_config_default() { + let config = ProtectionConfig::default(); + assert_eq!(config.auto_scan_interval_hours, 24); + assert_eq!(config.auto_update_interval_hours, 6); + assert_eq!(config.enabled_tools.len(), 6); + } + + #[test] + fn test_protection_tool_all() { + let all = ProtectionTool::all(); + assert_eq!(all.len(), 6); + assert!(all.contains(&ProtectionTool::Lynis)); + assert!(all.contains(&ProtectionTool::ClamAV)); + } + + #[test] + fn test_finding_severity() { + let finding = Finding { + id: "test".to_string(), + severity: FindingSeverity::High, + category: "security".to_string(), + title: "Test".to_string(), + description: "Test finding".to_string(), + file_path: None, + remediation: None, + }; + assert_eq!(finding.severity, FindingSeverity::High); + } +} diff --git a/src/security/protection/mod.rs b/src/security/protection/mod.rs new file mode 100644 index 000000000..5d73aa4ba --- /dev/null +++ b/src/security/protection/mod.rs @@ -0,0 +1,12 @@ +pub mod api; +pub mod chkrootkit; +pub mod installer; +pub mod lmd; +pub mod lynis; +pub mod manager; +pub mod rkhunter; +pub mod suricata; + +pub use api::configure_protection_routes; +pub use installer::{InstallResult, ProtectionInstaller, UninstallResult, VerifyResult}; +pub use manager::{ProtectionManager, ProtectionTool, ToolStatus}; diff --git a/src/security/protection/rkhunter.rs b/src/security/protection/rkhunter.rs new file mode 100644 index 000000000..38075a265 --- /dev/null +++ b/src/security/protection/rkhunter.rs @@ -0,0 +1,320 @@ +use anyhow::{Context, Result}; +use tracing::info; + +use crate::security::command_guard::SafeCommand; +use super::manager::{Finding, FindingSeverity, ScanResultStatus}; + +const RKHUNTER_LOG_PATH: &str = "/var/log/rkhunter.log"; + +pub async fn run_scan() -> Result<(ScanResultStatus, Vec, String)> { + info!("Running RKHunter rootkit scan"); + + let output = SafeCommand::new("sudo")? + .arg("rkhunter")? + .arg("--check")? + .arg("--skip-keypress")? + .arg("--report-warnings-only")? + .execute() + .context("Failed to run RKHunter scan")?; + + let stdout = String::from_utf8_lossy(&output.stdout).to_string(); + let stderr = String::from_utf8_lossy(&output.stderr).to_string(); + let raw_output = format!("{stdout}\n{stderr}"); + + let findings = parse_rkhunter_output(&stdout); + let status = determine_result_status(&findings); + + Ok((status, findings, raw_output)) +} + +pub async fn update_database() -> Result<()> { + info!("Updating RKHunter database"); + + SafeCommand::new("sudo")? + .arg("rkhunter")? + .arg("--update")? + .execute() + .context("Failed to update RKHunter database")?; + + info!("RKHunter database updated successfully"); + Ok(()) +} + +pub async fn update_properties() -> Result<()> { + info!("Updating RKHunter file properties database"); + + SafeCommand::new("sudo")? + .arg("rkhunter")? + .arg("--propupd")? + .execute() + .context("Failed to update RKHunter properties")?; + + info!("RKHunter properties updated successfully"); + Ok(()) +} + +pub fn parse_rkhunter_output(output: &str) -> Vec { + let mut findings = Vec::new(); + let mut current_section = String::new(); + + for line in output.lines() { + let trimmed = line.trim(); + + if trimmed.starts_with("Checking") { + current_section = trimmed.replace("Checking", "").trim().to_string(); + continue; + } + + if trimmed.contains("[ Warning ]") || trimmed.contains("[Warning]") { + let description = extract_warning_description(trimmed); + let finding = Finding { + id: format!("rkhunter-warn-{}", findings.len()), + severity: FindingSeverity::High, + category: current_section.clone(), + title: "RKHunter Warning".to_string(), + description, + file_path: extract_file_path(trimmed), + remediation: Some("Investigate the flagged file or configuration".to_string()), + }; + findings.push(finding); + } + + if trimmed.contains("[ Rootkit ]") || trimmed.to_lowercase().contains("rootkit found") { + let finding = Finding { + id: format!("rkhunter-rootkit-{}", findings.len()), + severity: FindingSeverity::Critical, + category: "Rootkit Detection".to_string(), + title: "Potential Rootkit Detected".to_string(), + description: trimmed.to_string(), + file_path: extract_file_path(trimmed), + remediation: Some("Immediately investigate and consider system recovery".to_string()), + }; + findings.push(finding); + } + + if trimmed.contains("Suspicious file") || trimmed.contains("suspicious") { + let finding = Finding { + id: format!("rkhunter-susp-{}", findings.len()), + severity: FindingSeverity::High, + category: current_section.clone(), + title: "Suspicious File Detected".to_string(), + description: trimmed.to_string(), + file_path: extract_file_path(trimmed), + remediation: Some("Verify the file integrity and source".to_string()), + }; + findings.push(finding); + } + + if trimmed.contains("[ Bad ]") { + let finding = Finding { + id: format!("rkhunter-bad-{}", findings.len()), + severity: FindingSeverity::High, + category: current_section.clone(), + title: "Bad Configuration or File".to_string(), + description: trimmed.to_string(), + file_path: extract_file_path(trimmed), + remediation: Some("Review and correct the flagged item".to_string()), + }; + findings.push(finding); + } + } + + findings +} + +pub fn parse_log_file() -> Result { + let content = std::fs::read_to_string(RKHUNTER_LOG_PATH) + .context("Failed to read RKHunter log file")?; + + let mut report = RKHunterReport::default(); + + for line in content.lines() { + if line.contains("Rootkits checked") { + if let Some(count) = extract_number_from_line(line) { + report.rootkits_checked = count; + } + } + + if line.contains("Possible rootkits") { + if let Some(count) = extract_number_from_line(line) { + report.possible_rootkits = count; + } + } + + if line.contains("Suspect files") { + if let Some(count) = extract_number_from_line(line) { + report.suspect_files = count; + } + } + + if line.contains("Warning:") { + report.warnings.push(line.replace("Warning:", "").trim().to_string()); + } + + if line.contains("rkhunter version") { + report.version = line.split(':').nth(1).unwrap_or("").trim().to_string(); + } + } + + Ok(report) +} + +pub async fn get_version() -> Result { + let output = SafeCommand::new("rkhunter")? + .arg("--version")? + .execute() + .context("Failed to get RKHunter version")?; + + let stdout = String::from_utf8_lossy(&output.stdout); + let version = stdout + .lines() + .find(|l| l.contains("version")) + .and_then(|l| l.split_whitespace().last()) + .unwrap_or("unknown") + .to_string(); + + Ok(version) +} + +fn extract_warning_description(line: &str) -> String { + line.replace("[ Warning ]", "") + .replace("[Warning]", "") + .trim() + .to_string() +} + +fn extract_file_path(line: &str) -> Option { + let words: Vec<&str> = line.split_whitespace().collect(); + for word in words { + if word.starts_with('/') { + return Some(word.trim_matches(|c| c == ':' || c == ',' || c == ';').to_string()); + } + } + None +} + +fn extract_number_from_line(line: &str) -> Option { + line.split_whitespace() + .find_map(|word| word.parse::().ok()) +} + +fn determine_result_status(findings: &[Finding]) -> ScanResultStatus { + let has_critical = findings.iter().any(|f| f.severity == FindingSeverity::Critical); + let has_high = findings.iter().any(|f| f.severity == FindingSeverity::High); + let has_medium = findings.iter().any(|f| f.severity == FindingSeverity::Medium); + + if has_critical { + ScanResultStatus::Infected + } else if has_high { + ScanResultStatus::Warnings + } else if has_medium { + ScanResultStatus::Warnings + } else { + ScanResultStatus::Clean + } +} + +#[derive(Debug, Clone, Default)] +pub struct RKHunterReport { + pub version: String, + pub rootkits_checked: u32, + pub possible_rootkits: u32, + pub suspect_files: u32, + pub warnings: Vec, + pub scan_time: Option, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_rkhunter_output_clean() { + let output = r#" +Checking for rootkits... + Performing check of known rootkit files and directories + 55808 Trojan - Variant A [ Not found ] + ADM Worm [ Not found ] + AjaKit Rootkit [ Not found ] +System checks summary +===================== +File properties checks... + Files checked: 142 + Suspect files: 0 +"#; + let findings = parse_rkhunter_output(output); + assert!(findings.is_empty()); + } + + #[test] + fn test_parse_rkhunter_output_warning() { + let output = r#" +Checking for rootkits... + Checking /dev for suspicious file types [ Warning ] + Suspicious file found: /dev/.udev/something +"#; + let findings = parse_rkhunter_output(output); + assert!(!findings.is_empty()); + assert_eq!(findings[0].severity, FindingSeverity::High); + } + + #[test] + fn test_extract_file_path() { + assert_eq!( + extract_file_path("Suspicious file found: /etc/passwd"), + Some("/etc/passwd".to_string()) + ); + assert_eq!( + extract_file_path("Checking /dev/sda for issues"), + Some("/dev/sda".to_string()) + ); + assert_eq!(extract_file_path("No path here"), None); + } + + #[test] + fn test_extract_number_from_line() { + assert_eq!(extract_number_from_line("Rootkits checked: 42"), Some(42)); + assert_eq!(extract_number_from_line("Found 5 issues"), Some(5)); + assert_eq!(extract_number_from_line("No numbers here"), None); + } + + #[test] + fn test_determine_result_status_clean() { + let findings: Vec = vec![]; + assert_eq!(determine_result_status(&findings), ScanResultStatus::Clean); + } + + #[test] + fn test_determine_result_status_critical() { + let findings = vec![Finding { + id: "test".to_string(), + severity: FindingSeverity::Critical, + category: "test".to_string(), + title: "Test".to_string(), + description: "Test".to_string(), + file_path: None, + remediation: None, + }]; + assert_eq!(determine_result_status(&findings), ScanResultStatus::Infected); + } + + #[test] + fn test_rkhunter_report_default() { + let report = RKHunterReport::default(); + assert_eq!(report.rootkits_checked, 0); + assert_eq!(report.possible_rootkits, 0); + assert!(report.warnings.is_empty()); + } + + #[test] + fn test_extract_warning_description() { + assert_eq!( + extract_warning_description("[ Warning ] Some issue here"), + "Some issue here" + ); + assert_eq!( + extract_warning_description("[Warning] Another issue"), + "Another issue" + ); + } +} diff --git a/src/security/protection/suricata.rs b/src/security/protection/suricata.rs new file mode 100644 index 000000000..e42524dae --- /dev/null +++ b/src/security/protection/suricata.rs @@ -0,0 +1,385 @@ +use anyhow::{Context, Result}; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use tracing::info; + +use crate::security::command_guard::SafeCommand; +use super::manager::{Finding, FindingSeverity, ScanResultStatus}; + +const SURICATA_EVE_LOG: &str = "/var/log/suricata/eve.json"; +const SURICATA_FAST_LOG: &str = "/var/log/suricata/fast.log"; +const SURICATA_RULES_DIR: &str = "/var/lib/suricata/rules"; + +pub async fn get_alerts() -> Result<(ScanResultStatus, Vec, String)> { + info!("Retrieving Suricata alerts"); + + let alerts = parse_eve_log(100).await.unwrap_or_default(); + let findings = alerts_to_findings(&alerts); + let status = determine_result_status(&findings); + + let raw_output = serde_json::to_string_pretty(&alerts).unwrap_or_default(); + + Ok((status, findings, raw_output)) +} + +pub async fn start_service() -> Result<()> { + info!("Starting Suricata service"); + + SafeCommand::new("sudo")? + .arg("systemctl")? + .arg("start")? + .arg("suricata")? + .execute() + .context("Failed to start Suricata service")?; + + info!("Suricata service started successfully"); + Ok(()) +} + +pub async fn stop_service() -> Result<()> { + info!("Stopping Suricata service"); + + SafeCommand::new("sudo")? + .arg("systemctl")? + .arg("stop")? + .arg("suricata")? + .execute() + .context("Failed to stop Suricata service")?; + + info!("Suricata service stopped successfully"); + Ok(()) +} + +pub async fn restart_service() -> Result<()> { + info!("Restarting Suricata service"); + + SafeCommand::new("sudo")? + .arg("systemctl")? + .arg("restart")? + .arg("suricata")? + .execute() + .context("Failed to restart Suricata service")?; + + info!("Suricata service restarted successfully"); + Ok(()) +} + +pub async fn update_rules() -> Result<()> { + info!("Updating Suricata rules"); + + SafeCommand::new("sudo")? + .arg("suricata-update")? + .execute() + .context("Failed to update Suricata rules")?; + + SafeCommand::new("sudo")? + .arg("systemctl")? + .arg("reload")? + .arg("suricata")? + .execute() + .context("Failed to reload Suricata after rule update")?; + + info!("Suricata rules updated successfully"); + Ok(()) +} + +pub async fn get_version() -> Result { + let output = SafeCommand::new("suricata")? + .arg("--build-info")? + .execute() + .context("Failed to get Suricata version")?; + + let stdout = String::from_utf8_lossy(&output.stdout); + let version = stdout + .lines() + .find(|l| l.contains("Suricata version")) + .and_then(|l| l.split(':').nth(1)) + .map(|v| v.trim().to_string()) + .unwrap_or_else(|| "unknown".to_string()); + + Ok(version) +} + +pub async fn get_rule_count() -> Result { + let mut count = 0u32; + + if let Ok(entries) = std::fs::read_dir(SURICATA_RULES_DIR) { + for entry in entries.flatten() { + let path = entry.path(); + if path.extension().is_some_and(|ext| ext == "rules") { + if let Ok(content) = std::fs::read_to_string(&path) { + count += content + .lines() + .filter(|l| !l.trim().starts_with('#') && !l.trim().is_empty()) + .count() as u32; + } + } + } + } + + Ok(count) +} + +pub async fn get_stats() -> Result { + let alerts = parse_eve_log(1000).await.unwrap_or_default(); + let today = Utc::now().date_naive(); + + let alerts_today = alerts + .iter() + .filter(|a| a.timestamp.date_naive() == today) + .count() as u32; + + let blocked_today = alerts + .iter() + .filter(|a| a.timestamp.date_naive() == today && a.action == "blocked") + .count() as u32; + + let rule_count = get_rule_count().await.unwrap_or(0); + + Ok(SuricataStats { + alerts_today, + blocked_today, + rule_count, + total_alerts: alerts.len() as u32, + }) +} + +pub async fn parse_eve_log(limit: usize) -> Result> { + let content = std::fs::read_to_string(SURICATA_EVE_LOG) + .context("Failed to read Suricata EVE log")?; + + let mut alerts = Vec::new(); + + for line in content.lines().rev().take(limit * 2) { + if let Ok(event) = serde_json::from_str::(line) { + if event.event_type == "alert" { + if let Some(alert_data) = event.alert { + let alert = SuricataAlert { + timestamp: event.timestamp, + src_ip: event.src_ip.unwrap_or_default(), + src_port: event.src_port.unwrap_or(0), + dest_ip: event.dest_ip.unwrap_or_default(), + dest_port: event.dest_port.unwrap_or(0), + protocol: event.proto.unwrap_or_default(), + signature: alert_data.signature, + signature_id: alert_data.signature_id, + severity: alert_data.severity, + category: alert_data.category, + action: alert_data.action, + }; + alerts.push(alert); + if alerts.len() >= limit { + break; + } + } + } + } + } + + alerts.reverse(); + Ok(alerts) +} + +fn alerts_to_findings(alerts: &[SuricataAlert]) -> Vec { + alerts + .iter() + .map(|alert| { + let severity = match alert.severity { + 1 => FindingSeverity::Critical, + 2 => FindingSeverity::High, + 3 => FindingSeverity::Medium, + _ => FindingSeverity::Low, + }; + + Finding { + id: format!("suricata-{}", alert.signature_id), + severity, + category: alert.category.clone(), + title: alert.signature.clone(), + description: format!( + "{}:{} -> {}:{} ({})", + alert.src_ip, alert.src_port, alert.dest_ip, alert.dest_port, alert.protocol + ), + file_path: None, + remediation: if alert.action == "blocked" { + Some("Traffic was automatically blocked".to_string()) + } else { + Some("Review the alert and consider blocking the source".to_string()) + }, + } + }) + .collect() +} + +fn determine_result_status(findings: &[Finding]) -> ScanResultStatus { + let has_critical = findings.iter().any(|f| f.severity == FindingSeverity::Critical); + let has_high = findings.iter().any(|f| f.severity == FindingSeverity::High); + + if has_critical { + ScanResultStatus::Infected + } else if has_high || !findings.is_empty() { + ScanResultStatus::Warnings + } else { + ScanResultStatus::Clean + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SuricataAlert { + pub timestamp: DateTime, + pub src_ip: String, + pub src_port: u16, + pub dest_ip: String, + pub dest_port: u16, + pub protocol: String, + pub signature: String, + pub signature_id: u64, + pub severity: u8, + pub category: String, + pub action: String, +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct SuricataStats { + pub alerts_today: u32, + pub blocked_today: u32, + pub rule_count: u32, + pub total_alerts: u32, +} + +#[derive(Debug, Deserialize)] +struct EveEvent { + timestamp: DateTime, + event_type: String, + src_ip: Option, + src_port: Option, + dest_ip: Option, + dest_port: Option, + proto: Option, + alert: Option, +} + +#[derive(Debug, Deserialize)] +struct EveAlert { + signature: String, + signature_id: u64, + severity: u8, + category: String, + action: String, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_alerts_to_findings_empty() { + let alerts: Vec = vec![]; + let findings = alerts_to_findings(&alerts); + assert!(findings.is_empty()); + } + + #[test] + fn test_alerts_to_findings_severity_mapping() { + let alert = SuricataAlert { + timestamp: Utc::now(), + src_ip: "192.168.1.1".to_string(), + src_port: 12345, + dest_ip: "10.0.0.1".to_string(), + dest_port: 80, + protocol: "TCP".to_string(), + signature: "Test Alert".to_string(), + signature_id: 1000001, + severity: 1, + category: "Test".to_string(), + action: "allowed".to_string(), + }; + + let findings = alerts_to_findings(&[alert]); + assert_eq!(findings.len(), 1); + assert_eq!(findings[0].severity, FindingSeverity::Critical); + } + + #[test] + fn test_alerts_to_findings_severity_high() { + let alert = SuricataAlert { + timestamp: Utc::now(), + src_ip: "192.168.1.1".to_string(), + src_port: 12345, + dest_ip: "10.0.0.1".to_string(), + dest_port: 443, + protocol: "TCP".to_string(), + signature: "High Severity Alert".to_string(), + signature_id: 1000002, + severity: 2, + category: "Test".to_string(), + action: "blocked".to_string(), + }; + + let findings = alerts_to_findings(&[alert]); + assert_eq!(findings[0].severity, FindingSeverity::High); + assert!(findings[0].remediation.as_ref().unwrap().contains("blocked")); + } + + #[test] + fn test_determine_result_status_clean() { + let findings: Vec = vec![]; + assert_eq!(determine_result_status(&findings), ScanResultStatus::Clean); + } + + #[test] + fn test_determine_result_status_critical() { + let findings = vec![Finding { + id: "test".to_string(), + severity: FindingSeverity::Critical, + category: "test".to_string(), + title: "Test".to_string(), + description: "Test".to_string(), + file_path: None, + remediation: None, + }]; + assert_eq!(determine_result_status(&findings), ScanResultStatus::Infected); + } + + #[test] + fn test_determine_result_status_warnings() { + let findings = vec![Finding { + id: "test".to_string(), + severity: FindingSeverity::Medium, + category: "test".to_string(), + title: "Test".to_string(), + description: "Test".to_string(), + file_path: None, + remediation: None, + }]; + assert_eq!(determine_result_status(&findings), ScanResultStatus::Warnings); + } + + #[test] + fn test_suricata_stats_default() { + let stats = SuricataStats::default(); + assert_eq!(stats.alerts_today, 0); + assert_eq!(stats.blocked_today, 0); + assert_eq!(stats.rule_count, 0); + } + + #[test] + fn test_suricata_alert_serialization() { + let alert = SuricataAlert { + timestamp: Utc::now(), + src_ip: "192.168.1.1".to_string(), + src_port: 12345, + dest_ip: "10.0.0.1".to_string(), + dest_port: 80, + protocol: "TCP".to_string(), + signature: "Test".to_string(), + signature_id: 1, + severity: 3, + category: "Test".to_string(), + action: "allowed".to_string(), + }; + + let json = serde_json::to_string(&alert); + assert!(json.is_ok()); + } +} diff --git a/src/security/rbac_middleware.rs b/src/security/rbac_middleware.rs index 9556007b4..6946d088d 100644 --- a/src/security/rbac_middleware.rs +++ b/src/security/rbac_middleware.rs @@ -11,7 +11,7 @@ use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; use std::sync::Arc; use tokio::sync::RwLock; -use tracing::debug; +use tracing::{debug, warn}; use uuid::Uuid; use super::auth::{AuthenticatedUser, Permission, Role}; @@ -589,7 +589,7 @@ impl RbacManager { cache.retain(|k, _| !k.starts_with(prefix)); } - async fn check_permission_string(&self, user: &AuthenticatedUser, permission_str: &str) -> bool { + pub async fn check_permission_string(&self, user: &AuthenticatedUser, permission_str: &str) -> bool { let permission = match permission_str.to_lowercase().as_str() { "read" => Permission::Read, "write" => Permission::Write, @@ -706,6 +706,175 @@ impl RequireResourceAccess { } } +#[derive(Clone)] +pub struct RbacMiddlewareState { + pub rbac_manager: Arc, + pub required_permission: Option, + pub required_roles: Vec, + pub resource_type: Option, +} + +impl RbacMiddlewareState { + pub fn new(rbac_manager: Arc) -> Self { + Self { + rbac_manager, + required_permission: None, + required_roles: Vec::new(), + resource_type: None, + } + } + + pub fn with_permission(mut self, permission: &str) -> Self { + self.required_permission = Some(permission.to_string()); + self + } + + pub fn with_roles(mut self, roles: Vec) -> Self { + self.required_roles = roles; + self + } + + pub fn with_resource_type(mut self, resource_type: &str) -> Self { + self.resource_type = Some(resource_type.to_string()); + self + } +} + +pub async fn require_permission_middleware( + State(state): State, + request: Request, + next: Next, +) -> Result { + let user = request + .extensions() + .get::() + .cloned() + .unwrap_or_else(AuthenticatedUser::anonymous); + + if let Some(ref required_perm) = state.required_permission { + let has_permission = state + .rbac_manager + .check_permission_string(&user, required_perm) + .await; + + if !has_permission { + warn!( + "Permission denied for user {}: missing permission {}", + user.user_id, required_perm + ); + return Err(RbacError::PermissionDenied(format!( + "Missing required permission: {required_perm}" + ))); + } + } + + if !state.required_roles.is_empty() { + let has_required_role = state + .required_roles + .iter() + .any(|role| user.has_role(role)); + + if !has_required_role { + warn!( + "Role check failed for user {}: required one of {:?}", + user.user_id, state.required_roles + ); + return Err(RbacError::InsufficientRole(format!( + "Required role: {:?}", + state.required_roles + ))); + } + } + + Ok(next.run(request).await) +} + +pub async fn require_admin_middleware( + request: Request, + next: Next, +) -> Result { + let user = request + .extensions() + .get::() + .cloned() + .unwrap_or_else(AuthenticatedUser::anonymous); + + if !user.is_admin() && !user.is_super_admin() { + warn!("Admin access denied for user {}", user.user_id); + return Err(RbacError::AdminRequired); + } + + Ok(next.run(request).await) +} + +pub async fn require_super_admin_middleware( + request: Request, + next: Next, +) -> Result { + let user = request + .extensions() + .get::() + .cloned() + .unwrap_or_else(AuthenticatedUser::anonymous); + + if !user.is_super_admin() { + warn!("Super admin access denied for user {}", user.user_id); + return Err(RbacError::SuperAdminRequired); + } + + Ok(next.run(request).await) +} + +#[derive(Debug, Clone)] +pub enum RbacError { + PermissionDenied(String), + InsufficientRole(String), + AdminRequired, + SuperAdminRequired, + ResourceAccessDenied(String), +} + +impl IntoResponse for RbacError { + fn into_response(self) -> Response { + let (status, message) = match self { + Self::PermissionDenied(msg) => (StatusCode::FORBIDDEN, msg), + Self::InsufficientRole(msg) => (StatusCode::FORBIDDEN, msg), + Self::AdminRequired => ( + StatusCode::FORBIDDEN, + "Administrator access required".to_string(), + ), + Self::SuperAdminRequired => ( + StatusCode::FORBIDDEN, + "Super administrator access required".to_string(), + ), + Self::ResourceAccessDenied(msg) => (StatusCode::FORBIDDEN, msg), + }; + + let body = serde_json::json!({ + "error": "access_denied", + "message": message, + "code": "RBAC_DENIED" + }); + + (status, Json(body)).into_response() + } +} + +pub fn create_permission_layer( + rbac_manager: Arc, + permission: &str, +) -> RbacMiddlewareState { + RbacMiddlewareState::new(rbac_manager).with_permission(permission) +} + +pub fn create_role_layer(rbac_manager: Arc, roles: Vec) -> RbacMiddlewareState { + RbacMiddlewareState::new(rbac_manager).with_roles(roles) +} + +pub fn create_admin_layer(rbac_manager: Arc) -> RbacMiddlewareState { + RbacMiddlewareState::new(rbac_manager).with_roles(vec![Role::Admin, Role::SuperAdmin]) +} + pub fn build_default_route_permissions() -> Vec { vec![ RoutePermission::new("/api/health", "GET", "").with_anonymous(true), diff --git a/src/settings/mod.rs b/src/settings/mod.rs index c0fec3863..bfe099b0d 100644 --- a/src/settings/mod.rs +++ b/src/settings/mod.rs @@ -3,6 +3,7 @@ pub mod menu_config; pub mod permission_inheritance; pub mod rbac; pub mod rbac_ui; +pub mod security_admin; use axum::{ extract::State, @@ -28,6 +29,7 @@ pub fn configure_settings_routes() -> Router> { ) .route("/api/user/security/devices", get(get_trusted_devices)) .merge(rbac::configure_rbac_routes()) + .merge(security_admin::configure_security_admin_routes()) } async fn get_storage_info(State(_state): State>) -> Html { diff --git a/src/settings/security_admin.rs b/src/settings/security_admin.rs new file mode 100644 index 000000000..f80027ac8 --- /dev/null +++ b/src/settings/security_admin.rs @@ -0,0 +1,405 @@ +use axum::{ + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, + routing::{delete, get, post}, + Json, Router, +}; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use uuid::Uuid; + +use crate::shared::state::AppState; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SecurityOverview { + pub tls_enabled: bool, + pub mtls_enabled: bool, + pub rate_limiting_enabled: bool, + pub cors_configured: bool, + pub api_keys_count: u32, + pub active_sessions_count: u32, + pub audit_log_enabled: bool, + pub mfa_enabled_users: u32, + pub total_users: u32, + pub last_security_scan: Option>, + pub security_score: u8, + pub vulnerabilities: SecurityVulnerabilities, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SecurityVulnerabilities { + pub critical: u32, + pub high: u32, + pub medium: u32, + pub low: u32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TlsSettings { + pub enabled: bool, + pub cert_expiry: Option>, + pub auto_renew: bool, + pub min_version: String, + pub cipher_suites: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RateLimitSettings { + pub enabled: bool, + pub requests_per_minute: u32, + pub burst_size: u32, + pub whitelist: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CorsSettings { + pub enabled: bool, + pub allowed_origins: Vec, + pub allowed_methods: Vec, + pub allowed_headers: Vec, + pub max_age_seconds: u32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AuditLogEntry { + pub id: Uuid, + pub timestamp: DateTime, + pub user_id: Option, + pub action: String, + pub resource: String, + pub resource_id: Option, + pub ip_address: Option, + pub user_agent: Option, + pub success: bool, + pub details: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ApiKeyInfo { + pub id: Uuid, + pub name: String, + pub prefix: String, + pub created_at: DateTime, + pub last_used_at: Option>, + pub expires_at: Option>, + pub scopes: Vec, + pub is_active: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CreateApiKeyRequest { + pub name: String, + pub scopes: Vec, + pub expires_in_days: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CreateApiKeyResponse { + pub id: Uuid, + pub name: String, + pub key: String, + pub expires_at: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MfaSettings { + pub require_mfa: bool, + pub allowed_methods: Vec, + pub grace_period_days: u32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SessionInfo { + pub id: Uuid, + pub user_id: Uuid, + pub user_email: String, + pub created_at: DateTime, + pub last_activity: DateTime, + pub ip_address: Option, + pub user_agent: Option, + pub is_current: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PasswordPolicy { + pub min_length: u8, + pub require_uppercase: bool, + pub require_lowercase: bool, + pub require_numbers: bool, + pub require_special_chars: bool, + pub max_age_days: Option, + pub prevent_reuse_count: u8, +} + +#[derive(Debug, Serialize)] +pub struct SecurityError { + pub error: String, + pub code: String, +} + +impl IntoResponse for SecurityError { + fn into_response(self) -> axum::response::Response { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({"error": self.error, "code": self.code})), + ) + .into_response() + } +} + +async fn get_security_overview( + State(_state): State>, +) -> Result, SecurityError> { + let overview = SecurityOverview { + tls_enabled: true, + mtls_enabled: false, + rate_limiting_enabled: true, + cors_configured: true, + api_keys_count: 5, + active_sessions_count: 12, + audit_log_enabled: true, + mfa_enabled_users: 8, + total_users: 25, + last_security_scan: Some(Utc::now()), + security_score: 85, + vulnerabilities: SecurityVulnerabilities { + critical: 0, + high: 1, + medium: 3, + low: 7, + }, + }; + Ok(Json(overview)) +} + +async fn get_tls_settings( + State(_state): State>, +) -> Result, SecurityError> { + let settings = TlsSettings { + enabled: true, + cert_expiry: Some(Utc::now() + chrono::Duration::days(90)), + auto_renew: true, + min_version: "TLS 1.2".to_string(), + cipher_suites: vec![ + "TLS_AES_256_GCM_SHA384".to_string(), + "TLS_CHACHA20_POLY1305_SHA256".to_string(), + "TLS_AES_128_GCM_SHA256".to_string(), + ], + }; + Ok(Json(settings)) +} + +async fn update_tls_settings( + State(_state): State>, + Json(_settings): Json, +) -> Result, SecurityError> { + let settings = TlsSettings { + enabled: true, + cert_expiry: Some(Utc::now() + chrono::Duration::days(90)), + auto_renew: true, + min_version: "TLS 1.2".to_string(), + cipher_suites: vec![ + "TLS_AES_256_GCM_SHA384".to_string(), + "TLS_CHACHA20_POLY1305_SHA256".to_string(), + ], + }; + Ok(Json(settings)) +} + +async fn get_rate_limit_settings( + State(_state): State>, +) -> Result, SecurityError> { + let settings = RateLimitSettings { + enabled: true, + requests_per_minute: 60, + burst_size: 100, + whitelist: vec![], + }; + Ok(Json(settings)) +} + +async fn update_rate_limit_settings( + State(_state): State>, + Json(settings): Json, +) -> Result, SecurityError> { + Ok(Json(settings)) +} + +async fn get_cors_settings( + State(_state): State>, +) -> Result, SecurityError> { + let settings = CorsSettings { + enabled: true, + allowed_origins: vec!["*".to_string()], + allowed_methods: vec![ + "GET".to_string(), + "POST".to_string(), + "PUT".to_string(), + "DELETE".to_string(), + ], + allowed_headers: vec!["Content-Type".to_string(), "Authorization".to_string()], + max_age_seconds: 3600, + }; + Ok(Json(settings)) +} + +async fn update_cors_settings( + State(_state): State>, + Json(settings): Json, +) -> Result, SecurityError> { + Ok(Json(settings)) +} + +async fn list_audit_logs( + State(_state): State>, +) -> Result>, SecurityError> { + let logs = vec![ + AuditLogEntry { + id: Uuid::new_v4(), + timestamp: Utc::now(), + user_id: Some(Uuid::new_v4()), + action: "login".to_string(), + resource: "session".to_string(), + resource_id: None, + ip_address: Some("192.168.1.100".to_string()), + user_agent: Some("Mozilla/5.0".to_string()), + success: true, + details: None, + }, + ]; + Ok(Json(logs)) +} + +async fn list_api_keys( + State(_state): State>, +) -> Result>, SecurityError> { + let keys = vec![]; + Ok(Json(keys)) +} + +async fn create_api_key( + State(_state): State>, + Json(req): Json, +) -> Result, SecurityError> { + let response = CreateApiKeyResponse { + id: Uuid::new_v4(), + name: req.name, + key: format!("gb_{}", Uuid::new_v4().to_string().replace('-', "")), + expires_at: req.expires_in_days.map(|days| Utc::now() + chrono::Duration::days(i64::from(days))), + }; + Ok(Json(response)) +} + +async fn revoke_api_key( + State(_state): State>, + Path(_key_id): Path, +) -> Result { + Ok(StatusCode::NO_CONTENT) +} + +async fn get_mfa_settings( + State(_state): State>, +) -> Result, SecurityError> { + let settings = MfaSettings { + require_mfa: false, + allowed_methods: vec!["totp".to_string(), "webauthn".to_string()], + grace_period_days: 7, + }; + Ok(Json(settings)) +} + +async fn update_mfa_settings( + State(_state): State>, + Json(settings): Json, +) -> Result, SecurityError> { + Ok(Json(settings)) +} + +async fn list_active_sessions( + State(_state): State>, +) -> Result>, SecurityError> { + let sessions = vec![]; + Ok(Json(sessions)) +} + +async fn revoke_session( + State(_state): State>, + Path(_session_id): Path, +) -> Result { + Ok(StatusCode::NO_CONTENT) +} + +async fn revoke_all_user_sessions( + State(_state): State>, + Path(_user_id): Path, +) -> Result { + Ok(StatusCode::NO_CONTENT) +} + +async fn get_password_policy( + State(_state): State>, +) -> Result, SecurityError> { + let policy = PasswordPolicy { + min_length: 12, + require_uppercase: true, + require_lowercase: true, + require_numbers: true, + require_special_chars: true, + max_age_days: Some(90), + prevent_reuse_count: 5, + }; + Ok(Json(policy)) +} + +async fn update_password_policy( + State(_state): State>, + Json(policy): Json, +) -> Result, SecurityError> { + Ok(Json(policy)) +} + +async fn run_security_scan( + State(state): State>, +) -> Result, SecurityError> { + get_security_overview(State(state)).await +} + +pub fn configure_security_admin_routes() -> Router> { + Router::new() + .route("/api/settings/security/overview", get(get_security_overview)) + .route("/api/settings/security/scan", post(run_security_scan)) + .route( + "/api/settings/security/tls", + get(get_tls_settings).put(update_tls_settings), + ) + .route( + "/api/settings/security/rate-limit", + get(get_rate_limit_settings).put(update_rate_limit_settings), + ) + .route( + "/api/settings/security/cors", + get(get_cors_settings).put(update_cors_settings), + ) + .route("/api/settings/security/audit", get(list_audit_logs)) + .route( + "/api/settings/security/api-keys", + get(list_api_keys).post(create_api_key), + ) + .route("/api/settings/security/api-keys/:key_id", delete(revoke_api_key)) + .route( + "/api/settings/security/mfa", + get(get_mfa_settings).put(update_mfa_settings), + ) + .route("/api/settings/security/sessions", get(list_active_sessions)) + .route("/api/settings/security/sessions/:session_id", delete(revoke_session)) + .route( + "/api/settings/security/users/:user_id/sessions", + delete(revoke_all_user_sessions), + ) + .route( + "/api/settings/security/password-policy", + get(get_password_policy).put(update_password_policy), + ) +} diff --git a/src/social/mod.rs b/src/social/mod.rs index 519dd87e1..7568bb90c 100644 --- a/src/social/mod.rs +++ b/src/social/mod.rs @@ -1178,10 +1178,40 @@ body {{ font-family: system-ui, sans-serif; background: var(--bg); color: var(-- ) } +async fn handle_get_suggested_communities_html() -> Html { + Html(r#" +
+
🌐
+
+ General Discussion + 128 members +
+ +
+
+
💡
+
+ Ideas & Feedback + 64 members +
+ +
+
+
🎉
+
+ Announcements + 256 members +
+ +
+ "#.to_string()) +} + pub fn configure_social_routes() -> Router> { Router::new() .route("/api/social/feed", get(handle_get_feed)) .route("/api/ui/social/feed", get(handle_get_feed_html)) + .route("/api/ui/social/suggested", get(handle_get_suggested_communities_html)) .route("/api/social/posts", post(handle_create_post)) .route("/api/social/posts/:id", get(handle_get_post)) .route("/api/social/posts/:id", put(handle_update_post)) diff --git a/src/video/mod.rs b/src/video/mod.rs index 00efa8013..c9965d595 100644 --- a/src/video/mod.rs +++ b/src/video/mod.rs @@ -99,7 +99,6 @@ pub fn configure_video_routes() -> Router> { ) .route("/api/video/analytics/view", post(record_view_handler)) .route("/api/video/ws/export/:id", get(export_progress_websocket)) - .route("/video", get(video_ui)) } pub fn configure(router: Router>) -> Router> { diff --git a/src/workspaces/mod.rs b/src/workspaces/mod.rs index 325b1c03b..f9bfe729e 100644 --- a/src/workspaces/mod.rs +++ b/src/workspaces/mod.rs @@ -1,3 +1,10 @@ +use axum::{ + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, + routing::{delete, get, post}, + Json, Router, +}; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; @@ -5,6 +12,8 @@ use std::sync::Arc; use tokio::sync::RwLock; use uuid::Uuid; +use crate::shared::state::AppState; + pub mod blocks; pub mod pages; pub mod collaboration; @@ -1291,3 +1300,211 @@ impl std::fmt::Display for WorkspacesError { } impl std::error::Error for WorkspacesError {} + +impl IntoResponse for WorkspacesError { + fn into_response(self) -> axum::response::Response { + let (status, message) = match &self { + Self::WorkspaceNotFound | Self::PageNotFound | Self::BlockNotFound + | Self::CommentNotFound | Self::VersionNotFound | Self::MemberNotFound => { + (StatusCode::NOT_FOUND, self.to_string()) + } + Self::PermissionDenied => (StatusCode::FORBIDDEN, self.to_string()), + Self::MemberAlreadyExists | Self::CannotRemoveLastOwner | Self::InvalidOperation(_) => { + (StatusCode::BAD_REQUEST, self.to_string()) + } + }; + (status, Json(serde_json::json!({"error": message}))).into_response() + } +} + +#[derive(Debug, Deserialize)] +pub struct CreateWorkspaceRequest { + pub name: String, + pub description: Option, +} + +#[derive(Debug, Deserialize)] +pub struct UpdateWorkspaceRequest { + pub name: Option, + pub description: Option, + pub icon: Option, +} + +#[derive(Debug, Deserialize)] +pub struct CreatePageRequest { + pub title: String, + pub parent_id: Option, +} + +#[derive(Debug, Deserialize)] +pub struct UpdatePageRequest { + pub title: Option, + pub icon: Option, +} + +#[derive(Debug, Deserialize)] +pub struct AddMemberRequest { + pub user_id: Uuid, + pub role: WorkspaceRole, +} + +async fn list_workspaces( + State(_state): State>, +) -> Json> { + let service = WorkspacesService::new(); + let org_id = Uuid::nil(); + let workspaces = service.list_workspaces(org_id).await; + Json(workspaces) +} + +async fn create_workspace( + State(_state): State>, + Json(req): Json, +) -> Result, WorkspacesError> { + let service = WorkspacesService::new(); + let org_id = Uuid::nil(); + let user_id = Uuid::nil(); + let workspace = service.create_workspace(org_id, &req.name, user_id).await?; + Ok(Json(workspace)) +} + +async fn get_workspace( + State(_state): State>, + Path(workspace_id): Path, +) -> Result, WorkspacesError> { + let service = WorkspacesService::new(); + let workspace = service.get_workspace(workspace_id).await.ok_or(WorkspacesError::WorkspaceNotFound)?; + Ok(Json(workspace)) +} + +async fn update_workspace( + State(_state): State>, + Path(workspace_id): Path, + Json(req): Json, +) -> Result, WorkspacesError> { + let service = WorkspacesService::new(); + let workspace = service.update_workspace(workspace_id, req.name, req.description, req.icon).await?; + Ok(Json(workspace)) +} + +async fn delete_workspace( + State(_state): State>, + Path(workspace_id): Path, +) -> Result { + let service = WorkspacesService::new(); + service.delete_workspace(workspace_id).await?; + Ok(StatusCode::NO_CONTENT) +} + +async fn list_pages( + State(_state): State>, + Path(workspace_id): Path, +) -> Json> { + let service = WorkspacesService::new(); + let pages = service.get_page_tree(workspace_id).await; + Json(pages) +} + +async fn create_page( + State(_state): State>, + Path(workspace_id): Path, + Json(req): Json, +) -> Result, WorkspacesError> { + let service = WorkspacesService::new(); + let user_id = Uuid::nil(); + let page = service.create_page(workspace_id, req.parent_id, &req.title, user_id).await?; + Ok(Json(page)) +} + +async fn get_page( + State(_state): State>, + Path(page_id): Path, +) -> Result, WorkspacesError> { + let service = WorkspacesService::new(); + let page = service.get_page(page_id).await.ok_or(WorkspacesError::PageNotFound)?; + Ok(Json(page)) +} + +async fn update_page( + State(_state): State>, + Path(page_id): Path, + Json(req): Json, +) -> Result, WorkspacesError> { + let service = WorkspacesService::new(); + let user_id = Uuid::nil(); + let page = service.update_page(page_id, req.title, req.icon, None, user_id).await?; + Ok(Json(page)) +} + +async fn delete_page( + State(_state): State>, + Path(page_id): Path, +) -> Result { + let service = WorkspacesService::new(); + service.delete_page(page_id).await?; + Ok(StatusCode::NO_CONTENT) +} + +async fn add_member( + State(_state): State>, + Path(workspace_id): Path, + Json(req): Json, +) -> Result { + let service = WorkspacesService::new(); + let inviter_id = Uuid::nil(); + service.add_member(workspace_id, req.user_id, req.role, inviter_id).await?; + Ok(StatusCode::CREATED) +} + +async fn remove_member( + State(_state): State>, + Path((workspace_id, user_id)): Path<(Uuid, Uuid)>, +) -> Result { + let service = WorkspacesService::new(); + service.remove_member(workspace_id, user_id).await?; + Ok(StatusCode::NO_CONTENT) +} + +async fn search_pages( + State(_state): State>, + Path(workspace_id): Path, + axum::extract::Query(params): axum::extract::Query>, +) -> Json> { + let service = WorkspacesService::new(); + let query = params.get("q").cloned().unwrap_or_default(); + let results = service.search_pages(workspace_id, &query).await; + Json(results) +} + +async fn get_slash_commands_handler( + State(_state): State>, +) -> Json> { + Json(get_slash_commands()) +} + +pub fn configure_workspaces_routes() -> Router> { + Router::new() + .route("/api/workspaces", get(list_workspaces).post(create_workspace)) + .route( + "/api/workspaces/:workspace_id", + get(get_workspace).put(update_workspace).delete(delete_workspace), + ) + .route( + "/api/workspaces/:workspace_id/pages", + get(list_pages).post(create_page), + ) + .route( + "/api/workspaces/:workspace_id/members", + post(add_member), + ) + .route( + "/api/workspaces/:workspace_id/members/:user_id", + delete(remove_member), + ) + .route( + "/api/workspaces/:workspace_id/search", + get(search_pages), + ) + .route("/api/pages/:page_id", get(get_page).put(update_page).delete(delete_page)) + .route("/api/workspaces/commands", get(get_slash_commands_handler)) +}