use axum::{ body::Body, http::{Request, StatusCode}, middleware::Next, response::{IntoResponse, Response}, Json, }; use futures_util::FutureExt; use serde_json::json; use std::panic::{catch_unwind, AssertUnwindSafe}; use tracing::{error, warn}; #[derive(Debug, Clone)] pub struct PanicHandlerConfig { pub log_panics: bool, pub include_backtrace: bool, pub custom_message: Option, pub notify_on_panic: bool, } impl Default for PanicHandlerConfig { fn default() -> Self { Self { log_panics: true, include_backtrace: cfg!(debug_assertions), custom_message: None, notify_on_panic: false, } } } impl PanicHandlerConfig { pub fn new() -> Self { Self::default() } pub fn production() -> Self { Self { log_panics: true, include_backtrace: false, custom_message: Some("An unexpected error occurred. Please try again later.".to_string()), notify_on_panic: true, } } pub fn development() -> Self { Self { log_panics: true, include_backtrace: true, custom_message: None, notify_on_panic: false, } } pub fn with_message(mut self, message: impl Into) -> Self { self.custom_message = Some(message.into()); self } pub fn with_backtrace(mut self, include: bool) -> Self { self.include_backtrace = include; self } pub fn with_logging(mut self, log: bool) -> Self { self.log_panics = log; self } pub fn with_notification(mut self, notify: bool) -> Self { self.notify_on_panic = notify; self } } pub async fn panic_handler_middleware(request: Request, next: Next) -> Response { panic_handler_middleware_with_config(request, next, &PanicHandlerConfig::default()).await } pub async fn panic_handler_middleware_with_config( request: Request, next: Next, config: &PanicHandlerConfig, ) -> Response { let method = request.method().clone(); let uri = request.uri().clone(); let request_id = request .headers() .get("x-request-id") .and_then(|v| v.to_str().ok()) .map(|s| s.to_string()) .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()); let result = AssertUnwindSafe(next.run(request)).catch_unwind().await; match result { Ok(response) => response, Err(panic_info) => { let panic_message = extract_panic_message(&panic_info); if config.log_panics { error!( request_id = %request_id, method = %method, uri = %uri, panic_message = %panic_message, "Request handler panicked" ); if config.include_backtrace { let backtrace = std::backtrace::Backtrace::capture(); error!(backtrace = %backtrace, "Panic backtrace"); } } if config.notify_on_panic { notify_panic(&request_id, method.as_ref(), &uri.to_string(), &panic_message); } create_panic_response(&request_id, config) } } } fn extract_panic_message(panic_info: &Box) -> String { if let Some(s) = panic_info.downcast_ref::<&str>() { s.to_string() } else if let Some(s) = panic_info.downcast_ref::() { s.clone() } else { "Unknown panic".to_string() } } fn create_panic_response(request_id: &str, config: &PanicHandlerConfig) -> Response { let message = config .custom_message .as_deref() .unwrap_or("An internal error occurred"); let body = json!({ "error": "internal_server_error", "message": message, "request_id": request_id }); (StatusCode::INTERNAL_SERVER_ERROR, Json(body)).into_response() } fn notify_panic(request_id: &str, method: &str, uri: &str, message: &str) { warn!( request_id = %request_id, method = %method, uri = %uri, message = %message, "PANIC NOTIFICATION: Server panic occurred" ); } pub fn set_global_panic_hook() { std::panic::set_hook(Box::new(|panic_info| { let location = panic_info .location() .map(|l| format!("{}:{}:{}", l.file(), l.line(), l.column())) .unwrap_or_else(|| "unknown location".to_string()); let message = if let Some(s) = panic_info.payload().downcast_ref::<&str>() { s.to_string() } else if let Some(s) = panic_info.payload().downcast_ref::() { s.clone() } else { "Unknown panic payload".to_string() }; error!( location = %location, message = %message, "Global panic handler caught panic" ); })); } pub fn catch_panic(f: F) -> Result where F: FnOnce() -> R + std::panic::UnwindSafe, { catch_unwind(f).map_err(|e| PanicError { message: extract_panic_message(&e), }) } pub async fn catch_panic_async(f: F) -> Result where F: FnOnce() -> Fut, Fut: std::future::Future, { match AssertUnwindSafe(f()).catch_unwind().await { Ok(result) => Ok(result), Err(e) => Err(PanicError { message: extract_panic_message(&e), }), } } #[derive(Debug, Clone)] pub struct PanicError { pub message: String, } impl std::fmt::Display for PanicError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Panic: {}", self.message) } } impl std::error::Error for PanicError {} impl IntoResponse for PanicError { fn into_response(self) -> Response { let body = json!({ "error": "internal_server_error", "message": "An internal error occurred" }); (StatusCode::INTERNAL_SERVER_ERROR, Json(body)).into_response() } } pub struct PanicGuard { name: String, logged: bool, } impl PanicGuard { pub fn new(name: impl Into) -> Self { Self { name: name.into(), logged: false, } } pub fn mark_completed(&mut self) { self.logged = true; } } impl Drop for PanicGuard { fn drop(&mut self) { if !self.logged && std::thread::panicking() { error!( guard_name = %self.name, "PanicGuard detected panic during drop" ); } } } #[macro_export] macro_rules! with_panic_guard { ($name:expr, $body:expr) => {{ let mut guard = $crate::security::panic_handler::PanicGuard::new($name); let result = $body; guard.mark_completed(); result }}; } #[cfg(test)] mod tests { use super::*; #[test] fn test_default_config() { let config = PanicHandlerConfig::default(); assert!(config.log_panics); assert!(!config.notify_on_panic); } #[test] fn test_production_config() { let config = PanicHandlerConfig::production(); assert!(config.log_panics); assert!(!config.include_backtrace); assert!(config.notify_on_panic); assert!(config.custom_message.is_some()); } #[test] fn test_development_config() { let config = PanicHandlerConfig::development(); assert!(config.log_panics); assert!(config.include_backtrace); assert!(!config.notify_on_panic); } #[test] fn test_config_builder() { let config = PanicHandlerConfig::new() .with_message("Custom error") .with_backtrace(true) .with_logging(false) .with_notification(true); assert_eq!(config.custom_message, Some("Custom error".to_string())); assert!(config.include_backtrace); assert!(!config.log_panics); assert!(config.notify_on_panic); } #[test] fn test_extract_panic_message_str() { let panic: Box = Box::new("test panic"); let message = extract_panic_message(&panic); assert_eq!(message, "test panic"); } #[test] fn test_extract_panic_message_string() { let panic: Box = Box::new("string panic".to_string()); let message = extract_panic_message(&panic); assert_eq!(message, "string panic"); } #[test] fn test_extract_panic_message_unknown() { let panic: Box = Box::new(42i32); let message = extract_panic_message(&panic); assert_eq!(message, "Unknown panic"); } #[test] fn test_catch_panic_success() { let result = catch_panic(|| 42); assert_eq!(result.unwrap(), 42); } #[test] fn test_catch_panic_failure() { let result = catch_panic(|| { panic!("test panic"); #[allow(unreachable_code)] 42 }); assert!(result.is_err()); assert!(result.unwrap_err().message.contains("test panic")); } #[test] fn test_panic_error_display() { let error = PanicError { message: "test error".to_string(), }; assert_eq!(format!("{}", error), "Panic: test error"); } #[test] fn test_panic_guard_normal() { let mut guard = PanicGuard::new("test"); guard.mark_completed(); } #[tokio::test] async fn test_catch_panic_async_success() { let result = catch_panic_async(|| async { 42 }).await; assert_eq!(result.unwrap(), 42); } #[test] fn test_create_panic_response() { let config = PanicHandlerConfig::default(); let response = create_panic_response("test-id", &config); assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); } #[test] fn test_create_panic_response_custom_message() { let config = PanicHandlerConfig::new().with_message("Custom error message"); let response = create_panic_response("test-id", &config); assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); } }