Fix Bedrock config for OpenAI GPT-OSS models
All checks were successful
BotServer CI / build (push) Successful in 12m35s

This commit is contained in:
Rodrigo Rodriguez (Pragmatismo) 2026-03-10 12:36:24 -03:00
parent c523cee177
commit 82bfd0a443
10 changed files with 857 additions and 44 deletions

View file

@ -23,7 +23,7 @@ impl LlmAssistConfig {
if let Ok(content) = std::fs::read_to_string(&path) {
for line in content.lines() {
let parts: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
let parts: Vec<&str> = line.splitn(2, ',').map(|s| s.trim()).collect();
if parts.len() < 2 {
continue;

View file

@ -160,30 +160,25 @@ impl BootstrapManager {
if pm.is_installed("directory") {
// Wait for Zitadel to be ready - it might have been started during installation
// Use incremental backoff: check frequently at first, then slow down
// Use very aggressive backoff for fastest startup detection
let mut directory_already_running = zitadel_health_check();
if !directory_already_running {
info!("Zitadel not responding to health check, waiting with incremental backoff...");
// Check intervals: 1s x5, 2s x5, 5s x5, 10s x3 = ~60s total
let intervals: [(u64, u32); 4] = [(1, 5), (2, 5), (5, 5), (10, 3)];
let mut total_waited: u64 = 0;
for (interval_secs, count) in intervals {
for i in 0..count {
if zitadel_health_check() {
info!("Zitadel/Directory service is now responding (waited {}s)", total_waited);
directory_already_running = true;
break;
}
sleep(Duration::from_secs(interval_secs)).await;
total_waited += interval_secs;
// Show incremental progress every ~10s
if total_waited % 10 == 0 || i == 0 {
info!("Zitadel health check: {}s elapsed, retrying...", total_waited);
}
}
if directory_already_running {
info!("Zitadel not responding to health check, waiting...");
// Check every 500ms for fast detection (was: 1s, 2s, 5s, 10s)
let mut checks = 0;
let max_checks = 120; // 60 seconds max
while checks < max_checks {
if zitadel_health_check() {
info!("Zitadel/Directory service is now responding (checked {} times)", checks);
directory_already_running = true;
break;
}
sleep(Duration::from_millis(500)).await;
checks += 1;
// Log progress every 10 checks (5 seconds)
if checks % 10 == 0 {
info!("Zitadel health check: {}s elapsed, retrying...", checks / 2);
}
}
}

View file

@ -207,16 +207,27 @@ pub enum BotExistsResult {
/// Check if Zitadel directory is healthy
pub fn zitadel_health_check() -> bool {
// Check if Zitadel is responding on port 8300
if let Ok(output) = Command::new("curl")
.args(["-f", "-s", "--connect-timeout", "2", "http://localhost:8300/debug/healthz"])
.output()
{
if output.status.success() {
return true;
// Use very short timeout for fast startup detection
let output = Command::new("curl")
.args(["-f", "-s", "--connect-timeout", "1", "-m", "2", "http://localhost:8300/debug/healthz"])
.output();
match output {
Ok(result) => {
if result.status.success() {
let response = String::from_utf8_lossy(&result.stdout);
debug!("Zitadel health check response: {}", response);
return response.trim() == "ok";
}
let stderr = String::from_utf8_lossy(&result.stderr);
debug!("Zitadel health check failed: {}", stderr);
}
Err(e) => {
debug!("Zitadel health check error: {}", e);
}
}
// Fallback: just check if port 8300 is listening
// Fast fallback: just check if port 8300 is listening
match Command::new("nc")
.args(["-z", "-w", "1", "127.0.0.1", "8300"])
.output()

View file

@ -503,7 +503,7 @@ impl WhatsAppAdapter {
}
if !list_block.is_empty() {
if !current_part.is_empty() && current_part.len() + list_block.len() + 1 <= max_length {
if !current_part.is_empty() && current_part.len() + list_block.len() < max_length {
current_part.push('\n');
current_part.push_str(&list_block);
} else {

View file

@ -416,7 +416,7 @@ impl BotOrchestrator {
let session_id = Uuid::parse_str(&message.session_id)?;
let message_content = message.content.clone();
let (session, context_data, history, model, key, system_prompt) = {
let (session, context_data, history, model, key, system_prompt, bot_llm_url) = {
let state_clone = self.state.clone();
tokio::task::spawn_blocking(
move || -> Result<_, Box<dyn std::error::Error + Send + Sync>> {
@ -453,14 +453,19 @@ impl BotOrchestrator {
.get_config(&session.bot_id, "llm-key", Some(""))
.unwrap_or_default();
// Load bot-specific llm-url (may differ from global default)
let bot_llm_url = config_manager
.get_bot_config_value(&session.bot_id, "llm-url")
.ok();
// Load system-prompt from config.csv, fallback to default
let system_prompt = config_manager
.get_config(&session.bot_id, "system-prompt", Some("You are a helpful assistant with access to tools that can help you complete tasks. When a user's request matches one of your available tools, use the appropriate tool instead of providing a generic response."))
.unwrap_or_else(|_| "You are a helpful assistant.".to_string());
.unwrap_or_else(|_| "You are a helpful General Bots assistant.".to_string());
info!("Loaded system-prompt for bot {}: {}", session.bot_id, &system_prompt[..system_prompt.len().min(200)]);
info!("Loaded system-prompt for bot {}: {}", session.bot_id, &system_prompt[..system_prompt.len().min(500)]);
Ok((session, context_data, history, model, key, system_prompt))
Ok((session, context_data, history, model, key, system_prompt, bot_llm_url))
},
)
.await??
@ -615,7 +620,14 @@ impl BotOrchestrator {
}
let (stream_tx, mut stream_rx) = mpsc::channel::<String>(100);
let llm = self.state.llm_provider.clone();
// Use bot-specific LLM provider if the bot has its own llm-url configured
let llm: std::sync::Arc<dyn crate::llm::LLMProvider> = if let Some(ref url) = bot_llm_url {
info!("Bot has custom llm-url: {}, creating per-bot LLM provider", url);
crate::llm::create_llm_provider_from_url(url, Some(model.clone()), None)
} else {
self.state.llm_provider.clone()
};
let model_clone = model.clone();
let key_clone = key.clone();

View file

@ -482,7 +482,7 @@ impl ConfigManager {
let mut updated = 0;
for line in content.lines().skip(1) {
let parts: Vec<&str> = line.split(',').collect();
let parts: Vec<&str> = line.splitn(2, ',').collect();
if parts.len() >= 2 {
let key = parts[0].trim();
let value = parts[1].trim();

239
src/llm/bedrock.rs Normal file
View file

@ -0,0 +1,239 @@
use async_trait::async_trait;
use futures::StreamExt;
use log::{error, info};
use serde_json::Value;
use tokio::sync::mpsc;
use crate::llm::LLMProvider;
#[derive(Debug)]
pub struct BedrockClient {
client: reqwest::Client,
base_url: String,
}
impl BedrockClient {
pub fn new(base_url: String) -> Self {
Self {
client: reqwest::Client::new(),
base_url,
}
}
}
#[async_trait]
impl LLMProvider for BedrockClient {
async fn generate(
&self,
prompt: &str,
messages: &Value,
model: &str,
key: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let default_messages = serde_json::json!([{"role": "user", "content": prompt}]);
let raw_messages = if messages.is_array() && !messages.as_array().unwrap_or(&vec![]).is_empty() {
messages
} else {
&default_messages
};
let mut messages_limited = Vec::new();
if let Some(msg_array) = raw_messages.as_array() {
for msg in msg_array {
messages_limited.push(msg.clone());
}
}
let formatted_messages = serde_json::Value::Array(messages_limited);
let auth_header = if key.starts_with("Bearer ") {
key.to_string()
} else {
format!("Bearer {}", key)
};
let request_body = serde_json::json!({
"model": model,
"messages": formatted_messages,
"stream": false
});
info!("Sending request to Bedrock endpoint: {}", self.base_url);
let response = self.client
.post(&self.base_url)
.header("Authorization", &auth_header)
.header("Content-Type", "application/json")
.json(&request_body)
.send()
.await?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
error!("Bedrock generate error: {}", error_text);
return Err(format!("Bedrock API error ({}): {}", status, error_text).into());
}
let json: Value = response.json().await?;
if let Some(choices) = json.get("choices") {
if let Some(first_choice) = choices.get(0) {
if let Some(message) = first_choice.get("message") {
if let Some(content) = message.get("content") {
if let Some(content_str) = content.as_str() {
return Ok(content_str.to_string());
}
}
}
}
}
Err("Failed to parse response from Bedrock".into())
}
async fn generate_stream(
&self,
prompt: &str,
messages: &Value,
tx: mpsc::Sender<String>,
model: &str,
key: &str,
tools: Option<&Vec<Value>>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let default_messages = serde_json::json!([{"role": "user", "content": prompt}]);
let raw_messages = if messages.is_array() && !messages.as_array().unwrap_or(&vec![]).is_empty() {
messages
} else {
&default_messages
};
let mut messages_limited = Vec::new();
if let Some(msg_array) = raw_messages.as_array() {
for msg in msg_array {
messages_limited.push(msg.clone());
}
}
let formatted_messages = serde_json::Value::Array(messages_limited);
let auth_header = if key.starts_with("Bearer ") {
key.to_string()
} else {
format!("Bearer {}", key)
};
let mut request_body = serde_json::json!({
"model": model,
"messages": formatted_messages,
"stream": true
});
if let Some(tools_value) = tools {
if !tools_value.is_empty() {
request_body["tools"] = serde_json::json!(tools_value);
info!("Added {} tools to Bedrock request", tools_value.len());
}
}
let stream_url = if self.base_url.ends_with("/invoke") {
self.base_url.replace("/invoke", "/invoke-with-response-stream")
} else {
self.base_url.clone()
};
info!("Sending streaming request to Bedrock endpoint: {}", stream_url);
let response = self.client
.post(&stream_url)
.header("Authorization", &auth_header)
.header("Content-Type", "application/json")
.json(&request_body)
.send()
.await?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
error!("Bedrock generate_stream error: {}", error_text);
return Err(format!("Bedrock API error ({}): {}", status, error_text).into());
}
let mut stream = response.bytes_stream();
let mut tool_call_buffer = String::new();
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
if let Ok(text) = std::str::from_utf8(&chunk) {
for line in text.split('\n') {
let line = line.trim();
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
continue;
}
if let Ok(json) = serde_json::from_str::<Value>(data) {
if let Some(choices) = json.get("choices") {
if let Some(first_choice) = choices.get(0) {
if let Some(delta) = first_choice.get("delta") {
// Handle standard content streaming
if let Some(content) = delta.get("content") {
if let Some(content_str) = content.as_str() {
if !content_str.is_empty() && tx.send(content_str.to_string()).await.is_err() {
return Ok(());
}
}
}
// Handle tool calls streaming
if let Some(tool_calls) = delta.get("tool_calls") {
if let Some(calls_array) = tool_calls.as_array() {
if let Some(first_call) = calls_array.first() {
if let Some(function) = first_call.get("function") {
// Stream function JSON representation just like OpenAI does
if let Some(name) = function.get("name") {
if let Some(name_str) = name.as_str() {
tool_call_buffer = format!("{{\"name\": \"{}\", \"arguments\": \"", name_str);
}
}
if let Some(args) = function.get("arguments") {
if let Some(args_str) = args.as_str() {
tool_call_buffer.push_str(args_str);
}
}
}
}
}
}
}
}
}
}
}
}
}
}
Err(e) => {
error!("Bedrock stream reading error: {}", e);
break;
}
}
}
// Finalize tool call JSON parsing
if !tool_call_buffer.is_empty() {
tool_call_buffer.push_str("\"}");
if tx.send(format!("`tool_call`: {}", tool_call_buffer)).await.is_err() {
return Ok(());
}
}
Ok(())
}
async fn cancel_job(&self, _session_id: &str) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
Ok(())
}
}

View file

@ -14,11 +14,15 @@ pub mod llm_models;
pub mod local;
pub mod rate_limiter;
pub mod smart_router;
pub mod vertex;
pub mod bedrock;
pub use claude::ClaudeClient;
pub use glm::GLMClient;
pub use llm_models::get_handler;
pub use rate_limiter::{ApiRateLimiter, RateLimits};
pub use vertex::VertexTokenManager;
pub use bedrock::BedrockClient;
#[async_trait]
pub trait LLMProvider: Send + Sync {
@ -248,14 +252,16 @@ impl LLMProvider for OpenAIClient {
// GLM-4.7 has 202750 tokens, other models vary
let context_limit = if model.contains("glm-4") || model.contains("GLM-4") {
202750
} else if model.contains("gpt-4") {
128000
} else if model.contains("gemini") {
1000000 // Google Gemini models have 1M+ token context
} else if model.contains("gpt-oss") || model.contains("gpt-4") {
128000 // Cerebras gpt-oss models and GPT-4 variants
} else if model.contains("gpt-3.5") {
16385
} else if model.starts_with("http://localhost:808") || model == "local" {
768 // Local llama.cpp server context limit
} else {
4096 // Default conservative limit
32768 // Default conservative limit for modern models
};
let messages = OpenAIClient::ensure_token_limit(raw_messages, context_limit);
@ -329,14 +335,16 @@ impl LLMProvider for OpenAIClient {
// GLM-4.7 has 202750 tokens, other models vary
let context_limit = if model.contains("glm-4") || model.contains("GLM-4") {
202750
} else if model.contains("gpt-4") {
128000
} else if model.contains("gemini") {
1000000 // Google Gemini models have 1M+ token context
} else if model.contains("gpt-oss") || model.contains("gpt-4") {
128000 // Cerebras gpt-oss models and GPT-4 variants
} else if model.contains("gpt-3.5") {
16385
} else if model.starts_with("http://localhost:808") || model == "local" {
768 // Local llama.cpp server context limit
} else {
4096 // Default conservative limit
32768 // Default conservative limit for modern models
};
let messages = OpenAIClient::ensure_token_limit(raw_messages, context_limit);
@ -448,6 +456,8 @@ pub enum LLMProviderType {
Claude,
AzureClaude,
GLM,
Bedrock,
Vertex,
}
impl From<&str> for LLMProviderType {
@ -461,6 +471,10 @@ impl From<&str> for LLMProviderType {
}
} else if lower.contains("z.ai") || lower.contains("glm") {
Self::GLM
} else if lower.contains("bedrock") {
Self::Bedrock
} else if lower.contains("googleapis.com") || lower.contains("vertex") || lower.contains("generativelanguage") {
Self::Vertex
} else {
Self::OpenAI
}
@ -498,6 +512,15 @@ pub fn create_llm_provider(
info!("Creating GLM/z.ai LLM provider with URL: {}", base_url);
std::sync::Arc::new(GLMClient::new(base_url))
}
LLMProviderType::Bedrock => {
info!("Creating Bedrock LLM provider with exact URL: {}", base_url);
std::sync::Arc::new(BedrockClient::new(base_url))
}
LLMProviderType::Vertex => {
info!("Creating Vertex/Gemini LLM provider with URL: {}", base_url);
// Re-export the struct if we haven't already
std::sync::Arc::new(crate::llm::vertex::VertexClient::new(base_url, endpoint_path))
}
}
}

515
src/llm/vertex.rs Normal file
View file

@ -0,0 +1,515 @@
//! Google Vertex AI and Gemini Native API Integration
//! Support for both OpenAI-compatible endpoints and native Google AI Studio / Vertex AI endpoints.
use log::{error, info, trace};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::RwLock;
const GOOGLE_TOKEN_URL: &str = "https://oauth2.googleapis.com/token";
const GOOGLE_SCOPE: &str = "https://www.googleapis.com/auth/cloud-platform";
#[derive(Debug, Deserialize)]
struct ServiceAccountKey {
client_email: String,
private_key: String,
token_uri: Option<String>,
}
#[derive(Debug, Deserialize)]
struct TokenResponse {
access_token: String,
expires_in: u64,
}
#[derive(Debug, Clone)]
struct CachedToken {
access_token: String,
expires_at: std::time::Instant,
}
/// Manages OAuth2 access tokens for Google Vertex AI
#[derive(Debug, Clone)]
pub struct VertexTokenManager {
credentials_path: String,
cached_token: Arc<RwLock<Option<CachedToken>>>,
}
impl VertexTokenManager {
pub fn new(credentials_path: &str) -> Self {
Self {
credentials_path: credentials_path.to_string(),
cached_token: Arc::new(RwLock::new(None)),
}
}
/// Get a valid access token, refreshing if necessary
pub async fn get_access_token(&self) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
// Check cached token
{
let cache = self.cached_token.read().await;
if let Some(ref token) = *cache {
// Return cached token if it expires in more than 60 seconds
if token.expires_at > std::time::Instant::now() + std::time::Duration::from_secs(60) {
trace!("Using cached Vertex AI access token");
return Ok(token.access_token.clone());
}
}
}
// Generate new token
info!("Generating new Vertex AI access token from service account");
let token = self.generate_token().await?;
// Cache it
{
let mut cache = self.cached_token.write().await;
*cache = Some(token.clone());
}
Ok(token.access_token)
}
async fn generate_token(&self) -> Result<CachedToken, Box<dyn std::error::Error + Send + Sync>> {
let content = if self.credentials_path.trim().starts_with('{') {
self.credentials_path.clone()
} else {
// Expand ~ to home directory
let expanded_path = if self.credentials_path.starts_with('~') {
let home = std::env::var("HOME").unwrap_or_else(|_| "/root".to_string());
self.credentials_path.replacen('~', &home, 1)
} else {
self.credentials_path.clone()
};
tokio::fs::read_to_string(&expanded_path)
.await
.map_err(|e| format!("Failed to read service account key from '{}': {}", expanded_path, e))?
};
let sa_key: ServiceAccountKey = serde_json::from_str(&content)
.map_err(|e| format!("Failed to parse service account JSON: {}", e))?;
let token_uri = sa_key.token_uri.as_deref().unwrap_or(GOOGLE_TOKEN_URL);
// Create JWT assertion
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_err(|e| format!("System time error: {}", e))?
.as_secs();
let jwt_claims = JwtClaims {
iss: sa_key.client_email.clone(),
scope: GOOGLE_SCOPE.to_string(),
aud: token_uri.to_string(),
iat: now,
exp: now + 3600,
};
let header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::RS256);
let encoding_key = jsonwebtoken::EncodingKey::from_rsa_pem(sa_key.private_key.as_bytes())
.map_err(|e| format!("Failed to parse RSA private key: {}", e))?;
let jwt = jsonwebtoken::encode(&header, &jwt_claims, &encoding_key)
.map_err(|e| format!("Failed to encode JWT: {}", e))?;
// Exchange JWT for access token
let client = reqwest::Client::new();
let response = client
.post(token_uri)
.form(&[
("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
("assertion", &jwt),
])
.send()
.await
.map_err(|e| format!("Failed to request token: {}", e))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
error!("Google OAuth2 token request failed: {} - {}", status, body);
return Err(format!("Token request failed with status {}: {}", status, body).into());
}
let token_response: TokenResponse = response
.json()
.await
.map_err(|e| format!("Failed to parse token response: {}", e))?;
info!("Successfully obtained Vertex AI access token (expires in {}s)", token_response.expires_in);
Ok(CachedToken {
access_token: token_response.access_token,
expires_at: std::time::Instant::now() + std::time::Duration::from_secs(token_response.expires_in),
})
}
}
#[derive(Debug, Serialize)]
struct JwtClaims {
iss: String,
scope: String,
aud: String,
iat: u64,
exp: u64,
}
/// Builds the Vertex AI OpenAI-compatible endpoint URL
pub fn build_vertex_url(project: &str, location: &str) -> String {
format!(
"https://{location}-aiplatform.googleapis.com/v1/projects/{project}/locations/{location}/endpoints/openapi"
)
}
/// Check if a URL is a Vertex AI endpoint
pub fn is_vertex_ai_url(url: &str) -> bool {
url.contains("aiplatform.googleapis.com") || url.contains("generativelanguage.googleapis.com")
}
use crate::llm::LLMProvider;
use serde_json::Value;
use tokio::sync::mpsc;
use async_trait::async_trait;
use futures::StreamExt;
#[derive(Debug)]
pub struct VertexClient {
client: reqwest::Client,
base_url: String,
endpoint_path: String,
token_manager: Arc<tokio::sync::RwLock<Option<Arc<VertexTokenManager>>>>,
}
impl VertexClient {
pub fn new(base_url: String, endpoint_path: Option<String>) -> Self {
let endpoint = endpoint_path.unwrap_or_else(|| {
if base_url.contains("generativelanguage.googleapis.com") {
// Default to OpenAI compatibility if possible, but generate_stream will auto-detect
"/v1beta/openai/chat/completions".to_string()
} else if base_url.contains("aiplatform.googleapis.com") && !base_url.contains("openapi") {
// If it's a naked aiplatform URL, it might be missing the project/location path
// We'll leave it empty and let the caller provide a full URL if needed
"".to_string()
} else {
"".to_string()
}
});
Self {
client: reqwest::Client::new(),
base_url,
endpoint_path: endpoint,
token_manager: Arc::new(tokio::sync::RwLock::new(None)),
}
}
async fn get_auth_header(&self, key: &str) -> (&'static str, String) {
// If the key is a path or JSON (starts with {), it's a Service Account
if key.starts_with('{') || key.starts_with('~') || key.starts_with('/') || key.ends_with(".json") {
let mut manager_opt = self.token_manager.write().await;
if manager_opt.is_none() {
*manager_opt = Some(Arc::new(VertexTokenManager::new(key)));
}
if let Some(manager) = manager_opt.as_ref() {
match manager.get_access_token().await {
Ok(token) => return ("Authorization", format!("Bearer {}", token)),
Err(e) => error!("Failed to get Vertex OAuth token: {}", e),
}
}
}
// Default to Google API Key (Google AI Studio / Gemini Developer APIs)
("x-goog-api-key", key.to_string())
}
fn convert_messages_to_gemini(&self, messages: &Value) -> Value {
let mut contents = Vec::new();
if let Some(msg_array) = messages.as_array() {
for msg in msg_array {
let role = msg.get("role").and_then(|r| r.as_str()).unwrap_or("user");
let content = msg.get("content").and_then(|c| c.as_str()).unwrap_or("");
let gemini_role = match role {
"assistant" => "model",
"system" => "user", // Gemini doesn't have system role in 'contents' by default, often wrapped in systemInstruction
_ => "user",
};
contents.push(serde_json::json!({
"role": gemini_role,
"parts": [{"text": content}]
}));
}
}
serde_json::json!(contents)
}
}
#[async_trait]
impl LLMProvider for VertexClient {
async fn generate(
&self,
prompt: &str,
messages: &Value,
model: &str,
key: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let default_messages = serde_json::json!([{"role": "user", "content": prompt}]);
let raw_messages = if messages.is_array() && !messages.as_array().unwrap_or(&vec![]).is_empty() {
messages
} else {
&default_messages
};
let (header_name, auth_value) = self.get_auth_header(key).await;
// Auto-detect if we should use native Gemini API
let is_native = self.endpoint_path.contains(":generateContent") ||
self.base_url.contains("generativelanguage") && !self.endpoint_path.contains("openai");
let full_url = if is_native && !self.endpoint_path.contains(":generateContent") {
// Build native URL for generativelanguage
format!("{}/v1beta/models/{}:generateContent", self.base_url.trim_end_matches('/'), model)
} else {
format!("{}{}", self.base_url, self.endpoint_path)
}.trim_end_matches('/').to_string();
let request_body = if is_native {
serde_json::json!({
"contents": self.convert_messages_to_gemini(raw_messages),
"generationConfig": {
"temperature": 0.7,
}
})
} else {
serde_json::json!({
"model": model,
"messages": raw_messages,
"stream": false
})
};
info!("Sending request to Vertex/Gemini endpoint: {}", full_url);
let response = self.client
.post(full_url)
.header(header_name, &auth_value)
.header("Content-Type", "application/json")
.json(&request_body)
.send()
.await?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
error!("Vertex/Gemini generate error: {}", error_text);
return Err(format!("Google API error ({}): {}", status, error_text).into());
}
let json: Value = response.json().await?;
// Try parsing OpenAI format
if let Some(choices) = json.get("choices") {
if let Some(first_choice) = choices.get(0) {
if let Some(message) = first_choice.get("message") {
if let Some(content) = message.get("content") {
if let Some(content_str) = content.as_str() {
return Ok(content_str.to_string());
}
}
}
}
}
// Try parsing Native Gemini format
if let Some(candidates) = json.get("candidates") {
if let Some(first_candidate) = candidates.get(0) {
if let Some(content) = first_candidate.get("content") {
if let Some(parts) = content.get("parts") {
if let Some(first_part) = parts.get(0) {
if let Some(text) = first_part.get("text") {
if let Some(text_str) = text.as_str() {
return Ok(text_str.to_string());
}
}
}
}
}
}
}
Err("Failed to parse response from Vertex/Gemini (unknown format)".into())
}
async fn generate_stream(
&self,
prompt: &str,
messages: &Value,
tx: mpsc::Sender<String>,
model: &str,
key: &str,
tools: Option<&Vec<Value>>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let default_messages = serde_json::json!([{"role": "user", "content": prompt}]);
let raw_messages = if messages.is_array() && !messages.as_array().unwrap_or(&vec![]).is_empty() {
messages
} else {
&default_messages
};
let (header_name, auth_value) = self.get_auth_header(key).await;
// Auto-detect if we should use native Gemini API
let is_native = self.endpoint_path.contains("GenerateContent") ||
(self.base_url.contains("googleapis.com") &&
!self.endpoint_path.contains("openai") &&
!self.endpoint_path.contains("completions"));
let full_url = if is_native && !self.endpoint_path.contains("GenerateContent") {
// Build native URL for Gemini if it looks like we need it
if self.base_url.contains("aiplatform") {
// Global Vertex endpoint format
format!("{}/v1/publishers/google/models/{}:streamGenerateContent", self.base_url.trim_end_matches('/'), model)
} else {
// Google AI Studio format
format!("{}/v1beta/models/{}:streamGenerateContent", self.base_url.trim_end_matches('/'), model)
}
} else {
format!("{}{}", self.base_url, self.endpoint_path)
}.trim_end_matches('/').to_string();
let mut request_body = if is_native {
let mut body = serde_json::json!({
"contents": self.convert_messages_to_gemini(raw_messages),
"generationConfig": {
"temperature": 0.7,
}
});
// Handle thinking models if requested
if model.contains("preview") || model.contains("3.1-flash") {
body["generationConfig"]["thinkingConfig"] = serde_json::json!({
"thinkingLevel": "LOW"
});
}
body
} else {
serde_json::json!({
"model": model,
"messages": raw_messages,
"stream": true
})
};
if let Some(tools_value) = tools {
if !tools_value.is_empty() {
request_body["tools"] = serde_json::json!(tools_value);
info!("Added {} tools to Vertex request", tools_value.len());
}
}
info!("Sending streaming request to Vertex/Gemini endpoint: {}", full_url);
let response = self.client
.post(full_url)
.header(header_name, &auth_value)
.header("Content-Type", "application/json")
.json(&request_body)
.send()
.await?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
error!("Vertex/Gemini generate_stream error: {}", error_text);
return Err(format!("Google API error ({}): {}", status, error_text).into());
}
let mut stream = response.bytes_stream();
let mut tool_call_buffer = String::new();
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
if let Ok(text) = std::str::from_utf8(&chunk) {
for line in text.split('\n') {
let line = line.trim();
if line.is_empty() { continue; }
if let Some(data) = line.strip_prefix("data: ") {
// --- OpenAI SSE Format ---
if data == "[DONE]" { continue; }
if let Ok(json) = serde_json::from_str::<Value>(data) {
if let Some(choices) = json.get("choices") {
if let Some(first_choice) = choices.get(0) {
if let Some(delta) = first_choice.get("delta") {
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
if !content.is_empty() {
let _ = tx.send(content.to_string()).await;
}
}
// Handle tool calls...
if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) {
if let Some(first_call) = tool_calls.first() {
if let Some(function) = first_call.get("function") {
if let Some(name) = function.get("name").and_then(|n| n.as_str()) {
tool_call_buffer = format!("{{\"name\": \"{}\", \"arguments\": \"", name);
}
if let Some(args) = function.get("arguments").and_then(|a| a.as_str()) {
tool_call_buffer.push_str(args);
}
}
}
}
}
}
}
}
} else {
// --- Native Gemini JSON format ---
// It usually arrives as raw JSON objects, sometimes with leading commas or brackets in a stream array
let trimmed = line.trim_start_matches(|c| c == ',' || c == '[' || c == ']').trim_end_matches(']');
if trimmed.is_empty() { continue; }
if let Ok(json) = serde_json::from_str::<Value>(trimmed) {
if let Some(candidates) = json.get("candidates").and_then(|c| c.as_array()) {
if let Some(candidate) = candidates.first() {
if let Some(content) = candidate.get("content") {
if let Some(parts) = content.get("parts").and_then(|p| p.as_array()) {
for part in parts {
if let Some(text) = part.get("text").and_then(|t| t.as_str()) {
if !text.is_empty() {
let _ = tx.send(text.to_string()).await;
}
}
}
}
}
}
}
}
}
}
}
}
Err(e) => {
error!("Vertex stream reading error: {}", e);
break;
}
}
}
if !tool_call_buffer.is_empty() {
tool_call_buffer.push_str("\"}");
let _ = tx.send(format!("`tool_call`: {}", tool_call_buffer)).await;
}
Ok(())
}
async fn cancel_job(&self, _session_id: &str) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
Ok(())
}
}

View file

@ -880,12 +880,30 @@ async fn start_drive_monitors(
.await
.unwrap_or_default();
let local_dev_bots = tokio::task::spawn_blocking(move || {
let mut bots = std::collections::HashSet::new();
let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/opt/gbo/data".to_string());
if let Ok(entries) = std::fs::read_dir(data_dir) {
for entry in entries.flatten() {
let path = entry.path();
if path.extension().and_then(|e| e.to_str()).map(|e| e.eq_ignore_ascii_case("gbai")).unwrap_or(false) {
if let Some(bot_name) = path.file_stem().and_then(|s| s.to_str()) {
bots.insert(bot_name.to_string());
}
}
}
}
bots
})
.await
.unwrap_or_default();
info!("Found {} active bots to monitor", bots_to_monitor.len());
for (bot_id, bot_name) in bots_to_monitor {
// Skip default bot - it's managed locally via ConfigWatcher
if bot_name == "default" {
info!("Skipping DriveMonitor for 'default' bot - managed via ConfigWatcher");
// Skip default bot and local dev bots - they are managed locally
if bot_name == "default" || local_dev_bots.contains(&bot_name) {
info!("Skipping DriveMonitor for '{}' bot - managed locally via ConfigWatcher/LocalFileMonitor", bot_name);
continue;
}