Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 134 additions & 71 deletions rust-executor/src/prolog_service/engine.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
use std::panic::AssertUnwindSafe;
use std::sync::mpsc;
use std::sync::Arc;
use std::sync::Mutex;

use deno_core::anyhow::Error;
use scryer_prolog::{LeafAnswer, MachineBuilder, Term};
use tokio::sync::{mpsc, oneshot};
use scryer_prolog::MachineBuilder;
use tokio::task;

use super::types::{query_result_from_leaf_answer, QueryResult};

#[derive(Debug)]
pub enum PrologServiceRequest {
RunQuery(String, oneshot::Sender<PrologServiceResponse>),
LoadModuleString(String, Vec<String>, oneshot::Sender<PrologServiceResponse>),
RunQuery(String, mpsc::Sender<PrologServiceResponse>),
LoadModuleString(String, Vec<String>, mpsc::Sender<PrologServiceResponse>),
#[allow(dead_code)]
Drop,
}
Expand All @@ -21,109 +24,168 @@ pub enum PrologServiceResponse {
LoadModuleResult(Result<(), Error>),
}

struct SendableReceiver<T>(Arc<Mutex<mpsc::Receiver<T>>>);

unsafe impl<T> Sync for SendableReceiver<T> {}

pub struct PrologEngine {
request_sender: mpsc::UnboundedSender<PrologServiceRequest>,
request_receiver: Option<mpsc::UnboundedReceiver<PrologServiceRequest>>,
request_sender: mpsc::Sender<PrologServiceRequest>,
request_receiver: Option<SendableReceiver<PrologServiceRequest>>,
}

impl PrologEngine {
pub fn new() -> PrologEngine {
let (request_sender, request_receiver) = mpsc::unbounded_channel::<PrologServiceRequest>();
let (request_sender, request_receiver) = mpsc::channel::<PrologServiceRequest>();

PrologEngine {
request_sender,
request_receiver: Some(request_receiver),
request_receiver: Some(SendableReceiver(Arc::new(Mutex::new(request_receiver)))),
}
}

pub async fn spawn(&mut self) -> Result<(), Error> {
let mut receiver = self
let receiver = self
.request_receiver
.take()
.ok_or_else(|| Error::msg("PrologEngine::spawn called twice"))?;
let (response_sender, response_receiver) = oneshot::channel();
let (response_sender, response_receiver) = mpsc::channel();

std::thread::spawn(move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.thread_name(String::from("prolog_service"))
.build()
.expect("Failed to create Tokio runtime");
let _guard = rt.enter();

tokio::task::block_in_place(|| {
rt.block_on(async move {
let mut machine = MachineBuilder::default().build();

response_sender
.send(PrologServiceResponse::InitComplete(Ok(())))
.unwrap();

while let Some(message) = receiver.recv().await {
match message {
PrologServiceRequest::RunQuery(query, response) => {
let answer_result = std::panic::catch_unwind(AssertUnwindSafe(|| {
query_result_from_leaf_answer(machine
.run_query(query.clone())
.collect::<Result<Vec<LeafAnswer>, Term>>())
}));

match answer_result {
Ok(result) => {
let _ = response
.send(PrologServiceResponse::QueryResult(result));
let mut machine = MachineBuilder::default().build();
let receiver = receiver.0.lock().unwrap();

response_sender
.send(PrologServiceResponse::InitComplete(Ok(())))
.unwrap();

while let Ok(message) = receiver.recv() {
match message {
PrologServiceRequest::RunQuery(query, response) => {
let answer_result = std::panic::catch_unwind(AssertUnwindSafe(|| {
let mut results = Vec::new();
let mut iter = machine.run_query(query.clone());
const MAX_RESULTS: usize = 1_000_000; // Adjust as needed

let mut panic = None;
while results.len() < MAX_RESULTS {
match std::panic::catch_unwind(AssertUnwindSafe(|| iter.next())) {
Ok(Some(Ok(answer))) => results.push(answer),
Ok(Some(Err(term))) => {
return query_result_from_leaf_answer(Err(term))
}
Ok(None) => break, // Iterator exhausted
Err(e) => {
let error_string =
if let Some(string) = e.downcast_ref::<String>() {
format!("Scryer panicked with: {:?} - when running query: {}", string, query)
format!(
"Scryer panicked in next(): {:?} - query: {}",
string, query
)
} else if let Some(&str) = e.downcast_ref::<&str>() {
format!("Scryer panicked with: {:?} - when running query: {}", str, query)
format!(
"Scryer panicked in next(): {:?} - query: {}",
str, query
)
} else {
format!("Scryer panicked with: {:?} - when running query: {}", e, query)
format!(
"Scryer panicked in next(): {:?} - query: {}",
e, query
)
};
log::error!("{}", error_string);
let _ =
response.send(PrologServiceResponse::QueryResult(Err(
format!("Scryer panicked with: {:?}", error_string),
)));
panic = Some(error_string);
break;
}
}
}
PrologServiceRequest::LoadModuleString(
module_name,
program_lines,
response,
) => {
let program = program_lines
.iter()
.map(|l| l.replace(['\n', '\r'], ""))
.collect::<Vec<String>>()
.join("\n");
machine.consult_module_string(module_name.as_str(), program);
let _ =
response.send(PrologServiceResponse::LoadModuleResult(Ok(())));

if let Some(error) = panic {
Err(error)
} else {
if results.len() >= MAX_RESULTS {
log::warn!(
"Query {} truncated at {} results",
query,
MAX_RESULTS
);
}

query_result_from_leaf_answer(Ok(results))
}
}));

match answer_result {
Ok(result) => {
let _ = response.send(PrologServiceResponse::QueryResult(result));
}
Err(e) => {
let error_string = if let Some(string) = e.downcast_ref::<String>()
{
format!(
"Scryer panicked with: {:?} - when running query: {}",
string, query
)
} else if let Some(&str) = e.downcast_ref::<&str>() {
format!(
"Scryer panicked with: {:?} - when running query: {}",
str, query
)
} else {
format!(
"Scryer panicked with: {:?} - when running query: {}",
e, query
)
};
log::error!("{}", error_string);
let _ = response.send(PrologServiceResponse::QueryResult(Err(
format!("Scryer panicked with: {:?}", error_string),
)));
}
PrologServiceRequest::Drop => return,
}
}
})
});
PrologServiceRequest::LoadModuleString(
module_name,
program_lines,
response,
) => {
let program = program_lines
.iter()
.map(|l| l.replace(['\n', '\r'], ""))
.collect::<Vec<String>>()
.join("\n");
machine.consult_module_string(module_name.as_str(), program);
let _ = response.send(PrologServiceResponse::LoadModuleResult(Ok(())));
}
PrologServiceRequest::Drop => return,
}
}
});

match response_receiver.await? {
let response = task::spawn_blocking(move || response_receiver.recv())
.await
.map_err(|e| Error::msg(format!("Failed to spawn blocking task: {}", e)))??;

match response {
PrologServiceResponse::InitComplete(result) => result?,
_ => unreachable!(),
};

Ok(())
}

// There two levels of error handling here:
// 1. The query can fail and Prolog returns an error
// This is represented as a QueryResult with an error string
// 2. The Prolog engine can panic and we don't have a result
// This is represented with the outer Result
pub async fn run_query(&self, query: String) -> Result<QueryResult, Error> {
let (response_sender, response_receiver) = oneshot::channel();
let (response_sender, response_receiver) = mpsc::channel();
self.request_sender
.send(PrologServiceRequest::RunQuery(query, response_sender))?;
let response = response_receiver.await?;

let response = task::spawn_blocking(move || response_receiver.recv())
.await
.map_err(|e| Error::msg(format!("Failed to spawn blocking task: {}", e)))??;

match response {
PrologServiceResponse::QueryResult(query_result) => Ok(query_result),
_ => unreachable!(),
Expand All @@ -135,14 +197,18 @@ impl PrologEngine {
module_name: String,
program_lines: Vec<String>,
) -> Result<(), Error> {
let (response_sender, response_receiver) = oneshot::channel();
let (response_sender, response_receiver) = mpsc::channel();
self.request_sender
.send(PrologServiceRequest::LoadModuleString(
module_name,
program_lines,
response_sender,
))?;
let response = response_receiver.await?;

let response = task::spawn_blocking(move || response_receiver.recv())
.await
.map_err(|e| Error::msg(format!("Failed to spawn blocking task: {}", e)))??;

match response {
PrologServiceResponse::LoadModuleResult(result) => result,
_ => unreachable!(),
Expand All @@ -158,6 +224,7 @@ impl PrologEngine {
#[cfg(test)]
mod prolog_test {
use super::*;
use tokio;

#[tokio::test]
async fn test_init_prolog_engine() {
Expand All @@ -178,16 +245,12 @@ mod prolog_test {
println!("Facts loaded");

let query = String::from("triple(\"a\",P,\"b\").");
//let query = String::from("write(\"A = \").");
//let query = String::from("halt.\n");
println!("Running query: {}", query);
let output = engine.run_query(query).await;
println!("Output: {:?}", output);
assert!(output.is_ok());

let query = String::from("triple(\"a\",\"p1\",\"b\").");
//let query = String::from("write(\"A = \").");
//let query = String::from("halt.\n");
println!("Running query: {}", query);
let output = engine.run_query(query).await;
println!("Output: {:?}", output);
Expand Down
75 changes: 44 additions & 31 deletions rust-executor/src/prolog_service/engine_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,52 +87,65 @@ impl PrologEnginePool {
join_all(futures).await.into_iter().collect()
}

pub async fn run_query(&self, query: String) -> Result<QueryResult, Error> {
let engines = self.engines.read().await;
async fn handle_engine_error(
&self,
engine_idx: usize,
error: impl std::fmt::Display,
query: &str,
) -> Result<QueryResult, Error> {
log::error!("Prolog engine error: {}", error);
log::error!("when running query: {}", query);
let mut engines = self.engines.write().await;
engines[engine_idx] = None;
Err(anyhow!("Engine failed and was invalidated: {}", error))
}

// Get a vec with all non-None (invalidated) engines
let valid_engines: Vec<_> = engines
.iter()
.enumerate()
.filter_map(|(i, e)| e.as_ref().map(|engine| (i, engine)))
.collect();
if valid_engines.is_empty() {
log::error!("No valid Prolog engines available");
return Err(anyhow!("No valid Prolog engines available"));
}
pub async fn run_query(&self, query: String) -> Result<QueryResult, Error> {
let (result, engine_idx) = {
let engines = self.engines.read().await;

// Get a vec with all non-None (invalidated) engines
let valid_engines: Vec<_> = engines
.iter()
.enumerate()
.filter_map(|(i, e)| e.as_ref().map(|engine| (i, engine)))
.collect();
if valid_engines.is_empty() {
log::error!("No valid Prolog engines available");
return Err(anyhow!("No valid Prolog engines available"));
}

// Round-robin selection of engine
let current = self.next_engine.fetch_add(1, Ordering::SeqCst);
let idx = current % valid_engines.len();
let (engine_idx, engine) = valid_engines[idx];
// Round-robin selection of engine
let current = self.next_engine.fetch_add(1, Ordering::SeqCst);
let idx = current % valid_engines.len();
let (engine_idx, engine) = valid_engines[idx];

// Preprocess query to replace huge vector URLs with small cache IDs
let processed_query = self.replace_embedding_url(query.clone()).await;
// Preprocess query to replace huge vector URLs with small cache IDs
let processed_query = self.replace_embedding_url(query.clone()).await;

// Run query
let result = engine.run_query(processed_query.clone()).await;
// Run query
let result = engine.run_query(processed_query.clone()).await;
(result, engine_idx)
};

let result = match result {
Err(e) => {
log::error!("Prolog engine error: {}", e);
log::error!("when running query: {}", query);
drop(engines);
let mut engines = self.engines.write().await;
engines[engine_idx] = None;
Err(anyhow!("Engine failed and was invalidated: {}", e))
}
Ok(mut result) => {
// Outer Result is an error -> engine panicked
Err(e) => self.handle_engine_error(engine_idx, e, &query).await,
// Inner Result is an error -> query failed
Ok(Err(e)) => Ok(Err(e)),
// Inner Result is a QueryResolution -> query succeeded
Ok(Ok(mut result)) => {
// Postprocess result to replace small cache IDs with huge vector URLs
// In-place and async/parallel processing of all values in all matches
if let Ok(QueryResolution::Matches(ref mut matches)) = result {
if let QueryResolution::Matches(ref mut matches) = result {
join_all(matches.iter_mut().map(|m| {
join_all(m.bindings.iter_mut().map(|(_, value)| {
self.replace_embedding_url_in_value_recursively(value)
}))
}))
.await;
}
Ok(result)
Ok(Ok(result))
}
};

Expand Down
Loading