Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/data/api/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Ray Data API
dataset.rst
data_iterator.rst
execution_options.rst
checkpoint.rst
aggregate.rst
grouped_data.rst
expressions.rst
Expand Down
18 changes: 18 additions & 0 deletions doc/source/data/api/checkpoint.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
.. _checkpoint-api:

Checkpoint API
==============

.. currentmodule:: ray.data.checkpoint.interfaces

Configuration
-------------

.. autosummary::
:nosignatures:
:toctree: doc/
:template: autosummary/class_without_autosummary.rst

CheckpointConfig
CheckpointBackend

14 changes: 14 additions & 0 deletions python/ray/data/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -1889,3 +1889,17 @@ py_test(
"//:ray_lib",
],
)

py_test(
name = "test_checkpoint",
size = "large",
srcs = ["tests/test_checkpoint.py"],
tags = [
"exclusive",
"team:data",
],
deps = [
":conftest",
"//:ray_lib",
],
)
7 changes: 7 additions & 0 deletions python/ray/data/_internal/planner/checkpoint/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .plan_read_op import plan_read_op_with_checkpoint_filter
from .plan_write_op import plan_write_op_with_checkpoint_writer

__all__ = [
"plan_read_op_with_checkpoint_filter",
"plan_write_op_with_checkpoint_writer",
]
47 changes: 47 additions & 0 deletions python/ray/data/_internal/planner/checkpoint/plan_read_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import functools
from typing import Callable, List, Optional

from ray import ObjectRef
from ray.data._internal.execution.interfaces import PhysicalOperator
from ray.data._internal.execution.operators.map_transformer import (
BlockMapTransformFn,
)
from ray.data._internal.logical.operators.read_operator import Read
from ray.data._internal.output_buffer import OutputBlockSizeOption
from ray.data._internal.planner.plan_read_op import plan_read_op
from ray.data.checkpoint.util import (
CHECKPOINTED_IDS_KWARG_NAME,
filter_checkpointed_rows_for_blocks,
)
from ray.data.context import DataContext


def plan_read_op_with_checkpoint_filter(
op: Read,
physical_children: List[PhysicalOperator],
data_context: DataContext,
load_checkpoint: Optional[Callable[[], ObjectRef]] = None,
) -> PhysicalOperator:
physical_op = plan_read_op(op, physical_children, data_context)

# TODO avoid modifying in-place
physical_op._map_transformer.add_transform_fns(
[
BlockMapTransformFn(
functools.partial(
filter_checkpointed_rows_for_blocks,
checkpoint_config=data_context.checkpoint_config,
),
output_block_size_option=OutputBlockSizeOption.of(
target_max_block_size=data_context.target_max_block_size,
),
),
]
)

if load_checkpoint is not None:
physical_op.add_map_task_kwargs_fn(
lambda: {CHECKPOINTED_IDS_KWARG_NAME: load_checkpoint()}
)

return physical_op
82 changes: 82 additions & 0 deletions python/ray/data/_internal/planner/checkpoint/plan_write_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import itertools
from typing import Iterable, List

from ray.data._internal.execution.interfaces import PhysicalOperator
from ray.data._internal.execution.interfaces.task_context import TaskContext
from ray.data._internal.execution.operators.map_transformer import (
BlockMapTransformFn,
)
from ray.data._internal.logical.operators.write_operator import Write
from ray.data._internal.planner.plan_write_op import (
_plan_write_op_internal,
generate_collect_write_stats_fn,
)
from ray.data.block import Block, BlockAccessor
from ray.data.checkpoint.checkpoint_writer import CheckpointWriter
from ray.data.checkpoint.interfaces import (
InvalidCheckpointingOperators,
)
from ray.data.context import DataContext
from ray.data.datasource.datasink import Datasink


def plan_write_op_with_checkpoint_writer(
op: Write, physical_children: List[PhysicalOperator], data_context: DataContext
) -> PhysicalOperator:
assert data_context.checkpoint_config is not None

collect_stats_fn = generate_collect_write_stats_fn()
write_checkpoint_for_block_fn = _generate_checkpoint_writing_transform(
data_context, op
)

physical_op = _plan_write_op_internal(
op,
physical_children,
data_context,
extra_transformations=[
write_checkpoint_for_block_fn,
collect_stats_fn,
],
)

return physical_op


def _generate_checkpoint_writing_transform(
data_context: DataContext, logical_op: Write
) -> BlockMapTransformFn:
datasink = logical_op._datasink_or_legacy_datasource
if not isinstance(datasink, Datasink):
raise InvalidCheckpointingOperators(
f"To enable checkpointing, Write operation must use a "
f"Datasink and not a legacy Datasource, but got: "
f"{type(datasink)}"
)

checkpoint_writer = CheckpointWriter.create(data_context.checkpoint_config)

# MapTransformFn for writing checkpoint files after write completes.
def write_checkpoint_for_block(
blocks: Iterable[Block], ctx: TaskContext
) -> Iterable[Block]:
it1, it2 = itertools.tee(blocks, 2)
for block in it1:
ba = BlockAccessor.for_block(block)
if ba.num_rows() > 0:
if data_context.checkpoint_config.id_column not in ba.column_names():
raise ValueError(
f"ID column {data_context.checkpoint_config.id_column} is "
f"absent in the block to be written. Do not drop or rename "
f"this column."
)
checkpoint_writer.write_block_checkpoint(ba)

return list(it2)

return BlockMapTransformFn(
write_checkpoint_for_block,
is_udf=False,
# NOTE: No need for block-shaping
disable_block_shaping=True,
)
81 changes: 81 additions & 0 deletions python/ray/data/_internal/planner/planner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import functools
import warnings
from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar

from ray import ObjectRef
from ray.data._internal.execution.execution_callback import add_execution_callback
from ray.data._internal.execution.interfaces import PhysicalOperator
from ray.data._internal.execution.operators.aggregate_num_rows import (
AggregateNumRows,
Expand Down Expand Up @@ -35,6 +39,10 @@
from ray.data._internal.logical.operators.read_operator import Read
from ray.data._internal.logical.operators.streaming_split_operator import StreamingSplit
from ray.data._internal.logical.operators.write_operator import Write
from ray.data._internal.planner.checkpoint import (
plan_read_op_with_checkpoint_filter,
plan_write_op_with_checkpoint_writer,
)
from ray.data._internal.planner.plan_all_to_all_op import plan_all_to_all_op
from ray.data._internal.planner.plan_download_op import plan_download_op
from ray.data._internal.planner.plan_read_op import plan_read_op
Expand All @@ -45,6 +53,7 @@
plan_udf_map_op,
)
from ray.data._internal.planner.plan_write_op import plan_write_op
from ray.data.checkpoint.load_checkpoint_callback import LoadCheckpointCallback
from ray.data.context import DataContext

LogicalOperatorType = TypeVar("LogicalOperatorType", bound=LogicalOperator)
Expand Down Expand Up @@ -159,16 +168,50 @@ class Planner:
StreamingSplit: plan_streaming_split_op,
Download: plan_download_op,
}
# Operators that support checkpoint filtering. Subclasses can override.
_CHECKPOINT_FILTER_OPS = (Read,)

def __init__(self):
self._supports_checkpointing = False
self._plan_fns_for_checkpointing = {}

def plan(self, logical_plan: LogicalPlan) -> PhysicalPlan:
"""Convert logical to physical operators recursively in post-order."""
checkpoint_config = logical_plan.context.checkpoint_config
if checkpoint_config is not None and self._check_supports_checkpointing(
logical_plan
):
self._supports_checkpointing = True

checkpoint_callback = self._create_checkpoint_callback(checkpoint_config)
add_execution_callback(checkpoint_callback, logical_plan.context)
load_checkpoint = checkpoint_callback.load_checkpoint

# Dynamically set the plan functions for checkpointing because they
# need to a reference to the checkpoint ref.
self._plan_fns_for_checkpointing = self._get_plan_fns_for_checkpointing(
load_checkpoint
)

elif checkpoint_config is not None:
assert not self._check_supports_checkpointing(logical_plan)
warnings.warn(
"You've enabled checkpointing, but the logical plan doesn't support "
"checkpointing. Checkpointing will be disabled."
)
physical_dag, op_map = self._plan_recursively(
logical_plan.dag, logical_plan.context
)
physical_plan = PhysicalPlan(physical_dag, op_map, logical_plan.context)
return physical_plan

def get_plan_fn(self, logical_op: LogicalOperator) -> PlanLogicalOpFn:
if self._supports_checkpointing:
assert self._plan_fns_for_checkpointing
plan_fn = find_plan_fn(logical_op, self._plan_fns_for_checkpointing)
if plan_fn is not None:
return plan_fn

plan_fn = find_plan_fn(logical_op, self._DEFAULT_PLAN_FNS)
if plan_fn is not None:
return plan_fn
Expand Down Expand Up @@ -223,6 +266,44 @@ def _plan_recursively(
op_map[physical_op] = logical_op
return physical_op, op_map

def _create_checkpoint_callback(self, checkpoint_config) -> LoadCheckpointCallback:
"""Factory method to create the LoadCheckpointCallback.

Subclasses can override this to use a different callback implementation.
"""
return LoadCheckpointCallback(checkpoint_config)

def _get_plan_fns_for_checkpointing(
self,
load_checkpoint: Callable[[], ObjectRef],
) -> Dict[Type[LogicalOperator], PlanLogicalOpFn]:
plan_fns = {
Read: functools.partial(
plan_read_op_with_checkpoint_filter,
load_checkpoint=load_checkpoint,
),
Write: plan_write_op_with_checkpoint_writer,
}
return plan_fns

def _check_supports_checkpointing(self, logical_plan: LogicalPlan) -> bool:
"""Check if the logical plan supports checkpointing.

Subclasses can override _CHECKPOINT_FILTER_OPS to support more operators.
"""
if not isinstance(logical_plan.dag, (Write, StreamingSplit)):
return False

def _all_paths_contain_checkpoint_filter(op: LogicalOperator) -> bool:
if isinstance(op, self._CHECKPOINT_FILTER_OPS):
return True
return all(
_all_paths_contain_checkpoint_filter(input_dep)
for input_dep in op.input_dependencies
)

return _all_paths_contain_checkpoint_filter(logical_plan.dag)


def find_plan_fn(
logical_op: LogicalOperator, plan_fns: Dict[Type[LogicalOperator], PlanLogicalOpFn]
Expand Down
3 changes: 3 additions & 0 deletions python/ray/data/checkpoint/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .interfaces import CheckpointBackend, CheckpointConfig

__all__ = ["CheckpointConfig", "CheckpointBackend"]
Loading