Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions python/ray/data/_internal/datasource/iceberg_datasink.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
if TYPE_CHECKING:
import pyarrow as pa
from pyiceberg.catalog import Catalog
from pyiceberg.io import FileIO
from pyiceberg.manifest import DataFile
from pyiceberg.schema import Schema
from pyiceberg.table import Table
from pyiceberg.table.metadata import TableMetadata
from pyiceberg.table.update.schema import UpdateSchema

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -135,6 +137,8 @@ def __init__(
self._catalog_name = "default"

self._table: "Table" = None
self._io: "FileIO" = None
self._table_metadata: "TableMetadata" = None

def __getstate__(self) -> dict:
"""Exclude `_table` during pickling."""
Expand All @@ -155,14 +159,17 @@ def _reload_table(self) -> None:
"""Reload the Iceberg table from the catalog."""
catalog = self._get_catalog()
self._table = catalog.load_table(self.table_identifier)
self._io = self._table.io
self._table_metadata = self._table.metadata

def _get_upsert_cols(self) -> List[str]:
"""Get join columns for upsert, using table identifier fields as fallback."""
upsert_cols = self._upsert_kwargs.get(_UPSERT_COLS_ID, [])
if not upsert_cols:
# Use table's identifier fields as fallback
for field_id in self._table.metadata.schema().identifier_field_ids:
col_name = self._table.metadata.schema().find_column_name(field_id)
schema = self._table_metadata.schema()
for field_id in schema.identifier_field_ids:
col_name = schema.find_column_name(field_id)
if col_name:
upsert_cols.append(col_name)
return upsert_cols
Expand Down Expand Up @@ -341,11 +348,6 @@ def write(self, blocks: Iterable[Block], ctx: TaskContext) -> IcebergWriteResult
"""
from pyiceberg.io.pyarrow import _dataframe_to_data_files

# Workers receive a pickled datasink with _table=None (excluded during
# serialization), so we reload it on first use.
if self._table is None:
self._reload_table()

all_data_files = []
upsert_keys_tables = []
block_schemas = []
Expand All @@ -365,9 +367,9 @@ def write(self, blocks: Iterable[Block], ctx: TaskContext) -> IcebergWriteResult
# Write data files to storage
data_files = list(
_dataframe_to_data_files(
table_metadata=self._table.metadata,
table_metadata=self._table_metadata,
df=pa_table,
io=self._table.io,
io=self._io,
)
)
all_data_files.extend(data_files)
Expand Down