Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions kauldron/contrib/evals/checkpointed_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
# from kauldron import data
from kauldron.evals import evaluators
from kauldron.train import auxiliaries
from kauldron.train import metric_writer
from kauldron.train import train_step
from kauldron.train import trainer_lib
from kauldron.utils import utils
Expand Down Expand Up @@ -174,12 +175,11 @@ def evaluate(
f"Dataset for eval {self.name!r} did not yield any elements."
)

self.writer.write_step_metrics(
values = metric_writer.prepare_step_metrics(
step=step,
aux=merged_aux,
schedules={},
log_summaries=True,
)
self.writer.write_step_metrics(step=step, values=values)

# Wait for the last checkpoint to be saved before completing the evaluation.
self.checkpointer.wait_until_finished()
Expand Down
5 changes: 2 additions & 3 deletions kauldron/evals/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,12 +262,11 @@ def evaluate(
f'{epy.pretty_repr(self.ds)}'
)

self.writer.write_step_metrics(
values = metric_writer.prepare_step_metrics(
step=step,
aux=merged_aux,
schedules={},
log_summaries=True,
)
self.writer.write_step_metrics(step=step, values=values)
return merged_aux

@functools.partial(
Expand Down
235 changes: 137 additions & 98 deletions kauldron/train/metric_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,109 +163,46 @@ def write_context_structure(
) -> None:
"""Write the context structure."""

# TODO(b/378060021): tidy this function up after old summaries are removed
def write_step_metrics(
self,
*,
step: int,
aux: auxiliaries.AuxiliariesState,
schedules: Mapping[str, optax.Schedule],
log_summaries: bool,
timer: Optional[chrono_utils.Chrono] = None,
values: dict[str, Any] | None = None,
**legacy_kwargs,
) -> None:
"""Logs scalar and image summaries."""
aux_result = aux.compute(flatten=True)

if not status.is_lead_host:
return

# schedules
schedule_values = jax.tree.map(
lambda s: _compute_schedule(s, step), schedules
)
schedule_values = kontext.flatten_with_path(
schedule_values, prefix="schedules", separator="/"
)

if timer:
performance_stats = {
f"perf_stats/{k}": v for k, v in timer.flush_metrics().items()
}
else:
performance_stats = {}
self.write_scalars(
step=step,
scalars=(
aux_result.loss_values
| aux_result.metric_values
| schedule_values
| performance_stats
),
)

if log_summaries:
image_summaries = {
name: value
for name, value in aux_result.summary_values.items()
if isinstance(value, Float["n h w #3"])
}
# Throw an error if empty arrays are given. TB throws very odd errors
# and kills Colab runtimes if we don't catch these ourselves.
for name, image in image_summaries.items():
if image.size == 0:
raise ValueError(
f"Image summary `{name}` is empty array of shape {image.shape}."
)
self.write_images(step=step, images=image_summaries)

# histograms
hist_summaries = {
name: value
for name, value in aux_result.summary_values.items()
if isinstance(value, summaries.Histogram)
}
self.write_histograms(
step=step,
arrays={k: hist.tensor for k, hist in hist_summaries.items()},
num_buckets={
k: hist.num_buckets for k, hist in hist_summaries.items()
},
"""Write pre-computed metrics, dispatching by value type.

Args:
step: The current training/eval step.
values: A flat ``{name: value}`` dict. Each value is dispatched to the
appropriate ``write_*`` method based on its type (scalar, image,
histogram, text, point cloud).
**legacy_kwargs: Deprecated. The old ``aux``, ``schedules``,
``log_summaries``, ``timer`` keyword arguments. These trigger
``prepare_step_metrics`` internally for backwards compatibility.
"""
if values is not None and legacy_kwargs:
raise TypeError(
"Cannot mix 'values' with legacy keyword arguments "
"(aux, schedules, log_summaries, timer)."
)

# point clouds
pc_summaries = {
name: value
for name, value in aux_result.summary_values.items()
if isinstance(value, summaries.PointCloud)
}
self.write_pointcloud(
step=step,
point_clouds={
k: point_cloud.point_clouds
for k, point_cloud in pc_summaries.items()
},
point_colors={
k: point_cloud.point_colors
for k, point_cloud in pc_summaries.items()
},
configs={
k: point_cloud.configs for k, point_cloud in pc_summaries.items()
},
if values is None:
if not legacy_kwargs:
raise TypeError(
"write_step_metrics() requires either 'values' or legacy keyword "
"arguments (aux, schedules, log_summaries, timer)."
)
logging.warning(
"Deprecated: write_step_metrics() with (aux, schedules, "
"log_summaries, timer) keyword arguments is deprecated. "
"Use prepare_step_metrics() + write_step_metrics(values=...) instead."
)
values = prepare_step_metrics(step=step, **legacy_kwargs)

# Text summaries
text_summaries = {
name: value
for name, value in aux_result.summary_values.items()
if isinstance(value, str) or enp.is_array_str(value)
}
self.write_texts(
step=step,
texts={k: text for k, text in text_summaries.items()},
)
if not status.is_lead_host:
return

# TODO(epot): This is blocking and slow. Is it really required ?
# Should likely be only called once at the end of the training / eval.
_write_values_by_type(self, step=step, values=values)
self.flush()

def flush(self) -> None:
Expand Down Expand Up @@ -593,10 +530,8 @@ def write_step_metrics(
self,
*,
step: int,
aux: auxiliaries.AuxiliariesState,
schedules: Mapping[str, optax.Schedule],
log_summaries: bool,
timer: Optional[chrono_utils.Chrono] = None,
values: dict[str, Any] | None = None,
**legacy_kwargs,
) -> None:
pass

Expand All @@ -619,3 +554,107 @@ def _compute_schedule(sched: optax.Schedule, step: int):
"""Evaluate schedule for step and return result."""
with jax.transfer_guard("allow"):
return sched(step)


def prepare_step_metrics(
*,
step: int,
aux: auxiliaries.AuxiliariesState,
schedules: Mapping[str, optax.Schedule] | None = None,
timer: Optional[chrono_utils.Chrono] = None,
log_summaries: bool = True,
) -> dict[str, Any]:
"""Compute and flatten all metrics into a writable ``{name: value}`` dict.

This is a helper that extracts the compute/schedule/timer logic that used to
live inside ``write_step_metrics``. Call it *before* passing the result to
``writer.write_step_metrics(step=..., values=...)``.

Args:
step: Current step (used for schedule evaluation).
aux: The accumulated auxiliaries state.
schedules: Optax schedules to evaluate at ``step``.
timer: Optional chrono timer whose ``flush_metrics`` will be included.
log_summaries: If ``False``, summary values (images, histograms, etc.) are
excluded from the returned dict.

Returns:
Flat ``{name: value}`` dict ready for ``write_step_metrics``.
"""
if schedules is None:
schedules = {}

aux_result = aux.compute(flatten=True)
values: dict[str, Any] = dict(
aux_result.loss_values | aux_result.metric_values
)

schedule_values = jax.tree.map(
lambda s: _compute_schedule(s, step), schedules
)
schedule_values = kontext.flatten_with_path(
schedule_values, prefix="schedules", separator="/"
)
values |= schedule_values

if timer:
values |= {f"perf_stats/{k}": v for k, v in timer.flush_metrics().items()}

if log_summaries:
values |= aux_result.summary_values

return values


def _write_values_by_type(
writer: WriterBase,
*,
step: int,
values: dict[str, Any],
) -> None:
"""Dispatch values to the appropriate write_* methods by leaf type."""
scalars: dict[str, Scalar] = {}
images: dict[str, Array] = {}
hist_arrays: dict[str, Array] = {}
hist_buckets: dict[str, int] = {}
pc_points: dict[str, Array] = {}
pc_colors: dict[str, Array | None] = {}
pc_configs: dict[str, Any] = {}
texts: dict[str, str] = {}

for name, value in values.items():
if isinstance(value, summaries.Histogram):
hist_arrays[name] = value.tensor
hist_buckets[name] = value.num_buckets
elif isinstance(value, summaries.PointCloud):
pc_points[name] = value.point_clouds
pc_colors[name] = value.point_colors
pc_configs[name] = value.configs
elif isinstance(value, str) or enp.is_array_str(value):
texts[name] = value
elif isinstance(value, Float["n h w #3"]):
if value.size == 0:
raise ValueError(
f"Image summary `{name}` is empty array of shape {value.shape}."
)
images[name] = value
else:
scalars[name] = value

if scalars:
writer.write_scalars(step=step, scalars=scalars)
if images:
writer.write_images(step=step, images=images)
if hist_arrays:
writer.write_histograms(
step=step, arrays=hist_arrays, num_buckets=hist_buckets
)
if pc_points:
writer.write_pointcloud(
step=step,
point_clouds=pc_points,
point_colors=pc_colors,
configs=pc_configs,
)
if texts:
writer.write_texts(step=step, texts=texts)
Loading
Loading