diff --git a/rust-executor/src/prolog_service/engine.rs b/rust-executor/src/prolog_service/engine.rs index 3c94c41b2..3a3566cc6 100644 --- a/rust-executor/src/prolog_service/engine.rs +++ b/rust-executor/src/prolog_service/engine.rs @@ -1,15 +1,18 @@ use std::panic::AssertUnwindSafe; +use std::sync::mpsc; +use std::sync::Arc; +use std::sync::Mutex; use deno_core::anyhow::Error; -use scryer_prolog::{LeafAnswer, MachineBuilder, Term}; -use tokio::sync::{mpsc, oneshot}; +use scryer_prolog::MachineBuilder; +use tokio::task; use super::types::{query_result_from_leaf_answer, QueryResult}; #[derive(Debug)] pub enum PrologServiceRequest { - RunQuery(String, oneshot::Sender), - LoadModuleString(String, Vec, oneshot::Sender), + RunQuery(String, mpsc::Sender), + LoadModuleString(String, Vec, mpsc::Sender), #[allow(dead_code)] Drop, } @@ -21,97 +24,147 @@ pub enum PrologServiceResponse { LoadModuleResult(Result<(), Error>), } +struct SendableReceiver(Arc>>); + +unsafe impl Sync for SendableReceiver {} + pub struct PrologEngine { - request_sender: mpsc::UnboundedSender, - request_receiver: Option>, + request_sender: mpsc::Sender, + request_receiver: Option>, } impl PrologEngine { pub fn new() -> PrologEngine { - let (request_sender, request_receiver) = mpsc::unbounded_channel::(); + let (request_sender, request_receiver) = mpsc::channel::(); PrologEngine { request_sender, - request_receiver: Some(request_receiver), + request_receiver: Some(SendableReceiver(Arc::new(Mutex::new(request_receiver)))), } } pub async fn spawn(&mut self) -> Result<(), Error> { - let mut receiver = self + let receiver = self .request_receiver .take() .ok_or_else(|| Error::msg("PrologEngine::spawn called twice"))?; - let (response_sender, response_receiver) = oneshot::channel(); + let (response_sender, response_receiver) = mpsc::channel(); std::thread::spawn(move || { - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .thread_name(String::from("prolog_service")) - .build() - .expect("Failed to create Tokio runtime"); - let _guard = rt.enter(); - - tokio::task::block_in_place(|| { - rt.block_on(async move { - let mut machine = MachineBuilder::default().build(); - - response_sender - .send(PrologServiceResponse::InitComplete(Ok(()))) - .unwrap(); - - while let Some(message) = receiver.recv().await { - match message { - PrologServiceRequest::RunQuery(query, response) => { - let answer_result = std::panic::catch_unwind(AssertUnwindSafe(|| { - query_result_from_leaf_answer(machine - .run_query(query.clone()) - .collect::, Term>>()) - })); - - match answer_result { - Ok(result) => { - let _ = response - .send(PrologServiceResponse::QueryResult(result)); + let mut machine = MachineBuilder::default().build(); + let receiver = receiver.0.lock().unwrap(); + + response_sender + .send(PrologServiceResponse::InitComplete(Ok(()))) + .unwrap(); + + while let Ok(message) = receiver.recv() { + match message { + PrologServiceRequest::RunQuery(query, response) => { + let answer_result = std::panic::catch_unwind(AssertUnwindSafe(|| { + let mut results = Vec::new(); + let mut iter = machine.run_query(query.clone()); + const MAX_RESULTS: usize = 1_000_000; // Adjust as needed + + let mut panic = None; + while results.len() < MAX_RESULTS { + match std::panic::catch_unwind(AssertUnwindSafe(|| iter.next())) { + Ok(Some(Ok(answer))) => results.push(answer), + Ok(Some(Err(term))) => { + return query_result_from_leaf_answer(Err(term)) } + Ok(None) => break, // Iterator exhausted Err(e) => { let error_string = if let Some(string) = e.downcast_ref::() { - format!("Scryer panicked with: {:?} - when running query: {}", string, query) + format!( + "Scryer panicked in next(): {:?} - query: {}", + string, query + ) } else if let Some(&str) = e.downcast_ref::<&str>() { - format!("Scryer panicked with: {:?} - when running query: {}", str, query) + format!( + "Scryer panicked in next(): {:?} - query: {}", + str, query + ) } else { - format!("Scryer panicked with: {:?} - when running query: {}", e, query) + format!( + "Scryer panicked in next(): {:?} - query: {}", + e, query + ) }; - log::error!("{}", error_string); - let _ = - response.send(PrologServiceResponse::QueryResult(Err( - format!("Scryer panicked with: {:?}", error_string), - ))); + panic = Some(error_string); + break; } } } - PrologServiceRequest::LoadModuleString( - module_name, - program_lines, - response, - ) => { - let program = program_lines - .iter() - .map(|l| l.replace(['\n', '\r'], "")) - .collect::>() - .join("\n"); - machine.consult_module_string(module_name.as_str(), program); - let _ = - response.send(PrologServiceResponse::LoadModuleResult(Ok(()))); + + if let Some(error) = panic { + Err(error) + } else { + if results.len() >= MAX_RESULTS { + log::warn!( + "Query {} truncated at {} results", + query, + MAX_RESULTS + ); + } + + query_result_from_leaf_answer(Ok(results)) + } + })); + + match answer_result { + Ok(result) => { + let _ = response.send(PrologServiceResponse::QueryResult(result)); + } + Err(e) => { + let error_string = if let Some(string) = e.downcast_ref::() + { + format!( + "Scryer panicked with: {:?} - when running query: {}", + string, query + ) + } else if let Some(&str) = e.downcast_ref::<&str>() { + format!( + "Scryer panicked with: {:?} - when running query: {}", + str, query + ) + } else { + format!( + "Scryer panicked with: {:?} - when running query: {}", + e, query + ) + }; + log::error!("{}", error_string); + let _ = response.send(PrologServiceResponse::QueryResult(Err( + format!("Scryer panicked with: {:?}", error_string), + ))); } - PrologServiceRequest::Drop => return, } } - }) - }); + PrologServiceRequest::LoadModuleString( + module_name, + program_lines, + response, + ) => { + let program = program_lines + .iter() + .map(|l| l.replace(['\n', '\r'], "")) + .collect::>() + .join("\n"); + machine.consult_module_string(module_name.as_str(), program); + let _ = response.send(PrologServiceResponse::LoadModuleResult(Ok(()))); + } + PrologServiceRequest::Drop => return, + } + } }); - match response_receiver.await? { + let response = task::spawn_blocking(move || response_receiver.recv()) + .await + .map_err(|e| Error::msg(format!("Failed to spawn blocking task: {}", e)))??; + + match response { PrologServiceResponse::InitComplete(result) => result?, _ => unreachable!(), }; @@ -119,11 +172,20 @@ impl PrologEngine { Ok(()) } + // There two levels of error handling here: + // 1. The query can fail and Prolog returns an error + // This is represented as a QueryResult with an error string + // 2. The Prolog engine can panic and we don't have a result + // This is represented with the outer Result pub async fn run_query(&self, query: String) -> Result { - let (response_sender, response_receiver) = oneshot::channel(); + let (response_sender, response_receiver) = mpsc::channel(); self.request_sender .send(PrologServiceRequest::RunQuery(query, response_sender))?; - let response = response_receiver.await?; + + let response = task::spawn_blocking(move || response_receiver.recv()) + .await + .map_err(|e| Error::msg(format!("Failed to spawn blocking task: {}", e)))??; + match response { PrologServiceResponse::QueryResult(query_result) => Ok(query_result), _ => unreachable!(), @@ -135,14 +197,18 @@ impl PrologEngine { module_name: String, program_lines: Vec, ) -> Result<(), Error> { - let (response_sender, response_receiver) = oneshot::channel(); + let (response_sender, response_receiver) = mpsc::channel(); self.request_sender .send(PrologServiceRequest::LoadModuleString( module_name, program_lines, response_sender, ))?; - let response = response_receiver.await?; + + let response = task::spawn_blocking(move || response_receiver.recv()) + .await + .map_err(|e| Error::msg(format!("Failed to spawn blocking task: {}", e)))??; + match response { PrologServiceResponse::LoadModuleResult(result) => result, _ => unreachable!(), @@ -158,6 +224,7 @@ impl PrologEngine { #[cfg(test)] mod prolog_test { use super::*; + use tokio; #[tokio::test] async fn test_init_prolog_engine() { @@ -178,16 +245,12 @@ mod prolog_test { println!("Facts loaded"); let query = String::from("triple(\"a\",P,\"b\")."); - //let query = String::from("write(\"A = \")."); - //let query = String::from("halt.\n"); println!("Running query: {}", query); let output = engine.run_query(query).await; println!("Output: {:?}", output); assert!(output.is_ok()); let query = String::from("triple(\"a\",\"p1\",\"b\")."); - //let query = String::from("write(\"A = \")."); - //let query = String::from("halt.\n"); println!("Running query: {}", query); let output = engine.run_query(query).await; println!("Output: {:?}", output); diff --git a/rust-executor/src/prolog_service/engine_pool.rs b/rust-executor/src/prolog_service/engine_pool.rs index 8ed76fe2b..bfac1c34d 100644 --- a/rust-executor/src/prolog_service/engine_pool.rs +++ b/rust-executor/src/prolog_service/engine_pool.rs @@ -87,44 +87,57 @@ impl PrologEnginePool { join_all(futures).await.into_iter().collect() } - pub async fn run_query(&self, query: String) -> Result { - let engines = self.engines.read().await; + async fn handle_engine_error( + &self, + engine_idx: usize, + error: impl std::fmt::Display, + query: &str, + ) -> Result { + log::error!("Prolog engine error: {}", error); + log::error!("when running query: {}", query); + let mut engines = self.engines.write().await; + engines[engine_idx] = None; + Err(anyhow!("Engine failed and was invalidated: {}", error)) + } - // Get a vec with all non-None (invalidated) engines - let valid_engines: Vec<_> = engines - .iter() - .enumerate() - .filter_map(|(i, e)| e.as_ref().map(|engine| (i, engine))) - .collect(); - if valid_engines.is_empty() { - log::error!("No valid Prolog engines available"); - return Err(anyhow!("No valid Prolog engines available")); - } + pub async fn run_query(&self, query: String) -> Result { + let (result, engine_idx) = { + let engines = self.engines.read().await; + + // Get a vec with all non-None (invalidated) engines + let valid_engines: Vec<_> = engines + .iter() + .enumerate() + .filter_map(|(i, e)| e.as_ref().map(|engine| (i, engine))) + .collect(); + if valid_engines.is_empty() { + log::error!("No valid Prolog engines available"); + return Err(anyhow!("No valid Prolog engines available")); + } - // Round-robin selection of engine - let current = self.next_engine.fetch_add(1, Ordering::SeqCst); - let idx = current % valid_engines.len(); - let (engine_idx, engine) = valid_engines[idx]; + // Round-robin selection of engine + let current = self.next_engine.fetch_add(1, Ordering::SeqCst); + let idx = current % valid_engines.len(); + let (engine_idx, engine) = valid_engines[idx]; - // Preprocess query to replace huge vector URLs with small cache IDs - let processed_query = self.replace_embedding_url(query.clone()).await; + // Preprocess query to replace huge vector URLs with small cache IDs + let processed_query = self.replace_embedding_url(query.clone()).await; - // Run query - let result = engine.run_query(processed_query.clone()).await; + // Run query + let result = engine.run_query(processed_query.clone()).await; + (result, engine_idx) + }; let result = match result { - Err(e) => { - log::error!("Prolog engine error: {}", e); - log::error!("when running query: {}", query); - drop(engines); - let mut engines = self.engines.write().await; - engines[engine_idx] = None; - Err(anyhow!("Engine failed and was invalidated: {}", e)) - } - Ok(mut result) => { + // Outer Result is an error -> engine panicked + Err(e) => self.handle_engine_error(engine_idx, e, &query).await, + // Inner Result is an error -> query failed + Ok(Err(e)) => Ok(Err(e)), + // Inner Result is a QueryResolution -> query succeeded + Ok(Ok(mut result)) => { // Postprocess result to replace small cache IDs with huge vector URLs // In-place and async/parallel processing of all values in all matches - if let Ok(QueryResolution::Matches(ref mut matches)) = result { + if let QueryResolution::Matches(ref mut matches) = result { join_all(matches.iter_mut().map(|m| { join_all(m.bindings.iter_mut().map(|(_, value)| { self.replace_embedding_url_in_value_recursively(value) @@ -132,7 +145,7 @@ impl PrologEnginePool { })) .await; } - Ok(result) + Ok(Ok(result)) } }; diff --git a/rust-executor/src/prolog_service/mod.rs b/rust-executor/src/prolog_service/mod.rs index 437625164..a2351f786 100644 --- a/rust-executor/src/prolog_service/mod.rs +++ b/rust-executor/src/prolog_service/mod.rs @@ -134,7 +134,7 @@ mod prolog_test { let result = service .run_query(perspective_id.clone(), query) .await - .expect("Error running query"); + .expect("no error running query"); assert_eq!( result, @@ -152,7 +152,7 @@ mod prolog_test { let result = service .run_query(perspective_id.clone(), query) .await - .expect("Error running query"); + .expect("no error running query"); assert_eq!(result, Ok(QueryResolution::True));