Skip to content

Commit 309dd26

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: Added Operation and PredictOperation (internal module)
PiperOrigin-RevId: 721069275
1 parent f68aa1f commit 309dd26

2 files changed

Lines changed: 506 additions & 0 deletions

File tree

google/genai/_operations.py

Lines changed: 365 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,365 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
# Code generated by the Google Gen AI SDK generator DO NOT EDIT.
17+
18+
from typing import Optional, Union
19+
from urllib.parse import urlencode
20+
from . import _api_module
21+
from . import _common
22+
from . import types
23+
from ._api_client import ApiClient
24+
from ._common import get_value_by_path as getv
25+
from ._common import set_value_by_path as setv
26+
27+
28+
def _GetOperationParameters_to_mldev(
29+
api_client: ApiClient,
30+
from_object: Union[dict, object],
31+
parent_object: dict = None,
32+
) -> dict:
33+
to_object = {}
34+
if getv(from_object, ['operation_name']) is not None:
35+
setv(
36+
to_object,
37+
['_url', 'operationName'],
38+
getv(from_object, ['operation_name']),
39+
)
40+
41+
if getv(from_object, ['config']) is not None:
42+
setv(to_object, ['config'], getv(from_object, ['config']))
43+
44+
return to_object
45+
46+
47+
def _GetOperationParameters_to_vertex(
48+
api_client: ApiClient,
49+
from_object: Union[dict, object],
50+
parent_object: dict = None,
51+
) -> dict:
52+
to_object = {}
53+
if getv(from_object, ['operation_name']) is not None:
54+
setv(
55+
to_object,
56+
['_url', 'operationName'],
57+
getv(from_object, ['operation_name']),
58+
)
59+
60+
if getv(from_object, ['config']) is not None:
61+
setv(to_object, ['config'], getv(from_object, ['config']))
62+
63+
return to_object
64+
65+
66+
def _FetchPredictOperationParameters_to_mldev(
67+
api_client: ApiClient,
68+
from_object: Union[dict, object],
69+
parent_object: dict = None,
70+
) -> dict:
71+
to_object = {}
72+
if getv(from_object, ['operation_name']) is not None:
73+
raise ValueError('operation_name parameter is not supported in Gemini API.')
74+
75+
if getv(from_object, ['resource_name']) is not None:
76+
raise ValueError('resource_name parameter is not supported in Gemini API.')
77+
78+
if getv(from_object, ['config']) is not None:
79+
raise ValueError('config parameter is not supported in Gemini API.')
80+
81+
return to_object
82+
83+
84+
def _FetchPredictOperationParameters_to_vertex(
85+
api_client: ApiClient,
86+
from_object: Union[dict, object],
87+
parent_object: dict = None,
88+
) -> dict:
89+
to_object = {}
90+
if getv(from_object, ['operation_name']) is not None:
91+
setv(to_object, ['operationName'], getv(from_object, ['operation_name']))
92+
93+
if getv(from_object, ['resource_name']) is not None:
94+
setv(
95+
to_object,
96+
['_url', 'resourceName'],
97+
getv(from_object, ['resource_name']),
98+
)
99+
100+
if getv(from_object, ['config']) is not None:
101+
setv(to_object, ['config'], getv(from_object, ['config']))
102+
103+
return to_object
104+
105+
106+
def _Operation_from_mldev(
107+
api_client: ApiClient,
108+
from_object: Union[dict, object],
109+
parent_object: dict = None,
110+
) -> dict:
111+
to_object = {}
112+
if getv(from_object, ['name']) is not None:
113+
setv(to_object, ['name'], getv(from_object, ['name']))
114+
115+
if getv(from_object, ['metadata']) is not None:
116+
setv(to_object, ['metadata'], getv(from_object, ['metadata']))
117+
118+
if getv(from_object, ['done']) is not None:
119+
setv(to_object, ['done'], getv(from_object, ['done']))
120+
121+
if getv(from_object, ['error']) is not None:
122+
setv(to_object, ['error'], getv(from_object, ['error']))
123+
124+
if getv(from_object, ['response']) is not None:
125+
setv(to_object, ['response'], getv(from_object, ['response']))
126+
127+
return to_object
128+
129+
130+
def _Operation_from_vertex(
131+
api_client: ApiClient,
132+
from_object: Union[dict, object],
133+
parent_object: dict = None,
134+
) -> dict:
135+
to_object = {}
136+
if getv(from_object, ['name']) is not None:
137+
setv(to_object, ['name'], getv(from_object, ['name']))
138+
139+
if getv(from_object, ['metadata']) is not None:
140+
setv(to_object, ['metadata'], getv(from_object, ['metadata']))
141+
142+
if getv(from_object, ['done']) is not None:
143+
setv(to_object, ['done'], getv(from_object, ['done']))
144+
145+
if getv(from_object, ['error']) is not None:
146+
setv(to_object, ['error'], getv(from_object, ['error']))
147+
148+
if getv(from_object, ['response']) is not None:
149+
setv(to_object, ['response'], getv(from_object, ['response']))
150+
151+
return to_object
152+
153+
154+
class _operations(_api_module.BaseModule):
155+
156+
def _get_operation(
157+
self,
158+
*,
159+
operation_name: str,
160+
config: Optional[types.GetOperationConfigOrDict] = None,
161+
) -> types.Operation:
162+
parameter_model = types._GetOperationParameters(
163+
operation_name=operation_name,
164+
config=config,
165+
)
166+
167+
if self._api_client.vertexai:
168+
request_dict = _GetOperationParameters_to_vertex(
169+
self._api_client, parameter_model
170+
)
171+
path = '{operationName}'.format_map(request_dict.get('_url'))
172+
else:
173+
request_dict = _GetOperationParameters_to_mldev(
174+
self._api_client, parameter_model
175+
)
176+
path = '{operationName}'.format_map(request_dict.get('_url'))
177+
query_params = request_dict.get('_query')
178+
if query_params:
179+
path = f'{path}?{urlencode(query_params)}'
180+
# TODO: remove the hack that pops config.
181+
request_dict.pop('config', None)
182+
183+
http_options = None
184+
if isinstance(config, dict):
185+
http_options = config.get('http_options', None)
186+
elif hasattr(config, 'http_options'):
187+
http_options = config.http_options
188+
189+
request_dict = _common.convert_to_dict(request_dict)
190+
request_dict = _common.encode_unserializable_types(request_dict)
191+
192+
response_dict = self._api_client.request(
193+
'get', path, request_dict, http_options
194+
)
195+
196+
if self._api_client.vertexai:
197+
response_dict = _Operation_from_vertex(self._api_client, response_dict)
198+
else:
199+
response_dict = _Operation_from_mldev(self._api_client, response_dict)
200+
201+
return_value = types.Operation._from_response(
202+
response=response_dict, kwargs=parameter_model
203+
)
204+
self._api_client._verify_response(return_value)
205+
return return_value
206+
207+
def _fetch_predict_operation(
208+
self,
209+
*,
210+
operation_name: str,
211+
resource_name: str,
212+
config: Optional[types.FetchPredictOperationConfigOrDict] = None,
213+
) -> types.Operation:
214+
parameter_model = types._FetchPredictOperationParameters(
215+
operation_name=operation_name,
216+
resource_name=resource_name,
217+
config=config,
218+
)
219+
220+
if not self._api_client.vertexai:
221+
raise ValueError('This method is only supported in the Vertex AI client.')
222+
else:
223+
request_dict = _FetchPredictOperationParameters_to_vertex(
224+
self._api_client, parameter_model
225+
)
226+
path = '{resourceName}:fetchPredictOperation'.format_map(
227+
request_dict.get('_url')
228+
)
229+
230+
query_params = request_dict.get('_query')
231+
if query_params:
232+
path = f'{path}?{urlencode(query_params)}'
233+
# TODO: remove the hack that pops config.
234+
request_dict.pop('config', None)
235+
236+
http_options = None
237+
if isinstance(config, dict):
238+
http_options = config.get('http_options', None)
239+
elif hasattr(config, 'http_options'):
240+
http_options = config.http_options
241+
242+
request_dict = _common.convert_to_dict(request_dict)
243+
request_dict = _common.encode_unserializable_types(request_dict)
244+
245+
response_dict = self._api_client.request(
246+
'post', path, request_dict, http_options
247+
)
248+
249+
if self._api_client.vertexai:
250+
response_dict = _Operation_from_vertex(self._api_client, response_dict)
251+
else:
252+
response_dict = _Operation_from_mldev(self._api_client, response_dict)
253+
254+
return_value = types.Operation._from_response(
255+
response=response_dict, kwargs=parameter_model
256+
)
257+
self._api_client._verify_response(return_value)
258+
return return_value
259+
260+
261+
class Async_operations(_api_module.BaseModule):
262+
263+
async def _get_operation(
264+
self,
265+
*,
266+
operation_name: str,
267+
config: Optional[types.GetOperationConfigOrDict] = None,
268+
) -> types.Operation:
269+
parameter_model = types._GetOperationParameters(
270+
operation_name=operation_name,
271+
config=config,
272+
)
273+
274+
if self._api_client.vertexai:
275+
request_dict = _GetOperationParameters_to_vertex(
276+
self._api_client, parameter_model
277+
)
278+
path = '{operationName}'.format_map(request_dict.get('_url'))
279+
else:
280+
request_dict = _GetOperationParameters_to_mldev(
281+
self._api_client, parameter_model
282+
)
283+
path = '{operationName}'.format_map(request_dict.get('_url'))
284+
query_params = request_dict.get('_query')
285+
if query_params:
286+
path = f'{path}?{urlencode(query_params)}'
287+
# TODO: remove the hack that pops config.
288+
request_dict.pop('config', None)
289+
290+
http_options = None
291+
if isinstance(config, dict):
292+
http_options = config.get('http_options', None)
293+
elif hasattr(config, 'http_options'):
294+
http_options = config.http_options
295+
296+
request_dict = _common.convert_to_dict(request_dict)
297+
request_dict = _common.encode_unserializable_types(request_dict)
298+
299+
response_dict = await self._api_client.async_request(
300+
'get', path, request_dict, http_options
301+
)
302+
303+
if self._api_client.vertexai:
304+
response_dict = _Operation_from_vertex(self._api_client, response_dict)
305+
else:
306+
response_dict = _Operation_from_mldev(self._api_client, response_dict)
307+
308+
return_value = types.Operation._from_response(
309+
response=response_dict, kwargs=parameter_model
310+
)
311+
self._api_client._verify_response(return_value)
312+
return return_value
313+
314+
async def _fetch_predict_operation(
315+
self,
316+
*,
317+
operation_name: str,
318+
resource_name: str,
319+
config: Optional[types.FetchPredictOperationConfigOrDict] = None,
320+
) -> types.Operation:
321+
parameter_model = types._FetchPredictOperationParameters(
322+
operation_name=operation_name,
323+
resource_name=resource_name,
324+
config=config,
325+
)
326+
327+
if not self._api_client.vertexai:
328+
raise ValueError('This method is only supported in the Vertex AI client.')
329+
else:
330+
request_dict = _FetchPredictOperationParameters_to_vertex(
331+
self._api_client, parameter_model
332+
)
333+
path = '{resourceName}:fetchPredictOperation'.format_map(
334+
request_dict.get('_url')
335+
)
336+
337+
query_params = request_dict.get('_query')
338+
if query_params:
339+
path = f'{path}?{urlencode(query_params)}'
340+
# TODO: remove the hack that pops config.
341+
request_dict.pop('config', None)
342+
343+
http_options = None
344+
if isinstance(config, dict):
345+
http_options = config.get('http_options', None)
346+
elif hasattr(config, 'http_options'):
347+
http_options = config.http_options
348+
349+
request_dict = _common.convert_to_dict(request_dict)
350+
request_dict = _common.encode_unserializable_types(request_dict)
351+
352+
response_dict = await self._api_client.async_request(
353+
'post', path, request_dict, http_options
354+
)
355+
356+
if self._api_client.vertexai:
357+
response_dict = _Operation_from_vertex(self._api_client, response_dict)
358+
else:
359+
response_dict = _Operation_from_mldev(self._api_client, response_dict)
360+
361+
return_value = types.Operation._from_response(
362+
response=response_dict, kwargs=parameter_model
363+
)
364+
self._api_client._verify_response(return_value)
365+
return return_value

0 commit comments

Comments
 (0)