2525 _BATCH_SIZE_PRESERVING_STUB_COL_NAME ,
2626 ArrowBlockAccessor ,
2727)
28+ from ray .data ._internal .planner .plan_expression .expression_visitors import (
29+ get_column_references ,
30+ )
2831from ray .data ._internal .progress_bar import ProgressBar
2932from ray .data ._internal .remote_fn import cached_remote_fn
3033from ray .data ._internal .util import (
5255from ray .data .datasource .path_util import (
5356 _resolve_paths_and_filesystem ,
5457)
58+ from ray .data .expressions import BinaryExpr , Expr , Operation
5559from ray .util .debug import log_once
5660
5761if TYPE_CHECKING :
@@ -160,6 +164,120 @@ def check_for_legacy_tensor_type(schema):
160164 )
161165
162166
167+ @dataclass
168+ class _SplitPredicateResult :
169+ """Result of splitting a predicate by column type.
170+
171+ Attributes:
172+ data_predicate: Expression containing only data column predicates
173+ (for PyArrow pushdown), or None if no data predicates exist.
174+ partition_predicate: Expression containing only partition column predicates
175+ (for partition pruning), or None if no partition predicates exist.
176+ """
177+
178+ data_predicate : Optional [Expr ]
179+ partition_predicate : Optional [Expr ]
180+
181+
182+ def _split_predicate_by_columns (
183+ predicate : Expr ,
184+ partition_columns : set ,
185+ ) -> _SplitPredicateResult :
186+ """Split a predicate into data-only and partition-only parts.
187+
188+ This function extracts both data column predicates and partition column
189+ predicates from AND chains, enabling both PyArrow pushdown (data part) and
190+ partition pruning (partition part).
191+
192+ Args:
193+ predicate: The predicate expression to analyze.
194+ partition_columns: Set of partition column names.
195+
196+ Returns:
197+ _SplitPredicateResult containing:
198+ - data_predicate: Expression with only data columns (for PyArrow pushdown),
199+ or None if no data predicates can be extracted.
200+ - partition_predicate: Expression with only partition columns (for pruning),
201+ or None if no partition predicates can be extracted.
202+
203+ Examples:
204+ >>> from ray.data.expressions import col
205+ >>> # Pure data predicate:
206+ >>> result = _split_predicate_by_columns(col("data1") > 5, {"partition_col"})
207+ >>> result.data_predicate is not None # Should have data predicate
208+ True
209+ >>> result.partition_predicate is None # Should not have partition predicate
210+ True
211+
212+ >>> # Pure partition predicate:
213+ >>> result = _split_predicate_by_columns(col("partition_col") == "US", {"partition_col"})
214+ >>> result.data_predicate is None # Should not have data predicate
215+ True
216+ >>> result.partition_predicate is not None # Should have partition predicate
217+ True
218+
219+ >>> # Mixed AND - can split both parts:
220+ >>> result = _split_predicate_by_columns(
221+ ... (col("data1") > 5) & (col("partition_col") == "US"),
222+ ... {"partition_col"}
223+ ... )
224+ >>> result.data_predicate is not None # Should have data predicate
225+ True
226+ >>> result.partition_predicate is not None # Should have partition predicate
227+ True
228+
229+ >>> # Mixed OR - can't split safely:
230+ >>> result = _split_predicate_by_columns(
231+ ... (col("data1") > 5) | (col("partition_col") == "US"),
232+ ... {"partition_col"}
233+ ... )
234+ >>> result.data_predicate is None # Should not have data predicate
235+ True
236+ >>> result.partition_predicate is None # Should not have partition predicate
237+ True
238+ """
239+ referenced_cols = set (get_column_references (predicate ))
240+ data_cols = referenced_cols - partition_columns
241+ partition_cols_in_predicate = referenced_cols & partition_columns
242+
243+ if not partition_cols_in_predicate :
244+ # Pure data predicate
245+ return _SplitPredicateResult (data_predicate = predicate , partition_predicate = None )
246+
247+ if not data_cols :
248+ # Pure partition predicate
249+ return _SplitPredicateResult (data_predicate = None , partition_predicate = predicate )
250+
251+ # Mixed predicate - try to split if it's an AND chain
252+ if isinstance (predicate , BinaryExpr ) and predicate .op == Operation .AND :
253+ # Recursively split left and right sides
254+ left_result = _split_predicate_by_columns (predicate .left , partition_columns )
255+ right_result = _split_predicate_by_columns (predicate .right , partition_columns )
256+
257+ # Helper to combine predicates from both sides
258+ def combine_predicates (
259+ left : Optional [Expr ], right : Optional [Expr ]
260+ ) -> Optional [Expr ]:
261+ if left and right :
262+ return left & right
263+ return left or right
264+
265+ data_predicate = combine_predicates (
266+ left_result .data_predicate , right_result .data_predicate
267+ )
268+ partition_predicate = combine_predicates (
269+ left_result .partition_predicate , right_result .partition_predicate
270+ )
271+
272+ return _SplitPredicateResult (
273+ data_predicate = data_predicate , partition_predicate = partition_predicate
274+ )
275+
276+ # For OR, NOT, or other operations with mixed columns,
277+ # we can't safely split - must evaluate the full predicate together
278+ return _SplitPredicateResult (data_predicate = None , partition_predicate = None )
279+
280+
163281class ParquetDatasource (Datasource ):
164282 """Parquet datasource, for reading and writing Parquet files.
165283
@@ -255,9 +373,13 @@ def __init__(
255373 # columns manually.
256374 data_columns , partition_columns = None , None
257375 if columns is not None :
258- data_columns , partition_columns = _infer_data_and_partition_columns (
259- columns , pq_ds .fragments [0 ], partitioning
260- )
376+ if pq_ds .fragments :
377+ data_columns , partition_columns = _infer_data_and_partition_columns (
378+ columns , pq_ds .fragments [0 ], partitioning
379+ )
380+ else :
381+ # Empty dataset - can't infer columns without fragments
382+ data_columns , partition_columns = [], []
261383
262384 if to_batch_kwargs is None :
263385 to_batch_kwargs = {}
@@ -274,11 +396,35 @@ def __init__(
274396 self ._to_batches_kwargs = to_batch_kwargs
275397 # Store as projection_map (identity mapping if columns specified, None otherwise)
276398 # Note: Empty list [] means no columns, None means all columns
277- if data_columns is None :
399+ # Include partition columns in projection_map if they were requested, so that
400+ # projection pushdown can properly track them
401+ if data_columns is None and partition_columns is None :
278402 self ._projection_map = None
279403 else :
280- self ._projection_map = {col : col for col in data_columns }
281- self ._partition_columns = partition_columns
404+ self ._projection_map = {}
405+ if data_columns is not None :
406+ self ._projection_map .update ({col : col for col in data_columns })
407+ if partition_columns is not None :
408+ self ._projection_map .update ({col : col for col in partition_columns })
409+
410+ # Eagerly compute the actual partition columns for _partition_columns.
411+ # This ensures _partition_columns is always a list (never None).
412+ actual_partition_columns = partition_columns
413+ if partition_columns is None and partitioning is not None and pq_ds .fragments :
414+ parse = PathPartitionParser (partitioning )
415+ parsed_partitions = parse (pq_ds .fragments [0 ].path )
416+ if parsed_partitions :
417+ actual_partition_columns = list (parsed_partitions .keys ())
418+
419+ # Store selected partition columns. Always a list (never None) representing
420+ # the actual partition columns to include.
421+ self ._partition_columns = (
422+ actual_partition_columns if actual_partition_columns is not None else []
423+ )
424+ # Track whether partition columns were explicitly part of the user's column selection
425+ self ._partition_columns_selected = (
426+ partition_columns is not None and len (self ._partition_columns ) > 0
427+ )
282428 self ._read_schema = schema
283429 self ._file_schema = pq_ds .schema
284430 self ._partition_schema = _get_partition_columns_schema (
@@ -390,7 +536,7 @@ def get_read_tasks(
390536 self ._default_batch_size ,
391537 self ._get_data_columns (),
392538 self .get_column_renames (),
393- self ._partition_columns ,
539+ self ._get_partition_columns () ,
394540 self ._read_schema ,
395541 self ._include_paths ,
396542 self ._partitioning ,
@@ -441,10 +587,123 @@ def get_current_projection(self) -> Optional[List[str]]:
441587 # NOTE: In case there's no projection both file and partition columns
442588 # will be none
443589 data_columns = self ._get_data_columns ()
444- if data_columns is None and self ._partition_columns is None :
590+ partition_columns = self ._get_partition_columns ()
591+ if data_columns is None and partition_columns is None :
445592 return None
446593
447- return (data_columns or []) + (self ._partition_columns or [])
594+ return (data_columns or []) + (partition_columns or [])
595+
596+ def _get_partition_columns (self ) -> Optional [List [str ]]:
597+ """Extract partition columns from projection map.
598+
599+ This method extracts partition columns from _projection_map, which is the
600+ source of truth after projection pushdown. Since partition columns are now
601+ included in _projection_map during initialization when requested, we can
602+ reliably extract them from the map.
603+
604+ Returns:
605+ List of partition column names in the projection, None if there's
606+ no projection (meaning include all partition columns), or [] if
607+ partition columns aren't in the projection map (meaning include
608+ no partition columns).
609+ """
610+ if self ._projection_map is None :
611+ return None
612+
613+ if not self ._partition_columns :
614+ return None
615+
616+ # Extract partition columns that are in the projection map
617+ partition_cols = [
618+ col for col in self ._projection_map .keys () if col in self ._partition_columns
619+ ]
620+
621+ # If partition columns are found in projection map, return them
622+ if partition_cols :
623+ return partition_cols
624+
625+ # No partition columns in projection map.
626+ # Since the projection map exists and is the source of truth after
627+ # projection pushdown, return [] (no partition columns to include).
628+ return []
629+
630+ def _get_data_columns (self ) -> Optional [List [str ]]:
631+ """Extract data columns from projection map, excluding partition columns.
632+
633+ Partition columns aren't in the physical file schema, so they must be
634+ filtered out before passing to PyArrow's to_batches().
635+
636+ Returns:
637+ List of data column names to read from files, or None if no projection.
638+ Can return empty list if only partition columns are projected.
639+ """
640+ if self ._projection_map is None :
641+ return None
642+
643+ # Get partition columns and filter them out from the projection
644+ partition_cols = self ._partition_columns
645+ data_cols = [
646+ col for col in self ._projection_map .keys () if col not in partition_cols
647+ ]
648+
649+ return data_cols
650+
651+ def apply_predicate (
652+ self ,
653+ predicate_expr : Expr ,
654+ ) -> "ParquetDatasource" :
655+ """Apply a predicate with data pushdown and partition pruning.
656+
657+ This method optimizes predicates in three ways:
658+ 1. Data predicates → pushed to PyArrow (row-level filtering)
659+ 2. Partition predicates → used for partition pruning (file-level filtering)
660+ 3. Mixed predicates → both optimizations applied together
661+ """
662+ partition_cols = set (self ._partition_columns )
663+
664+ if not partition_cols :
665+ # No partition columns - can push down everything normally
666+ return super ().apply_predicate (predicate_expr )
667+
668+ # Split predicate into data and partition parts
669+ split_result = _split_predicate_by_columns (predicate_expr , partition_cols )
670+
671+ # Apply partition pruning if we have a partition predicate
672+ if (
673+ split_result .partition_predicate is not None
674+ and self ._partitioning is not None
675+ ):
676+ parser = PathPartitionParser (self ._partitioning )
677+ pruned_fragments = []
678+ pruned_paths = []
679+
680+ for fragment , path in zip (self ._pq_fragments , self ._pq_paths ):
681+ # Evaluate partition predicate - skip if it doesn't match
682+ if parser .evaluate_predicate_on_partition (
683+ path , split_result .partition_predicate
684+ ):
685+ pruned_fragments .append (fragment )
686+ pruned_paths .append (path )
687+
688+ # Apply partition pruning directly to self
689+ self ._pq_fragments = pruned_fragments
690+ self ._pq_paths = pruned_paths
691+
692+ # Push down data predicate to PyArrow if present
693+ # Create a copy and push down the data predicate to PyArrow
694+ import copy
695+
696+ datasource = copy .copy (self )
697+
698+ # Only call apply_predicate if there's a data predicate to push down
699+ # If data_predicate is None (pure partition predicate), skip it to avoid
700+ # creating invalid expressions like existing_expr & None
701+ if split_result .data_predicate is not None :
702+ return super (ParquetDatasource , datasource ).apply_predicate (
703+ split_result .data_predicate
704+ )
705+
706+ return datasource
448707
449708 def _estimate_in_mem_size (self , fragments : List [_ParquetFragment ]) -> int :
450709 in_mem_size = sum ([f .file_size for f in fragments ]) * self ._encoding_ratio
0 commit comments