Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions src/genai_utils/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,50 +213,57 @@ def add_citations(response: types.GenerateContentResponse) -> str:
return text


def validate_labels(labels: dict[str, str]) -> None:
def validate_labels(labels: dict[str, str]) -> dict[str, str]:
"""
Validates labels for GCP requirements.
Validates labels for GCP requirements, removing any labels that would cause GCP to
return an error.

GCP label requirements:
- Keys must start with a lowercase letter
- Keys and values can only contain lowercase letters, numbers, hyphens, and underscores
- Keys and values must be max 63 characters
- Keys cannot be empty

Raises:
GeminiError: If labels don't meet GCP requirements
"""
label_pattern = re.compile(r"^[a-z0-9_-]{1,63}$")
key_start_pattern = re.compile(r"^[a-z]")

valid_labels: dict[str, str] = {}
for key, value in labels.items():
if not key:
raise GeminiError("Label keys cannot be empty")
_logger.warning("Label keys cannot be empty")
continue

if len(key) > 63:
raise GeminiError(
_logger.warning(
f"Label key '{key}' exceeds 63 characters (length: {len(key)})"
)
continue

if len(value) > 63:
raise GeminiError(
_logger.warning(
f"Label value for key '{key}' exceeds 63 characters (length: {len(value)})"
)
continue

if not key_start_pattern.match(key):
raise GeminiError(f"Label key '{key}' must start with a lowercase letter")
_logger.warning(f"Label key '{key}' must start with a lowercase letter")
continue

if not label_pattern.match(key):
raise GeminiError(
_logger.warning(
f"Label key '{key}' contains invalid characters. "
"Only lowercase letters, numbers, hyphens, and underscores are allowed"
)
continue

if not label_pattern.match(value):
raise GeminiError(
_logger.warning(
f"Label value '{value}' for key '{key}' contains invalid characters. "
"Only lowercase letters, numbers, hyphens, and underscores are allowed"
)
continue
valid_labels[key] = value
return valid_labels


def check_grounding_ran(response: types.GenerateContentResponse) -> bool:
Expand Down Expand Up @@ -543,8 +550,7 @@ class Movie(BaseModel):

if inline_citations and not use_grounding:
raise GeminiError("Inline citations only work if `use_grounding = True`")
merged_labels = DEFAULT_LABELS | labels
validate_labels(merged_labels)
merged_labels = validate_labels(DEFAULT_LABELS | labels)

response = await client.aio.models.generate_content(
model=model_config.model_name,
Expand Down
83 changes: 44 additions & 39 deletions tests/genai_utils/test_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from google.genai.models import Models

from genai_utils.gemini import (
GeminiError,
ModelConfig,
run_prompt_async,
validate_labels,
Expand All @@ -24,74 +23,74 @@ async def get_dummy():


def test_validate_labels_valid():
"""Test that valid labels pass validation"""
"""Test that valid labels pass validation and are returned"""
valid_labels = {
"team": "ai",
"project": "genai-utils",
"environment": "production",
"version": "1-2-3",
"my_label": "my_value",
}
# Should not raise any exception
validate_labels(valid_labels)
result = validate_labels(valid_labels)
assert result == valid_labels


def test_validate_labels_empty_key():
"""Test that empty keys are rejected"""
with pytest.raises(GeminiError, match="cannot be empty"):
validate_labels({"": "value"})
"""Test that empty keys are filtered out"""
result = validate_labels({"": "value", "valid": "label"})
assert result == {"valid": "label"}


def test_validate_labels_key_too_long():
"""Test that keys exceeding 63 characters are rejected"""
"""Test that keys exceeding 63 characters are filtered out"""
long_key = "a" * 64
with pytest.raises(GeminiError, match="exceeds 63 characters"):
validate_labels({long_key: "value"})
result = validate_labels({long_key: "value", "valid": "label"})
assert result == {"valid": "label"}


def test_validate_labels_value_too_long():
"""Test that values exceeding 63 characters are rejected"""
"""Test that values exceeding 63 characters are filtered out"""
long_value = "a" * 64
with pytest.raises(GeminiError, match="exceeds 63 characters"):
validate_labels({"key": long_value})
result = validate_labels({"key": long_value, "valid": "label"})
assert result == {"valid": "label"}


def test_validate_labels_key_starts_with_number():
"""Test that keys starting with numbers are rejected"""
with pytest.raises(GeminiError, match="must start with a lowercase letter"):
validate_labels({"1key": "value"})
"""Test that keys starting with numbers are filtered out"""
result = validate_labels({"1key": "value", "valid": "label"})
assert result == {"valid": "label"}


def test_validate_labels_key_starts_with_uppercase():
"""Test that keys starting with uppercase are rejected"""
with pytest.raises(GeminiError, match="must start with a lowercase letter"):
validate_labels({"Key": "value"})
"""Test that keys starting with uppercase are filtered out"""
result = validate_labels({"Key": "value", "valid": "label"})
assert result == {"valid": "label"}


@pytest.mark.parametrize(
"invalid_key", ["key@value", "key.name", "key$", "key with space", "key/name"]
)
def test_validate_labels_key_invalid_characters(invalid_key):
"""Test that keys with invalid characters are rejected"""
with pytest.raises(GeminiError, match="contains invalid characters"):
validate_labels({invalid_key: "value"})
"""Test that keys with invalid characters are filtered out"""
result = validate_labels({invalid_key: "value", "valid": "label"})
assert result == {"valid": "label"}


@pytest.mark.parametrize(
"invalid_value", ["value@", "value.txt", "value$", "value with space", "value/"]
)
def test_validate_labels_value_invalid_characters(invalid_value):
"""Test that values with invalid characters are rejected"""
with pytest.raises(GeminiError, match="contains invalid characters"):
validate_labels({"key": invalid_value})
"""Test that values with invalid characters are filtered out"""
result = validate_labels({"key": invalid_value, "valid": "label"})
assert result == {"valid": "label"}


def test_validate_labels_max_length_valid():
"""Test that keys and values at exactly 63 characters are valid"""
max_key = "a" * 63
max_value = "b" * 63
# Should not raise any exception
validate_labels({max_key: max_value})
result = validate_labels({max_key: max_value})
assert result == {max_key: max_value}


def test_validate_labels_valid_special_chars():
Expand All @@ -102,8 +101,8 @@ def test_validate_labels_valid_special_chars():
"my-key_name": "my-value_123",
"key123": "value456",
}
# Should not raise any exception
validate_labels(valid_labels)
result = validate_labels(valid_labels)
assert result == valid_labels


@patch("genai_utils.gemini.genai.Client")
Expand Down Expand Up @@ -137,7 +136,7 @@ async def test_run_prompt_with_valid_labels(mock_client):

@patch("genai_utils.gemini.genai.Client")
async def test_run_prompt_with_invalid_labels(mock_client):
"""Test that run_prompt rejects invalid labels"""
"""Test that run_prompt filters out invalid labels"""
client = Mock(Client)
models = Mock(Models)
async_client = Mock(AsyncClient)
Expand All @@ -147,16 +146,22 @@ async def test_run_prompt_with_invalid_labels(mock_client):
async_client.models = models
mock_client.return_value = client

invalid_labels = {"Invalid": "value"} # uppercase key
invalid_labels = {"Invalid": "value", "valid": "label"} # uppercase key is invalid

with pytest.raises(GeminiError, match="must start with a lowercase letter"):
await run_prompt_async(
"test prompt",
labels=invalid_labels,
model_config=ModelConfig(
project="project", location="location", model_name="model"
),
)
await run_prompt_async(
"test prompt",
labels=invalid_labels,
model_config=ModelConfig(
project="project", location="location", model_name="model"
),
)

# Verify the call was made with only valid labels
assert models.generate_content.called
call_kwargs = models.generate_content.call_args[1]
assert "config" in call_kwargs
# The invalid "Invalid" key should be filtered out
assert call_kwargs["config"].labels == {"valid": "label"}


@patch("genai_utils.gemini.genai.Client")
Expand Down