Skip to content

Commit 0f713f1

Browse files
hkt74copybara-github
authored andcommitted
feat: support list models to return base models
PiperOrigin-RevId: 713744661
1 parent a1ed3fa commit 0f713f1

6 files changed

Lines changed: 154 additions & 43 deletions

File tree

google/genai/_api_client.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ class HttpOptions(BaseModel):
6060
default=None,
6161
description="""Timeout for the request in seconds.""",
6262
)
63+
skip_project_and_location_in_path: bool = Field(
64+
default=False,
65+
description="""If set to True, the project and location will not be appended to the path.""",
66+
)
6367

6468

6569
class HttpOptionsDict(TypedDict):
@@ -75,7 +79,8 @@ class HttpOptionsDict(TypedDict):
7579
"""If set, the response payload will be returned int the supplied dict."""
7680
timeout: Optional[Union[float, Tuple[float, float]]] = None
7781
"""Timeout for the request in seconds."""
78-
82+
skip_project_and_location_in_path: bool = False
83+
"""If set to True, the project and location will not be appended to the path."""
7984

8085
HttpOptionsOrDict = Union[HttpOptions, HttpOptionsDict]
8186

@@ -266,7 +271,14 @@ def _build_request(
266271
)
267272
else:
268273
patched_http_options = self._http_options
269-
if self.vertexai and not path.startswith('projects/'):
274+
skip_project_and_location_in_path_val = patched_http_options.get(
275+
'skip_project_and_location_in_path', False
276+
)
277+
if (
278+
self.vertexai
279+
and not path.startswith('projects/')
280+
and not skip_project_and_location_in_path_val
281+
):
270282
path = f'projects/{self.project}/locations/{self.location}/' + path
271283
url = _join_url_path(
272284
patched_http_options['base_url'],

google/genai/_replay_api_client.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ def _redact_request_url(url: str) -> str:
7272
'{VERTEX_URL_PREFIX}/',
7373
url,
7474
)
75+
result = re.sub(
76+
r'.*-aiplatform.googleapis.com/[^/]+/',
77+
'{VERTEX_URL_PREFIX}/',
78+
result,
79+
)
7580
result = re.sub(
7681
r'https://generativelanguage.googleapis.com/[^/]+',
7782
'{MLDEV_URL_PREFIX}',

google/genai/_transformers.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,30 @@ def t_model(client: _api_client.ApiClient, model: str):
142142
else:
143143
return f'models/{model}'
144144

145+
def t_models_url(api_client: _api_client.ApiClient, base_models: bool) -> str:
146+
if api_client.vertexai:
147+
if base_models:
148+
return 'publishers/google/models'
149+
else:
150+
return 'models'
151+
else:
152+
if base_models:
153+
return 'models'
154+
else:
155+
return 'tunedModels'
156+
157+
158+
def t_extract_models(api_client: _api_client.ApiClient, response: dict) -> list[types.Model]:
159+
if response.get('models') is not None:
160+
return response.get('models')
161+
elif response.get('tunedModels') is not None:
162+
return response.get('tunedModels')
163+
elif response.get('publisherModels') is not None:
164+
return response.get('publisherModels')
165+
else:
166+
raise ValueError('Cannot determine the models type.')
167+
168+
145169
def t_caches_model(api_client: _api_client.ApiClient, model: str):
146170
model = t_model(api_client, model)
147171
if not model:

google/genai/models.py

Lines changed: 65 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from . import _extra_utils
2323
from . import _transformers as t
2424
from . import types
25-
from ._api_client import ApiClient
25+
from ._api_client import ApiClient, HttpOptionsDict
2626
from ._common import get_value_by_path as getv
2727
from ._common import set_value_by_path as setv
2828
from .pagers import AsyncPager, Pager
@@ -2280,6 +2280,9 @@ def _ListModelsConfig_to_mldev(
22802280
parent_object: dict = None,
22812281
) -> dict:
22822282
to_object = {}
2283+
if getv(from_object, ['http_options']) is not None:
2284+
setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))
2285+
22832286
if getv(from_object, ['page_size']) is not None:
22842287
setv(
22852288
parent_object, ['_query', 'pageSize'], getv(from_object, ['page_size'])
@@ -2295,6 +2298,13 @@ def _ListModelsConfig_to_mldev(
22952298
if getv(from_object, ['filter']) is not None:
22962299
setv(parent_object, ['_query', 'filter'], getv(from_object, ['filter']))
22972300

2301+
if getv(from_object, ['query_base']) is not None:
2302+
setv(
2303+
parent_object,
2304+
['_url', 'models_url'],
2305+
t.t_models_url(api_client, getv(from_object, ['query_base'])),
2306+
)
2307+
22982308
return to_object
22992309

23002310

@@ -2304,6 +2314,9 @@ def _ListModelsConfig_to_vertex(
23042314
parent_object: dict = None,
23052315
) -> dict:
23062316
to_object = {}
2317+
if getv(from_object, ['http_options']) is not None:
2318+
setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))
2319+
23072320
if getv(from_object, ['page_size']) is not None:
23082321
setv(
23092322
parent_object, ['_query', 'pageSize'], getv(from_object, ['page_size'])
@@ -2319,6 +2332,13 @@ def _ListModelsConfig_to_vertex(
23192332
if getv(from_object, ['filter']) is not None:
23202333
setv(parent_object, ['_query', 'filter'], getv(from_object, ['filter']))
23212334

2335+
if getv(from_object, ['query_base']) is not None:
2336+
setv(
2337+
parent_object,
2338+
['_url', 'models_url'],
2339+
t.t_models_url(api_client, getv(from_object, ['query_base'])),
2340+
)
2341+
23222342
return to_object
23232343

23242344

@@ -3524,13 +3544,15 @@ def _ListModelsResponse_from_mldev(
35243544
if getv(from_object, ['nextPageToken']) is not None:
35253545
setv(to_object, ['next_page_token'], getv(from_object, ['nextPageToken']))
35263546

3527-
if getv(from_object, ['tunedModels']) is not None:
3547+
if getv(from_object, ['_self']) is not None:
35283548
setv(
35293549
to_object,
35303550
['models'],
35313551
[
35323552
_Model_from_mldev(api_client, item, to_object)
3533-
for item in getv(from_object, ['tunedModels'])
3553+
for item in t.t_extract_models(
3554+
api_client, getv(from_object, ['_self'])
3555+
)
35343556
],
35353557
)
35363558

@@ -3546,13 +3568,15 @@ def _ListModelsResponse_from_vertex(
35463568
if getv(from_object, ['nextPageToken']) is not None:
35473569
setv(to_object, ['next_page_token'], getv(from_object, ['nextPageToken']))
35483570

3549-
if getv(from_object, ['models']) is not None:
3571+
if getv(from_object, ['_self']) is not None:
35503572
setv(
35513573
to_object,
35523574
['models'],
35533575
[
35543576
_Model_from_vertex(api_client, item, to_object)
3555-
for item in getv(from_object, ['models'])
3577+
for item in t.t_extract_models(
3578+
api_client, getv(from_object, ['_self'])
3579+
)
35563580
],
35573581
)
35583582

@@ -4091,12 +4115,12 @@ def _list(
40914115
request_dict = _ListModelsParameters_to_vertex(
40924116
self.api_client, parameter_model
40934117
)
4094-
path = 'models'.format_map(request_dict.get('_url'))
4118+
path = '{models_url}'.format_map(request_dict.get('_url'))
40954119
else:
40964120
request_dict = _ListModelsParameters_to_mldev(
40974121
self.api_client, parameter_model
40984122
)
4099-
path = 'tunedModels'.format_map(request_dict.get('_url'))
4123+
path = '{models_url}'.format_map(request_dict.get('_url'))
41004124
query_params = request_dict.get('_query')
41014125
if query_params:
41024126
path = f'{path}?{urlencode(query_params)}'
@@ -4523,17 +4547,24 @@ def list(
45234547
types._ListModelsParameters(config=config).config
45244548
or types.ListModelsConfig()
45254549
)
4526-
45274550
if self.api_client.vertexai:
4528-
# Filter for tuning jobs artifacts by labels.
45294551
config = config.copy()
4530-
filter_value = config.filter
4531-
config.filter = (
4532-
filter_value + '&filter=labels.tune-type:*'
4533-
if filter_value
4534-
else 'labels.tune-type:*'
4535-
)
4536-
4552+
if config.query_base:
4553+
http_options = (
4554+
config.http_options if config.http_options else HttpOptionsDict()
4555+
)
4556+
http_options['skip_project_and_location_in_path'] = True
4557+
config.http_options = http_options
4558+
else:
4559+
# Filter for tuning jobs artifacts by labels.
4560+
filter_value = config.filter
4561+
config.filter = (
4562+
filter_value + '&filter=labels.tune-type:*'
4563+
if filter_value
4564+
else 'labels.tune-type:*'
4565+
)
4566+
if not config.query_base:
4567+
config.query_base = False
45374568
return Pager(
45384569
'models',
45394570
self._list,
@@ -4999,12 +5030,12 @@ async def _list(
49995030
request_dict = _ListModelsParameters_to_vertex(
50005031
self.api_client, parameter_model
50015032
)
5002-
path = 'models'.format_map(request_dict.get('_url'))
5033+
path = '{models_url}'.format_map(request_dict.get('_url'))
50035034
else:
50045035
request_dict = _ListModelsParameters_to_mldev(
50055036
self.api_client, parameter_model
50065037
)
5007-
path = 'tunedModels'.format_map(request_dict.get('_url'))
5038+
path = '{models_url}'.format_map(request_dict.get('_url'))
50085039
query_params = request_dict.get('_query')
50095040
if query_params:
50105041
path = f'{path}?{urlencode(query_params)}'
@@ -5366,16 +5397,24 @@ async def list(
53665397
types._ListModelsParameters(config=config).config
53675398
or types.ListModelsConfig()
53685399
)
5369-
53705400
if self.api_client.vertexai:
5371-
# Filter for tuning jobs artifacts by labels.
53725401
config = config.copy()
5373-
filter_value = config.filter
5374-
config.filter = (
5375-
filter_value + '&filter=labels.tune-type:*'
5376-
if filter_value
5377-
else 'labels.tune-type:*'
5378-
)
5402+
if config.query_base:
5403+
http_options = (
5404+
config.http_options if config.http_options else HttpOptionsDict()
5405+
)
5406+
http_options['skip_project_and_location_in_path'] = True
5407+
config.http_options = http_options
5408+
else:
5409+
# Filter for tuning jobs artifacts by labels.
5410+
filter_value = config.filter
5411+
config.filter = (
5412+
filter_value + '&filter=labels.tune-type:*'
5413+
if filter_value
5414+
else 'labels.tune-type:*'
5415+
)
5416+
if not config.query_base:
5417+
config.query_base = False
53795418
return AsyncPager(
53805419
'models',
53815420
self._list,

google/genai/tests/models/test_list.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,15 @@
2424

2525
test_table: list[pytest_helper.TestTableItem] = [
2626
pytest_helper.TestTableItem(
27-
name='test_list_models',
27+
name='test_tuned_models',
2828
parameters=types._ListModelsParameters(),
2929
),
3030
pytest_helper.TestTableItem(
31-
name='test_list_models_with_config',
31+
name='test_base_models',
32+
parameters=types._ListModelsParameters(config={'query_base': True}),
33+
),
34+
pytest_helper.TestTableItem(
35+
name='test_with_config',
3236
parameters=types._ListModelsParameters(config={'page_size': 3}),
3337
),
3438
]
@@ -40,30 +44,44 @@
4044
)
4145

4246

43-
def test_pager(client):
44-
models = client.models.list(config={'page_size': 10})
47+
def test_tuned_models_pager(client):
48+
pager = client.models.list(config={'page_size': 10})
49+
50+
assert pager.name == 'models'
51+
assert pager.page_size == 10
52+
assert len(pager) <= 10
53+
54+
# Iterate through all the pages. Then next_page() should raise an exception.
55+
for _ in pager:
56+
pass
57+
with pytest.raises(IndexError, match='No more pages to fetch.'):
58+
pager.next_page()
59+
60+
61+
def test_base_models_pager(client):
62+
pager = client.models.list(config={'page_size': 10, 'query_base': True})
4563

46-
assert models.name == 'models'
47-
assert models.page_size == 10
48-
assert len(models) <= 10
64+
assert pager.name == 'models'
65+
assert pager.page_size == 10
66+
assert len(pager) <= 10
4967

5068
# Iterate through all the pages. Then next_page() should raise an exception.
51-
for _ in models:
69+
for _ in pager:
5270
pass
5371
with pytest.raises(IndexError, match='No more pages to fetch.'):
54-
models.next_page()
72+
pager.next_page()
5573

5674

5775
@pytest.mark.asyncio
5876
async def test_async_pager(client):
59-
models = await client.aio.models.list(config={'page_size': 10})
77+
pager = await client.aio.models.list(config={'page_size': 10})
6078

61-
assert models.name == 'models'
62-
assert models.page_size == 10
63-
assert len(models) <= 10
79+
assert pager.name == 'models'
80+
assert pager.page_size == 10
81+
assert len(pager) <= 10
6482

6583
# Iterate through all the pages. Then next_page() should raise an exception.
66-
async for _ in models:
84+
async for _ in pager:
6785
pass
6886
with pytest.raises(IndexError, match='No more pages to fetch.'):
69-
await models.next_page()
87+
await pager.next_page()

google/genai/types.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3863,13 +3863,23 @@ class ModelDict(TypedDict, total=False):
38633863

38643864
class ListModelsConfig(_common.BaseModel):
38653865

3866+
http_options: Optional[dict[str, Any]] = Field(
3867+
default=None, description="""Used to override HTTP request options."""
3868+
)
38663869
page_size: Optional[int] = Field(default=None, description="""""")
38673870
page_token: Optional[str] = Field(default=None, description="""""")
38683871
filter: Optional[str] = Field(default=None, description="""""")
3872+
query_base: Optional[bool] = Field(
3873+
default=None,
3874+
description="""Set true to list base models, false to list tuned models.""",
3875+
)
38693876

38703877

38713878
class ListModelsConfigDict(TypedDict, total=False):
38723879

3880+
http_options: Optional[dict[str, Any]]
3881+
"""Used to override HTTP request options."""
3882+
38733883
page_size: Optional[int]
38743884
""""""
38753885

@@ -3879,6 +3889,9 @@ class ListModelsConfigDict(TypedDict, total=False):
38793889
filter: Optional[str]
38803890
""""""
38813891

3892+
query_base: Optional[bool]
3893+
"""Set true to list base models, false to list tuned models."""
3894+
38823895

38833896
ListModelsConfigOrDict = Union[ListModelsConfig, ListModelsConfigDict]
38843897

0 commit comments

Comments
 (0)