Skip to content

Commit 05e7efd

Browse files
sampan-s-nayaksampan
andauthored
[Core] Optimize open telemetry metric recording calls (#59337)
## Description this pr introduces the following optimizations in the `opentelemetryMetricsRecorder` and some of its consumers: - use asynchronous instruments wherever available (counter and up down counter) - introduce a batch api to record histogram metrics (to prevent lock contention caused by repeated `set_metric_value()` calls) - batch events received metric update in aggregator_agent instead of making individual calls --------- Signed-off-by: sampan <sampan@anyscale.com> Co-authored-by: sampan <sampan@anyscale.com>
1 parent 2015205 commit 05e7efd

File tree

6 files changed

+322
-102
lines changed

6 files changed

+322
-102
lines changed

python/ray/_private/telemetry/metric_cardinality.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22
from typing import Callable, Dict, List
33

44
from ray._private.ray_constants import RAY_METRIC_CARDINALITY_LEVEL
5+
from ray._private.telemetry.metric_types import MetricType
56

67
# Keep in sync with the WorkerIdKey in src/ray/stats/tag_defs.cc
78
WORKER_ID_TAG_KEY = "WorkerId"
89
# Keep in sync with the NameKey in src/ray/stats/tag_defs.cc
910
TASK_OR_ACTOR_NAME_TAG_KEY = "Name"
10-
HIGH_CARDINALITY_METRICS_TO_AGGREGATION: Dict[str, Callable[[List[float]], float]] = {
11-
"tasks": lambda values: sum(values),
12-
"actors": lambda values: sum(values),
11+
# Aggregation functions for high-cardinality gauge metrics when labels are dropped.
12+
# Counter and Sum metrics always use sum() aggregation.
13+
HIGH_CARDINALITY_GAUGE_AGGREGATION: Dict[str, Callable[[List[float]], float]] = {
14+
"tasks": sum,
15+
"actors": sum,
1316
}
1417

1518
_CARDINALITY_LEVEL = None
@@ -45,14 +48,33 @@ def get_cardinality_level() -> "MetricCardinality":
4548
return _CARDINALITY_LEVEL
4649

4750
@staticmethod
48-
def get_aggregation_function(metric_name: str) -> Callable[[List[float]], float]:
49-
if metric_name in HIGH_CARDINALITY_METRICS_TO_AGGREGATION:
50-
return HIGH_CARDINALITY_METRICS_TO_AGGREGATION[metric_name]
51+
def get_aggregation_function(
52+
metric_name: str, metric_type: MetricType = MetricType.GAUGE
53+
) -> Callable[[List[float]], float]:
54+
"""Get the aggregation function for a metric when labels are dropped. This method does not currently support histogram metrics.
55+
56+
Args:
57+
metric_name: The name of the metric.
58+
metric_type: The type of the metric. If provided, Counter and Sum
59+
metrics always use sum() aggregation.
60+
61+
Returns:
62+
A function that takes a list of values and returns the aggregated value.
63+
"""
64+
# Counter and Sum metrics always aggregate by summing
65+
if metric_type in (MetricType.COUNTER, MetricType.SUM):
66+
return sum
67+
# Histogram metrics are not supported by this method
68+
if metric_type == MetricType.HISTOGRAM:
69+
raise ValueError("No Aggregation function for histogram metrics.")
70+
# Gauge metrics use metric-specific aggregation or default to first value
71+
if metric_name in HIGH_CARDINALITY_GAUGE_AGGREGATION:
72+
return HIGH_CARDINALITY_GAUGE_AGGREGATION[metric_name]
5173
return lambda values: values[0]
5274

5375
@staticmethod
5476
def get_high_cardinality_metrics() -> List[str]:
55-
return list(HIGH_CARDINALITY_METRICS_TO_AGGREGATION.keys())
77+
return list(HIGH_CARDINALITY_GAUGE_AGGREGATION.keys())
5678

5779
@staticmethod
5880
def get_high_cardinality_labels_to_drop(metric_name: str) -> List[str]:
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from enum import Enum
2+
3+
4+
class MetricType(Enum):
5+
"""Types of metrics supported by the telemetry system.
6+
7+
Note: SUMMARY metric type is not supported. SUMMARY is a Prometheus metric type
8+
that is not explicitly supported in OpenTelemetry. Use HISTOGRAM instead for
9+
similar use cases (e.g., latency distributions with quantiles).
10+
"""
11+
12+
GAUGE = 0
13+
COUNTER = 1
14+
SUM = 2
15+
HISTOGRAM = 3

python/ray/_private/telemetry/open_telemetry_metric_recorder.py

Lines changed: 154 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import threading
33
from collections import defaultdict
4-
from typing import List
4+
from typing import Callable, List
55

66
from opentelemetry import metrics
77
from opentelemetry.exporter.prometheus import PrometheusMetricReader
@@ -10,6 +10,7 @@
1010

1111
from ray._private.metrics_agent import Record
1212
from ray._private.telemetry.metric_cardinality import MetricCardinality
13+
from ray._private.telemetry.metric_types import MetricType
1314

1415
logger = logging.getLogger(__name__)
1516

@@ -29,11 +30,68 @@ class OpenTelemetryMetricRecorder:
2930
def __init__(self):
3031
self._lock = threading.Lock()
3132
self._registered_instruments = {}
32-
self._observations_by_name = defaultdict(dict)
33+
self._gauge_observations_by_name = defaultdict(dict)
34+
self._counter_observations_by_name = defaultdict(dict)
35+
self._sum_observations_by_name = defaultdict(dict)
3336
self._histogram_bucket_midpoints = defaultdict(list)
3437
self._init_metrics()
3538
self.meter = metrics.get_meter(__name__)
3639

40+
def _create_observable_callback(
41+
self, metric_name: str, metric_type: MetricType
42+
) -> Callable[[dict], List[Observation]]:
43+
"""
44+
Factory method to create callbacks for observable metrics.
45+
46+
Args:
47+
metric_name: name of the metric for which the callback is being created
48+
metric_type: type of the metric for which the callback is being created
49+
50+
Returns:
51+
Callable: A callback function that can be used to record observations for the metric.
52+
"""
53+
54+
def callback(options):
55+
with self._lock:
56+
# Select appropriate storage based on metric type
57+
if metric_type == MetricType.GAUGE:
58+
observations = self._gauge_observations_by_name.get(metric_name, {})
59+
# Clear after reading (gauges report last value)
60+
self._gauge_observations_by_name[metric_name] = {}
61+
elif metric_type == MetricType.COUNTER:
62+
observations = self._counter_observations_by_name.get(
63+
metric_name, {}
64+
)
65+
# Don't clear - counters are cumulative
66+
elif metric_type == MetricType.SUM:
67+
observations = self._sum_observations_by_name.get(metric_name, {})
68+
# Don't clear - sums are cumulative
69+
else:
70+
return []
71+
72+
# Aggregate by filtered tags (drop high cardinality labels)
73+
high_cardinality_labels = (
74+
MetricCardinality.get_high_cardinality_labels_to_drop(metric_name)
75+
)
76+
# First, collect all values that share the same filtered tag set
77+
values_by_filtered_tags = defaultdict(list)
78+
for tag_set, val in observations.items():
79+
filtered = frozenset(
80+
(k, v) for k, v in tag_set if k not in high_cardinality_labels
81+
)
82+
values_by_filtered_tags[filtered].append(val)
83+
84+
# Then aggregate each group using the appropriate aggregation function
85+
agg_fn = MetricCardinality.get_aggregation_function(
86+
metric_name, metric_type
87+
)
88+
return [
89+
Observation(agg_fn(values), attributes=dict(filtered))
90+
for filtered, values in values_by_filtered_tags.items()
91+
]
92+
93+
return callback
94+
3795
def _init_metrics(self):
3896
# Initialize the global metrics provider and meter. We only do this once on
3997
# the first initialization of the class, because re-setting the meter provider
@@ -52,55 +110,19 @@ def register_gauge_metric(self, name: str, description: str) -> None:
52110
# Gauge with the same name is already registered.
53111
return
54112

55-
# Register ObservableGauge with a dynamic callback. Callbacks are special
56-
# features in OpenTelemetry that allow you to provide a function that will
57-
# compute the telemetry at collection time.
58-
def callback(options):
59-
# Take snapshot of current observations.
60-
with self._lock:
61-
observations = self._observations_by_name[name]
62-
# Clear the observations to avoid emitting dead observations.
63-
self._observations_by_name[name] = {}
64-
# Drop high cardinality from tag_set and sum up the value for
65-
# same tag set after dropping
66-
aggregated_observations = defaultdict(list)
67-
high_cardinality_labels = (
68-
MetricCardinality.get_high_cardinality_labels_to_drop(name)
69-
)
70-
for tag_set, val in observations.items():
71-
# Convert frozenset back to dict
72-
tags_dict = dict(tag_set)
73-
# Filter out high cardinality labels
74-
filtered_tags = {
75-
k: v
76-
for k, v in tags_dict.items()
77-
if k not in high_cardinality_labels
78-
}
79-
# Create a key for aggregation
80-
filtered_key = frozenset(filtered_tags.items())
81-
# Collect values for the same filtered tag set for aggregation
82-
aggregated_observations[filtered_key].append(val)
83-
84-
return [
85-
Observation(
86-
MetricCardinality.get_aggregation_function(name)(values),
87-
attributes=dict(tag_set),
88-
)
89-
for tag_set, values in aggregated_observations.items()
90-
]
91-
113+
callback = self._create_observable_callback(name, MetricType.GAUGE)
92114
instrument = self.meter.create_observable_gauge(
93115
name=f"{NAMESPACE}_{name}",
94116
description=description,
95117
unit="1",
96118
callbacks=[callback],
97119
)
98120
self._registered_instruments[name] = instrument
99-
self._observations_by_name[name] = {}
121+
self._gauge_observations_by_name[name] = {}
100122

101123
def register_counter_metric(self, name: str, description: str) -> None:
102124
"""
103-
Register a counter metric with the given name and description.
125+
Register an observable counter metric with the given name and description.
104126
"""
105127
with self._lock:
106128
if name in self._registered_instruments:
@@ -111,16 +133,19 @@ def register_counter_metric(self, name: str, description: str) -> None:
111133
# registered multiple times.
112134
return
113135

114-
instrument = self.meter.create_counter(
136+
callback = self._create_observable_callback(name, MetricType.COUNTER)
137+
instrument = self.meter.create_observable_counter(
115138
name=f"{NAMESPACE}_{name}",
116139
description=description,
117140
unit="1",
141+
callbacks=[callback],
118142
)
119143
self._registered_instruments[name] = instrument
144+
self._counter_observations_by_name[name] = {}
120145

121146
def register_sum_metric(self, name: str, description: str) -> None:
122147
"""
123-
Register a sum metric with the given name and description.
148+
Register an observable sum metric with the given name and description.
124149
"""
125150
with self._lock:
126151
if name in self._registered_instruments:
@@ -131,12 +156,15 @@ def register_sum_metric(self, name: str, description: str) -> None:
131156
# registered multiple times.
132157
return
133158

134-
instrument = self.meter.create_up_down_counter(
159+
callback = self._create_observable_callback(name, MetricType.SUM)
160+
instrument = self.meter.create_observable_up_down_counter(
135161
name=f"{NAMESPACE}_{name}",
136162
description=description,
137163
unit="1",
164+
callbacks=[callback],
138165
)
139166
self._registered_instruments[name] = instrument
167+
self._sum_observations_by_name[name] = {}
140168

141169
def register_histogram_metric(
142170
self, name: str, description: str, buckets: List[float]
@@ -187,35 +215,97 @@ def get_histogram_bucket_midpoints(self, name: str) -> List[float]:
187215

188216
def set_metric_value(self, name: str, tags: dict, value: float):
189217
"""
190-
Set the value of a metric with the given name and tags. If the metric is not
191-
registered, it lazily records the value for observable metrics or is a no-op for
218+
Set the value of a metric with the given name and tags.
219+
220+
For observable metrics (gauge, counter, sum), this stores the value internally
221+
and returns immediately. The value will be exported asynchronously when
222+
OpenTelemetry collects metrics.
223+
224+
For histograms, this calls record() synchronously since there is no observable
225+
histogram in OpenTelemetry.
226+
227+
If the metric is not registered, it lazily records the value for observable metrics or is a no-op for
192228
synchronous metrics.
193229
"""
194230
with self._lock:
195-
if self._observations_by_name.get(name) is not None:
196-
# Set the value of an observable metric with the given name and tags. It
197-
# lazily records the metric value by storing it in a dictionary until
198-
# the value actually gets exported by OpenTelemetry.
199-
self._observations_by_name[name][frozenset(tags.items())] = value
231+
tag_key = frozenset(tags.items())
232+
if self._gauge_observations_by_name.get(name) is not None:
233+
# Gauge - store the most recent value for the given tags.
234+
self._gauge_observations_by_name[name][tag_key] = value
235+
elif name in self._counter_observations_by_name:
236+
# Counter - increment the value for the given tags.
237+
self._counter_observations_by_name[name][tag_key] = (
238+
self._counter_observations_by_name[name].get(tag_key, 0) + value
239+
)
240+
elif name in self._sum_observations_by_name:
241+
# Sum - add the value for the given tags.
242+
self._sum_observations_by_name[name][tag_key] = (
243+
self._sum_observations_by_name[name].get(tag_key, 0) + value
244+
)
200245
else:
246+
# Histogram - record the value synchronously.
201247
instrument = self._registered_instruments.get(name)
202-
tags = {
203-
k: v
204-
for k, v in tags.items()
205-
if k
206-
not in MetricCardinality.get_high_cardinality_labels_to_drop(name)
207-
}
208-
if isinstance(instrument, metrics.Counter):
209-
instrument.add(value, attributes=tags)
210-
elif isinstance(instrument, metrics.UpDownCounter):
211-
instrument.add(value, attributes=tags)
212-
elif isinstance(instrument, metrics.Histogram):
213-
instrument.record(value, attributes=tags)
248+
if isinstance(instrument, metrics.Histogram):
249+
# Filter out high cardinality labels.
250+
filtered_tags = {
251+
k: v
252+
for k, v in tags.items()
253+
if k
254+
not in MetricCardinality.get_high_cardinality_labels_to_drop(
255+
name
256+
)
257+
}
258+
instrument.record(value, attributes=filtered_tags)
214259
else:
215260
logger.warning(
216-
f"Unsupported synchronous instrument type for metric: {name}."
261+
f"Metric {name} is not registered or unsupported type."
217262
)
218263

264+
def record_histogram_aggregated_batch(
265+
self,
266+
name: str,
267+
data_points: List[dict],
268+
) -> None:
269+
"""
270+
Record pre-aggregated histogram data for multiple data points in a single batch.
271+
272+
This method takes pre-aggregated bucket counts and reconstructs individual
273+
observations using bucket midpoints. It acquires the lock once and performs
274+
all record() calls for ALL data points, minimizing lock contention.
275+
276+
Note: The histogram sum value will be an approximation since we use bucket midpoints instead of actual values.
277+
"""
278+
with self._lock:
279+
instrument = self._registered_instruments.get(name)
280+
if not isinstance(instrument, metrics.Histogram):
281+
logger.warning(
282+
f"Metric {name} is not a registered histogram, skipping recording."
283+
)
284+
return
285+
286+
bucket_midpoints = self._histogram_bucket_midpoints[name]
287+
high_cardinality_labels = (
288+
MetricCardinality.get_high_cardinality_labels_to_drop(name)
289+
)
290+
291+
for dp in data_points:
292+
tags = dp["tags"]
293+
bucket_counts = dp["bucket_counts"]
294+
assert len(bucket_counts) == len(
295+
bucket_midpoints
296+
), "Number of bucket counts and midpoints must match"
297+
298+
filtered_tags = {
299+
k: v for k, v in tags.items() if k not in high_cardinality_labels
300+
}
301+
302+
for i, bucket_count in enumerate(bucket_counts):
303+
if bucket_count == 0:
304+
continue
305+
midpoint = bucket_midpoints[i]
306+
for _ in range(bucket_count):
307+
instrument.record(midpoint, attributes=filtered_tags)
308+
219309
def record_and_export(self, records: List[Record], global_tags=None):
220310
"""
221311
Record a list of telemetry records and export them to Prometheus.

0 commit comments

Comments
 (0)