Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions Modules/_base64/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ struct BorrowedBuffer {
}

impl BorrowedBuffer {
unsafe fn from_object(obj: *mut PyObject) -> Result<Self, ()> {
fn from_object(obj: &PyObject) -> Result<Self, ()> {
let mut view = MaybeUninit::<Py_buffer>::uninit();
if unsafe { PyObject_GetBuffer(obj, view.as_mut_ptr(), PYBUF_SIMPLE) } != 0 {
if unsafe { PyObject_GetBuffer(obj.as_raw(), view.as_mut_ptr(), PYBUF_SIMPLE) } != 0 {
return Err(());
}
Ok(Self {
Expand All @@ -110,6 +110,9 @@ impl Drop for BorrowedBuffer {
}
}

/// # Safety
/// `module` must be a valid pointer of PyObject representing the module.
/// `args` must be a valid pointer to an array of valid PyObject pointers with length `nargs`.
#[unsafe(no_mangle)]
pub unsafe extern "C" fn standard_b64encode(
_module: *mut PyObject,
Expand All @@ -126,10 +129,19 @@ pub unsafe extern "C" fn standard_b64encode(
return ptr::null_mut();
}

let source = unsafe { *args };
let buffer = match unsafe { BorrowedBuffer::from_object(source) } {
let source = unsafe { &**args };

// Safe cast by Safety
match standard_b64encode_impl(source) {
Ok(result) => result,
Err(_) => ptr::null_mut(),
}
}

fn standard_b64encode_impl(source: &PyObject) -> Result<*mut PyObject, ()> {
let buffer = match BorrowedBuffer::from_object(source) {
Ok(buf) => buf,
Err(_) => return ptr::null_mut(),
Err(_) => return Err(()),
};

let view_len = buffer.len();
Expand All @@ -140,44 +152,43 @@ pub unsafe extern "C" fn standard_b64encode(
c"standard_b64encode() argument has negative length".as_ptr(),
);
}
return ptr::null_mut();
return Err(());
}

let input_len = view_len as usize;
let input = unsafe { slice::from_raw_parts(buffer.as_ptr(), input_len) };

let Some(output_len) = encoded_output_len(input_len) else {
unsafe {
PyErr_NoMemory();
}
return ptr::null_mut();
return Err(());
};

if output_len > isize::MAX as usize {
unsafe {
PyErr_NoMemory();
}
return ptr::null_mut();
return Err(());
}

let result = unsafe {
PyBytes_FromStringAndSize(ptr::null(), output_len as Py_ssize_t)
};
let result = unsafe { PyBytes_FromStringAndSize(ptr::null(), output_len as Py_ssize_t) };
if result.is_null() {
return ptr::null_mut();
return Err(());
}

let dest_ptr = unsafe { PyBytes_AsString(result) };
if dest_ptr.is_null() {
unsafe {
Py_DecRef(result);
}
return ptr::null_mut();
return Err(());
}
let dest = unsafe { slice::from_raw_parts_mut(dest_ptr.cast::<u8>(), output_len) };

let written = encode_into(input, dest);
debug_assert_eq!(written, output_len);
result
Ok(result)
}

#[unsafe(no_mangle)]
Expand Down
1 change: 1 addition & 0 deletions Modules/cpython-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ fn generate_c_api_bindings(srcdir: &Path, builddir: Option<&str>, out_path: &Pat
.allowlist_type("_?Py.*")
.allowlist_var("_?Py.*")
.blocklist_type("^PyMethodDef$")
.blocklist_type("PyObject")
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
.generate()
.expect("Unable to generate bindings");
Expand Down
19 changes: 15 additions & 4 deletions Modules/cpython-sys/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,17 @@ pub const _Py_STATIC_IMMORTAL_INITIAL_REFCNT: Py_ssize_t =
#[cfg(not(target_pointer_width = "64"))]
pub const _Py_STATIC_IMMORTAL_INITIAL_REFCNT: Py_ssize_t = 7u32 << 28;

#[repr(transparent)]
pub struct PyObject(std::cell::UnsafeCell<_object>);

impl PyObject {
#[inline]
pub fn as_raw(&self) -> *mut Self {
self.0.get() as *mut Self
}
}


#[repr(C)]
pub union PyMethodDefFuncPointer {
pub PyCFunction: unsafe extern "C" fn(slf: *mut PyObject, args: *mut PyObject) -> *mut PyObject,
Expand Down Expand Up @@ -113,18 +124,18 @@ unsafe impl Send for PyMethodDef {}

#[cfg(py_gil_disabled)]
pub const PyObject_HEAD_INIT: PyObject = {
let mut obj: PyObject = unsafe { std::mem::MaybeUninit::zeroed().assume_init() };
let mut obj: _object = unsafe { std::mem::MaybeUninit::zeroed().assume_init() };
obj.ob_flags = _Py_STATICALLY_ALLOCATED_FLAG as _;
obj
PyObject(std::cell::UnsafeCell::new(obj))
};

#[cfg(not(py_gil_disabled))]
pub const PyObject_HEAD_INIT: PyObject = PyObject {
pub const PyObject_HEAD_INIT: PyObject = PyObject(std::cell::UnsafeCell::new(_object {
__bindgen_anon_1: _object__bindgen_ty_1 {
ob_refcnt_full: _Py_STATIC_IMMORTAL_INITIAL_REFCNT as i64,
},
ob_type: std::ptr::null_mut(),
};
}));

pub const PyModuleDef_HEAD_INIT: PyModuleDef_Base = PyModuleDef_Base {
ob_base: PyObject_HEAD_INIT,
Expand Down