#![allow(unused_imports)] #![allow(unused_variables)] #![allow(dead_code)] use serde_json; use async_trait::async_trait; use serde_json::json; use std::sync::Arc; use tokio::sync::mpsc; struct MockLLMProvider { response: String, call_count: std::sync::atomic::AtomicUsize, } impl MockLLMProvider { fn new(response: &str) -> Self { Self { response: response.to_string(), call_count: std::sync::atomic::AtomicUsize::new(0), } } fn get_call_count(&self) -> usize { self.call_count.load(std::sync::atomic::Ordering::SeqCst) } } #[async_trait] impl LLMProvider for MockLLMProvider { async fn generate( &self, _prompt: &str, _messages: &serde_json::Value, _model: &str, _key: &str, ) -> Result> { self.call_count .fetch_add(1, std::sync::atomic::Ordering::SeqCst); Ok(self.response.clone()) } async fn generate_stream( &self, _prompt: &str, _messages: &serde_json::Value, tx: mpsc::Sender, _model: &str, _key: &str, ) -> Result<(), Box> { self.call_count .fetch_add(1, std::sync::atomic::Ordering::SeqCst); let _ = tx.send(self.response.clone()).await; Ok(()) } async fn cancel_job( &self, _session_id: &str, ) -> Result<(), Box> { Ok(()) } } struct MockEmbeddingService; #[async_trait] impl EmbeddingService for MockEmbeddingService { async fn get_embedding( &self, text: &str, ) -> Result, Box> { let hash = text.bytes().fold(0u32, |acc, b| acc.wrapping_add(b as u32)); Ok(vec![hash as f32 / 255.0; 10]) } async fn compute_similarity(&self, embedding1: &[f32], embedding2: &[f32]) -> f32 { if embedding1.len() != embedding2.len() { return 0.0; } let diff: f32 = embedding1 .iter() .zip(embedding2.iter()) .map(|(a, b)| (a - b).abs()) .sum(); 1.0 - (diff / embedding1.len() as f32).min(1.0) } } #[tokio::test] async fn test_exact_cache_hit() { let mock_provider = Arc::new(MockLLMProvider::new("Test response")); let cache_client = Arc::new(redis::Client::open("redis://127.0.0.1/").unwrap()); let config = CacheConfig { ttl: 60, semantic_matching: false, similarity_threshold: 0.95, max_similarity_checks: 10, key_prefix: "test_cache".to_string(), }; let cached_provider = CachedLLMProvider::new(mock_provider.clone(), cache_client, config, None); let prompt = "What is the weather?"; let messages = json!([{"role": "user", "content": prompt}]); let model = "test-model"; let key = "test-key"; let result1 = cached_provider .generate(prompt, &messages, model, key) .await .unwrap(); assert_eq!(result1, "Test response"); assert_eq!(mock_provider.get_call_count(), 1); let result2 = cached_provider .generate(prompt, &messages, model, key) .await .unwrap(); assert_eq!(result2, "Test response"); assert_eq!(mock_provider.get_call_count(), 1); } #[tokio::test] async fn test_semantic_cache_hit() { let mock_provider = Arc::new(MockLLMProvider::new("Weather is sunny")); let cache_client = Arc::new(redis::Client::open("redis://127.0.0.1/").unwrap()); let config = CacheConfig { ttl: 60, semantic_matching: true, similarity_threshold: 0.8, max_similarity_checks: 10, key_prefix: "test_semantic".to_string(), }; let embedding_service = Arc::new(MockEmbeddingService); let cached_provider = CachedLLMProvider::new( mock_provider.clone(), cache_client, config, Some(embedding_service), ); let messages = json!([{"role": "user", "content": "test"}]); let model = "test-model"; let key = "test-key"; let result1 = cached_provider .generate("What's the weather?", &messages, model, key) .await .unwrap(); assert_eq!(result1, "Weather is sunny"); assert_eq!(mock_provider.get_call_count(), 1); let result2 = cached_provider .generate("What is the weather?", &messages, model, key) .await .unwrap(); assert_eq!(result2, "Weather is sunny"); assert_eq!(mock_provider.get_call_count(), 2); } #[tokio::test] async fn test_cache_miss_different_model() { let mock_provider = Arc::new(MockLLMProvider::new("Response")); let cache_client = Arc::new(redis::Client::open("redis://127.0.0.1/").unwrap()); let config = CacheConfig::default(); let cached_provider = CachedLLMProvider::new(mock_provider.clone(), cache_client, config, None); let prompt = "Same prompt"; let messages = json!([{"role": "user", "content": prompt}]); let key = "test-key"; let _ = cached_provider .generate(prompt, &messages, "model1", key) .await .unwrap(); assert_eq!(mock_provider.get_call_count(), 1); let _ = cached_provider .generate(prompt, &messages, "model2", key) .await .unwrap(); assert_eq!(mock_provider.get_call_count(), 2); } #[tokio::test] async fn test_cache_statistics() { let mock_provider = Arc::new(MockLLMProvider::new("Response")); let cache_client = Arc::new(redis::Client::open("redis://127.0.0.1/").unwrap()); let config = CacheConfig { ttl: 60, semantic_matching: false, similarity_threshold: 0.95, max_similarity_checks: 10, key_prefix: "test_stats".to_string(), }; let cached_provider = CachedLLMProvider::new(mock_provider, cache_client, config, None); let _ = cached_provider.clear_cache(None).await; let messages = json!([]); for i in 0..5 { let _ = cached_provider .generate(&format!("prompt_{}", i), &messages, "model", "key") .await; } for i in 0..3 { let _ = cached_provider .generate(&format!("prompt_{}", i), &messages, "model", "key") .await; } let stats = cached_provider.get_cache_stats().await.unwrap(); assert_eq!(stats.total_entries, 5); assert_eq!(stats.total_hits, 3); assert!(stats.total_size_bytes > 0); assert_eq!(stats.model_distribution.get("model"), Some(&5)); } #[tokio::test] async fn test_stream_generation_with_cache() { let mock_provider = Arc::new(MockLLMProvider::new("Streamed response")); let cache_client = Arc::new(redis::Client::open("redis://127.0.0.1/").unwrap()); let config = CacheConfig { ttl: 60, semantic_matching: false, similarity_threshold: 0.95, max_similarity_checks: 10, key_prefix: "test_stream".to_string(), }; let cached_provider = CachedLLMProvider::new(mock_provider.clone(), cache_client, config, None); let prompt = "Stream this"; let messages = json!([{"role": "user", "content": prompt}]); let model = "test-model"; let key = "test-key"; let (tx1, mut rx1) = mpsc::channel(100); cached_provider .generate_stream(prompt, &messages, tx1, model, key) .await .unwrap(); let mut result1 = String::new(); while let Some(chunk) = rx1.recv().await { result1.push_str(&chunk); } assert_eq!(result1, "Streamed response"); assert_eq!(mock_provider.get_call_count(), 1); let (tx2, mut rx2) = mpsc::channel(100); cached_provider .generate_stream(prompt, &messages, tx2, model, key) .await .unwrap(); let mut result2 = String::new(); while let Some(chunk) = rx2.recv().await { result2.push_str(&chunk); } assert!(result2.contains("Streamed response")); assert_eq!(mock_provider.get_call_count(), 1); } #[test] fn test_cosine_similarity_calculation() { let service = LocalEmbeddingService::new( "http://localhost:8082".to_string(), "test-model".to_string(), ); let vec1 = vec![0.5, 0.5, 0.5]; let vec2 = vec![0.5, 0.5, 0.5]; let similarity = tokio::runtime::Runtime::new() .unwrap() .block_on(service.compute_similarity(&vec1, &vec2)); assert_eq!(similarity, 1.0); let vec3 = vec![1.0, 0.0]; let vec4 = vec![0.0, 1.0]; let similarity = tokio::runtime::Runtime::new() .unwrap() .block_on(service.compute_similarity(&vec3, &vec4)); assert_eq!(similarity, 0.0); let vec5 = vec![1.0, 1.0]; let vec6 = vec![-1.0, -1.0]; let similarity = tokio::runtime::Runtime::new() .unwrap() .block_on(service.compute_similarity(&vec5, &vec6)); assert_eq!(similarity, -1.0); }