Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 237 additions & 4 deletions rust-executor/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -453,10 +453,181 @@ impl Ad4mDb {
Ok(result > 0)
}

/// Validates that a notification query is safe and well-formed
fn validate_notification_query(query: &str) -> Result<(), String> {
let query_trimmed = query.trim();
let query_upper = query_trimmed.to_uppercase();

// Check for empty query
if query_trimmed.is_empty() {
return Err("Query cannot be empty".to_string());
}

// Check query length (prevent extremely long queries)
if query_trimmed.len() > 10000 {
return Err("Query is too long (max 10000 characters)".to_string());
}

// Validate that query starts with SELECT, RETURN, LET, or WITH
let first_word = query_upper.split_whitespace().next().unwrap_or("");
if !matches!(first_word, "SELECT" | "RETURN" | "LET" | "WITH") {
return Err(format!(
"Query must start with SELECT, RETURN, LET, or WITH. Got: {}",
first_word
));
}

// Check for mutating operations
// Use a single pass that tracks string literals to avoid false positives
let mutating_operations = [
"INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "REMOVE", "DEFINE", "ALTER", "RELATE",
"BEGIN", "COMMIT", "CANCEL",
];

let query_bytes = query_upper.as_bytes();

for operation in &mutating_operations {
let op_bytes = operation.as_bytes();
let mut search_pos = 0;

while let Some(pos) = query_upper[search_pos..].find(operation) {
let absolute_pos = search_pos + pos;

// Re-track string state up to this position
let mut in_string = false;
let mut escaped = false;
let mut string_char: u8 = 0;

for i in 0..absolute_pos {
let byte = query_bytes[i];

match byte {
b'\\' if in_string => {
escaped = !escaped;
}
b'\'' | b'"' => {
if !escaped {
if in_string {
if byte == string_char {
in_string = false;
}
} else {
in_string = true;
string_char = byte;
}
}
if escaped {
escaped = false;
}
}
_ => {
if escaped {
escaped = false;
}
}
}
}

// Skip if inside a string literal
if in_string {
search_pos = absolute_pos + 1;
continue;
}

// Check what comes before (byte-based)
let before_ok = if absolute_pos == 0 {
true
} else {
let before_byte = query_bytes[absolute_pos - 1];
matches!(before_byte, b' ' | b'\t' | b'\n' | b'\r' | b';' | b'(')
};

// Check what comes after (byte-based)
let after_pos = absolute_pos + op_bytes.len();
let after_ok = if after_pos >= query_bytes.len() {
true
} else {
let after_byte = query_bytes[after_pos];
matches!(after_byte, b' ' | b'\t' | b'\n' | b'\r' | b';' | b'(')
};

if before_ok && after_ok {
return Err(format!(
"Query contains mutating operation '{}' which is not allowed",
operation
));
}

search_pos = absolute_pos + 1;
}
}

// Basic syntax check - ensure balanced parentheses
// Skip counting parentheses inside string literals
let mut paren_count = 0;
let mut in_string = false;
let mut string_char = ' '; // Track which quote character started the string
let mut escaped = false;

for c in query_trimmed.chars() {
match c {
'\\' if in_string => {
// Toggle escaped state for backslashes inside strings
// If already escaped, this is a literal backslash (\\)
// If not escaped, next char will be escaped
escaped = !escaped;
}
'\'' | '"' => {
if !escaped {
if in_string {
// Check if this closes the current string
if c == string_char {
in_string = false;
}
} else {
// Start a new string
in_string = true;
string_char = c;
}
}
// Clear escaped state after processing
if escaped {
escaped = false;
}
}
'(' if !in_string => paren_count += 1,
')' if !in_string => paren_count -= 1,
_ => {
// Any other character clears the escaped state
if escaped {
escaped = false;
}
}
}

if paren_count < 0 {
return Err("Unbalanced parentheses in query".to_string());
}
}
if paren_count != 0 {
return Err("Unbalanced parentheses in query".to_string());
}

Ok(())
}

pub fn add_notification(
&self,
notification: NotificationInput,
) -> Result<String, rusqlite::Error> {
// Validate the trigger query before storing
if let Err(e) = Self::validate_notification_query(&notification.trigger) {
return Err(rusqlite::Error::InvalidParameterName(format!(
"Invalid notification query: {}",
e
)));
}

let id = uuid::Uuid::new_v4().to_string();
self.conn.execute(
"INSERT INTO notifications (id, granted, description, appName, appUrl, appIconPath, trigger, perspective_ids, webhookUrl, webhookAuth) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)",
Expand Down Expand Up @@ -535,6 +706,14 @@ impl Ad4mDb {
id: String,
updated_notification: &Notification,
) -> Result<bool, rusqlite::Error> {
// Validate the trigger query before updating
if let Err(e) = Self::validate_notification_query(&updated_notification.trigger) {
return Err(rusqlite::Error::InvalidParameterName(format!(
"Invalid notification query: {}",
e
)));
}

let result = self.conn.execute(
"UPDATE notifications SET description = ?2, appName = ?3, appUrl = ?4, appIconPath = ?5, trigger = ?6, perspective_ids = ?7, webhookUrl = ?8, webhookAuth = ?9, granted = ?10 WHERE id = ?1",
params![
Expand Down Expand Up @@ -2864,7 +3043,7 @@ mod tests {
app_name: "Test App".to_string(),
app_url: "http://test.app".to_string(),
app_icon_path: "/test/icon.png".to_string(),
trigger: "test-trigger".to_string(),
trigger: "SELECT * FROM link WHERE predicate = 'test://trigger'".to_string(),
perspective_ids: vec![perspective1.uuid.clone()],
webhook_url: "http://test.webhook".to_string(),
webhook_auth: "test-auth".to_string(),
Expand Down Expand Up @@ -3208,7 +3387,7 @@ mod tests {
app_name: "Test App Name".to_string(),
app_url: "Test App URL".to_string(),
app_icon_path: "Test App Icon Path".to_string(),
trigger: "Test Trigger".to_string(),
trigger: "SELECT * FROM link WHERE predicate = 'test://trigger'".to_string(),
perspective_ids: vec!["Test Perspective ID".to_string()],
webhook_url: "Test Webhook URL".to_string(),
webhook_auth: "Test Webhook Auth".to_string(),
Expand All @@ -3231,7 +3410,10 @@ mod tests {
test_notification.app_icon_path,
"Test App Icon Path".to_string()
);
assert_eq!(test_notification.trigger, "Test Trigger");
assert_eq!(
test_notification.trigger,
"SELECT * FROM link WHERE predicate = 'test://trigger'"
);
assert_eq!(
test_notification.perspective_ids,
vec!["Test Perspective ID".to_string()]
Expand All @@ -3247,7 +3429,7 @@ mod tests {
app_name: "Test App Name".to_string(),
app_url: "Test App URL".to_string(),
app_icon_path: "Test App Icon Path".to_string(),
trigger: "Test Trigger".to_string(),
trigger: "SELECT * FROM link WHERE predicate = 'test://updated'".to_string(),
perspective_ids: vec!["Test Perspective ID".to_string()],
webhook_url: "Test Webhook URL".to_string(),
webhook_auth: "Test Webhook Auth".to_string(),
Expand Down Expand Up @@ -3279,6 +3461,57 @@ mod tests {
.all(|n| n.id != notification_id));
}

#[test]
fn test_notification_query_validation_with_string_literals() {
let db = Ad4mDb::new(":memory:").unwrap();

// Should accept: keyword inside string literal
let notification1 = NotificationInput {
description: "Test".to_string(),
app_name: "Test".to_string(),
app_url: "Test".to_string(),
app_icon_path: "Test".to_string(),
trigger: "SELECT * FROM link WHERE data = 'DELETE this'".to_string(),
perspective_ids: vec!["test".to_string()],
webhook_url: "".to_string(),
webhook_auth: "".to_string(),
};

let result1 = db.add_notification(notification1);
assert!(result1.is_ok(), "Should allow DELETE inside string literal");

// Should reject: actual DELETE operation
let notification2 = NotificationInput {
description: "Test".to_string(),
app_name: "Test".to_string(),
app_url: "Test".to_string(),
app_icon_path: "Test".to_string(),
trigger: "DELETE FROM link WHERE id = 123".to_string(),
perspective_ids: vec!["test".to_string()],
webhook_url: "".to_string(),
webhook_auth: "".to_string(),
};

let result2 = db.add_notification(notification2);
assert!(result2.is_err(), "Should reject actual DELETE operation");
assert!(result2.unwrap_err().to_string().contains("DELETE"));

// Should accept: escaped quotes with keyword
let notification3 = NotificationInput {
description: "Test".to_string(),
app_name: "Test".to_string(),
app_url: "Test".to_string(),
app_icon_path: "Test".to_string(),
trigger: r#"SELECT * FROM link WHERE data = "Don\'t DELETE this""#.to_string(),
perspective_ids: vec!["test".to_string()],
webhook_url: "".to_string(),
webhook_auth: "".to_string(),
};

let result3 = db.add_notification(notification3);
assert!(result3.is_ok(), "Should allow DELETE inside escaped string");
}

#[test]
fn test_task_operations() {
let db = Ad4mDb::new(":memory:").unwrap();
Expand Down
Loading