-
Notifications
You must be signed in to change notification settings - Fork 7.3k
Description
What happened + What you expected to happen
I have a Ray Data pipeline consisting of the following steps:
- Image source which loads input images
- Sampling of a set of patches from each image
- Image patch encoder (batch inference)
- Group by image UID, and collect patch embeddings for each image
While developing this pipeline, especially the use of groupby().map_groups() which is new to me, I used ds.take() at the end so I can check the output of the map_groups operation. The result was incorrect, where the number of embedding vectors per image was consisting of only 1 element when using take(1) instead of the expected number of embeddings (e.g., a PxD matrix rather where P is number of patches that were extracted from a given image).
I noticed in the Data Dashboard that the limit operation added by take() was actually performed before the map_groups(). Only after setting up a toy script as posted below I realized that this was causing the unexpected behavior.
With above pipeline, I would have expected take(n) to return the embeddings computed for n images. Not limit the per patch embeddings to n before executing the map_groups() operation.
Versions / Dependencies
Ray 2.53.0
Reproduction script
The following script outputs:
Image ID: UUID_2 Points: (1, 3) Features: (1, 128) Labels: (1,)
Image ID: UUID_2 Points: (18, 3) Features: (18, 128) Labels: (18,)
The first output is from the use of take(1) exhibiting the issue, while the second output is the expected result.
"""Ray Data pipeline to test and debug ray.data.Dataset.groupby().map_groups()."""
# %%
# Imports
import numpy as np
import ray.data
from ray.data.block import DataBatch
# %%
# Initialize Ray
if ray.is_initialized():
ray.shutdown()
ray.init(
runtime_env={
"env_vars": {
"RAY_DEBUG": "0",
"RAY_DEBUG_POST_MORTEM": "0",
},
}
)
# %%
# Dummy data source
image_ds = ray.data.from_items([{"image_id": f"UUID_{i}"} for i in range(10)])
# %%
# Extract image patches
def extract_patches(row: dict) -> list[dict]:
"""Dummy function to simulate extraction of multiple patches from each image (1-to-N flat_map)."""
image_id = row["image_id"]
n_points = np.random.randint(10, 20)
point_ids = np.arange(n_points)
points = np.random.rand(n_points, 3)
labels = np.random.randint(0, 2, size=n_points)
patches = [
{
"image_id": image_id,
"point_id": point_id,
"point": point,
"label": label
}
for point_id, point, label in zip(point_ids, points, labels)
]
return patches
patches_ds = image_ds.flat_map(extract_patches, num_cpus=0.01)
# %%
# Process image patches
def process_patches(batch: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
"""Dummy function to simulate processing of image patches (e.g. feature extraction)."""
batch_size = batch["point"].shape[0]
batch["feats"] = np.random.rand(batch_size, 128)
return batch
processed_patches_ds = patches_ds.map_batches(
process_patches, # type: ignore[arg-type]
batch_format="numpy",
batch_size=64,
num_cpus=0.01,
)
# %%
# Collect image features
def agg_grouped_features(group: DataBatch) -> DataBatch:
"""Aggregate patch features for each image into one matrix."""
image_id = np.unique(group["image_id"])
if len(image_id) != 1:
raise ValueError(f"Expected exactly one unique image_id in group, found {image_id}")
point_ids = group["point_id"]
n_points = point_ids.shape[0]
if not np.array_equal(np.sort(point_ids), np.arange(n_points)):
raise ValueError(f"Expected point_ids of group to contain all indices from 0 to {n_points - 1}")
order = np.argsort(point_ids)
points = group["point"][order]
feats = group["feats"][order]
labels = group["label"][order]
return {
"image_id": image_id,
"points": np.expand_dims(points, axis=0),
"feats": np.expand_dims(feats, axis=0),
"labels": np.expand_dims(labels, axis=0),
}
features_ds = processed_patches_ds.groupby("image_id").map_groups(
agg_grouped_features,
batch_format="numpy",
num_cpus=0.01,
)
# Uncomment the below line to avoid the issue.
# features_ds = features_ds.materialize()
# %%
# Take samples from final dataset
#
# FIXME: The limit() operator is moved before the map_groups() operation, causing the wrong number of items to be grouped.
# Instead of obtaining 1 item with the features from all patches across that image, only the features of 1 patch are returned.
#
# This is not the case when using features_ds.take_all() or after materializing the dataset first.
samples = features_ds.take_all()
sample = features_ds.take(1)[0]
print(
# fmt: off
"Image ID:", sample["image_id"],
"Points:", sample["points"].shape,
"Features:", sample["feats"].shape,
"Labels:", sample["labels"].shape,
# fmt: on
)
print(
# fmt: off
"Image ID:", samples[0]["image_id"],
"Points:", samples[0]["points"].shape,
"Features:", samples[0]["feats"].shape,
"Labels:", samples[0]["labels"].shape,
# fmt: on
)Issue Severity
None
Metadata
Metadata
Assignees
Labels
Type
Projects
Status