diff --git a/src/llm/hallucination_detector.rs b/src/llm/hallucination_detector.rs index 43d6fa78..40f3625c 100644 --- a/src/llm/hallucination_detector.rs +++ b/src/llm/hallucination_detector.rs @@ -1,67 +1,36 @@ -//! Hallucination Loop Detector +//! Simple Hallucination Loop Detector //! -//! Detects when an LLM gets stuck in a repetition loop (hallucination). -//! This module provides detection for all channels (web, WhatsApp, Telegram, etc.). +//! Detects when an LLM gets stuck in a repetition loop. +//! Only triggers when the same pattern repeats 50+ times consecutively. -use std::collections::hash_map::DefaultHasher; -use std::hash::{Hash, Hasher}; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::Mutex; +use log::warn; + +const THRESHOLD: usize = 50; +const WINDOW: Duration = Duration::from_secs(60); -/// Configuration for hallucination detection #[derive(Debug, Clone)] pub struct HallucinationConfig { - /// Minimum text length before detection starts - pub min_text_length: usize, - /// Pattern lengths to check (in characters) - pub pattern_lengths: Vec, - /// Number of consecutive repetitions to trigger detection - pub consecutive_threshold: usize, - /// Number of total occurrences in recent text to trigger detection - pub occurrence_threshold: usize, - /// Recent text window size for occurrence counting - pub recent_text_window: usize, - /// Number of identical tokens to trigger detection - pub identical_token_threshold: usize, - /// Common words to ignore (won't trigger detection when repeated) - pub ignore_words: Vec, + pub threshold: usize, + pub window: Duration, } -/// Default list of common words that shouldn't trigger hallucination detection -const DEFAULT_IGNORE_WORDS: &[&str] = &[ - "the", "a", "an", "is", "are", "was", "were", "be", "been", "being", - "have", "has", "had", "do", "does", "did", "will", "would", "could", "should", - "may", "might", "must", "shall", "can", "need", "dare", "ought", "used", - "to", "of", "in", "for", "on", "with", "at", "by", "from", "as", - "into", "through", "during", "before", "after", "above", "below", "between", - "and", "but", "or", "nor", "so", "yet", "both", "either", "neither", - "not", "only", "own", "same", "than", "too", "very", "just", - "de", "da", "do", "das", "dos", "e", "é", "em", "no", "na", "nos", "nas", - "para", "por", "com", "sem", "sobre", "entre", "após", "antes", "depois", - "que", "se", "ou", "mas", "porém", "como", "assim", "também", "ainda", - "um", "uma", "uns", "umas", "o", "a", "os", "as", -]; - impl Default for HallucinationConfig { fn default() -> Self { Self { - min_text_length: 50, - pattern_lengths: vec![3, 4, 5, 6, 8, 10, 15, 20], - consecutive_threshold: 10, // Increased from 5 to 10 - occurrence_threshold: 15, // Increased from 8 to 15 - recent_text_window: 500, - identical_token_threshold: 15, // Increased from 10 to 15 - ignore_words: DEFAULT_IGNORE_WORDS.iter().map(|s| s.to_string()).collect(), + threshold: THRESHOLD, + window: WINDOW, } } } -/// State for tracking hallucination during streaming -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct HallucinationDetector { config: HallucinationConfig, - last_content_hash: u64, - identical_count: usize, - detected: bool, - detected_pattern: Option, + pattern_counts: Arc>>, } impl Default for HallucinationDetector { @@ -71,166 +40,41 @@ impl Default for HallucinationDetector { } impl HallucinationDetector { - /// Create a new detector with custom configuration pub fn new(config: HallucinationConfig) -> Self { Self { config, - last_content_hash: 0, - identical_count: 0, - detected: false, - detected_pattern: None, + pattern_counts: Arc::new(Mutex::new(HashMap::new())), } } - /// Check if hallucination has been detected - pub fn is_detected(&self) -> bool { - self.detected + /// Check if a pattern is hallucinating (repeating 50+ times) + pub async fn check(&self, pattern: &str) -> bool { + if pattern.trim().is_empty() || pattern.len() < 3 { + return false; + } + + let mut counts = self.pattern_counts.lock().await; + let now = Instant::now(); + + // Clean old entries + counts.retain(|_, (_, time)| now.duration_since(*time) < self.config.window); + + // Increment count for this pattern + let (count, _) = counts.entry(pattern.to_string()).or_insert((0, now)); + *count += 1; + + if *count >= self.config.threshold { + warn!("Hallucination detected: pattern {:?} repeated {} times", pattern, count); + true + } else { + false + } } - /// Get the detected pattern if any - pub fn get_detected_pattern(&self) -> Option<&str> { - self.detected_pattern.as_deref() - } - - /// Get the detected pattern as owned String - pub fn get_detected_pattern_owned(&self) -> Option { - self.detected_pattern.clone() - } - - /// Check a new token/chunk for hallucination patterns - /// Returns true if hallucination is detected - pub fn check_token(&mut self, token: &str) -> bool { - if self.detected { - return true; - } - - // Check for identical token repetition - if !token.trim().is_empty() { - let mut hasher = DefaultHasher::new(); - token.hash(&mut hasher); - let content_hash = hasher.finish(); - - if content_hash == self.last_content_hash { - self.identical_count += 1; - if self.identical_count >= self.config.identical_token_threshold { - log::warn!( - "LLM hallucination detected: identical token repeated {} times: {:?}", - self.identical_count, - token - ); - self.detected = true; - self.detected_pattern = Some(format!("{} ({}x)", token.trim(), self.identical_count)); - return true; - } - } else { - self.identical_count = 0; - } - self.last_content_hash = content_hash; - } - - false - } - - /// Check accumulated text for repetition patterns - /// Returns Some(pattern) if hallucination is detected - pub fn check_text(&mut self, text: &str) -> Option { - if self.detected { - return self.detected_pattern.clone(); - } - - // Skip detection for short texts - if text.len() < self.config.min_text_length { - return None; - } - - // Check for repeated patterns of various lengths - for pattern_len in &self.config.pattern_lengths { - if text.len() < *pattern_len * 5 { - continue; - } - - // Get the last pattern to check - let chars: Vec = text.chars().collect(); - let start = chars.len().saturating_sub(*pattern_len); - let pattern: String = chars[start..].iter().collect(); - let pattern_str = pattern.trim(); - - if pattern_str.is_empty() || pattern_str.len() < 2 { - continue; - } - - // Ignore common Markdown separators - if pattern_str == "---" || pattern_str == "***" || pattern_str == "___" { - continue; - } - - // Count how many times this pattern appears consecutively at the end - let mut count = 0; - let mut search_text = text; - - while search_text.ends_with(pattern_str) || search_text.ends_with(&pattern) { - count += 1; - if count >= self.config.consecutive_threshold { - // Found threshold repetitions - likely hallucination - log::warn!( - "LLM hallucination loop detected: pattern {:?} repeated {} times consecutively", - pattern_str, - count - ); - self.detected = true; - self.detected_pattern = Some(pattern_str.to_string()); - return self.detected_pattern.clone(); - } - // Remove one occurrence and continue checking - if search_text.ends_with(pattern_str) { - search_text = &search_text[..search_text.len().saturating_sub(pattern_str.len())]; - } else { - search_text = &search_text[..search_text.len().saturating_sub(pattern.len())]; - } - } - - // Alternative: count total occurrences in recent text - let recent_start = chars.len().saturating_sub(self.config.recent_text_window); - let recent_text: String = chars[recent_start..].iter().collect(); - let total_count = recent_text.matches(pattern_str).count(); - if total_count >= self.config.occurrence_threshold && pattern_str.len() >= 3 { - log::warn!( - "LLM hallucination loop detected: pattern {:?} appears {} times in recent {} chars", - pattern_str, - total_count, - self.config.recent_text_window - ); - self.detected = true; - self.detected_pattern = Some(format!("{} ({}x)", pattern_str, total_count)); - return self.detected_pattern.clone(); - } - } - - None - } - - /// Combined check: both token and accumulated text - /// Returns true if hallucination detected - pub fn check(&mut self, token: &str, accumulated_text: &str) -> bool { - // First check token repetition - if self.check_token(token) { - return true; - } - - // Then check accumulated text for patterns - if self.check_text(accumulated_text).is_some() { - return true; - } - - false - } - - /// Reset the detector state (for new conversations) - pub fn reset(&mut self) { - self.last_content_hash = 0; - self.identical_count = 0; - self.detected = false; - self.detected_pattern = None; + /// Reset all counts + pub async fn reset(&self) { + let mut counts = self.pattern_counts.lock().await; + counts.clear(); } } @@ -238,55 +82,30 @@ impl HallucinationDetector { mod tests { use super::*; - #[test] - fn test_identical_token_detection() { - let mut detector = HallucinationDetector::default(); - - // Same token repeated - for _ in 0..9 { - assert!(!detector.check_token("GBJ2KP")); + #[tokio::test] + async fn test_no_hallucination_below_threshold() { + let detector = HallucinationDetector::default(); + for _ in 0..49 { + assert!(!detector.check("test_pattern").await); } - // 10th repetition should trigger - assert!(detector.check_token("GBJ2KP")); } - #[test] - fn test_pattern_repetition() { - let mut detector = HallucinationDetector::default(); - - // Build text with repeated pattern - let repeated = "XYZ123 ".repeat(6); - let result = detector.check_text(&repeated); - - assert!(result.is_some()); - assert!(detector.is_detected()); - } - - #[test] - fn test_normal_text_not_detected() { - let mut detector = HallucinationDetector::default(); - - let normal_text = "This is a normal response without any repetition patterns. \ - The LLM is generating coherent text that makes sense."; - - assert!(!detector.check_token("normal")); - assert!(detector.check_text(normal_text).is_none()); - assert!(!detector.is_detected()); - } - - #[test] - fn test_reset() { - let mut detector = HallucinationDetector::default(); - - // Trigger detection - for _ in 0..10 { - detector.check_token("REPEAT"); + #[tokio::test] + async fn test_hallucination_at_threshold() { + let detector = HallucinationDetector::default(); + for _ in 0..50 { + detector.check("test_pattern").await; } - assert!(detector.is_detected()); + assert!(detector.check("test_pattern").await); + } - // Reset - detector.reset(); - assert!(!detector.is_detected()); - assert!(detector.get_detected_pattern().is_none()); + #[tokio::test] + async fn test_reset() { + let detector = HallucinationDetector::default(); + for _ in 0..50 { + detector.check("pattern").await; + } + detector.reset().await; + assert!(!detector.check("pattern").await); } } diff --git a/src/whatsapp/mod.rs b/src/whatsapp/mod.rs index 94a17ee0..25b72b9e 100644 --- a/src/whatsapp/mod.rs +++ b/src/whatsapp/mod.rs @@ -1107,21 +1107,16 @@ async fn route_to_bot( // Rate limiting is handled by WhatsAppAdapter::send_whatsapp_message } - // Use the shared LLM hallucination detector - let mut hallucination_detector = crate::llm::hallucination_detector::HallucinationDetector::default(); + // Use the shared LLM hallucination detector (simple: 50+ repetitions = hallucination) + let detector = crate::llm::hallucination_detector::HallucinationDetector::default(); while let Some(response) = rx.recv().await { let is_final = response.is_complete; if !response.content.is_empty() { - buffer.push_str(&response.content); - - // Check for hallucination using the shared LLM detector - if hallucination_detector.check(&response.content, &buffer) { - warn!( - "WA hallucination detected: {:?}, stopping stream", - hallucination_detector.get_detected_pattern() - ); + // Check for hallucination (50+ repetitions of same pattern) + if detector.check(&response.content).await { + warn!("WA hallucination detected: {:?}, stopping stream", response.content); // Send what we have and stop if !buffer.trim().is_empty() { let clean_buffer = buffer.trim_end();