diff --git a/src/auto_task/ask_later.rs b/src/auto_task/ask_later.rs index ec1343b72..7c18fb837 100644 --- a/src/auto_task/ask_later.rs +++ b/src/auto_task/ask_later.rs @@ -1,5 +1,14 @@ use crate::core::shared::models::UserSession; -use crate::core::shared::state::AppState; +use crate::shared::state::AppState; + +fn is_sensitive_config_key(key: &str) -> bool { + let key_lower = key.to_lowercase(); + let sensitive_patterns = [ + "password", "secret", "token", "key", "credential", "auth", + "api_key", "apikey", "pass", "pwd", "cert", "private", + ]; + sensitive_patterns.iter().any(|p| key_lower.contains(p)) +} use diesel::prelude::*; use diesel::sql_query; use diesel::sql_types::Text; @@ -75,7 +84,12 @@ pub fn ask_later_keyword(state: Arc, user: UserSession, engine: &mut E match fill_pending_info(&state, &user, config_key, value) { Ok(_) => { - info!("Pending info filled: {} = {}", config_key, value); + let safe_value = if is_sensitive_config_key(config_key) { + "[REDACTED]" + } else { + value + }; + info!("Pending info filled: {} = {}", config_key, safe_value); true } Err(e) => { diff --git a/src/main.rs b/src/main.rs index 866dc2cb2..84b7d483e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,9 +16,10 @@ use tower_http::services::ServeDir; use tower_http::trace::TraceLayer; use botserver::security::{ - create_cors_layer, create_rate_limit_layer, create_security_headers_layer, - request_id_middleware, security_headers_middleware, set_global_panic_hook, - HttpRateLimitConfig, PanicHandlerConfig, SecurityHeadersConfig, + auth_middleware, create_cors_layer, create_rate_limit_layer, create_security_headers_layer, + request_id_middleware, security_headers_middleware, set_cors_allowed_origins, + set_global_panic_hook, AuthConfig, HttpRateLimitConfig, PanicHandlerConfig, + SecurityHeadersConfig, }; use botlib::SystemLimits; @@ -145,11 +146,44 @@ async fn run_axum_server( port: u16, _worker_count: usize, ) -> std::io::Result<()> { - // Use hardened CORS configuration instead of allowing everything - // In production, set CORS_ALLOWED_ORIGINS env var to restrict origins - // In development, localhost origins are allowed by default + // Load CORS allowed origins from bot config database if available + // Config key: cors-allowed-origins in config.csv + if let Ok(mut conn) = app_state.conn.get() { + use crate::shared::models::schema::bot_configuration::dsl::*; + use diesel::prelude::*; + + if let Ok(origins_str) = bot_configuration + .filter(config_key.eq("cors-allowed-origins")) + .select(config_value) + .first::(&mut conn) + { + let origins: Vec = origins_str + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect(); + if !origins.is_empty() { + info!("Loaded {} CORS allowed origins from config", origins.len()); + set_cors_allowed_origins(origins); + } + } + } + + // Use hardened CORS configuration + // Origins configured via config.csv cors-allowed-origins or Vault let cors = create_cors_layer(); + // Create auth config for protected routes + let auth_config = Arc::new(AuthConfig::default() + .add_anonymous_path("/health") + .add_anonymous_path("/healthz") + .add_anonymous_path("/api/health") + .add_anonymous_path("/api/v1/health") + .add_anonymous_path("/ws") + .add_anonymous_path("/auth") + .add_public_path("/static") + .add_public_path("/favicon.ico")); + use crate::core::urls::ApiUrls; let mut api_router = Router::new() @@ -260,10 +294,15 @@ async fn run_axum_server( PanicHandlerConfig::development() }; - info!("Security middleware enabled: rate limiting, security headers, panic handler, request ID tracking"); + info!("Security middleware enabled: rate limiting, security headers, panic handler, request ID tracking, authentication"); let app = Router::new() .merge(api_router.with_state(app_state.clone())) + // Authentication middleware for protected routes + .layer(middleware::from_fn_with_state( + auth_config.clone(), + auth_middleware, + )) // Static files fallback for legacy /apps/* paths .nest_service("/static", ServeDir::new(&site_path)) // Security middleware stack (order matters - first added is outermost) diff --git a/src/security/cors.rs b/src/security/cors.rs index d9f7ff176..c8ad400af 100644 --- a/src/security/cors.rs +++ b/src/security/cors.rs @@ -1,6 +1,7 @@ use axum::http::{header, HeaderValue, Method}; use std::collections::HashSet; use tower_http::cors::{AllowOrigin, CorsLayer}; +use tracing::info; #[derive(Debug, Clone)] pub struct CorsConfig { @@ -51,6 +52,25 @@ impl CorsConfig { Self::default() } + pub fn from_config_value(allowed_origins: Option<&str>) -> Self { + let mut config = Self::production(); + + if let Some(origins) = allowed_origins { + let origins: Vec = origins + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect(); + + if !origins.is_empty() { + info!("CORS configured with {} allowed origins", origins.len()); + config.allowed_origins = origins; + } + } + + config + } + pub fn production() -> Self { Self { allowed_origins: vec![], @@ -188,7 +208,7 @@ impl CorsConfig { let mut cors = CorsLayer::new(); if self.allowed_origins.is_empty() { - let allowed_env_origins = get_allowed_origins_from_env(); + let allowed_env_origins = get_allowed_origins_from_config(); if allowed_env_origins.is_empty() { cors = cors.allow_origin(AllowOrigin::predicate(validate_origin)); } else { @@ -241,15 +261,24 @@ impl CorsConfig { } } -fn get_allowed_origins_from_env() -> Vec { - std::env::var("CORS_ALLOWED_ORIGINS") - .map(|v| { - v.split(',') - .map(|s| s.trim().to_string()) - .filter(|s| !s.is_empty()) - .collect() - }) - .unwrap_or_default() +fn get_allowed_origins_from_config() -> Vec { + if let Some(origins) = CORS_ALLOWED_ORIGINS.read().ok().and_then(|g| g.clone()) { + return origins; + } + Vec::new() +} + +static CORS_ALLOWED_ORIGINS: std::sync::RwLock>> = std::sync::RwLock::new(None); + +pub fn set_cors_allowed_origins(origins: Vec) { + if let Ok(mut guard) = CORS_ALLOWED_ORIGINS.write() { + info!("Setting CORS allowed origins: {:?}", origins); + *guard = Some(origins); + } +} + +pub fn get_cors_allowed_origins() -> Vec { + get_allowed_origins_from_config() } fn validate_origin(origin: &HeaderValue, _request: &axum::http::request::Parts) -> bool { @@ -262,9 +291,9 @@ fn validate_origin(origin: &HeaderValue, _request: &axum::http::request::Parts) return false; } - let env_origins = get_allowed_origins_from_env(); - if !env_origins.is_empty() { - return env_origins.iter().any(|allowed| allowed == origin_str); + let config_origins = get_allowed_origins_from_config(); + if !config_origins.is_empty() { + return config_origins.iter().any(|allowed| allowed == origin_str); } if is_valid_origin_format(origin_str) { @@ -305,14 +334,22 @@ fn is_valid_origin_format(origin: &str) -> bool { } pub fn create_cors_layer() -> CorsLayer { - let is_production = std::env::var("BOTSERVER_ENV") - .map(|v| v == "production" || v == "prod") - .unwrap_or(false); + let config_origins = get_allowed_origins_from_config(); - if is_production { + if !config_origins.is_empty() { + info!("Creating CORS layer with configured origins"); + CorsConfig::production().with_origins(config_origins).build() + } else { + info!("Creating CORS layer with development defaults (no origins configured)"); + CorsConfig::development().build() + } +} + +pub fn create_cors_layer_for_production(allowed_origins: Vec) -> CorsLayer { + if allowed_origins.is_empty() { CorsConfig::production().build() } else { - CorsConfig::development().build() + CorsConfig::production().with_origins(allowed_origins).build() } } @@ -357,35 +394,29 @@ impl OriginValidator { self } - pub fn from_env() -> Self { + pub fn from_config(origins: Vec, patterns: Vec, allow_localhost: bool) -> Self { let mut validator = Self::new(); - if let Ok(origins) = std::env::var("CORS_ALLOWED_ORIGINS") { - for origin in origins.split(',') { - let trimmed = origin.trim(); - if !trimmed.is_empty() { - validator.allowed_origins.insert(trimmed.to_string()); - } + for origin in origins { + if !origin.is_empty() { + validator.allowed_origins.insert(origin); } } - if let Ok(patterns) = std::env::var("CORS_ALLOWED_PATTERNS") { - for pattern in patterns.split(',') { - let trimmed = pattern.trim(); - if !trimmed.is_empty() { - validator.allowed_patterns.push(trimmed.to_string()); - } + for pattern in patterns { + if !pattern.is_empty() { + validator.allowed_patterns.push(pattern); } } - let allow_localhost = std::env::var("CORS_ALLOW_LOCALHOST") - .map(|v| v == "true" || v == "1") - .unwrap_or(false); validator.allow_localhost = allow_localhost; - validator } + pub fn from_allowed_origins(origins: Vec) -> Self { + Self::from_config(origins, Vec::new(), false) + } + pub fn is_allowed(&self, origin: &str) -> bool { if self.allowed_origins.contains(origin) { return true; diff --git a/src/security/mod.rs b/src/security/mod.rs index 51b842f73..3b920a390 100644 --- a/src/security/mod.rs +++ b/src/security/mod.rs @@ -35,7 +35,8 @@ pub use cert_pinning::{ CertPinningManager, PinType, PinValidationResult, PinnedCert, PinningStats, }; pub use cors::{ - create_cors_layer, create_cors_layer_with_origins, CorsConfig, OriginValidator, + create_cors_layer, create_cors_layer_for_production, create_cors_layer_with_origins, + get_cors_allowed_origins, set_cors_allowed_origins, CorsConfig, OriginValidator, }; pub use command_guard::{ has_nvidia_gpu_safe, safe_nvidia_smi, safe_pandoc_async, safe_pdftotext_async, diff --git a/src/security/zitadel_auth.rs b/src/security/zitadel_auth.rs index a6edc732d..06c07bc35 100644 --- a/src/security/zitadel_auth.rs +++ b/src/security/zitadel_auth.rs @@ -1,3 +1,4 @@ +use crate::core::secrets::SecretsManager; use crate::security::auth::{AuthConfig, AuthError, AuthenticatedUser, BotAccess, Permission, Role}; use anyhow::{anyhow, Result}; use axum::{ @@ -8,7 +9,7 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::RwLock; -use tracing::{error, warn}; +use tracing::{error, info, warn}; use uuid::Uuid; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -16,6 +17,7 @@ pub struct ZitadelAuthConfig { pub issuer_url: String, pub api_url: String, pub client_id: String, + #[serde(skip_serializing)] pub client_secret: String, pub project_id: String, pub cache_ttl_secs: u64, @@ -25,13 +27,11 @@ pub struct ZitadelAuthConfig { impl Default for ZitadelAuthConfig { fn default() -> Self { Self { - issuer_url: std::env::var("ZITADEL_ISSUER_URL") - .unwrap_or_else(|_| "https://localhost:8080".to_string()), - api_url: std::env::var("ZITADEL_API_URL") - .unwrap_or_else(|_| "https://localhost:8080".to_string()), - client_id: std::env::var("ZITADEL_CLIENT_ID").unwrap_or_default(), - client_secret: std::env::var("ZITADEL_CLIENT_SECRET").unwrap_or_default(), - project_id: std::env::var("ZITADEL_PROJECT_ID").unwrap_or_default(), + issuer_url: "https://localhost:8080".to_string(), + api_url: "https://localhost:8080".to_string(), + client_id: String::new(), + client_secret: String::new(), + project_id: String::new(), cache_ttl_secs: 300, introspect_tokens: true, } @@ -51,6 +51,36 @@ impl ZitadelAuthConfig { } } + pub async fn from_vault(secrets: &SecretsManager) -> Result { + let (url, project_id, client_id, client_secret) = secrets.get_directory_config().await?; + + info!("Loaded Zitadel configuration from Vault"); + + Ok(Self { + issuer_url: url.clone(), + api_url: url, + client_id, + client_secret, + project_id, + cache_ttl_secs: 300, + introspect_tokens: true, + }) + } + + pub async fn from_vault_or_default(secrets: &SecretsManager) -> Self { + match Self::from_vault(secrets).await { + Ok(config) => config, + Err(e) => { + warn!("Failed to load Zitadel config from Vault: {}. Using defaults.", e); + Self::default() + } + } + } + + pub fn is_configured(&self) -> bool { + !self.client_id.is_empty() && !self.client_secret.is_empty() + } + pub fn with_project_id(mut self, project_id: impl Into) -> Self { self.project_id = project_id.into(); self @@ -667,6 +697,18 @@ mod tests { let config = ZitadelAuthConfig::default(); assert_eq!(config.cache_ttl_secs, 300); assert!(config.introspect_tokens); + assert!(!config.is_configured()); + } + + #[test] + fn test_zitadel_auth_config_is_configured() { + let config = ZitadelAuthConfig::new( + "https://auth.example.com", + "https://api.example.com", + "client123", + "secret456", + ); + assert!(config.is_configured()); } #[test]