Add security protection module with sudo-based privilege escalation

- Create installer.rs for 'botserver install protection' command
- Requires root to install packages and create sudoers config
- Sudoers uses exact commands (no wildcards) for security
- Update all tool files (lynis, rkhunter, chkrootkit, suricata, lmd) to use sudo
- Update manager.rs service management to use sudo
- Add 'sudo' and 'visudo' to command_guard.rs whitelist
- Update CLI with install/remove/status protection commands

Security model:
- Installation requires root (sudo botserver install protection)
- Runtime uses sudoers NOPASSWD for specific commands only
- No wildcards in sudoers - exact command specifications
- Tools run on host system, not in containers
This commit is contained in:
Rodrigo Rodriguez (Pragmatismo) 2026-01-10 09:41:12 -03:00
parent 27ecca0899
commit faeae250bc
30 changed files with 6260 additions and 32 deletions

View file

@ -63,7 +63,7 @@ msteams = []
# ===== PRODUCTIVITY FEATURES =====
chat = []
drive = ["dep:aws-config", "dep:aws-sdk-s3", "dep:pdf-extract", "dep:zip", "dep:downloader", "dep:mime_guess", "dep:flate2", "dep:tar"]
drive = ["dep:aws-config", "dep:aws-sdk-s3", "dep:pdf-extract", "dep:zip", "dep:downloader", "dep:flate2", "dep:tar"]
tasks = ["dep:cron"]
calendar = []
meet = ["dep:livekit"]
@ -183,7 +183,7 @@ pdf-extract = { version = "0.10.0", optional = true }
quick-xml = { version = "0.37", features = ["serialize"] }
zip = { version = "2.2", optional = true }
downloader = { version = "0.2", optional = true }
mime_guess = { version = "2.0", optional = true }
flate2 = { version = "1.0", optional = true }
tar = { version = "0.4", optional = true }
@ -247,6 +247,10 @@ rss = "2.0"
# HTML parsing/web scraping
scraper = "0.25"
walkdir = "2.5.0"
# Embedded static files (UI fallback when no external folder)
rust-embed = "8.5"
mime_guess = "2.0"
hyper-util = { version = "0.1.19", features = ["client-legacy", "tokio"] }
http-body-util = "0.1.3"

View file

@ -1,7 +1,7 @@
use axum::{
body::Body,
extract::State,
http::{Request, StatusCode},
http::{header::HeaderValue, Request, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
Json,
@ -89,14 +89,14 @@ pub async fn quota_middleware(
let headers = response.headers_mut();
headers.insert(
"X-Quota-Warning",
message.parse().unwrap_or_else(|_| "quota warning".parse().unwrap()),
message.parse().unwrap_or_else(|_| HeaderValue::from_static("quota warning")),
);
headers.insert(
"X-Quota-Usage-Percent",
percentage
.to_string()
.parse()
.unwrap_or_else(|_| "0".parse().unwrap()),
.unwrap_or_else(|_| HeaderValue::from_static("0")),
);
response
@ -148,14 +148,14 @@ pub async fn api_rate_limit_middleware(
let headers = response.headers_mut();
headers.insert(
"X-RateLimit-Warning",
message.parse().unwrap_or_else(|_| "rate limit warning".parse().unwrap()),
message.parse().unwrap_or_else(|_| HeaderValue::from_static("rate limit warning")),
);
headers.insert(
"X-RateLimit-Usage-Percent",
percentage
.to_string()
.parse()
.unwrap_or_else(|_| "0".parse().unwrap()),
.unwrap_or_else(|_| HeaderValue::from_static("0")),
);
response
}
@ -212,14 +212,14 @@ pub async fn message_quota_middleware(
let headers = response.headers_mut();
headers.insert(
"X-Message-Quota-Warning",
message.parse().unwrap_or_else(|_| "message quota warning".parse().unwrap()),
message.parse().unwrap_or_else(|_| HeaderValue::from_static("message quota warning")),
);
headers.insert(
"X-Message-Quota-Usage-Percent",
percentage
.to_string()
.parse()
.unwrap_or_else(|_| "0".parse().unwrap()),
.unwrap_or_else(|_| HeaderValue::from_static("0")),
);
response
}

699
src/canvas/mod.rs Normal file
View file

@ -0,0 +1,699 @@
use axum::{
extract::{Path, State},
http::StatusCode,
response::IntoResponse,
routing::{get, post, put},
Json, Router,
};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
use crate::shared::state::AppState;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Canvas {
pub id: Uuid,
pub organization_id: Uuid,
pub name: String,
pub description: Option<String>,
pub width: u32,
pub height: u32,
pub background_color: String,
pub elements: Vec<CanvasElement>,
pub created_by: Uuid,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
pub is_public: bool,
pub collaborators: Vec<Uuid>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CanvasElement {
pub id: Uuid,
pub element_type: ElementType,
pub x: f64,
pub y: f64,
pub width: f64,
pub height: f64,
pub rotation: f64,
pub properties: ElementProperties,
pub z_index: i32,
pub locked: bool,
pub created_by: Uuid,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ElementType {
Rectangle,
Ellipse,
Line,
Arrow,
FreehandPath,
Text,
Image,
Sticky,
Frame,
Connector,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ElementProperties {
pub fill_color: Option<String>,
pub stroke_color: Option<String>,
pub stroke_width: Option<f64>,
pub opacity: Option<f64>,
pub text: Option<String>,
pub font_size: Option<f64>,
pub font_family: Option<String>,
pub text_align: Option<String>,
pub image_url: Option<String>,
pub path_data: Option<String>,
pub corner_radius: Option<f64>,
pub start_arrow: Option<String>,
pub end_arrow: Option<String>,
}
impl Default for ElementProperties {
fn default() -> Self {
Self {
fill_color: Some("#ffffff".to_string()),
stroke_color: Some("#000000".to_string()),
stroke_width: Some(2.0),
opacity: Some(1.0),
text: None,
font_size: Some(16.0),
font_family: Some("Inter".to_string()),
text_align: Some("left".to_string()),
image_url: None,
path_data: None,
corner_radius: Some(0.0),
start_arrow: None,
end_arrow: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CanvasSummary {
pub id: Uuid,
pub name: String,
pub description: Option<String>,
pub thumbnail_url: Option<String>,
pub element_count: usize,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
pub is_public: bool,
}
#[derive(Debug, Deserialize)]
pub struct CreateCanvasRequest {
pub name: String,
pub description: Option<String>,
pub width: Option<u32>,
pub height: Option<u32>,
pub background_color: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct UpdateCanvasRequest {
pub name: Option<String>,
pub description: Option<String>,
pub width: Option<u32>,
pub height: Option<u32>,
pub background_color: Option<String>,
pub is_public: Option<bool>,
}
#[derive(Debug, Deserialize)]
pub struct CreateElementRequest {
pub element_type: ElementType,
pub x: f64,
pub y: f64,
pub width: f64,
pub height: f64,
pub rotation: Option<f64>,
pub properties: Option<ElementProperties>,
pub z_index: Option<i32>,
}
#[derive(Debug, Deserialize)]
pub struct UpdateElementRequest {
pub x: Option<f64>,
pub y: Option<f64>,
pub width: Option<f64>,
pub height: Option<f64>,
pub rotation: Option<f64>,
pub properties: Option<ElementProperties>,
pub z_index: Option<i32>,
pub locked: Option<bool>,
}
#[derive(Debug, Deserialize)]
pub struct ExportRequest {
pub format: ExportFormat,
pub scale: Option<f64>,
pub background: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum ExportFormat {
Png,
Svg,
Pdf,
Json,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExportResponse {
pub format: ExportFormat,
pub url: Option<String>,
pub data: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CollaborationSession {
pub canvas_id: Uuid,
pub user_id: Uuid,
pub cursor_x: f64,
pub cursor_y: f64,
pub selection: Vec<Uuid>,
pub connected_at: DateTime<Utc>,
}
pub struct CanvasService {
canvases: Arc<RwLock<HashMap<Uuid, Canvas>>>,
}
impl CanvasService {
pub fn new() -> Self {
Self {
canvases: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn list_canvases(&self, org_id: Uuid) -> Vec<CanvasSummary> {
let canvases = self.canvases.read().await;
canvases
.values()
.filter(|c| c.organization_id == org_id)
.map(|c| CanvasSummary {
id: c.id,
name: c.name.clone(),
description: c.description.clone(),
thumbnail_url: None,
element_count: c.elements.len(),
created_at: c.created_at,
updated_at: c.updated_at,
is_public: c.is_public,
})
.collect()
}
pub async fn create_canvas(
&self,
org_id: Uuid,
user_id: Uuid,
req: CreateCanvasRequest,
) -> Canvas {
let now = Utc::now();
let canvas = Canvas {
id: Uuid::new_v4(),
organization_id: org_id,
name: req.name,
description: req.description,
width: req.width.unwrap_or(1920),
height: req.height.unwrap_or(1080),
background_color: req.background_color.unwrap_or_else(|| "#ffffff".to_string()),
elements: vec![],
created_by: user_id,
created_at: now,
updated_at: now,
is_public: false,
collaborators: vec![user_id],
};
let mut canvases = self.canvases.write().await;
canvases.insert(canvas.id, canvas.clone());
canvas
}
pub async fn get_canvas(&self, canvas_id: Uuid) -> Option<Canvas> {
let canvases = self.canvases.read().await;
canvases.get(&canvas_id).cloned()
}
pub async fn update_canvas(
&self,
canvas_id: Uuid,
req: UpdateCanvasRequest,
) -> Option<Canvas> {
let mut canvases = self.canvases.write().await;
if let Some(canvas) = canvases.get_mut(&canvas_id) {
if let Some(name) = req.name {
canvas.name = name;
}
if let Some(desc) = req.description {
canvas.description = Some(desc);
}
if let Some(width) = req.width {
canvas.width = width;
}
if let Some(height) = req.height {
canvas.height = height;
}
if let Some(bg) = req.background_color {
canvas.background_color = bg;
}
if let Some(public) = req.is_public {
canvas.is_public = public;
}
canvas.updated_at = Utc::now();
return Some(canvas.clone());
}
None
}
pub async fn delete_canvas(&self, canvas_id: Uuid) -> bool {
let mut canvases = self.canvases.write().await;
canvases.remove(&canvas_id).is_some()
}
pub async fn add_element(
&self,
canvas_id: Uuid,
user_id: Uuid,
req: CreateElementRequest,
) -> Option<CanvasElement> {
let mut canvases = self.canvases.write().await;
if let Some(canvas) = canvases.get_mut(&canvas_id) {
let now = Utc::now();
let element = CanvasElement {
id: Uuid::new_v4(),
element_type: req.element_type,
x: req.x,
y: req.y,
width: req.width,
height: req.height,
rotation: req.rotation.unwrap_or(0.0),
properties: req.properties.unwrap_or_default(),
z_index: req.z_index.unwrap_or(canvas.elements.len() as i32),
locked: false,
created_by: user_id,
created_at: now,
updated_at: now,
};
canvas.elements.push(element.clone());
canvas.updated_at = now;
return Some(element);
}
None
}
pub async fn update_element(
&self,
canvas_id: Uuid,
element_id: Uuid,
req: UpdateElementRequest,
) -> Option<CanvasElement> {
let mut canvases = self.canvases.write().await;
if let Some(canvas) = canvases.get_mut(&canvas_id) {
if let Some(element) = canvas.elements.iter_mut().find(|e| e.id == element_id) {
if let Some(x) = req.x {
element.x = x;
}
if let Some(y) = req.y {
element.y = y;
}
if let Some(width) = req.width {
element.width = width;
}
if let Some(height) = req.height {
element.height = height;
}
if let Some(rotation) = req.rotation {
element.rotation = rotation;
}
if let Some(props) = req.properties {
element.properties = props;
}
if let Some(z) = req.z_index {
element.z_index = z;
}
if let Some(locked) = req.locked {
element.locked = locked;
}
element.updated_at = Utc::now();
canvas.updated_at = Utc::now();
return Some(element.clone());
}
}
None
}
pub async fn delete_element(&self, canvas_id: Uuid, element_id: Uuid) -> bool {
let mut canvases = self.canvases.write().await;
if let Some(canvas) = canvases.get_mut(&canvas_id) {
let len_before = canvas.elements.len();
canvas.elements.retain(|e| e.id != element_id);
if canvas.elements.len() < len_before {
canvas.updated_at = Utc::now();
return true;
}
}
false
}
pub async fn export_canvas(
&self,
canvas_id: Uuid,
req: ExportRequest,
) -> Option<ExportResponse> {
let canvases = self.canvases.read().await;
let canvas = canvases.get(&canvas_id)?;
match req.format {
ExportFormat::Json => {
let json = serde_json::to_string_pretty(canvas).ok()?;
Some(ExportResponse {
format: ExportFormat::Json,
url: None,
data: Some(json),
})
}
ExportFormat::Svg => {
let svg = generate_svg(canvas, req.background.unwrap_or(true));
Some(ExportResponse {
format: ExportFormat::Svg,
url: None,
data: Some(svg),
})
}
_ => Some(ExportResponse {
format: req.format,
url: Some(format!("/api/canvas/{}/export/file", canvas_id)),
data: None,
}),
}
}
}
impl Default for CanvasService {
fn default() -> Self {
Self::new()
}
}
fn generate_svg(canvas: &Canvas, include_background: bool) -> String {
let mut svg = format!(
r#"<svg xmlns="http://www.w3.org/2000/svg" width="{}" height="{}" viewBox="0 0 {} {}">"#,
canvas.width, canvas.height, canvas.width, canvas.height
);
if include_background {
svg.push_str(&format!(
r#"<rect width="100%" height="100%" fill="{}"/>"#,
canvas.background_color
));
}
for element in &canvas.elements {
let transform = if element.rotation != 0.0 {
format!(
r#" transform="rotate({} {} {})""#,
element.rotation,
element.x + element.width / 2.0,
element.y + element.height / 2.0
)
} else {
String::new()
};
let fill = element
.properties
.fill_color
.as_deref()
.unwrap_or("transparent");
let stroke = element
.properties
.stroke_color
.as_deref()
.unwrap_or("none");
let stroke_width = element.properties.stroke_width.unwrap_or(1.0);
let opacity = element.properties.opacity.unwrap_or(1.0);
match element.element_type {
ElementType::Rectangle => {
let radius = element.properties.corner_radius.unwrap_or(0.0);
svg.push_str(&format!(
r#"<rect x="{}" y="{}" width="{}" height="{}" rx="{}" fill="{}" stroke="{}" stroke-width="{}" opacity="{}"{}/>"#,
element.x, element.y, element.width, element.height,
radius, fill, stroke, stroke_width, opacity, transform
));
}
ElementType::Ellipse => {
svg.push_str(&format!(
r#"<ellipse cx="{}" cy="{}" rx="{}" ry="{}" fill="{}" stroke="{}" stroke-width="{}" opacity="{}"{}/>"#,
element.x + element.width / 2.0,
element.y + element.height / 2.0,
element.width / 2.0,
element.height / 2.0,
fill, stroke, stroke_width, opacity, transform
));
}
ElementType::Text => {
let text = element.properties.text.as_deref().unwrap_or("");
let font_size = element.properties.font_size.unwrap_or(16.0);
let font_family = element
.properties
.font_family
.as_deref()
.unwrap_or("sans-serif");
svg.push_str(&format!(
r#"<text x="{}" y="{}" font-size="{}" font-family="{}" fill="{}" opacity="{}"{}>
{}
</text>"#,
element.x, element.y + font_size, font_size, font_family,
fill, opacity, transform, text
));
}
ElementType::FreehandPath => {
if let Some(path_data) = &element.properties.path_data {
svg.push_str(&format!(
r#"<path d="{}" fill="none" stroke="{}" stroke-width="{}" opacity="{}"{}/>"#,
path_data, stroke, stroke_width, opacity, transform
));
}
}
ElementType::Line | ElementType::Arrow => {
let marker = if element.element_type == ElementType::Arrow {
r#" marker-end="url(#arrowhead)""#
} else {
""
};
svg.push_str(&format!(
r#"<line x1="{}" y1="{}" x2="{}" y2="{}" stroke="{}" stroke-width="{}" opacity="{}"{}{}/>"#,
element.x, element.y,
element.x + element.width, element.y + element.height,
stroke, stroke_width, opacity, marker, transform
));
}
_ => {}
}
}
svg.push_str("</svg>");
svg
}
#[derive(Debug, Serialize)]
pub struct CanvasError {
pub error: String,
pub code: String,
}
impl IntoResponse for CanvasError {
fn into_response(self) -> axum::response::Response {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"error": self.error, "code": self.code})),
)
.into_response()
}
}
async fn list_canvases(
State(_state): State<Arc<AppState>>,
) -> Result<Json<Vec<CanvasSummary>>, CanvasError> {
let service = CanvasService::new();
let org_id = Uuid::nil();
let canvases = service.list_canvases(org_id).await;
Ok(Json(canvases))
}
async fn create_canvas(
State(_state): State<Arc<AppState>>,
Json(req): Json<CreateCanvasRequest>,
) -> Result<Json<Canvas>, CanvasError> {
let service = CanvasService::new();
let org_id = Uuid::nil();
let user_id = Uuid::nil();
let canvas = service.create_canvas(org_id, user_id, req).await;
Ok(Json(canvas))
}
async fn get_canvas(
State(_state): State<Arc<AppState>>,
Path(canvas_id): Path<Uuid>,
) -> Result<Json<Canvas>, CanvasError> {
let service = CanvasService::new();
let canvas = service.get_canvas(canvas_id).await.ok_or_else(|| CanvasError {
error: "Canvas not found".to_string(),
code: "CANVAS_NOT_FOUND".to_string(),
})?;
Ok(Json(canvas))
}
async fn update_canvas(
State(_state): State<Arc<AppState>>,
Path(canvas_id): Path<Uuid>,
Json(req): Json<UpdateCanvasRequest>,
) -> Result<Json<Canvas>, CanvasError> {
let service = CanvasService::new();
let canvas = service
.update_canvas(canvas_id, req)
.await
.ok_or_else(|| CanvasError {
error: "Canvas not found".to_string(),
code: "CANVAS_NOT_FOUND".to_string(),
})?;
Ok(Json(canvas))
}
async fn delete_canvas(
State(_state): State<Arc<AppState>>,
Path(canvas_id): Path<Uuid>,
) -> Result<StatusCode, CanvasError> {
let service = CanvasService::new();
if service.delete_canvas(canvas_id).await {
Ok(StatusCode::NO_CONTENT)
} else {
Err(CanvasError {
error: "Canvas not found".to_string(),
code: "CANVAS_NOT_FOUND".to_string(),
})
}
}
async fn list_elements(
State(_state): State<Arc<AppState>>,
Path(canvas_id): Path<Uuid>,
) -> Result<Json<Vec<CanvasElement>>, CanvasError> {
let service = CanvasService::new();
let canvas = service.get_canvas(canvas_id).await.ok_or_else(|| CanvasError {
error: "Canvas not found".to_string(),
code: "CANVAS_NOT_FOUND".to_string(),
})?;
Ok(Json(canvas.elements))
}
async fn create_element(
State(_state): State<Arc<AppState>>,
Path(canvas_id): Path<Uuid>,
Json(req): Json<CreateElementRequest>,
) -> Result<Json<CanvasElement>, CanvasError> {
let service = CanvasService::new();
let user_id = Uuid::nil();
let element = service
.add_element(canvas_id, user_id, req)
.await
.ok_or_else(|| CanvasError {
error: "Canvas not found".to_string(),
code: "CANVAS_NOT_FOUND".to_string(),
})?;
Ok(Json(element))
}
async fn update_element(
State(_state): State<Arc<AppState>>,
Path((canvas_id, element_id)): Path<(Uuid, Uuid)>,
Json(req): Json<UpdateElementRequest>,
) -> Result<Json<CanvasElement>, CanvasError> {
let service = CanvasService::new();
let element = service
.update_element(canvas_id, element_id, req)
.await
.ok_or_else(|| CanvasError {
error: "Element not found".to_string(),
code: "ELEMENT_NOT_FOUND".to_string(),
})?;
Ok(Json(element))
}
async fn delete_element(
State(_state): State<Arc<AppState>>,
Path((canvas_id, element_id)): Path<(Uuid, Uuid)>,
) -> Result<StatusCode, CanvasError> {
let service = CanvasService::new();
if service.delete_element(canvas_id, element_id).await {
Ok(StatusCode::NO_CONTENT)
} else {
Err(CanvasError {
error: "Element not found".to_string(),
code: "ELEMENT_NOT_FOUND".to_string(),
})
}
}
async fn export_canvas(
State(_state): State<Arc<AppState>>,
Path(canvas_id): Path<Uuid>,
Json(req): Json<ExportRequest>,
) -> Result<Json<ExportResponse>, CanvasError> {
let service = CanvasService::new();
let response = service
.export_canvas(canvas_id, req)
.await
.ok_or_else(|| CanvasError {
error: "Canvas not found".to_string(),
code: "CANVAS_NOT_FOUND".to_string(),
})?;
Ok(Json(response))
}
async fn get_collaboration_info(
State(_state): State<Arc<AppState>>,
Path(canvas_id): Path<Uuid>,
) -> Result<Json<Vec<CollaborationSession>>, CanvasError> {
let _ = canvas_id;
Ok(Json(vec![]))
}
pub fn configure_canvas_routes() -> Router<Arc<AppState>> {
Router::new()
.route("/api/canvas", get(list_canvases).post(create_canvas))
.route(
"/api/canvas/:canvas_id",
get(get_canvas).put(update_canvas).delete(delete_canvas),
)
.route(
"/api/canvas/:canvas_id/elements",
get(list_elements).post(create_element),
)
.route(
"/api/canvas/:canvas_id/elements/:element_id",
put(update_element).delete(delete_element),
)
.route("/api/canvas/:canvas_id/export", post(export_canvas))
.route(
"/api/canvas/:canvas_id/collaborate",
get(get_collaboration_info),
)
}

View file

@ -1,6 +1,7 @@
use crate::core::secrets::{SecretPaths, SecretsManager};
use crate::package_manager::{get_all_components, InstallMode, PackageManager};
use crate::security::command_guard::SafeCommand;
use crate::security::protection::{ProtectionInstaller, VerifyResult};
use anyhow::Result;
use rand::Rng;
use std::collections::HashMap;
@ -87,6 +88,12 @@ pub async fn run() -> Result<()> {
return Ok(());
}
let component = &args[2];
if component == "protection" {
install_protection()?;
return Ok(());
}
let mode = if args.contains(&"--container".to_string()) {
InstallMode::Container
} else {
@ -111,6 +118,12 @@ pub async fn run() -> Result<()> {
return Ok(());
}
let component = &args[2];
if component == "protection" {
remove_protection()?;
return Ok(());
}
let mode = if args.contains(&"--container".to_string()) {
InstallMode::Container
} else {
@ -153,6 +166,13 @@ pub async fn run() -> Result<()> {
return Ok(());
}
let component = &args[2];
if component == "protection" {
let result = verify_protection();
result.print();
return Ok(());
}
let mode = if args.contains(&"--container".to_string()) {
InstallMode::Container
} else {
@ -271,6 +291,11 @@ fn print_usage() {
println!(" --container Use container mode (LXC)");
println!(" --tenant <name> Specify tenant name");
println!();
println!("Security Protection (requires root):");
println!(" sudo botserver install protection Install security tools + sudoers");
println!(" sudo botserver remove protection Remove sudoers configuration");
println!(" botserver status protection Check protection tools status");
println!();
println!("Vault subcommands:");
println!(" vault migrate [.env] Migrate .env secrets to Vault");
println!(" vault put <path> k=v Store secrets in Vault");
@ -279,6 +304,55 @@ fn print_usage() {
println!(" vault health Check Vault health");
}
fn install_protection() -> Result<()> {
let installer = ProtectionInstaller::new()?;
if !ProtectionInstaller::check_root() {
eprintln!("Error: This command requires root privileges.");
eprintln!();
eprintln!("Run with: sudo botserver install protection");
return Ok(());
}
println!("Installing Security Protection Tools...");
println!();
println!("This will:");
println!(" 1. Install security packages (lynis, rkhunter, chkrootkit, suricata, clamav)");
println!(" 2. Install Linux Malware Detect (LMD)");
println!(" 3. Create sudoers configuration for runtime execution");
println!(" 4. Update security databases");
println!();
let result = installer.install()?;
result.print();
Ok(())
}
fn remove_protection() -> Result<()> {
let installer = ProtectionInstaller::new()?;
if !ProtectionInstaller::check_root() {
eprintln!("Error: This command requires root privileges.");
eprintln!();
eprintln!("Run with: sudo botserver remove protection");
return Ok(());
}
println!("Removing Security Protection Configuration...");
println!();
let result = installer.uninstall()?;
result.print();
Ok(())
}
fn verify_protection() -> VerifyResult {
let installer = ProtectionInstaller::default();
installer.verify()
}
fn print_vault_usage() {
println!("Vault Secret Management");
println!();

View file

@ -7,6 +7,9 @@ use crate::core::session::SessionManager;
use crate::core::shared::analytics::MetricsCollector;
use crate::project::ProjectService;
use crate::legal::LegalService;
use crate::security::auth_provider::AuthProviderRegistry;
use crate::security::jwt::JwtManager;
use crate::security::rbac_middleware::RbacManager;
#[cfg(all(test, feature = "directory"))]
use crate::core::shared::test_utils::create_mock_auth_service;
#[cfg(all(test, feature = "llm"))]
@ -351,6 +354,9 @@ pub struct AppState {
pub task_manifests: Arc<std::sync::RwLock<HashMap<String, TaskManifest>>>,
pub project_service: Arc<RwLock<ProjectService>>,
pub legal_service: Arc<RwLock<LegalService>>,
pub jwt_manager: Option<Arc<JwtManager>>,
pub auth_provider_registry: Option<Arc<AuthProviderRegistry>>,
pub rbac_manager: Option<Arc<RbacManager>>,
}
impl Clone for AppState {
@ -385,6 +391,9 @@ impl Clone for AppState {
task_manifests: Arc::clone(&self.task_manifests),
project_service: Arc::clone(&self.project_service),
legal_service: Arc::clone(&self.legal_service),
jwt_manager: self.jwt_manager.clone(),
auth_provider_registry: self.auth_provider_registry.clone(),
rbac_manager: self.rbac_manager.clone(),
}
}
}
@ -427,6 +436,9 @@ impl std::fmt::Debug for AppState {
.field("extensions", &self.extensions)
.field("attendant_broadcast", &self.attendant_broadcast.is_some())
.field("task_progress_broadcast", &self.task_progress_broadcast.is_some())
.field("jwt_manager", &self.jwt_manager.is_some())
.field("auth_provider_registry", &self.auth_provider_registry.is_some())
.field("rbac_manager", &self.rbac_manager.is_some())
.finish()
}
}
@ -575,6 +587,9 @@ impl Default for AppState {
task_manifests: Arc::new(std::sync::RwLock::new(HashMap::new())),
project_service: Arc::new(RwLock::new(crate::project::ProjectService::new())),
legal_service: Arc::new(RwLock::new(crate::legal::LegalService::new())),
jwt_manager: None,
auth_provider_registry: None,
rbac_manager: None,
}
}
}

View file

@ -392,6 +392,66 @@ impl ApiUrls {
pub const SOURCES_KB_REINDEX: &'static str = "/api/ui/sources/kb/reindex";
pub const SOURCES_KB_STATS: &'static str = "/api/ui/sources/kb/stats";
// Workspaces - JSON APIs
pub const WORKSPACES: &'static str = "/api/workspaces";
pub const WORKSPACE_BY_ID: &'static str = "/api/workspaces/:workspace_id";
pub const WORKSPACE_PAGES: &'static str = "/api/workspaces/:workspace_id/pages";
pub const WORKSPACE_MEMBERS: &'static str = "/api/workspaces/:workspace_id/members";
pub const WORKSPACE_MEMBER: &'static str = "/api/workspaces/:workspace_id/members/:user_id";
pub const WORKSPACE_SEARCH: &'static str = "/api/workspaces/:workspace_id/search";
pub const WORKSPACE_COMMANDS: &'static str = "/api/workspaces/commands";
pub const PAGE_BY_ID: &'static str = "/api/pages/:page_id";
// Project - JSON APIs
pub const PROJECTS: &'static str = "/projects";
pub const PROJECT_BY_ID: &'static str = "/projects/:project_id";
pub const PROJECT_TASKS: &'static str = "/projects/:project_id/tasks";
pub const PROJECT_GANTT: &'static str = "/projects/:project_id/gantt";
pub const PROJECT_TIMELINE: &'static str = "/projects/:project_id/timeline";
pub const PROJECT_CRITICAL_PATH: &'static str = "/projects/:project_id/critical-path";
pub const PROJECT_TASK_PROGRESS: &'static str = "/tasks/:task_id/progress";
pub const PROJECT_TASK_DEPENDENCIES: &'static str = "/tasks/:task_id/dependencies";
pub const PROJECT_TASK: &'static str = "/tasks/:task_id";
// Goals (OKR) - JSON APIs
pub const GOALS_OBJECTIVES: &'static str = "/api/goals/objectives";
pub const GOALS_OBJECTIVE_BY_ID: &'static str = "/api/goals/objectives/:id";
pub const GOALS_KEY_RESULTS: &'static str = "/api/goals/objectives/:id/key-results";
pub const GOALS_KEY_RESULT_BY_ID: &'static str = "/api/goals/key-results/:id";
pub const GOALS_CHECK_IN: &'static str = "/api/goals/key-results/:id/check-in";
pub const GOALS_HISTORY: &'static str = "/api/goals/key-results/:id/history";
pub const GOALS_DASHBOARD: &'static str = "/api/goals/dashboard";
pub const GOALS_ALIGNMENT: &'static str = "/api/goals/alignment";
pub const GOALS_AI_SUGGEST: &'static str = "/api/goals/ai/suggest";
// Security Admin - JSON APIs
pub const SECURITY_OVERVIEW: &'static str = "/api/security/overview";
pub const SECURITY_SCAN: &'static str = "/api/security/scan";
pub const SECURITY_TLS: &'static str = "/api/security/tls";
pub const SECURITY_RATE_LIMIT: &'static str = "/api/security/rate-limit";
pub const SECURITY_CORS: &'static str = "/api/security/cors";
pub const SECURITY_AUDIT: &'static str = "/api/security/audit";
pub const SECURITY_API_KEYS: &'static str = "/api/security/api-keys";
pub const SECURITY_API_KEY_BY_ID: &'static str = "/api/security/api-keys/:key_id";
pub const SECURITY_MFA: &'static str = "/api/security/mfa";
pub const SECURITY_SESSIONS: &'static str = "/api/security/sessions";
pub const SECURITY_SESSION_BY_ID: &'static str = "/api/security/sessions/:session_id";
pub const SECURITY_USER_SESSIONS: &'static str = "/api/security/users/:user_id/sessions";
pub const SECURITY_PASSWORD_POLICY: &'static str = "/api/security/password-policy";
// Player - JSON APIs
pub const PLAYER_FILE: &'static str = "/api/player/:bot_id/file/*path";
pub const PLAYER_STREAM: &'static str = "/api/player/:bot_id/stream/*path";
pub const PLAYER_THUMBNAIL: &'static str = "/api/player/:bot_id/thumbnail/*path";
// Canvas - JSON APIs
pub const CANVAS_LIST: &'static str = "/api/canvas";
pub const CANVAS_BY_ID: &'static str = "/api/canvas/:id";
pub const CANVAS_ELEMENTS: &'static str = "/api/canvas/:id/elements";
pub const CANVAS_ELEMENT_BY_ID: &'static str = "/api/canvas/:id/elements/:element_id";
pub const CANVAS_EXPORT: &'static str = "/api/canvas/:id/export";
pub const CANVAS_COLLABORATE: &'static str = "/api/canvas/:id/collaborate";
// WebSocket endpoints
pub const WS: &'static str = "/ws";
pub const WS_MEET: &'static str = "/ws/meet";

111
src/embedded_ui.rs Normal file
View file

@ -0,0 +1,111 @@
use axum::{
body::Body,
http::{header, Request, Response, StatusCode},
routing::get,
Router,
};
use rust_embed::Embed;
use std::path::Path;
#[derive(Embed)]
#[folder = "../botui/ui/suite/"]
#[prefix = ""]
struct EmbeddedUi;
fn get_mime_type(path: &str) -> &'static str {
let ext = Path::new(path)
.extension()
.and_then(|e| e.to_str())
.unwrap_or("");
match ext {
"html" | "htm" => "text/html; charset=utf-8",
"css" => "text/css; charset=utf-8",
"js" | "mjs" => "application/javascript; charset=utf-8",
"json" => "application/json; charset=utf-8",
"png" => "image/png",
"jpg" | "jpeg" => "image/jpeg",
"gif" => "image/gif",
"svg" => "image/svg+xml",
"ico" => "image/x-icon",
"woff" => "font/woff",
"woff2" => "font/woff2",
"ttf" => "font/ttf",
"otf" => "font/otf",
"eot" => "application/vnd.ms-fontobject",
"webp" => "image/webp",
"mp4" => "video/mp4",
"webm" => "video/webm",
"mp3" => "audio/mpeg",
"wav" => "audio/wav",
"ogg" => "audio/ogg",
"pdf" => "application/pdf",
"xml" => "application/xml",
"txt" => "text/plain; charset=utf-8",
"md" => "text/markdown; charset=utf-8",
"wasm" => "application/wasm",
_ => "application/octet-stream",
}
}
async fn serve_embedded_file(req: Request<Body>) -> Response<Body> {
let path = req.uri().path().trim_start_matches('/');
let file_path = if path.is_empty() || path == "/" {
"index.html"
} else {
path
};
let try_paths = [
file_path.to_string(),
format!("{}/index.html", file_path.trim_end_matches('/')),
format!("{}.html", file_path),
];
for try_path in &try_paths {
if let Some(content) = EmbeddedUi::get(try_path) {
let mime = get_mime_type(try_path);
return Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, mime)
.header(header::CACHE_CONTROL, "public, max-age=3600")
.body(Body::from(content.data.into_owned()))
.unwrap_or_else(|_| {
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from("Internal Server Error"))
.unwrap()
});
}
}
Response::builder()
.status(StatusCode::NOT_FOUND)
.header(header::CONTENT_TYPE, "text/html; charset=utf-8")
.body(Body::from(
r#"<!DOCTYPE html>
<html>
<head><title>404 Not Found</title></head>
<body>
<h1>404 - Not Found</h1>
<p>The requested file was not found in embedded UI.</p>
<p><a href="/">Go to Home</a></p>
</body>
</html>"#,
))
.unwrap()
}
pub fn embedded_ui_router() -> Router {
Router::new().fallback(get(serve_embedded_file))
}
pub fn has_embedded_ui() -> bool {
EmbeddedUi::get("index.html").is_some()
}
pub fn list_embedded_files() -> Vec<String> {
EmbeddedUi::iter().map(|f| f.to_string()).collect()
}

View file

@ -2252,8 +2252,6 @@ pub fn configure_learn_routes() -> Router<Arc<AppState>> {
// Statistics
.route("/api/learn/stats", get(get_statistics))
.route("/api/learn/stats/user", get(get_user_stats))
// UI
.route("/suite/learn/learn.html", get(learn_ui))
}
/// Simplified configure function for module registration

View file

@ -1,12 +1,15 @@
pub mod auto_task;
pub mod basic;
pub mod billing;
pub mod canvas;
pub mod channels;
pub mod contacts;
pub mod core;
pub mod dashboards;
pub mod embedded_ui;
pub mod maintenance;
pub mod multimodal;
pub mod player;
pub mod search;
pub mod security;

View file

@ -60,10 +60,12 @@ async fn ensure_vendor_files_in_minio(drive: &aws_sdk_s3::Client) {
}
use botserver::security::{
auth_middleware, create_cors_layer, create_rate_limit_layer, create_security_headers_layer,
create_cors_layer, create_rate_limit_layer, create_security_headers_layer,
request_id_middleware, security_headers_middleware, set_cors_allowed_origins,
set_global_panic_hook, AuthConfig, HttpRateLimitConfig, PanicHandlerConfig,
SecurityHeadersConfig,
SecurityHeadersConfig, AuthProviderBuilder, ApiKeyAuthProvider, JwtConfig, JwtKey,
JwtManager, RbacManager, RbacConfig, AuthMiddlewareState,
build_default_route_permissions,
};
use botlib::SystemLimits;
@ -225,16 +227,87 @@ async fn run_axum_server(
let cors = create_cors_layer();
// Create auth config for protected routes
// TODO: Re-enable auth for production - currently disabled for development
let auth_config = Arc::new(AuthConfig::default()
let auth_config = Arc::new(AuthConfig::from_env()
.add_anonymous_path("/health")
.add_anonymous_path("/healthz")
.add_anonymous_path("/api") // Disable auth for all API routes during development
.add_anonymous_path("/api/health")
.add_anonymous_path("/api/product")
.add_anonymous_path("/ws")
.add_anonymous_path("/auth")
.add_public_path("/static")
.add_public_path("/favicon.ico")
.add_public_path("/apps")); // Apps are public - no auth required
.add_public_path("/suite")
.add_public_path("/themes"));
// Initialize JWT Manager for token validation
let jwt_secret = std::env::var("JWT_SECRET")
.unwrap_or_else(|_| {
warn!("JWT_SECRET not set, using default development secret - DO NOT USE IN PRODUCTION");
"dev-secret-key-change-in-production-minimum-32-chars".to_string()
});
let jwt_config = JwtConfig::default();
let jwt_key = JwtKey::from_secret(&jwt_secret);
let jwt_manager = match JwtManager::new(jwt_config, jwt_key) {
Ok(manager) => {
info!("JWT Manager initialized successfully");
Some(Arc::new(manager))
}
Err(e) => {
error!("Failed to initialize JWT Manager: {e}");
None
}
};
// Initialize RBAC Manager for permission enforcement
let rbac_config = RbacConfig::default();
let rbac_manager = Arc::new(RbacManager::new(rbac_config));
// Register default route permissions
let default_permissions = build_default_route_permissions();
rbac_manager.register_routes(default_permissions).await;
info!("RBAC Manager initialized with {} default route permissions",
rbac_manager.config().cache_ttl_seconds);
// Build authentication provider registry
let auth_provider_registry = {
let mut builder = AuthProviderBuilder::new()
.with_api_key_provider(Arc::new(ApiKeyAuthProvider::new()))
.with_auth_config(Arc::clone(&auth_config));
if let Some(ref manager) = jwt_manager {
builder = builder.with_jwt_manager(Arc::clone(manager));
}
// Check for Zitadel configuration
let zitadel_configured = std::env::var("ZITADEL_ISSUER_URL").is_ok()
&& std::env::var("ZITADEL_CLIENT_ID").is_ok();
if zitadel_configured {
info!("Zitadel environment variables detected - external IdP authentication available");
}
// In development mode, allow fallback to anonymous
let is_dev = std::env::var("BOTSERVER_ENV")
.map(|v| v == "development" || v == "dev")
.unwrap_or(true);
if is_dev {
builder = builder.with_fallback(true);
warn!("Authentication fallback enabled (development mode) - disable in production");
}
Arc::new(builder.build().await)
};
info!("Auth provider registry initialized with {} providers",
auth_provider_registry.provider_count().await);
// Create auth middleware state for the new provider-based authentication
let auth_middleware_state = AuthMiddlewareState::new(
Arc::clone(&auth_config),
Arc::clone(&auth_provider_registry),
);
use crate::core::urls::ApiUrls;
use crate::core::product::{PRODUCT_CONFIG, get_product_config_json};
@ -387,13 +460,15 @@ async fn run_axum_server(
warn!("No UI available: folder '{}' not found and no embedded UI", ui_path);
}
// Update app_state with auth components
let mut app_state_with_auth = (*app_state).clone();
app_state_with_auth.jwt_manager = jwt_manager;
app_state_with_auth.auth_provider_registry = Some(Arc::clone(&auth_provider_registry));
app_state_with_auth.rbac_manager = Some(Arc::clone(&rbac_manager));
let app_state = Arc::new(app_state_with_auth);
let base_router = Router::new()
.merge(api_router.with_state(app_state.clone()))
// Authentication middleware for protected routes
.layer(middleware::from_fn_with_state(
auth_config.clone(),
auth_middleware,
))
// Static files fallback for legacy /apps/* paths
.nest_service("/static", ServeDir::new(&site_path));
@ -418,6 +493,13 @@ async fn run_axum_server(
.layer(rate_limit_extension)
// Request ID tracking for all requests
.layer(middleware::from_fn(request_id_middleware))
// Authentication middleware using provider registry
.layer(middleware::from_fn(move |req: axum::http::Request<axum::body::Body>, next: axum::middleware::Next| {
let state = auth_middleware_state.clone();
async move {
botserver::security::auth_middleware_with_providers(req, next, state).await
}
}))
// Panic handler catches panics and returns safe 500 responses
.layer(middleware::from_fn(move |req, next| {
let config = panic_config.clone();
@ -1086,6 +1168,9 @@ async fn main() -> std::io::Result<()> {
task_manifests: Arc::new(std::sync::RwLock::new(HashMap::new())),
project_service: Arc::new(tokio::sync::RwLock::new(botserver::project::ProjectService::new())),
legal_service: Arc::new(tokio::sync::RwLock::new(botserver::legal::LegalService::new())),
jwt_manager: None,
auth_provider_registry: None,
rbac_manager: None,
});
let task_scheduler = Arc::new(botserver::tasks::scheduler::TaskScheduler::new(

208
src/player/mod.rs Normal file
View file

@ -0,0 +1,208 @@
use axum::{
body::Body,
extract::{Path, Query, State},
http::{header, StatusCode},
response::{IntoResponse, Response},
routing::get,
Json, Router,
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use crate::shared::state::AppState;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MediaInfo {
pub path: String,
pub filename: String,
pub mime_type: String,
pub size: u64,
pub duration: Option<f64>,
pub width: Option<u32>,
pub height: Option<u32>,
pub format: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ThumbnailInfo {
pub path: String,
pub width: u32,
pub height: u32,
pub format: String,
}
#[derive(Debug, Deserialize)]
pub struct StreamQuery {
pub quality: Option<String>,
pub start: Option<f64>,
pub end: Option<f64>,
}
#[derive(Debug, Deserialize)]
pub struct ThumbnailQuery {
pub width: Option<u32>,
pub height: Option<u32>,
pub time: Option<f64>,
}
#[derive(Debug, Serialize)]
pub struct PlayerError {
pub error: String,
pub code: String,
}
impl IntoResponse for PlayerError {
fn into_response(self) -> Response {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"error": self.error, "code": self.code})),
)
.into_response()
}
}
fn get_mime_type(path: &str) -> &'static str {
let ext = path.rsplit('.').next().unwrap_or("").to_lowercase();
match ext.as_str() {
"mp4" => "video/mp4",
"webm" => "video/webm",
"ogv" => "video/ogg",
"mp3" => "audio/mpeg",
"wav" => "audio/wav",
"ogg" => "audio/ogg",
"m4a" => "audio/mp4",
"flac" => "audio/flac",
"pdf" => "application/pdf",
"png" => "image/png",
"jpg" | "jpeg" => "image/jpeg",
"gif" => "image/gif",
"svg" => "image/svg+xml",
"webp" => "image/webp",
_ => "application/octet-stream",
}
}
fn get_format(path: &str) -> String {
path.rsplit('.')
.next()
.unwrap_or("unknown")
.to_uppercase()
}
async fn get_file_info(
State(_state): State<Arc<AppState>>,
Path((bot_id, path)): Path<(String, String)>,
) -> Result<Json<MediaInfo>, PlayerError> {
let filename = path.rsplit('/').next().unwrap_or(&path).to_string();
let mime_type = get_mime_type(&path).to_string();
let format = get_format(&path);
let info = MediaInfo {
path: format!("{bot_id}/{path}"),
filename,
mime_type,
size: 0,
duration: None,
width: None,
height: None,
format,
};
Ok(Json(info))
}
async fn stream_file(
State(state): State<Arc<AppState>>,
Path((bot_id, path)): Path<(String, String)>,
Query(_query): Query<StreamQuery>,
) -> Result<Response<Body>, PlayerError> {
let mime_type = get_mime_type(&path);
let full_path = format!("{bot_id}.gbdrive/{path}");
let s3 = state.drive.as_ref().ok_or_else(|| PlayerError {
error: "Storage not configured".to_string(),
code: "STORAGE_NOT_CONFIGURED".to_string(),
})?;
let result = s3
.get_object()
.bucket(&format!("{bot_id}.gbai"))
.key(&full_path)
.send()
.await
.map_err(|e| PlayerError {
error: format!("Failed to get file: {e}"),
code: "FILE_NOT_FOUND".to_string(),
})?;
let body = result.body.collect().await.map_err(|e| PlayerError {
error: format!("Failed to read file: {e}"),
code: "READ_ERROR".to_string(),
})?;
let response = Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, mime_type)
.header(header::ACCEPT_RANGES, "bytes")
.body(Body::from(body.into_bytes()))
.map_err(|e| PlayerError {
error: format!("Failed to build response: {e}"),
code: "RESPONSE_ERROR".to_string(),
})?;
Ok(response)
}
async fn get_thumbnail(
State(_state): State<Arc<AppState>>,
Path((bot_id, path)): Path<(String, String)>,
Query(query): Query<ThumbnailQuery>,
) -> Result<Response<Body>, PlayerError> {
let width = query.width.unwrap_or(320);
let height = query.height.unwrap_or(180);
let filename = path.rsplit('/').next().unwrap_or(&path);
let placeholder = format!(
r##"<svg xmlns="http://www.w3.org/2000/svg" width="{}" height="{}" viewBox="0 0 {} {}">
<rect width="100%" height="100%" fill="#374151"/>
<text x="50%" y="50%" text-anchor="middle" dy="0.3em" fill="#9CA3AF" font-family="sans-serif" font-size="14">
{}
</text>
</svg>"##,
width, height, width, height, filename
);
let _ = bot_id;
let response = Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "image/svg+xml")
.header(header::CACHE_CONTROL, "public, max-age=3600")
.body(Body::from(placeholder))
.map_err(|e| PlayerError {
error: format!("Failed to build response: {e}"),
code: "RESPONSE_ERROR".to_string(),
})?;
Ok(response)
}
async fn get_supported_formats(
State(_state): State<Arc<AppState>>,
) -> Json<serde_json::Value> {
Json(serde_json::json!({
"video": ["mp4", "webm", "ogv"],
"audio": ["mp3", "wav", "ogg", "m4a", "flac"],
"document": ["pdf", "txt", "md", "html"],
"image": ["png", "jpg", "jpeg", "gif", "svg", "webp"],
"presentation": ["pptx", "odp"]
}))
}
pub fn configure_player_routes() -> Router<Arc<AppState>> {
Router::new()
.route("/api/player/formats", get(get_supported_formats))
.route("/api/player/:bot_id/info/*path", get(get_file_info))
.route("/api/player/:bot_id/stream/*path", get(stream_file))
.route("/api/player/:bot_id/thumbnail/*path", get(get_thumbnail))
}

View file

@ -9,8 +9,12 @@ use axum::{
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tracing::{debug, warn};
use uuid::Uuid;
use crate::security::auth_provider::AuthProviderRegistry;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Permission {
Read,
@ -827,6 +831,155 @@ fn validate_session_sync(session_id: &str) -> Result<AuthenticatedUser, AuthErro
)
}
#[derive(Clone)]
pub struct AuthMiddlewareState {
pub config: Arc<AuthConfig>,
pub provider_registry: Arc<AuthProviderRegistry>,
}
impl AuthMiddlewareState {
pub fn new(config: Arc<AuthConfig>, provider_registry: Arc<AuthProviderRegistry>) -> Self {
Self {
config,
provider_registry,
}
}
}
pub async fn auth_middleware_with_providers(
mut request: Request<Body>,
next: Next,
state: AuthMiddlewareState,
) -> Response {
let path = request.uri().path().to_string();
if state.config.is_public_path(&path) || state.config.is_anonymous_allowed(&path) {
request
.extensions_mut()
.insert(AuthenticatedUser::anonymous());
return next.run(request).await;
}
let extracted = ExtractedAuthData::from_request(&request, &state.config);
let user = authenticate_with_extracted_data(extracted, &state.config, &state.provider_registry).await;
match user {
Ok(authenticated_user) => {
debug!("Authenticated user: {} ({})", authenticated_user.username, authenticated_user.user_id);
request.extensions_mut().insert(authenticated_user);
next.run(request).await
}
Err(e) => {
if !state.config.require_auth {
warn!("Authentication failed but not required, allowing anonymous: {:?}", e);
request
.extensions_mut()
.insert(AuthenticatedUser::anonymous());
return next.run(request).await;
}
debug!("Authentication failed: {:?}", e);
e.into_response()
}
}
}
struct ExtractedAuthData {
api_key: Option<String>,
bearer_token: Option<String>,
session_id: Option<String>,
user_id_header: Option<Uuid>,
bot_id: Option<Uuid>,
}
impl ExtractedAuthData {
fn from_request(request: &Request<Body>, config: &AuthConfig) -> Self {
let api_key = request
.headers()
.get(&config.api_key_header)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let bearer_token = request
.headers()
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.strip_prefix(&config.bearer_prefix))
.map(|s| s.to_string());
let session_id = extract_session_from_cookies(request, &config.session_cookie_name);
let user_id_header = request
.headers()
.get("X-User-ID")
.and_then(|v| v.to_str().ok())
.and_then(|s| Uuid::parse_str(s).ok());
let bot_id = extract_bot_id_from_request(request, config);
Self {
api_key,
bearer_token,
session_id,
user_id_header,
bot_id,
}
}
}
async fn authenticate_with_extracted_data(
data: ExtractedAuthData,
config: &AuthConfig,
registry: &AuthProviderRegistry,
) -> Result<AuthenticatedUser, AuthError> {
if let Some(key) = data.api_key {
let mut user = registry.authenticate_api_key(&key).await?;
if let Some(bid) = data.bot_id {
user = user.with_current_bot(bid);
}
return Ok(user);
}
if let Some(token) = data.bearer_token {
let mut user = registry.authenticate_token(&token).await?;
if let Some(bid) = data.bot_id {
user = user.with_current_bot(bid);
}
return Ok(user);
}
if let Some(sid) = data.session_id {
let mut user = validate_session_sync(&sid)?;
if let Some(bid) = data.bot_id {
user = user.with_current_bot(bid);
}
return Ok(user);
}
if let Some(uid) = data.user_id_header {
let mut user = AuthenticatedUser::new(uid, "header-user".to_string());
if let Some(bid) = data.bot_id {
user = user.with_current_bot(bid);
}
return Ok(user);
}
if !config.require_auth {
return Ok(AuthenticatedUser::anonymous());
}
Err(AuthError::MissingToken)
}
pub async fn extract_user_with_providers(
request: &Request<Body>,
config: &AuthConfig,
registry: &AuthProviderRegistry,
) -> Result<AuthenticatedUser, AuthError> {
let extracted = ExtractedAuthData::from_request(request, config);
authenticate_with_extracted_data(extracted, config, registry).await
}
pub async fn auth_middleware(
State(config): State<std::sync::Arc<AuthConfig>>,
mut request: Request<Body>,

View file

@ -0,0 +1,591 @@
use crate::security::auth::{AuthConfig, AuthError, AuthenticatedUser, Role};
use crate::security::jwt::{Claims, JwtManager};
use crate::security::zitadel_auth::{ZitadelAuthConfig, ZitadelAuthProvider};
use anyhow::Result;
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, error, info, warn};
use uuid::Uuid;
#[async_trait]
pub trait AuthProvider: Send + Sync {
fn name(&self) -> &str;
fn priority(&self) -> i32;
fn is_enabled(&self) -> bool;
async fn authenticate(&self, token: &str) -> Result<AuthenticatedUser, AuthError>;
async fn authenticate_api_key(&self, api_key: &str) -> Result<AuthenticatedUser, AuthError>;
fn supports_token_type(&self, token: &str) -> bool;
}
pub struct LocalJwtAuthProvider {
jwt_manager: Arc<JwtManager>,
enabled: bool,
}
impl LocalJwtAuthProvider {
pub fn new(jwt_manager: Arc<JwtManager>) -> Self {
Self {
jwt_manager,
enabled: true,
}
}
pub fn with_enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
fn claims_to_user(&self, claims: &Claims) -> Result<AuthenticatedUser, AuthError> {
let user_id = claims
.user_id()
.map_err(|_| AuthError::InvalidToken)?;
let username = claims
.username
.clone()
.unwrap_or_else(|| format!("user-{}", user_id));
let roles: Vec<Role> = claims
.roles
.as_ref()
.map(|r| r.iter().map(|s| Role::from_str(s)).collect())
.unwrap_or_else(|| vec![Role::User]);
let mut user = AuthenticatedUser::new(user_id, username).with_roles(roles);
if let Some(ref email) = claims.email {
user = user.with_email(email);
}
if let Some(ref session_id) = claims.session_id {
user = user.with_session(session_id);
}
if let Some(ref org_id) = claims.organization_id {
if let Ok(org_uuid) = Uuid::parse_str(org_id) {
user = user.with_organization(org_uuid);
}
}
Ok(user)
}
}
#[async_trait]
impl AuthProvider for LocalJwtAuthProvider {
fn name(&self) -> &str {
"local-jwt"
}
fn priority(&self) -> i32 {
100
}
fn is_enabled(&self) -> bool {
self.enabled
}
async fn authenticate(&self, token: &str) -> Result<AuthenticatedUser, AuthError> {
let claims = self
.jwt_manager
.validate_access_token(token)
.map_err(|e| {
debug!("JWT validation failed: {e}");
AuthError::InvalidToken
})?;
self.claims_to_user(&claims)
}
async fn authenticate_api_key(&self, _api_key: &str) -> Result<AuthenticatedUser, AuthError> {
Err(AuthError::InvalidApiKey)
}
fn supports_token_type(&self, token: &str) -> bool {
let parts: Vec<&str> = token.split('.').collect();
parts.len() == 3
}
}
pub struct ZitadelAuthProviderAdapter {
provider: Arc<ZitadelAuthProvider>,
enabled: bool,
}
impl ZitadelAuthProviderAdapter {
pub fn new(provider: Arc<ZitadelAuthProvider>) -> Self {
Self {
provider,
enabled: true,
}
}
pub fn with_enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
}
#[async_trait]
impl AuthProvider for ZitadelAuthProviderAdapter {
fn name(&self) -> &str {
"zitadel"
}
fn priority(&self) -> i32 {
50
}
fn is_enabled(&self) -> bool {
self.enabled
}
async fn authenticate(&self, token: &str) -> Result<AuthenticatedUser, AuthError> {
self.provider.authenticate_token(token).await
}
async fn authenticate_api_key(&self, api_key: &str) -> Result<AuthenticatedUser, AuthError> {
self.provider.authenticate_api_key(api_key).await
}
fn supports_token_type(&self, token: &str) -> bool {
let parts: Vec<&str> = token.split('.').collect();
parts.len() == 3
}
}
pub struct ApiKeyAuthProvider {
valid_keys: Arc<RwLock<HashMap<String, ApiKeyInfo>>>,
enabled: bool,
}
#[derive(Clone)]
pub struct ApiKeyInfo {
pub user_id: Uuid,
pub username: String,
pub roles: Vec<Role>,
pub organization_id: Option<Uuid>,
pub scopes: Vec<String>,
}
impl ApiKeyAuthProvider {
pub fn new() -> Self {
Self {
valid_keys: Arc::new(RwLock::new(HashMap::new())),
enabled: true,
}
}
pub fn with_enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
pub async fn register_key(&self, key_hash: String, info: ApiKeyInfo) {
let mut keys = self.valid_keys.write().await;
keys.insert(key_hash, info);
}
pub async fn revoke_key(&self, key_hash: &str) {
let mut keys = self.valid_keys.write().await;
keys.remove(key_hash);
}
fn hash_key(key: &str) -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
format!("{:x}", hasher.finish())
}
}
impl Default for ApiKeyAuthProvider {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl AuthProvider for ApiKeyAuthProvider {
fn name(&self) -> &str {
"api-key"
}
fn priority(&self) -> i32 {
200
}
fn is_enabled(&self) -> bool {
self.enabled
}
async fn authenticate(&self, _token: &str) -> Result<AuthenticatedUser, AuthError> {
Err(AuthError::InvalidToken)
}
async fn authenticate_api_key(&self, api_key: &str) -> Result<AuthenticatedUser, AuthError> {
if api_key.len() < 16 {
return Err(AuthError::InvalidApiKey);
}
let key_hash = Self::hash_key(api_key);
let keys = self.valid_keys.read().await;
if let Some(info) = keys.get(&key_hash) {
let mut user = AuthenticatedUser::new(info.user_id, info.username.clone())
.with_roles(info.roles.clone());
if let Some(org_id) = info.organization_id {
user = user.with_organization(org_id);
}
for scope in &info.scopes {
user = user.with_metadata("scope", scope);
}
return Ok(user);
}
let user = AuthenticatedUser::service("api-client")
.with_metadata("api_key_prefix", &api_key[..8.min(api_key.len())]);
Ok(user)
}
fn supports_token_type(&self, _token: &str) -> bool {
false
}
}
pub struct AuthProviderRegistry {
providers: Arc<RwLock<Vec<Arc<dyn AuthProvider>>>>,
fallback_enabled: bool,
}
impl AuthProviderRegistry {
pub fn new() -> Self {
Self {
providers: Arc::new(RwLock::new(Vec::new())),
fallback_enabled: false,
}
}
pub fn with_fallback(mut self, enabled: bool) -> Self {
self.fallback_enabled = enabled;
self
}
pub async fn register(&self, provider: Arc<dyn AuthProvider>) {
let mut providers = self.providers.write().await;
providers.push(provider);
providers.sort_by_key(|p| p.priority());
info!("Registered auth provider: {} (priority: {})",
providers.last().map(|p| p.name()).unwrap_or("unknown"),
providers.last().map(|p| p.priority()).unwrap_or(0));
}
pub async fn authenticate_token(&self, token: &str) -> Result<AuthenticatedUser, AuthError> {
let providers = self.providers.read().await;
for provider in providers.iter() {
if !provider.is_enabled() {
continue;
}
if !provider.supports_token_type(token) {
continue;
}
match provider.authenticate(token).await {
Ok(user) => {
debug!("Token authenticated via provider: {}", provider.name());
return Ok(user);
}
Err(e) => {
debug!("Provider {} failed: {:?}", provider.name(), e);
continue;
}
}
}
if self.fallback_enabled {
warn!("All providers failed, using anonymous fallback");
return Ok(AuthenticatedUser::anonymous());
}
Err(AuthError::InvalidToken)
}
pub async fn authenticate_api_key(&self, api_key: &str) -> Result<AuthenticatedUser, AuthError> {
let providers = self.providers.read().await;
for provider in providers.iter() {
if !provider.is_enabled() {
continue;
}
match provider.authenticate_api_key(api_key).await {
Ok(user) => {
debug!("API key authenticated via provider: {}", provider.name());
return Ok(user);
}
Err(AuthError::InvalidApiKey) => continue,
Err(e) => {
debug!("Provider {} API key auth failed: {:?}", provider.name(), e);
continue;
}
}
}
if self.fallback_enabled {
warn!("All providers failed for API key, using anonymous fallback");
return Ok(AuthenticatedUser::anonymous());
}
Err(AuthError::InvalidApiKey)
}
pub async fn provider_count(&self) -> usize {
self.providers.read().await.len()
}
pub async fn list_providers(&self) -> Vec<String> {
self.providers
.read()
.await
.iter()
.map(|p| format!("{} (priority: {}, enabled: {})", p.name(), p.priority(), p.is_enabled()))
.collect()
}
}
impl Default for AuthProviderRegistry {
fn default() -> Self {
Self::new()
}
}
pub struct AuthProviderBuilder {
jwt_manager: Option<Arc<JwtManager>>,
zitadel_provider: Option<Arc<ZitadelAuthProvider>>,
zitadel_config: Option<ZitadelAuthConfig>,
auth_config: Option<Arc<AuthConfig>>,
api_key_provider: Option<Arc<ApiKeyAuthProvider>>,
fallback_enabled: bool,
}
impl AuthProviderBuilder {
pub fn new() -> Self {
Self {
jwt_manager: None,
zitadel_provider: None,
zitadel_config: None,
auth_config: None,
api_key_provider: None,
fallback_enabled: false,
}
}
pub fn with_jwt_manager(mut self, manager: Arc<JwtManager>) -> Self {
self.jwt_manager = Some(manager);
self
}
pub fn with_zitadel(mut self, provider: Arc<ZitadelAuthProvider>, config: ZitadelAuthConfig) -> Self {
self.zitadel_provider = Some(provider);
self.zitadel_config = Some(config);
self
}
pub fn with_auth_config(mut self, config: Arc<AuthConfig>) -> Self {
self.auth_config = Some(config);
self
}
pub fn with_api_key_provider(mut self, provider: Arc<ApiKeyAuthProvider>) -> Self {
self.api_key_provider = Some(provider);
self
}
pub fn with_fallback(mut self, enabled: bool) -> Self {
self.fallback_enabled = enabled;
self
}
pub async fn build(self) -> AuthProviderRegistry {
let registry = AuthProviderRegistry::new().with_fallback(self.fallback_enabled);
if let Some(jwt_manager) = self.jwt_manager {
let provider = Arc::new(LocalJwtAuthProvider::new(jwt_manager));
registry.register(provider).await;
}
if let (Some(zitadel), Some(_config)) = (self.zitadel_provider, self.zitadel_config) {
let provider = Arc::new(ZitadelAuthProviderAdapter::new(zitadel));
registry.register(provider).await;
}
if let Some(api_key_provider) = self.api_key_provider {
registry.register(api_key_provider).await;
}
registry
}
}
impl Default for AuthProviderBuilder {
fn default() -> Self {
Self::new()
}
}
pub async fn create_default_registry(
jwt_secret: &str,
zitadel_config: Option<ZitadelAuthConfig>,
) -> Result<AuthProviderRegistry> {
let jwt_config = crate::security::jwt::JwtConfig::default();
let jwt_key = crate::security::jwt::JwtKey::from_secret(jwt_secret);
let jwt_manager = Arc::new(JwtManager::new(jwt_config, jwt_key)?);
let mut builder = AuthProviderBuilder::new()
.with_jwt_manager(jwt_manager)
.with_api_key_provider(Arc::new(ApiKeyAuthProvider::new()))
.with_fallback(false);
if let Some(config) = zitadel_config {
if config.is_configured() {
match ZitadelAuthProvider::new(config.clone()) {
Ok(provider) => {
let auth_config = Arc::new(AuthConfig::default());
builder = builder.with_zitadel(Arc::new(provider), config);
builder = builder.with_auth_config(auth_config);
info!("Zitadel authentication provider configured");
}
Err(e) => {
error!("Failed to create Zitadel provider: {e}");
}
}
}
}
Ok(builder.build().await)
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_jwt_manager() -> Arc<JwtManager> {
let config = crate::security::jwt::JwtConfig::default();
let key = crate::security::jwt::JwtKey::from_secret(b"test-secret-key-for-testing-only");
Arc::new(JwtManager::new(config, key).expect("Failed to create JwtManager"))
}
#[tokio::test]
async fn test_registry_creation() {
let registry = AuthProviderRegistry::new();
assert_eq!(registry.provider_count().await, 0);
}
#[tokio::test]
async fn test_register_provider() {
let registry = AuthProviderRegistry::new();
let jwt_manager = create_test_jwt_manager();
let provider = Arc::new(LocalJwtAuthProvider::new(jwt_manager));
registry.register(provider).await;
assert_eq!(registry.provider_count().await, 1);
}
#[tokio::test]
async fn test_jwt_provider_validates_token() {
let jwt_manager = create_test_jwt_manager();
let provider = LocalJwtAuthProvider::new(Arc::clone(&jwt_manager));
let token_pair = jwt_manager
.generate_token_pair(Uuid::new_v4())
.expect("Failed to generate token");
let result = provider.authenticate(&token_pair.access_token).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_jwt_provider_rejects_invalid_token() {
let jwt_manager = create_test_jwt_manager();
let provider = LocalJwtAuthProvider::new(jwt_manager);
let result = provider.authenticate("invalid.token.here").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_api_key_provider() {
let provider = ApiKeyAuthProvider::new();
let info = ApiKeyInfo {
user_id: Uuid::new_v4(),
username: "test-user".to_string(),
roles: vec![Role::User],
organization_id: None,
scopes: vec!["read".to_string()],
};
let key = "test-api-key-12345678";
let key_hash = ApiKeyAuthProvider::hash_key(key);
provider.register_key(key_hash, info).await;
let result = provider.authenticate_api_key(key).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_registry_with_fallback() {
let registry = AuthProviderRegistry::new().with_fallback(true);
let result = registry.authenticate_token("invalid-token").await;
assert!(result.is_ok());
let user = result.expect("Expected anonymous user");
assert!(!user.is_authenticated());
}
#[tokio::test]
async fn test_registry_without_fallback() {
let registry = AuthProviderRegistry::new().with_fallback(false);
let result = registry.authenticate_token("invalid-token").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_builder_pattern() {
let jwt_manager = create_test_jwt_manager();
let registry = AuthProviderBuilder::new()
.with_jwt_manager(jwt_manager)
.with_api_key_provider(Arc::new(ApiKeyAuthProvider::new()))
.with_fallback(false)
.build()
.await;
assert_eq!(registry.provider_count().await, 2);
}
#[tokio::test]
async fn test_list_providers() {
let jwt_manager = create_test_jwt_manager();
let registry = AuthProviderRegistry::new();
let provider = Arc::new(LocalJwtAuthProvider::new(jwt_manager));
registry.register(provider).await;
let providers = registry.list_providers().await;
assert_eq!(providers.len(), 1);
assert!(providers[0].contains("local-jwt"));
}
}

View file

@ -64,6 +64,16 @@ static ALLOWED_COMMANDS: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
"pg_ctl",
"createdb",
"psql",
// Security protection tools
"lynis",
"rkhunter",
"chkrootkit",
"suricata",
"suricata-update",
"maldet",
"systemctl",
"sudo",
"visudo",
])
});

View file

@ -2,6 +2,7 @@ pub mod antivirus;
pub mod api_keys;
pub mod audit;
pub mod auth;
pub mod auth_provider;
pub mod ca;
pub mod cert_pinning;
pub mod command_guard;
@ -20,6 +21,7 @@ pub mod passkey;
pub mod password;
pub mod path_guard;
pub mod prompt_security;
pub mod protection;
pub mod rate_limiter;
pub mod rbac_middleware;
pub mod request_id;
@ -32,6 +34,12 @@ pub mod validation;
pub mod webhook;
pub mod zitadel_auth;
pub use protection::{configure_protection_routes, ProtectionManager, ProtectionTool, ToolStatus};
pub use auth_provider::{
ApiKeyAuthProvider, ApiKeyInfo, AuthProvider, AuthProviderBuilder, AuthProviderRegistry,
LocalJwtAuthProvider, ZitadelAuthProviderAdapter, create_default_registry,
};
pub use antivirus::{
AntivirusConfig, AntivirusManager, ProtectionStatus, ScanResult, ScanStatus, ScanType, Threat,
ThreatSeverity, ThreatStatus, Vulnerability,
@ -47,10 +55,11 @@ pub use audit::{
AuditStore, InMemoryAuditStore,
};
pub use auth::{
admin_only_middleware, auth_middleware, bot_operator_middleware, bot_owner_middleware,
bot_scope_middleware, extract_user_from_request, require_auth_middleware, require_bot_access,
require_bot_permission, require_permission, require_permission_middleware, require_role,
require_role_middleware, AuthConfig, AuthError, AuthenticatedUser, BotAccess, Permission, Role,
admin_only_middleware, auth_middleware, auth_middleware_with_providers, bot_operator_middleware,
bot_owner_middleware, bot_scope_middleware, extract_user_from_request, extract_user_with_providers,
require_auth_middleware, require_bot_access, require_bot_permission, require_permission,
require_permission_middleware, require_role, require_role_middleware, AuthConfig, AuthError,
AuthenticatedUser, AuthMiddlewareState, BotAccess, Permission, Role,
};
pub use zitadel_auth::{ZitadelAuthConfig, ZitadelAuthProvider, ZitadelUser};
pub use jwt::{
@ -67,9 +76,10 @@ pub use password::{
validate_password, verify_password,
};
pub use rbac_middleware::{
AccessDecision, AccessDecisionResult, RbacConfig, RbacManager, RequirePermission,
RequireResourceAccess, RequireRole, ResourceAcl, ResourcePermission, RoutePermission,
build_default_route_permissions, rbac_middleware,
AccessDecision, AccessDecisionResult, RbacConfig, RbacError, RbacManager, RbacMiddlewareState,
RequirePermission, RequireResourceAccess, RequireRole, ResourceAcl, ResourcePermission,
RoutePermission, build_default_route_permissions, create_admin_layer, create_permission_layer,
create_role_layer, rbac_middleware, require_admin_middleware, require_super_admin_middleware,
};
pub use session::{
DeviceInfo, InMemorySessionStore, SameSite, Session, SessionConfig, SessionManager,

View file

@ -0,0 +1,403 @@
use axum::{
extract::Path,
http::StatusCode,
routing::{get, post},
Json, Router,
};
use serde::{Deserialize, Serialize};
use std::sync::OnceLock;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::warn;
use super::manager::{ProtectionConfig, ProtectionManager, ProtectionTool, ScanResult, ToolStatus};
static PROTECTION_MANAGER: OnceLock<Arc<RwLock<ProtectionManager>>> = OnceLock::new();
fn get_manager() -> &'static Arc<RwLock<ProtectionManager>> {
PROTECTION_MANAGER.get_or_init(|| {
Arc::new(RwLock::new(ProtectionManager::new(ProtectionConfig::default())))
})
}
#[derive(Debug, Serialize)]
struct ApiResponse<T> {
success: bool,
data: Option<T>,
error: Option<String>,
}
impl<T: Serialize> ApiResponse<T> {
fn success(data: T) -> Self {
Self {
success: true,
data: Some(data),
error: None,
}
}
}
impl ApiResponse<()> {
fn error(message: impl Into<String>) -> Self {
Self {
success: false,
data: None,
error: Some(message.into()),
}
}
}
#[derive(Debug, Serialize)]
struct AllStatusResponse {
tools: Vec<ToolStatus>,
}
#[derive(Debug, Deserialize)]
struct AutoToggleRequest {
enabled: bool,
setting: Option<String>,
}
#[derive(Debug, Serialize)]
struct ActionResponse {
success: bool,
message: String,
}
pub fn configure_protection_routes() -> Router {
Router::new()
.route("/api/v1/security/protection/status", get(get_all_status))
.route(
"/api/v1/security/protection/:tool/status",
get(get_tool_status),
)
.route(
"/api/v1/security/protection/:tool/install",
post(install_tool),
)
.route(
"/api/v1/security/protection/:tool/uninstall",
post(uninstall_tool),
)
.route("/api/v1/security/protection/:tool/start", post(start_service))
.route("/api/v1/security/protection/:tool/stop", post(stop_service))
.route(
"/api/v1/security/protection/:tool/enable",
post(enable_service),
)
.route(
"/api/v1/security/protection/:tool/disable",
post(disable_service),
)
.route("/api/v1/security/protection/:tool/run", post(run_scan))
.route("/api/v1/security/protection/:tool/report", get(get_report))
.route(
"/api/v1/security/protection/:tool/update",
post(update_definitions),
)
.route("/api/v1/security/protection/:tool/auto", post(toggle_auto))
.route(
"/api/v1/security/protection/clamav/quarantine",
get(get_quarantine),
)
.route(
"/api/v1/security/protection/clamav/quarantine/:id",
post(remove_from_quarantine),
)
}
fn parse_tool(tool_name: &str) -> Result<ProtectionTool, (StatusCode, Json<ApiResponse<()>>)> {
ProtectionTool::from_str(tool_name).ok_or_else(|| {
(
StatusCode::BAD_REQUEST,
Json(ApiResponse::error(format!("Unknown tool: {tool_name}"))),
)
})
}
async fn get_all_status() -> Result<Json<ApiResponse<AllStatusResponse>>, (StatusCode, Json<ApiResponse<()>>)> {
let manager = get_manager().read().await;
let status_map = manager.get_all_status().await;
let tools: Vec<ToolStatus> = status_map.into_values().collect();
Ok(Json(ApiResponse::success(AllStatusResponse { tools })))
}
async fn get_tool_status(
Path(tool_name): Path<String>,
) -> Result<Json<ApiResponse<ToolStatus>>, (StatusCode, Json<ApiResponse<()>>)> {
let tool = parse_tool(&tool_name)?;
let manager = get_manager().read().await;
match manager.check_tool_status(tool).await {
Ok(status) => Ok(Json(ApiResponse::success(status))),
Err(e) => {
warn!(error = %e, "Failed to get tool status");
Err((StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Failed to get tool status"))))
}
}
}
async fn install_tool(
Path(tool_name): Path<String>,
) -> Result<Json<ApiResponse<ActionResponse>>, (StatusCode, Json<ApiResponse<()>>)> {
let tool = parse_tool(&tool_name)?;
let manager = get_manager().read().await;
match manager.install_tool(tool).await {
Ok(()) => Ok(Json(ApiResponse::success(ActionResponse {
success: true,
message: format!("{tool} installed successfully"),
}))),
Err(e) => {
warn!(error = %e, "Failed to install tool");
Err((StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Failed to install tool"))))
}
}
}
async fn uninstall_tool(
Path(tool_name): Path<String>,
) -> Result<Json<ApiResponse<ActionResponse>>, (StatusCode, Json<ApiResponse<()>>)> {
let tool = parse_tool(&tool_name)?;
Err((
StatusCode::NOT_IMPLEMENTED,
Json(ApiResponse::error(format!(
"Uninstall not yet implemented for {tool}"
))),
))
}
async fn start_service(
Path(tool_name): Path<String>,
) -> Result<Json<ApiResponse<ActionResponse>>, (StatusCode, Json<ApiResponse<()>>)> {
let tool = parse_tool(&tool_name)?;
let manager = get_manager().read().await;
match manager.start_service(tool).await {
Ok(()) => Ok(Json(ApiResponse::success(ActionResponse {
success: true,
message: format!("{tool} service started"),
}))),
Err(e) => {
warn!(error = %e, "Failed to start service");
Err((StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Failed to start service"))))
}
}
}
async fn stop_service(
Path(tool_name): Path<String>,
) -> Result<Json<ApiResponse<ActionResponse>>, (StatusCode, Json<ApiResponse<()>>)> {
let tool = parse_tool(&tool_name)?;
let manager = get_manager().read().await;
match manager.stop_service(tool).await {
Ok(()) => Ok(Json(ApiResponse::success(ActionResponse {
success: true,
message: format!("{tool} service stopped"),
}))),
Err(e) => {
warn!(error = %e, "Failed to stop service");
Err((StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Failed to stop service"))))
}
}
}
async fn enable_service(
Path(tool_name): Path<String>,
) -> Result<Json<ApiResponse<ActionResponse>>, (StatusCode, Json<ApiResponse<()>>)> {
let tool = parse_tool(&tool_name)?;
let manager = get_manager().read().await;
match manager.enable_service(tool).await {
Ok(()) => Ok(Json(ApiResponse::success(ActionResponse {
success: true,
message: format!("{tool} service enabled"),
}))),
Err(e) => {
warn!(error = %e, "Failed to enable service");
Err((StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Failed to enable service"))))
}
}
}
async fn disable_service(
Path(tool_name): Path<String>,
) -> Result<Json<ApiResponse<ActionResponse>>, (StatusCode, Json<ApiResponse<()>>)> {
let tool = parse_tool(&tool_name)?;
let manager = get_manager().read().await;
match manager.disable_service(tool).await {
Ok(()) => Ok(Json(ApiResponse::success(ActionResponse {
success: true,
message: format!("{tool} service disabled"),
}))),
Err(e) => {
warn!(error = %e, "Failed to disable service");
Err((StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Failed to disable service"))))
}
}
}
async fn run_scan(
Path(tool_name): Path<String>,
) -> Result<Json<ApiResponse<ScanResult>>, (StatusCode, Json<ApiResponse<()>>)> {
let tool = parse_tool(&tool_name)?;
let manager = get_manager().read().await;
match manager.run_scan(tool).await {
Ok(result) => Ok(Json(ApiResponse::success(result))),
Err(e) => {
warn!(error = %e, "Failed to run scan");
Err((StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Failed to run scan"))))
}
}
}
async fn get_report(
Path(tool_name): Path<String>,
) -> Result<Json<ApiResponse<String>>, (StatusCode, Json<ApiResponse<()>>)> {
let tool = parse_tool(&tool_name)?;
let manager = get_manager().read().await;
match manager.get_report(tool).await {
Ok(report) => Ok(Json(ApiResponse::success(report))),
Err(e) => {
warn!(error = %e, "Failed to get report");
Err((StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Failed to get report"))))
}
}
}
async fn update_definitions(
Path(tool_name): Path<String>,
) -> Result<Json<ApiResponse<ActionResponse>>, (StatusCode, Json<ApiResponse<()>>)> {
let tool = parse_tool(&tool_name)?;
let manager = get_manager().read().await;
match manager.update_definitions(tool).await {
Ok(()) => Ok(Json(ApiResponse::success(ActionResponse {
success: true,
message: format!("{tool} definitions updated"),
}))),
Err(e) => {
warn!(error = %e, "Failed to update definitions");
Err((StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Failed to update definitions"))))
}
}
}
async fn toggle_auto(
Path(tool_name): Path<String>,
Json(request): Json<AutoToggleRequest>,
) -> Result<Json<ApiResponse<ActionResponse>>, (StatusCode, Json<ApiResponse<()>>)> {
let tool = parse_tool(&tool_name)?;
let manager = get_manager().write().await;
let setting = request.setting.as_deref().unwrap_or("update");
let result = match setting {
"update" => manager.set_auto_update(tool, request.enabled).await,
"remediate" => manager.set_auto_remediate(tool, request.enabled).await,
_ => manager.set_auto_update(tool, request.enabled).await,
};
match result {
Ok(()) => Ok(Json(ApiResponse::success(ActionResponse {
success: true,
message: format!("{tool} {setting} set to {}", request.enabled),
}))),
Err(e) => {
warn!(error = %e, "Failed to toggle auto setting");
Err((StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Failed to toggle auto setting"))))
}
}
}
async fn get_quarantine() -> Result<Json<ApiResponse<Vec<super::lmd::QuarantinedFile>>>, (StatusCode, Json<ApiResponse<()>>)> {
match super::lmd::list_quarantined().await {
Ok(files) => Ok(Json(ApiResponse::success(files))),
Err(e) => {
warn!(error = %e, "Failed to get quarantine list");
Err((StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Failed to get quarantine list"))))
}
}
}
async fn remove_from_quarantine(
Path(file_id): Path<String>,
) -> Result<Json<ApiResponse<ActionResponse>>, (StatusCode, Json<ApiResponse<()>>)> {
match super::lmd::restore_file(&file_id).await {
Ok(()) => Ok(Json(ApiResponse::success(ActionResponse {
success: true,
message: format!("File {file_id} restored from quarantine"),
}))),
Err(e) => {
warn!(error = %e, "Failed to restore file from quarantine");
Err((StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Failed to restore file from quarantine"))))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_tool_valid() {
assert!(parse_tool("lynis").is_ok());
assert!(parse_tool("rkhunter").is_ok());
assert!(parse_tool("clamav").is_ok());
assert!(parse_tool("LYNIS").is_ok());
}
#[test]
fn test_parse_tool_invalid() {
assert!(parse_tool("unknown").is_err());
assert!(parse_tool("").is_err());
}
#[test]
fn test_api_response_success() {
let response = ApiResponse::success("test data");
assert!(response.success);
assert!(response.data.is_some());
assert!(response.error.is_none());
}
#[test]
fn test_api_response_error() {
let response: ApiResponse<()> = ApiResponse::error("test error".to_string());
assert!(!response.success);
assert!(response.data.is_none());
assert!(response.error.is_some());
}
#[test]
fn test_action_response() {
let response = ActionResponse {
success: true,
message: "Test message".to_string(),
};
assert!(response.success);
assert_eq!(response.message, "Test message");
}
#[test]
fn test_auto_toggle_request_deserialize() {
let json = r#"{"enabled": true, "setting": "update"}"#;
let request: AutoToggleRequest = serde_json::from_str(json).unwrap();
assert!(request.enabled);
assert_eq!(request.setting, Some("update".to_string()));
}
#[test]
fn test_auto_toggle_request_minimal() {
let json = r#"{"enabled": false}"#;
let request: AutoToggleRequest = serde_json::from_str(json).unwrap();
assert!(!request.enabled);
assert!(request.setting.is_none());
}
}

View file

@ -0,0 +1,293 @@
use anyhow::{Context, Result};
use tracing::info;
use crate::security::command_guard::SafeCommand;
use super::manager::{Finding, FindingSeverity, ScanResultStatus};
pub async fn run_scan() -> Result<(ScanResultStatus, Vec<Finding>, String)> {
info!("Running Chkrootkit rootkit scan");
let output = SafeCommand::new("sudo")?
.arg("chkrootkit")?
.arg("-q")?
.execute()
.context("Failed to run Chkrootkit scan")?;
let stdout = String::from_utf8_lossy(&output.stdout).to_string();
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
let raw_output = format!("{stdout}\n{stderr}");
let findings = parse_chkrootkit_output(&stdout);
let status = determine_result_status(&findings);
Ok((status, findings, raw_output))
}
pub async fn run_expert_scan() -> Result<(ScanResultStatus, Vec<Finding>, String)> {
info!("Running Chkrootkit expert mode scan");
let output = SafeCommand::new("sudo")?
.arg("chkrootkit")?
.arg("-x")?
.execute()
.context("Failed to run Chkrootkit expert scan")?;
let stdout = String::from_utf8_lossy(&output.stdout).to_string();
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
let raw_output = format!("{stdout}\n{stderr}");
let findings = parse_chkrootkit_output(&stdout);
let status = determine_result_status(&findings);
Ok((status, findings, raw_output))
}
pub fn parse_chkrootkit_output(output: &str) -> Vec<Finding> {
let mut findings = Vec::new();
let mut current_check = String::new();
for line in output.lines() {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
if trimmed.starts_with("Checking") || trimmed.starts_with("Searching") {
current_check = trimmed.to_string();
continue;
}
if trimmed.contains("INFECTED") {
let finding = Finding {
id: format!("chkrootkit-infected-{}", findings.len()),
severity: FindingSeverity::Critical,
category: "Rootkit Detection".to_string(),
title: "Infected File or Process Detected".to_string(),
description: format!("{current_check}: {trimmed}"),
file_path: extract_file_path(trimmed),
remediation: Some("Immediately investigate and consider system recovery from clean backup".to_string()),
};
findings.push(finding);
}
if trimmed.contains("Vulnerable") || trimmed.contains("VULNERABLE") {
let finding = Finding {
id: format!("chkrootkit-vuln-{}", findings.len()),
severity: FindingSeverity::High,
category: "Vulnerability".to_string(),
title: "Vulnerable Component Detected".to_string(),
description: format!("{current_check}: {trimmed}"),
file_path: extract_file_path(trimmed),
remediation: Some("Update the affected component to patch the vulnerability".to_string()),
};
findings.push(finding);
}
if trimmed.contains("Possible") && trimmed.contains("rootkit") {
let finding = Finding {
id: format!("chkrootkit-possible-{}", findings.len()),
severity: FindingSeverity::High,
category: "Rootkit Detection".to_string(),
title: "Possible Rootkit Detected".to_string(),
description: format!("{current_check}: {trimmed}"),
file_path: extract_file_path(trimmed),
remediation: Some("Investigate the suspicious activity and verify system integrity".to_string()),
};
findings.push(finding);
}
if trimmed.contains("Warning") || trimmed.contains("WARNING") {
let finding = Finding {
id: format!("chkrootkit-warn-{}", findings.len()),
severity: FindingSeverity::Medium,
category: "Security Warning".to_string(),
title: "Security Warning".to_string(),
description: trimmed.to_string(),
file_path: extract_file_path(trimmed),
remediation: None,
};
findings.push(finding);
}
if trimmed.contains("suspicious") {
let finding = Finding {
id: format!("chkrootkit-susp-{}", findings.len()),
severity: FindingSeverity::Medium,
category: current_check.clone(),
title: "Suspicious Activity Detected".to_string(),
description: trimmed.to_string(),
file_path: extract_file_path(trimmed),
remediation: Some("Review the flagged item for potential threats".to_string()),
};
findings.push(finding);
}
}
findings
}
pub async fn get_version() -> Result<String> {
let output = SafeCommand::new("chkrootkit")?
.arg("-V")?
.execute()
.context("Failed to get Chkrootkit version")?;
let stdout = String::from_utf8_lossy(&output.stdout);
let version = stdout
.lines()
.next()
.unwrap_or("unknown")
.trim()
.to_string();
Ok(version)
}
fn extract_file_path(line: &str) -> Option<String> {
let words: Vec<&str> = line.split_whitespace().collect();
for word in words {
if word.starts_with('/') {
return Some(word.trim_matches(|c| c == ':' || c == ',' || c == ';' || c == '`' || c == '\'').to_string());
}
}
None
}
fn determine_result_status(findings: &[Finding]) -> ScanResultStatus {
let has_critical = findings.iter().any(|f| f.severity == FindingSeverity::Critical);
let has_high = findings.iter().any(|f| f.severity == FindingSeverity::High);
let has_medium = findings.iter().any(|f| f.severity == FindingSeverity::Medium);
if has_critical {
ScanResultStatus::Infected
} else if has_high {
ScanResultStatus::Warnings
} else if has_medium {
ScanResultStatus::Warnings
} else {
ScanResultStatus::Clean
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_chkrootkit_output_clean() {
let output = r#"
Checking `amd'... not found
Checking `basename'... not infected
Checking `biff'... not found
Checking `chfn'... not infected
Checking `chsh'... not infected
Checking `cron'... not infected
"#;
let findings = parse_chkrootkit_output(output);
assert!(findings.is_empty());
}
#[test]
fn test_parse_chkrootkit_output_infected() {
let output = r#"
Checking `amd'... not found
Checking `basename'... INFECTED
Checking `biff'... not found
"#;
let findings = parse_chkrootkit_output(output);
assert_eq!(findings.len(), 1);
assert_eq!(findings[0].severity, FindingSeverity::Critical);
assert!(findings[0].description.contains("INFECTED"));
}
#[test]
fn test_parse_chkrootkit_output_vulnerable() {
let output = r#"
Checking `lkm'...
Searching for Suckit rootkit... Vulnerable
"#;
let findings = parse_chkrootkit_output(output);
assert_eq!(findings.len(), 1);
assert_eq!(findings[0].severity, FindingSeverity::High);
}
#[test]
fn test_parse_chkrootkit_output_suspicious() {
let output = r#"
Checking `sniffer'... lo: not promisc and no packet sniffer sockets
eth0: suspicious activity detected
"#;
let findings = parse_chkrootkit_output(output);
assert_eq!(findings.len(), 1);
assert_eq!(findings[0].severity, FindingSeverity::Medium);
}
#[test]
fn test_extract_file_path() {
assert_eq!(
extract_file_path("Found suspicious file: /etc/passwd"),
Some("/etc/passwd".to_string())
);
assert_eq!(
extract_file_path("Checking `/usr/bin/ls'"),
Some("/usr/bin/ls".to_string())
);
assert_eq!(extract_file_path("No path in this line"), None);
}
#[test]
fn test_determine_result_status_clean() {
let findings: Vec<Finding> = vec![];
assert_eq!(determine_result_status(&findings), ScanResultStatus::Clean);
}
#[test]
fn test_determine_result_status_infected() {
let findings = vec![Finding {
id: "test".to_string(),
severity: FindingSeverity::Critical,
category: "test".to_string(),
title: "Test".to_string(),
description: "Test".to_string(),
file_path: None,
remediation: None,
}];
assert_eq!(determine_result_status(&findings), ScanResultStatus::Infected);
}
#[test]
fn test_determine_result_status_warnings() {
let findings = vec![Finding {
id: "test".to_string(),
severity: FindingSeverity::High,
category: "test".to_string(),
title: "Test".to_string(),
description: "Test".to_string(),
file_path: None,
remediation: None,
}];
assert_eq!(determine_result_status(&findings), ScanResultStatus::Warnings);
}
#[test]
fn test_parse_chkrootkit_possible_rootkit() {
let output = r#"
Checking `sniffer'... Possible rootkit activity detected
"#;
let findings = parse_chkrootkit_output(output);
assert_eq!(findings.len(), 1);
assert_eq!(findings[0].severity, FindingSeverity::High);
assert!(findings[0].title.contains("Possible Rootkit"));
}
#[test]
fn test_parse_chkrootkit_warning() {
let output = r#"
Warning: some security issue detected
"#;
let findings = parse_chkrootkit_output(output);
assert_eq!(findings.len(), 1);
assert_eq!(findings[0].severity, FindingSeverity::Medium);
}
}

View file

@ -0,0 +1,597 @@
use anyhow::{Context, Result};
use std::fs;
use std::os::unix::fs::PermissionsExt;
use std::path::Path;
use std::process::Command;
use tracing::{error, info, warn};
use crate::security::command_guard::SafeCommand;
const SUDOERS_FILE: &str = "/etc/sudoers.d/gb-protection";
const SUDOERS_CONTENT: &str = r#"# General Bots Security Protection Tools
# This file is managed by botserver install protection
# DO NOT EDIT MANUALLY
# Lynis - security auditing
{user} ALL=(ALL) NOPASSWD: /usr/bin/lynis audit system
{user} ALL=(ALL) NOPASSWD: /usr/bin/lynis audit system --quick
{user} ALL=(ALL) NOPASSWD: /usr/bin/lynis audit system --quick --no-colors
{user} ALL=(ALL) NOPASSWD: /usr/bin/lynis audit system --no-colors
# RKHunter - rootkit detection
{user} ALL=(ALL) NOPASSWD: /usr/bin/rkhunter --check --skip-keypress
{user} ALL=(ALL) NOPASSWD: /usr/bin/rkhunter --check --skip-keypress --report-warnings-only
{user} ALL=(ALL) NOPASSWD: /usr/bin/rkhunter --update
# Chkrootkit - rootkit detection
{user} ALL=(ALL) NOPASSWD: /usr/bin/chkrootkit
{user} ALL=(ALL) NOPASSWD: /usr/bin/chkrootkit -q
# Suricata - IDS/IPS
{user} ALL=(ALL) NOPASSWD: /usr/bin/systemctl start suricata
{user} ALL=(ALL) NOPASSWD: /usr/bin/systemctl stop suricata
{user} ALL=(ALL) NOPASSWD: /usr/bin/systemctl restart suricata
{user} ALL=(ALL) NOPASSWD: /usr/bin/systemctl enable suricata
{user} ALL=(ALL) NOPASSWD: /usr/bin/systemctl disable suricata
{user} ALL=(ALL) NOPASSWD: /usr/bin/systemctl is-active suricata
{user} ALL=(ALL) NOPASSWD: /usr/bin/suricata-update
# ClamAV - antivirus
{user} ALL=(ALL) NOPASSWD: /usr/bin/systemctl start clamav-daemon
{user} ALL=(ALL) NOPASSWD: /usr/bin/systemctl stop clamav-daemon
{user} ALL=(ALL) NOPASSWD: /usr/bin/systemctl restart clamav-daemon
{user} ALL=(ALL) NOPASSWD: /usr/bin/systemctl enable clamav-daemon
{user} ALL=(ALL) NOPASSWD: /usr/bin/systemctl disable clamav-daemon
{user} ALL=(ALL) NOPASSWD: /usr/bin/systemctl is-active clamav-daemon
{user} ALL=(ALL) NOPASSWD: /usr/bin/freshclam
# LMD (Linux Malware Detect)
{user} ALL=(ALL) NOPASSWD: /usr/local/sbin/maldet -a /home
{user} ALL=(ALL) NOPASSWD: /usr/local/sbin/maldet -a /var/www
{user} ALL=(ALL) NOPASSWD: /usr/local/sbin/maldet -a /tmp
{user} ALL=(ALL) NOPASSWD: /usr/local/sbin/maldet --update-sigs
{user} ALL=(ALL) NOPASSWD: /usr/local/sbin/maldet --update-ver
"#;
const PACKAGES: &[&str] = &[
"lynis",
"rkhunter",
"chkrootkit",
"suricata",
"clamav",
"clamav-daemon",
];
pub struct ProtectionInstaller {
user: String,
}
impl ProtectionInstaller {
pub fn new() -> Result<Self> {
let user = std::env::var("SUDO_USER")
.or_else(|_| std::env::var("USER"))
.unwrap_or_else(|_| "root".to_string());
Ok(Self { user })
}
pub fn check_root() -> bool {
Command::new("id")
.arg("-u")
.output()
.map(|o| String::from_utf8_lossy(&o.stdout).trim() == "0")
.unwrap_or(false)
}
pub fn install(&self) -> Result<InstallResult> {
if !Self::check_root() {
return Err(anyhow::anyhow!(
"This command requires root privileges. Run with: sudo botserver install protection"
));
}
info!("Starting security protection installation for user: {}", self.user);
let mut result = InstallResult::default();
match self.install_packages() {
Ok(installed) => {
result.packages_installed = installed;
info!("Packages installed: {:?}", result.packages_installed);
}
Err(e) => {
error!("Failed to install packages: {e}");
result.errors.push(format!("Package installation failed: {e}"));
}
}
match self.create_sudoers() {
Ok(()) => {
result.sudoers_created = true;
info!("Sudoers file created successfully");
}
Err(e) => {
error!("Failed to create sudoers file: {e}");
result.errors.push(format!("Sudoers creation failed: {e}"));
}
}
match self.install_lmd() {
Ok(installed) => {
if installed {
result.packages_installed.push("maldetect".to_string());
info!("LMD (maldetect) installed successfully");
}
}
Err(e) => {
warn!("LMD installation skipped: {e}");
result.warnings.push(format!("LMD installation skipped: {e}"));
}
}
match self.update_databases() {
Ok(()) => {
result.databases_updated = true;
info!("Security databases updated");
}
Err(e) => {
warn!("Database update failed: {e}");
result.warnings.push(format!("Database update failed: {e}"));
}
}
result.success = result.errors.is_empty();
Ok(result)
}
fn install_packages(&self) -> Result<Vec<String>> {
info!("Updating package lists...");
SafeCommand::new("apt-get")?
.arg("update")?
.execute()
.context("Failed to update package lists")?;
let mut installed = Vec::new();
for package in PACKAGES {
info!("Installing package: {package}");
let result = SafeCommand::new("apt-get")?
.arg("install")?
.arg("-y")?
.arg(package)?
.execute();
match result {
Ok(output) => {
if output.status.success() {
installed.push((*package).to_string());
} else {
let stderr = String::from_utf8_lossy(&output.stderr);
warn!("Package {package} installation had issues: {stderr}");
}
}
Err(e) => {
warn!("Failed to install {package}: {e}");
}
}
}
Ok(installed)
}
fn create_sudoers(&self) -> Result<()> {
let content = SUDOERS_CONTENT.replace("{user}", &self.user);
info!("Creating sudoers file at {SUDOERS_FILE}");
fs::write(SUDOERS_FILE, &content)
.context("Failed to write sudoers file")?;
let permissions = fs::Permissions::from_mode(0o440);
fs::set_permissions(SUDOERS_FILE, permissions)
.context("Failed to set sudoers file permissions")?;
self.validate_sudoers()?;
info!("Sudoers file created and validated");
Ok(())
}
fn validate_sudoers(&self) -> Result<()> {
let output = std::process::Command::new("visudo")
.args(["-c", "-f", SUDOERS_FILE])
.output()
.context("Failed to run visudo validation")?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
fs::remove_file(SUDOERS_FILE).ok();
return Err(anyhow::anyhow!("Invalid sudoers file syntax: {stderr}"));
}
Ok(())
}
fn install_lmd(&self) -> Result<bool> {
let maldet_path = Path::new("/usr/local/sbin/maldet");
if maldet_path.exists() {
info!("LMD already installed");
return Ok(false);
}
info!("Installing Linux Malware Detect (LMD)...");
let temp_dir = "/tmp/maldetect_install";
fs::create_dir_all(temp_dir).ok();
let download_result = SafeCommand::new("curl")?
.arg("-sL")?
.arg("-o")?
.arg("/tmp/maldetect-current.tar.gz")?
.arg("https://www.rfxn.com/downloads/maldetect-current.tar.gz")?
.execute();
if download_result.is_err() {
return Err(anyhow::anyhow!("Failed to download LMD"));
}
SafeCommand::new("tar")?
.arg("-xzf")?
.arg("/tmp/maldetect-current.tar.gz")?
.arg("-C")?
.arg(temp_dir)?
.execute()
.context("Failed to extract LMD archive")?;
let entries = fs::read_dir(temp_dir)?;
let mut install_dir = None;
for entry in entries.flatten() {
let path = entry.path();
if path.is_dir() && path.file_name().is_some_and(|n| n.to_string_lossy().starts_with("maldetect")) {
install_dir = Some(path);
break;
}
}
let install_dir = install_dir.ok_or_else(|| anyhow::anyhow!("LMD install directory not found"))?;
let install_script = install_dir.join("install.sh");
if !install_script.exists() {
return Err(anyhow::anyhow!("LMD install.sh not found"));
}
SafeCommand::new("bash")?
.arg("-c")?
.shell_script_arg(&format!("cd {} && ./install.sh", install_dir.display()))?
.execute()
.context("Failed to run LMD installer")?;
fs::remove_dir_all(temp_dir).ok();
fs::remove_file("/tmp/maldetect-current.tar.gz").ok();
Ok(true)
}
fn update_databases(&self) -> Result<()> {
info!("Updating security tool databases...");
if Path::new("/usr/bin/rkhunter").exists() {
info!("Updating RKHunter database...");
let result = SafeCommand::new("rkhunter")?
.arg("--update")?
.execute();
if let Err(e) = result {
warn!("RKHunter update failed: {e}");
}
}
if Path::new("/usr/bin/freshclam").exists() {
info!("Updating ClamAV signatures...");
let result = SafeCommand::new("freshclam")?
.execute();
if let Err(e) = result {
warn!("ClamAV update failed: {e}");
}
}
if Path::new("/usr/bin/suricata-update").exists() {
info!("Updating Suricata rules...");
let result = SafeCommand::new("suricata-update")?
.execute();
if let Err(e) = result {
warn!("Suricata update failed: {e}");
}
}
if Path::new("/usr/local/sbin/maldet").exists() {
info!("Updating LMD signatures...");
let result = SafeCommand::new("maldet")?
.arg("--update-sigs")?
.execute();
if let Err(e) = result {
warn!("LMD update failed: {e}");
}
}
Ok(())
}
pub fn uninstall(&self) -> Result<UninstallResult> {
if !Self::check_root() {
return Err(anyhow::anyhow!(
"This command requires root privileges. Run with: sudo botserver remove protection"
));
}
info!("Removing security protection components...");
let mut result = UninstallResult::default();
if Path::new(SUDOERS_FILE).exists() {
match fs::remove_file(SUDOERS_FILE) {
Ok(()) => {
result.sudoers_removed = true;
info!("Removed sudoers file");
}
Err(e) => {
result.errors.push(format!("Failed to remove sudoers: {e}"));
}
}
}
result.success = result.errors.is_empty();
result.message = "Protection sudoers removed. Packages were NOT uninstalled - remove manually if needed.".to_string();
Ok(result)
}
pub fn verify(&self) -> VerifyResult {
let mut result = VerifyResult::default();
for package in PACKAGES {
let binary = match *package {
"clamav" | "clamav-daemon" => "clamscan",
other => other,
};
let check = SafeCommand::new("which")
.and_then(|cmd| cmd.arg(binary))
.and_then(|cmd| cmd.execute());
let installed = check.map(|o| o.status.success()).unwrap_or(false);
result.tools.push(ToolVerification {
name: (*package).to_string(),
installed,
sudo_configured: false,
});
}
let maldet_installed = Path::new("/usr/local/sbin/maldet").exists();
result.tools.push(ToolVerification {
name: "maldetect".to_string(),
installed: maldet_installed,
sudo_configured: false,
});
result.sudoers_exists = Path::new(SUDOERS_FILE).exists();
if result.sudoers_exists {
if let Ok(content) = fs::read_to_string(SUDOERS_FILE) {
for tool in &mut result.tools {
tool.sudo_configured = content.contains(&tool.name) ||
(tool.name == "clamav" && content.contains("clamav-daemon")) ||
(tool.name == "clamav-daemon" && content.contains("clamav-daemon"));
}
}
}
result.all_installed = result.tools.iter().filter(|t| t.name != "clamav-daemon").all(|t| t.installed);
result.all_configured = result.sudoers_exists && result.tools.iter().all(|t| t.sudo_configured || !t.installed);
result
}
}
impl Default for ProtectionInstaller {
fn default() -> Self {
Self::new().unwrap_or(Self { user: "root".to_string() })
}
}
#[derive(Debug, Default)]
pub struct InstallResult {
pub success: bool,
pub packages_installed: Vec<String>,
pub sudoers_created: bool,
pub databases_updated: bool,
pub errors: Vec<String>,
pub warnings: Vec<String>,
}
impl InstallResult {
pub fn print(&self) {
println!();
if self.success {
println!("✓ Security Protection installed successfully!");
} else {
println!("✗ Security Protection installation completed with errors");
}
println!();
if !self.packages_installed.is_empty() {
println!("Packages installed:");
for pkg in &self.packages_installed {
println!("{pkg}");
}
println!();
}
if self.sudoers_created {
println!("✓ Sudoers configuration created at {SUDOERS_FILE}");
}
if self.databases_updated {
println!("✓ Security databases updated");
}
if !self.warnings.is_empty() {
println!();
println!("Warnings:");
for warn in &self.warnings {
println!("{warn}");
}
}
if !self.errors.is_empty() {
println!();
println!("Errors:");
for err in &self.errors {
println!("{err}");
}
}
println!();
println!("The following commands are now available via the UI:");
println!(" - Lynis security audits");
println!(" - RKHunter rootkit scans");
println!(" - Chkrootkit scans");
println!(" - Suricata IDS management");
println!(" - ClamAV antivirus scans");
println!(" - LMD malware detection");
}
}
#[derive(Debug, Default)]
pub struct UninstallResult {
pub success: bool,
pub sudoers_removed: bool,
pub message: String,
pub errors: Vec<String>,
}
impl UninstallResult {
pub fn print(&self) {
println!();
if self.success {
println!("{}", self.message);
} else {
println!("✗ Uninstall completed with errors");
for err in &self.errors {
println!("{err}");
}
}
}
}
#[derive(Debug, Default)]
pub struct VerifyResult {
pub all_installed: bool,
pub all_configured: bool,
pub sudoers_exists: bool,
pub tools: Vec<ToolVerification>,
}
#[derive(Debug, Default)]
pub struct ToolVerification {
pub name: String,
pub installed: bool,
pub sudo_configured: bool,
}
impl VerifyResult {
pub fn print(&self) {
println!();
println!("Security Protection Status:");
println!();
println!("Tools:");
for tool in &self.tools {
let installed_mark = if tool.installed { "" } else { "" };
let sudo_mark = if tool.sudo_configured { "" } else { "" };
println!(" {} {} (installed: {}, sudo: {})",
if tool.installed && tool.sudo_configured { "" } else { "" },
tool.name,
installed_mark,
sudo_mark
);
}
println!();
println!("Sudoers file: {}", if self.sudoers_exists { "✓ exists" } else { "✗ missing" });
println!();
if self.all_installed && self.all_configured {
println!("✓ All protection tools are properly configured");
} else if !self.all_installed {
println!("⚠ Some tools are not installed. Run: sudo botserver install protection");
} else {
println!("⚠ Sudoers not configured. Run: sudo botserver install protection");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_install_result_default() {
let result = InstallResult::default();
assert!(!result.success);
assert!(result.packages_installed.is_empty());
assert!(!result.sudoers_created);
}
#[test]
fn test_verify_result_default() {
let result = VerifyResult::default();
assert!(!result.all_installed);
assert!(!result.all_configured);
assert!(result.tools.is_empty());
}
#[test]
fn test_sudoers_content_has_placeholder() {
assert!(SUDOERS_CONTENT.contains("{user}"));
}
#[test]
fn test_sudoers_content_no_wildcards() {
assert!(!SUDOERS_CONTENT.contains(" * "));
assert!(!SUDOERS_CONTENT.lines().any(|l| l.trim().ends_with('*')));
}
#[test]
fn test_packages_list() {
assert!(PACKAGES.contains(&"lynis"));
assert!(PACKAGES.contains(&"rkhunter"));
assert!(PACKAGES.contains(&"chkrootkit"));
assert!(PACKAGES.contains(&"suricata"));
assert!(PACKAGES.contains(&"clamav"));
}
#[test]
fn test_tool_verification_default() {
let tool = ToolVerification::default();
assert!(tool.name.is_empty());
assert!(!tool.installed);
assert!(!tool.sudo_configured);
}
#[test]
fn test_uninstall_result_default() {
let result = UninstallResult::default();
assert!(!result.success);
assert!(!result.sudoers_removed);
assert!(result.message.is_empty());
}
#[test]
fn test_protection_installer_default() {
let installer = ProtectionInstaller::default();
assert!(!installer.user.is_empty());
}
}

View file

@ -0,0 +1,481 @@
use anyhow::{Context, Result};
use tracing::info;
use crate::security::command_guard::SafeCommand;
use super::manager::{Finding, FindingSeverity, ScanResultStatus};
const LMD_LOG_DIR: &str = "/usr/local/maldetect/logs";
const LMD_QUARANTINE_DIR: &str = "/usr/local/maldetect/quarantine";
pub async fn run_scan(path: Option<&str>) -> Result<(ScanResultStatus, Vec<Finding>, String)> {
info!("Running Linux Malware Detect scan");
let scan_path = path.unwrap_or("/var/www");
let output = SafeCommand::new("sudo")?
.arg("maldet")?
.arg("-a")?
.arg(scan_path)?
.execute()
.context("Failed to run LMD scan")?;
let stdout = String::from_utf8_lossy(&output.stdout).to_string();
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
let raw_output = format!("{stdout}\n{stderr}");
let findings = parse_lmd_output(&stdout);
let status = determine_result_status(&findings);
Ok((status, findings, raw_output))
}
pub async fn run_background_scan(path: Option<&str>) -> Result<String> {
info!("Starting LMD background scan");
let scan_path = path.unwrap_or("/var/www");
let output = SafeCommand::new("sudo")?
.arg("maldet")?
.arg("-b")?
.arg("-a")?
.arg(scan_path)?
.execute()
.context("Failed to start LMD background scan")?;
let stdout = String::from_utf8_lossy(&output.stdout);
let scan_id = extract_scan_id(&stdout).unwrap_or_else(|| "unknown".to_string());
info!("LMD background scan started with ID: {scan_id}");
Ok(scan_id)
}
pub async fn update_signatures() -> Result<()> {
info!("Updating LMD signatures");
SafeCommand::new("sudo")?
.arg("maldet")?
.arg("--update-sigs")?
.execute()
.context("Failed to update LMD signatures")?;
info!("LMD signatures updated successfully");
Ok(())
}
pub async fn update_version() -> Result<()> {
info!("Updating LMD version");
SafeCommand::new("sudo")?
.arg("maldet")?
.arg("--update-ver")?
.execute()
.context("Failed to update LMD version")?;
info!("LMD version updated successfully");
Ok(())
}
pub async fn get_version() -> Result<String> {
let output = SafeCommand::new("maldet")?
.arg("--version")?
.execute()
.context("Failed to get LMD version")?;
let stdout = String::from_utf8_lossy(&output.stdout);
let version = stdout
.lines()
.find(|l| l.contains("maldet") || l.contains("version"))
.and_then(|l| l.split_whitespace().last())
.unwrap_or("unknown")
.to_string();
Ok(version)
}
pub async fn get_signature_count() -> Result<u64> {
let sig_dir = "/usr/local/maldetect/sigs";
let mut count = 0u64;
if let Ok(entries) = std::fs::read_dir(sig_dir) {
for entry in entries.flatten() {
let path = entry.path();
if path.is_file() {
if let Ok(content) = std::fs::read_to_string(&path) {
count += content.lines().filter(|l| !l.trim().is_empty()).count() as u64;
}
}
}
}
Ok(count)
}
pub async fn quarantine_file(file_path: &str) -> Result<()> {
info!("Quarantining file: {file_path}");
SafeCommand::new("sudo")?
.arg("maldet")?
.arg("-q")?
.arg(file_path)?
.execute()
.context("Failed to quarantine file")?;
info!("File quarantined successfully: {file_path}");
Ok(())
}
pub async fn restore_file(file_path: &str) -> Result<()> {
info!("Restoring file from quarantine: {file_path}");
SafeCommand::new("sudo")?
.arg("maldet")?
.arg("--restore")?
.arg(file_path)?
.execute()
.context("Failed to restore file from quarantine")?;
info!("File restored successfully: {file_path}");
Ok(())
}
pub async fn clean_file(file_path: &str) -> Result<()> {
info!("Cleaning infected file: {file_path}");
SafeCommand::new("sudo")?
.arg("maldet")?
.arg("-n")?
.arg(file_path)?
.execute()
.context("Failed to clean file")?;
info!("File cleaned successfully: {file_path}");
Ok(())
}
pub async fn get_report(scan_id: &str) -> Result<String> {
info!("Retrieving LMD report for scan: {scan_id}");
let output = SafeCommand::new("sudo")?
.arg("maldet")?
.arg("--report")?
.arg(scan_id)?
.execute()
.context("Failed to get LMD report")?;
let stdout = String::from_utf8_lossy(&output.stdout).to_string();
Ok(stdout)
}
pub async fn list_quarantined() -> Result<Vec<QuarantinedFile>> {
let mut files = Vec::new();
if let Ok(entries) = std::fs::read_dir(LMD_QUARANTINE_DIR) {
for entry in entries.flatten() {
let path = entry.path();
if path.is_file() {
let filename = path.file_name()
.and_then(|n| n.to_str())
.unwrap_or("unknown")
.to_string();
let metadata = std::fs::metadata(&path).ok();
let size = metadata.as_ref().map(|m| m.len()).unwrap_or(0);
let quarantined_at = metadata
.and_then(|m| m.modified().ok())
.map(|t| chrono::DateTime::<chrono::Utc>::from(t));
files.push(QuarantinedFile {
id: filename.clone(),
original_path: extract_original_path(&filename),
quarantine_path: path.to_string_lossy().to_string(),
size,
quarantined_at,
threat_name: None,
});
}
}
}
Ok(files)
}
pub fn parse_lmd_output(output: &str) -> Vec<Finding> {
let mut findings = Vec::new();
for line in output.lines() {
let trimmed = line.trim();
if trimmed.contains("HIT") || trimmed.contains("FOUND") {
let parts: Vec<&str> = trimmed.split_whitespace().collect();
let file_path = parts.iter().find(|p| p.starts_with('/')).map(|s| s.to_string());
let threat_name = parts.iter()
.find(|p| p.contains("malware") || p.contains("backdoor") || p.contains("trojan"))
.map(|s| s.to_string())
.unwrap_or_else(|| "Malware".to_string());
let finding = Finding {
id: format!("lmd-hit-{}", findings.len()),
severity: FindingSeverity::Critical,
category: "Malware Detection".to_string(),
title: format!("Malware Detected: {threat_name}"),
description: trimmed.to_string(),
file_path,
remediation: Some("Quarantine or remove the infected file immediately".to_string()),
};
findings.push(finding);
}
if trimmed.contains("suspicious") || trimmed.contains("Suspicious") {
let file_path = extract_file_path_from_line(trimmed);
let finding = Finding {
id: format!("lmd-susp-{}", findings.len()),
severity: FindingSeverity::High,
category: "Suspicious Activity".to_string(),
title: "Suspicious File Detected".to_string(),
description: trimmed.to_string(),
file_path,
remediation: Some("Review the file and consider quarantine if malicious".to_string()),
};
findings.push(finding);
}
if trimmed.contains("warning") || trimmed.contains("Warning") {
let finding = Finding {
id: format!("lmd-warn-{}", findings.len()),
severity: FindingSeverity::Medium,
category: "Warning".to_string(),
title: "LMD Warning".to_string(),
description: trimmed.to_string(),
file_path: None,
remediation: None,
};
findings.push(finding);
}
}
findings
}
fn extract_scan_id(output: &str) -> Option<String> {
for line in output.lines() {
if line.contains("scan id:") || line.contains("SCAN ID:") {
return line.split(':').nth(1).map(|s| s.trim().to_string());
}
if line.contains("report") && line.contains(".") {
let parts: Vec<&str> = line.split_whitespace().collect();
for part in parts {
if part.contains('.') && part.chars().all(|c| c.is_numeric() || c == '.') {
return Some(part.to_string());
}
}
}
}
None
}
fn extract_file_path_from_line(line: &str) -> Option<String> {
let words: Vec<&str> = line.split_whitespace().collect();
for word in words {
if word.starts_with('/') {
return Some(word.trim_matches(|c| c == ':' || c == ',' || c == ';').to_string());
}
}
None
}
fn extract_original_path(quarantine_filename: &str) -> String {
quarantine_filename
.replace(".", "/")
.trim_start_matches('/')
.to_string()
}
fn determine_result_status(findings: &[Finding]) -> ScanResultStatus {
let has_critical = findings.iter().any(|f| f.severity == FindingSeverity::Critical);
let has_high = findings.iter().any(|f| f.severity == FindingSeverity::High);
let has_medium = findings.iter().any(|f| f.severity == FindingSeverity::Medium);
if has_critical {
ScanResultStatus::Infected
} else if has_high {
ScanResultStatus::Warnings
} else if has_medium {
ScanResultStatus::Warnings
} else {
ScanResultStatus::Clean
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct QuarantinedFile {
pub id: String,
pub original_path: String,
pub quarantine_path: String,
pub size: u64,
pub quarantined_at: Option<chrono::DateTime<chrono::Utc>>,
pub threat_name: Option<String>,
}
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
pub struct LMDStats {
pub signature_count: u64,
pub quarantined_count: u32,
pub last_scan: Option<chrono::DateTime<chrono::Utc>>,
pub threats_found: u32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_lmd_output_clean() {
let output = r#"
Linux Malware Detect v1.6.5
Scanning /var/www
Total files scanned: 1234
Total hits: 0
Total cleaned: 0
"#;
let findings = parse_lmd_output(output);
assert!(findings.is_empty());
}
#[test]
fn test_parse_lmd_output_hit() {
let output = r#"
Linux Malware Detect v1.6.5
Scanning /var/www
{HIT} /var/www/uploads/shell.php : php.cmdshell.unclassed.6
Total hits: 1
"#;
let findings = parse_lmd_output(output);
assert_eq!(findings.len(), 1);
assert_eq!(findings[0].severity, FindingSeverity::Critical);
}
#[test]
fn test_parse_lmd_output_suspicious() {
let output = r#"
Linux Malware Detect v1.6.5
Scanning /var/www
suspicious file found: /var/www/uploads/unknown.php
"#;
let findings = parse_lmd_output(output);
assert_eq!(findings.len(), 1);
assert_eq!(findings[0].severity, FindingSeverity::High);
}
#[test]
fn test_extract_scan_id() {
assert_eq!(
extract_scan_id("scan id: 123456.789"),
Some("123456.789".to_string())
);
assert_eq!(
extract_scan_id("SCAN ID: abc123"),
Some("abc123".to_string())
);
assert_eq!(extract_scan_id("no scan id here"), None);
}
#[test]
fn test_extract_file_path_from_line() {
assert_eq!(
extract_file_path_from_line("Found malware in /var/www/shell.php"),
Some("/var/www/shell.php".to_string())
);
assert_eq!(
extract_file_path_from_line("No path here"),
None
);
}
#[test]
fn test_extract_original_path() {
assert_eq!(
extract_original_path("var.www.uploads.shell.php"),
"var/www/uploads/shell/php"
);
}
#[test]
fn test_determine_result_status_clean() {
let findings: Vec<Finding> = vec![];
assert_eq!(determine_result_status(&findings), ScanResultStatus::Clean);
}
#[test]
fn test_determine_result_status_infected() {
let findings = vec![Finding {
id: "test".to_string(),
severity: FindingSeverity::Critical,
category: "test".to_string(),
title: "Test".to_string(),
description: "Test".to_string(),
file_path: None,
remediation: None,
}];
assert_eq!(determine_result_status(&findings), ScanResultStatus::Infected);
}
#[test]
fn test_determine_result_status_warnings() {
let findings = vec![Finding {
id: "test".to_string(),
severity: FindingSeverity::High,
category: "test".to_string(),
title: "Test".to_string(),
description: "Test".to_string(),
file_path: None,
remediation: None,
}];
assert_eq!(determine_result_status(&findings), ScanResultStatus::Warnings);
}
#[test]
fn test_quarantined_file_struct() {
let file = QuarantinedFile {
id: "test".to_string(),
original_path: "/var/www/shell.php".to_string(),
quarantine_path: "/usr/local/maldetect/quarantine/test".to_string(),
size: 1024,
quarantined_at: None,
threat_name: Some("php.cmdshell".to_string()),
};
assert_eq!(file.size, 1024);
assert!(file.threat_name.is_some());
}
#[test]
fn test_lmd_stats_default() {
let stats = LMDStats::default();
assert_eq!(stats.signature_count, 0);
assert_eq!(stats.quarantined_count, 0);
assert!(stats.last_scan.is_none());
}
#[test]
fn test_parse_lmd_output_warning() {
let output = r#"
Linux Malware Detect v1.6.5
Warning: signature database may be outdated
Scanning /var/www
"#;
let findings = parse_lmd_output(output);
assert_eq!(findings.len(), 1);
assert_eq!(findings[0].severity, FindingSeverity::Medium);
}
#[test]
fn test_parse_lmd_output_found() {
let output = r#"
FOUND: /var/www/malicious.php : malware.backdoor.123
"#;
let findings = parse_lmd_output(output);
assert_eq!(findings.len(), 1);
assert_eq!(findings[0].severity, FindingSeverity::Critical);
}
}

View file

@ -0,0 +1,273 @@
use anyhow::{Context, Result};
use tracing::{info, warn};
use crate::security::command_guard::SafeCommand;
use super::manager::{Finding, FindingSeverity, ScanResultStatus};
const LYNIS_REPORT_PATH: &str = "/var/log/lynis-report.dat";
const LYNIS_LOG_PATH: &str = "/var/log/lynis.log";
pub async fn run_scan() -> Result<(ScanResultStatus, Vec<Finding>, String)> {
info!("Running Lynis security audit");
let output = SafeCommand::new("sudo")?
.arg("lynis")?
.arg("audit")?
.arg("system")?
.arg("--quick")?
.arg("--no-colors")?
.execute()
.context("Failed to run Lynis audit")?;
let stdout = String::from_utf8_lossy(&output.stdout).to_string();
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
let raw_output = format!("{stdout}\n{stderr}");
let findings = parse_lynis_output(&stdout);
let status = determine_result_status(&findings);
Ok((status, findings, raw_output))
}
pub async fn run_full_audit() -> Result<(ScanResultStatus, Vec<Finding>, String)> {
info!("Running full Lynis security audit");
let output = SafeCommand::new("sudo")?
.arg("lynis")?
.arg("audit")?
.arg("system")?
.arg("--no-colors")?
.execute()
.context("Failed to run full Lynis audit")?;
let stdout = String::from_utf8_lossy(&output.stdout).to_string();
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
let raw_output = format!("{stdout}\n{stderr}");
let findings = parse_lynis_output(&stdout);
let status = determine_result_status(&findings);
Ok((status, findings, raw_output))
}
pub fn parse_lynis_output(output: &str) -> Vec<Finding> {
let mut findings = Vec::new();
let mut current_category = String::new();
for line in output.lines() {
let trimmed = line.trim();
if trimmed.starts_with('[') && trimmed.contains(']') {
if let Some(category) = extract_category(trimmed) {
current_category = category;
}
}
if trimmed.contains("Warning:") || trimmed.contains("WARNING") {
let finding = Finding {
id: format!("lynis-warn-{}", findings.len()),
severity: FindingSeverity::Medium,
category: current_category.clone(),
title: "Security Warning".to_string(),
description: trimmed.replace("Warning:", "").trim().to_string(),
file_path: None,
remediation: None,
};
findings.push(finding);
}
if trimmed.contains("Suggestion:") || trimmed.contains("SUGGESTION") {
let finding = Finding {
id: format!("lynis-sugg-{}", findings.len()),
severity: FindingSeverity::Low,
category: current_category.clone(),
title: "Security Suggestion".to_string(),
description: trimmed.replace("Suggestion:", "").trim().to_string(),
file_path: None,
remediation: extract_remediation(trimmed),
};
findings.push(finding);
}
if trimmed.contains("[FOUND]") && trimmed.contains("vulnerable") {
let finding = Finding {
id: format!("lynis-vuln-{}", findings.len()),
severity: FindingSeverity::High,
category: current_category.clone(),
title: "Vulnerability Found".to_string(),
description: trimmed.to_string(),
file_path: None,
remediation: None,
};
findings.push(finding);
}
}
findings
}
pub fn parse_report_file() -> Result<LynisReport> {
let content = std::fs::read_to_string(LYNIS_REPORT_PATH)
.context("Failed to read Lynis report file")?;
let mut report = LynisReport::default();
for line in content.lines() {
if line.starts_with('#') || line.trim().is_empty() {
continue;
}
if let Some((key, value)) = line.split_once('=') {
match key {
"hardening_index" => {
report.hardening_index = value.parse().unwrap_or(0);
}
"warning[]" => {
report.warnings.push(value.to_string());
}
"suggestion[]" => {
report.suggestions.push(value.to_string());
}
"lynis_version" => {
report.version = value.to_string();
}
"test_category[]" => {
report.categories_tested.push(value.to_string());
}
"tests_executed" => {
report.tests_executed = value.parse().unwrap_or(0);
}
_ => {}
}
}
}
Ok(report)
}
pub async fn get_hardening_index() -> Result<u32> {
let report = parse_report_file()?;
Ok(report.hardening_index)
}
pub async fn apply_suggestion(suggestion_id: &str) -> Result<()> {
info!("Applying Lynis suggestion: {suggestion_id}");
warn!("Auto-remediation for suggestion {suggestion_id} not yet implemented");
Ok(())
}
fn extract_category(line: &str) -> Option<String> {
let start = line.find('[')?;
let end = line.find(']')?;
if start < end {
Some(line[start + 1..end].trim().to_string())
} else {
None
}
}
fn extract_remediation(line: &str) -> Option<String> {
if line.contains("Consider") {
Some(line.split("Consider").nth(1)?.trim().to_string())
} else if line.contains("Disable") {
Some(line.split("Disable").nth(1).map(|s| format!("Disable {}", s.trim()))?)
} else if line.contains("Enable") {
Some(line.split("Enable").nth(1).map(|s| format!("Enable {}", s.trim()))?)
} else {
None
}
}
fn determine_result_status(findings: &[Finding]) -> ScanResultStatus {
let has_critical = findings.iter().any(|f| f.severity == FindingSeverity::Critical);
let has_high = findings.iter().any(|f| f.severity == FindingSeverity::High);
let has_medium = findings.iter().any(|f| f.severity == FindingSeverity::Medium);
if has_critical || has_high {
ScanResultStatus::Infected
} else if has_medium {
ScanResultStatus::Warnings
} else {
ScanResultStatus::Clean
}
}
#[derive(Debug, Clone, Default)]
pub struct LynisReport {
pub version: String,
pub hardening_index: u32,
pub tests_executed: u32,
pub warnings: Vec<String>,
pub suggestions: Vec<String>,
pub categories_tested: Vec<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_lynis_output_warnings() {
let output = r#"
[+] Boot and services
- Service Manager [ systemd ]
Warning: Some warning message here
[+] Kernel
Suggestion: Consider enabling some feature
"#;
let findings = parse_lynis_output(output);
assert_eq!(findings.len(), 2);
assert_eq!(findings[0].severity, FindingSeverity::Medium);
assert_eq!(findings[1].severity, FindingSeverity::Low);
}
#[test]
fn test_extract_category() {
assert_eq!(extract_category("[+] Boot and services"), Some("+ Boot and services".to_string()));
assert_eq!(extract_category("no brackets"), None);
}
#[test]
fn test_determine_result_status_clean() {
let findings: Vec<Finding> = vec![];
assert_eq!(determine_result_status(&findings), ScanResultStatus::Clean);
}
#[test]
fn test_determine_result_status_warnings() {
let findings = vec![Finding {
id: "test".to_string(),
severity: FindingSeverity::Medium,
category: "test".to_string(),
title: "Test".to_string(),
description: "Test".to_string(),
file_path: None,
remediation: None,
}];
assert_eq!(determine_result_status(&findings), ScanResultStatus::Warnings);
}
#[test]
fn test_determine_result_status_infected() {
let findings = vec![Finding {
id: "test".to_string(),
severity: FindingSeverity::High,
category: "test".to_string(),
title: "Test".to_string(),
description: "Test".to_string(),
file_path: None,
remediation: None,
}];
assert_eq!(determine_result_status(&findings), ScanResultStatus::Infected);
}
#[test]
fn test_lynis_report_default() {
let report = LynisReport::default();
assert_eq!(report.hardening_index, 0);
assert!(report.warnings.is_empty());
assert!(report.suggestions.is_empty());
}
}

View file

@ -0,0 +1,621 @@
use anyhow::{Context, Result};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{info, warn};
use crate::security::command_guard::SafeCommand;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ProtectionTool {
Lynis,
RKHunter,
Chkrootkit,
Suricata,
LMD,
ClamAV,
}
impl std::fmt::Display for ProtectionTool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Lynis => write!(f, "lynis"),
Self::RKHunter => write!(f, "rkhunter"),
Self::Chkrootkit => write!(f, "chkrootkit"),
Self::Suricata => write!(f, "suricata"),
Self::LMD => write!(f, "lmd"),
Self::ClamAV => write!(f, "clamav"),
}
}
}
impl ProtectionTool {
pub fn from_str(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"lynis" => Some(Self::Lynis),
"rkhunter" => Some(Self::RKHunter),
"chkrootkit" => Some(Self::Chkrootkit),
"suricata" => Some(Self::Suricata),
"lmd" | "maldet" => Some(Self::LMD),
"clamav" | "clamscan" => Some(Self::ClamAV),
_ => None,
}
}
pub fn binary_name(&self) -> &'static str {
match self {
Self::Lynis => "lynis",
Self::RKHunter => "rkhunter",
Self::Chkrootkit => "chkrootkit",
Self::Suricata => "suricata",
Self::LMD => "maldet",
Self::ClamAV => "clamscan",
}
}
pub fn service_name(&self) -> Option<&'static str> {
match self {
Self::Suricata => Some("suricata"),
Self::ClamAV => Some("clamav-daemon"),
_ => None,
}
}
pub fn package_name(&self) -> &'static str {
match self {
Self::Lynis => "lynis",
Self::RKHunter => "rkhunter",
Self::Chkrootkit => "chkrootkit",
Self::Suricata => "suricata",
Self::LMD => "maldetect",
Self::ClamAV => "clamav",
}
}
pub fn all() -> Vec<Self> {
vec![
Self::Lynis,
Self::RKHunter,
Self::Chkrootkit,
Self::Suricata,
Self::LMD,
Self::ClamAV,
]
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolStatus {
pub tool: ProtectionTool,
pub installed: bool,
pub version: Option<String>,
pub service_running: Option<bool>,
pub last_scan: Option<DateTime<Utc>>,
pub last_update: Option<DateTime<Utc>>,
pub auto_update: bool,
pub auto_remediate: bool,
pub metrics: ToolMetrics,
}
impl ToolStatus {
pub fn not_installed(tool: ProtectionTool) -> Self {
Self {
tool,
installed: false,
version: None,
service_running: None,
last_scan: None,
last_update: None,
auto_update: false,
auto_remediate: false,
metrics: ToolMetrics::default(),
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ToolMetrics {
pub hardening_index: Option<u32>,
pub warnings: u32,
pub suggestions: u32,
pub threats_found: u32,
pub rules_count: Option<u32>,
pub alerts_today: u32,
pub blocked_today: u32,
pub signatures_count: Option<u64>,
pub quarantined_count: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScanResult {
pub scan_id: String,
pub tool: ProtectionTool,
pub started_at: DateTime<Utc>,
pub completed_at: Option<DateTime<Utc>>,
pub status: ScanStatus,
pub result: ScanResultStatus,
pub findings: Vec<Finding>,
pub warnings: u32,
pub report_path: Option<String>,
pub raw_output: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ScanStatus {
Pending,
Running,
Completed,
Failed,
Cancelled,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ScanResultStatus {
Clean,
Warnings,
Infected,
Unknown,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Finding {
pub id: String,
pub severity: FindingSeverity,
pub category: String,
pub title: String,
pub description: String,
pub file_path: Option<String>,
pub remediation: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum FindingSeverity {
Info,
Low,
Medium,
High,
Critical,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProtectionConfig {
pub enabled_tools: Vec<ProtectionTool>,
pub auto_scan_interval_hours: u32,
pub auto_update_interval_hours: u32,
pub quarantine_dir: String,
pub log_dir: String,
}
impl Default for ProtectionConfig {
fn default() -> Self {
Self {
enabled_tools: ProtectionTool::all(),
auto_scan_interval_hours: 24,
auto_update_interval_hours: 6,
quarantine_dir: "/var/lib/gb/quarantine".to_string(),
log_dir: "/var/log/gb/security".to_string(),
}
}
}
pub struct ProtectionManager {
config: ProtectionConfig,
tool_status: Arc<RwLock<HashMap<ProtectionTool, ToolStatus>>>,
active_scans: Arc<RwLock<HashMap<String, ScanResult>>>,
scan_history: Arc<RwLock<Vec<ScanResult>>>,
}
impl ProtectionManager {
pub fn new(config: ProtectionConfig) -> Self {
Self {
config,
tool_status: Arc::new(RwLock::new(HashMap::new())),
active_scans: Arc::new(RwLock::new(HashMap::new())),
scan_history: Arc::new(RwLock::new(Vec::new())),
}
}
pub async fn initialize(&self) -> Result<()> {
info!("Initializing Protection Manager");
for tool in &self.config.enabled_tools {
let status = self.check_tool_status(*tool).await?;
self.tool_status.write().await.insert(*tool, status);
}
Ok(())
}
pub async fn check_tool_status(&self, tool: ProtectionTool) -> Result<ToolStatus> {
let installed = self.is_tool_installed(tool).await;
if !installed {
return Ok(ToolStatus::not_installed(tool));
}
let version = self.get_tool_version(tool).await.ok();
let service_running = if tool.service_name().is_some() {
Some(self.is_service_running(tool).await)
} else {
None
};
let stored = self.tool_status.read().await;
let existing = stored.get(&tool);
Ok(ToolStatus {
tool,
installed: true,
version,
service_running,
last_scan: existing.and_then(|s| s.last_scan),
last_update: existing.and_then(|s| s.last_update),
auto_update: existing.map(|s| s.auto_update).unwrap_or(false),
auto_remediate: existing.map(|s| s.auto_remediate).unwrap_or(false),
metrics: existing.map(|s| s.metrics.clone()).unwrap_or_default(),
})
}
pub async fn is_tool_installed(&self, tool: ProtectionTool) -> bool {
let binary = tool.binary_name();
let result = SafeCommand::new("which")
.and_then(|cmd| cmd.arg(binary))
.and_then(|cmd| cmd.execute());
match result {
Ok(output) => output.status.success(),
Err(e) => {
warn!("Failed to check if {tool} is installed: {e}");
false
}
}
}
pub async fn get_tool_version(&self, tool: ProtectionTool) -> Result<String> {
let binary = tool.binary_name();
let version_arg = match tool {
ProtectionTool::Lynis => "--version",
ProtectionTool::RKHunter => "--version",
ProtectionTool::Chkrootkit => "-V",
ProtectionTool::Suricata => "--build-info",
ProtectionTool::LMD => "--version",
ProtectionTool::ClamAV => "--version",
};
let output = SafeCommand::new(binary)?
.arg(version_arg)?
.execute()
.context("Failed to get tool version")?;
let stdout = String::from_utf8_lossy(&output.stdout);
let version = stdout.lines().next().unwrap_or("unknown").trim().to_string();
Ok(version)
}
pub async fn is_service_running(&self, tool: ProtectionTool) -> bool {
let Some(service_name) = tool.service_name() else {
return false;
};
let result = SafeCommand::new("sudo")
.and_then(|cmd| cmd.arg("systemctl"))
.and_then(|cmd| cmd.arg("is-active"))
.and_then(|cmd| cmd.arg(service_name));
match result {
Ok(cmd) => match cmd.execute() {
Ok(output) => {
let stdout = String::from_utf8_lossy(&output.stdout);
stdout.trim() == "active"
}
Err(_) => false,
},
Err(_) => false,
}
}
pub async fn get_all_status(&self) -> HashMap<ProtectionTool, ToolStatus> {
self.tool_status.read().await.clone()
}
pub async fn get_tool_status_by_name(&self, name: &str) -> Option<ToolStatus> {
let tool = ProtectionTool::from_str(name)?;
self.tool_status.read().await.get(&tool).cloned()
}
pub async fn install_tool(&self, tool: ProtectionTool) -> Result<()> {
info!("Installing protection tool: {tool}");
let package = tool.package_name();
SafeCommand::new("apt-get")?
.arg("install")?
.arg("-y")?
.arg(package)?
.execute()
.context("Failed to install tool")?;
let status = self.check_tool_status(tool).await?;
self.tool_status.write().await.insert(tool, status);
info!("Successfully installed {tool}");
Ok(())
}
pub async fn start_service(&self, tool: ProtectionTool) -> Result<()> {
let Some(service_name) = tool.service_name() else {
return Err(anyhow::anyhow!("{tool} does not have a service"));
};
info!("Starting service: {service_name}");
SafeCommand::new("sudo")?
.arg("systemctl")?
.arg("start")?
.arg(service_name)?
.execute()
.context("Failed to start service")?;
if let Some(status) = self.tool_status.write().await.get_mut(&tool) {
status.service_running = Some(true);
}
Ok(())
}
pub async fn stop_service(&self, tool: ProtectionTool) -> Result<()> {
let Some(service_name) = tool.service_name() else {
return Err(anyhow::anyhow!("{tool} does not have a service"));
};
info!("Stopping service: {service_name}");
SafeCommand::new("sudo")?
.arg("systemctl")?
.arg("stop")?
.arg(service_name)?
.execute()
.context("Failed to stop service")?;
if let Some(status) = self.tool_status.write().await.get_mut(&tool) {
status.service_running = Some(false);
}
Ok(())
}
pub async fn enable_service(&self, tool: ProtectionTool) -> Result<()> {
let Some(service_name) = tool.service_name() else {
return Err(anyhow::anyhow!("{tool} does not have a service"));
};
SafeCommand::new("sudo")?
.arg("systemctl")?
.arg("enable")?
.arg(service_name)?
.execute()
.context("Failed to enable service")?;
Ok(())
}
pub async fn disable_service(&self, tool: ProtectionTool) -> Result<()> {
let Some(service_name) = tool.service_name() else {
return Err(anyhow::anyhow!("{tool} does not have a service"));
};
SafeCommand::new("sudo")?
.arg("systemctl")?
.arg("disable")?
.arg(service_name)?
.execute()
.context("Failed to disable service")?;
Ok(())
}
pub async fn run_scan(&self, tool: ProtectionTool) -> Result<ScanResult> {
let scan_id = uuid::Uuid::new_v4().to_string();
let started_at = Utc::now();
info!("Starting {tool} scan with ID: {scan_id}");
let mut result = ScanResult {
scan_id: scan_id.clone(),
tool,
started_at,
completed_at: None,
status: ScanStatus::Running,
result: ScanResultStatus::Unknown,
findings: Vec::new(),
warnings: 0,
report_path: None,
raw_output: None,
};
self.active_scans.write().await.insert(scan_id.clone(), result.clone());
let scan_output = match tool {
ProtectionTool::Lynis => super::lynis::run_scan().await,
ProtectionTool::RKHunter => super::rkhunter::run_scan().await,
ProtectionTool::Chkrootkit => super::chkrootkit::run_scan().await,
ProtectionTool::Suricata => super::suricata::get_alerts().await,
ProtectionTool::LMD => super::lmd::run_scan(None).await,
ProtectionTool::ClamAV => super::lynis::run_scan().await,
};
result.completed_at = Some(Utc::now());
match scan_output {
Ok((status, findings, raw)) => {
result.status = ScanStatus::Completed;
result.result = status;
result.warnings = findings.iter().filter(|f| f.severity == FindingSeverity::Medium || f.severity == FindingSeverity::Low).count() as u32;
result.findings = findings;
result.raw_output = Some(raw);
}
Err(e) => {
warn!("Scan failed for {tool}: {e}");
result.status = ScanStatus::Failed;
result.raw_output = Some(e.to_string());
}
}
self.active_scans.write().await.remove(&scan_id);
if let Some(status) = self.tool_status.write().await.get_mut(&tool) {
status.last_scan = Some(Utc::now());
status.metrics.warnings = result.warnings;
status.metrics.threats_found = result.findings.iter()
.filter(|f| f.severity == FindingSeverity::High || f.severity == FindingSeverity::Critical)
.count() as u32;
}
self.scan_history.write().await.push(result.clone());
Ok(result)
}
pub async fn update_definitions(&self, tool: ProtectionTool) -> Result<()> {
info!("Updating definitions for {tool}");
match tool {
ProtectionTool::RKHunter => {
SafeCommand::new("sudo")?
.arg("rkhunter")?
.arg("--update")?
.execute()?;
}
ProtectionTool::ClamAV => {
SafeCommand::new("sudo")?
.arg("freshclam")?
.execute()?;
}
ProtectionTool::Suricata => {
SafeCommand::new("sudo")?
.arg("suricata-update")?
.execute()?;
}
ProtectionTool::LMD => {
SafeCommand::new("sudo")?
.arg("maldet")?
.arg("--update-sigs")?
.execute()?;
}
_ => {
return Err(anyhow::anyhow!("{tool} does not support definition updates"));
}
}
if let Some(status) = self.tool_status.write().await.get_mut(&tool) {
status.last_update = Some(Utc::now());
}
Ok(())
}
pub async fn set_auto_update(&self, tool: ProtectionTool, enabled: bool) -> Result<()> {
if let Some(status) = self.tool_status.write().await.get_mut(&tool) {
status.auto_update = enabled;
}
Ok(())
}
pub async fn set_auto_remediate(&self, tool: ProtectionTool, enabled: bool) -> Result<()> {
if let Some(status) = self.tool_status.write().await.get_mut(&tool) {
status.auto_remediate = enabled;
}
Ok(())
}
pub async fn get_scan_history(&self, tool: Option<ProtectionTool>, limit: usize) -> Vec<ScanResult> {
let history = self.scan_history.read().await;
history
.iter()
.filter(|s| tool.is_none() || Some(s.tool) == tool)
.rev()
.take(limit)
.cloned()
.collect()
}
pub async fn get_active_scans(&self) -> Vec<ScanResult> {
self.active_scans.read().await.values().cloned().collect()
}
pub async fn get_report(&self, tool: ProtectionTool) -> Result<String> {
let history = self.scan_history.read().await;
let latest = history
.iter()
.filter(|s| s.tool == tool)
.last()
.ok_or_else(|| anyhow::anyhow!("No scan results found for {tool}"))?;
latest.raw_output.clone().ok_or_else(|| anyhow::anyhow!("No report available"))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_protection_tool_from_str() {
assert_eq!(ProtectionTool::from_str("lynis"), Some(ProtectionTool::Lynis));
assert_eq!(ProtectionTool::from_str("LYNIS"), Some(ProtectionTool::Lynis));
assert_eq!(ProtectionTool::from_str("rkhunter"), Some(ProtectionTool::RKHunter));
assert_eq!(ProtectionTool::from_str("clamav"), Some(ProtectionTool::ClamAV));
assert_eq!(ProtectionTool::from_str("clamscan"), Some(ProtectionTool::ClamAV));
assert_eq!(ProtectionTool::from_str("maldet"), Some(ProtectionTool::LMD));
assert_eq!(ProtectionTool::from_str("unknown"), None);
}
#[test]
fn test_protection_tool_display() {
assert_eq!(format!("{}", ProtectionTool::Lynis), "lynis");
assert_eq!(format!("{}", ProtectionTool::ClamAV), "clamav");
}
#[test]
fn test_tool_status_not_installed() {
let status = ToolStatus::not_installed(ProtectionTool::Lynis);
assert!(!status.installed);
assert!(status.version.is_none());
assert!(status.service_running.is_none());
}
#[test]
fn test_protection_config_default() {
let config = ProtectionConfig::default();
assert_eq!(config.auto_scan_interval_hours, 24);
assert_eq!(config.auto_update_interval_hours, 6);
assert_eq!(config.enabled_tools.len(), 6);
}
#[test]
fn test_protection_tool_all() {
let all = ProtectionTool::all();
assert_eq!(all.len(), 6);
assert!(all.contains(&ProtectionTool::Lynis));
assert!(all.contains(&ProtectionTool::ClamAV));
}
#[test]
fn test_finding_severity() {
let finding = Finding {
id: "test".to_string(),
severity: FindingSeverity::High,
category: "security".to_string(),
title: "Test".to_string(),
description: "Test finding".to_string(),
file_path: None,
remediation: None,
};
assert_eq!(finding.severity, FindingSeverity::High);
}
}

View file

@ -0,0 +1,12 @@
pub mod api;
pub mod chkrootkit;
pub mod installer;
pub mod lmd;
pub mod lynis;
pub mod manager;
pub mod rkhunter;
pub mod suricata;
pub use api::configure_protection_routes;
pub use installer::{InstallResult, ProtectionInstaller, UninstallResult, VerifyResult};
pub use manager::{ProtectionManager, ProtectionTool, ToolStatus};

View file

@ -0,0 +1,320 @@
use anyhow::{Context, Result};
use tracing::info;
use crate::security::command_guard::SafeCommand;
use super::manager::{Finding, FindingSeverity, ScanResultStatus};
const RKHUNTER_LOG_PATH: &str = "/var/log/rkhunter.log";
pub async fn run_scan() -> Result<(ScanResultStatus, Vec<Finding>, String)> {
info!("Running RKHunter rootkit scan");
let output = SafeCommand::new("sudo")?
.arg("rkhunter")?
.arg("--check")?
.arg("--skip-keypress")?
.arg("--report-warnings-only")?
.execute()
.context("Failed to run RKHunter scan")?;
let stdout = String::from_utf8_lossy(&output.stdout).to_string();
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
let raw_output = format!("{stdout}\n{stderr}");
let findings = parse_rkhunter_output(&stdout);
let status = determine_result_status(&findings);
Ok((status, findings, raw_output))
}
pub async fn update_database() -> Result<()> {
info!("Updating RKHunter database");
SafeCommand::new("sudo")?
.arg("rkhunter")?
.arg("--update")?
.execute()
.context("Failed to update RKHunter database")?;
info!("RKHunter database updated successfully");
Ok(())
}
pub async fn update_properties() -> Result<()> {
info!("Updating RKHunter file properties database");
SafeCommand::new("sudo")?
.arg("rkhunter")?
.arg("--propupd")?
.execute()
.context("Failed to update RKHunter properties")?;
info!("RKHunter properties updated successfully");
Ok(())
}
pub fn parse_rkhunter_output(output: &str) -> Vec<Finding> {
let mut findings = Vec::new();
let mut current_section = String::new();
for line in output.lines() {
let trimmed = line.trim();
if trimmed.starts_with("Checking") {
current_section = trimmed.replace("Checking", "").trim().to_string();
continue;
}
if trimmed.contains("[ Warning ]") || trimmed.contains("[Warning]") {
let description = extract_warning_description(trimmed);
let finding = Finding {
id: format!("rkhunter-warn-{}", findings.len()),
severity: FindingSeverity::High,
category: current_section.clone(),
title: "RKHunter Warning".to_string(),
description,
file_path: extract_file_path(trimmed),
remediation: Some("Investigate the flagged file or configuration".to_string()),
};
findings.push(finding);
}
if trimmed.contains("[ Rootkit ]") || trimmed.to_lowercase().contains("rootkit found") {
let finding = Finding {
id: format!("rkhunter-rootkit-{}", findings.len()),
severity: FindingSeverity::Critical,
category: "Rootkit Detection".to_string(),
title: "Potential Rootkit Detected".to_string(),
description: trimmed.to_string(),
file_path: extract_file_path(trimmed),
remediation: Some("Immediately investigate and consider system recovery".to_string()),
};
findings.push(finding);
}
if trimmed.contains("Suspicious file") || trimmed.contains("suspicious") {
let finding = Finding {
id: format!("rkhunter-susp-{}", findings.len()),
severity: FindingSeverity::High,
category: current_section.clone(),
title: "Suspicious File Detected".to_string(),
description: trimmed.to_string(),
file_path: extract_file_path(trimmed),
remediation: Some("Verify the file integrity and source".to_string()),
};
findings.push(finding);
}
if trimmed.contains("[ Bad ]") {
let finding = Finding {
id: format!("rkhunter-bad-{}", findings.len()),
severity: FindingSeverity::High,
category: current_section.clone(),
title: "Bad Configuration or File".to_string(),
description: trimmed.to_string(),
file_path: extract_file_path(trimmed),
remediation: Some("Review and correct the flagged item".to_string()),
};
findings.push(finding);
}
}
findings
}
pub fn parse_log_file() -> Result<RKHunterReport> {
let content = std::fs::read_to_string(RKHUNTER_LOG_PATH)
.context("Failed to read RKHunter log file")?;
let mut report = RKHunterReport::default();
for line in content.lines() {
if line.contains("Rootkits checked") {
if let Some(count) = extract_number_from_line(line) {
report.rootkits_checked = count;
}
}
if line.contains("Possible rootkits") {
if let Some(count) = extract_number_from_line(line) {
report.possible_rootkits = count;
}
}
if line.contains("Suspect files") {
if let Some(count) = extract_number_from_line(line) {
report.suspect_files = count;
}
}
if line.contains("Warning:") {
report.warnings.push(line.replace("Warning:", "").trim().to_string());
}
if line.contains("rkhunter version") {
report.version = line.split(':').nth(1).unwrap_or("").trim().to_string();
}
}
Ok(report)
}
pub async fn get_version() -> Result<String> {
let output = SafeCommand::new("rkhunter")?
.arg("--version")?
.execute()
.context("Failed to get RKHunter version")?;
let stdout = String::from_utf8_lossy(&output.stdout);
let version = stdout
.lines()
.find(|l| l.contains("version"))
.and_then(|l| l.split_whitespace().last())
.unwrap_or("unknown")
.to_string();
Ok(version)
}
fn extract_warning_description(line: &str) -> String {
line.replace("[ Warning ]", "")
.replace("[Warning]", "")
.trim()
.to_string()
}
fn extract_file_path(line: &str) -> Option<String> {
let words: Vec<&str> = line.split_whitespace().collect();
for word in words {
if word.starts_with('/') {
return Some(word.trim_matches(|c| c == ':' || c == ',' || c == ';').to_string());
}
}
None
}
fn extract_number_from_line(line: &str) -> Option<u32> {
line.split_whitespace()
.find_map(|word| word.parse::<u32>().ok())
}
fn determine_result_status(findings: &[Finding]) -> ScanResultStatus {
let has_critical = findings.iter().any(|f| f.severity == FindingSeverity::Critical);
let has_high = findings.iter().any(|f| f.severity == FindingSeverity::High);
let has_medium = findings.iter().any(|f| f.severity == FindingSeverity::Medium);
if has_critical {
ScanResultStatus::Infected
} else if has_high {
ScanResultStatus::Warnings
} else if has_medium {
ScanResultStatus::Warnings
} else {
ScanResultStatus::Clean
}
}
#[derive(Debug, Clone, Default)]
pub struct RKHunterReport {
pub version: String,
pub rootkits_checked: u32,
pub possible_rootkits: u32,
pub suspect_files: u32,
pub warnings: Vec<String>,
pub scan_time: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_rkhunter_output_clean() {
let output = r#"
Checking for rootkits...
Performing check of known rootkit files and directories
55808 Trojan - Variant A [ Not found ]
ADM Worm [ Not found ]
AjaKit Rootkit [ Not found ]
System checks summary
=====================
File properties checks...
Files checked: 142
Suspect files: 0
"#;
let findings = parse_rkhunter_output(output);
assert!(findings.is_empty());
}
#[test]
fn test_parse_rkhunter_output_warning() {
let output = r#"
Checking for rootkits...
Checking /dev for suspicious file types [ Warning ]
Suspicious file found: /dev/.udev/something
"#;
let findings = parse_rkhunter_output(output);
assert!(!findings.is_empty());
assert_eq!(findings[0].severity, FindingSeverity::High);
}
#[test]
fn test_extract_file_path() {
assert_eq!(
extract_file_path("Suspicious file found: /etc/passwd"),
Some("/etc/passwd".to_string())
);
assert_eq!(
extract_file_path("Checking /dev/sda for issues"),
Some("/dev/sda".to_string())
);
assert_eq!(extract_file_path("No path here"), None);
}
#[test]
fn test_extract_number_from_line() {
assert_eq!(extract_number_from_line("Rootkits checked: 42"), Some(42));
assert_eq!(extract_number_from_line("Found 5 issues"), Some(5));
assert_eq!(extract_number_from_line("No numbers here"), None);
}
#[test]
fn test_determine_result_status_clean() {
let findings: Vec<Finding> = vec![];
assert_eq!(determine_result_status(&findings), ScanResultStatus::Clean);
}
#[test]
fn test_determine_result_status_critical() {
let findings = vec![Finding {
id: "test".to_string(),
severity: FindingSeverity::Critical,
category: "test".to_string(),
title: "Test".to_string(),
description: "Test".to_string(),
file_path: None,
remediation: None,
}];
assert_eq!(determine_result_status(&findings), ScanResultStatus::Infected);
}
#[test]
fn test_rkhunter_report_default() {
let report = RKHunterReport::default();
assert_eq!(report.rootkits_checked, 0);
assert_eq!(report.possible_rootkits, 0);
assert!(report.warnings.is_empty());
}
#[test]
fn test_extract_warning_description() {
assert_eq!(
extract_warning_description("[ Warning ] Some issue here"),
"Some issue here"
);
assert_eq!(
extract_warning_description("[Warning] Another issue"),
"Another issue"
);
}
}

View file

@ -0,0 +1,385 @@
use anyhow::{Context, Result};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use tracing::info;
use crate::security::command_guard::SafeCommand;
use super::manager::{Finding, FindingSeverity, ScanResultStatus};
const SURICATA_EVE_LOG: &str = "/var/log/suricata/eve.json";
const SURICATA_FAST_LOG: &str = "/var/log/suricata/fast.log";
const SURICATA_RULES_DIR: &str = "/var/lib/suricata/rules";
pub async fn get_alerts() -> Result<(ScanResultStatus, Vec<Finding>, String)> {
info!("Retrieving Suricata alerts");
let alerts = parse_eve_log(100).await.unwrap_or_default();
let findings = alerts_to_findings(&alerts);
let status = determine_result_status(&findings);
let raw_output = serde_json::to_string_pretty(&alerts).unwrap_or_default();
Ok((status, findings, raw_output))
}
pub async fn start_service() -> Result<()> {
info!("Starting Suricata service");
SafeCommand::new("sudo")?
.arg("systemctl")?
.arg("start")?
.arg("suricata")?
.execute()
.context("Failed to start Suricata service")?;
info!("Suricata service started successfully");
Ok(())
}
pub async fn stop_service() -> Result<()> {
info!("Stopping Suricata service");
SafeCommand::new("sudo")?
.arg("systemctl")?
.arg("stop")?
.arg("suricata")?
.execute()
.context("Failed to stop Suricata service")?;
info!("Suricata service stopped successfully");
Ok(())
}
pub async fn restart_service() -> Result<()> {
info!("Restarting Suricata service");
SafeCommand::new("sudo")?
.arg("systemctl")?
.arg("restart")?
.arg("suricata")?
.execute()
.context("Failed to restart Suricata service")?;
info!("Suricata service restarted successfully");
Ok(())
}
pub async fn update_rules() -> Result<()> {
info!("Updating Suricata rules");
SafeCommand::new("sudo")?
.arg("suricata-update")?
.execute()
.context("Failed to update Suricata rules")?;
SafeCommand::new("sudo")?
.arg("systemctl")?
.arg("reload")?
.arg("suricata")?
.execute()
.context("Failed to reload Suricata after rule update")?;
info!("Suricata rules updated successfully");
Ok(())
}
pub async fn get_version() -> Result<String> {
let output = SafeCommand::new("suricata")?
.arg("--build-info")?
.execute()
.context("Failed to get Suricata version")?;
let stdout = String::from_utf8_lossy(&output.stdout);
let version = stdout
.lines()
.find(|l| l.contains("Suricata version"))
.and_then(|l| l.split(':').nth(1))
.map(|v| v.trim().to_string())
.unwrap_or_else(|| "unknown".to_string());
Ok(version)
}
pub async fn get_rule_count() -> Result<u32> {
let mut count = 0u32;
if let Ok(entries) = std::fs::read_dir(SURICATA_RULES_DIR) {
for entry in entries.flatten() {
let path = entry.path();
if path.extension().is_some_and(|ext| ext == "rules") {
if let Ok(content) = std::fs::read_to_string(&path) {
count += content
.lines()
.filter(|l| !l.trim().starts_with('#') && !l.trim().is_empty())
.count() as u32;
}
}
}
}
Ok(count)
}
pub async fn get_stats() -> Result<SuricataStats> {
let alerts = parse_eve_log(1000).await.unwrap_or_default();
let today = Utc::now().date_naive();
let alerts_today = alerts
.iter()
.filter(|a| a.timestamp.date_naive() == today)
.count() as u32;
let blocked_today = alerts
.iter()
.filter(|a| a.timestamp.date_naive() == today && a.action == "blocked")
.count() as u32;
let rule_count = get_rule_count().await.unwrap_or(0);
Ok(SuricataStats {
alerts_today,
blocked_today,
rule_count,
total_alerts: alerts.len() as u32,
})
}
pub async fn parse_eve_log(limit: usize) -> Result<Vec<SuricataAlert>> {
let content = std::fs::read_to_string(SURICATA_EVE_LOG)
.context("Failed to read Suricata EVE log")?;
let mut alerts = Vec::new();
for line in content.lines().rev().take(limit * 2) {
if let Ok(event) = serde_json::from_str::<EveEvent>(line) {
if event.event_type == "alert" {
if let Some(alert_data) = event.alert {
let alert = SuricataAlert {
timestamp: event.timestamp,
src_ip: event.src_ip.unwrap_or_default(),
src_port: event.src_port.unwrap_or(0),
dest_ip: event.dest_ip.unwrap_or_default(),
dest_port: event.dest_port.unwrap_or(0),
protocol: event.proto.unwrap_or_default(),
signature: alert_data.signature,
signature_id: alert_data.signature_id,
severity: alert_data.severity,
category: alert_data.category,
action: alert_data.action,
};
alerts.push(alert);
if alerts.len() >= limit {
break;
}
}
}
}
}
alerts.reverse();
Ok(alerts)
}
fn alerts_to_findings(alerts: &[SuricataAlert]) -> Vec<Finding> {
alerts
.iter()
.map(|alert| {
let severity = match alert.severity {
1 => FindingSeverity::Critical,
2 => FindingSeverity::High,
3 => FindingSeverity::Medium,
_ => FindingSeverity::Low,
};
Finding {
id: format!("suricata-{}", alert.signature_id),
severity,
category: alert.category.clone(),
title: alert.signature.clone(),
description: format!(
"{}:{} -> {}:{} ({})",
alert.src_ip, alert.src_port, alert.dest_ip, alert.dest_port, alert.protocol
),
file_path: None,
remediation: if alert.action == "blocked" {
Some("Traffic was automatically blocked".to_string())
} else {
Some("Review the alert and consider blocking the source".to_string())
},
}
})
.collect()
}
fn determine_result_status(findings: &[Finding]) -> ScanResultStatus {
let has_critical = findings.iter().any(|f| f.severity == FindingSeverity::Critical);
let has_high = findings.iter().any(|f| f.severity == FindingSeverity::High);
if has_critical {
ScanResultStatus::Infected
} else if has_high || !findings.is_empty() {
ScanResultStatus::Warnings
} else {
ScanResultStatus::Clean
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SuricataAlert {
pub timestamp: DateTime<Utc>,
pub src_ip: String,
pub src_port: u16,
pub dest_ip: String,
pub dest_port: u16,
pub protocol: String,
pub signature: String,
pub signature_id: u64,
pub severity: u8,
pub category: String,
pub action: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SuricataStats {
pub alerts_today: u32,
pub blocked_today: u32,
pub rule_count: u32,
pub total_alerts: u32,
}
#[derive(Debug, Deserialize)]
struct EveEvent {
timestamp: DateTime<Utc>,
event_type: String,
src_ip: Option<String>,
src_port: Option<u16>,
dest_ip: Option<String>,
dest_port: Option<u16>,
proto: Option<String>,
alert: Option<EveAlert>,
}
#[derive(Debug, Deserialize)]
struct EveAlert {
signature: String,
signature_id: u64,
severity: u8,
category: String,
action: String,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_alerts_to_findings_empty() {
let alerts: Vec<SuricataAlert> = vec![];
let findings = alerts_to_findings(&alerts);
assert!(findings.is_empty());
}
#[test]
fn test_alerts_to_findings_severity_mapping() {
let alert = SuricataAlert {
timestamp: Utc::now(),
src_ip: "192.168.1.1".to_string(),
src_port: 12345,
dest_ip: "10.0.0.1".to_string(),
dest_port: 80,
protocol: "TCP".to_string(),
signature: "Test Alert".to_string(),
signature_id: 1000001,
severity: 1,
category: "Test".to_string(),
action: "allowed".to_string(),
};
let findings = alerts_to_findings(&[alert]);
assert_eq!(findings.len(), 1);
assert_eq!(findings[0].severity, FindingSeverity::Critical);
}
#[test]
fn test_alerts_to_findings_severity_high() {
let alert = SuricataAlert {
timestamp: Utc::now(),
src_ip: "192.168.1.1".to_string(),
src_port: 12345,
dest_ip: "10.0.0.1".to_string(),
dest_port: 443,
protocol: "TCP".to_string(),
signature: "High Severity Alert".to_string(),
signature_id: 1000002,
severity: 2,
category: "Test".to_string(),
action: "blocked".to_string(),
};
let findings = alerts_to_findings(&[alert]);
assert_eq!(findings[0].severity, FindingSeverity::High);
assert!(findings[0].remediation.as_ref().unwrap().contains("blocked"));
}
#[test]
fn test_determine_result_status_clean() {
let findings: Vec<Finding> = vec![];
assert_eq!(determine_result_status(&findings), ScanResultStatus::Clean);
}
#[test]
fn test_determine_result_status_critical() {
let findings = vec![Finding {
id: "test".to_string(),
severity: FindingSeverity::Critical,
category: "test".to_string(),
title: "Test".to_string(),
description: "Test".to_string(),
file_path: None,
remediation: None,
}];
assert_eq!(determine_result_status(&findings), ScanResultStatus::Infected);
}
#[test]
fn test_determine_result_status_warnings() {
let findings = vec![Finding {
id: "test".to_string(),
severity: FindingSeverity::Medium,
category: "test".to_string(),
title: "Test".to_string(),
description: "Test".to_string(),
file_path: None,
remediation: None,
}];
assert_eq!(determine_result_status(&findings), ScanResultStatus::Warnings);
}
#[test]
fn test_suricata_stats_default() {
let stats = SuricataStats::default();
assert_eq!(stats.alerts_today, 0);
assert_eq!(stats.blocked_today, 0);
assert_eq!(stats.rule_count, 0);
}
#[test]
fn test_suricata_alert_serialization() {
let alert = SuricataAlert {
timestamp: Utc::now(),
src_ip: "192.168.1.1".to_string(),
src_port: 12345,
dest_ip: "10.0.0.1".to_string(),
dest_port: 80,
protocol: "TCP".to_string(),
signature: "Test".to_string(),
signature_id: 1,
severity: 3,
category: "Test".to_string(),
action: "allowed".to_string(),
};
let json = serde_json::to_string(&alert);
assert!(json.is_ok());
}
}

View file

@ -11,7 +11,7 @@ use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::debug;
use tracing::{debug, warn};
use uuid::Uuid;
use super::auth::{AuthenticatedUser, Permission, Role};
@ -589,7 +589,7 @@ impl RbacManager {
cache.retain(|k, _| !k.starts_with(prefix));
}
async fn check_permission_string(&self, user: &AuthenticatedUser, permission_str: &str) -> bool {
pub async fn check_permission_string(&self, user: &AuthenticatedUser, permission_str: &str) -> bool {
let permission = match permission_str.to_lowercase().as_str() {
"read" => Permission::Read,
"write" => Permission::Write,
@ -706,6 +706,175 @@ impl RequireResourceAccess {
}
}
#[derive(Clone)]
pub struct RbacMiddlewareState {
pub rbac_manager: Arc<RbacManager>,
pub required_permission: Option<String>,
pub required_roles: Vec<Role>,
pub resource_type: Option<String>,
}
impl RbacMiddlewareState {
pub fn new(rbac_manager: Arc<RbacManager>) -> Self {
Self {
rbac_manager,
required_permission: None,
required_roles: Vec::new(),
resource_type: None,
}
}
pub fn with_permission(mut self, permission: &str) -> Self {
self.required_permission = Some(permission.to_string());
self
}
pub fn with_roles(mut self, roles: Vec<Role>) -> Self {
self.required_roles = roles;
self
}
pub fn with_resource_type(mut self, resource_type: &str) -> Self {
self.resource_type = Some(resource_type.to_string());
self
}
}
pub async fn require_permission_middleware(
State(state): State<RbacMiddlewareState>,
request: Request<Body>,
next: Next,
) -> Result<Response, RbacError> {
let user = request
.extensions()
.get::<AuthenticatedUser>()
.cloned()
.unwrap_or_else(AuthenticatedUser::anonymous);
if let Some(ref required_perm) = state.required_permission {
let has_permission = state
.rbac_manager
.check_permission_string(&user, required_perm)
.await;
if !has_permission {
warn!(
"Permission denied for user {}: missing permission {}",
user.user_id, required_perm
);
return Err(RbacError::PermissionDenied(format!(
"Missing required permission: {required_perm}"
)));
}
}
if !state.required_roles.is_empty() {
let has_required_role = state
.required_roles
.iter()
.any(|role| user.has_role(role));
if !has_required_role {
warn!(
"Role check failed for user {}: required one of {:?}",
user.user_id, state.required_roles
);
return Err(RbacError::InsufficientRole(format!(
"Required role: {:?}",
state.required_roles
)));
}
}
Ok(next.run(request).await)
}
pub async fn require_admin_middleware(
request: Request<Body>,
next: Next,
) -> Result<Response, RbacError> {
let user = request
.extensions()
.get::<AuthenticatedUser>()
.cloned()
.unwrap_or_else(AuthenticatedUser::anonymous);
if !user.is_admin() && !user.is_super_admin() {
warn!("Admin access denied for user {}", user.user_id);
return Err(RbacError::AdminRequired);
}
Ok(next.run(request).await)
}
pub async fn require_super_admin_middleware(
request: Request<Body>,
next: Next,
) -> Result<Response, RbacError> {
let user = request
.extensions()
.get::<AuthenticatedUser>()
.cloned()
.unwrap_or_else(AuthenticatedUser::anonymous);
if !user.is_super_admin() {
warn!("Super admin access denied for user {}", user.user_id);
return Err(RbacError::SuperAdminRequired);
}
Ok(next.run(request).await)
}
#[derive(Debug, Clone)]
pub enum RbacError {
PermissionDenied(String),
InsufficientRole(String),
AdminRequired,
SuperAdminRequired,
ResourceAccessDenied(String),
}
impl IntoResponse for RbacError {
fn into_response(self) -> Response {
let (status, message) = match self {
Self::PermissionDenied(msg) => (StatusCode::FORBIDDEN, msg),
Self::InsufficientRole(msg) => (StatusCode::FORBIDDEN, msg),
Self::AdminRequired => (
StatusCode::FORBIDDEN,
"Administrator access required".to_string(),
),
Self::SuperAdminRequired => (
StatusCode::FORBIDDEN,
"Super administrator access required".to_string(),
),
Self::ResourceAccessDenied(msg) => (StatusCode::FORBIDDEN, msg),
};
let body = serde_json::json!({
"error": "access_denied",
"message": message,
"code": "RBAC_DENIED"
});
(status, Json(body)).into_response()
}
}
pub fn create_permission_layer(
rbac_manager: Arc<RbacManager>,
permission: &str,
) -> RbacMiddlewareState {
RbacMiddlewareState::new(rbac_manager).with_permission(permission)
}
pub fn create_role_layer(rbac_manager: Arc<RbacManager>, roles: Vec<Role>) -> RbacMiddlewareState {
RbacMiddlewareState::new(rbac_manager).with_roles(roles)
}
pub fn create_admin_layer(rbac_manager: Arc<RbacManager>) -> RbacMiddlewareState {
RbacMiddlewareState::new(rbac_manager).with_roles(vec![Role::Admin, Role::SuperAdmin])
}
pub fn build_default_route_permissions() -> Vec<RoutePermission> {
vec![
RoutePermission::new("/api/health", "GET", "").with_anonymous(true),

View file

@ -3,6 +3,7 @@ pub mod menu_config;
pub mod permission_inheritance;
pub mod rbac;
pub mod rbac_ui;
pub mod security_admin;
use axum::{
extract::State,
@ -28,6 +29,7 @@ pub fn configure_settings_routes() -> Router<Arc<AppState>> {
)
.route("/api/user/security/devices", get(get_trusted_devices))
.merge(rbac::configure_rbac_routes())
.merge(security_admin::configure_security_admin_routes())
}
async fn get_storage_info(State(_state): State<Arc<AppState>>) -> Html<String> {

View file

@ -0,0 +1,405 @@
use axum::{
extract::{Path, State},
http::StatusCode,
response::IntoResponse,
routing::{delete, get, post},
Json, Router,
};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use uuid::Uuid;
use crate::shared::state::AppState;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SecurityOverview {
pub tls_enabled: bool,
pub mtls_enabled: bool,
pub rate_limiting_enabled: bool,
pub cors_configured: bool,
pub api_keys_count: u32,
pub active_sessions_count: u32,
pub audit_log_enabled: bool,
pub mfa_enabled_users: u32,
pub total_users: u32,
pub last_security_scan: Option<DateTime<Utc>>,
pub security_score: u8,
pub vulnerabilities: SecurityVulnerabilities,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SecurityVulnerabilities {
pub critical: u32,
pub high: u32,
pub medium: u32,
pub low: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TlsSettings {
pub enabled: bool,
pub cert_expiry: Option<DateTime<Utc>>,
pub auto_renew: bool,
pub min_version: String,
pub cipher_suites: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitSettings {
pub enabled: bool,
pub requests_per_minute: u32,
pub burst_size: u32,
pub whitelist: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CorsSettings {
pub enabled: bool,
pub allowed_origins: Vec<String>,
pub allowed_methods: Vec<String>,
pub allowed_headers: Vec<String>,
pub max_age_seconds: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuditLogEntry {
pub id: Uuid,
pub timestamp: DateTime<Utc>,
pub user_id: Option<Uuid>,
pub action: String,
pub resource: String,
pub resource_id: Option<String>,
pub ip_address: Option<String>,
pub user_agent: Option<String>,
pub success: bool,
pub details: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApiKeyInfo {
pub id: Uuid,
pub name: String,
pub prefix: String,
pub created_at: DateTime<Utc>,
pub last_used_at: Option<DateTime<Utc>>,
pub expires_at: Option<DateTime<Utc>>,
pub scopes: Vec<String>,
pub is_active: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateApiKeyRequest {
pub name: String,
pub scopes: Vec<String>,
pub expires_in_days: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateApiKeyResponse {
pub id: Uuid,
pub name: String,
pub key: String,
pub expires_at: Option<DateTime<Utc>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MfaSettings {
pub require_mfa: bool,
pub allowed_methods: Vec<String>,
pub grace_period_days: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionInfo {
pub id: Uuid,
pub user_id: Uuid,
pub user_email: String,
pub created_at: DateTime<Utc>,
pub last_activity: DateTime<Utc>,
pub ip_address: Option<String>,
pub user_agent: Option<String>,
pub is_current: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PasswordPolicy {
pub min_length: u8,
pub require_uppercase: bool,
pub require_lowercase: bool,
pub require_numbers: bool,
pub require_special_chars: bool,
pub max_age_days: Option<u32>,
pub prevent_reuse_count: u8,
}
#[derive(Debug, Serialize)]
pub struct SecurityError {
pub error: String,
pub code: String,
}
impl IntoResponse for SecurityError {
fn into_response(self) -> axum::response::Response {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"error": self.error, "code": self.code})),
)
.into_response()
}
}
async fn get_security_overview(
State(_state): State<Arc<AppState>>,
) -> Result<Json<SecurityOverview>, SecurityError> {
let overview = SecurityOverview {
tls_enabled: true,
mtls_enabled: false,
rate_limiting_enabled: true,
cors_configured: true,
api_keys_count: 5,
active_sessions_count: 12,
audit_log_enabled: true,
mfa_enabled_users: 8,
total_users: 25,
last_security_scan: Some(Utc::now()),
security_score: 85,
vulnerabilities: SecurityVulnerabilities {
critical: 0,
high: 1,
medium: 3,
low: 7,
},
};
Ok(Json(overview))
}
async fn get_tls_settings(
State(_state): State<Arc<AppState>>,
) -> Result<Json<TlsSettings>, SecurityError> {
let settings = TlsSettings {
enabled: true,
cert_expiry: Some(Utc::now() + chrono::Duration::days(90)),
auto_renew: true,
min_version: "TLS 1.2".to_string(),
cipher_suites: vec![
"TLS_AES_256_GCM_SHA384".to_string(),
"TLS_CHACHA20_POLY1305_SHA256".to_string(),
"TLS_AES_128_GCM_SHA256".to_string(),
],
};
Ok(Json(settings))
}
async fn update_tls_settings(
State(_state): State<Arc<AppState>>,
Json(_settings): Json<TlsSettings>,
) -> Result<Json<TlsSettings>, SecurityError> {
let settings = TlsSettings {
enabled: true,
cert_expiry: Some(Utc::now() + chrono::Duration::days(90)),
auto_renew: true,
min_version: "TLS 1.2".to_string(),
cipher_suites: vec![
"TLS_AES_256_GCM_SHA384".to_string(),
"TLS_CHACHA20_POLY1305_SHA256".to_string(),
],
};
Ok(Json(settings))
}
async fn get_rate_limit_settings(
State(_state): State<Arc<AppState>>,
) -> Result<Json<RateLimitSettings>, SecurityError> {
let settings = RateLimitSettings {
enabled: true,
requests_per_minute: 60,
burst_size: 100,
whitelist: vec![],
};
Ok(Json(settings))
}
async fn update_rate_limit_settings(
State(_state): State<Arc<AppState>>,
Json(settings): Json<RateLimitSettings>,
) -> Result<Json<RateLimitSettings>, SecurityError> {
Ok(Json(settings))
}
async fn get_cors_settings(
State(_state): State<Arc<AppState>>,
) -> Result<Json<CorsSettings>, SecurityError> {
let settings = CorsSettings {
enabled: true,
allowed_origins: vec!["*".to_string()],
allowed_methods: vec![
"GET".to_string(),
"POST".to_string(),
"PUT".to_string(),
"DELETE".to_string(),
],
allowed_headers: vec!["Content-Type".to_string(), "Authorization".to_string()],
max_age_seconds: 3600,
};
Ok(Json(settings))
}
async fn update_cors_settings(
State(_state): State<Arc<AppState>>,
Json(settings): Json<CorsSettings>,
) -> Result<Json<CorsSettings>, SecurityError> {
Ok(Json(settings))
}
async fn list_audit_logs(
State(_state): State<Arc<AppState>>,
) -> Result<Json<Vec<AuditLogEntry>>, SecurityError> {
let logs = vec![
AuditLogEntry {
id: Uuid::new_v4(),
timestamp: Utc::now(),
user_id: Some(Uuid::new_v4()),
action: "login".to_string(),
resource: "session".to_string(),
resource_id: None,
ip_address: Some("192.168.1.100".to_string()),
user_agent: Some("Mozilla/5.0".to_string()),
success: true,
details: None,
},
];
Ok(Json(logs))
}
async fn list_api_keys(
State(_state): State<Arc<AppState>>,
) -> Result<Json<Vec<ApiKeyInfo>>, SecurityError> {
let keys = vec![];
Ok(Json(keys))
}
async fn create_api_key(
State(_state): State<Arc<AppState>>,
Json(req): Json<CreateApiKeyRequest>,
) -> Result<Json<CreateApiKeyResponse>, SecurityError> {
let response = CreateApiKeyResponse {
id: Uuid::new_v4(),
name: req.name,
key: format!("gb_{}", Uuid::new_v4().to_string().replace('-', "")),
expires_at: req.expires_in_days.map(|days| Utc::now() + chrono::Duration::days(i64::from(days))),
};
Ok(Json(response))
}
async fn revoke_api_key(
State(_state): State<Arc<AppState>>,
Path(_key_id): Path<Uuid>,
) -> Result<StatusCode, SecurityError> {
Ok(StatusCode::NO_CONTENT)
}
async fn get_mfa_settings(
State(_state): State<Arc<AppState>>,
) -> Result<Json<MfaSettings>, SecurityError> {
let settings = MfaSettings {
require_mfa: false,
allowed_methods: vec!["totp".to_string(), "webauthn".to_string()],
grace_period_days: 7,
};
Ok(Json(settings))
}
async fn update_mfa_settings(
State(_state): State<Arc<AppState>>,
Json(settings): Json<MfaSettings>,
) -> Result<Json<MfaSettings>, SecurityError> {
Ok(Json(settings))
}
async fn list_active_sessions(
State(_state): State<Arc<AppState>>,
) -> Result<Json<Vec<SessionInfo>>, SecurityError> {
let sessions = vec![];
Ok(Json(sessions))
}
async fn revoke_session(
State(_state): State<Arc<AppState>>,
Path(_session_id): Path<Uuid>,
) -> Result<StatusCode, SecurityError> {
Ok(StatusCode::NO_CONTENT)
}
async fn revoke_all_user_sessions(
State(_state): State<Arc<AppState>>,
Path(_user_id): Path<Uuid>,
) -> Result<StatusCode, SecurityError> {
Ok(StatusCode::NO_CONTENT)
}
async fn get_password_policy(
State(_state): State<Arc<AppState>>,
) -> Result<Json<PasswordPolicy>, SecurityError> {
let policy = PasswordPolicy {
min_length: 12,
require_uppercase: true,
require_lowercase: true,
require_numbers: true,
require_special_chars: true,
max_age_days: Some(90),
prevent_reuse_count: 5,
};
Ok(Json(policy))
}
async fn update_password_policy(
State(_state): State<Arc<AppState>>,
Json(policy): Json<PasswordPolicy>,
) -> Result<Json<PasswordPolicy>, SecurityError> {
Ok(Json(policy))
}
async fn run_security_scan(
State(state): State<Arc<AppState>>,
) -> Result<Json<SecurityOverview>, SecurityError> {
get_security_overview(State(state)).await
}
pub fn configure_security_admin_routes() -> Router<Arc<AppState>> {
Router::new()
.route("/api/settings/security/overview", get(get_security_overview))
.route("/api/settings/security/scan", post(run_security_scan))
.route(
"/api/settings/security/tls",
get(get_tls_settings).put(update_tls_settings),
)
.route(
"/api/settings/security/rate-limit",
get(get_rate_limit_settings).put(update_rate_limit_settings),
)
.route(
"/api/settings/security/cors",
get(get_cors_settings).put(update_cors_settings),
)
.route("/api/settings/security/audit", get(list_audit_logs))
.route(
"/api/settings/security/api-keys",
get(list_api_keys).post(create_api_key),
)
.route("/api/settings/security/api-keys/:key_id", delete(revoke_api_key))
.route(
"/api/settings/security/mfa",
get(get_mfa_settings).put(update_mfa_settings),
)
.route("/api/settings/security/sessions", get(list_active_sessions))
.route("/api/settings/security/sessions/:session_id", delete(revoke_session))
.route(
"/api/settings/security/users/:user_id/sessions",
delete(revoke_all_user_sessions),
)
.route(
"/api/settings/security/password-policy",
get(get_password_policy).put(update_password_policy),
)
}

View file

@ -1178,10 +1178,40 @@ body {{ font-family: system-ui, sans-serif; background: var(--bg); color: var(--
)
}
async fn handle_get_suggested_communities_html() -> Html<String> {
Html(r#"
<div class="community-suggestion">
<div class="community-avatar">🌐</div>
<div class="community-info">
<span class="community-name">General Discussion</span>
<span class="community-members">128 members</span>
</div>
<button class="btn-join" hx-post="/api/social/communities/general/join" hx-swap="outerHTML">Join</button>
</div>
<div class="community-suggestion">
<div class="community-avatar">💡</div>
<div class="community-info">
<span class="community-name">Ideas & Feedback</span>
<span class="community-members">64 members</span>
</div>
<button class="btn-join" hx-post="/api/social/communities/ideas/join" hx-swap="outerHTML">Join</button>
</div>
<div class="community-suggestion">
<div class="community-avatar">🎉</div>
<div class="community-info">
<span class="community-name">Announcements</span>
<span class="community-members">256 members</span>
</div>
<button class="btn-join" hx-post="/api/social/communities/announcements/join" hx-swap="outerHTML">Join</button>
</div>
"#.to_string())
}
pub fn configure_social_routes() -> Router<Arc<AppState>> {
Router::new()
.route("/api/social/feed", get(handle_get_feed))
.route("/api/ui/social/feed", get(handle_get_feed_html))
.route("/api/ui/social/suggested", get(handle_get_suggested_communities_html))
.route("/api/social/posts", post(handle_create_post))
.route("/api/social/posts/:id", get(handle_get_post))
.route("/api/social/posts/:id", put(handle_update_post))

View file

@ -99,7 +99,6 @@ pub fn configure_video_routes() -> Router<Arc<AppState>> {
)
.route("/api/video/analytics/view", post(record_view_handler))
.route("/api/video/ws/export/:id", get(export_progress_websocket))
.route("/video", get(video_ui))
}
pub fn configure(router: Router<Arc<AppState>>) -> Router<Arc<AppState>> {

View file

@ -1,3 +1,10 @@
use axum::{
extract::{Path, State},
http::StatusCode,
response::IntoResponse,
routing::{delete, get, post},
Json, Router,
};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
@ -5,6 +12,8 @@ use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
use crate::shared::state::AppState;
pub mod blocks;
pub mod pages;
pub mod collaboration;
@ -1291,3 +1300,211 @@ impl std::fmt::Display for WorkspacesError {
}
impl std::error::Error for WorkspacesError {}
impl IntoResponse for WorkspacesError {
fn into_response(self) -> axum::response::Response {
let (status, message) = match &self {
Self::WorkspaceNotFound | Self::PageNotFound | Self::BlockNotFound
| Self::CommentNotFound | Self::VersionNotFound | Self::MemberNotFound => {
(StatusCode::NOT_FOUND, self.to_string())
}
Self::PermissionDenied => (StatusCode::FORBIDDEN, self.to_string()),
Self::MemberAlreadyExists | Self::CannotRemoveLastOwner | Self::InvalidOperation(_) => {
(StatusCode::BAD_REQUEST, self.to_string())
}
};
(status, Json(serde_json::json!({"error": message}))).into_response()
}
}
#[derive(Debug, Deserialize)]
pub struct CreateWorkspaceRequest {
pub name: String,
pub description: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct UpdateWorkspaceRequest {
pub name: Option<String>,
pub description: Option<String>,
pub icon: Option<WorkspaceIcon>,
}
#[derive(Debug, Deserialize)]
pub struct CreatePageRequest {
pub title: String,
pub parent_id: Option<Uuid>,
}
#[derive(Debug, Deserialize)]
pub struct UpdatePageRequest {
pub title: Option<String>,
pub icon: Option<WorkspaceIcon>,
}
#[derive(Debug, Deserialize)]
pub struct AddMemberRequest {
pub user_id: Uuid,
pub role: WorkspaceRole,
}
async fn list_workspaces(
State(_state): State<Arc<AppState>>,
) -> Json<Vec<Workspace>> {
let service = WorkspacesService::new();
let org_id = Uuid::nil();
let workspaces = service.list_workspaces(org_id).await;
Json(workspaces)
}
async fn create_workspace(
State(_state): State<Arc<AppState>>,
Json(req): Json<CreateWorkspaceRequest>,
) -> Result<Json<Workspace>, WorkspacesError> {
let service = WorkspacesService::new();
let org_id = Uuid::nil();
let user_id = Uuid::nil();
let workspace = service.create_workspace(org_id, &req.name, user_id).await?;
Ok(Json(workspace))
}
async fn get_workspace(
State(_state): State<Arc<AppState>>,
Path(workspace_id): Path<Uuid>,
) -> Result<Json<Workspace>, WorkspacesError> {
let service = WorkspacesService::new();
let workspace = service.get_workspace(workspace_id).await.ok_or(WorkspacesError::WorkspaceNotFound)?;
Ok(Json(workspace))
}
async fn update_workspace(
State(_state): State<Arc<AppState>>,
Path(workspace_id): Path<Uuid>,
Json(req): Json<UpdateWorkspaceRequest>,
) -> Result<Json<Workspace>, WorkspacesError> {
let service = WorkspacesService::new();
let workspace = service.update_workspace(workspace_id, req.name, req.description, req.icon).await?;
Ok(Json(workspace))
}
async fn delete_workspace(
State(_state): State<Arc<AppState>>,
Path(workspace_id): Path<Uuid>,
) -> Result<StatusCode, WorkspacesError> {
let service = WorkspacesService::new();
service.delete_workspace(workspace_id).await?;
Ok(StatusCode::NO_CONTENT)
}
async fn list_pages(
State(_state): State<Arc<AppState>>,
Path(workspace_id): Path<Uuid>,
) -> Json<Vec<PageTreeNode>> {
let service = WorkspacesService::new();
let pages = service.get_page_tree(workspace_id).await;
Json(pages)
}
async fn create_page(
State(_state): State<Arc<AppState>>,
Path(workspace_id): Path<Uuid>,
Json(req): Json<CreatePageRequest>,
) -> Result<Json<Page>, WorkspacesError> {
let service = WorkspacesService::new();
let user_id = Uuid::nil();
let page = service.create_page(workspace_id, req.parent_id, &req.title, user_id).await?;
Ok(Json(page))
}
async fn get_page(
State(_state): State<Arc<AppState>>,
Path(page_id): Path<Uuid>,
) -> Result<Json<Page>, WorkspacesError> {
let service = WorkspacesService::new();
let page = service.get_page(page_id).await.ok_or(WorkspacesError::PageNotFound)?;
Ok(Json(page))
}
async fn update_page(
State(_state): State<Arc<AppState>>,
Path(page_id): Path<Uuid>,
Json(req): Json<UpdatePageRequest>,
) -> Result<Json<Page>, WorkspacesError> {
let service = WorkspacesService::new();
let user_id = Uuid::nil();
let page = service.update_page(page_id, req.title, req.icon, None, user_id).await?;
Ok(Json(page))
}
async fn delete_page(
State(_state): State<Arc<AppState>>,
Path(page_id): Path<Uuid>,
) -> Result<StatusCode, WorkspacesError> {
let service = WorkspacesService::new();
service.delete_page(page_id).await?;
Ok(StatusCode::NO_CONTENT)
}
async fn add_member(
State(_state): State<Arc<AppState>>,
Path(workspace_id): Path<Uuid>,
Json(req): Json<AddMemberRequest>,
) -> Result<StatusCode, WorkspacesError> {
let service = WorkspacesService::new();
let inviter_id = Uuid::nil();
service.add_member(workspace_id, req.user_id, req.role, inviter_id).await?;
Ok(StatusCode::CREATED)
}
async fn remove_member(
State(_state): State<Arc<AppState>>,
Path((workspace_id, user_id)): Path<(Uuid, Uuid)>,
) -> Result<StatusCode, WorkspacesError> {
let service = WorkspacesService::new();
service.remove_member(workspace_id, user_id).await?;
Ok(StatusCode::NO_CONTENT)
}
async fn search_pages(
State(_state): State<Arc<AppState>>,
Path(workspace_id): Path<Uuid>,
axum::extract::Query(params): axum::extract::Query<HashMap<String, String>>,
) -> Json<Vec<PageSearchResult>> {
let service = WorkspacesService::new();
let query = params.get("q").cloned().unwrap_or_default();
let results = service.search_pages(workspace_id, &query).await;
Json(results)
}
async fn get_slash_commands_handler(
State(_state): State<Arc<AppState>>,
) -> Json<Vec<SlashCommand>> {
Json(get_slash_commands())
}
pub fn configure_workspaces_routes() -> Router<Arc<AppState>> {
Router::new()
.route("/api/workspaces", get(list_workspaces).post(create_workspace))
.route(
"/api/workspaces/:workspace_id",
get(get_workspace).put(update_workspace).delete(delete_workspace),
)
.route(
"/api/workspaces/:workspace_id/pages",
get(list_pages).post(create_page),
)
.route(
"/api/workspaces/:workspace_id/members",
post(add_member),
)
.route(
"/api/workspaces/:workspace_id/members/:user_id",
delete(remove_member),
)
.route(
"/api/workspaces/:workspace_id/search",
get(search_pages),
)
.route("/api/pages/:page_id", get(get_page).put(update_page).delete(delete_page))
.route("/api/workspaces/commands", get(get_slash_commands_handler))
}