2626import sys
2727import types
2828from ast import literal_eval
29+ from collections .abc import Sequence
2930from contextlib import suppress
3031from datetime import date , datetime , timedelta , timezone
3132from pathlib import Path
33+ from typing import Any , Optional
3234
3335import libcst as cst
3436from libcst import matchers as m
5658_leading_space_re = re .compile ("(^[ ]*)(?:[^ \n ])" , re .MULTILINE )
5759
5860
59- def dedent (text ) :
61+ def dedent (text : str ) -> tuple [ str , str ] :
6062 # Simplified textwrap.dedent, for valid Python source code only
6163 text = _space_only_re .sub ("" , text )
6264 prefix = min (_leading_space_re .findall (text ), key = len )
@@ -70,7 +72,14 @@ def indent(text: str, prefix: str) -> str:
7072class AddExamplesCodemod (VisitorBasedCodemodCommand ):
7173 DESCRIPTION = "Add explicit examples to failing tests."
7274
73- def __init__ (self , context , fn_examples , strip_via = (), dec = "example" , width = 88 ):
75+ def __init__ (
76+ self ,
77+ context : CodemodContext ,
78+ fn_examples : dict [str , list [tuple [cst .Call , str ]]],
79+ strip_via : tuple [str , ...] = (),
80+ decorator : str = "example" ,
81+ width : int = 88 ,
82+ ):
7483 """Add @example() decorator(s) for failing test(s).
7584
7685 `code` is the source code of the module where the test functions are defined.
@@ -79,21 +88,29 @@ def __init__(self, context, fn_examples, strip_via=(), dec="example", width=88):
7988 assert fn_examples , "This codemod does nothing without fn_examples."
8089 super ().__init__ (context )
8190
82- self .decorator_func = cst .parse_expression (dec )
91+ self .decorator_func = cst .parse_expression (decorator )
8392 self .line_length = width
84- value_in_strip_via = m .MatchIfTrue (lambda x : literal_eval (x .value ) in strip_via )
93+ value_in_strip_via : Any = m .MatchIfTrue (
94+ lambda x : literal_eval (x .value ) in strip_via
95+ )
8596 self .strip_matching = m .Call (
8697 m .Attribute (m .Call (), m .Name ("via" )),
8798 [m .Arg (m .SimpleString () & value_in_strip_via )],
8899 )
89100
90101 # Codemod the failing examples to Call nodes usable as decorators
91102 self .fn_examples = {
92- k : tuple (d for x in nodes if (d := self .__call_node_to_example_dec (* x )))
103+ k : tuple (
104+ d
105+ for (node , via ) in nodes
106+ if (d := self .__call_node_to_example_dec (node , via ))
107+ )
93108 for k , nodes in fn_examples .items ()
94109 }
95110
96- def __call_node_to_example_dec (self , node , via ):
111+ def __call_node_to_example_dec (
112+ self , node : cst .Call , via : str
113+ ) -> Optional [cst .Decorator ]:
97114 # If we have black installed, remove trailing comma, _unless_ there's a comment
98115 node = node .with_changes (
99116 func = self .decorator_func ,
@@ -112,7 +129,7 @@ def __call_node_to_example_dec(self, node, via):
112129 else node .args
113130 ),
114131 )
115- via = cst .Call (
132+ via : cst . BaseExpression = cst .Call (
116133 func = cst .Attribute (node , cst .Name ("via" )),
117134 args = [cst .Arg (cst .SimpleString (repr (via )))],
118135 )
@@ -127,7 +144,9 @@ def __call_node_to_example_dec(self, node, via):
127144 via = cst .parse_expression (pretty .strip ())
128145 return cst .Decorator (via )
129146
130- def leave_FunctionDef (self , _ , updated_node ):
147+ def leave_FunctionDef (
148+ self , _original_node : cst .FunctionDef , updated_node : cst .FunctionDef
149+ ) -> cst .FunctionDef :
131150 return updated_node .with_changes (
132151 # TODO: improve logic for where in the list to insert this decorator
133152 decorators = tuple (
@@ -140,11 +159,16 @@ def leave_FunctionDef(self, _, updated_node):
140159 )
141160
142161
143- def get_patch_for (func , failing_examples , * , strip_via = ()):
162+ def get_patch_for (
163+ func : Any ,
164+ examples : Sequence [tuple [str , str ]],
165+ * ,
166+ strip_via : tuple [str , ...] = (),
167+ ) -> Optional [tuple [str , str , str ]]:
144168 # Skip this if we're unable to find the location or source of this function.
145169 try :
146170 module = sys .modules [func .__module__ ]
147- fname = Path (module .__file__ ).relative_to (Path .cwd ())
171+ fname = Path (module .__file__ ).relative_to (Path .cwd ()) # type: ignore
148172 before = inspect .getsource (func )
149173 except Exception :
150174 return None
@@ -160,10 +184,10 @@ def get_patch_for(func, failing_examples, *, strip_via=()):
160184
161185 # The printed examples might include object reprs which are invalid syntax,
162186 # so we parse here and skip over those. If _none_ are valid, there's no patch.
163- call_nodes = []
164- for ex , via in set (failing_examples ):
187+ call_nodes : list [ tuple [ cst . Call , str ]] = []
188+ for ex , via in set (examples ):
165189 with suppress (Exception ):
166- node = cst .parse_module (ex )
190+ node : Any = cst .parse_module (ex )
167191 the_call = node .body [0 ].body [0 ].value
168192 assert isinstance (the_call , cst .Call ), the_call
169193 # Check for st.data(), which doesn't support explicit examples
@@ -194,14 +218,15 @@ def get_patch_for(func, failing_examples, *, strip_via=()):
194218 with suppress (Exception ):
195219 wrapper = cst .metadata .MetadataWrapper (node )
196220 kwarg_names = {
197- a .keyword for a in m .findall (wrapper , m .Arg (keyword = m .Name ()))
221+ node .keyword # type: ignore
222+ for node in m .findall (wrapper , m .Arg (keyword = m .Name ()))
198223 }
199224 node = m .replace (
200225 wrapper ,
201226 m .Name (value = m .MatchIfTrue (names .__contains__ ))
202227 & m .MatchMetadata (ExpressionContextProvider , ExpressionContext .LOAD )
203- & m .MatchIfTrue (lambda n , k = kwarg_names : n not in k ),
204- replacement = lambda node , _ , ns = names : ns [node .value ],
228+ & m .MatchIfTrue (lambda n , k = kwarg_names : n not in k ), # type: ignore
229+ replacement = lambda node , _ , ns = names : ns [node .value ], # type: ignore
205230 )
206231 node = node .body [0 ].body [0 ].value
207232 assert isinstance (node , cst .Call ), node
@@ -229,18 +254,23 @@ def get_patch_for(func, failing_examples, *, strip_via=()):
229254 CodemodContext (),
230255 fn_examples = {func .__name__ : call_nodes },
231256 strip_via = strip_via ,
232- dec = decorator_func ,
257+ decorator = decorator_func ,
233258 width = 88 - len (prefix ), # to match Black's default formatting
234259 ).transform_module (node )
235260 return (str (fname ), before , indent (after .code , prefix = prefix ))
236261
237262
238- def make_patch (triples , * , msg = "Hypothesis: add explicit examples" , when = None ):
263+ def make_patch (
264+ triples : Sequence [tuple [str , str , str ]],
265+ * ,
266+ msg : str = "Hypothesis: add explicit examples" ,
267+ when : Optional [datetime ] = None ,
268+ ) -> str :
239269 """Create a patch for (fname, before, after) triples."""
240270 assert triples , "attempted to create empty patch"
241271 when = when or datetime .now (tz = timezone .utc )
242272
243- by_fname = {}
273+ by_fname : dict [ Path , list [ tuple [ str , str ]]] = {}
244274 for fname , before , after in triples :
245275 by_fname .setdefault (Path (fname ), []).append ((before , after ))
246276
0 commit comments