From a50d229346167c539d3fbae2dcf26160df311e41 Mon Sep 17 00:00:00 2001 From: "Rodrigo Rodriguez (Pragmatismo)" Date: Sun, 28 Dec 2025 11:50:48 -0300 Subject: [PATCH] Add limits module and resilience improvements --- Cargo.toml | 4 +- src/lib.rs | 17 +- src/limits.rs | 476 ++++++++++++++++++++++++++++++++++++++++++++++ src/resilience.rs | 96 ++++++++-- 4 files changed, 576 insertions(+), 17 deletions(-) create mode 100644 src/limits.rs diff --git a/Cargo.toml b/Cargo.toml index c10caec..b2ec5c1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,10 +11,11 @@ categories = ["api-bindings", "web-programming"] [features] default = [] -full = ["database", "http-client", "validation"] +full = ["database", "http-client", "validation", "resilience"] database = ["dep:diesel"] http-client = ["dep:reqwest"] validation = ["dep:validator"] +resilience = [] [dependencies] # Core @@ -26,6 +27,7 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" uuid = { version = "1.11", features = ["serde", "v4"] } toml = "0.8" +tokio = { version = "1.41", features = ["sync", "time"] } # Optional: Database diesel = { version = "2.1", features = ["postgres", "uuid", "chrono", "serde_json", "r2d2"], optional = true } diff --git a/src/lib.rs b/src/lib.rs index ef41f92..f7c1a70 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,18 +1,33 @@ - pub mod branding; pub mod error; #[cfg(feature = "http-client")] pub mod http_client; +pub mod limits; pub mod message_types; pub mod models; +pub mod resilience; pub mod version; pub use branding::{ branding, init_branding, is_white_label, platform_name, platform_short, BrandingConfig, }; pub use error::{BotError, BotResult}; +pub use limits::{ + check_array_length_limit, check_file_size_limit, check_loop_limit, check_recursion_limit, + check_string_length_limit, format_limit_error_response, LimitExceeded, LimitType, RateLimiter, + SystemLimits, MAX_API_CALLS_PER_HOUR, MAX_API_CALLS_PER_MINUTE, MAX_ARRAY_LENGTH, + MAX_BOTS_PER_TENANT, MAX_CONCURRENT_REQUESTS_GLOBAL, MAX_CONCURRENT_REQUESTS_PER_USER, + MAX_DB_CONNECTIONS_PER_TENANT, MAX_DB_QUERY_RESULTS, MAX_DRIVE_STORAGE_BYTES, + MAX_FILE_SIZE_BYTES, MAX_KB_DOCUMENTS_PER_BOT, MAX_KB_DOCUMENT_SIZE_BYTES, + MAX_LLM_REQUESTS_PER_MINUTE, MAX_LLM_TOKENS_PER_REQUEST, MAX_LOOP_ITERATIONS, + MAX_PENDING_TASKS, MAX_RECURSION_DEPTH, MAX_REQUEST_BODY_BYTES, MAX_SCRIPT_EXECUTION_SECONDS, + MAX_SESSIONS_PER_USER, MAX_SESSION_IDLE_SECONDS, MAX_STRING_LENGTH, MAX_TOOLS_PER_BOT, + MAX_UPLOAD_SIZE_BYTES, MAX_WEBSOCKET_CONNECTIONS_GLOBAL, MAX_WEBSOCKET_CONNECTIONS_PER_USER, + RATE_LIMIT_BURST_MULTIPLIER, RATE_LIMIT_WINDOW_SECONDS, +}; pub use message_types::MessageType; pub use models::{ApiResponse, BotResponse, Session, Suggestion, UserMessage}; +pub use resilience::{ResilienceError, RetryConfig}; pub use version::{ get_botserver_version, init_version_registry, register_component, version_string, ComponentSource, ComponentStatus, ComponentVersion, VersionRegistry, BOTSERVER_VERSION, diff --git a/src/limits.rs b/src/limits.rs new file mode 100644 index 0000000..176c1d1 --- /dev/null +++ b/src/limits.rs @@ -0,0 +1,476 @@ +use std::collections::HashMap; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::RwLock; + +pub const MAX_LOOP_ITERATIONS: u32 = 100_000; +pub const MAX_RECURSION_DEPTH: u32 = 100; +pub const MAX_FILE_SIZE_BYTES: u64 = 100 * 1024 * 1024; +pub const MAX_UPLOAD_SIZE_BYTES: u64 = 50 * 1024 * 1024; +pub const MAX_REQUEST_BODY_BYTES: u64 = 10 * 1024 * 1024; +pub const MAX_STRING_LENGTH: usize = 10 * 1024 * 1024; +pub const MAX_ARRAY_LENGTH: usize = 1_000_000; +pub const MAX_CONCURRENT_REQUESTS_PER_USER: u32 = 100; +pub const MAX_CONCURRENT_REQUESTS_GLOBAL: u32 = 10_000; +pub const MAX_WEBSOCKET_CONNECTIONS_PER_USER: u32 = 10; +pub const MAX_WEBSOCKET_CONNECTIONS_GLOBAL: u32 = 50_000; +pub const MAX_DB_QUERY_RESULTS: u32 = 10_000; +pub const MAX_DB_CONNECTIONS_PER_TENANT: u32 = 20; +pub const MAX_LLM_TOKENS_PER_REQUEST: u32 = 128_000; +pub const MAX_LLM_REQUESTS_PER_MINUTE: u32 = 60; +pub const MAX_KB_DOCUMENTS_PER_BOT: u32 = 100_000; +pub const MAX_KB_DOCUMENT_SIZE_BYTES: u64 = 50 * 1024 * 1024; +pub const MAX_SCRIPT_EXECUTION_SECONDS: u64 = 300; +pub const MAX_API_CALLS_PER_MINUTE: u32 = 1000; +pub const MAX_API_CALLS_PER_HOUR: u32 = 10_000; +pub const MAX_DRIVE_STORAGE_BYTES: u64 = 10 * 1024 * 1024 * 1024; +pub const MAX_SESSION_IDLE_SECONDS: u64 = 3600; +pub const MAX_SESSIONS_PER_USER: u32 = 10; +pub const MAX_BOTS_PER_TENANT: u32 = 100; +pub const MAX_TOOLS_PER_BOT: u32 = 500; +pub const MAX_PENDING_TASKS: u32 = 1000; +pub const RATE_LIMIT_WINDOW_SECONDS: u64 = 60; +pub const RATE_LIMIT_BURST_MULTIPLIER: f64 = 1.5; + +#[derive(Debug, Clone)] +pub struct SystemLimits { + pub max_loop_iterations: u32, + pub max_recursion_depth: u32, + pub max_file_size_bytes: u64, + pub max_upload_size_bytes: u64, + pub max_request_body_bytes: u64, + pub max_string_length: usize, + pub max_array_length: usize, + pub max_concurrent_requests_per_user: u32, + pub max_concurrent_requests_global: u32, + pub max_websocket_connections_per_user: u32, + pub max_websocket_connections_global: u32, + pub max_db_query_results: u32, + pub max_db_connections_per_tenant: u32, + pub max_llm_tokens_per_request: u32, + pub max_llm_requests_per_minute: u32, + pub max_kb_documents_per_bot: u32, + pub max_kb_document_size_bytes: u64, + pub max_script_execution_seconds: u64, + pub max_api_calls_per_minute: u32, + pub max_api_calls_per_hour: u32, + pub max_drive_storage_bytes: u64, + pub max_session_idle_seconds: u64, + pub max_sessions_per_user: u32, + pub max_bots_per_tenant: u32, + pub max_tools_per_bot: u32, + pub max_pending_tasks: u32, + pub rate_limit_window_seconds: u64, + pub rate_limit_burst_multiplier: f64, +} + +impl Default for SystemLimits { + fn default() -> Self { + Self { + max_loop_iterations: MAX_LOOP_ITERATIONS, + max_recursion_depth: MAX_RECURSION_DEPTH, + max_file_size_bytes: MAX_FILE_SIZE_BYTES, + max_upload_size_bytes: MAX_UPLOAD_SIZE_BYTES, + max_request_body_bytes: MAX_REQUEST_BODY_BYTES, + max_string_length: MAX_STRING_LENGTH, + max_array_length: MAX_ARRAY_LENGTH, + max_concurrent_requests_per_user: MAX_CONCURRENT_REQUESTS_PER_USER, + max_concurrent_requests_global: MAX_CONCURRENT_REQUESTS_GLOBAL, + max_websocket_connections_per_user: MAX_WEBSOCKET_CONNECTIONS_PER_USER, + max_websocket_connections_global: MAX_WEBSOCKET_CONNECTIONS_GLOBAL, + max_db_query_results: MAX_DB_QUERY_RESULTS, + max_db_connections_per_tenant: MAX_DB_CONNECTIONS_PER_TENANT, + max_llm_tokens_per_request: MAX_LLM_TOKENS_PER_REQUEST, + max_llm_requests_per_minute: MAX_LLM_REQUESTS_PER_MINUTE, + max_kb_documents_per_bot: MAX_KB_DOCUMENTS_PER_BOT, + max_kb_document_size_bytes: MAX_KB_DOCUMENT_SIZE_BYTES, + max_script_execution_seconds: MAX_SCRIPT_EXECUTION_SECONDS, + max_api_calls_per_minute: MAX_API_CALLS_PER_MINUTE, + max_api_calls_per_hour: MAX_API_CALLS_PER_HOUR, + max_drive_storage_bytes: MAX_DRIVE_STORAGE_BYTES, + max_session_idle_seconds: MAX_SESSION_IDLE_SECONDS, + max_sessions_per_user: MAX_SESSIONS_PER_USER, + max_bots_per_tenant: MAX_BOTS_PER_TENANT, + max_tools_per_bot: MAX_TOOLS_PER_BOT, + max_pending_tasks: MAX_PENDING_TASKS, + rate_limit_window_seconds: RATE_LIMIT_WINDOW_SECONDS, + rate_limit_burst_multiplier: RATE_LIMIT_BURST_MULTIPLIER, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum LimitType { + LoopIterations, + RecursionDepth, + FileSize, + UploadSize, + RequestBody, + StringLength, + ArrayLength, + ConcurrentRequests, + WebsocketConnections, + DbQueryResults, + DbConnections, + LlmTokens, + LlmRequests, + KbDocuments, + KbDocumentSize, + ScriptExecution, + ApiCallsMinute, + ApiCallsHour, + DriveStorage, + SessionIdle, + SessionsPerUser, + BotsPerTenant, + ToolsPerBot, + PendingTasks, +} + +impl std::fmt::Display for LimitType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::LoopIterations => write!(f, "loop_iterations"), + Self::RecursionDepth => write!(f, "recursion_depth"), + Self::FileSize => write!(f, "file_size"), + Self::UploadSize => write!(f, "upload_size"), + Self::RequestBody => write!(f, "request_body"), + Self::StringLength => write!(f, "string_length"), + Self::ArrayLength => write!(f, "array_length"), + Self::ConcurrentRequests => write!(f, "concurrent_requests"), + Self::WebsocketConnections => write!(f, "websocket_connections"), + Self::DbQueryResults => write!(f, "db_query_results"), + Self::DbConnections => write!(f, "db_connections"), + Self::LlmTokens => write!(f, "llm_tokens"), + Self::LlmRequests => write!(f, "llm_requests"), + Self::KbDocuments => write!(f, "kb_documents"), + Self::KbDocumentSize => write!(f, "kb_document_size"), + Self::ScriptExecution => write!(f, "script_execution"), + Self::ApiCallsMinute => write!(f, "api_calls_minute"), + Self::ApiCallsHour => write!(f, "api_calls_hour"), + Self::DriveStorage => write!(f, "drive_storage"), + Self::SessionIdle => write!(f, "session_idle"), + Self::SessionsPerUser => write!(f, "sessions_per_user"), + Self::BotsPerTenant => write!(f, "bots_per_tenant"), + Self::ToolsPerBot => write!(f, "tools_per_bot"), + Self::PendingTasks => write!(f, "pending_tasks"), + } + } +} + +#[derive(Debug)] +pub struct LimitExceeded { + pub limit_type: LimitType, + pub current: u64, + pub maximum: u64, + pub retry_after_secs: Option, +} + +impl std::fmt::Display for LimitExceeded { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Limit exceeded for {}: {} > {} (max)", + self.limit_type, self.current, self.maximum + ) + } +} + +impl std::error::Error for LimitExceeded {} + +#[derive(Debug)] +struct RateLimitEntry { + count: AtomicU64, + window_start: RwLock, +} + +impl RateLimitEntry { + fn new() -> Self { + Self { + count: AtomicU64::new(0), + window_start: RwLock::new(Instant::now()), + } + } +} + +#[derive(Debug)] +pub struct RateLimiter { + limits: SystemLimits, + per_user_minute: RwLock>>, + per_user_hour: RwLock>>, + global_minute: Arc, + global_hour: Arc, +} + +impl Default for RateLimiter { + fn default() -> Self { + Self::new(SystemLimits::default()) + } +} + +impl RateLimiter { + pub fn new(limits: SystemLimits) -> Self { + Self { + limits, + per_user_minute: RwLock::new(HashMap::new()), + per_user_hour: RwLock::new(HashMap::new()), + global_minute: Arc::new(RateLimitEntry::new()), + global_hour: Arc::new(RateLimitEntry::new()), + } + } + + pub async fn check_rate_limit(&self, user_id: &str) -> Result<(), LimitExceeded> { + self.check_global_limits().await?; + self.check_user_limits(user_id).await + } + + async fn check_global_limits(&self) -> Result<(), LimitExceeded> { + let now = Instant::now(); + + { + let window_start = self.global_minute.window_start.read().await; + if now.duration_since(*window_start) > Duration::from_secs(60) { + drop(window_start); + let mut window_start = self.global_minute.window_start.write().await; + *window_start = now; + self.global_minute.count.store(0, Ordering::SeqCst); + } + } + + let count = self.global_minute.count.fetch_add(1, Ordering::SeqCst) + 1; + let max = u64::from(self.limits.max_api_calls_per_minute) * 100; + + if count > max { + self.global_minute.count.fetch_sub(1, Ordering::SeqCst); + return Err(LimitExceeded { + limit_type: LimitType::ApiCallsMinute, + current: count, + maximum: max, + retry_after_secs: Some(60), + }); + } + + { + let window_start = self.global_hour.window_start.read().await; + if now.duration_since(*window_start) > Duration::from_secs(3600) { + drop(window_start); + let mut window_start = self.global_hour.window_start.write().await; + *window_start = now; + self.global_hour.count.store(0, Ordering::SeqCst); + } + } + + let hour_count = self.global_hour.count.fetch_add(1, Ordering::SeqCst) + 1; + let hour_max = u64::from(self.limits.max_api_calls_per_hour) * 100; + + if hour_count > hour_max { + self.global_hour.count.fetch_sub(1, Ordering::SeqCst); + return Err(LimitExceeded { + limit_type: LimitType::ApiCallsHour, + current: hour_count, + maximum: hour_max, + retry_after_secs: Some(3600), + }); + } + + Ok(()) + } + + async fn check_user_limits(&self, user_id: &str) -> Result<(), LimitExceeded> { + self.check_user_minute_limit(user_id).await?; + self.check_user_hour_limit(user_id).await + } + + async fn check_user_minute_limit(&self, user_id: &str) -> Result<(), LimitExceeded> { + let entry = { + let map = self.per_user_minute.read().await; + map.get(user_id).cloned() + }; + + let entry = match entry { + Some(e) => e, + None => { + let new_entry = Arc::new(RateLimitEntry::new()); + let mut map = self.per_user_minute.write().await; + map.insert(user_id.to_string(), Arc::clone(&new_entry)); + new_entry + } + }; + + let now = Instant::now(); + { + let window_start = entry.window_start.read().await; + if now.duration_since(*window_start) > Duration::from_secs(60) { + drop(window_start); + let mut window_start = entry.window_start.write().await; + *window_start = now; + entry.count.store(0, Ordering::SeqCst); + } + } + + let count = entry.count.fetch_add(1, Ordering::SeqCst) + 1; + let max = u64::from(self.limits.max_api_calls_per_minute); + + if count > max { + entry.count.fetch_sub(1, Ordering::SeqCst); + return Err(LimitExceeded { + limit_type: LimitType::ApiCallsMinute, + current: count, + maximum: max, + retry_after_secs: Some(60), + }); + } + + Ok(()) + } + + async fn check_user_hour_limit(&self, user_id: &str) -> Result<(), LimitExceeded> { + let entry = { + let map = self.per_user_hour.read().await; + map.get(user_id).cloned() + }; + + let entry = match entry { + Some(e) => e, + None => { + let new_entry = Arc::new(RateLimitEntry::new()); + let mut map = self.per_user_hour.write().await; + map.insert(user_id.to_string(), Arc::clone(&new_entry)); + new_entry + } + }; + + let now = Instant::now(); + { + let window_start = entry.window_start.read().await; + if now.duration_since(*window_start) > Duration::from_secs(3600) { + drop(window_start); + let mut window_start = entry.window_start.write().await; + *window_start = now; + entry.count.store(0, Ordering::SeqCst); + } + } + + let count = entry.count.fetch_add(1, Ordering::SeqCst) + 1; + let max = u64::from(self.limits.max_api_calls_per_hour); + + if count > max { + entry.count.fetch_sub(1, Ordering::SeqCst); + return Err(LimitExceeded { + limit_type: LimitType::ApiCallsHour, + current: count, + maximum: max, + retry_after_secs: Some(3600), + }); + } + + Ok(()) + } + + pub async fn cleanup_stale_entries(&self) { + let now = Instant::now(); + let stale_threshold = Duration::from_secs(7200); + + { + let mut map = self.per_user_minute.write().await; + let mut to_remove = Vec::new(); + for (user_id, entry) in map.iter() { + let window_start = entry.window_start.read().await; + if now.duration_since(*window_start) > stale_threshold { + to_remove.push(user_id.clone()); + } + } + for user_id in to_remove { + map.remove(&user_id); + } + } + + { + let mut map = self.per_user_hour.write().await; + let mut to_remove = Vec::new(); + for (user_id, entry) in map.iter() { + let window_start = entry.window_start.read().await; + if now.duration_since(*window_start) > stale_threshold { + to_remove.push(user_id.clone()); + } + } + for user_id in to_remove { + map.remove(&user_id); + } + } + } +} + +pub fn check_loop_limit(iterations: u32, max: u32) -> Result<(), LimitExceeded> { + if iterations >= max { + return Err(LimitExceeded { + limit_type: LimitType::LoopIterations, + current: u64::from(iterations), + maximum: u64::from(max), + retry_after_secs: None, + }); + } + Ok(()) +} + +pub fn check_recursion_limit(depth: u32, max: u32) -> Result<(), LimitExceeded> { + if depth >= max { + return Err(LimitExceeded { + limit_type: LimitType::RecursionDepth, + current: u64::from(depth), + maximum: u64::from(max), + retry_after_secs: None, + }); + } + Ok(()) +} + +pub fn check_file_size_limit(size: u64, max: u64) -> Result<(), LimitExceeded> { + if size > max { + return Err(LimitExceeded { + limit_type: LimitType::FileSize, + current: size, + maximum: max, + retry_after_secs: None, + }); + } + Ok(()) +} + +pub fn check_string_length_limit(length: usize, max: usize) -> Result<(), LimitExceeded> { + if length > max { + return Err(LimitExceeded { + limit_type: LimitType::StringLength, + current: length as u64, + maximum: max as u64, + retry_after_secs: None, + }); + } + Ok(()) +} + +pub fn check_array_length_limit(length: usize, max: usize) -> Result<(), LimitExceeded> { + if length > max { + return Err(LimitExceeded { + limit_type: LimitType::ArrayLength, + current: length as u64, + maximum: max as u64, + retry_after_secs: None, + }); + } + Ok(()) +} + +pub fn format_limit_error_response(error: &LimitExceeded) -> (u16, String) { + let status = 429; + let body = serde_json::json!({ + "error": "rate_limit_exceeded", + "message": error.to_string(), + "limit_type": error.limit_type.to_string(), + "current": error.current, + "maximum": error.maximum, + "retry_after_secs": error.retry_after_secs, + }); + (status, body.to_string()) +} diff --git a/src/resilience.rs b/src/resilience.rs index 6fec0ef..dd83659 100644 --- a/src/resilience.rs +++ b/src/resilience.rs @@ -1,19 +1,15 @@ - use std::future::Future; -use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; use std::sync::Arc; -use std::time::{Duration, Instant}; -use tokio::sync::{RwLock, Semaphore, SemaphorePermit}; +use std::time::Duration; use tokio::time::{sleep, timeout}; +pub type RetryPredicate = Arc bool + Send + Sync>; + #[derive(Debug, Clone)] pub enum ResilienceError { Timeout { duration: Duration }, CircuitOpen { until: Option }, - RetriesExhausted { - attempts: u32, - last_error: String, - }, + RetriesExhausted { attempts: u32, last_error: String }, BulkheadFull { max_concurrent: usize }, Operation(String), } @@ -55,15 +51,27 @@ impl std::fmt::Display for ResilienceError { impl std::error::Error for ResilienceError {} - -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct RetryConfig { pub max_attempts: u32, pub initial_delay: Duration, pub max_delay: Duration, pub backoff_multiplier: f64, pub jitter_factor: f64, - retryable: Option bool + Send + Sync>>, + retryable: Option, +} + +impl std::fmt::Debug for RetryConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RetryConfig") + .field("max_attempts", &self.max_attempts) + .field("initial_delay", &self.initial_delay) + .field("max_delay", &self.max_delay) + .field("backoff_multiplier", &self.backoff_multiplier) + .field("jitter_factor", &self.jitter_factor) + .field("retryable", &self.retryable.is_some()) + .finish() + } } impl Default for RetryConfig { @@ -80,31 +88,37 @@ impl Default for RetryConfig { } impl RetryConfig { + /// Create a new retry config with custom max attempts pub fn with_max_attempts(mut self, attempts: u32) -> Self { self.max_attempts = attempts.max(1); self } + /// Set initial delay pub fn with_initial_delay(mut self, delay: Duration) -> Self { self.initial_delay = delay; self } + /// Set maximum delay cap pub fn with_max_delay(mut self, delay: Duration) -> Self { self.max_delay = delay; self } + /// Set backoff multiplier pub fn with_backoff_multiplier(mut self, multiplier: f64) -> Self { self.backoff_multiplier = multiplier.max(1.0); self } + /// Set jitter factor (0.0 to 1.0) pub fn with_jitter(mut self, jitter: f64) -> Self { self.jitter_factor = jitter.clamp(0.0, 1.0); self } + /// Set custom retryable predicate pub fn with_retryable(mut self, predicate: F) -> Self where F: Fn(&str) -> bool + Send + Sync + 'static, @@ -113,6 +127,7 @@ impl RetryConfig { self } + /// Aggressive retry for critical operations pub fn aggressive() -> Self { Self { max_attempts: 5, @@ -124,6 +139,7 @@ impl RetryConfig { } } + /// Conservative retry for non-critical operations pub fn conservative() -> Self { Self { max_attempts: 2, @@ -136,15 +152,15 @@ impl RetryConfig { } fn calculate_delay(&self, attempt: u32) -> Duration { - let base_delay = self.initial_delay.as_secs_f64() - * self.backoff_multiplier.powi(attempt.saturating_sub(1) as i32); + let exponent = i32::try_from(attempt.saturating_sub(1)).unwrap_or(0); + let base_delay = self.backoff_multiplier.powi(exponent) * self.initial_delay.as_secs_f64(); let capped_delay = base_delay.min(self.max_delay.as_secs_f64()); let jitter = if self.jitter_factor > 0.0 { let jitter_range = capped_delay * self.jitter_factor; - let pseudo_random = ((attempt as f64 * 1.618033988749895) % 1.0) * 2.0 - 1.0; - jitter_range * pseudo_random + let pseudo_random = (f64::from(attempt) * 1.618_033_988_749_895) % 1.0; + (2.0_f64).mul_add(pseudo_random, -1.0) * jitter_range } else { 0.0 }; @@ -156,3 +172,53 @@ impl RetryConfig { if let Some(ref predicate) = self.retryable { predicate(error) } else { + error.contains("timeout") + || error.contains("connection") + || error.contains("temporarily") + || error.contains("503") + || error.contains("429") + } + } +} + +pub async fn retry(config: &RetryConfig, mut operation: F) -> Result +where + F: FnMut() -> Fut, + Fut: Future>, +{ + let mut last_error = String::new(); + + for attempt in 1..=config.max_attempts { + match operation().await { + Ok(result) => return Ok(result), + Err(e) => { + if attempt == config.max_attempts { + last_error = e; + break; + } + + if !config.is_retryable(&e) { + return Err(ResilienceError::Operation(e)); + } + + last_error = e; + let delay = config.calculate_delay(attempt); + sleep(delay).await; + } + } + } + + Err(ResilienceError::RetriesExhausted { + attempts: config.max_attempts, + last_error, + }) +} + +pub async fn with_timeout(duration: Duration, future: F) -> Result +where + F: Future, +{ + timeout(duration, future) + .await + .map_err(|_| ResilienceError::Timeout { duration }) +}