From 81c60ceb25e71119786e50c6c0df7cd35ab76910 Mon Sep 17 00:00:00 2001 From: "Rodrigo Rodriguez (Pragmatismo)" Date: Mon, 13 Apr 2026 17:36:31 -0300 Subject: [PATCH] feat: add Kimi client and GLM thinking mode support, fix tool exec direct return --- src/core/bot/mod.rs | 42 +++--- src/llm/glm.rs | 81 +++++++--- src/llm/kimi.rs | 356 ++++++++++++++++++++++++++++++++++++++++++++ src/llm/mod.rs | 45 +++++- 4 files changed, 471 insertions(+), 53 deletions(-) create mode 100644 src/llm/kimi.rs diff --git a/src/core/bot/mod.rs b/src/core/bot/mod.rs index 09f1a009..d8de1f23 100644 --- a/src/core/bot/mod.rs +++ b/src/core/bot/mod.rs @@ -404,32 +404,24 @@ impl BotOrchestrator { format!("Erro ao executar '{}': {}", tool_name, tool_result.error.unwrap_or_default()) }; - let trimmed_result = response_content.trim(); - let has_kb = crate::basic::keywords::use_kb::get_active_kbs_for_session( - &self.state.conn, session_id - ).unwrap_or_default(); + // Direct tool execution — return result immediately, no LLM call + let final_response = BotResponse { + bot_id: message.bot_id.clone(), + user_id: message.user_id.clone(), + session_id: message.session_id.clone(), + channel: message.channel.clone(), + content: response_content, + message_type: MessageType::BOT_RESPONSE, + stream_token: None, + is_complete: true, + suggestions: vec![], + context_name: None, + context_length: 0, + context_max_length: 0, + }; - if !has_kb.is_empty() && (trimmed_result.is_empty() || trimmed_result.eq_ignore_ascii_case(tool_name)) { - info!("[TOOL_EXEC] Tool '{}' activated KB, falling through to LLM pipeline", tool_name); - } else { - let final_response = BotResponse { - bot_id: message.bot_id.clone(), - user_id: message.user_id.clone(), - session_id: message.session_id.clone(), - channel: message.channel.clone(), - content: response_content, - message_type: MessageType::BOT_RESPONSE, - stream_token: None, - is_complete: true, - suggestions: vec![], - context_name: None, - context_length: 0, - context_max_length: 0, - }; - - let _ = response_tx.send(final_response).await; - return Ok(()); - } + let _ = response_tx.send(final_response).await; + return Ok(()); } } diff --git a/src/llm/glm.rs b/src/llm/glm.rs index 75334b96..61664d13 100644 --- a/src/llm/glm.rs +++ b/src/llm/glm.rs @@ -31,9 +31,25 @@ pub struct GLMRequest { #[serde(skip_serializing_if = "Option::is_none")] pub temperature: Option, #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub tools: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub tool_choice: Option, + #[serde(rename = "extra_body", skip_serializing_if = "Option::is_none")] + pub extra_body: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GLMExtraBody { + #[serde(rename = "chat_template_kwargs")] + pub chat_template_kwargs: GLMChatTemplateKwargs, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GLMChatTemplateKwargs { + pub enable_thinking: bool, + pub clear_thinking: bool, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -140,17 +156,28 @@ impl LLMProvider for GLMClient { tool_calls: None, }]; - // Use glm-4.7 instead of glm-4 for z.ai API - let model_name = if model == "glm-4" { "glm-4.7" } else { model }; + // NVIDIA API uses z-ai/glm4.7 as the model identifier + let model_name = if model == "glm-4" || model == "glm-4.7" { + "z-ai/glm4.7" + } else { + model + }; let request = GLMRequest { model: model_name.to_string(), messages, stream: Some(false), max_tokens: None, - temperature: None, + temperature: Some(1.0), + top_p: Some(1.0), tools: None, tool_choice: None, + extra_body: Some(GLMExtraBody { + chat_template_kwargs: GLMChatTemplateKwargs { + enable_thinking: true, + clear_thinking: false, + }, + }), }; let url = self.build_url(); @@ -198,15 +225,12 @@ impl LLMProvider for GLMClient { let role = m.get("role")?.as_str()?; let content = m.get("content")?.as_str()?; let sanitized = Self::sanitize_utf8(content); - if !sanitized.is_empty() { - Some(GLMMessage { - role: role.to_string(), - content: Some(sanitized), - tool_calls: None, - }) - } else { - None - } + // NVIDIA API accepts empty content, don't filter them out + Some(GLMMessage { + role: role.to_string(), + content: Some(sanitized), + tool_calls: None, + }) }) .collect::>() } else { @@ -218,14 +242,18 @@ impl LLMProvider for GLMClient { }] }; - // If no messages or all empty, return error + // If no messages, return error if messages.is_empty() { return Err("No valid messages in request".into()); } - // Use glm-4.7 for tool calling support + // NVIDIA API uses z-ai/glm4.7 as the model identifier // GLM-4.7 supports standard OpenAI-compatible function calling - let model_name = if model == "glm-4" { "glm-4.7" } else { model }; + let model_name = if model == "glm-4" || model == "glm-4.7" { + "z-ai/glm4.7" + } else { + model + }; // Set tool_choice to "auto" when tools are present - this tells GLM to automatically decide when to call a tool let tool_choice = if tools.is_some() { @@ -239,9 +267,16 @@ impl LLMProvider for GLMClient { messages, stream: Some(true), max_tokens: None, - temperature: None, + temperature: Some(1.0), + top_p: Some(1.0), tools: tools.cloned(), tool_choice, + extra_body: Some(GLMExtraBody { + chat_template_kwargs: GLMChatTemplateKwargs { + enable_thinking: true, + clear_thinking: false, + }, + }), }; let url = self.build_url(); @@ -307,12 +342,14 @@ impl LLMProvider for GLMClient { } } - // GLM/z.ai returns both reasoning_content (thinking) and content (response) - // We only send the actual content, ignoring reasoning_content - // This makes GLM behave like OpenAI-compatible APIs - if let Some(content) = delta.get("content").and_then(|c| c.as_str()) { - if !content.is_empty() { - match tx.send(content.to_string()).await { + // GLM-4.7 on NVIDIA sends text via reasoning_content when thinking is enabled + // content may be null; we accept both fields + let content = delta.get("content").and_then(|c| c.as_str()) + .or_else(|| delta.get("reasoning_content").and_then(|c| c.as_str())); + + if let Some(text) = content { + if !text.is_empty() { + match tx.send(text.to_string()).await { Ok(_) => {}, Err(e) => { error!("Failed to send to channel: {}", e); diff --git a/src/llm/kimi.rs b/src/llm/kimi.rs new file mode 100644 index 00000000..20b60bf3 --- /dev/null +++ b/src/llm/kimi.rs @@ -0,0 +1,356 @@ +use async_trait::async_trait; +use futures::StreamExt; +use log::{error, info}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use tokio::sync::mpsc; + +use super::LLMProvider; + +// Kimi K2.5 (moonshotai) API Client +// NVIDIA endpoint with special chat_template_kwargs for thinking mode + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct KimiMessage { + pub role: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + #[serde(default)] + pub tool_calls: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct KimiRequest { + pub model: String, + pub messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + #[serde(rename = "chat_template_kwargs", skip_serializing_if = "Option::is_none")] + pub chat_template_kwargs: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct KimiChatTemplateKwargs { + pub thinking: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct KimiResponseChoice { + #[serde(default)] + pub index: u32, + pub message: KimiMessage, + #[serde(default)] + pub finish_reason: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct KimiResponse { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub choices: Vec, + #[serde(default)] + pub usage: Option, +} + +// Streaming structures +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct KimiStreamDelta { + #[serde(default)] + pub content: Option, + #[serde(default)] + pub role: Option, + #[serde(default)] + pub tool_calls: Option>, + #[serde(default)] + pub reasoning_content: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct KimiStreamChoice { + #[serde(default)] + pub index: u32, + #[serde(default)] + pub delta: KimiStreamDelta, + #[serde(default)] + pub finish_reason: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct KimiStreamChunk { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub choices: Vec, + #[serde(default)] + pub usage: Option, +} + +#[derive(Debug)] +pub struct KimiClient { + client: reqwest::Client, + base_url: String, +} + +impl KimiClient { + pub fn new(base_url: String) -> Self { + let base = base_url.trim_end_matches('/').to_string(); + + Self { + client: reqwest::Client::new(), + base_url: base, + } + } + + fn build_url(&self) -> String { + if self.base_url.contains("/chat/completions") { + self.base_url.clone() + } else { + format!("{}/chat/completions", self.base_url) + } + } + + fn sanitize_utf8(input: &str) -> String { + input.chars() + .filter(|c| { + let cp = *c as u32; + !(0xD800..=0xDBFF).contains(&cp) && !(0xDC00..=0xDFFF).contains(&cp) + }) + .collect() + } +} + +#[async_trait] +impl LLMProvider for KimiClient { + async fn generate( + &self, + prompt: &str, + _config: &Value, + model: &str, + key: &str, + ) -> Result> { + let messages = vec![KimiMessage { + role: "user".to_string(), + content: Some(prompt.to_string()), + tool_calls: None, + }]; + + let model_name = if model == "kimi-k2.5" || model == "kimi-k2" { + "moonshotai/kimi-k2.5" + } else { + model + }; + + let request = KimiRequest { + model: model_name.to_string(), + messages, + stream: Some(false), + max_tokens: None, + temperature: Some(1.0), + top_p: Some(1.0), + tools: None, + tool_choice: None, + chat_template_kwargs: Some(KimiChatTemplateKwargs { + thinking: true, + }), + }; + + let url = self.build_url(); + info!("Kimi non-streaming request to: {}", url); + + let response = self + .client + .post(&url) + .header("Authorization", format!("Bearer {}", key)) + .header("Content-Type", "application/json") + .header("Accept", "application/json") + .json(&request) + .send() + .await?; + + if !response.status().is_success() { + let error_text = response.text().await.unwrap_or_default(); + error!("Kimi API error: {}", error_text); + return Err(format!("Kimi API error: {}", error_text).into()); + } + + let kimi_response: KimiResponse = response.json().await?; + let content = kimi_response + .choices + .first() + .and_then(|c| c.message.content.clone()) + .unwrap_or_default(); + + Ok(content) + } + + async fn generate_stream( + &self, + prompt: &str, + config: &Value, + tx: mpsc::Sender, + model: &str, + key: &str, + tools: Option<&Vec>, + ) -> Result<(), Box> { + let messages = if let Some(msgs) = config.as_array() { + msgs.iter() + .filter_map(|m| { + let role = m.get("role")?.as_str()?; + let content = m.get("content")?.as_str()?; + let sanitized = Self::sanitize_utf8(content); + Some(KimiMessage { + role: role.to_string(), + content: Some(sanitized), + tool_calls: None, + }) + }) + .collect::>() + } else { + vec![KimiMessage { + role: "user".to_string(), + content: Some(Self::sanitize_utf8(prompt)), + tool_calls: None, + }] + }; + + if messages.is_empty() { + return Err("No valid messages in request".into()); + } + + let model_name = if model == "kimi-k2.5" || model == "kimi-k2" { + "moonshotai/kimi-k2.5" + } else { + model + }; + + let tool_choice = if tools.is_some() { + Some(serde_json::json!("auto")) + } else { + None + }; + + let request = KimiRequest { + model: model_name.to_string(), + messages, + stream: Some(true), + max_tokens: None, + temperature: Some(1.0), + top_p: Some(1.0), + tools: tools.cloned(), + tool_choice, + chat_template_kwargs: Some(KimiChatTemplateKwargs { + thinking: true, + }), + }; + + let url = self.build_url(); + info!("Kimi streaming request to: {}", url); + + let response = self + .client + .post(&url) + .header("Authorization", format!("Bearer {}", key)) + .header("Content-Type", "application/json") + .header("Accept", "text/event-stream") + .json(&request) + .send() + .await?; + + if !response.status().is_success() { + let error_text = response.text().await.unwrap_or_default(); + error!("Kimi streaming error: {}", error_text); + return Err(format!("Kimi streaming error: {}", error_text).into()); + } + + let mut stream = response.bytes_stream(); + + let mut buffer = Vec::new(); + while let Some(chunk_result) = stream.next().await { + let chunk = chunk_result.map_err(|e| format!("Stream error: {}", e))?; + + buffer.extend_from_slice(&chunk); + let data = String::from_utf8_lossy(&buffer); + + for line in data.lines() { + let line = line.trim(); + + if line.is_empty() { + continue; + } + + if line == "data: [DONE]" { + std::mem::drop(tx.send(String::new())); + return Ok(()); + } + + if let Some(json_str) = line.strip_prefix("data: ") { + let json_str = json_str.trim(); + if let Ok(chunk_data) = serde_json::from_str::(json_str) { + if let Some(choices) = chunk_data.get("choices").and_then(|c| c.as_array()) { + for choice in choices { + if let Some(delta) = choice.get("delta") { + if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) { + for tool_call in tool_calls { + let tool_call_json = serde_json::json!({ + "type": "tool_call", + "content": tool_call + }).to_string(); + let _ = tx.send(tool_call_json).await; + } + } + + // Kimi K2.5 sends text via reasoning_content (thinking mode) + // content may be null; we accept both fields + let text = delta.get("content").and_then(|c| c.as_str()) + .or_else(|| delta.get("reasoning_content").and_then(|c| c.as_str())); + + if let Some(text) = text { + if !text.is_empty() { + let _ = tx.send(text.to_string()).await; + } + } + } + + if let Some(reason) = choice.get("finish_reason").and_then(|r| r.as_str()) { + if !reason.is_empty() { + info!("Kimi stream finished: {}", reason); + std::mem::drop(tx.send(String::new())); + return Ok(()); + } + } + } + } + } + } + } + + if let Some(last_newline) = data.rfind('\n') { + buffer = buffer[last_newline + 1..].to_vec(); + } + } + + std::mem::drop(tx.send(String::new())); + Ok(()) + } + + async fn cancel_job( + &self, + _session_id: &str, + ) -> Result<(), Box> { + info!("Kimi cancel requested for session {} (no-op)", _session_id); + Ok(()) + } +} diff --git a/src/llm/mod.rs b/src/llm/mod.rs index 5c324cfc..ac96ae40 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -10,6 +10,7 @@ pub mod claude; pub mod episodic_memory; pub mod glm; pub mod hallucination_detector; +pub mod kimi; pub mod llm_models; pub mod local; pub mod observability; @@ -20,6 +21,7 @@ pub mod bedrock; pub use claude::ClaudeClient; pub use glm::GLMClient; +pub use kimi::KimiClient; pub use llm_models::get_handler; pub use rate_limiter::{ApiRateLimiter, RateLimits}; pub use vertex::VertexTokenManager; @@ -584,6 +586,8 @@ pub enum LLMProviderType { Claude, AzureClaude, GLM, + Kimi, + OSS, Bedrock, Vertex, } @@ -597,12 +601,19 @@ impl From<&str> for LLMProviderType { } else { Self::Claude } - } else if lower.contains("z.ai") || lower.contains("glm") || lower.contains("nvidia") { + } else if lower.contains("kimi") || lower.contains("moonshot") { + Self::Kimi + } else if lower.contains("z.ai") || lower.contains("glm") { Self::GLM + } else if lower.contains("gpt-oss") { + Self::OSS } else if lower.contains("bedrock") { Self::Bedrock } else if lower.contains("googleapis.com") || lower.contains("vertex") || lower.contains("generativelanguage") { Self::Vertex + } else if lower.contains("nvidia") { + // Default NVIDIA provider: GLM for now, can be overridden + Self::GLM } else { Self::OpenAI } @@ -640,6 +651,18 @@ pub fn create_llm_provider( info!("Creating GLM/z.ai LLM provider with URL: {}", base_url); std::sync::Arc::new(GLMClient::new(base_url)) } + LLMProviderType::Kimi => { + info!("Creating Kimi K2.5 LLM provider with URL: {}", base_url); + std::sync::Arc::new(KimiClient::new(base_url)) + } + LLMProviderType::OSS => { + info!("Creating GPT-OSS LLM provider with URL: {}", base_url); + std::sync::Arc::new(OpenAIClient::new( + "empty".to_string(), + Some(base_url), + endpoint_path, + )) + } LLMProviderType::Bedrock => { info!("Creating Bedrock LLM provider with exact URL: {}", base_url); std::sync::Arc::new(BedrockClient::new(base_url)) @@ -653,17 +676,27 @@ pub fn create_llm_provider( } /// Create LLM provider from URL with optional explicit provider type override. -/// If explicit_provider is Some, it takes precedence over URL-based detection. +/// Priority: explicit_provider > model name > URL-based detection. pub fn create_llm_provider_from_url( url: &str, model: Option, endpoint_path: Option, explicit_provider: Option, ) -> std::sync::Arc { - let provider_type = explicit_provider.as_ref().map(|p| *p).unwrap_or_else(|| LLMProviderType::from(url)); - if explicit_provider.is_some() { - info!("Using explicit LLM provider type: {:?} for URL: {}", provider_type, url); - } + let provider_type = if let Some(p) = explicit_provider { + info!("Using explicit LLM provider type: {:?} for URL: {}", p, url); + p + } else if let Some(ref m) = model { + // Model name takes precedence over URL + let detected = LLMProviderType::from(m.as_str()); + info!("Detected LLM provider type from model '{}': {:?}", m, detected); + detected + } else { + // Fall back to URL-based detection + let detected = LLMProviderType::from(url); + info!("Detected LLM provider type from URL: {:?}", detected); + detected + }; create_llm_provider(provider_type, url.to_string(), model, endpoint_path) }