Skip to content

Commit bfc472e

Browse files
committed
Add more typing to safer
1 parent 1d481eb commit bfc472e

File tree

1 file changed

+57
-44
lines changed

1 file changed

+57
-44
lines changed

safer/__init__.py

Lines changed: 57 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -147,32 +147,27 @@
147147
import contextlib
148148
import functools
149149
import io
150-
import json
151150
import os
152151
import shutil
153152
import sys
154153
import tempfile
155154
import traceback
155+
import typing as t
156156
from pathlib import Path
157-
from typing import IO, Callable, Optional, Union
158-
159-
# There's an edge case in #23 I can't yet fix, so I fail
160-
# deliberately
161-
BUG_MESSAGE = 'Sorry, safer.writer fails if temp_file (#23)'
162157

163158
__all__ = 'writer', 'open', 'closer', 'dump', 'printer'
164159

165160

166161
def writer(
167-
stream: Union[Callable, None, IO, Path, str] = None,
168-
is_binary: Optional[bool] = None,
162+
stream: t.Union[t.Callable, None, t.IO, Path, str] = None,
163+
is_binary: t.Optional[bool] = None,
169164
close_on_exit: bool = False,
170165
temp_file: bool = False,
171166
chunk_size: int = 0x100000,
172167
delete_failures: bool = True,
173-
dry_run: Union[bool, Callable] = False,
168+
dry_run: t.Union[bool, t.Callable] = False,
174169
enabled: bool = True,
175-
) -> Union[Callable, IO]:
170+
) -> t.Union[t.Callable, t.IO]:
176171
"""
177172
Write safely to file streams, sockets and callables.
178173
@@ -233,7 +228,7 @@ def writer(
233228
if not enabled:
234229
return stream
235230

236-
write: Optional[Callable]
231+
write: t.Optional[t.Callable]
237232

238233
if callable(dry_run):
239234
write, dry_run = dry_run, True
@@ -307,21 +302,26 @@ def write(v):
307302
return closer.fp
308303

309304

305+
# There's an edge case in #23 I can't yet fix, so I fail
306+
# deliberately
307+
BUG_MESSAGE = 'Sorry, safer.writer fails if temp_file (#23)'
308+
309+
310310
def open(
311-
name: Union[Path, str],
311+
name: t.Union[Path, str],
312312
mode: str = 'r',
313313
buffering: int = -1,
314-
encoding: Optional[str] = None,
315-
errors: Optional[str] = None,
316-
newline: Optional[str] = None,
314+
encoding: t.Optional[str] = None,
315+
errors: t.Optional[str] = None,
316+
newline: t.Optional[str] = None,
317317
closefd: bool = True,
318-
opener: Optional[Callable] = None,
318+
opener: t.Optional[t.Callable] = None,
319319
make_parents: bool = False,
320320
delete_failures: bool = True,
321321
temp_file: bool = False,
322-
dry_run: Union[bool, Callable] = False,
322+
dry_run: t.Union[bool, t.Callable] = False,
323323
enabled: bool = True,
324-
) -> IO:
324+
) -> t.IO:
325325
"""
326326
Args:
327327
make_parents: If true, create the parent directory of the file if needed
@@ -443,7 +443,9 @@ def simple_write(value):
443443
return closer._make_stream(buffering, mode, **kwargs)
444444

445445

446-
def closer(stream, is_binary=None, close_on_exit=True, **kwds):
446+
def closer(
447+
stream: t.IO, is_binary: t.Optional[bool] = None, close_on_exit: bool = True, **kwds
448+
) -> t.Union[t.Callable, t.IO]:
447449
"""
448450
Like `safer.writer()` but with `close_on_exit=True` by default
449451
@@ -453,7 +455,12 @@ def closer(stream, is_binary=None, close_on_exit=True, **kwds):
453455
return writer(stream, is_binary, close_on_exit, **kwds)
454456

455457

456-
def dump(obj, stream=None, dump=None, **kwargs):
458+
def dump(
459+
obj,
460+
stream: t.Union[t.Callable, None, t.IO, Path, str] = None,
461+
dump: t.Any = None,
462+
**kwargs,
463+
) -> t.Any:
457464
"""
458465
Safely serialize `obj` as a formatted stream to `fp`` (a
459466
`.write()`-supporting file-like object, or a filename),
@@ -476,23 +483,34 @@ def dump(obj, stream=None, dump=None, **kwargs):
476483
kwargs:
477484
Additional arguments to `dump`.
478485
"""
479-
if isinstance(stream, str):
480-
name = stream
481-
is_binary = False
482-
else:
483-
name = getattr(stream, 'name', None)
484-
mode = getattr(stream, 'mode', None)
486+
if not isinstance(stream, str):
487+
name = getattr(stream, 'name', '')
488+
mode = getattr(stream, 'mode', '')
485489
if name and mode:
486490
is_binary = 'b' in mode
487491
else:
488492
is_binary = hasattr(stream, 'recv') and hasattr(stream, 'send')
493+
else:
494+
name = stream
495+
is_binary = False
489496

490-
if name and not dump:
491-
dump = Path(name).suffix[1:] or None
492-
if dump == 'yml':
493-
dump = 'yaml'
497+
dump = _get_dumper(dump or Path(name).suffix[1:])
498+
499+
with t.cast(t.IO, writer(stream)) as fp:
500+
if is_binary:
501+
write = fp.write
502+
fp.write = lambda s: write(s.encode('utf-8')) # type: ignore
494503

504+
return dump(obj, fp)
505+
506+
507+
def _get_dumper(dump: t.Any) -> t.Callable:
495508
if isinstance(dump, str):
509+
if not dump:
510+
dump = 'json'
511+
elif dump == 'yml':
512+
dump = 'yaml'
513+
496514
try:
497515
dump = __import__(dump)
498516
except ImportError:
@@ -501,24 +519,19 @@ def dump(obj, stream=None, dump=None, **kwargs):
501519
mod, name = dump.rsplit('.', maxsplit=1)
502520
dump = getattr(__import__(mod), name)
503521

504-
if dump is None:
505-
dump = json.dump
506-
507-
elif not callable(dump):
508-
try:
509-
dump = dump.safe_dump
510-
except AttributeError:
511-
dump = dump.dump
522+
if callable(dump):
523+
return dump
512524

513-
with writer(stream) as fp:
514-
if is_binary:
515-
write = fp.write
516-
fp.write = lambda s: write(s.encode('utf-8'))
517-
return dump(obj, fp)
525+
try:
526+
return dump.safe_dump
527+
except AttributeError:
528+
return dump.dump
518529

519530

520531
@contextlib.contextmanager
521-
def printer(name, mode='w', *args, **kwargs):
532+
def printer(
533+
name: t.Union[Path, str], mode: str = 'w', *args, **kwargs
534+
) -> t.Generator[t.Callable, None, None]:
522535
"""
523536
A context manager that yields a function that prints to the opened file,
524537
only writing to the original file at the exit of the context,

0 commit comments

Comments
 (0)