Skip to content

Commit 05e8fb0

Browse files
authored
Merge pull request #4436 from tybug/typing-patching
Add type hints to `_patching.py`
2 parents 324cfc6 + 033b4be commit 05e8fb0

2 files changed

Lines changed: 52 additions & 19 deletions

File tree

hypothesis-python/RELEASE.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
RELEASE_TYPE: patch
2+
3+
Add type hints to internal code for patching.

hypothesis-python/src/hypothesis/extra/_patching.py

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@
2626
import sys
2727
import types
2828
from ast import literal_eval
29+
from collections.abc import Sequence
2930
from contextlib import suppress
3031
from datetime import date, datetime, timedelta, timezone
3132
from pathlib import Path
33+
from typing import Any, Optional
3234

3335
import libcst as cst
3436
from libcst import matchers as m
@@ -56,7 +58,7 @@
5658
_leading_space_re = re.compile("(^[ ]*)(?:[^ \n])", re.MULTILINE)
5759

5860

59-
def dedent(text):
61+
def dedent(text: str) -> tuple[str, str]:
6062
# Simplified textwrap.dedent, for valid Python source code only
6163
text = _space_only_re.sub("", text)
6264
prefix = min(_leading_space_re.findall(text), key=len)
@@ -70,7 +72,14 @@ def indent(text: str, prefix: str) -> str:
7072
class AddExamplesCodemod(VisitorBasedCodemodCommand):
7173
DESCRIPTION = "Add explicit examples to failing tests."
7274

73-
def __init__(self, context, fn_examples, strip_via=(), dec="example", width=88):
75+
def __init__(
76+
self,
77+
context: CodemodContext,
78+
fn_examples: dict[str, list[tuple[cst.Call, str]]],
79+
strip_via: tuple[str, ...] = (),
80+
decorator: str = "example",
81+
width: int = 88,
82+
):
7483
"""Add @example() decorator(s) for failing test(s).
7584
7685
`code` is the source code of the module where the test functions are defined.
@@ -79,21 +88,29 @@ def __init__(self, context, fn_examples, strip_via=(), dec="example", width=88):
7988
assert fn_examples, "This codemod does nothing without fn_examples."
8089
super().__init__(context)
8190

82-
self.decorator_func = cst.parse_expression(dec)
91+
self.decorator_func = cst.parse_expression(decorator)
8392
self.line_length = width
84-
value_in_strip_via = m.MatchIfTrue(lambda x: literal_eval(x.value) in strip_via)
93+
value_in_strip_via: Any = m.MatchIfTrue(
94+
lambda x: literal_eval(x.value) in strip_via
95+
)
8596
self.strip_matching = m.Call(
8697
m.Attribute(m.Call(), m.Name("via")),
8798
[m.Arg(m.SimpleString() & value_in_strip_via)],
8899
)
89100

90101
# Codemod the failing examples to Call nodes usable as decorators
91102
self.fn_examples = {
92-
k: tuple(d for x in nodes if (d := self.__call_node_to_example_dec(*x)))
103+
k: tuple(
104+
d
105+
for (node, via) in nodes
106+
if (d := self.__call_node_to_example_dec(node, via))
107+
)
93108
for k, nodes in fn_examples.items()
94109
}
95110

96-
def __call_node_to_example_dec(self, node, via):
111+
def __call_node_to_example_dec(
112+
self, node: cst.Call, via: str
113+
) -> Optional[cst.Decorator]:
97114
# If we have black installed, remove trailing comma, _unless_ there's a comment
98115
node = node.with_changes(
99116
func=self.decorator_func,
@@ -112,7 +129,7 @@ def __call_node_to_example_dec(self, node, via):
112129
else node.args
113130
),
114131
)
115-
via = cst.Call(
132+
via: cst.BaseExpression = cst.Call(
116133
func=cst.Attribute(node, cst.Name("via")),
117134
args=[cst.Arg(cst.SimpleString(repr(via)))],
118135
)
@@ -127,7 +144,9 @@ def __call_node_to_example_dec(self, node, via):
127144
via = cst.parse_expression(pretty.strip())
128145
return cst.Decorator(via)
129146

130-
def leave_FunctionDef(self, _, updated_node):
147+
def leave_FunctionDef(
148+
self, _original_node: cst.FunctionDef, updated_node: cst.FunctionDef
149+
) -> cst.FunctionDef:
131150
return updated_node.with_changes(
132151
# TODO: improve logic for where in the list to insert this decorator
133152
decorators=tuple(
@@ -140,11 +159,16 @@ def leave_FunctionDef(self, _, updated_node):
140159
)
141160

142161

143-
def get_patch_for(func, failing_examples, *, strip_via=()):
162+
def get_patch_for(
163+
func: Any,
164+
examples: Sequence[tuple[str, str]],
165+
*,
166+
strip_via: tuple[str, ...] = (),
167+
) -> Optional[tuple[str, str, str]]:
144168
# Skip this if we're unable to find the location or source of this function.
145169
try:
146170
module = sys.modules[func.__module__]
147-
fname = Path(module.__file__).relative_to(Path.cwd())
171+
fname = Path(module.__file__).relative_to(Path.cwd()) # type: ignore
148172
before = inspect.getsource(func)
149173
except Exception:
150174
return None
@@ -160,10 +184,10 @@ def get_patch_for(func, failing_examples, *, strip_via=()):
160184

161185
# The printed examples might include object reprs which are invalid syntax,
162186
# so we parse here and skip over those. If _none_ are valid, there's no patch.
163-
call_nodes = []
164-
for ex, via in set(failing_examples):
187+
call_nodes: list[tuple[cst.Call, str]] = []
188+
for ex, via in set(examples):
165189
with suppress(Exception):
166-
node = cst.parse_module(ex)
190+
node: Any = cst.parse_module(ex)
167191
the_call = node.body[0].body[0].value
168192
assert isinstance(the_call, cst.Call), the_call
169193
# Check for st.data(), which doesn't support explicit examples
@@ -194,14 +218,15 @@ def get_patch_for(func, failing_examples, *, strip_via=()):
194218
with suppress(Exception):
195219
wrapper = cst.metadata.MetadataWrapper(node)
196220
kwarg_names = {
197-
a.keyword for a in m.findall(wrapper, m.Arg(keyword=m.Name()))
221+
node.keyword # type: ignore
222+
for node in m.findall(wrapper, m.Arg(keyword=m.Name()))
198223
}
199224
node = m.replace(
200225
wrapper,
201226
m.Name(value=m.MatchIfTrue(names.__contains__))
202227
& m.MatchMetadata(ExpressionContextProvider, ExpressionContext.LOAD)
203-
& m.MatchIfTrue(lambda n, k=kwarg_names: n not in k),
204-
replacement=lambda node, _, ns=names: ns[node.value],
228+
& m.MatchIfTrue(lambda n, k=kwarg_names: n not in k), # type: ignore
229+
replacement=lambda node, _, ns=names: ns[node.value], # type: ignore
205230
)
206231
node = node.body[0].body[0].value
207232
assert isinstance(node, cst.Call), node
@@ -229,18 +254,23 @@ def get_patch_for(func, failing_examples, *, strip_via=()):
229254
CodemodContext(),
230255
fn_examples={func.__name__: call_nodes},
231256
strip_via=strip_via,
232-
dec=decorator_func,
257+
decorator=decorator_func,
233258
width=88 - len(prefix), # to match Black's default formatting
234259
).transform_module(node)
235260
return (str(fname), before, indent(after.code, prefix=prefix))
236261

237262

238-
def make_patch(triples, *, msg="Hypothesis: add explicit examples", when=None):
263+
def make_patch(
264+
triples: Sequence[tuple[str, str, str]],
265+
*,
266+
msg: str = "Hypothesis: add explicit examples",
267+
when: Optional[datetime] = None,
268+
) -> str:
239269
"""Create a patch for (fname, before, after) triples."""
240270
assert triples, "attempted to create empty patch"
241271
when = when or datetime.now(tz=timezone.utc)
242272

243-
by_fname = {}
273+
by_fname: dict[Path, list[tuple[str, str]]] = {}
244274
for fname, before, after in triples:
245275
by_fname.setdefault(Path(fname), []).append((before, after))
246276

0 commit comments

Comments
 (0)