Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion packages/atproto_client/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import typing as t
import warnings

from pydantic import BaseModel, ConfigDict, alias_generators, model_validator
from pydantic import BaseModel, ConfigDict, Field, alias_generators, model_validator

from atproto_client.exceptions import ModelFieldNotFoundError

Expand Down Expand Up @@ -75,5 +75,9 @@ class UnknownRecord(ModelBase):
pass


class UnknownUnionModel(ModelBase):
py_type: str = Field(alias='$type')


class RecordModelBase(UnknownRecord):
pass
18 changes: 13 additions & 5 deletions packages/atproto_codegen/models/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,11 @@ def _get_ref_union_typehint(nsid: NSID, field_type_def: models.LexRefUnion, *, o
if not def_names and field_type_def.closed:
raise ValueError('The schema is invalid because union must have at least one type when it is closed')

if not def_names:
# actually it's a union of unknown types but it must have $type field.
# we do specify more correct type here for now
def_names.append('t.Any')
is_unknown_union = not def_names

if is_unknown_union:
# union of unknown types but it must have $type field.
def_names.append('base.UnknownUnionModel')

# unbelievable but it's true. If schema doesn't describe the right type in Union
# we should fall back to the plain data
Expand All @@ -217,7 +218,14 @@ def _get_ref_union_typehint(nsid: NSID, field_type_def: models.LexRefUnion, *, o
# append 't.Dict[str, t.Any]' to def_names # FIXME(MarshalX): support pydantic

def_names = ', '.join([f"'{name}'" for name in def_names])
def_field_meta = 'Field(default=None, discriminator="py_type")' if optional else 'Field(discriminator="py_type")'

if is_unknown_union:
# unknown type does not compatible with discriminator because $type is unknown :)
def_field_meta = 'Field(default=None)' if optional else 'Field()'
else:
def_field_meta = (
'Field(default=None, discriminator="py_type")' if optional else 'Field(discriminator="py_type")'
)

annotated_union = f'te.Annotated[t.Union[{def_names}], {def_field_meta}]'
return _get_optional_typehint(annotated_union, optional=optional)
Expand Down
25 changes: 25 additions & 0 deletions tests/test_atproto_client/models/tests/test_unknown_union_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pytest
from atproto_client.models.base import UnknownUnionModel
from pydantic import ValidationError


def test_unknown_union_model() -> None:
"""Test the UnknownUnionModel class."""
expected_type = 'app.bsky.embed.record#view'

model = UnknownUnionModel(py_type=expected_type)
assert model.py_type == expected_type

with pytest.raises(ValidationError):
# Attempt to create an instance without $type
UnknownUnionModel(blabla='blabla')

model = UnknownUnionModel.model_validate_json(f'{{"$type": "{expected_type}"}}')
assert model.py_type == expected_type

with pytest.raises(ValidationError):
# Attempt to create an instance without $type
UnknownUnionModel.model_validate_json('{"literallyNoTypeInfo": "lol"}')

model_with_extra = UnknownUnionModel(py_type=expected_type, extra_field='extra_value')
assert model_with_extra.extra_field == 'extra_value'
Loading