Add limits module and resilience improvements
This commit is contained in:
parent
94f333f983
commit
a50d229346
4 changed files with 576 additions and 17 deletions
|
|
@ -11,10 +11,11 @@ categories = ["api-bindings", "web-programming"]
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = []
|
||||||
full = ["database", "http-client", "validation"]
|
full = ["database", "http-client", "validation", "resilience"]
|
||||||
database = ["dep:diesel"]
|
database = ["dep:diesel"]
|
||||||
http-client = ["dep:reqwest"]
|
http-client = ["dep:reqwest"]
|
||||||
validation = ["dep:validator"]
|
validation = ["dep:validator"]
|
||||||
|
resilience = []
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
# Core
|
# Core
|
||||||
|
|
@ -26,6 +27,7 @@ serde = { version = "1.0", features = ["derive"] }
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
uuid = { version = "1.11", features = ["serde", "v4"] }
|
uuid = { version = "1.11", features = ["serde", "v4"] }
|
||||||
toml = "0.8"
|
toml = "0.8"
|
||||||
|
tokio = { version = "1.41", features = ["sync", "time"] }
|
||||||
|
|
||||||
# Optional: Database
|
# Optional: Database
|
||||||
diesel = { version = "2.1", features = ["postgres", "uuid", "chrono", "serde_json", "r2d2"], optional = true }
|
diesel = { version = "2.1", features = ["postgres", "uuid", "chrono", "serde_json", "r2d2"], optional = true }
|
||||||
|
|
|
||||||
17
src/lib.rs
17
src/lib.rs
|
|
@ -1,18 +1,33 @@
|
||||||
|
|
||||||
pub mod branding;
|
pub mod branding;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
#[cfg(feature = "http-client")]
|
#[cfg(feature = "http-client")]
|
||||||
pub mod http_client;
|
pub mod http_client;
|
||||||
|
pub mod limits;
|
||||||
pub mod message_types;
|
pub mod message_types;
|
||||||
pub mod models;
|
pub mod models;
|
||||||
|
pub mod resilience;
|
||||||
pub mod version;
|
pub mod version;
|
||||||
|
|
||||||
pub use branding::{
|
pub use branding::{
|
||||||
branding, init_branding, is_white_label, platform_name, platform_short, BrandingConfig,
|
branding, init_branding, is_white_label, platform_name, platform_short, BrandingConfig,
|
||||||
};
|
};
|
||||||
pub use error::{BotError, BotResult};
|
pub use error::{BotError, BotResult};
|
||||||
|
pub use limits::{
|
||||||
|
check_array_length_limit, check_file_size_limit, check_loop_limit, check_recursion_limit,
|
||||||
|
check_string_length_limit, format_limit_error_response, LimitExceeded, LimitType, RateLimiter,
|
||||||
|
SystemLimits, MAX_API_CALLS_PER_HOUR, MAX_API_CALLS_PER_MINUTE, MAX_ARRAY_LENGTH,
|
||||||
|
MAX_BOTS_PER_TENANT, MAX_CONCURRENT_REQUESTS_GLOBAL, MAX_CONCURRENT_REQUESTS_PER_USER,
|
||||||
|
MAX_DB_CONNECTIONS_PER_TENANT, MAX_DB_QUERY_RESULTS, MAX_DRIVE_STORAGE_BYTES,
|
||||||
|
MAX_FILE_SIZE_BYTES, MAX_KB_DOCUMENTS_PER_BOT, MAX_KB_DOCUMENT_SIZE_BYTES,
|
||||||
|
MAX_LLM_REQUESTS_PER_MINUTE, MAX_LLM_TOKENS_PER_REQUEST, MAX_LOOP_ITERATIONS,
|
||||||
|
MAX_PENDING_TASKS, MAX_RECURSION_DEPTH, MAX_REQUEST_BODY_BYTES, MAX_SCRIPT_EXECUTION_SECONDS,
|
||||||
|
MAX_SESSIONS_PER_USER, MAX_SESSION_IDLE_SECONDS, MAX_STRING_LENGTH, MAX_TOOLS_PER_BOT,
|
||||||
|
MAX_UPLOAD_SIZE_BYTES, MAX_WEBSOCKET_CONNECTIONS_GLOBAL, MAX_WEBSOCKET_CONNECTIONS_PER_USER,
|
||||||
|
RATE_LIMIT_BURST_MULTIPLIER, RATE_LIMIT_WINDOW_SECONDS,
|
||||||
|
};
|
||||||
pub use message_types::MessageType;
|
pub use message_types::MessageType;
|
||||||
pub use models::{ApiResponse, BotResponse, Session, Suggestion, UserMessage};
|
pub use models::{ApiResponse, BotResponse, Session, Suggestion, UserMessage};
|
||||||
|
pub use resilience::{ResilienceError, RetryConfig};
|
||||||
pub use version::{
|
pub use version::{
|
||||||
get_botserver_version, init_version_registry, register_component, version_string,
|
get_botserver_version, init_version_registry, register_component, version_string,
|
||||||
ComponentSource, ComponentStatus, ComponentVersion, VersionRegistry, BOTSERVER_VERSION,
|
ComponentSource, ComponentStatus, ComponentVersion, VersionRegistry, BOTSERVER_VERSION,
|
||||||
|
|
|
||||||
476
src/limits.rs
Normal file
476
src/limits.rs
Normal file
|
|
@ -0,0 +1,476 @@
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::atomic::{AtomicU64, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
|
pub const MAX_LOOP_ITERATIONS: u32 = 100_000;
|
||||||
|
pub const MAX_RECURSION_DEPTH: u32 = 100;
|
||||||
|
pub const MAX_FILE_SIZE_BYTES: u64 = 100 * 1024 * 1024;
|
||||||
|
pub const MAX_UPLOAD_SIZE_BYTES: u64 = 50 * 1024 * 1024;
|
||||||
|
pub const MAX_REQUEST_BODY_BYTES: u64 = 10 * 1024 * 1024;
|
||||||
|
pub const MAX_STRING_LENGTH: usize = 10 * 1024 * 1024;
|
||||||
|
pub const MAX_ARRAY_LENGTH: usize = 1_000_000;
|
||||||
|
pub const MAX_CONCURRENT_REQUESTS_PER_USER: u32 = 100;
|
||||||
|
pub const MAX_CONCURRENT_REQUESTS_GLOBAL: u32 = 10_000;
|
||||||
|
pub const MAX_WEBSOCKET_CONNECTIONS_PER_USER: u32 = 10;
|
||||||
|
pub const MAX_WEBSOCKET_CONNECTIONS_GLOBAL: u32 = 50_000;
|
||||||
|
pub const MAX_DB_QUERY_RESULTS: u32 = 10_000;
|
||||||
|
pub const MAX_DB_CONNECTIONS_PER_TENANT: u32 = 20;
|
||||||
|
pub const MAX_LLM_TOKENS_PER_REQUEST: u32 = 128_000;
|
||||||
|
pub const MAX_LLM_REQUESTS_PER_MINUTE: u32 = 60;
|
||||||
|
pub const MAX_KB_DOCUMENTS_PER_BOT: u32 = 100_000;
|
||||||
|
pub const MAX_KB_DOCUMENT_SIZE_BYTES: u64 = 50 * 1024 * 1024;
|
||||||
|
pub const MAX_SCRIPT_EXECUTION_SECONDS: u64 = 300;
|
||||||
|
pub const MAX_API_CALLS_PER_MINUTE: u32 = 1000;
|
||||||
|
pub const MAX_API_CALLS_PER_HOUR: u32 = 10_000;
|
||||||
|
pub const MAX_DRIVE_STORAGE_BYTES: u64 = 10 * 1024 * 1024 * 1024;
|
||||||
|
pub const MAX_SESSION_IDLE_SECONDS: u64 = 3600;
|
||||||
|
pub const MAX_SESSIONS_PER_USER: u32 = 10;
|
||||||
|
pub const MAX_BOTS_PER_TENANT: u32 = 100;
|
||||||
|
pub const MAX_TOOLS_PER_BOT: u32 = 500;
|
||||||
|
pub const MAX_PENDING_TASKS: u32 = 1000;
|
||||||
|
pub const RATE_LIMIT_WINDOW_SECONDS: u64 = 60;
|
||||||
|
pub const RATE_LIMIT_BURST_MULTIPLIER: f64 = 1.5;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct SystemLimits {
|
||||||
|
pub max_loop_iterations: u32,
|
||||||
|
pub max_recursion_depth: u32,
|
||||||
|
pub max_file_size_bytes: u64,
|
||||||
|
pub max_upload_size_bytes: u64,
|
||||||
|
pub max_request_body_bytes: u64,
|
||||||
|
pub max_string_length: usize,
|
||||||
|
pub max_array_length: usize,
|
||||||
|
pub max_concurrent_requests_per_user: u32,
|
||||||
|
pub max_concurrent_requests_global: u32,
|
||||||
|
pub max_websocket_connections_per_user: u32,
|
||||||
|
pub max_websocket_connections_global: u32,
|
||||||
|
pub max_db_query_results: u32,
|
||||||
|
pub max_db_connections_per_tenant: u32,
|
||||||
|
pub max_llm_tokens_per_request: u32,
|
||||||
|
pub max_llm_requests_per_minute: u32,
|
||||||
|
pub max_kb_documents_per_bot: u32,
|
||||||
|
pub max_kb_document_size_bytes: u64,
|
||||||
|
pub max_script_execution_seconds: u64,
|
||||||
|
pub max_api_calls_per_minute: u32,
|
||||||
|
pub max_api_calls_per_hour: u32,
|
||||||
|
pub max_drive_storage_bytes: u64,
|
||||||
|
pub max_session_idle_seconds: u64,
|
||||||
|
pub max_sessions_per_user: u32,
|
||||||
|
pub max_bots_per_tenant: u32,
|
||||||
|
pub max_tools_per_bot: u32,
|
||||||
|
pub max_pending_tasks: u32,
|
||||||
|
pub rate_limit_window_seconds: u64,
|
||||||
|
pub rate_limit_burst_multiplier: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for SystemLimits {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
max_loop_iterations: MAX_LOOP_ITERATIONS,
|
||||||
|
max_recursion_depth: MAX_RECURSION_DEPTH,
|
||||||
|
max_file_size_bytes: MAX_FILE_SIZE_BYTES,
|
||||||
|
max_upload_size_bytes: MAX_UPLOAD_SIZE_BYTES,
|
||||||
|
max_request_body_bytes: MAX_REQUEST_BODY_BYTES,
|
||||||
|
max_string_length: MAX_STRING_LENGTH,
|
||||||
|
max_array_length: MAX_ARRAY_LENGTH,
|
||||||
|
max_concurrent_requests_per_user: MAX_CONCURRENT_REQUESTS_PER_USER,
|
||||||
|
max_concurrent_requests_global: MAX_CONCURRENT_REQUESTS_GLOBAL,
|
||||||
|
max_websocket_connections_per_user: MAX_WEBSOCKET_CONNECTIONS_PER_USER,
|
||||||
|
max_websocket_connections_global: MAX_WEBSOCKET_CONNECTIONS_GLOBAL,
|
||||||
|
max_db_query_results: MAX_DB_QUERY_RESULTS,
|
||||||
|
max_db_connections_per_tenant: MAX_DB_CONNECTIONS_PER_TENANT,
|
||||||
|
max_llm_tokens_per_request: MAX_LLM_TOKENS_PER_REQUEST,
|
||||||
|
max_llm_requests_per_minute: MAX_LLM_REQUESTS_PER_MINUTE,
|
||||||
|
max_kb_documents_per_bot: MAX_KB_DOCUMENTS_PER_BOT,
|
||||||
|
max_kb_document_size_bytes: MAX_KB_DOCUMENT_SIZE_BYTES,
|
||||||
|
max_script_execution_seconds: MAX_SCRIPT_EXECUTION_SECONDS,
|
||||||
|
max_api_calls_per_minute: MAX_API_CALLS_PER_MINUTE,
|
||||||
|
max_api_calls_per_hour: MAX_API_CALLS_PER_HOUR,
|
||||||
|
max_drive_storage_bytes: MAX_DRIVE_STORAGE_BYTES,
|
||||||
|
max_session_idle_seconds: MAX_SESSION_IDLE_SECONDS,
|
||||||
|
max_sessions_per_user: MAX_SESSIONS_PER_USER,
|
||||||
|
max_bots_per_tenant: MAX_BOTS_PER_TENANT,
|
||||||
|
max_tools_per_bot: MAX_TOOLS_PER_BOT,
|
||||||
|
max_pending_tasks: MAX_PENDING_TASKS,
|
||||||
|
rate_limit_window_seconds: RATE_LIMIT_WINDOW_SECONDS,
|
||||||
|
rate_limit_burst_multiplier: RATE_LIMIT_BURST_MULTIPLIER,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum LimitType {
|
||||||
|
LoopIterations,
|
||||||
|
RecursionDepth,
|
||||||
|
FileSize,
|
||||||
|
UploadSize,
|
||||||
|
RequestBody,
|
||||||
|
StringLength,
|
||||||
|
ArrayLength,
|
||||||
|
ConcurrentRequests,
|
||||||
|
WebsocketConnections,
|
||||||
|
DbQueryResults,
|
||||||
|
DbConnections,
|
||||||
|
LlmTokens,
|
||||||
|
LlmRequests,
|
||||||
|
KbDocuments,
|
||||||
|
KbDocumentSize,
|
||||||
|
ScriptExecution,
|
||||||
|
ApiCallsMinute,
|
||||||
|
ApiCallsHour,
|
||||||
|
DriveStorage,
|
||||||
|
SessionIdle,
|
||||||
|
SessionsPerUser,
|
||||||
|
BotsPerTenant,
|
||||||
|
ToolsPerBot,
|
||||||
|
PendingTasks,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for LimitType {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
Self::LoopIterations => write!(f, "loop_iterations"),
|
||||||
|
Self::RecursionDepth => write!(f, "recursion_depth"),
|
||||||
|
Self::FileSize => write!(f, "file_size"),
|
||||||
|
Self::UploadSize => write!(f, "upload_size"),
|
||||||
|
Self::RequestBody => write!(f, "request_body"),
|
||||||
|
Self::StringLength => write!(f, "string_length"),
|
||||||
|
Self::ArrayLength => write!(f, "array_length"),
|
||||||
|
Self::ConcurrentRequests => write!(f, "concurrent_requests"),
|
||||||
|
Self::WebsocketConnections => write!(f, "websocket_connections"),
|
||||||
|
Self::DbQueryResults => write!(f, "db_query_results"),
|
||||||
|
Self::DbConnections => write!(f, "db_connections"),
|
||||||
|
Self::LlmTokens => write!(f, "llm_tokens"),
|
||||||
|
Self::LlmRequests => write!(f, "llm_requests"),
|
||||||
|
Self::KbDocuments => write!(f, "kb_documents"),
|
||||||
|
Self::KbDocumentSize => write!(f, "kb_document_size"),
|
||||||
|
Self::ScriptExecution => write!(f, "script_execution"),
|
||||||
|
Self::ApiCallsMinute => write!(f, "api_calls_minute"),
|
||||||
|
Self::ApiCallsHour => write!(f, "api_calls_hour"),
|
||||||
|
Self::DriveStorage => write!(f, "drive_storage"),
|
||||||
|
Self::SessionIdle => write!(f, "session_idle"),
|
||||||
|
Self::SessionsPerUser => write!(f, "sessions_per_user"),
|
||||||
|
Self::BotsPerTenant => write!(f, "bots_per_tenant"),
|
||||||
|
Self::ToolsPerBot => write!(f, "tools_per_bot"),
|
||||||
|
Self::PendingTasks => write!(f, "pending_tasks"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct LimitExceeded {
|
||||||
|
pub limit_type: LimitType,
|
||||||
|
pub current: u64,
|
||||||
|
pub maximum: u64,
|
||||||
|
pub retry_after_secs: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for LimitExceeded {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(
|
||||||
|
f,
|
||||||
|
"Limit exceeded for {}: {} > {} (max)",
|
||||||
|
self.limit_type, self.current, self.maximum
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for LimitExceeded {}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct RateLimitEntry {
|
||||||
|
count: AtomicU64,
|
||||||
|
window_start: RwLock<Instant>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RateLimitEntry {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
count: AtomicU64::new(0),
|
||||||
|
window_start: RwLock::new(Instant::now()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct RateLimiter {
|
||||||
|
limits: SystemLimits,
|
||||||
|
per_user_minute: RwLock<HashMap<String, Arc<RateLimitEntry>>>,
|
||||||
|
per_user_hour: RwLock<HashMap<String, Arc<RateLimitEntry>>>,
|
||||||
|
global_minute: Arc<RateLimitEntry>,
|
||||||
|
global_hour: Arc<RateLimitEntry>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for RateLimiter {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new(SystemLimits::default())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RateLimiter {
|
||||||
|
pub fn new(limits: SystemLimits) -> Self {
|
||||||
|
Self {
|
||||||
|
limits,
|
||||||
|
per_user_minute: RwLock::new(HashMap::new()),
|
||||||
|
per_user_hour: RwLock::new(HashMap::new()),
|
||||||
|
global_minute: Arc::new(RateLimitEntry::new()),
|
||||||
|
global_hour: Arc::new(RateLimitEntry::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn check_rate_limit(&self, user_id: &str) -> Result<(), LimitExceeded> {
|
||||||
|
self.check_global_limits().await?;
|
||||||
|
self.check_user_limits(user_id).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn check_global_limits(&self) -> Result<(), LimitExceeded> {
|
||||||
|
let now = Instant::now();
|
||||||
|
|
||||||
|
{
|
||||||
|
let window_start = self.global_minute.window_start.read().await;
|
||||||
|
if now.duration_since(*window_start) > Duration::from_secs(60) {
|
||||||
|
drop(window_start);
|
||||||
|
let mut window_start = self.global_minute.window_start.write().await;
|
||||||
|
*window_start = now;
|
||||||
|
self.global_minute.count.store(0, Ordering::SeqCst);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let count = self.global_minute.count.fetch_add(1, Ordering::SeqCst) + 1;
|
||||||
|
let max = u64::from(self.limits.max_api_calls_per_minute) * 100;
|
||||||
|
|
||||||
|
if count > max {
|
||||||
|
self.global_minute.count.fetch_sub(1, Ordering::SeqCst);
|
||||||
|
return Err(LimitExceeded {
|
||||||
|
limit_type: LimitType::ApiCallsMinute,
|
||||||
|
current: count,
|
||||||
|
maximum: max,
|
||||||
|
retry_after_secs: Some(60),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
let window_start = self.global_hour.window_start.read().await;
|
||||||
|
if now.duration_since(*window_start) > Duration::from_secs(3600) {
|
||||||
|
drop(window_start);
|
||||||
|
let mut window_start = self.global_hour.window_start.write().await;
|
||||||
|
*window_start = now;
|
||||||
|
self.global_hour.count.store(0, Ordering::SeqCst);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let hour_count = self.global_hour.count.fetch_add(1, Ordering::SeqCst) + 1;
|
||||||
|
let hour_max = u64::from(self.limits.max_api_calls_per_hour) * 100;
|
||||||
|
|
||||||
|
if hour_count > hour_max {
|
||||||
|
self.global_hour.count.fetch_sub(1, Ordering::SeqCst);
|
||||||
|
return Err(LimitExceeded {
|
||||||
|
limit_type: LimitType::ApiCallsHour,
|
||||||
|
current: hour_count,
|
||||||
|
maximum: hour_max,
|
||||||
|
retry_after_secs: Some(3600),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn check_user_limits(&self, user_id: &str) -> Result<(), LimitExceeded> {
|
||||||
|
self.check_user_minute_limit(user_id).await?;
|
||||||
|
self.check_user_hour_limit(user_id).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn check_user_minute_limit(&self, user_id: &str) -> Result<(), LimitExceeded> {
|
||||||
|
let entry = {
|
||||||
|
let map = self.per_user_minute.read().await;
|
||||||
|
map.get(user_id).cloned()
|
||||||
|
};
|
||||||
|
|
||||||
|
let entry = match entry {
|
||||||
|
Some(e) => e,
|
||||||
|
None => {
|
||||||
|
let new_entry = Arc::new(RateLimitEntry::new());
|
||||||
|
let mut map = self.per_user_minute.write().await;
|
||||||
|
map.insert(user_id.to_string(), Arc::clone(&new_entry));
|
||||||
|
new_entry
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let now = Instant::now();
|
||||||
|
{
|
||||||
|
let window_start = entry.window_start.read().await;
|
||||||
|
if now.duration_since(*window_start) > Duration::from_secs(60) {
|
||||||
|
drop(window_start);
|
||||||
|
let mut window_start = entry.window_start.write().await;
|
||||||
|
*window_start = now;
|
||||||
|
entry.count.store(0, Ordering::SeqCst);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let count = entry.count.fetch_add(1, Ordering::SeqCst) + 1;
|
||||||
|
let max = u64::from(self.limits.max_api_calls_per_minute);
|
||||||
|
|
||||||
|
if count > max {
|
||||||
|
entry.count.fetch_sub(1, Ordering::SeqCst);
|
||||||
|
return Err(LimitExceeded {
|
||||||
|
limit_type: LimitType::ApiCallsMinute,
|
||||||
|
current: count,
|
||||||
|
maximum: max,
|
||||||
|
retry_after_secs: Some(60),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn check_user_hour_limit(&self, user_id: &str) -> Result<(), LimitExceeded> {
|
||||||
|
let entry = {
|
||||||
|
let map = self.per_user_hour.read().await;
|
||||||
|
map.get(user_id).cloned()
|
||||||
|
};
|
||||||
|
|
||||||
|
let entry = match entry {
|
||||||
|
Some(e) => e,
|
||||||
|
None => {
|
||||||
|
let new_entry = Arc::new(RateLimitEntry::new());
|
||||||
|
let mut map = self.per_user_hour.write().await;
|
||||||
|
map.insert(user_id.to_string(), Arc::clone(&new_entry));
|
||||||
|
new_entry
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let now = Instant::now();
|
||||||
|
{
|
||||||
|
let window_start = entry.window_start.read().await;
|
||||||
|
if now.duration_since(*window_start) > Duration::from_secs(3600) {
|
||||||
|
drop(window_start);
|
||||||
|
let mut window_start = entry.window_start.write().await;
|
||||||
|
*window_start = now;
|
||||||
|
entry.count.store(0, Ordering::SeqCst);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let count = entry.count.fetch_add(1, Ordering::SeqCst) + 1;
|
||||||
|
let max = u64::from(self.limits.max_api_calls_per_hour);
|
||||||
|
|
||||||
|
if count > max {
|
||||||
|
entry.count.fetch_sub(1, Ordering::SeqCst);
|
||||||
|
return Err(LimitExceeded {
|
||||||
|
limit_type: LimitType::ApiCallsHour,
|
||||||
|
current: count,
|
||||||
|
maximum: max,
|
||||||
|
retry_after_secs: Some(3600),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn cleanup_stale_entries(&self) {
|
||||||
|
let now = Instant::now();
|
||||||
|
let stale_threshold = Duration::from_secs(7200);
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut map = self.per_user_minute.write().await;
|
||||||
|
let mut to_remove = Vec::new();
|
||||||
|
for (user_id, entry) in map.iter() {
|
||||||
|
let window_start = entry.window_start.read().await;
|
||||||
|
if now.duration_since(*window_start) > stale_threshold {
|
||||||
|
to_remove.push(user_id.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for user_id in to_remove {
|
||||||
|
map.remove(&user_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut map = self.per_user_hour.write().await;
|
||||||
|
let mut to_remove = Vec::new();
|
||||||
|
for (user_id, entry) in map.iter() {
|
||||||
|
let window_start = entry.window_start.read().await;
|
||||||
|
if now.duration_since(*window_start) > stale_threshold {
|
||||||
|
to_remove.push(user_id.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for user_id in to_remove {
|
||||||
|
map.remove(&user_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn check_loop_limit(iterations: u32, max: u32) -> Result<(), LimitExceeded> {
|
||||||
|
if iterations >= max {
|
||||||
|
return Err(LimitExceeded {
|
||||||
|
limit_type: LimitType::LoopIterations,
|
||||||
|
current: u64::from(iterations),
|
||||||
|
maximum: u64::from(max),
|
||||||
|
retry_after_secs: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn check_recursion_limit(depth: u32, max: u32) -> Result<(), LimitExceeded> {
|
||||||
|
if depth >= max {
|
||||||
|
return Err(LimitExceeded {
|
||||||
|
limit_type: LimitType::RecursionDepth,
|
||||||
|
current: u64::from(depth),
|
||||||
|
maximum: u64::from(max),
|
||||||
|
retry_after_secs: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn check_file_size_limit(size: u64, max: u64) -> Result<(), LimitExceeded> {
|
||||||
|
if size > max {
|
||||||
|
return Err(LimitExceeded {
|
||||||
|
limit_type: LimitType::FileSize,
|
||||||
|
current: size,
|
||||||
|
maximum: max,
|
||||||
|
retry_after_secs: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn check_string_length_limit(length: usize, max: usize) -> Result<(), LimitExceeded> {
|
||||||
|
if length > max {
|
||||||
|
return Err(LimitExceeded {
|
||||||
|
limit_type: LimitType::StringLength,
|
||||||
|
current: length as u64,
|
||||||
|
maximum: max as u64,
|
||||||
|
retry_after_secs: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn check_array_length_limit(length: usize, max: usize) -> Result<(), LimitExceeded> {
|
||||||
|
if length > max {
|
||||||
|
return Err(LimitExceeded {
|
||||||
|
limit_type: LimitType::ArrayLength,
|
||||||
|
current: length as u64,
|
||||||
|
maximum: max as u64,
|
||||||
|
retry_after_secs: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn format_limit_error_response(error: &LimitExceeded) -> (u16, String) {
|
||||||
|
let status = 429;
|
||||||
|
let body = serde_json::json!({
|
||||||
|
"error": "rate_limit_exceeded",
|
||||||
|
"message": error.to_string(),
|
||||||
|
"limit_type": error.limit_type.to_string(),
|
||||||
|
"current": error.current,
|
||||||
|
"maximum": error.maximum,
|
||||||
|
"retry_after_secs": error.retry_after_secs,
|
||||||
|
});
|
||||||
|
(status, body.to_string())
|
||||||
|
}
|
||||||
|
|
@ -1,19 +1,15 @@
|
||||||
|
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::Duration;
|
||||||
use tokio::sync::{RwLock, Semaphore, SemaphorePermit};
|
|
||||||
use tokio::time::{sleep, timeout};
|
use tokio::time::{sleep, timeout};
|
||||||
|
|
||||||
|
pub type RetryPredicate = Arc<dyn Fn(&str) -> bool + Send + Sync>;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum ResilienceError {
|
pub enum ResilienceError {
|
||||||
Timeout { duration: Duration },
|
Timeout { duration: Duration },
|
||||||
CircuitOpen { until: Option<Duration> },
|
CircuitOpen { until: Option<Duration> },
|
||||||
RetriesExhausted {
|
RetriesExhausted { attempts: u32, last_error: String },
|
||||||
attempts: u32,
|
|
||||||
last_error: String,
|
|
||||||
},
|
|
||||||
BulkheadFull { max_concurrent: usize },
|
BulkheadFull { max_concurrent: usize },
|
||||||
Operation(String),
|
Operation(String),
|
||||||
}
|
}
|
||||||
|
|
@ -55,15 +51,27 @@ impl std::fmt::Display for ResilienceError {
|
||||||
|
|
||||||
impl std::error::Error for ResilienceError {}
|
impl std::error::Error for ResilienceError {}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct RetryConfig {
|
pub struct RetryConfig {
|
||||||
pub max_attempts: u32,
|
pub max_attempts: u32,
|
||||||
pub initial_delay: Duration,
|
pub initial_delay: Duration,
|
||||||
pub max_delay: Duration,
|
pub max_delay: Duration,
|
||||||
pub backoff_multiplier: f64,
|
pub backoff_multiplier: f64,
|
||||||
pub jitter_factor: f64,
|
pub jitter_factor: f64,
|
||||||
retryable: Option<Arc<dyn Fn(&str) -> bool + Send + Sync>>,
|
retryable: Option<RetryPredicate>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for RetryConfig {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("RetryConfig")
|
||||||
|
.field("max_attempts", &self.max_attempts)
|
||||||
|
.field("initial_delay", &self.initial_delay)
|
||||||
|
.field("max_delay", &self.max_delay)
|
||||||
|
.field("backoff_multiplier", &self.backoff_multiplier)
|
||||||
|
.field("jitter_factor", &self.jitter_factor)
|
||||||
|
.field("retryable", &self.retryable.is_some())
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for RetryConfig {
|
impl Default for RetryConfig {
|
||||||
|
|
@ -80,31 +88,37 @@ impl Default for RetryConfig {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RetryConfig {
|
impl RetryConfig {
|
||||||
|
/// Create a new retry config with custom max attempts
|
||||||
pub fn with_max_attempts(mut self, attempts: u32) -> Self {
|
pub fn with_max_attempts(mut self, attempts: u32) -> Self {
|
||||||
self.max_attempts = attempts.max(1);
|
self.max_attempts = attempts.max(1);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Set initial delay
|
||||||
pub fn with_initial_delay(mut self, delay: Duration) -> Self {
|
pub fn with_initial_delay(mut self, delay: Duration) -> Self {
|
||||||
self.initial_delay = delay;
|
self.initial_delay = delay;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Set maximum delay cap
|
||||||
pub fn with_max_delay(mut self, delay: Duration) -> Self {
|
pub fn with_max_delay(mut self, delay: Duration) -> Self {
|
||||||
self.max_delay = delay;
|
self.max_delay = delay;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Set backoff multiplier
|
||||||
pub fn with_backoff_multiplier(mut self, multiplier: f64) -> Self {
|
pub fn with_backoff_multiplier(mut self, multiplier: f64) -> Self {
|
||||||
self.backoff_multiplier = multiplier.max(1.0);
|
self.backoff_multiplier = multiplier.max(1.0);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Set jitter factor (0.0 to 1.0)
|
||||||
pub fn with_jitter(mut self, jitter: f64) -> Self {
|
pub fn with_jitter(mut self, jitter: f64) -> Self {
|
||||||
self.jitter_factor = jitter.clamp(0.0, 1.0);
|
self.jitter_factor = jitter.clamp(0.0, 1.0);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Set custom retryable predicate
|
||||||
pub fn with_retryable<F>(mut self, predicate: F) -> Self
|
pub fn with_retryable<F>(mut self, predicate: F) -> Self
|
||||||
where
|
where
|
||||||
F: Fn(&str) -> bool + Send + Sync + 'static,
|
F: Fn(&str) -> bool + Send + Sync + 'static,
|
||||||
|
|
@ -113,6 +127,7 @@ impl RetryConfig {
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Aggressive retry for critical operations
|
||||||
pub fn aggressive() -> Self {
|
pub fn aggressive() -> Self {
|
||||||
Self {
|
Self {
|
||||||
max_attempts: 5,
|
max_attempts: 5,
|
||||||
|
|
@ -124,6 +139,7 @@ impl RetryConfig {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Conservative retry for non-critical operations
|
||||||
pub fn conservative() -> Self {
|
pub fn conservative() -> Self {
|
||||||
Self {
|
Self {
|
||||||
max_attempts: 2,
|
max_attempts: 2,
|
||||||
|
|
@ -136,15 +152,15 @@ impl RetryConfig {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn calculate_delay(&self, attempt: u32) -> Duration {
|
fn calculate_delay(&self, attempt: u32) -> Duration {
|
||||||
let base_delay = self.initial_delay.as_secs_f64()
|
let exponent = i32::try_from(attempt.saturating_sub(1)).unwrap_or(0);
|
||||||
* self.backoff_multiplier.powi(attempt.saturating_sub(1) as i32);
|
let base_delay = self.backoff_multiplier.powi(exponent) * self.initial_delay.as_secs_f64();
|
||||||
|
|
||||||
let capped_delay = base_delay.min(self.max_delay.as_secs_f64());
|
let capped_delay = base_delay.min(self.max_delay.as_secs_f64());
|
||||||
|
|
||||||
let jitter = if self.jitter_factor > 0.0 {
|
let jitter = if self.jitter_factor > 0.0 {
|
||||||
let jitter_range = capped_delay * self.jitter_factor;
|
let jitter_range = capped_delay * self.jitter_factor;
|
||||||
let pseudo_random = ((attempt as f64 * 1.618033988749895) % 1.0) * 2.0 - 1.0;
|
let pseudo_random = (f64::from(attempt) * 1.618_033_988_749_895) % 1.0;
|
||||||
jitter_range * pseudo_random
|
(2.0_f64).mul_add(pseudo_random, -1.0) * jitter_range
|
||||||
} else {
|
} else {
|
||||||
0.0
|
0.0
|
||||||
};
|
};
|
||||||
|
|
@ -156,3 +172,53 @@ impl RetryConfig {
|
||||||
if let Some(ref predicate) = self.retryable {
|
if let Some(ref predicate) = self.retryable {
|
||||||
predicate(error)
|
predicate(error)
|
||||||
} else {
|
} else {
|
||||||
|
error.contains("timeout")
|
||||||
|
|| error.contains("connection")
|
||||||
|
|| error.contains("temporarily")
|
||||||
|
|| error.contains("503")
|
||||||
|
|| error.contains("429")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn retry<F, Fut, T>(config: &RetryConfig, mut operation: F) -> Result<T, ResilienceError>
|
||||||
|
where
|
||||||
|
F: FnMut() -> Fut,
|
||||||
|
Fut: Future<Output = Result<T, String>>,
|
||||||
|
{
|
||||||
|
let mut last_error = String::new();
|
||||||
|
|
||||||
|
for attempt in 1..=config.max_attempts {
|
||||||
|
match operation().await {
|
||||||
|
Ok(result) => return Ok(result),
|
||||||
|
Err(e) => {
|
||||||
|
if attempt == config.max_attempts {
|
||||||
|
last_error = e;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !config.is_retryable(&e) {
|
||||||
|
return Err(ResilienceError::Operation(e));
|
||||||
|
}
|
||||||
|
|
||||||
|
last_error = e;
|
||||||
|
let delay = config.calculate_delay(attempt);
|
||||||
|
sleep(delay).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(ResilienceError::RetriesExhausted {
|
||||||
|
attempts: config.max_attempts,
|
||||||
|
last_error,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn with_timeout<F, T>(duration: Duration, future: F) -> Result<T, ResilienceError>
|
||||||
|
where
|
||||||
|
F: Future<Output = T>,
|
||||||
|
{
|
||||||
|
timeout(duration, future)
|
||||||
|
.await
|
||||||
|
.map_err(|_| ResilienceError::Timeout { duration })
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue