|
25 | 25 | from typing import Any, Literal, Optional, Union |
26 | 26 |
|
27 | 27 | import google.auth |
28 | | -from pydantic import BaseModel |
29 | 28 | from requests.exceptions import HTTPError |
30 | 29 |
|
31 | 30 | from . import errors |
|
34 | 33 | from ._api_client import HttpRequest |
35 | 34 | from ._api_client import HttpResponse |
36 | 35 | from ._api_client import RequestJsonEncoder |
| 36 | +from ._common import BaseModel |
37 | 37 |
|
38 | 38 | def _redact_version_numbers(version_string: str) -> str: |
39 | 39 | """Redacts version numbers in the form x.y.z from a string.""" |
@@ -264,18 +264,9 @@ def close(self): |
264 | 264 | replay_file_path = self._get_replay_file_path() |
265 | 265 | os.makedirs(os.path.dirname(replay_file_path), exist_ok=True) |
266 | 266 | with open(replay_file_path, 'w') as f: |
267 | | - replay_session_dict = self.replay_session.model_dump() |
268 | | - # Use for non-utf-8 bytes in image/video... output. |
269 | | - for interaction in replay_session_dict['interactions']: |
270 | | - segments = [] |
271 | | - for response in interaction['response']['sdk_response_segments']: |
272 | | - segments.append(json.loads(json.dumps( |
273 | | - response, cls=ResponseJsonEncoder |
274 | | - ))) |
275 | | - interaction['response']['sdk_response_segments'] = segments |
276 | 267 | f.write( |
277 | 268 | json.dumps( |
278 | | - replay_session_dict, indent=2, cls=RequestJsonEncoder |
| 269 | + self.replay_session.model_dump(mode='json'), indent=2, cls=ResponseJsonEncoder |
279 | 270 | ) |
280 | 271 | ) |
281 | 272 | self.replay_session = None |
@@ -376,15 +367,8 @@ def _verify_response(self, response_model: BaseModel): |
376 | 367 | if isinstance(response_model, list): |
377 | 368 | response_model = response_model[0] |
378 | 369 | print('response_model: ', response_model.model_dump(exclude_none=True)) |
379 | | - actual = json.dumps( |
380 | | - response_model.model_dump(exclude_none=True), |
381 | | - cls=ResponseJsonEncoder, |
382 | | - sort_keys=True, |
383 | | - ) |
384 | | - expected = json.dumps( |
385 | | - interaction.response.sdk_response_segments[self._sdk_response_index], |
386 | | - sort_keys=True, |
387 | | - ) |
| 370 | + actual = response_model.model_dump(exclude_none=True, mode='json') |
| 371 | + expected = interaction.response.sdk_response_segments[self._sdk_response_index] |
388 | 372 | assert ( |
389 | 373 | actual == expected |
390 | 374 | ), f'SDK response mismatch:\nActual: {actual}\nExpected: {expected}' |
@@ -437,36 +421,12 @@ def upload_file(self, file_path: str, upload_url: str, upload_size: int): |
437 | 421 | return self._build_response_from_replay(request).text |
438 | 422 |
|
439 | 423 |
|
| 424 | +# TODO(b/389693448): Cleanup datetime hacks. |
440 | 425 | class ResponseJsonEncoder(json.JSONEncoder): |
441 | 426 | """The replay test json encoder for response. |
442 | | -
|
443 | | - We need RequestJsonEncoder and ResponseJsonEncoder because: |
444 | | - 1. In production, we only need RequestJsonEncoder to help json module |
445 | | - to convert non-stringable and stringable types to json string. Especially |
446 | | - for bytes type, the value of bytes field is encoded to base64 string so it |
447 | | - is always stringable and the RequestJsonEncoder doesn't have to deal with |
448 | | - utf-8 JSON broken issue. |
449 | | - 2. In replay test, we also need ResponseJsonEncoder to help json module |
450 | | - convert non-stringable and stringable types to json string. But response |
451 | | - object returned from SDK method is different from the request api_client |
452 | | - sent to server. For the bytes type, there is no base64 string in response |
453 | | - anymore, because SDK handles it internally. So bytes type in Response is |
454 | | - non-stringable. The ResponseJsonEncoder uses different encoding |
455 | | - strategy than the RequestJsonEncoder to deal with utf-8 JSON broken issue. |
456 | 427 | """ |
457 | 428 | def default(self, o): |
458 | | - if isinstance(o, bytes): |
459 | | - # Use base64.b64encode() to encode bytes to string so that the media bytes |
460 | | - # fields are serializable. |
461 | | - # o.decode(encoding='utf-8', errors='replace') doesn't work because it |
462 | | - # uses a fixed error string `\ufffd` for all non-utf-8 characters, |
463 | | - # which cannot be converted back to original bytes. And other languages |
464 | | - # only have the original bytes to compare with. |
465 | | - # Since we use base64.b64encoding() in replay test, a change that breaks |
466 | | - # native bytes can be captured by |
467 | | - # test_compute_tokens.py::test_token_bytes_deserialization. |
468 | | - return base64.b64encode(o).decode(encoding='utf-8') |
469 | | - elif isinstance(o, datetime.datetime): |
| 429 | + if isinstance(o, datetime.datetime): |
470 | 430 | # dt.isoformat() prints "2024-11-15T23:27:45.624657+00:00" |
471 | 431 | # but replay files want "2024-11-15T23:27:45.624657Z" |
472 | 432 | if o.isoformat().endswith('+00:00'): |
|
0 commit comments