Skip to content

Commit b07c549

Browse files
sararobcopybara-github
authored andcommitted
feat: support new fields in FileData, GenerationConfig, GroundingChunkRetrievedContext, RetrievalConfig, Schema, TuningJob, VertexAISearch,
PiperOrigin-RevId: 761989390
1 parent 38acaed commit b07c549

9 files changed

Lines changed: 929 additions & 183 deletions

File tree

google/genai/_live_converters.py

Lines changed: 106 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,42 @@ def _Blob_to_vertex(
281281
return to_object
282282

283283

284+
def _FileData_to_mldev(
285+
api_client: BaseApiClient,
286+
from_object: Union[dict[str, Any], object],
287+
parent_object: Optional[dict[str, Any]] = None,
288+
) -> dict[str, Any]:
289+
to_object: dict[str, Any] = {}
290+
if getv(from_object, ['display_name']) is not None:
291+
raise ValueError('display_name parameter is not supported in Gemini API.')
292+
293+
if getv(from_object, ['file_uri']) is not None:
294+
setv(to_object, ['fileUri'], getv(from_object, ['file_uri']))
295+
296+
if getv(from_object, ['mime_type']) is not None:
297+
setv(to_object, ['mimeType'], getv(from_object, ['mime_type']))
298+
299+
return to_object
300+
301+
302+
def _FileData_to_vertex(
303+
api_client: BaseApiClient,
304+
from_object: Union[dict[str, Any], object],
305+
parent_object: Optional[dict[str, Any]] = None,
306+
) -> dict[str, Any]:
307+
to_object: dict[str, Any] = {}
308+
if getv(from_object, ['display_name']) is not None:
309+
setv(to_object, ['displayName'], getv(from_object, ['display_name']))
310+
311+
if getv(from_object, ['file_uri']) is not None:
312+
setv(to_object, ['fileUri'], getv(from_object, ['file_uri']))
313+
314+
if getv(from_object, ['mime_type']) is not None:
315+
setv(to_object, ['mimeType'], getv(from_object, ['mime_type']))
316+
317+
return to_object
318+
319+
284320
def _Part_to_mldev(
285321
api_client: BaseApiClient,
286322
from_object: Union[dict[str, Any], object],
@@ -308,6 +344,15 @@ def _Part_to_mldev(
308344
),
309345
)
310346

347+
if getv(from_object, ['file_data']) is not None:
348+
setv(
349+
to_object,
350+
['fileData'],
351+
_FileData_to_mldev(
352+
api_client, getv(from_object, ['file_data']), to_object
353+
),
354+
)
355+
311356
if getv(from_object, ['code_execution_result']) is not None:
312357
setv(
313358
to_object,
@@ -318,9 +363,6 @@ def _Part_to_mldev(
318363
if getv(from_object, ['executable_code']) is not None:
319364
setv(to_object, ['executableCode'], getv(from_object, ['executable_code']))
320365

321-
if getv(from_object, ['file_data']) is not None:
322-
setv(to_object, ['fileData'], getv(from_object, ['file_data']))
323-
324366
if getv(from_object, ['function_call']) is not None:
325367
setv(to_object, ['functionCall'], getv(from_object, ['function_call']))
326368

@@ -364,6 +406,15 @@ def _Part_to_vertex(
364406
),
365407
)
366408

409+
if getv(from_object, ['file_data']) is not None:
410+
setv(
411+
to_object,
412+
['fileData'],
413+
_FileData_to_vertex(
414+
api_client, getv(from_object, ['file_data']), to_object
415+
),
416+
)
417+
367418
if getv(from_object, ['code_execution_result']) is not None:
368419
setv(
369420
to_object,
@@ -374,9 +425,6 @@ def _Part_to_vertex(
374425
if getv(from_object, ['executable_code']) is not None:
375426
setv(to_object, ['executableCode'], getv(from_object, ['executable_code']))
376427

377-
if getv(from_object, ['file_data']) is not None:
378-
setv(to_object, ['fileData'], getv(from_object, ['file_data']))
379-
380428
if getv(from_object, ['function_call']) is not None:
381429
setv(to_object, ['functionCall'], getv(from_object, ['function_call']))
382430

@@ -2653,6 +2701,40 @@ def _Blob_from_vertex(
26532701
return to_object
26542702

26552703

2704+
def _FileData_from_mldev(
2705+
api_client: BaseApiClient,
2706+
from_object: Union[dict[str, Any], object],
2707+
parent_object: Optional[dict[str, Any]] = None,
2708+
) -> dict[str, Any]:
2709+
to_object: dict[str, Any] = {}
2710+
2711+
if getv(from_object, ['fileUri']) is not None:
2712+
setv(to_object, ['file_uri'], getv(from_object, ['fileUri']))
2713+
2714+
if getv(from_object, ['mimeType']) is not None:
2715+
setv(to_object, ['mime_type'], getv(from_object, ['mimeType']))
2716+
2717+
return to_object
2718+
2719+
2720+
def _FileData_from_vertex(
2721+
api_client: BaseApiClient,
2722+
from_object: Union[dict[str, Any], object],
2723+
parent_object: Optional[dict[str, Any]] = None,
2724+
) -> dict[str, Any]:
2725+
to_object: dict[str, Any] = {}
2726+
if getv(from_object, ['displayName']) is not None:
2727+
setv(to_object, ['display_name'], getv(from_object, ['displayName']))
2728+
2729+
if getv(from_object, ['fileUri']) is not None:
2730+
setv(to_object, ['file_uri'], getv(from_object, ['fileUri']))
2731+
2732+
if getv(from_object, ['mimeType']) is not None:
2733+
setv(to_object, ['mime_type'], getv(from_object, ['mimeType']))
2734+
2735+
return to_object
2736+
2737+
26562738
def _Part_from_mldev(
26572739
api_client: BaseApiClient,
26582740
from_object: Union[dict[str, Any], object],
@@ -2680,6 +2762,15 @@ def _Part_from_mldev(
26802762
),
26812763
)
26822764

2765+
if getv(from_object, ['fileData']) is not None:
2766+
setv(
2767+
to_object,
2768+
['file_data'],
2769+
_FileData_from_mldev(
2770+
api_client, getv(from_object, ['fileData']), to_object
2771+
),
2772+
)
2773+
26832774
if getv(from_object, ['codeExecutionResult']) is not None:
26842775
setv(
26852776
to_object,
@@ -2690,9 +2781,6 @@ def _Part_from_mldev(
26902781
if getv(from_object, ['executableCode']) is not None:
26912782
setv(to_object, ['executable_code'], getv(from_object, ['executableCode']))
26922783

2693-
if getv(from_object, ['fileData']) is not None:
2694-
setv(to_object, ['file_data'], getv(from_object, ['fileData']))
2695-
26962784
if getv(from_object, ['functionCall']) is not None:
26972785
setv(to_object, ['function_call'], getv(from_object, ['functionCall']))
26982786

@@ -2736,6 +2824,15 @@ def _Part_from_vertex(
27362824
),
27372825
)
27382826

2827+
if getv(from_object, ['fileData']) is not None:
2828+
setv(
2829+
to_object,
2830+
['file_data'],
2831+
_FileData_from_vertex(
2832+
api_client, getv(from_object, ['fileData']), to_object
2833+
),
2834+
)
2835+
27392836
if getv(from_object, ['codeExecutionResult']) is not None:
27402837
setv(
27412838
to_object,
@@ -2746,9 +2843,6 @@ def _Part_from_vertex(
27462843
if getv(from_object, ['executableCode']) is not None:
27472844
setv(to_object, ['executable_code'], getv(from_object, ['executableCode']))
27482845

2749-
if getv(from_object, ['fileData']) is not None:
2750-
setv(to_object, ['file_data'], getv(from_object, ['fileData']))
2751-
27522846
if getv(from_object, ['functionCall']) is not None:
27532847
setv(to_object, ['function_call'], getv(from_object, ['functionCall']))
27542848

google/genai/_tokens_converters.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,42 @@ def _Blob_to_vertex(
281281
return to_object
282282

283283

284+
def _FileData_to_mldev(
285+
api_client: BaseApiClient,
286+
from_object: Union[dict[str, Any], object],
287+
parent_object: Optional[dict[str, Any]] = None,
288+
) -> dict[str, Any]:
289+
to_object: dict[str, Any] = {}
290+
if getv(from_object, ['display_name']) is not None:
291+
raise ValueError('display_name parameter is not supported in Gemini API.')
292+
293+
if getv(from_object, ['file_uri']) is not None:
294+
setv(to_object, ['fileUri'], getv(from_object, ['file_uri']))
295+
296+
if getv(from_object, ['mime_type']) is not None:
297+
setv(to_object, ['mimeType'], getv(from_object, ['mime_type']))
298+
299+
return to_object
300+
301+
302+
def _FileData_to_vertex(
303+
api_client: BaseApiClient,
304+
from_object: Union[dict[str, Any], object],
305+
parent_object: Optional[dict[str, Any]] = None,
306+
) -> dict[str, Any]:
307+
to_object: dict[str, Any] = {}
308+
if getv(from_object, ['display_name']) is not None:
309+
setv(to_object, ['displayName'], getv(from_object, ['display_name']))
310+
311+
if getv(from_object, ['file_uri']) is not None:
312+
setv(to_object, ['fileUri'], getv(from_object, ['file_uri']))
313+
314+
if getv(from_object, ['mime_type']) is not None:
315+
setv(to_object, ['mimeType'], getv(from_object, ['mime_type']))
316+
317+
return to_object
318+
319+
284320
def _Part_to_mldev(
285321
api_client: BaseApiClient,
286322
from_object: Union[dict[str, Any], object],
@@ -308,6 +344,15 @@ def _Part_to_mldev(
308344
),
309345
)
310346

347+
if getv(from_object, ['file_data']) is not None:
348+
setv(
349+
to_object,
350+
['fileData'],
351+
_FileData_to_mldev(
352+
api_client, getv(from_object, ['file_data']), to_object
353+
),
354+
)
355+
311356
if getv(from_object, ['code_execution_result']) is not None:
312357
setv(
313358
to_object,
@@ -318,9 +363,6 @@ def _Part_to_mldev(
318363
if getv(from_object, ['executable_code']) is not None:
319364
setv(to_object, ['executableCode'], getv(from_object, ['executable_code']))
320365

321-
if getv(from_object, ['file_data']) is not None:
322-
setv(to_object, ['fileData'], getv(from_object, ['file_data']))
323-
324366
if getv(from_object, ['function_call']) is not None:
325367
setv(to_object, ['functionCall'], getv(from_object, ['function_call']))
326368

@@ -364,6 +406,15 @@ def _Part_to_vertex(
364406
),
365407
)
366408

409+
if getv(from_object, ['file_data']) is not None:
410+
setv(
411+
to_object,
412+
['fileData'],
413+
_FileData_to_vertex(
414+
api_client, getv(from_object, ['file_data']), to_object
415+
),
416+
)
417+
367418
if getv(from_object, ['code_execution_result']) is not None:
368419
setv(
369420
to_object,
@@ -374,9 +425,6 @@ def _Part_to_vertex(
374425
if getv(from_object, ['executable_code']) is not None:
375426
setv(to_object, ['executableCode'], getv(from_object, ['executable_code']))
376427

377-
if getv(from_object, ['file_data']) is not None:
378-
setv(to_object, ['fileData'], getv(from_object, ['file_data']))
379-
380428
if getv(from_object, ['function_call']) is not None:
381429
setv(to_object, ['functionCall'], getv(from_object, ['function_call']))
382430

google/genai/_transformers.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,21 @@ def handle_null_fields(schema: dict[str, Any]) -> None:
630630
del schema['anyOf']
631631

632632

633+
def _raise_for_unsupported_schema_type(origin: Any) -> None:
634+
"""Raises an error if the schema type is unsupported."""
635+
raise ValueError(f'Unsupported schema type: {origin}')
636+
637+
638+
def _raise_for_unsupported_mldev_properties(schema: Any, client: _api_client.BaseApiClient) -> None:
639+
if not client.vertexai and (
640+
schema.get('additionalProperties')
641+
or schema.get('additional_properties')
642+
):
643+
raise ValueError(
644+
'additionalProperties is not supported in the Gemini API.'
645+
)
646+
647+
633648
def process_schema(
634649
schema: dict[str, Any],
635650
client: _api_client.BaseApiClient,
@@ -700,6 +715,8 @@ def process_schema(
700715
if schema.get('title') == 'PlaceholderLiteralEnum':
701716
del schema['title']
702717

718+
_raise_for_unsupported_mldev_properties(schema, client)
719+
703720
# Standardize spelling for relevant schema fields. For example, if a dict is
704721
# provided directly to response_schema, it may use `any_of` instead of `anyOf.
705722
# Otherwise, model_json_schema() uses `anyOf`.
@@ -818,8 +835,9 @@ def t_schema(
818835
return _process_enum(origin, client)
819836
if isinstance(origin, types.Schema):
820837
if dict(origin) == dict(types.Schema()):
821-
# response_schema value was coerced to an empty Schema instance because it did not adhere to the Schema field annotation
822-
raise ValueError(f'Unsupported schema type.')
838+
# response_schema value was coerced to an empty Schema instance because
839+
# it did not adhere to the Schema field annotation
840+
_raise_for_unsupported_schema_type(origin)
823841
schema = origin.model_dump(exclude_unset=True)
824842
process_schema(schema, client)
825843
return types.Schema.model_validate(schema)

0 commit comments

Comments
 (0)