Skip to content

[Data] Limit operator is applied before map_groups() #60872

@aschuh-hf

Description

@aschuh-hf

What happened + What you expected to happen

I have a Ray Data pipeline consisting of the following steps:

  1. Image source which loads input images
  2. Sampling of a set of patches from each image
  3. Image patch encoder (batch inference)
  4. Group by image UID, and collect patch embeddings for each image

While developing this pipeline, especially the use of groupby().map_groups() which is new to me, I used ds.take() at the end so I can check the output of the map_groups operation. The result was incorrect, where the number of embedding vectors per image was consisting of only 1 element when using take(1) instead of the expected number of embeddings (e.g., a PxD matrix rather where P is number of patches that were extracted from a given image).

I noticed in the Data Dashboard that the limit operation added by take() was actually performed before the map_groups(). Only after setting up a toy script as posted below I realized that this was causing the unexpected behavior.

With above pipeline, I would have expected take(n) to return the embeddings computed for n images. Not limit the per patch embeddings to n before executing the map_groups() operation.

Versions / Dependencies

Ray 2.53.0

Reproduction script

The following script outputs:

Image ID: UUID_2 Points: (1, 3) Features: (1, 128) Labels: (1,)
Image ID: UUID_2 Points: (18, 3) Features: (18, 128) Labels: (18,)

The first output is from the use of take(1) exhibiting the issue, while the second output is the expected result.

"""Ray Data pipeline to test and debug ray.data.Dataset.groupby().map_groups()."""

# %%
# Imports
import numpy as np
import ray.data
from ray.data.block import DataBatch


# %%
# Initialize Ray
if ray.is_initialized():
    ray.shutdown()

ray.init(
    runtime_env={
        "env_vars": {
            "RAY_DEBUG": "0",
            "RAY_DEBUG_POST_MORTEM": "0",
        },
    }
)


# %%
# Dummy data source
image_ds = ray.data.from_items([{"image_id": f"UUID_{i}"} for i in range(10)])


# %%
# Extract image patches
def extract_patches(row: dict) -> list[dict]:
    """Dummy function to simulate extraction of multiple patches from each image (1-to-N flat_map)."""
    image_id = row["image_id"]
    n_points = np.random.randint(10, 20)
    point_ids = np.arange(n_points)
    points = np.random.rand(n_points, 3)
    labels = np.random.randint(0, 2, size=n_points)
    patches = [
        {
            "image_id": image_id,
            "point_id": point_id,
            "point": point,
            "label": label
        }
        for point_id, point, label in zip(point_ids, points, labels)
    ]
    return patches


patches_ds = image_ds.flat_map(extract_patches, num_cpus=0.01)


# %%
# Process image patches
def process_patches(batch: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
    """Dummy function to simulate processing of image patches (e.g. feature extraction)."""
    batch_size = batch["point"].shape[0]
    batch["feats"] = np.random.rand(batch_size, 128)
    return batch


processed_patches_ds = patches_ds.map_batches(
    process_patches,  # type: ignore[arg-type]
    batch_format="numpy",
    batch_size=64,
    num_cpus=0.01,
)


# %%
# Collect image features
def agg_grouped_features(group: DataBatch) -> DataBatch:
    """Aggregate patch features for each image into one matrix."""
    image_id = np.unique(group["image_id"])
    if len(image_id) != 1:
        raise ValueError(f"Expected exactly one unique image_id in group, found {image_id}")
    point_ids = group["point_id"]
    n_points = point_ids.shape[0]
    if not np.array_equal(np.sort(point_ids), np.arange(n_points)):
        raise ValueError(f"Expected point_ids of group to contain all indices from 0 to {n_points - 1}")
    order = np.argsort(point_ids)
    points = group["point"][order]
    feats = group["feats"][order]
    labels = group["label"][order]
    return {
        "image_id": image_id,
        "points": np.expand_dims(points, axis=0),
        "feats": np.expand_dims(feats, axis=0),
        "labels": np.expand_dims(labels, axis=0),
    }


features_ds = processed_patches_ds.groupby("image_id").map_groups(
    agg_grouped_features,
    batch_format="numpy",
    num_cpus=0.01,
)

# Uncomment the below line to avoid the issue.
# features_ds = features_ds.materialize()


# %%
# Take samples from final dataset
#
# FIXME: The limit() operator is moved before the map_groups() operation, causing the wrong number of items to be grouped.
# Instead of obtaining 1 item with the features from all patches across that image, only the features of 1 patch are returned.
#
# This is not the case when using features_ds.take_all() or after materializing the dataset first.
samples = features_ds.take_all()
sample = features_ds.take(1)[0]

print(
    # fmt: off
    "Image ID:", sample["image_id"],
    "Points:", sample["points"].shape,
    "Features:", sample["feats"].shape,
    "Labels:", sample["labels"].shape,
    # fmt: on
)

print(
    # fmt: off
    "Image ID:", samples[0]["image_id"],
    "Points:", samples[0]["points"].shape,
    "Features:", samples[0]["feats"].shape,
    "Labels:", samples[0]["labels"].shape,
    # fmt: on
)

Issue Severity

None

Metadata

Metadata

Assignees

Labels

P0Issues that should be fixed in short orderbugSomething that is supposed to be working; but isn'tcommunity-backlog

Type

No type

Projects

Status

Done

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions