Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pgrx-examples/srf/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ fn result_table<'a>() -> Result<
Ok(Some(TableIterator::new(vec![(Some(1), Some(2))])))
}

#[pg_extern]
fn one_col<'a>() -> TableIterator<'a, (name!(a, Option<i32>),)> {
TableIterator::new(std::iter::once((Some(42),)))
}

#[cfg(any(test, feature = "pg_test"))]
#[pg_schema]
mod tests {
Expand Down
58 changes: 43 additions & 15 deletions pgrx-sql-entity-graph/src/pg_extern/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ impl PgExtern {
}
}
}
Returning::Iterated { tys: _retval_tys, optional, result } => {
Returning::Iterated { tys: retval_tys, optional, result } => {
let result_handler = if *optional && *result {
// don't need unsafe annotations because of the larger unsafe block coming up
quote_spanned! { self.func.sig.span() =>
Expand All @@ -543,20 +543,48 @@ impl PgExtern {
}
};

quote_spanned! { self.func.sig.span() =>
#[no_mangle]
#[doc(hidden)]
#[::pgrx::pgrx_macros::pg_guard]
pub unsafe extern "C" fn #func_name_wrapper #func_generics(#fcinfo_ident: ::pgrx::pg_sys::FunctionCallInfo) -> ::pgrx::pg_sys::Datum {
#[allow(unused_unsafe)]
unsafe {
// SAFETY: the caller has asserted that `fcinfo` is a valid FunctionCallInfo pointer, allocated by Postgres
// with all its fields properly setup. Unless the user is calling this wrapper function directly, this
// will always be the case
::pgrx::iter::TableIterator::srf_next(#fcinfo_ident, || {
#( #arg_fetches )*
#result_handler
})
if retval_tys.len() == 1 {
// Postgres considers functions returning a 1-field table (`RETURNS TABLE (T)`) to be
// a function that `RETRUNS SETOF T`. So we write a different wrapper implementation
// that transparently transforms the `TableIterator` returned by the user into a `SetOfIterator`
quote_spanned! { self.func.sig.span() =>
#[no_mangle]
#[doc(hidden)]
#[::pgrx::pgrx_macros::pg_guard]
pub unsafe extern "C" fn #func_name_wrapper #func_generics(#fcinfo_ident: ::pgrx::pg_sys::FunctionCallInfo) -> ::pgrx::pg_sys::Datum {
#[allow(unused_unsafe)]
unsafe {
// SAFETY: the caller has asserted that `fcinfo` is a valid FunctionCallInfo pointer, allocated by Postgres
// with all its fields properly setup. Unless the user is calling this wrapper function directly, this
// will always be the case
::pgrx::iter::SetOfIterator::srf_next(#fcinfo_ident, || {
#( #arg_fetches )*
let table_iterator = { #result_handler };

// we need to convert the 1-field `TableIterator` provided by the user
// into a SetOfIterator in order to properly handle the case of `RETURNS TABLE (T)`,
// which is a table that returns only 1 field.
table_iterator.map(|i| ::pgrx::iter::SetOfIterator::new(i.into_iter().map(|(v,)| v)))
})
}
}
}
} else {
quote_spanned! { self.func.sig.span() =>
#[no_mangle]
#[doc(hidden)]
#[::pgrx::pgrx_macros::pg_guard]
pub unsafe extern "C" fn #func_name_wrapper #func_generics(#fcinfo_ident: ::pgrx::pg_sys::FunctionCallInfo) -> ::pgrx::pg_sys::Datum {
#[allow(unused_unsafe)]
unsafe {
// SAFETY: the caller has asserted that `fcinfo` is a valid FunctionCallInfo pointer, allocated by Postgres
// with all its fields properly setup. Unless the user is calling this wrapper function directly, this
// will always be the case
::pgrx::iter::TableIterator::srf_next(#fcinfo_ident, || {
#( #arg_fetches )*
#result_handler
})
}
}
}
}
Expand Down
29 changes: 29 additions & 0 deletions pgrx-tests/src/tests/srf_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,27 @@ fn result_table_5_none<'a>() -> Result<
Ok(None)
}

#[pg_extern]
fn one_col<'a>() -> TableIterator<'a, (name!(a, i32),)> {
TableIterator::new(std::iter::once((42,)))
}

#[pg_extern]
fn one_col_option<'a>() -> Option<TableIterator<'a, (name!(a, i32),)>> {
Some(TableIterator::new(std::iter::once((42,))))
}

#[pg_extern]
fn one_col_result<'a>() -> Result<TableIterator<'a, (name!(a, i32),)>, Box<dyn std::error::Error>> {
Ok(TableIterator::new(std::iter::once((42,))))
}

#[pg_extern]
fn one_col_result_option<'a>(
) -> Result<Option<TableIterator<'a, (name!(a, i32),)>>, Box<dyn std::error::Error>> {
Ok(Some(TableIterator::new(std::iter::once((42,)))))
}

#[cfg(any(test, feature = "pg_test"))]
#[pgrx::pg_schema]
mod tests {
Expand Down Expand Up @@ -316,4 +337,12 @@ mod tests {
let result = Spi::get_two::<i32, i32>("SELECT * from result_table_5_none()");
assert_eq!(result, Err(spi::Error::InvalidPosition));
}

#[pg_test]
pub fn test_one_col_table() {
assert_eq!(Spi::get_one::<i32>("SELECT * from one_col()"), Ok(Some(42)));
assert_eq!(Spi::get_one::<i32>("SELECT * from one_col_option()"), Ok(Some(42)));
assert_eq!(Spi::get_one::<i32>("SELECT * from one_col_result()"), Ok(Some(42)));
assert_eq!(Spi::get_one::<i32>("SELECT * from one_col_result_option()"), Ok(Some(42)));
}
}