feat: add WhatsApp rate limiting and LLM hallucination detection
All checks were successful
BotServer CI / build (push) Successful in 11m51s
All checks were successful
BotServer CI / build (push) Successful in 11m51s
This commit is contained in:
parent
24709f7811
commit
77c35ccde5
7 changed files with 653 additions and 8 deletions
|
|
@ -2,6 +2,7 @@ pub mod instagram;
|
|||
pub mod teams;
|
||||
pub mod telegram;
|
||||
pub mod whatsapp;
|
||||
pub mod whatsapp_rate_limiter;
|
||||
|
||||
use crate::core::shared::models::BotResponse;
|
||||
use async_trait::async_trait;
|
||||
|
|
|
|||
|
|
@ -4,17 +4,22 @@ use serde::{Deserialize, Serialize};
|
|||
use uuid::Uuid;
|
||||
|
||||
use crate::core::bot::channels::ChannelAdapter;
|
||||
use crate::core::bot::channels::whatsapp_rate_limiter::WhatsAppRateLimiter;
|
||||
use crate::core::config::ConfigManager;
|
||||
use crate::core::shared::models::BotResponse;
|
||||
use crate::core::shared::utils::DbPool;
|
||||
|
||||
#[derive(Debug)]
|
||||
/// Global rate limiter for WhatsApp API (shared across all adapters)
|
||||
static WHATSAPP_RATE_LIMITER: std::sync::OnceLock<WhatsAppRateLimiter> = std::sync::OnceLock::new();
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WhatsAppAdapter {
|
||||
api_key: String,
|
||||
phone_number_id: String,
|
||||
webhook_verify_token: String,
|
||||
_business_account_id: String,
|
||||
api_version: String,
|
||||
rate_limiter: &'static WhatsAppRateLimiter,
|
||||
}
|
||||
|
||||
impl WhatsAppAdapter {
|
||||
|
|
@ -41,12 +46,27 @@ impl WhatsAppAdapter {
|
|||
.get_config(&bot_id, "whatsapp-api-version", Some("v17.0"))
|
||||
.unwrap_or_else(|_| "v17.0".to_string());
|
||||
|
||||
// Get rate limit tier from config (default to Tier 1 for safety)
|
||||
let tier_str = config_manager
|
||||
.get_config(&bot_id, "whatsapp-rate-tier", None)
|
||||
.unwrap_or_else(|_| "1".to_string());
|
||||
let tier = match tier_str.as_str() {
|
||||
"1" | "tier1" => super::whatsapp_rate_limiter::WhatsAppTier::Tier1,
|
||||
"2" | "tier2" => super::whatsapp_rate_limiter::WhatsAppTier::Tier2,
|
||||
"3" | "tier3" => super::whatsapp_rate_limiter::WhatsAppTier::Tier3,
|
||||
"4" | "tier4" => super::whatsapp_rate_limiter::WhatsAppTier::Tier4,
|
||||
_ => super::whatsapp_rate_limiter::WhatsAppTier::Tier1,
|
||||
};
|
||||
|
||||
Self {
|
||||
api_key,
|
||||
phone_number_id,
|
||||
webhook_verify_token: verify_token,
|
||||
_business_account_id: business_account_id,
|
||||
api_version,
|
||||
rate_limiter: WHATSAPP_RATE_LIMITER.get_or_init(|| {
|
||||
super::whatsapp_rate_limiter::WhatsAppRateLimiter::from_tier(tier)
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -114,6 +134,9 @@ impl WhatsAppAdapter {
|
|||
to: &str,
|
||||
message: &str,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
// Wait for rate limiter before making API call
|
||||
self.rate_limiter.acquire().await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let url = format!(
|
||||
|
|
@ -157,6 +180,9 @@ impl WhatsAppAdapter {
|
|||
language_code: &str,
|
||||
components: Vec<serde_json::Value>,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
// Wait for rate limiter before making API call
|
||||
self.rate_limiter.acquire().await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let url = format!(
|
||||
|
|
@ -643,9 +669,9 @@ impl ChannelAdapter for WhatsAppAdapter {
|
|||
i + 1, parts.len(), response.user_id, &part.chars().take(50).collect::<String>(), message_id
|
||||
);
|
||||
|
||||
// Small delay between messages to avoid rate limiting
|
||||
// Use rate limiter to wait before sending next message
|
||||
if i < parts.len() - 1 {
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
|
||||
self.rate_limiter.acquire().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
315
src/core/bot/channels/whatsapp_rate_limiter.rs
Normal file
315
src/core/bot/channels/whatsapp_rate_limiter.rs
Normal file
|
|
@ -0,0 +1,315 @@
|
|||
//! WhatsApp Rate Limiter
|
||||
//!
|
||||
//! Implements rate limiting for WhatsApp Cloud API based on Meta's throughput tiers.
|
||||
//!
|
||||
//! ## Meta WhatsApp Rate Limits (per phone number)
|
||||
//!
|
||||
//! | Tier | Messages/second | Conversations/day |
|
||||
//! |------|-----------------|-------------------|
|
||||
//! | 1 | 40 | 1,000 |
|
||||
//! | 2 | 80 | 10,000 |
|
||||
//! | 3 | 200 | 100,000 |
|
||||
//! | 4 | 400+ | Unlimited |
|
||||
//!
|
||||
//! New phone numbers start at Tier 1.
|
||||
|
||||
use governor::{
|
||||
clock::DefaultClock,
|
||||
middleware::NoOpMiddleware,
|
||||
state::{InMemoryState, NotKeyed},
|
||||
Quota, RateLimiter,
|
||||
};
|
||||
use std::num::NonZeroU32;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::time::sleep;
|
||||
|
||||
/// WhatsApp throughput tier levels (matches Meta's tiers)
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum WhatsAppTier {
|
||||
/// Tier 1: New phone numbers (40 msg/s, 1000 conv/day)
|
||||
Tier1,
|
||||
/// Tier 2: Medium quality (80 msg/s, 10000 conv/day)
|
||||
Tier2,
|
||||
/// Tier 3: High quality (200 msg/s, 100000 conv/day)
|
||||
Tier3,
|
||||
/// Tier 4: Premium (400+ msg/s, unlimited)
|
||||
Tier4,
|
||||
}
|
||||
|
||||
impl Default for WhatsAppTier {
|
||||
fn default() -> Self {
|
||||
Self::Tier1
|
||||
}
|
||||
}
|
||||
|
||||
impl WhatsAppTier {
|
||||
/// Get messages per second for this tier
|
||||
pub fn messages_per_second(&self) -> u32 {
|
||||
match self {
|
||||
Self::Tier1 => 40,
|
||||
Self::Tier2 => 80,
|
||||
Self::Tier3 => 200,
|
||||
Self::Tier4 => 400,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get burst size (slightly higher to allow brief spikes)
|
||||
pub fn burst_size(&self) -> u32 {
|
||||
match self {
|
||||
Self::Tier1 => 50,
|
||||
Self::Tier2 => 100,
|
||||
Self::Tier3 => 250,
|
||||
Self::Tier4 => 500,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get minimum delay between messages (for streaming)
|
||||
pub fn min_delay_ms(&self) -> u64 {
|
||||
match self {
|
||||
Self::Tier1 => 25, // 40 msg/s = 25ms between messages
|
||||
Self::Tier2 => 12, // 80 msg/s = 12.5ms
|
||||
Self::Tier3 => 5, // 200 msg/s = 5ms
|
||||
Self::Tier4 => 2, // 400 msg/s = 2.5ms
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for WhatsAppTier {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Tier1 => write!(f, "Tier 1 (40 msg/s)"),
|
||||
Self::Tier2 => write!(f, "Tier 2 (80 msg/s)"),
|
||||
Self::Tier3 => write!(f, "Tier 3 (200 msg/s)"),
|
||||
Self::Tier4 => write!(f, "Tier 4 (400+ msg/s)"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for WhatsApp rate limiting
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WhatsAppRateLimitConfig {
|
||||
/// Throughput tier (determines rate limits)
|
||||
pub tier: WhatsAppTier,
|
||||
/// Custom messages per second (overrides tier if set)
|
||||
pub custom_mps: Option<u32>,
|
||||
/// Custom burst size (overrides tier if set)
|
||||
pub custom_burst: Option<u32>,
|
||||
/// Enable rate limiting
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
impl Default for WhatsAppRateLimitConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
tier: WhatsAppTier::Tier1,
|
||||
custom_mps: None,
|
||||
custom_burst: None,
|
||||
enabled: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl WhatsAppRateLimitConfig {
|
||||
/// Create config for a specific tier
|
||||
pub fn from_tier(tier: WhatsAppTier) -> Self {
|
||||
Self {
|
||||
tier,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create config with custom rate
|
||||
pub fn custom(messages_per_second: u32, burst_size: u32) -> Self {
|
||||
Self {
|
||||
tier: WhatsAppTier::Tier1,
|
||||
custom_mps: Some(messages_per_second),
|
||||
custom_burst: Some(burst_size),
|
||||
enabled: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get effective messages per second
|
||||
pub fn effective_mps(&self) -> u32 {
|
||||
self.custom_mps.unwrap_or_else(|| self.tier.messages_per_second())
|
||||
}
|
||||
|
||||
/// Get effective burst size
|
||||
pub fn effective_burst(&self) -> u32 {
|
||||
self.custom_burst.unwrap_or_else(|| self.tier.burst_size())
|
||||
}
|
||||
}
|
||||
|
||||
type Limiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock, NoOpMiddleware>;
|
||||
|
||||
/// WhatsApp Rate Limiter
|
||||
///
|
||||
/// Uses token bucket algorithm via governor crate.
|
||||
/// Thread-safe and async-friendly.
|
||||
#[derive(Debug)]
|
||||
pub struct WhatsAppRateLimiter {
|
||||
limiter: Arc<Limiter>,
|
||||
config: WhatsAppRateLimitConfig,
|
||||
min_delay: Duration,
|
||||
}
|
||||
|
||||
impl WhatsAppRateLimiter {
|
||||
/// Create a new rate limiter with default Tier 1 settings
|
||||
pub fn new() -> Self {
|
||||
Self::with_config(WhatsAppRateLimitConfig::default())
|
||||
}
|
||||
|
||||
/// Create a rate limiter for a specific tier
|
||||
pub fn from_tier(tier: WhatsAppTier) -> Self {
|
||||
Self::with_config(WhatsAppRateLimitConfig::from_tier(tier))
|
||||
}
|
||||
|
||||
/// Create a rate limiter with custom configuration
|
||||
pub fn with_config(config: WhatsAppRateLimitConfig) -> Self {
|
||||
let mps = config.effective_mps();
|
||||
let burst = config.effective_burst();
|
||||
let min_delay = Duration::from_millis(config.tier.min_delay_ms());
|
||||
|
||||
let quota = Quota::per_second(
|
||||
NonZeroU32::new(mps).unwrap_or(NonZeroU32::MIN)
|
||||
)
|
||||
.allow_burst(
|
||||
NonZeroU32::new(burst).unwrap_or(NonZeroU32::MIN)
|
||||
);
|
||||
|
||||
Self {
|
||||
limiter: Arc::new(RateLimiter::direct(quota)),
|
||||
config,
|
||||
min_delay,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a message can be sent immediately
|
||||
pub fn check(&self) -> bool {
|
||||
self.limiter.check().is_ok()
|
||||
}
|
||||
|
||||
/// Wait until a message can be sent (async)
|
||||
///
|
||||
/// This will block until the rate limiter allows the message.
|
||||
/// Uses exponential backoff for waiting.
|
||||
pub async fn acquire(&self) {
|
||||
if !self.config.enabled {
|
||||
return;
|
||||
}
|
||||
|
||||
// Try to acquire immediately
|
||||
if self.limiter.check().is_ok() {
|
||||
return;
|
||||
}
|
||||
|
||||
// If not available, wait with minimum delay
|
||||
loop {
|
||||
sleep(self.min_delay).await;
|
||||
if self.limiter.check().is_ok() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to acquire with timeout
|
||||
///
|
||||
/// Returns true if acquired, false if timed out
|
||||
pub async fn try_acquire_timeout(&self, timeout: Duration) -> bool {
|
||||
if !self.config.enabled {
|
||||
return true;
|
||||
}
|
||||
|
||||
if self.limiter.check().is_ok() {
|
||||
return true;
|
||||
}
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
while start.elapsed() < timeout {
|
||||
sleep(self.min_delay).await;
|
||||
if self.limiter.check().is_ok() {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Get current configuration
|
||||
pub fn config(&self) -> &WhatsAppRateLimitConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Get the tier
|
||||
pub fn tier(&self) -> WhatsAppTier {
|
||||
self.config.tier
|
||||
}
|
||||
|
||||
/// Get minimum delay between messages
|
||||
pub fn min_delay(&self) -> Duration {
|
||||
self.min_delay
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for WhatsAppRateLimiter {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for WhatsAppRateLimiter {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
limiter: Arc::clone(&self.limiter),
|
||||
config: self.config.clone(),
|
||||
min_delay: self.min_delay,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_tier_defaults() {
|
||||
assert_eq!(WhatsAppTier::Tier1.messages_per_second(), 40);
|
||||
assert_eq!(WhatsAppTier::Tier2.messages_per_second(), 80);
|
||||
assert_eq!(WhatsAppTier::Tier3.messages_per_second(), 200);
|
||||
assert_eq!(WhatsAppTier::Tier4.messages_per_second(), 400);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rate_limiter_creation() {
|
||||
let limiter = WhatsAppRateLimiter::new();
|
||||
assert!(limiter.check());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tier_limiter() {
|
||||
let limiter = WhatsAppRateLimiter::from_tier(WhatsAppTier::Tier2);
|
||||
assert_eq!(limiter.tier(), WhatsAppTier::Tier2);
|
||||
assert!(limiter.check());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_custom_config() {
|
||||
let config = WhatsAppRateLimitConfig::custom(100, 150);
|
||||
assert_eq!(config.effective_mps(), 100);
|
||||
assert_eq!(config.effective_burst(), 150);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_acquire() {
|
||||
let limiter = WhatsAppRateLimiter::from_tier(WhatsAppTier::Tier4);
|
||||
// Should acquire immediately
|
||||
limiter.acquire().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_try_acquire_timeout() {
|
||||
let limiter = WhatsAppRateLimiter::new();
|
||||
// Should succeed immediately
|
||||
let result = limiter.try_acquire_timeout(Duration::from_millis(100)).await;
|
||||
assert!(result);
|
||||
}
|
||||
}
|
||||
|
|
@ -458,7 +458,7 @@ impl BotOrchestrator {
|
|||
.get_config(&session.bot_id, "system-prompt", Some("You are a helpful assistant with access to tools that can help you complete tasks. When a user's request matches one of your available tools, use the appropriate tool instead of providing a generic response."))
|
||||
.unwrap_or_else(|_| "You are a helpful assistant.".to_string());
|
||||
|
||||
trace!("Loaded system-prompt for bot {}: {}", session.bot_id, &system_prompt[..system_prompt.len().min(100)]);
|
||||
info!("Loaded system-prompt for bot {}: {}", session.bot_id, &system_prompt[..system_prompt.len().min(200)]);
|
||||
|
||||
Ok((session, context_data, history, model, key, system_prompt))
|
||||
},
|
||||
|
|
|
|||
287
src/llm/hallucination_detector.rs
Normal file
287
src/llm/hallucination_detector.rs
Normal file
|
|
@ -0,0 +1,287 @@
|
|||
//! Hallucination Loop Detector
|
||||
//!
|
||||
//! Detects when an LLM gets stuck in a repetition loop (hallucination).
|
||||
//! This module provides detection for all channels (web, WhatsApp, Telegram, etc.).
|
||||
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
|
||||
/// Configuration for hallucination detection
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HallucinationConfig {
|
||||
/// Minimum text length before detection starts
|
||||
pub min_text_length: usize,
|
||||
/// Pattern lengths to check (in characters)
|
||||
pub pattern_lengths: Vec<usize>,
|
||||
/// Number of consecutive repetitions to trigger detection
|
||||
pub consecutive_threshold: usize,
|
||||
/// Number of total occurrences in recent text to trigger detection
|
||||
pub occurrence_threshold: usize,
|
||||
/// Recent text window size for occurrence counting
|
||||
pub recent_text_window: usize,
|
||||
/// Number of identical tokens to trigger detection
|
||||
pub identical_token_threshold: usize,
|
||||
/// Common words to ignore (won't trigger detection when repeated)
|
||||
pub ignore_words: Vec<String>,
|
||||
}
|
||||
|
||||
/// Default list of common words that shouldn't trigger hallucination detection
|
||||
const DEFAULT_IGNORE_WORDS: &[&str] = &[
|
||||
"the", "a", "an", "is", "are", "was", "were", "be", "been", "being",
|
||||
"have", "has", "had", "do", "does", "did", "will", "would", "could", "should",
|
||||
"may", "might", "must", "shall", "can", "need", "dare", "ought", "used",
|
||||
"to", "of", "in", "for", "on", "with", "at", "by", "from", "as",
|
||||
"into", "through", "during", "before", "after", "above", "below", "between",
|
||||
"and", "but", "or", "nor", "so", "yet", "both", "either", "neither",
|
||||
"not", "only", "own", "same", "than", "too", "very", "just",
|
||||
"de", "da", "do", "das", "dos", "e", "é", "em", "no", "na", "nos", "nas",
|
||||
"para", "por", "com", "sem", "sobre", "entre", "após", "antes", "depois",
|
||||
"que", "se", "ou", "mas", "porém", "como", "assim", "também", "ainda",
|
||||
"um", "uma", "uns", "umas", "o", "a", "os", "as",
|
||||
];
|
||||
|
||||
impl Default for HallucinationConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
min_text_length: 50,
|
||||
pattern_lengths: vec![3, 4, 5, 6, 8, 10, 15, 20],
|
||||
consecutive_threshold: 5,
|
||||
occurrence_threshold: 8,
|
||||
recent_text_window: 500,
|
||||
identical_token_threshold: 10,
|
||||
ignore_words: DEFAULT_IGNORE_WORDS.iter().map(|s| s.to_string()).collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// State for tracking hallucination during streaming
|
||||
#[derive(Debug)]
|
||||
pub struct HallucinationDetector {
|
||||
config: HallucinationConfig,
|
||||
last_content_hash: u64,
|
||||
identical_count: usize,
|
||||
detected: bool,
|
||||
detected_pattern: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for HallucinationDetector {
|
||||
fn default() -> Self {
|
||||
Self::new(HallucinationConfig::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl HallucinationDetector {
|
||||
/// Create a new detector with custom configuration
|
||||
pub fn new(config: HallucinationConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
last_content_hash: 0,
|
||||
identical_count: 0,
|
||||
detected: false,
|
||||
detected_pattern: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if hallucination has been detected
|
||||
pub fn is_detected(&self) -> bool {
|
||||
self.detected
|
||||
}
|
||||
|
||||
/// Get the detected pattern if any
|
||||
pub fn get_detected_pattern(&self) -> Option<&str> {
|
||||
self.detected_pattern.as_deref()
|
||||
}
|
||||
|
||||
/// Get the detected pattern as owned String
|
||||
pub fn get_detected_pattern_owned(&self) -> Option<String> {
|
||||
self.detected_pattern.clone()
|
||||
}
|
||||
|
||||
/// Check a new token/chunk for hallucination patterns
|
||||
/// Returns true if hallucination is detected
|
||||
pub fn check_token(&mut self, token: &str) -> bool {
|
||||
if self.detected {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check for identical token repetition
|
||||
if !token.trim().is_empty() {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
token.hash(&mut hasher);
|
||||
let content_hash = hasher.finish();
|
||||
|
||||
if content_hash == self.last_content_hash {
|
||||
self.identical_count += 1;
|
||||
if self.identical_count >= self.config.identical_token_threshold {
|
||||
log::warn!(
|
||||
"LLM hallucination detected: identical token repeated {} times: {:?}",
|
||||
self.identical_count,
|
||||
token
|
||||
);
|
||||
self.detected = true;
|
||||
self.detected_pattern = Some(format!("{} ({}x)", token.trim(), self.identical_count));
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
self.identical_count = 0;
|
||||
}
|
||||
self.last_content_hash = content_hash;
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// Check accumulated text for repetition patterns
|
||||
/// Returns Some(pattern) if hallucination is detected
|
||||
pub fn check_text(&mut self, text: &str) -> Option<String> {
|
||||
if self.detected {
|
||||
return self.detected_pattern.clone();
|
||||
}
|
||||
|
||||
// Skip detection for short texts
|
||||
if text.len() < self.config.min_text_length {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Check for repeated patterns of various lengths
|
||||
for pattern_len in &self.config.pattern_lengths {
|
||||
if text.len() < *pattern_len * 5 {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get the last pattern to check
|
||||
let chars: Vec<char> = text.chars().collect();
|
||||
let start = chars.len().saturating_sub(*pattern_len);
|
||||
let pattern: String = chars[start..].iter().collect();
|
||||
let pattern_str = pattern.trim();
|
||||
|
||||
if pattern_str.is_empty() || pattern_str.len() < 2 {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Count how many times this pattern appears consecutively at the end
|
||||
let mut count = 0;
|
||||
let mut search_text = text;
|
||||
|
||||
while search_text.ends_with(pattern_str) || search_text.ends_with(&pattern) {
|
||||
count += 1;
|
||||
if count >= self.config.consecutive_threshold {
|
||||
// Found threshold repetitions - likely hallucination
|
||||
log::warn!(
|
||||
"LLM hallucination loop detected: pattern {:?} repeated {} times consecutively",
|
||||
pattern_str,
|
||||
count
|
||||
);
|
||||
self.detected = true;
|
||||
self.detected_pattern = Some(pattern_str.to_string());
|
||||
return self.detected_pattern.clone();
|
||||
}
|
||||
// Remove one occurrence and continue checking
|
||||
if search_text.ends_with(pattern_str) {
|
||||
search_text = &search_text[..search_text.len().saturating_sub(pattern_str.len())];
|
||||
} else {
|
||||
search_text = &search_text[..search_text.len().saturating_sub(pattern.len())];
|
||||
}
|
||||
}
|
||||
|
||||
// Alternative: count total occurrences in recent text
|
||||
let recent_start = chars.len().saturating_sub(self.config.recent_text_window);
|
||||
let recent_text: String = chars[recent_start..].iter().collect();
|
||||
let total_count = recent_text.matches(pattern_str).count();
|
||||
if total_count >= self.config.occurrence_threshold && pattern_str.len() >= 3 {
|
||||
log::warn!(
|
||||
"LLM hallucination loop detected: pattern {:?} appears {} times in recent {} chars",
|
||||
pattern_str,
|
||||
total_count,
|
||||
self.config.recent_text_window
|
||||
);
|
||||
self.detected = true;
|
||||
self.detected_pattern = Some(format!("{} ({}x)", pattern_str, total_count));
|
||||
return self.detected_pattern.clone();
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Combined check: both token and accumulated text
|
||||
/// Returns true if hallucination detected
|
||||
pub fn check(&mut self, token: &str, accumulated_text: &str) -> bool {
|
||||
// First check token repetition
|
||||
if self.check_token(token) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Then check accumulated text for patterns
|
||||
if self.check_text(accumulated_text).is_some() {
|
||||
return true;
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// Reset the detector state (for new conversations)
|
||||
pub fn reset(&mut self) {
|
||||
self.last_content_hash = 0;
|
||||
self.identical_count = 0;
|
||||
self.detected = false;
|
||||
self.detected_pattern = None;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_identical_token_detection() {
|
||||
let mut detector = HallucinationDetector::default();
|
||||
|
||||
// Same token repeated
|
||||
for _ in 0..9 {
|
||||
assert!(!detector.check_token("GBJ2KP"));
|
||||
}
|
||||
// 10th repetition should trigger
|
||||
assert!(detector.check_token("GBJ2KP"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pattern_repetition() {
|
||||
let mut detector = HallucinationDetector::default();
|
||||
|
||||
// Build text with repeated pattern
|
||||
let repeated = "XYZ123 ".repeat(6);
|
||||
let result = detector.check_text(&repeated);
|
||||
|
||||
assert!(result.is_some());
|
||||
assert!(detector.is_detected());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normal_text_not_detected() {
|
||||
let mut detector = HallucinationDetector::default();
|
||||
|
||||
let normal_text = "This is a normal response without any repetition patterns. \
|
||||
The LLM is generating coherent text that makes sense.";
|
||||
|
||||
assert!(!detector.check_token("normal"));
|
||||
assert!(detector.check_text(normal_text).is_none());
|
||||
assert!(!detector.is_detected());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reset() {
|
||||
let mut detector = HallucinationDetector::default();
|
||||
|
||||
// Trigger detection
|
||||
for _ in 0..10 {
|
||||
detector.check_token("REPEAT");
|
||||
}
|
||||
assert!(detector.is_detected());
|
||||
|
||||
// Reset
|
||||
detector.reset();
|
||||
assert!(!detector.is_detected());
|
||||
assert!(detector.get_detected_pattern().is_none());
|
||||
}
|
||||
}
|
||||
|
|
@ -9,6 +9,7 @@ pub mod cache;
|
|||
pub mod claude;
|
||||
pub mod episodic_memory;
|
||||
pub mod glm;
|
||||
pub mod hallucination_detector;
|
||||
pub mod llm_models;
|
||||
pub mod local;
|
||||
pub mod rate_limiter;
|
||||
|
|
|
|||
|
|
@ -625,7 +625,7 @@ async fn process_incoming_message(
|
|||
}
|
||||
}
|
||||
|
||||
let (session, is_new) = find_or_create_session(&state, bot_id, &phone, &name).await?;
|
||||
let (session, is_new) = find_or_create_session(&state, &effective_bot_id, &phone, &name).await?;
|
||||
|
||||
let needs_human = check_needs_human(&session);
|
||||
|
||||
|
|
@ -1104,16 +1104,31 @@ async fn route_to_bot(
|
|||
if let Err(e) = adapter.send_message(wa_response).await {
|
||||
log::error!("Failed to send WhatsApp response part: {}", e);
|
||||
}
|
||||
|
||||
// Small delay between parts to avoid rate limiting
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
|
||||
// Rate limiting is handled by WhatsAppAdapter::send_whatsapp_message
|
||||
}
|
||||
|
||||
// Use the shared LLM hallucination detector
|
||||
let mut hallucination_detector = crate::llm::hallucination_detector::HallucinationDetector::default();
|
||||
|
||||
while let Some(response) = rx.recv().await {
|
||||
let is_final = response.is_complete;
|
||||
|
||||
if !response.content.is_empty() {
|
||||
buffer.push_str(&response.content);
|
||||
|
||||
// Check for hallucination using the shared LLM detector
|
||||
if hallucination_detector.check(&response.content, &buffer) {
|
||||
warn!(
|
||||
"WA hallucination detected: {:?}, stopping stream",
|
||||
hallucination_detector.get_detected_pattern()
|
||||
);
|
||||
// Send what we have and stop
|
||||
if !buffer.trim().is_empty() {
|
||||
let clean_buffer = buffer.trim_end();
|
||||
send_part(&adapter_for_send, &phone, clean_buffer.to_string(), true).await;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// IMPROVED LOGIC:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue