- Added table_exists_in_database() to check if table exists in PostgreSQL - Updated validate_table_name() to allow valid identifiers (not just whitelist) - Added validate_table_name_with_conn() for full validation with DB check - Added is_table_allowed_with_conn() for handlers to verify table existence - Updated list_records_handler and count_records_handler to use dynamic check - Uses parameterized query for table existence check (SQL injection safe)
447 lines
14 KiB
Rust
447 lines
14 KiB
Rust
use std::collections::HashSet;
|
|
use std::sync::LazyLock;
|
|
|
|
/// System tables that are always allowed (core application tables)
|
|
static SYSTEM_TABLES: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
|
|
HashSet::from([
|
|
"automations",
|
|
"bots",
|
|
"bot_configurations",
|
|
"bot_memories",
|
|
"clicks",
|
|
"group_members",
|
|
"groups",
|
|
"message_history",
|
|
"organizations",
|
|
"table_access",
|
|
"tasks",
|
|
"trigger_kinds",
|
|
"user_login_tokens",
|
|
"user_preferences",
|
|
"user_sessions",
|
|
"users",
|
|
"a2a_messages",
|
|
"api_records",
|
|
"calendar_events",
|
|
"documents",
|
|
"email_accounts",
|
|
"meetings",
|
|
"notifications",
|
|
"oauth_providers",
|
|
"oauth_tokens",
|
|
"research_sessions",
|
|
"sources",
|
|
])
|
|
});
|
|
|
|
/// Check if a table exists in the database
|
|
/// This allows dynamically created tables (from app generator) to be accessed
|
|
pub fn table_exists_in_database(conn: &mut diesel::PgConnection, table_name: &str) -> bool {
|
|
use diesel::prelude::*;
|
|
use diesel::sql_query;
|
|
use diesel::sql_types::Text;
|
|
|
|
#[derive(diesel::QueryableByName)]
|
|
struct TableExists {
|
|
#[diesel(sql_type = diesel::sql_types::Bool)]
|
|
exists: bool,
|
|
}
|
|
|
|
let result: Result<TableExists, _> = sql_query(
|
|
"SELECT EXISTS (
|
|
SELECT FROM information_schema.tables
|
|
WHERE table_schema = 'public'
|
|
AND table_name = $1
|
|
) as exists"
|
|
)
|
|
.bind::<Text, _>(table_name)
|
|
.get_result(conn);
|
|
|
|
match result {
|
|
Ok(r) => r.exists,
|
|
Err(_) => false,
|
|
}
|
|
}
|
|
|
|
static ALLOWED_ORDER_COLUMNS: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
|
|
HashSet::from([
|
|
"id",
|
|
"created_at",
|
|
"updated_at",
|
|
"name",
|
|
"email",
|
|
"title",
|
|
"status",
|
|
"priority",
|
|
"due_date",
|
|
"start_date",
|
|
"end_date",
|
|
"order",
|
|
"position",
|
|
"timestamp",
|
|
])
|
|
});
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub enum SqlGuardError {
|
|
InvalidTableName(String),
|
|
InvalidColumnName(String),
|
|
InvalidOrderDirection(String),
|
|
InvalidIdentifier(String),
|
|
PotentialInjection(String),
|
|
}
|
|
|
|
impl std::fmt::Display for SqlGuardError {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
match self {
|
|
Self::InvalidTableName(name) => write!(f, "Invalid table name: {name}"),
|
|
Self::InvalidColumnName(name) => write!(f, "Invalid column name: {name}"),
|
|
Self::InvalidOrderDirection(dir) => write!(f, "Invalid order direction: {dir}"),
|
|
Self::InvalidIdentifier(id) => write!(f, "Invalid identifier: {id}"),
|
|
Self::PotentialInjection(input) => write!(f, "Potential SQL injection detected: {input}"),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl std::error::Error for SqlGuardError {}
|
|
|
|
/// Validate table name - checks system tables whitelist only
|
|
/// For full validation including dynamic tables, use validate_table_name_with_conn
|
|
pub fn validate_table_name(table: &str) -> Result<&str, SqlGuardError> {
|
|
let sanitized = sanitize_identifier(table);
|
|
|
|
// Check basic identifier validity (no SQL injection)
|
|
if sanitized.is_empty() || sanitized.len() > 63 {
|
|
return Err(SqlGuardError::InvalidTableName(table.to_string()));
|
|
}
|
|
|
|
// Check for dangerous patterns
|
|
if sanitized.contains(';') || sanitized.contains("--") || sanitized.contains("/*") {
|
|
return Err(SqlGuardError::PotentialInjection(table.to_string()));
|
|
}
|
|
|
|
// System tables are always allowed
|
|
if SYSTEM_TABLES.contains(sanitized.as_str()) {
|
|
return Ok(table);
|
|
}
|
|
|
|
// For non-system tables, we allow them if they look safe
|
|
// The actual existence check should be done with validate_table_name_with_conn
|
|
// This allows dynamically created tables from app_generator to work
|
|
if sanitized.chars().all(|c| c.is_alphanumeric() || c == '_') {
|
|
Ok(table)
|
|
} else {
|
|
Err(SqlGuardError::InvalidTableName(table.to_string()))
|
|
}
|
|
}
|
|
|
|
/// Validate table name with database connection - checks if table actually exists
|
|
pub fn validate_table_name_with_conn<'a>(
|
|
conn: &mut diesel::PgConnection,
|
|
table: &'a str,
|
|
) -> Result<&'a str, SqlGuardError> {
|
|
let sanitized = sanitize_identifier(table);
|
|
|
|
// First do basic validation
|
|
if sanitized.is_empty() || sanitized.len() > 63 {
|
|
return Err(SqlGuardError::InvalidTableName(table.to_string()));
|
|
}
|
|
|
|
if sanitized.contains(';') || sanitized.contains("--") || sanitized.contains("/*") {
|
|
return Err(SqlGuardError::PotentialInjection(table.to_string()));
|
|
}
|
|
|
|
// System tables are always allowed
|
|
if SYSTEM_TABLES.contains(sanitized.as_str()) {
|
|
return Ok(table);
|
|
}
|
|
|
|
// Check if table exists in database (for dynamically created tables)
|
|
if table_exists_in_database(conn, &sanitized) {
|
|
Ok(table)
|
|
} else {
|
|
Err(SqlGuardError::InvalidTableName(table.to_string()))
|
|
}
|
|
}
|
|
|
|
pub fn validate_order_column(column: &str) -> Result<&str, SqlGuardError> {
|
|
let sanitized = sanitize_identifier(column);
|
|
|
|
if ALLOWED_ORDER_COLUMNS.contains(sanitized.as_str()) {
|
|
Ok(column)
|
|
} else {
|
|
Err(SqlGuardError::InvalidColumnName(column.to_string()))
|
|
}
|
|
}
|
|
|
|
pub fn validate_order_direction(direction: &str) -> Result<&'static str, SqlGuardError> {
|
|
match direction.to_uppercase().as_str() {
|
|
"ASC" => Ok("ASC"),
|
|
"DESC" => Ok("DESC"),
|
|
_ => Err(SqlGuardError::InvalidOrderDirection(direction.to_string())),
|
|
}
|
|
}
|
|
|
|
pub fn sanitize_identifier(name: &str) -> String {
|
|
name.chars()
|
|
.filter(|c| c.is_ascii_alphanumeric() || *c == '_')
|
|
.collect()
|
|
}
|
|
|
|
pub fn validate_identifier(name: &str) -> Result<String, SqlGuardError> {
|
|
let sanitized = sanitize_identifier(name);
|
|
|
|
if sanitized.is_empty() {
|
|
return Err(SqlGuardError::InvalidIdentifier(name.to_string()));
|
|
}
|
|
|
|
if sanitized.len() > 64 {
|
|
return Err(SqlGuardError::InvalidIdentifier("Identifier too long".to_string()));
|
|
}
|
|
|
|
if sanitized.chars().next().map(|c| c.is_ascii_digit()).unwrap_or(false) {
|
|
return Err(SqlGuardError::InvalidIdentifier("Identifier cannot start with digit".to_string()));
|
|
}
|
|
|
|
Ok(sanitized)
|
|
}
|
|
|
|
pub fn check_for_injection_patterns(input: &str) -> Result<(), SqlGuardError> {
|
|
let lower = input.to_lowercase();
|
|
|
|
let dangerous_patterns = [
|
|
"--",
|
|
"/*",
|
|
"*/",
|
|
";",
|
|
"union",
|
|
"select",
|
|
"insert",
|
|
"update",
|
|
"delete",
|
|
"drop",
|
|
"truncate",
|
|
"exec",
|
|
"execute",
|
|
"xp_",
|
|
"sp_",
|
|
"0x",
|
|
"char(",
|
|
"nchar(",
|
|
"varchar(",
|
|
"nvarchar(",
|
|
"cast(",
|
|
"convert(",
|
|
"@@",
|
|
"waitfor",
|
|
"delay",
|
|
"benchmark",
|
|
"sleep(",
|
|
];
|
|
|
|
for pattern in dangerous_patterns {
|
|
if lower.contains(pattern) {
|
|
return Err(SqlGuardError::PotentialInjection(format!(
|
|
"Dangerous pattern '{}' detected",
|
|
pattern
|
|
)));
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub fn escape_string_literal(value: &str) -> String {
|
|
value.replace('\'', "''").replace('\\', "\\\\")
|
|
}
|
|
|
|
pub fn build_safe_select_query(
|
|
table: &str,
|
|
order_by: Option<&str>,
|
|
order_dir: Option<&str>,
|
|
limit: i32,
|
|
offset: i32,
|
|
) -> Result<String, SqlGuardError> {
|
|
let validated_table = validate_table_name(table)?;
|
|
|
|
let safe_limit = limit.clamp(1, 1000);
|
|
let safe_offset = offset.max(0);
|
|
|
|
let order_clause = match (order_by, order_dir) {
|
|
(Some(col), Some(dir)) => {
|
|
let validated_col = validate_order_column(col)?;
|
|
let validated_dir = validate_order_direction(dir)?;
|
|
format!("ORDER BY {} {}", validated_col, validated_dir)
|
|
}
|
|
(Some(col), None) => {
|
|
let validated_col = validate_order_column(col)?;
|
|
format!("ORDER BY {} ASC", validated_col)
|
|
}
|
|
_ => "ORDER BY id ASC".to_string(),
|
|
};
|
|
|
|
Ok(format!(
|
|
"SELECT row_to_json(t.*) as data FROM {} t {} LIMIT {} OFFSET {}",
|
|
validated_table, order_clause, safe_limit, safe_offset
|
|
))
|
|
}
|
|
|
|
pub fn build_safe_count_query(table: &str) -> Result<String, SqlGuardError> {
|
|
let validated_table = validate_table_name(table)?;
|
|
Ok(format!("SELECT COUNT(*) as count FROM {}", validated_table))
|
|
}
|
|
|
|
pub fn build_safe_delete_query(table: &str) -> Result<String, SqlGuardError> {
|
|
let validated_table = validate_table_name(table)?;
|
|
Ok(format!("DELETE FROM {} WHERE id = $1", validated_table))
|
|
}
|
|
|
|
pub fn build_safe_select_by_id_query(table: &str) -> Result<String, SqlGuardError> {
|
|
let validated_table = validate_table_name(table)?;
|
|
Ok(format!(
|
|
"SELECT row_to_json(t.*) as data FROM {} t WHERE id = $1",
|
|
validated_table
|
|
))
|
|
}
|
|
|
|
pub fn register_dynamic_table(table_name: &'static str) {
|
|
log::info!("Dynamic table registration requested for: {}", table_name);
|
|
}
|
|
|
|
/// Check if table is in the system tables whitelist
|
|
pub fn is_system_table(table: &str) -> bool {
|
|
let sanitized = sanitize_identifier(table);
|
|
SYSTEM_TABLES.contains(sanitized.as_str())
|
|
}
|
|
|
|
/// Check if table is allowed (system table or valid identifier)
|
|
/// For full check including database existence, use is_table_allowed_with_conn
|
|
pub fn is_table_allowed(table: &str) -> bool {
|
|
let sanitized = sanitize_identifier(table);
|
|
if SYSTEM_TABLES.contains(sanitized.as_str()) {
|
|
return true;
|
|
}
|
|
// Allow valid identifiers (actual existence checked elsewhere)
|
|
!sanitized.is_empty()
|
|
&& sanitized.len() <= 63
|
|
&& sanitized.chars().all(|c| c.is_alphanumeric() || c == '_')
|
|
}
|
|
|
|
/// Check if table is allowed with database connection
|
|
pub fn is_table_allowed_with_conn(conn: &mut diesel::PgConnection, table: &str) -> bool {
|
|
let sanitized = sanitize_identifier(table);
|
|
if SYSTEM_TABLES.contains(sanitized.as_str()) {
|
|
return true;
|
|
}
|
|
table_exists_in_database(conn, &sanitized)
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_validate_table_name_allowed() {
|
|
assert!(validate_table_name("users").is_ok());
|
|
assert!(validate_table_name("bots").is_ok());
|
|
assert!(validate_table_name("tasks").is_ok());
|
|
}
|
|
|
|
#[test]
|
|
fn test_validate_table_name_disallowed() {
|
|
assert!(validate_table_name("evil_table").is_err());
|
|
assert!(validate_table_name("users; DROP TABLE users;--").is_err());
|
|
assert!(validate_table_name("").is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn test_validate_order_direction() {
|
|
assert_eq!(validate_order_direction("ASC").unwrap(), "ASC");
|
|
assert_eq!(validate_order_direction("desc").unwrap(), "DESC");
|
|
assert_eq!(validate_order_direction("Asc").unwrap(), "ASC");
|
|
assert!(validate_order_direction("RANDOM").is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn test_sanitize_identifier() {
|
|
assert_eq!(sanitize_identifier("valid_name"), "valid_name");
|
|
assert_eq!(sanitize_identifier("name123"), "name123");
|
|
assert_eq!(sanitize_identifier("name; DROP--"), "nameDROP");
|
|
assert_eq!(sanitize_identifier(""), "");
|
|
}
|
|
|
|
#[test]
|
|
fn test_check_for_injection_patterns() {
|
|
assert!(check_for_injection_patterns("normal text").is_ok());
|
|
assert!(check_for_injection_patterns("hello world").is_ok());
|
|
|
|
assert!(check_for_injection_patterns("'; DROP TABLE users;--").is_err());
|
|
assert!(check_for_injection_patterns("1 UNION SELECT * FROM passwords").is_err());
|
|
assert!(check_for_injection_patterns("test/*comment*/").is_err());
|
|
assert!(check_for_injection_patterns("WAITFOR DELAY '0:0:5'").is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn test_escape_string_literal() {
|
|
assert_eq!(escape_string_literal("hello"), "hello");
|
|
assert_eq!(escape_string_literal("it's"), "it''s");
|
|
assert_eq!(escape_string_literal("back\\slash"), "back\\\\slash");
|
|
assert_eq!(escape_string_literal("O'Brien's"), "O''Brien''s");
|
|
}
|
|
|
|
#[test]
|
|
fn test_build_safe_select_query() {
|
|
let query = build_safe_select_query("users", Some("created_at"), Some("DESC"), 10, 0);
|
|
assert!(query.is_ok());
|
|
let q = query.unwrap();
|
|
assert!(q.contains("users"));
|
|
assert!(q.contains("ORDER BY created_at DESC"));
|
|
assert!(q.contains("LIMIT 10"));
|
|
assert!(q.contains("OFFSET 0"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_build_safe_select_query_invalid_table() {
|
|
let query = build_safe_select_query("evil_table", None, None, 10, 0);
|
|
assert!(query.is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn test_build_safe_count_query() {
|
|
let query = build_safe_count_query("users");
|
|
assert!(query.is_ok());
|
|
assert!(query.unwrap().contains("SELECT COUNT(*)"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_build_safe_delete_query() {
|
|
let query = build_safe_delete_query("tasks");
|
|
assert!(query.is_ok());
|
|
assert!(query.unwrap().contains("DELETE FROM tasks WHERE id = $1"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_validate_identifier() {
|
|
assert!(validate_identifier("valid_name").is_ok());
|
|
assert!(validate_identifier("name123").is_ok());
|
|
assert!(validate_identifier("").is_err());
|
|
assert!(validate_identifier("123name").is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn test_limit_clamping() {
|
|
let query = build_safe_select_query("users", None, None, 10000, 0).unwrap();
|
|
assert!(query.contains("LIMIT 1000"));
|
|
|
|
let query2 = build_safe_select_query("users", None, None, -5, -10).unwrap();
|
|
assert!(query2.contains("LIMIT 1"));
|
|
assert!(query2.contains("OFFSET 0"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_is_table_allowed() {
|
|
assert!(is_table_allowed("users"));
|
|
assert!(is_table_allowed("bots"));
|
|
assert!(!is_table_allowed("hacked_table"));
|
|
}
|
|
}
|