|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from typing import TYPE_CHECKING, Any |
| 4 | + |
| 5 | +from griffe._internal.docstrings.models import DocstringParameter, DocstringSectionParameters |
| 6 | +from griffe._internal.enumerations import DocstringSectionKind, ParameterKind |
| 7 | +from griffe._internal.expressions import Expr, ExprSubscript |
| 8 | +from griffe._internal.extensions.base import Extension |
| 9 | +from griffe._internal.models import Class, Docstring, Function, Parameter, Parameters |
| 10 | + |
| 11 | +if TYPE_CHECKING: |
| 12 | + from collections.abc import Iterable |
| 13 | + |
| 14 | + |
| 15 | +def _update_docstring(func: Function, parameters: Iterable[Parameter], kwparam: Parameter | None = None) -> None: |
| 16 | + if not func.docstring: |
| 17 | + func.docstring = Docstring("", parent=func) |
| 18 | + sections = func.docstring.parsed |
| 19 | + section_gen = (section for section in sections if section.kind is DocstringSectionKind.parameters) |
| 20 | + if kwparam and (params_section := next(section_gen, None)): |
| 21 | + # Remove the `**kwargs` entry. |
| 22 | + param_gen = (i for i, arg in enumerate(params_section.value) if arg.name.lstrip("*") == kwparam.name) |
| 23 | + if (kwarg_pos := next(param_gen, None)) is not None: |
| 24 | + params_section.value.pop(kwarg_pos) |
| 25 | + else: |
| 26 | + # Create a parameters section if none exists. |
| 27 | + params_section = DocstringSectionParameters([]) |
| 28 | + func.docstring.parsed.append(params_section) |
| 29 | + # Add entries for all TypedDict attributes. |
| 30 | + for param in parameters: |
| 31 | + if param.name != "self": |
| 32 | + params_section.value.append( |
| 33 | + DocstringParameter( |
| 34 | + name=param.name, |
| 35 | + description=param.docstring.value if param.docstring else "", |
| 36 | + annotation=param.annotation, |
| 37 | + value=param.default, |
| 38 | + ), |
| 39 | + ) |
| 40 | + |
| 41 | + |
| 42 | +def _params_from_attrs(attrs: Iterable[Any]) -> Parameters: |
| 43 | + return Parameters( |
| 44 | + Parameter(name="self", kind=ParameterKind.positional_or_keyword), |
| 45 | + *( |
| 46 | + Parameter( |
| 47 | + name=attr.name, |
| 48 | + annotation=attr.annotation, |
| 49 | + kind=ParameterKind.keyword_only, |
| 50 | + default=attr.value, |
| 51 | + docstring=attr.docstring, |
| 52 | + ) |
| 53 | + for attr in attrs |
| 54 | + ), |
| 55 | + ) |
| 56 | + |
| 57 | + |
| 58 | +class UnpackTypedDictExtension(Extension): |
| 59 | + """An extension to handle `Unpack[TypeDict]`.""" |
| 60 | + |
| 61 | + def on_class(self, *, cls: Class, **kwargs: Any) -> None: # noqa: ARG002 |
| 62 | + """Add an `__init__` method to `TypedDict` classes if missing.""" |
| 63 | + for base in cls.bases: |
| 64 | + if isinstance(base, Expr) and base.canonical_path in {"typing.TypedDict", "typing_extensions.TypedDict"}: |
| 65 | + cls.labels.add("typed-dict") |
| 66 | + break |
| 67 | + else: |
| 68 | + return |
| 69 | + |
| 70 | + attributes = cls.attributes.values() |
| 71 | + |
| 72 | + if "__init__" not in cls.members: |
| 73 | + # Build the `__init__` method and add it to the class. |
| 74 | + parameters = _params_from_attrs(attributes) |
| 75 | + init = Function(name="__init__", parameters=parameters, returns="None") |
| 76 | + cls.set_member("__init__", init) |
| 77 | + # Update the `__init__` docstring. |
| 78 | + _update_docstring(init, parameters) |
| 79 | + |
| 80 | + # Remove attributes from the class, as they are now in the `__init__` method. |
| 81 | + for attr in attributes: |
| 82 | + cls.del_member(attr.name) |
| 83 | + |
| 84 | + def on_function(self, *, func: Function, **kwargs: Any) -> None: # noqa: ARG002 |
| 85 | + """Replace `**kwargs: Unpack[TypedDict]` parameters with the actual TypedDict attributes.""" |
| 86 | + # Find any `**kwargs: Unpack[TypedDict]` parameter. |
| 87 | + for parameter in func.parameters: |
| 88 | + if parameter.kind is ParameterKind.var_keyword: |
| 89 | + annotation = parameter.annotation |
| 90 | + if isinstance(annotation, ExprSubscript) and annotation.canonical_path in { |
| 91 | + "typing.Annotated", |
| 92 | + "typing_extensions.Annotated", |
| 93 | + }: |
| 94 | + annotation = annotation.slice.elements[0] # type: ignore[union-attr] |
| 95 | + if isinstance(annotation, ExprSubscript) and annotation.canonical_path in { |
| 96 | + "typing.Unpack", |
| 97 | + "typing_extensions.Unpack", |
| 98 | + }: |
| 99 | + slice_path = annotation.slice.canonical_path # type: ignore[union-attr] |
| 100 | + typed_dict = func.modules_collection[slice_path] |
| 101 | + break |
| 102 | + else: |
| 103 | + return |
| 104 | + |
| 105 | + if "__init__" in typed_dict.members: |
| 106 | + # The `__init__` was already generated: use its parameters. |
| 107 | + parameters = typed_dict["__init__"].parameters |
| 108 | + else: |
| 109 | + # Fallback to building parameters from attributes. |
| 110 | + parameters = _params_from_attrs(typed_dict.attributes.values()) |
| 111 | + |
| 112 | + # Update any parameter section in the docstring. |
| 113 | + # We do this before updating the signature so that |
| 114 | + # parsing the docstring doesn't emit warnings. |
| 115 | + _update_docstring(func, parameters, parameter) |
| 116 | + |
| 117 | + # Update the function parameters. |
| 118 | + del func.parameters[parameter.name] |
| 119 | + for param in parameters: |
| 120 | + if param.name != "self": |
| 121 | + func.parameters[param.name] = Parameter( |
| 122 | + name=param.name, |
| 123 | + annotation=param.annotation, |
| 124 | + kind=ParameterKind.keyword_only, |
| 125 | + default=param.default, |
| 126 | + docstring=param.docstring, |
| 127 | + ) |
0 commit comments