Skip to content

Commit 3f07659

Browse files
goutamvenkat-anyscaleelliot-barn
authored andcommitted
[Data] [1/n] Predicate Expression Support (#56313)
<!-- Thank you for your contribution! Please review https://github.com/ray-project/ray/blob/master/CONTRIBUTING.rst before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? Adds support for predicate expressions in Ray Data's Expression System. Involves the following: 1. Support for unary operations (NOT, IS_NULL(), IN and their inverses) 2. Add `PredicateExpr` to expression evaluator ## Related issue number ## Checks - [x] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [x] I've run `scripts/format.sh` to lint the changes in this PR. - [ ] I've included any doc changes needed for https://docs.ray.io/en/master/. - [ ] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [ ] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [x] Unit tests - [ ] Release tests - [ ] This PR is not tested :( --------- Signed-off-by: Goutam V. <goutam@anyscale.com> Signed-off-by: elliot-barn <elliot.barnwell@anyscale.com>
1 parent 9983920 commit 3f07659

File tree

7 files changed

+745
-14
lines changed

7 files changed

+745
-14
lines changed

doc/source/data/api/expressions.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,6 @@ instantiate them directly, but you may encounter them when working with expressi
3535
Expr
3636
ColumnExpr
3737
LiteralExpr
38-
BinaryExpr
38+
BinaryExpr
39+
UnaryExpr
40+
UDFExpr

python/ray/data/_expression_evaluator.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,34 +16,56 @@
1616
LiteralExpr,
1717
Operation,
1818
UDFExpr,
19+
UnaryExpr,
1920
)
2021

21-
_PANDAS_EXPR_OPS_MAP = {
22+
23+
def _pa_is_in(left: Any, right: Any) -> Any:
24+
if not isinstance(right, (pa.Array, pa.ChunkedArray)):
25+
right = pa.array(right.as_py() if isinstance(right, pa.Scalar) else right)
26+
return pc.is_in(left, right)
27+
28+
29+
_PANDAS_EXPR_OPS_MAP: Dict[Operation, Callable[..., Any]] = {
2230
Operation.ADD: operator.add,
2331
Operation.SUB: operator.sub,
2432
Operation.MUL: operator.mul,
2533
Operation.DIV: operator.truediv,
34+
Operation.FLOORDIV: operator.floordiv,
2635
Operation.GT: operator.gt,
2736
Operation.LT: operator.lt,
2837
Operation.GE: operator.ge,
2938
Operation.LE: operator.le,
3039
Operation.EQ: operator.eq,
40+
Operation.NE: operator.ne,
3141
Operation.AND: operator.and_,
3242
Operation.OR: operator.or_,
43+
Operation.NOT: operator.not_,
44+
Operation.IS_NULL: pd.isna,
45+
Operation.IS_NOT_NULL: pd.notna,
46+
Operation.IN: lambda left, right: left.is_in(right),
47+
Operation.NOT_IN: lambda left, right: ~left.is_in(right),
3348
}
3449

35-
_ARROW_EXPR_OPS_MAP = {
50+
_ARROW_EXPR_OPS_MAP: Dict[Operation, Callable[..., Any]] = {
3651
Operation.ADD: pc.add,
3752
Operation.SUB: pc.subtract,
3853
Operation.MUL: pc.multiply,
3954
Operation.DIV: pc.divide,
55+
Operation.FLOORDIV: lambda left, right: pc.floor(pc.divide(left, right)),
4056
Operation.GT: pc.greater,
4157
Operation.LT: pc.less,
4258
Operation.GE: pc.greater_equal,
4359
Operation.LE: pc.less_equal,
4460
Operation.EQ: pc.equal,
45-
Operation.AND: pc.and_,
46-
Operation.OR: pc.or_,
61+
Operation.NE: pc.not_equal,
62+
Operation.AND: pc.and_kleene,
63+
Operation.OR: pc.or_kleene,
64+
Operation.NOT: pc.invert,
65+
Operation.IS_NULL: pc.is_null,
66+
Operation.IS_NOT_NULL: pc.is_valid,
67+
Operation.IN: _pa_is_in,
68+
Operation.NOT_IN: lambda left, right: pc.invert(_pa_is_in(left, right)),
4769
}
4870

4971

@@ -63,6 +85,10 @@ def _eval_expr_recursive(
6385
_eval_expr_recursive(expr.left, batch, ops),
6486
_eval_expr_recursive(expr.right, batch, ops),
6587
)
88+
if isinstance(expr, UnaryExpr):
89+
# TODO: Use Visitor pattern here and store ops in shared state.
90+
return ops[expr.op](_eval_expr_recursive(expr.operand, batch, ops))
91+
6692
if isinstance(expr, UDFExpr):
6793
args = [_eval_expr_recursive(arg, batch, ops) for arg in expr.args]
6894
kwargs = {
@@ -79,6 +105,7 @@ def _eval_expr_recursive(
79105
)
80106

81107
return result
108+
82109
raise TypeError(f"Unsupported expression node: {type(expr).__name__}")
83110

84111

python/ray/data/_internal/pandas_block.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,8 @@ def rename_columns(self, columns_rename: Dict[str, str]) -> "pandas.DataFrame":
320320
def upsert_column(
321321
self, column_name: str, column_data: BlockColumn
322322
) -> "pandas.DataFrame":
323+
import pyarrow
324+
323325
if isinstance(column_data, (pyarrow.Array, pyarrow.ChunkedArray)):
324326
column_data = column_data.to_pandas()
325327

python/ray/data/_internal/planner/plan_expression/expression_evaluator.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ def visit_Compare(self, node: ast.Compare) -> ds.Expression:
6565

6666
op = node.ops[0]
6767
if isinstance(op, ast.In):
68-
return left_expr.isin(comparators[0])
68+
return left_expr.is_in(comparators[0])
6969
elif isinstance(op, ast.NotIn):
70-
return ~left_expr.isin(comparators[0])
70+
return ~left_expr.is_in(comparators[0])
7171
elif isinstance(op, ast.Eq):
7272
return left_expr == comparators[0]
7373
elif isinstance(op, ast.NotEq):
@@ -210,7 +210,7 @@ def visit_Call(self, node: ast.Call) -> ds.Expression:
210210
nan_is_null=nan_is_null
211211
),
212212
"is_valid": lambda arg: arg.is_valid(),
213-
"isin": lambda arg1, arg2: arg1.isin(arg2),
213+
"is_in": lambda arg1, arg2: arg1.is_in(arg2),
214214
}
215215

216216
if func_name in function_map:
@@ -224,11 +224,11 @@ def visit_Call(self, node: ast.Call) -> ds.Expression:
224224
return function_map[func_name](args[0], args[1])
225225
else:
226226
raise ValueError("is_null function requires one or two arguments.")
227-
# Handle the "isin" function with exactly two arguments
228-
elif func_name == "isin" and len(args) != 2:
229-
raise ValueError("isin function requires two arguments.")
227+
# Handle the "is_in" function with exactly two arguments
228+
elif func_name == "is_in" and len(args) != 2:
229+
raise ValueError("is_in function requires two arguments.")
230230
# Ensure the function has one argument (for functions like is_valid)
231-
elif func_name != "isin" and len(args) != 1:
231+
elif func_name != "is_in" and len(args) != 1:
232232
raise ValueError(f"{func_name} function requires exactly one argument.")
233233
# Call the corresponding function with the arguments
234234
return function_map[func_name](*args)

python/ray/data/expressions.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from abc import ABC, abstractmethod
55
from dataclasses import dataclass, field
66
from enum import Enum
7-
from typing import Any, Callable, Dict, List
7+
from typing import Any, Callable, Dict, List, Union
88

99
from ray.data.block import BatchColumn
1010
from ray.data.datatype import DataType
@@ -23,26 +23,40 @@ class Operation(Enum):
2323
SUB: Subtraction operation (-)
2424
MUL: Multiplication operation (*)
2525
DIV: Division operation (/)
26+
FLOORDIV: Floor division operation (//)
2627
GT: Greater than comparison (>)
2728
LT: Less than comparison (<)
2829
GE: Greater than or equal comparison (>=)
2930
LE: Less than or equal comparison (<=)
3031
EQ: Equality comparison (==)
32+
NE: Not equal comparison (!=)
3133
AND: Logical AND operation (&)
3234
OR: Logical OR operation (|)
35+
NOT: Logical NOT operation (~)
36+
IS_NULL: Check if value is null
37+
IS_NOT_NULL: Check if value is not null
38+
IN: Check if value is in a list
39+
NOT_IN: Check if value is not in a list
3340
"""
3441

3542
ADD = "add"
3643
SUB = "sub"
3744
MUL = "mul"
3845
DIV = "div"
46+
FLOORDIV = "floordiv"
3947
GT = "gt"
4048
LT = "lt"
4149
GE = "ge"
4250
LE = "le"
4351
EQ = "eq"
52+
NE = "ne"
4453
AND = "and"
4554
OR = "or"
55+
NOT = "not"
56+
IS_NULL = "is_null"
57+
IS_NOT_NULL = "is_not_null"
58+
IN = "in"
59+
NOT_IN = "not_in"
4660

4761

4862
@DeveloperAPI(stability="alpha")
@@ -127,6 +141,14 @@ def __rtruediv__(self, other: Any) -> "Expr":
127141
"""Reverse division operator (for literal / expr)."""
128142
return LiteralExpr(other)._bin(self, Operation.DIV)
129143

144+
def __floordiv__(self, other: Any) -> "Expr":
145+
"""Floor division operator (//)."""
146+
return self._bin(other, Operation.FLOORDIV)
147+
148+
def __rfloordiv__(self, other: Any) -> "Expr":
149+
"""Reverse floor division operator (for literal // expr)."""
150+
return LiteralExpr(other)._bin(self, Operation.FLOORDIV)
151+
130152
# comparison
131153
def __gt__(self, other: Any) -> "Expr":
132154
"""Greater than operator (>)."""
@@ -148,6 +170,10 @@ def __eq__(self, other: Any) -> "Expr":
148170
"""Equality operator (==)."""
149171
return self._bin(other, Operation.EQ)
150172

173+
def __ne__(self, other: Any) -> "Expr":
174+
"""Not equal operator (!=)."""
175+
return self._bin(other, Operation.NE)
176+
151177
# boolean
152178
def __and__(self, other: Any) -> "Expr":
153179
"""Logical AND operator (&)."""
@@ -157,6 +183,31 @@ def __or__(self, other: Any) -> "Expr":
157183
"""Logical OR operator (|)."""
158184
return self._bin(other, Operation.OR)
159185

186+
def __invert__(self) -> "Expr":
187+
"""Logical NOT operator (~)."""
188+
return UnaryExpr(Operation.NOT, self)
189+
190+
# predicate methods
191+
def is_null(self) -> "Expr":
192+
"""Check if the expression value is null."""
193+
return UnaryExpr(Operation.IS_NULL, self)
194+
195+
def is_not_null(self) -> "Expr":
196+
"""Check if the expression value is not null."""
197+
return UnaryExpr(Operation.IS_NOT_NULL, self)
198+
199+
def is_in(self, values: Union[List[Any], "Expr"]) -> "Expr":
200+
"""Check if the expression value is in a list of values."""
201+
if not isinstance(values, Expr):
202+
values = LiteralExpr(values)
203+
return self._bin(values, Operation.IN)
204+
205+
def not_in(self, values: Union[List[Any], "Expr"]) -> "Expr":
206+
"""Check if the expression value is not in a list of values."""
207+
if not isinstance(values, Expr):
208+
values = LiteralExpr(values)
209+
return self._bin(values, Operation.NOT_IN)
210+
160211

161212
@DeveloperAPI(stability="alpha")
162213
@dataclass(frozen=True, eq=False)
@@ -257,6 +308,39 @@ def structurally_equals(self, other: Any) -> bool:
257308
)
258309

259310

311+
@DeveloperAPI(stability="alpha")
312+
@dataclass(frozen=True, eq=False)
313+
class UnaryExpr(Expr):
314+
"""Expression that represents a unary operation on a single expression.
315+
316+
This expression type represents an operation with one operand.
317+
Common unary operations include logical NOT, IS NULL, IS NOT NULL, etc.
318+
319+
Args:
320+
op: The operation to perform (from Operation enum)
321+
operand: The operand expression
322+
323+
Example:
324+
>>> from ray.data.expressions import col
325+
>>> # Check if a column is null
326+
>>> expr = col("age").is_null() # Creates UnaryExpr(IS_NULL, col("age"))
327+
>>> # Logical not
328+
>>> expr = ~(col("active")) # Creates UnaryExpr(NOT, col("active"))
329+
"""
330+
331+
op: Operation
332+
operand: Expr
333+
334+
data_type: DataType = field(init=False)
335+
336+
def structurally_equals(self, other: Any) -> bool:
337+
return (
338+
isinstance(other, UnaryExpr)
339+
and self.op is other.op
340+
and self.operand.structurally_equals(other.operand)
341+
)
342+
343+
260344
@DeveloperAPI(stability="alpha")
261345
@dataclass(frozen=True, eq=False)
262346
class UDFExpr(Expr):
@@ -517,6 +601,7 @@ def download(uri_column_name: str) -> DownloadExpr:
517601
"ColumnExpr",
518602
"LiteralExpr",
519603
"BinaryExpr",
604+
"UnaryExpr",
520605
"UDFExpr",
521606
"udf",
522607
"DownloadExpr",

0 commit comments

Comments
 (0)