diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index 5526508f5aca..4374fed04921 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -783,7 +783,7 @@ PyObject *CPyBytes_Concat(PyObject *a, PyObject *b); PyObject *CPyBytes_Join(PyObject *sep, PyObject *iter); CPyTagged CPyBytes_Ord(PyObject *obj); PyObject *CPyBytes_Multiply(PyObject *bytes, CPyTagged count); - +int CPyBytes_Startswith(PyObject *self, PyObject *subobj); int CPyBytes_Compare(PyObject *left, PyObject *right); diff --git a/mypyc/lib-rt/bytes_ops.c b/mypyc/lib-rt/bytes_ops.c index 8ecf9337c28b..25718b4603b3 100644 --- a/mypyc/lib-rt/bytes_ops.c +++ b/mypyc/lib-rt/bytes_ops.c @@ -171,3 +171,41 @@ PyObject *CPyBytes_Multiply(PyObject *bytes, CPyTagged count) { } return PySequence_Repeat(bytes, temp_count); } + +int CPyBytes_Startswith(PyObject *self, PyObject *subobj) { + if (PyBytes_CheckExact(self) && PyBytes_CheckExact(subobj)) { + if (self == subobj) { + return 1; + } + + Py_ssize_t subobj_len = PyBytes_GET_SIZE(subobj); + if (subobj_len == 0) { + return 1; + } + + Py_ssize_t self_len = PyBytes_GET_SIZE(self); + if (subobj_len > self_len) { + return 0; + } + + const char *self_buf = PyBytes_AS_STRING(self); + const char *subobj_buf = PyBytes_AS_STRING(subobj); + + return memcmp(self_buf, subobj_buf, (size_t)subobj_len) == 0 ? 1 : 0; + } + _Py_IDENTIFIER(startswith); + PyObject *name = _PyUnicode_FromId(&PyId_startswith); + if (name == NULL) { + return 2; + } + PyObject *result = PyObject_CallMethodOneArg(self, name, subobj); + if (result == NULL) { + return 2; + } + int ret = PyObject_IsTrue(result); + Py_DECREF(result); + if (ret < 0) { + return 2; + } + return ret; +} diff --git a/mypyc/primitives/bytes_ops.py b/mypyc/primitives/bytes_ops.py index 0669ddac00df..728da4181135 100644 --- a/mypyc/primitives/bytes_ops.py +++ b/mypyc/primitives/bytes_ops.py @@ -7,6 +7,7 @@ from mypyc.ir.rtypes import ( RUnion, bit_rprimitive, + bool_rprimitive, bytes_rprimitive, c_int_rprimitive, c_pyssize_t_rprimitive, @@ -139,6 +140,16 @@ dependencies=[BYTES_EXTRA_OPS], ) +# bytes.startswith(bytes) +method_op( + name="startswith", + arg_types=[bytes_rprimitive, bytes_rprimitive], + return_type=c_int_rprimitive, + c_function_name="CPyBytes_Startswith", + truncated_type=bool_rprimitive, + error_kind=ERR_MAGIC, +) + # Join bytes objects and return a new bytes. # The first argument is the total number of the following bytes. bytes_build_op = custom_op( diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index d9202707124b..592f6676e95e 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -179,6 +179,7 @@ def __getitem__(self, i: slice) -> bytes: ... def join(self, x: Iterable[object]) -> bytes: ... def decode(self, encoding: str=..., errors: str=...) -> str: ... def translate(self, t: bytes) -> bytes: ... + def startswith(self, t: bytes) -> bool: ... def __iter__(self) -> Iterator[int]: ... class bytearray: @@ -192,6 +193,7 @@ def __add__(self, s: bytes) -> bytearray: ... def __setitem__(self, i: int, o: int) -> None: ... def __getitem__(self, i: int) -> int: ... def decode(self, x: str = ..., y: str = ...) -> str: ... + def startswith(self, t: bytes) -> bool: ... class bool(int): def __init__(self, o: object = ...) -> None: ... diff --git a/mypyc/test-data/irbuild-bytes.test b/mypyc/test-data/irbuild-bytes.test index 9473944f44fe..5e7c546eb25a 100644 --- a/mypyc/test-data/irbuild-bytes.test +++ b/mypyc/test-data/irbuild-bytes.test @@ -248,3 +248,16 @@ def f(b, table): L0: r0 = CPyBytes_Translate(b, table) return r0 + +[case testBytesStartsWith] +def f(a: bytes, b: bytes) -> bool: + return a.startswith(b) +[out] +def f(a, b): + a, b :: bytes + r0 :: i32 + r1 :: bool +L0: + r0 = CPyBytes_Startswith(a, b) + r1 = truncate r0: i32 to builtins.bool + return r1 diff --git a/mypyc/test-data/run-bytes.test b/mypyc/test-data/run-bytes.test index 9a319b636772..6e4b57152a4b 100644 --- a/mypyc/test-data/run-bytes.test +++ b/mypyc/test-data/run-bytes.test @@ -200,6 +200,27 @@ def test_translate() -> None: with assertRaises(ValueError, "translation table must be 256 characters long"): b'test'.translate(bytes(100)) +def test_startswith() -> None: + # Test default behavior + test = b'some string' + assert test.startswith(b'some') + assert test.startswith(b'some string') + assert not test.startswith(b'other') + assert not test.startswith(b'some string but longer') + + # Test empty cases + assert test.startswith(b'') + assert b''.startswith(b'') + assert not b''.startswith(test) + + # Test bytearray to verify slow paths + assert test.startswith(bytearray(b'some')) + assert not test.startswith(bytearray(b'other')) + + test = bytearray(b'some string') + assert test.startswith(b'some') + assert not test.startswith(b'other') + [case testBytesSlicing] def test_bytes_slicing() -> None: b = b'abcdefg'