diff --git a/Modules/_base64/src/lib.rs b/Modules/_base64/src/lib.rs index f308a61daf301e..49fd7930045c0b 100644 --- a/Modules/_base64/src/lib.rs +++ b/Modules/_base64/src/lib.rs @@ -83,9 +83,9 @@ struct BorrowedBuffer { } impl BorrowedBuffer { - unsafe fn from_object(obj: *mut PyObject) -> Result { + fn from_object(obj: &PyObject) -> Result { let mut view = MaybeUninit::::uninit(); - if unsafe { PyObject_GetBuffer(obj, view.as_mut_ptr(), PYBUF_SIMPLE) } != 0 { + if unsafe { PyObject_GetBuffer(obj.as_raw(), view.as_mut_ptr(), PYBUF_SIMPLE) } != 0 { return Err(()); } Ok(Self { @@ -110,6 +110,9 @@ impl Drop for BorrowedBuffer { } } +/// # Safety +/// `module` must be a valid pointer of PyObject representing the module. +/// `args` must be a valid pointer to an array of valid PyObject pointers with length `nargs`. #[unsafe(no_mangle)] pub unsafe extern "C" fn standard_b64encode( _module: *mut PyObject, @@ -126,10 +129,19 @@ pub unsafe extern "C" fn standard_b64encode( return ptr::null_mut(); } - let source = unsafe { *args }; - let buffer = match unsafe { BorrowedBuffer::from_object(source) } { + let source = unsafe { &**args }; + + // Safe cast by Safety + match standard_b64encode_impl(source) { + Ok(result) => result, + Err(_) => ptr::null_mut(), + } +} + +fn standard_b64encode_impl(source: &PyObject) -> Result<*mut PyObject, ()> { + let buffer = match BorrowedBuffer::from_object(source) { Ok(buf) => buf, - Err(_) => return ptr::null_mut(), + Err(_) => return Err(()), }; let view_len = buffer.len(); @@ -140,8 +152,9 @@ pub unsafe extern "C" fn standard_b64encode( c"standard_b64encode() argument has negative length".as_ptr(), ); } - return ptr::null_mut(); + return Err(()); } + let input_len = view_len as usize; let input = unsafe { slice::from_raw_parts(buffer.as_ptr(), input_len) }; @@ -149,21 +162,19 @@ pub unsafe extern "C" fn standard_b64encode( unsafe { PyErr_NoMemory(); } - return ptr::null_mut(); + return Err(()); }; if output_len > isize::MAX as usize { unsafe { PyErr_NoMemory(); } - return ptr::null_mut(); + return Err(()); } - let result = unsafe { - PyBytes_FromStringAndSize(ptr::null(), output_len as Py_ssize_t) - }; + let result = unsafe { PyBytes_FromStringAndSize(ptr::null(), output_len as Py_ssize_t) }; if result.is_null() { - return ptr::null_mut(); + return Err(()); } let dest_ptr = unsafe { PyBytes_AsString(result) }; @@ -171,13 +182,13 @@ pub unsafe extern "C" fn standard_b64encode( unsafe { Py_DecRef(result); } - return ptr::null_mut(); + return Err(()); } let dest = unsafe { slice::from_raw_parts_mut(dest_ptr.cast::(), output_len) }; let written = encode_into(input, dest); debug_assert_eq!(written, output_len); - result + Ok(result) } #[unsafe(no_mangle)] diff --git a/Modules/cpython-sys/build.rs b/Modules/cpython-sys/build.rs index 680066c4fd5e9d..8256e2fc93cd03 100644 --- a/Modules/cpython-sys/build.rs +++ b/Modules/cpython-sys/build.rs @@ -55,6 +55,7 @@ fn generate_c_api_bindings(srcdir: &Path, builddir: Option<&str>, out_path: &Pat .allowlist_type("_?Py.*") .allowlist_var("_?Py.*") .blocklist_type("^PyMethodDef$") + .blocklist_type("PyObject") .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) .generate() .expect("Unable to generate bindings"); diff --git a/Modules/cpython-sys/src/lib.rs b/Modules/cpython-sys/src/lib.rs index ed1d68eedd600a..9a3c46b34d8c36 100644 --- a/Modules/cpython-sys/src/lib.rs +++ b/Modules/cpython-sys/src/lib.rs @@ -56,6 +56,17 @@ pub const _Py_STATIC_IMMORTAL_INITIAL_REFCNT: Py_ssize_t = #[cfg(not(target_pointer_width = "64"))] pub const _Py_STATIC_IMMORTAL_INITIAL_REFCNT: Py_ssize_t = 7u32 << 28; +#[repr(transparent)] +pub struct PyObject(std::cell::UnsafeCell<_object>); + +impl PyObject { + #[inline] + pub fn as_raw(&self) -> *mut Self { + self.0.get() as *mut Self + } +} + + #[repr(C)] pub union PyMethodDefFuncPointer { pub PyCFunction: unsafe extern "C" fn(slf: *mut PyObject, args: *mut PyObject) -> *mut PyObject, @@ -113,18 +124,18 @@ unsafe impl Send for PyMethodDef {} #[cfg(py_gil_disabled)] pub const PyObject_HEAD_INIT: PyObject = { - let mut obj: PyObject = unsafe { std::mem::MaybeUninit::zeroed().assume_init() }; + let mut obj: _object = unsafe { std::mem::MaybeUninit::zeroed().assume_init() }; obj.ob_flags = _Py_STATICALLY_ALLOCATED_FLAG as _; - obj + PyObject(std::cell::UnsafeCell::new(obj)) }; #[cfg(not(py_gil_disabled))] -pub const PyObject_HEAD_INIT: PyObject = PyObject { +pub const PyObject_HEAD_INIT: PyObject = PyObject(std::cell::UnsafeCell::new(_object { __bindgen_anon_1: _object__bindgen_ty_1 { ob_refcnt_full: _Py_STATIC_IMMORTAL_INITIAL_REFCNT as i64, }, ob_type: std::ptr::null_mut(), -}; +})); pub const PyModuleDef_HEAD_INIT: PyModuleDef_Base = PyModuleDef_Base { ob_base: PyObject_HEAD_INIT,