botserver/src/security/command_guard.rs

431 lines
13 KiB
Rust
Raw Normal View History

2025-12-28 19:29:18 -03:00
use std::collections::HashSet;
use std::path::PathBuf;
use std::process::Output;
use std::sync::LazyLock;
static ALLOWED_COMMANDS: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
HashSet::from([
"pdftotext",
"pandoc",
"nvidia-smi",
"powershell",
"clamscan",
"freshclam",
"mc",
"ffmpeg",
"ffprobe",
"convert",
"gs",
"tesseract",
"which",
"where",
2025-12-28 19:29:18 -03:00
])
});
static FORBIDDEN_SHELL_CHARS: LazyLock<HashSet<char>> = LazyLock::new(|| {
HashSet::from([
';', '|', '&', '$', '`', '(', ')', '{', '}', '<', '>', '\n', '\r', '\0',
])
});
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CommandGuardError {
CommandNotAllowed(String),
InvalidArgument(String),
PathTraversal(String),
ExecutionFailed(String),
ShellInjectionAttempt(String),
}
impl std::fmt::Display for CommandGuardError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::CommandNotAllowed(cmd) => write!(f, "Command not in allowlist: {cmd}"),
Self::InvalidArgument(arg) => write!(f, "Invalid argument: {arg}"),
Self::PathTraversal(path) => write!(f, "Path traversal detected: {path}"),
Self::ExecutionFailed(msg) => write!(f, "Command execution failed: {msg}"),
Self::ShellInjectionAttempt(input) => {
write!(f, "Shell injection attempt detected: {input}")
}
}
}
}
impl std::error::Error for CommandGuardError {}
pub struct SafeCommand {
command: String,
args: Vec<String>,
working_dir: Option<PathBuf>,
allowed_paths: Vec<PathBuf>,
}
impl SafeCommand {
pub fn new(command: &str) -> Result<Self, CommandGuardError> {
let cmd_name = std::path::Path::new(command)
.file_name()
.and_then(|n| n.to_str())
.unwrap_or(command);
if !ALLOWED_COMMANDS.contains(cmd_name) {
return Err(CommandGuardError::CommandNotAllowed(command.to_string()));
}
Ok(Self {
command: command.to_string(),
args: Vec::new(),
working_dir: None,
allowed_paths: vec![
PathBuf::from("/tmp"),
PathBuf::from("/var/tmp"),
dirs::home_dir().unwrap_or_else(|| PathBuf::from("/")),
std::env::current_dir().unwrap_or_else(|_| PathBuf::from("/")),
],
})
}
pub fn arg(mut self, arg: &str) -> Result<Self, CommandGuardError> {
validate_argument(arg)?;
self.args.push(arg.to_string());
Ok(self)
}
pub fn args(mut self, args: &[&str]) -> Result<Self, CommandGuardError> {
for arg in args {
validate_argument(arg)?;
self.args.push((*arg).to_string());
}
Ok(self)
}
pub fn path_arg(mut self, path: &std::path::Path) -> Result<Self, CommandGuardError> {
let validated_path = validate_path(path, &self.allowed_paths)?;
self.args.push(validated_path.to_string_lossy().to_string());
Ok(self)
}
pub fn working_dir(mut self, dir: &std::path::Path) -> Result<Self, CommandGuardError> {
let validated = validate_path(dir, &self.allowed_paths)?;
self.working_dir = Some(validated);
Ok(self)
}
pub fn allow_path(mut self, path: PathBuf) -> Self {
self.allowed_paths.push(path);
self
}
pub fn execute(&self) -> Result<Output, CommandGuardError> {
let mut cmd = std::process::Command::new(&self.command);
cmd.args(&self.args);
if let Some(ref dir) = self.working_dir {
cmd.current_dir(dir);
}
cmd.env_clear();
cmd.env("PATH", "/usr/local/bin:/usr/bin:/bin");
cmd.env("HOME", dirs::home_dir().unwrap_or_else(|| PathBuf::from("/tmp")));
cmd.env("LANG", "C.UTF-8");
cmd.output()
.map_err(|e| CommandGuardError::ExecutionFailed(e.to_string()))
}
pub async fn execute_async(&self) -> Result<Output, CommandGuardError> {
let mut cmd = tokio::process::Command::new(&self.command);
cmd.args(&self.args);
if let Some(ref dir) = self.working_dir {
cmd.current_dir(dir);
}
cmd.env_clear();
cmd.env("PATH", "/usr/local/bin:/usr/bin:/bin");
cmd.env("HOME", dirs::home_dir().unwrap_or_else(|| PathBuf::from("/tmp")));
cmd.env("LANG", "C.UTF-8");
cmd.output()
.await
.map_err(|e| CommandGuardError::ExecutionFailed(e.to_string()))
}
}
pub fn validate_argument(arg: &str) -> Result<(), CommandGuardError> {
if arg.is_empty() {
return Err(CommandGuardError::InvalidArgument(
"Empty argument".to_string(),
));
}
if arg.len() > 4096 {
return Err(CommandGuardError::InvalidArgument(
"Argument too long".to_string(),
));
}
for c in arg.chars() {
if FORBIDDEN_SHELL_CHARS.contains(&c) {
return Err(CommandGuardError::ShellInjectionAttempt(format!(
"Forbidden character '{}' in argument",
c.escape_default()
)));
}
}
let dangerous_patterns = [
"$(", "`", "&&", "||", ">>", "<<", "..", "//", "\\\\",
];
for pattern in dangerous_patterns {
if arg.contains(pattern) {
return Err(CommandGuardError::ShellInjectionAttempt(format!(
"Dangerous pattern '{}' detected",
pattern
)));
}
}
Ok(())
}
pub fn validate_path(
path: &std::path::Path,
allowed_roots: &[PathBuf],
) -> Result<PathBuf, CommandGuardError> {
let canonical = path
.canonicalize()
.or_else(|_| {
if let Some(parent) = path.parent() {
parent.canonicalize().map(|p| p.join(path.file_name().unwrap_or_default()))
} else {
Err(std::io::Error::new(
std::io::ErrorKind::NotFound,
"Path not found",
))
}
})
.map_err(|_| {
CommandGuardError::PathTraversal(format!(
"Cannot canonicalize path: {}",
path.display()
))
})?;
let path_str = canonical.to_string_lossy();
if path_str.contains("..") {
return Err(CommandGuardError::PathTraversal(format!(
"Path contains traversal: {}",
path.display()
)));
}
let is_allowed = allowed_roots
.iter()
.any(|root| canonical.starts_with(root));
if !is_allowed {
return Err(CommandGuardError::PathTraversal(format!(
"Path outside allowed directories: {}",
path.display()
)));
}
Ok(canonical)
}
pub fn sanitize_filename(filename: &str) -> String {
filename
.chars()
.filter(|c| c.is_alphanumeric() || *c == '.' || *c == '-' || *c == '_')
.collect::<String>()
.trim_start_matches('.')
.to_string()
}
pub fn safe_pdftotext(
pdf_path: &std::path::Path,
_allowed_paths: &[PathBuf],
) -> Result<String, CommandGuardError> {
let output = SafeCommand::new("pdftotext")?
.allow_path(pdf_path.parent().unwrap_or(std::path::Path::new("/tmp")).to_path_buf())
.arg("-layout")?
.path_arg(pdf_path)?
.arg("-")?
.execute()?;
if output.status.success() {
Ok(String::from_utf8_lossy(&output.stdout).to_string())
} else {
Err(CommandGuardError::ExecutionFailed(
String::from_utf8_lossy(&output.stderr).to_string(),
))
}
}
pub async fn safe_pdftotext_async(
pdf_path: &std::path::Path,
) -> Result<String, CommandGuardError> {
let parent = pdf_path.parent().unwrap_or(std::path::Path::new("/tmp")).to_path_buf();
let output = SafeCommand::new("pdftotext")?
.allow_path(parent)
.arg("-layout")?
.path_arg(pdf_path)?
.arg("-")?
.execute_async()
.await?;
if output.status.success() {
Ok(String::from_utf8_lossy(&output.stdout).to_string())
} else {
Err(CommandGuardError::ExecutionFailed(
String::from_utf8_lossy(&output.stderr).to_string(),
))
}
}
pub async fn safe_pandoc_async(
input_path: &std::path::Path,
from_format: &str,
to_format: &str,
) -> Result<String, CommandGuardError> {
validate_argument(from_format)?;
validate_argument(to_format)?;
let allowed_formats = ["docx", "plain", "html", "markdown", "rst", "latex", "txt"];
if !allowed_formats.contains(&from_format) || !allowed_formats.contains(&to_format) {
return Err(CommandGuardError::InvalidArgument(
"Invalid format specified".to_string(),
));
}
let parent = input_path.parent().unwrap_or(std::path::Path::new("/tmp")).to_path_buf();
let output = SafeCommand::new("pandoc")?
.allow_path(parent)
.arg("-f")?
.arg(from_format)?
.arg("-t")?
.arg(to_format)?
.path_arg(input_path)?
.execute_async()
.await?;
if output.status.success() {
Ok(String::from_utf8_lossy(&output.stdout).to_string())
} else {
Err(CommandGuardError::ExecutionFailed(
String::from_utf8_lossy(&output.stderr).to_string(),
))
}
}
pub fn safe_nvidia_smi() -> Result<std::collections::HashMap<String, f32>, CommandGuardError> {
let output = SafeCommand::new("nvidia-smi")?
.arg("--query-gpu=utilization.gpu,utilization.memory")?
.arg("--format=csv,noheader,nounits")?
.execute()?;
if !output.status.success() {
return Err(CommandGuardError::ExecutionFailed(
"Failed to query GPU utilization".to_string(),
));
}
let output_str = String::from_utf8_lossy(&output.stdout);
let mut util = std::collections::HashMap::new();
for line in output_str.lines() {
let parts: Vec<&str> = line.split(',').collect();
if parts.len() >= 2 {
util.insert(
"gpu".to_string(),
parts[0].trim().parse::<f32>().unwrap_or_default(),
);
util.insert(
"memory".to_string(),
parts[1].trim().parse::<f32>().unwrap_or_default(),
);
}
}
Ok(util)
}
pub fn has_nvidia_gpu_safe() -> bool {
SafeCommand::new("nvidia-smi")
.and_then(|cmd| {
cmd.arg("--query-gpu=utilization.gpu")?
.arg("--format=csv,noheader,nounits")
})
.and_then(|cmd| cmd.execute())
.map(|output| output.status.success())
.unwrap_or(false)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_argument_valid() {
assert!(validate_argument("hello").is_ok());
assert!(validate_argument("-f").is_ok());
assert!(validate_argument("--format=csv").is_ok());
assert!(validate_argument("/path/to/file.txt").is_ok());
}
#[test]
fn test_validate_argument_invalid() {
assert!(validate_argument("hello; rm -rf /").is_err());
assert!(validate_argument("$(whoami)").is_err());
assert!(validate_argument("file | cat").is_err());
assert!(validate_argument("test && echo").is_err());
assert!(validate_argument("`id`").is_err());
assert!(validate_argument("").is_err());
}
#[test]
fn test_safe_command_allowed() {
assert!(SafeCommand::new("pdftotext").is_ok());
assert!(SafeCommand::new("pandoc").is_ok());
assert!(SafeCommand::new("nvidia-smi").is_ok());
}
#[test]
fn test_safe_command_disallowed() {
assert!(SafeCommand::new("rm").is_err());
assert!(SafeCommand::new("bash").is_err());
assert!(SafeCommand::new("sh").is_err());
assert!(SafeCommand::new("curl").is_err());
assert!(SafeCommand::new("wget").is_err());
}
#[test]
fn test_sanitize_filename() {
assert_eq!(sanitize_filename("test.pdf"), "test.pdf");
assert_eq!(sanitize_filename("my-file_v1.txt"), "my-file_v1.txt");
assert_eq!(sanitize_filename("../../../etc/passwd"), "etcpasswd");
assert_eq!(sanitize_filename(".hidden"), "hidden");
assert_eq!(sanitize_filename("file;rm -rf.txt"), "filerm-rf.txt");
}
#[test]
fn test_path_traversal_detection() {
let _allowed = vec![PathBuf::from("/tmp")];
let result = validate_argument("../../../etc/passwd");
assert!(result.is_err());
}
#[test]
fn test_command_guard_error_display() {
let err = CommandGuardError::CommandNotAllowed("bash".to_string());
assert!(err.to_string().contains("bash"));
let err2 = CommandGuardError::ShellInjectionAttempt("$(id)".to_string());
assert!(err2.to_string().contains("injection"));
}
}