Skip to content

Commit 1176e43

Browse files
qiaodevcopybara-github
authored andcommitted
feat: Use ser_json_byte val_json_bytes in bytes type public interface
PiperOrigin-RevId: 715040448
1 parent 0e4b0e5 commit 1176e43

11 files changed

Lines changed: 432 additions & 104 deletions

google/genai/_api_client.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -574,13 +574,12 @@ def _verify_response(self, response_model: BaseModel):
574574
pass
575575

576576

577+
# TODO(b/389693448): Cleanup datetime hacks.
577578
class RequestJsonEncoder(json.JSONEncoder):
578579
"""Encode bytes as strings without modify its content."""
579580

580581
def default(self, o):
581-
if isinstance(o, bytes):
582-
return o.decode()
583-
elif isinstance(o, datetime.datetime):
582+
if isinstance(o, datetime.datetime):
584583
# This Zulu time format is used by the Vertex AI API and the test recorder
585584
# Using strftime works well, but we want to align with the replay encoder.
586585
# o.astimezone(datetime.timezone.utc).strftime('%Y-%m-%dT%H:%M:%S.%fZ')

google/genai/_common.py

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,8 @@ class BaseModel(pydantic.BaseModel):
189189
extra='forbid',
190190
# This allows us to use arbitrary types in the model. E.g. PIL.Image.
191191
arbitrary_types_allowed=True,
192+
ser_json_bytes='base64',
193+
val_json_bytes='base64',
192194
)
193195

194196
@classmethod
@@ -200,7 +202,10 @@ def _from_response(
200202
# We will provide another mechanism to allow users to access these fields.
201203
_remove_extra_fields(cls, response)
202204
validated_response = cls.model_validate(response)
203-
return apply_base64_decoding_for_model(validated_response)
205+
return validated_response
206+
207+
def to_json_dict(self) -> dict[str, object]:
208+
return self.model_dump(exclude_none=True, mode='json')
204209

205210

206211
def timestamped_unique_name() -> str:
@@ -216,40 +221,21 @@ def timestamped_unique_name() -> str:
216221

217222
def apply_base64_encoding(data: dict[str, object]) -> dict[str, object]:
218223
"""Applies base64 encoding to bytes values in the given data."""
219-
return process_bytes_fields(data, encode=True)
220-
221-
222-
def apply_base64_decoding(data: dict[str, object]) -> dict[str, object]:
223-
"""Applies base64 decoding to bytes values in the given data."""
224-
return process_bytes_fields(data, encode=False)
225-
226-
227-
def apply_base64_decoding_for_model(data: BaseModel) -> BaseModel:
228-
d = data.model_dump(exclude_none=True)
229-
d = apply_base64_decoding(d)
230-
return data.model_validate(d)
231-
232-
233-
def process_bytes_fields(data: dict[str, object], encode=True) -> dict[str, object]:
234224
processed_data = {}
235225
if not isinstance(data, dict):
236226
return data
237227
for key, value in data.items():
238228
if isinstance(value, bytes):
239-
if encode:
240-
processed_data[key] = base64.b64encode(value)
241-
else:
242-
processed_data[key] = base64.b64decode(value)
229+
processed_data[key] = base64.urlsafe_b64encode(value).decode('ascii')
243230
elif isinstance(value, dict):
244-
processed_data[key] = process_bytes_fields(value, encode)
231+
processed_data[key] = apply_base64_encoding(value)
245232
elif isinstance(value, list):
246-
if encode and all(isinstance(v, bytes) for v in value):
247-
processed_data[key] = [base64.b64encode(v) for v in value]
248-
elif all(isinstance(v, bytes) for v in value):
249-
processed_data[key] = [base64.b64decode(v) for v in value]
233+
if all(isinstance(v, bytes) for v in value):
234+
processed_data[key] = [
235+
base64.urlsafe_b64encode(v).decode('ascii') for v in value
236+
]
250237
else:
251-
processed_data[key] = [process_bytes_fields(v, encode) for v in value]
238+
processed_data[key] = [apply_base64_encoding(v) for v in value]
252239
else:
253240
processed_data[key] = value
254241
return processed_data
255-

google/genai/_replay_api_client.py

Lines changed: 6 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from typing import Any, Literal, Optional, Union
2626

2727
import google.auth
28-
from pydantic import BaseModel
2928
from requests.exceptions import HTTPError
3029

3130
from . import errors
@@ -34,6 +33,7 @@
3433
from ._api_client import HttpRequest
3534
from ._api_client import HttpResponse
3635
from ._api_client import RequestJsonEncoder
36+
from ._common import BaseModel
3737

3838
def _redact_version_numbers(version_string: str) -> str:
3939
"""Redacts version numbers in the form x.y.z from a string."""
@@ -264,18 +264,9 @@ def close(self):
264264
replay_file_path = self._get_replay_file_path()
265265
os.makedirs(os.path.dirname(replay_file_path), exist_ok=True)
266266
with open(replay_file_path, 'w') as f:
267-
replay_session_dict = self.replay_session.model_dump()
268-
# Use for non-utf-8 bytes in image/video... output.
269-
for interaction in replay_session_dict['interactions']:
270-
segments = []
271-
for response in interaction['response']['sdk_response_segments']:
272-
segments.append(json.loads(json.dumps(
273-
response, cls=ResponseJsonEncoder
274-
)))
275-
interaction['response']['sdk_response_segments'] = segments
276267
f.write(
277268
json.dumps(
278-
replay_session_dict, indent=2, cls=RequestJsonEncoder
269+
self.replay_session.model_dump(mode='json'), indent=2, cls=ResponseJsonEncoder
279270
)
280271
)
281272
self.replay_session = None
@@ -376,15 +367,8 @@ def _verify_response(self, response_model: BaseModel):
376367
if isinstance(response_model, list):
377368
response_model = response_model[0]
378369
print('response_model: ', response_model.model_dump(exclude_none=True))
379-
actual = json.dumps(
380-
response_model.model_dump(exclude_none=True),
381-
cls=ResponseJsonEncoder,
382-
sort_keys=True,
383-
)
384-
expected = json.dumps(
385-
interaction.response.sdk_response_segments[self._sdk_response_index],
386-
sort_keys=True,
387-
)
370+
actual = response_model.model_dump(exclude_none=True, mode='json')
371+
expected = interaction.response.sdk_response_segments[self._sdk_response_index]
388372
assert (
389373
actual == expected
390374
), f'SDK response mismatch:\nActual: {actual}\nExpected: {expected}'
@@ -437,36 +421,12 @@ def upload_file(self, file_path: str, upload_url: str, upload_size: int):
437421
return self._build_response_from_replay(request).text
438422

439423

424+
# TODO(b/389693448): Cleanup datetime hacks.
440425
class ResponseJsonEncoder(json.JSONEncoder):
441426
"""The replay test json encoder for response.
442-
443-
We need RequestJsonEncoder and ResponseJsonEncoder because:
444-
1. In production, we only need RequestJsonEncoder to help json module
445-
to convert non-stringable and stringable types to json string. Especially
446-
for bytes type, the value of bytes field is encoded to base64 string so it
447-
is always stringable and the RequestJsonEncoder doesn't have to deal with
448-
utf-8 JSON broken issue.
449-
2. In replay test, we also need ResponseJsonEncoder to help json module
450-
convert non-stringable and stringable types to json string. But response
451-
object returned from SDK method is different from the request api_client
452-
sent to server. For the bytes type, there is no base64 string in response
453-
anymore, because SDK handles it internally. So bytes type in Response is
454-
non-stringable. The ResponseJsonEncoder uses different encoding
455-
strategy than the RequestJsonEncoder to deal with utf-8 JSON broken issue.
456427
"""
457428
def default(self, o):
458-
if isinstance(o, bytes):
459-
# Use base64.b64encode() to encode bytes to string so that the media bytes
460-
# fields are serializable.
461-
# o.decode(encoding='utf-8', errors='replace') doesn't work because it
462-
# uses a fixed error string `\ufffd` for all non-utf-8 characters,
463-
# which cannot be converted back to original bytes. And other languages
464-
# only have the original bytes to compare with.
465-
# Since we use base64.b64encoding() in replay test, a change that breaks
466-
# native bytes can be captured by
467-
# test_compute_tokens.py::test_token_bytes_deserialization.
468-
return base64.b64encode(o).decode(encoding='utf-8')
469-
elif isinstance(o, datetime.datetime):
429+
if isinstance(o, datetime.datetime):
470430
# dt.isoformat() prints "2024-11-15T23:27:45.624657+00:00"
471431
# but replay files want "2024-11-15T23:27:45.624657Z"
472432
if o.isoformat().endswith('+00:00'):

google/genai/_transformers.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,3 +476,17 @@ def t_tuning_job_status(
476476
return 'JOB_STATE_FAILED'
477477
else:
478478
return status
479+
480+
481+
# Some fields don't accept url safe base64 encoding.
482+
# We shouldn't use this transformer if the backend adhere to Cloud Type
483+
# format https://cloud.google.com/docs/discovery/type-format.
484+
# TODO(b/389133914): Remove the hack after Vertex backend fix the issue.
485+
def t_bytes(api_client: _api_client.ApiClient, data: bytes) -> str:
486+
if not isinstance(data, bytes):
487+
return data
488+
if api_client.vertexai:
489+
return base64.b64encode(data).decode('ascii')
490+
else:
491+
return base64.urlsafe_encode(data).decode('ascii')
492+

google/genai/models.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1547,7 +1547,11 @@ def _Image_to_mldev(
15471547
raise ValueError('gcs_uri parameter is not supported in Google AI.')
15481548

15491549
if getv(from_object, ['image_bytes']) is not None:
1550-
setv(to_object, ['bytesBase64Encoded'], getv(from_object, ['image_bytes']))
1550+
setv(
1551+
to_object,
1552+
['bytesBase64Encoded'],
1553+
t.t_bytes(api_client, getv(from_object, ['image_bytes'])),
1554+
)
15511555

15521556
return to_object
15531557

@@ -1562,7 +1566,11 @@ def _Image_to_vertex(
15621566
setv(to_object, ['gcsUri'], getv(from_object, ['gcs_uri']))
15631567

15641568
if getv(from_object, ['image_bytes']) is not None:
1565-
setv(to_object, ['bytesBase64Encoded'], getv(from_object, ['image_bytes']))
1569+
setv(
1570+
to_object,
1571+
['bytesBase64Encoded'],
1572+
t.t_bytes(api_client, getv(from_object, ['image_bytes'])),
1573+
)
15661574

15671575
return to_object
15681576

@@ -3193,7 +3201,11 @@ def _Image_from_mldev(
31933201
to_object = {}
31943202

31953203
if getv(from_object, ['bytesBase64Encoded']) is not None:
3196-
setv(to_object, ['image_bytes'], getv(from_object, ['bytesBase64Encoded']))
3204+
setv(
3205+
to_object,
3206+
['image_bytes'],
3207+
t.t_bytes(api_client, getv(from_object, ['bytesBase64Encoded'])),
3208+
)
31973209

31983210
return to_object
31993211

@@ -3208,7 +3220,11 @@ def _Image_from_vertex(
32083220
setv(to_object, ['gcs_uri'], getv(from_object, ['gcsUri']))
32093221

32103222
if getv(from_object, ['bytesBase64Encoded']) is not None:
3211-
setv(to_object, ['image_bytes'], getv(from_object, ['bytesBase64Encoded']))
3223+
setv(
3224+
to_object,
3225+
['image_bytes'],
3226+
t.t_bytes(api_client, getv(from_object, ['bytesBase64Encoded'])),
3227+
)
32123228

32133229
return to_object
32143230

google/genai/tests/client/test_json_encoder.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import datetime
12
# Copyright 2024 Google LLC
23
#
34
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -19,21 +20,26 @@
1920

2021

2122
def test_json_encoder():
22-
assert json.dumps({'key': 'value'}, cls=RequestJsonEncoder) == '{"key": "value"}'
23-
assert json.dumps({'key': b'value'}, cls=RequestJsonEncoder) == '{"key": "value"}'
23+
date_value = datetime.datetime.fromtimestamp(
24+
1736397612, tz=datetime.timezone.utc
25+
)
26+
assert (
27+
json.dumps({'key': date_value}, cls=RequestJsonEncoder)
28+
== '{"key": "2025-01-09T04:40:12Z"}'
29+
)
2430
assert (
25-
json.dumps({'nested': {'key': 'value'}}, cls=RequestJsonEncoder)
26-
== '{"nested": {"key": "value"}}'
31+
json.dumps({'nested': {'key': date_value}}, cls=RequestJsonEncoder)
32+
== '{"nested": {"key": "2025-01-09T04:40:12Z"}}'
2733
)
2834
assert (
29-
json.dumps({'nested': {'key': b'value'}}, cls=RequestJsonEncoder)
30-
== '{"nested": {"key": "value"}}'
35+
json.dumps({'nested': {'key': date_value}}, cls=RequestJsonEncoder)
36+
== '{"nested": {"key": "2025-01-09T04:40:12Z"}}'
3137
)
3238
assert (
33-
json.dumps({'list': ['value', 'value']}, cls=RequestJsonEncoder)
34-
== '{"list": ["value", "value"]}'
39+
json.dumps({'list': [date_value, date_value]}, cls=RequestJsonEncoder)
40+
== '{"list": ["2025-01-09T04:40:12Z", "2025-01-09T04:40:12Z"]}'
3541
)
3642
assert (
37-
json.dumps({'list': [b'value', b'value']}, cls=RequestJsonEncoder)
38-
== '{"list": ["value", "value"]}'
43+
json.dumps({'list': [date_value, date_value]}, cls=RequestJsonEncoder)
44+
== '{"list": ["2025-01-09T04:40:12Z", "2025-01-09T04:40:12Z"]}'
3945
)

google/genai/tests/live/test_live.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ async def test_async_session_send_realtime_input(
198198
api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
199199
)
200200
realtime_input = types.LiveClientRealtimeInput(
201-
media_chunks=[types.Blob(data='000000', mime_type='audio/pcm')]
201+
media_chunks=[types.Blob(data='MDAwMDAw', mime_type='audio/pcm')]
202202
)
203203
await session.send(input=realtime_input)
204204
mock_websocket.send.assert_called_once()

google/genai/tests/models/test_generate_content_part.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -309,9 +309,6 @@
309309
),
310310
],
311311
),
312-
# Base64 string is invalid input.
313-
exception_if_vertex='400',
314-
exception_if_mldev='400',
315312
),
316313
pytest_helper.TestTableItem(
317314
name='test_union_none_part',
@@ -613,15 +610,14 @@ def test_from_function_call_response(client):
613610

614611
@pytest.mark.asyncio
615612
async def test_image_base64_stream_async(client):
616-
with pytest.raises(errors.ClientError):
617-
async for part in client.aio.models.generate_content_stream(
618-
model='gemini-1.5-flash-001',
619-
contents=[
620-
'What is this image about?',
621-
{'inline_data': {'data': image_string, 'mimeType': 'image/png'}},
622-
],
623-
):
624-
pass
613+
async for part in client.aio.models.generate_content_stream(
614+
model='gemini-1.5-flash-001',
615+
contents=[
616+
'What is this image about?',
617+
{'inline_data': {'data': image_string, 'mimeType': 'image/png'}},
618+
],
619+
):
620+
pass
625621

626622

627623
# function_call and function_response are tested in generate_content_tools.py

google/genai/tests/pytest_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def setup(
167167
# exclude_unset=True is needed to avoid warnings.
168168
# See https://github.com/pydantic/pydantic/issues/6467.
169169
json.dumps(
170-
test_table_file.model_dump(exclude_unset=True, by_alias=True),
170+
test_table_file.model_dump(exclude_unset=True, by_alias=True, mode='json'),
171171
indent=2,
172172
cls=ResponseJsonEncoder,
173173
)

0 commit comments

Comments
 (0)