Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions litellm/llms/gemini/image_generation/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,15 @@ def _transform_image_usage(self, usage_metadata: dict) -> ImageUsage:
tokens_details = usage_metadata.get("promptTokensDetails", [])
for details in tokens_details:
if isinstance(details, dict):
modality = details.get("modality")
token_count = details.get("tokenCount", 0)
modality = str(details.get("modality", "")).upper()
raw_token_count = details.get(
"tokenCount", details.get("token_count", 0)
)
token_count = raw_token_count if isinstance(raw_token_count, int) else 0
if modality == "TEXT":
input_tokens_details.text_tokens = token_count
input_tokens_details.text_tokens += token_count
elif modality == "IMAGE":
input_tokens_details.image_tokens = token_count
input_tokens_details.image_tokens += token_count

return ImageUsage(
input_tokens=usage_metadata.get("promptTokenCount", 0),
Expand Down Expand Up @@ -274,4 +277,4 @@ def transform_image_generation_response(
b64_json=prediction.get("bytesBase64Encoded", None),
url=None, # Google AI returns base64, not URLs
))
return model_response
return model_response
75 changes: 49 additions & 26 deletions litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -1623,6 +1623,11 @@ def _calculate_usage( # noqa: PLR0915
response_tokens: Optional[int] = None
response_tokens_details: Optional[CompletionTokensDetailsWrapper] = None
usage_metadata = completion_response["usageMetadata"]

def _get_token_count(detail: dict) -> int:
raw_token_count = detail.get("tokenCount", detail.get("token_count", 0))
return raw_token_count if isinstance(raw_token_count, int) else 0

if "cachedContentTokenCount" in usage_metadata:
cached_tokens = usage_metadata["cachedContentTokenCount"]

Expand All @@ -1632,10 +1637,16 @@ def _calculate_usage( # noqa: PLR0915
if "responseTokensDetails" in usage_metadata:
response_tokens_details = CompletionTokensDetailsWrapper()
for detail in usage_metadata["responseTokensDetails"]:
if detail["modality"] == "TEXT":
response_tokens_details.text_tokens = detail.get("tokenCount", 0)
elif detail["modality"] == "AUDIO":
response_tokens_details.audio_tokens = detail.get("tokenCount", 0)
modality = str(detail.get("modality", "")).upper()
token_count = _get_token_count(detail)
if modality == "TEXT":
response_tokens_details.text_tokens = (
response_tokens_details.text_tokens or 0
) + token_count
elif modality == "AUDIO":
response_tokens_details.audio_tokens = (
response_tokens_details.audio_tokens or 0
) + token_count

#########################################################

Expand All @@ -1644,16 +1655,24 @@ def _calculate_usage( # noqa: PLR0915
if response_tokens_details is None:
response_tokens_details = CompletionTokensDetailsWrapper()
for detail in usage_metadata["candidatesTokensDetails"]:
modality = detail.get("modality")
token_count = detail.get("tokenCount", 0)
modality = str(detail.get("modality", "")).upper()
token_count = _get_token_count(detail)
if modality == "TEXT":
response_tokens_details.text_tokens = token_count
response_tokens_details.text_tokens = (
response_tokens_details.text_tokens or 0
) + token_count
elif modality == "AUDIO":
response_tokens_details.audio_tokens = token_count
response_tokens_details.audio_tokens = (
response_tokens_details.audio_tokens or 0
) + token_count
elif modality == "IMAGE":
response_tokens_details.image_tokens = token_count
response_tokens_details.image_tokens = (
response_tokens_details.image_tokens or 0
) + token_count
elif modality == "VIDEO":
response_tokens_details.video_tokens = token_count
response_tokens_details.video_tokens = (
response_tokens_details.video_tokens or 0
) + token_count

# Calculate text_tokens if not explicitly provided in candidatesTokensDetails
# candidatesTokenCount includes all modalities, so: text = total - (image + audio + video)
Expand All @@ -1677,14 +1696,16 @@ def _calculate_usage( # noqa: PLR0915
## Parse promptTokensDetails (total tokens by modality, includes cached + non-cached)
if "promptTokensDetails" in usage_metadata:
for detail in usage_metadata["promptTokensDetails"]:
if detail["modality"] == "AUDIO":
prompt_audio_tokens = detail.get("tokenCount", 0)
elif detail["modality"] == "TEXT":
prompt_text_tokens = detail.get("tokenCount", 0)
elif detail["modality"] == "IMAGE":
prompt_image_tokens = detail.get("tokenCount", 0)
elif detail["modality"] == "VIDEO":
prompt_video_tokens = detail.get("tokenCount", 0)
modality = str(detail.get("modality", "")).upper()
token_count = _get_token_count(detail)
if modality == "AUDIO":
prompt_audio_tokens = (prompt_audio_tokens or 0) + token_count
elif modality == "TEXT":
prompt_text_tokens = (prompt_text_tokens or 0) + token_count
elif modality == "IMAGE":
prompt_image_tokens = (prompt_image_tokens or 0) + token_count
elif modality == "VIDEO":
prompt_video_tokens = (prompt_video_tokens or 0) + token_count

## Parse cacheTokensDetails (breakdown of cached tokens by modality)
## When explicit caching is used, Gemini provides this field to show which modalities were cached
Expand All @@ -1695,14 +1716,16 @@ def _calculate_usage( # noqa: PLR0915

if "cacheTokensDetails" in usage_metadata:
for detail in usage_metadata["cacheTokensDetails"]:
if detail["modality"] == "AUDIO":
cached_audio_tokens = detail.get("tokenCount", 0)
elif detail["modality"] == "TEXT":
cached_text_tokens = detail.get("tokenCount", 0)
elif detail["modality"] == "IMAGE":
cached_image_tokens = detail.get("tokenCount", 0)
elif detail["modality"] == "VIDEO":
cached_video_tokens = detail.get("tokenCount", 0)
modality = str(detail.get("modality", "")).upper()
token_count = _get_token_count(detail)
if modality == "AUDIO":
cached_audio_tokens = (cached_audio_tokens or 0) + token_count
elif modality == "TEXT":
cached_text_tokens = (cached_text_tokens or 0) + token_count
elif modality == "IMAGE":
cached_image_tokens = (cached_image_tokens or 0) + token_count
elif modality == "VIDEO":
cached_video_tokens = (cached_video_tokens or 0) + token_count

## Calculate non-cached tokens by subtracting cached from total (per modality)
## This is necessary because promptTokensDetails includes both cached and non-cached tokens
Expand Down
55 changes: 55 additions & 0 deletions tests/llm_translation/test_gemini_image_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
This test verifies the fix for issue #18323 where image_generation()
was returning usage=0 while completion() returned proper token usage.
"""
import os
import pytest
from unittest.mock import patch, MagicMock
import litellm
from litellm.llms.gemini.image_generation.transformation import GoogleImageGenConfig
from litellm.types.utils import ImageResponse, ImageObject, ImageUsage


Expand Down Expand Up @@ -211,3 +213,56 @@ def test_gemini_imagen_models_no_usage_extraction():

# For Imagen models, we don't extract usage from the predictions format
# This test just ensures we don't crash


def test_gemini_image_generation_accumulates_multiple_image_prompt_token_details():
"""
Regression test: promptTokensDetails can include multiple IMAGE entries.
These must be accumulated instead of overwritten.
"""
previous_local_model_cost_map = os.environ.get("LITELLM_LOCAL_MODEL_COST_MAP")
previous_model_cost = litellm.model_cost
try:
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")

model = "gemini/gemini-3-pro-image-preview"
config = GoogleImageGenConfig()

usage_metadata = {
"promptTokenCount": 200,
"candidatesTokenCount": 0,
"totalTokenCount": 200,
"promptTokensDetails": [
{"modality": "TEXT", "tokenCount": 10},
{"modality": "IMAGE", "tokenCount": 90},
{"modality": "IMAGE", "tokenCount": 100},
],
}

parsed_usage = config._transform_image_usage(usage_metadata)
image_response = ImageResponse(
data=[ImageObject(b64_json="fake_image_data")],
usage=parsed_usage,
)

observed_cost = litellm.completion_cost(
completion_response=image_response,
model=model,
custom_llm_provider="gemini",
)

model_info = litellm.get_model_info(model=model, custom_llm_provider="gemini")
expected_image_tokens = 190
expected_total_prompt_tokens = 200
expected_prompt_cost = expected_total_prompt_tokens * model_info["input_cost_per_token"]

assert parsed_usage.input_tokens_details.image_tokens == expected_image_tokens
assert parsed_usage.input_tokens_details.text_tokens == 10
assert observed_cost == pytest.approx(expected_prompt_cost, rel=1e-12)
finally:
if previous_local_model_cost_map is None:
os.environ.pop("LITELLM_LOCAL_MODEL_COST_MAP", None)
else:
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = previous_local_model_cost_map
litellm.model_cost = previous_model_cost
Comment on lines +218 to +268
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing test for vertex _calculate_usage accumulation

This test covers the GoogleImageGenConfig._transform_image_usage path (image generation), but the PR also changed token accumulation logic in VertexGeminiConfig._calculate_usage across four loops (responseTokensDetails, candidatesTokensDetails, promptTokensDetails, cacheTokensDetails). None of those accumulation changes are covered by a test with duplicate modality entries.

Consider adding a test in tests/test_litellm/llms/vertex_ai/gemini/test_vertex_and_google_ai_studio_gemini.py that passes promptTokensDetails (or candidatesTokensDetails) with multiple entries of the same modality to verify the accumulation works there too.

Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,42 @@ def test_vertex_ai_usage_metadata_with_image_tokens_in_prompt():
)


def test_vertex_ai_usage_metadata_accumulates_duplicate_modalities():
"""Ensure _calculate_usage accumulates repeated modality entries."""
v = VertexGeminiConfig()
usage_metadata = {
"promptTokenCount": 210,
"candidatesTokenCount": 50,
"totalTokenCount": 260,
"promptTokensDetails": [
{"modality": "TEXT", "tokenCount": 20},
{"modality": "IMAGE", "tokenCount": 90},
{"modality": "IMAGE", "token_count": 100},
],
"candidatesTokensDetails": [
{"modality": "IMAGE", "tokenCount": 30},
{"modality": "TEXT", "tokenCount": 15},
{"modality": "TEXT", "token_count": 5},
],
"cacheTokensDetails": [
{"modality": "TEXT", "tokenCount": 4},
{"modality": "IMAGE", "tokenCount": 40},
{"modality": "IMAGE", "token_count": 10},
],
}
usage_metadata = UsageMetadata(**usage_metadata)
result = v._calculate_usage(completion_response={"usageMetadata": usage_metadata})

# prompt details are total - cached per modality
assert result.prompt_tokens_details.text_tokens == 16 # 20 - 4
assert result.prompt_tokens_details.image_tokens == 140 # (90 + 100) - (40 + 10)

# candidates details accumulate duplicate modalities
assert result.completion_tokens_details.text_tokens == 20 # 15 + 5
assert result.completion_tokens_details.image_tokens == 30
assert result.completion_tokens == 50


def test_vertex_ai_map_thinking_param_with_budget_tokens_0():
"""
If budget_tokens is 0, do not set includeThoughts to True
Expand Down Expand Up @@ -3723,4 +3759,3 @@ def test_vertex_ai_usage_metadata_video_tokens_with_caching():
"Prompt video tokens should be 10240 - 5120 (cached) = 5120"
assert result.prompt_tokens_details.text_tokens == 9
assert result.prompt_tokens_details.audio_tokens == 200

Loading