revert: restore working LLM streaming code from 260a13e7
All checks were successful
BotServer CI/CD / build (push) Successful in 11m59s
All checks were successful
BotServer CI/CD / build (push) Successful in 11m59s
The recent LLM changes (timeouts, tool call accumulation, extra logging) broke the WebSocket message flow. Reverting to the known working version.
This commit is contained in:
parent
301a7dda33
commit
ed3406dd80
1 changed files with 43 additions and 178 deletions
221
src/llm/mod.rs
221
src/llm/mod.rs
|
|
@ -3,7 +3,6 @@ use futures::StreamExt;
|
|||
use log::{error, info};
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, RwLock};
|
||||
|
||||
pub mod cache;
|
||||
|
|
@ -11,10 +10,8 @@ pub mod claude;
|
|||
pub mod episodic_memory;
|
||||
pub mod glm;
|
||||
pub mod hallucination_detector;
|
||||
pub mod kimi;
|
||||
pub mod llm_models;
|
||||
pub mod local;
|
||||
pub mod observability;
|
||||
pub mod rate_limiter;
|
||||
pub mod smart_router;
|
||||
pub mod vertex;
|
||||
|
|
@ -22,7 +19,6 @@ pub mod bedrock;
|
|||
|
||||
pub use claude::ClaudeClient;
|
||||
pub use glm::GLMClient;
|
||||
pub use kimi::KimiClient;
|
||||
pub use llm_models::get_handler;
|
||||
pub use rate_limiter::{ApiRateLimiter, RateLimits};
|
||||
pub use vertex::VertexTokenManager;
|
||||
|
|
@ -173,12 +169,7 @@ impl OpenAIClient {
|
|||
let has_chat_path = !has_v1_path && trimmed_base.contains("/chat/completions");
|
||||
|
||||
let endpoint = if let Some(path) = endpoint_path {
|
||||
// If endpoint path is already in base URL, use empty to avoid duplication
|
||||
if trimmed_base.contains(&path) {
|
||||
"".to_string()
|
||||
} else {
|
||||
path
|
||||
}
|
||||
path
|
||||
} else if has_v1_path || (has_chat_path && !trimmed_base.contains("z.ai")) {
|
||||
// Path already in base_url, use empty endpoint
|
||||
"".to_string()
|
||||
|
|
@ -199,11 +190,7 @@ impl OpenAIClient {
|
|||
};
|
||||
|
||||
Self {
|
||||
client: reqwest::Client::builder()
|
||||
.connect_timeout(Duration::from_secs(5))
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()
|
||||
.unwrap_or_else(|_| reqwest::Client::new()),
|
||||
client: reqwest::Client::new(),
|
||||
base_url: base,
|
||||
endpoint_path: endpoint,
|
||||
rate_limiter: Arc::new(rate_limiter),
|
||||
|
|
@ -227,39 +214,21 @@ impl OpenAIClient {
|
|||
history: &[(String, String)],
|
||||
) -> Value {
|
||||
let mut messages = Vec::new();
|
||||
|
||||
let mut system_parts = Vec::new();
|
||||
if !system_prompt.is_empty() {
|
||||
system_parts.push(Self::sanitize_utf8(system_prompt));
|
||||
}
|
||||
if !context_data.is_empty() {
|
||||
system_parts.push(Self::sanitize_utf8(context_data));
|
||||
}
|
||||
|
||||
for (role, content) in history {
|
||||
if role == "episodic" || role == "compact" {
|
||||
system_parts.push(format!("[Previous conversation summary]: {}", Self::sanitize_utf8(content)));
|
||||
}
|
||||
}
|
||||
|
||||
if !system_parts.is_empty() {
|
||||
messages.push(serde_json::json!({
|
||||
"role": "system",
|
||||
"content": system_parts.join("\n\n")
|
||||
"content": Self::sanitize_utf8(system_prompt)
|
||||
}));
|
||||
}
|
||||
|
||||
for (role, content) in history {
|
||||
let normalized_role = match role.as_str() {
|
||||
"user" => "user",
|
||||
"assistant" => "assistant",
|
||||
"system" => "system",
|
||||
"episodic" | "compact" => continue, // Already added to system prompt
|
||||
_ => "user", // Fallback Default
|
||||
};
|
||||
|
||||
if !context_data.is_empty() {
|
||||
messages.push(serde_json::json!({
|
||||
"role": normalized_role,
|
||||
"role": "system",
|
||||
"content": Self::sanitize_utf8(context_data)
|
||||
}));
|
||||
}
|
||||
for (role, content) in history {
|
||||
messages.push(serde_json::json!({
|
||||
"role": role,
|
||||
"content": Self::sanitize_utf8(content)
|
||||
}));
|
||||
}
|
||||
|
|
@ -326,11 +295,7 @@ impl LLMProvider for OpenAIClient {
|
|||
.json(&serde_json::json!({
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": false,
|
||||
"max_tokens": 131072,
|
||||
"extra_body": {
|
||||
"reasoning_split": false
|
||||
}
|
||||
"stream": false
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
|
@ -338,12 +303,6 @@ impl LLMProvider for OpenAIClient {
|
|||
let status = response.status();
|
||||
if status != reqwest::StatusCode::OK {
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
|
||||
// Handle 429 rate limit with user-friendly message
|
||||
if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
|
||||
return Ok("Server is busy, please try again in a few seconds...".to_string());
|
||||
}
|
||||
|
||||
error!("LLM generate error: {}", error_text);
|
||||
return Err(format!("LLM request failed with status: {}", status).into());
|
||||
}
|
||||
|
|
@ -423,11 +382,7 @@ impl LLMProvider for OpenAIClient {
|
|||
let mut request_body = serde_json::json!({
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": true,
|
||||
"max_tokens": 131072,
|
||||
"extra_body": {
|
||||
"reasoning_split": false
|
||||
}
|
||||
"stream": true
|
||||
});
|
||||
|
||||
// Add tools to the request if provided
|
||||
|
|
@ -438,92 +393,45 @@ impl LLMProvider for OpenAIClient {
|
|||
}
|
||||
}
|
||||
|
||||
info!("LLM: Sending request to {}", full_url);
|
||||
let response = self
|
||||
.client
|
||||
.post(&full_url)
|
||||
.header("Authorization", &auth_header)
|
||||
.json(&request_body)
|
||||
.send()
|
||||
.await?;
|
||||
info!("LLM: Response received with status {}", response.status());
|
||||
let response = self
|
||||
.client
|
||||
.post(&full_url)
|
||||
.header("Authorization", &auth_header)
|
||||
.json(&request_body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
if status != reqwest::StatusCode::OK {
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
|
||||
// Handle 429 rate limit with user-friendly message
|
||||
if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
|
||||
let _ = tx.send("Server is busy, please try again in a few seconds...".to_string()).await;
|
||||
return Err("Rate limit exceeded".into());
|
||||
}
|
||||
|
||||
error!("LLM generate_stream error: {}", error_text);
|
||||
return Err(format!("LLM request failed with status: {}", status).into());
|
||||
}
|
||||
|
||||
let handler = get_handler(model);
|
||||
let mut stream = response.bytes_stream();
|
||||
let handler = get_handler(model);
|
||||
let mut stream = response.bytes_stream();
|
||||
|
||||
// Accumulate tool calls here because OpenAI streams them in fragments
|
||||
let mut active_tool_calls: Vec<serde_json::Value> = Vec::new();
|
||||
let mut chunk_count = 0u32;
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
chunk_count += 1;
|
||||
if chunk_count <= 5 || chunk_count % 50 == 0 {
|
||||
info!("LLM: Received chunk #{}", chunk_count);
|
||||
}
|
||||
let chunk = chunk_result?;
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = chunk_result?;
|
||||
let chunk_str = String::from_utf8_lossy(&chunk);
|
||||
for line in chunk_str.lines() {
|
||||
if line.starts_with("data: ") && !line.contains("[DONE]") {
|
||||
if let Ok(data) = serde_json::from_str::<Value>(&line[6..]) {
|
||||
let content = data["choices"][0]["delta"]["content"].as_str();
|
||||
|
||||
// TEMP DISABLED: Thinking detection causing deadlock issues
|
||||
// Just pass content through directly without any thinking detection
|
||||
|
||||
if let Some(text) = content {
|
||||
let processed = handler.process_content(text);
|
||||
if !processed.is_empty() {
|
||||
let _ = tx.send(processed).await;
|
||||
}
|
||||
}
|
||||
if let Ok(data) = serde_json::from_str::<Value>(&line[6..]) {
|
||||
if let Some(content) = data["choices"][0]["delta"]["content"].as_str() {
|
||||
let processed = handler.process_content(content);
|
||||
if !processed.is_empty() {
|
||||
let _ = tx.send(processed).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle standard OpenAI tool_calls
|
||||
if let Some(tool_calls) = data["choices"][0]["delta"]["tool_calls"].as_array() {
|
||||
for tool_delta in tool_calls {
|
||||
if let Some(index) = tool_delta["index"].as_u64() {
|
||||
let idx = index as usize;
|
||||
if active_tool_calls.len() <= idx {
|
||||
active_tool_calls.resize(idx + 1, serde_json::json!({
|
||||
"id": "",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "",
|
||||
"arguments": ""
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
let current = &mut active_tool_calls[idx];
|
||||
|
||||
if let Some(id) = tool_delta["id"].as_str() {
|
||||
current["id"] = serde_json::Value::String(id.to_string());
|
||||
}
|
||||
|
||||
if let Some(func) = tool_delta.get("function") {
|
||||
if let Some(name) = func.get("name").and_then(|n| n.as_str()) {
|
||||
current["function"]["name"] = serde_json::Value::String(name.to_string());
|
||||
}
|
||||
if let Some(args) = func.get("arguments").and_then(|a| a.as_str()) {
|
||||
if let Some(existing_args) = current["function"]["arguments"].as_str() {
|
||||
let mut new_args = existing_args.to_string();
|
||||
new_args.push_str(args);
|
||||
current["function"]["arguments"] = serde_json::Value::String(new_args);
|
||||
}
|
||||
}
|
||||
for tool_call in tool_calls {
|
||||
// We send the tool_call object as a JSON string so stream_response
|
||||
// can buffer it and parse it using ToolExecutor::parse_tool_call
|
||||
if let Some(func) = tool_call.get("function") {
|
||||
if let Some(args) = func.get("arguments").and_then(|a| a.as_str()) {
|
||||
let _ = tx.send(args.to_string()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -532,20 +440,8 @@ let content = data["choices"][0]["delta"]["content"].as_str();
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Send accumulated tool calls when stream finishes
|
||||
for tool_call in active_tool_calls {
|
||||
if !tool_call["function"]["name"].as_str().unwrap_or("").is_empty() {
|
||||
let tool_call_json = serde_json::json!({
|
||||
"type": "tool_call",
|
||||
"content": tool_call
|
||||
}).to_string();
|
||||
let _ = tx.send(tool_call_json).await;
|
||||
}
|
||||
}
|
||||
|
||||
info!("LLM: Stream complete, sent {} chunks", chunk_count);
|
||||
Ok(())
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cancel_job(
|
||||
|
|
@ -567,8 +463,6 @@ pub enum LLMProviderType {
|
|||
Claude,
|
||||
AzureClaude,
|
||||
GLM,
|
||||
Kimi,
|
||||
OSS,
|
||||
Bedrock,
|
||||
Vertex,
|
||||
}
|
||||
|
|
@ -582,19 +476,12 @@ impl From<&str> for LLMProviderType {
|
|||
} else {
|
||||
Self::Claude
|
||||
}
|
||||
} else if lower.contains("kimi") || lower.contains("moonshot") {
|
||||
Self::Kimi
|
||||
} else if lower.contains("z.ai") || lower.contains("glm") {
|
||||
Self::GLM
|
||||
} else if lower.contains("gpt-oss") {
|
||||
Self::OSS
|
||||
} else if lower.contains("bedrock") {
|
||||
Self::Bedrock
|
||||
} else if lower.contains("googleapis.com") || lower.contains("vertex") || lower.contains("generativelanguage") {
|
||||
Self::Vertex
|
||||
} else if lower.contains("nvidia") {
|
||||
// Default NVIDIA provider: GLM for now, can be overridden
|
||||
Self::GLM
|
||||
} else {
|
||||
Self::OpenAI
|
||||
}
|
||||
|
|
@ -632,18 +519,6 @@ pub fn create_llm_provider(
|
|||
info!("Creating GLM/z.ai LLM provider with URL: {}", base_url);
|
||||
std::sync::Arc::new(GLMClient::new(base_url))
|
||||
}
|
||||
LLMProviderType::Kimi => {
|
||||
info!("Creating Kimi K2.5 LLM provider with URL: {}", base_url);
|
||||
std::sync::Arc::new(KimiClient::new(base_url))
|
||||
}
|
||||
LLMProviderType::OSS => {
|
||||
info!("Creating GPT-OSS LLM provider with URL: {}", base_url);
|
||||
std::sync::Arc::new(OpenAIClient::new(
|
||||
"empty".to_string(),
|
||||
Some(base_url),
|
||||
endpoint_path,
|
||||
))
|
||||
}
|
||||
LLMProviderType::Bedrock => {
|
||||
info!("Creating Bedrock LLM provider with exact URL: {}", base_url);
|
||||
std::sync::Arc::new(BedrockClient::new(base_url))
|
||||
|
|
@ -657,27 +532,17 @@ pub fn create_llm_provider(
|
|||
}
|
||||
|
||||
/// Create LLM provider from URL with optional explicit provider type override.
|
||||
/// Priority: explicit_provider > model name > URL-based detection.
|
||||
/// If explicit_provider is Some, it takes precedence over URL-based detection.
|
||||
pub fn create_llm_provider_from_url(
|
||||
url: &str,
|
||||
model: Option<String>,
|
||||
endpoint_path: Option<String>,
|
||||
explicit_provider: Option<LLMProviderType>,
|
||||
) -> std::sync::Arc<dyn LLMProvider> {
|
||||
let provider_type = if let Some(p) = explicit_provider {
|
||||
info!("Using explicit LLM provider type: {:?} for URL: {}", p, url);
|
||||
p
|
||||
} else if let Some(ref m) = model {
|
||||
// Model name takes precedence over URL
|
||||
let detected = LLMProviderType::from(m.as_str());
|
||||
info!("Detected LLM provider type from model '{}': {:?}", m, detected);
|
||||
detected
|
||||
} else {
|
||||
// Fall back to URL-based detection
|
||||
let detected = LLMProviderType::from(url);
|
||||
info!("Detected LLM provider type from URL: {:?}", detected);
|
||||
detected
|
||||
};
|
||||
let provider_type = explicit_provider.as_ref().map(|p| *p).unwrap_or_else(|| LLMProviderType::from(url));
|
||||
if explicit_provider.is_some() {
|
||||
info!("Using explicit LLM provider type: {:?} for URL: {}", provider_type, url);
|
||||
}
|
||||
create_llm_provider(provider_type, url.to_string(), model, endpoint_path)
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue