Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
c1329af
[Data] UDF Expression Support for with_column
goutamvenkat-anyscale Aug 20, 2025
e569229
Modify docs
goutamvenkat-anyscale Aug 20, 2025
702d73b
Small fixes
goutamvenkat-anyscale Aug 20, 2025
3733f01
Hash collision avoidance
goutamvenkat-anyscale Aug 20, 2025
c26460f
Merge branch 'master' into goutam/udf_expr
goutamvenkat-anyscale Aug 20, 2025
1c159ca
add todo
goutamvenkat-anyscale Aug 20, 2025
aada462
Merge branch 'master' into goutam/udf_expr
goutamvenkat-anyscale Aug 20, 2025
da27af3
Doc fix
goutamvenkat-anyscale Aug 20, 2025
7f815f2
Merge branch 'master' into goutam/udf_expr
goutamvenkat-anyscale Aug 20, 2025
8fc8e8b
Guard pydantic import
goutamvenkat-anyscale Aug 21, 2025
6cccdd7
Replace Pydantic usage with Dataclass
goutamvenkat-anyscale Aug 21, 2025
f41ddb6
Merge branch 'master' into goutam/udf_expr
goutamvenkat-anyscale Aug 21, 2025
1579925
one more doc lint
goutamvenkat-anyscale Aug 21, 2025
a64c879
more
goutamvenkat-anyscale Aug 21, 2025
bc305ef
Merge branch 'master' into goutam/udf_expr
goutamvenkat-anyscale Aug 21, 2025
696c6aa
Add rst changes
goutamvenkat-anyscale Aug 21, 2025
325981f
Some comments
goutamvenkat-anyscale Aug 25, 2025
ee0013d
Remove condition
goutamvenkat-anyscale Aug 25, 2025
2a1481c
Remove batch_size from with_column
goutamvenkat-anyscale Aug 25, 2025
c35b130
Merge branch 'master' into goutam/udf_expr
goutamvenkat-anyscale Aug 25, 2025
6cea48a
Merge branch 'master' into goutam/udf_expr
goutamvenkat-anyscale Aug 27, 2025
4864369
Merge branch 'goutam/udf_expr' of https://github.com/goutamvenkat-any…
goutamvenkat-anyscale Aug 27, 2025
3b11314
Remove datatype from this PR
goutamvenkat-anyscale Aug 27, 2025
8e0dd2b
Fix lint stuff
goutamvenkat-anyscale Aug 27, 2025
de098f0
Add rst back
goutamvenkat-anyscale Aug 27, 2025
519dcff
one more comment
goutamvenkat-anyscale Aug 28, 2025
60f37cb
Merge branch 'master' into goutam/udf_expr
goutamvenkat-anyscale Aug 28, 2025
4069ad6
Clean up
goutamvenkat-anyscale Aug 28, 2025
2746b4f
Address comments
goutamvenkat-anyscale Aug 29, 2025
dac02d4
more comments
goutamvenkat-anyscale Aug 29, 2025
0b9b87a
Clean up
goutamvenkat-anyscale Aug 29, 2025
edbd43e
Imports
goutamvenkat-anyscale Aug 29, 2025
e2194d6
Merge from master - conflicts
goutamvenkat-anyscale Aug 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/data/api/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ Ray Data API
data_context.rst
preprocessor.rst
llm.rst
from_other_data_libs.rst
from_other_data_libs.rst
3 changes: 2 additions & 1 deletion doc/source/data/api/expressions.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
.. _expressions-api:

Expressions API
===============
================

.. currentmodule:: ray.data.expressions

Expand All @@ -19,6 +19,7 @@ Public API

col
lit
udf
download

Expression Classes
Expand Down
27 changes: 24 additions & 3 deletions python/ray/data/_expression_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,19 @@
import operator
from typing import Any, Callable, Dict

import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.compute as pc

from ray.data.block import DataBatch
from ray.data.expressions import (
BinaryExpr,
ColumnExpr,
Expr,
LiteralExpr,
Operation,
UDFExpr,
)

_PANDAS_EXPR_OPS_MAP = {
Expand Down Expand Up @@ -44,7 +47,9 @@
}


def _eval_expr_recursive(expr: "Expr", batch, ops: Dict["Operation", Callable]) -> Any:
def _eval_expr_recursive(
expr: "Expr", batch: DataBatch, ops: Dict["Operation", Callable[..., Any]]
) -> Any:
"""Generic recursive expression evaluator."""
# TODO: Separate unresolved expressions (arbitrary AST with unresolved refs)
# and resolved expressions (bound to a schema) for better error handling
Expand All @@ -58,10 +63,26 @@ def _eval_expr_recursive(expr: "Expr", batch, ops: Dict["Operation", Callable])
_eval_expr_recursive(expr.left, batch, ops),
_eval_expr_recursive(expr.right, batch, ops),
)
raise TypeError(f"Unsupported expression node: {type(expr).__name__}")
if isinstance(expr, UDFExpr):
args = [_eval_expr_recursive(arg, batch, ops) for arg in expr.args]
kwargs = {
k: _eval_expr_recursive(v, batch, ops) for k, v in expr.kwargs.items()
}
result = expr.fn(*args, **kwargs)

# Can't perform type validation for unions if python version is < 3.10
if not isinstance(result, (pd.Series, np.ndarray, pa.Array, pa.ChunkedArray)):
function_name = expr.fn.__name__
raise TypeError(
f"UDF '{function_name}' returned invalid type {type(result).__name__}. "
f"Expected type (pandas.Series, numpy.ndarray, pyarrow.Array, or pyarrow.ChunkedArray)"
)

def eval_expr(expr: "Expr", batch) -> Any:
return result
raise TypeError(f"Unsupported expression node: {type(expr).__name__}")


def eval_expr(expr: "Expr", batch: DataBatch) -> Any:
"""Recursively evaluate *expr* against a batch of the appropriate type."""
if isinstance(batch, pd.DataFrame):
return _eval_expr_recursive(expr, batch, _PANDAS_EXPR_OPS_MAP)
Expand Down
5 changes: 5 additions & 0 deletions python/ray/data/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@
# Represents a single column of the ``Block``
BlockColumn = Union["pyarrow.ChunkedArray", "pyarrow.Array", "pandas.Series"]

# Represents a single column of the ``Batch``
BatchColumn = Union[
"pandas.Series", "np.ndarray", "pyarrow.Array", "pyarrow.ChunkedArray"
]


logger = logging.getLogger(__name__)

Expand Down
37 changes: 28 additions & 9 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,31 +783,50 @@ def _map_batches_without_batch_size_validation(
return Dataset(plan, logical_plan)

@PublicAPI(api_group=EXPRESSION_API_GROUP, stability="alpha")
def with_column(self, column_name: str, expr: Expr, **ray_remote_args) -> "Dataset":
def with_column(
self,
column_name: str,
expr: Expr,
**ray_remote_args,
) -> "Dataset":
"""
Add a new column to the dataset via an expression.

Examples:
This method allows you to add a new column to a dataset by applying an
expression. The expression can be composed of existing columns, literals,
and user-defined functions (UDFs).

Examples:
>>> import ray
>>> from ray.data.expressions import col
>>> ds = ray.data.range(100)
>>> ds.with_column("id_2", (col("id") * 2)).schema()
Column Type
------ ----
id int64
id_2 int64
>>> # Add a new column 'id_2' by multiplying 'id' by 2.
>>> ds.with_column("id_2", col("id") * 2).show(2)
{'id': 0, 'id_2': 0}
{'id': 1, 'id_2': 2}

>>> # Using a UDF with with_column
>>> from ray.data.expressions import udf
>>> import pyarrow.compute as pc
>>>
>>> @udf()
... def add_one(column):
... return pc.add(column, 1)
>>>
>>> ds.with_column("id_plus_one", add_one(col("id"))).show(2)
{'id': 0, 'id_plus_one': 1}
{'id': 1, 'id_plus_one': 2}

Args:
column_name: The name of the new column.
expr: An expression that defines the new column values.
**ray_remote_args: Additional resource requirements to request from
Ray (e.g., num_gpus=1 to request GPUs for the map tasks). See
:func:`ray.remote` for details.
Ray for the map tasks (e.g., `num_gpus=1`).

Returns:
A new dataset with the added column evaluated via the expression.
"""
# TODO: update schema based on the expression AST.
from ray.data._internal.logical.operators.map_operator import Download, Project

# TODO: Once the expression API supports UDFs, we can clean up the code here.
Expand Down
142 changes: 141 additions & 1 deletion python/ray/data/expressions.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import functools
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Any
from typing import Any, Callable, Dict, List

from ray.data.block import BatchColumn
from ray.util.annotations import DeveloperAPI, PublicAPI


Expand Down Expand Up @@ -239,6 +241,142 @@ def structurally_equals(self, other: Any) -> bool:
)


@DeveloperAPI(stability="alpha")
@dataclass(frozen=True, eq=False)
class UDFExpr(Expr):
"""Expression that represents a user-defined function call.

This expression type wraps a UDF with schema inference capabilities,
allowing UDFs to be used seamlessly within the expression system.

UDFs operate on batches of data, where each column argument is passed
as a PyArrow Array containing multiple values from that column across the batch.

Args:
fn: The user-defined function to call
args: List of argument expressions (positional arguments)
kwargs: Dictionary of keyword argument expressions
function_name: Optional name for the function (for debugging)

Example:
>>> from ray.data.expressions import col, udf
>>> import pyarrow as pa
>>> import pyarrow.compute as pc
>>>
>>> @udf()
... def add_one(x: pa.Array) -> pa.Array:
... return pc.add(x, 1)
>>>
>>> # Use in expressions
>>> expr = add_one(col("value"))
"""

fn: Callable[..., BatchColumn]
args: List[Expr]
kwargs: Dict[str, Expr]

def structurally_equals(self, other: Any) -> bool:
return (
isinstance(other, UDFExpr)
and self.fn == other.fn
and len(self.args) == len(other.args)
and all(a.structurally_equals(b) for a, b in zip(self.args, other.args))
and self.kwargs.keys() == other.kwargs.keys()
and all(
self.kwargs[k].structurally_equals(other.kwargs[k])
for k in self.kwargs.keys()
)
)


def _create_udf_callable(fn: Callable[..., BatchColumn]) -> Callable[..., UDFExpr]:
"""Create a callable that generates UDFExpr when called with expressions."""

def udf_callable(*args, **kwargs) -> UDFExpr:
# Convert arguments to expressions if they aren't already
expr_args = []
for arg in args:
if isinstance(arg, Expr):
expr_args.append(arg)
else:
expr_args.append(LiteralExpr(arg))

expr_kwargs = {}
for k, v in kwargs.items():
if isinstance(v, Expr):
expr_kwargs[k] = v
else:
expr_kwargs[k] = LiteralExpr(v)

return UDFExpr(
fn=fn,
args=expr_args,
kwargs=expr_kwargs,
)

# Preserve original function metadata
functools.update_wrapper(udf_callable, fn)

# Store the original function for access if needed
udf_callable._original_fn = fn

return udf_callable


@PublicAPI(stability="alpha")
def udf() -> Callable[..., UDFExpr]:
"""
Decorator to convert a UDF into an expression-compatible function.

This decorator allows UDFs to be used seamlessly within the expression system,
enabling schema inference and integration with other expressions.

IMPORTANT: UDFs operate on batches of data, not individual rows. When your UDF
is called, each column argument will be passed as a PyArrow Array containing
multiple values from that column across the batch. Under the hood, when working
with multiple columns, they get translated to PyArrow arrays (one array per column).

Returns:
A callable that creates UDFExpr instances when called with expressions

Example:
>>> from ray.data.expressions import col, udf
>>> import pyarrow as pa
>>> import pyarrow.compute as pc
>>> import ray
>>>
>>> # UDF that operates on a batch of values (PyArrow Array)
>>> @udf()
... def add_one(x: pa.Array) -> pa.Array:
... return pc.add(x, 1) # Vectorized operation on the entire Array
>>>
>>> # UDF that combines multiple columns (each as a PyArrow Array)
>>> @udf()
... def format_name(first: pa.Array, last: pa.Array) -> pa.Array:
... return pc.binary_join_element_wise(first, last, " ") # Vectorized string concatenation
>>>
>>> # Use in dataset operations
>>> ds = ray.data.from_items([
... {"value": 5, "first": "John", "last": "Doe"},
... {"value": 10, "first": "Jane", "last": "Smith"}
... ])
>>>
>>> # Single column transformation (operates on batches)
>>> ds_incremented = ds.with_column("value_plus_one", add_one(col("value")))
>>>
>>> # Multi-column transformation (each column becomes a PyArrow Array)
>>> ds_formatted = ds.with_column("full_name", format_name(col("first"), col("last")))
>>>
>>> # Can also be used in complex expressions
>>> ds_complex = ds.with_column("doubled_plus_one", add_one(col("value")) * 2)
"""

def decorator(func: Callable[..., BatchColumn]) -> Callable[..., UDFExpr]:
return _create_udf_callable(func)

return decorator


@DeveloperAPI(stability="alpha")
@dataclass(frozen=True, eq=False)
class DownloadExpr(Expr):
Expand Down Expand Up @@ -356,6 +494,8 @@ def download(uri_column_name: str) -> DownloadExpr:
"ColumnExpr",
"LiteralExpr",
"BinaryExpr",
"UDFExpr",
"udf",
"DownloadExpr",
"col",
"lit",
Expand Down
Loading