diff --git a/pre_commit_hooks/fix_encoding_pragma.py b/pre_commit_hooks/fix_encoding_pragma.py index bde4e78a..23fc79fd 100644 --- a/pre_commit_hooks/fix_encoding_pragma.py +++ b/pre_commit_hooks/fix_encoding_pragma.py @@ -9,14 +9,14 @@ from typing import Sequence from typing import Union -DEFAULT_PRAGMA = b'# -*- coding: utf-8 -*-\n' +DEFAULT_PRAGMA = b'# -*- coding: utf-8 -*-' def has_coding(line): # type: (bytes) -> bool if not line.strip(): return False return ( - line.lstrip()[0:1] == b'#' and ( + line.lstrip()[:1] == b'#' and ( b'unicode' in line or b'encoding' in line or b'coding:' in line or @@ -26,7 +26,7 @@ def has_coding(line): # type: (bytes) -> bool class ExpectedContents(collections.namedtuple( - 'ExpectedContents', ('shebang', 'rest', 'pragma_status'), + 'ExpectedContents', ('shebang', 'rest', 'pragma_status', 'ending'), )): """ pragma_status: @@ -47,6 +47,8 @@ def is_expected_pragma(self, remove): # type: (bool) -> bool def _get_expected_contents(first_line, second_line, rest, expected_pragma): # type: (bytes, bytes, bytes, bytes) -> ExpectedContents + ending = b'\r\n' if first_line.endswith(b'\r\n') else b'\n' + if first_line.startswith(b'#!'): shebang = first_line potential_coding = second_line @@ -55,7 +57,7 @@ def _get_expected_contents(first_line, second_line, rest, expected_pragma): potential_coding = first_line rest = second_line + rest - if potential_coding == expected_pragma: + if potential_coding.rstrip(b'\r\n') == expected_pragma: pragma_status = True # type: Optional[bool] elif has_coding(potential_coding): pragma_status = None @@ -64,7 +66,7 @@ def _get_expected_contents(first_line, second_line, rest, expected_pragma): rest = potential_coding + rest return ExpectedContents( - shebang=shebang, rest=rest, pragma_status=pragma_status, + shebang=shebang, rest=rest, pragma_status=pragma_status, ending=ending, ) @@ -93,7 +95,7 @@ def fix_encoding_pragma(f, remove=False, expected_pragma=DEFAULT_PRAGMA): f.truncate() f.write(expected.shebang) if not remove: - f.write(expected_pragma) + f.write(expected_pragma + expected.ending) f.write(expected.rest) return 1 @@ -102,11 +104,7 @@ def fix_encoding_pragma(f, remove=False, expected_pragma=DEFAULT_PRAGMA): def _normalize_pragma(pragma): # type: (Union[bytes, str]) -> bytes if not isinstance(pragma, bytes): pragma = pragma.encode('UTF-8') - return pragma.rstrip() + b'\n' - - -def _to_disp(pragma): # type: (bytes) -> str - return pragma.decode().rstrip() + return pragma.rstrip() def main(argv=None): # type: (Optional[Sequence[str]]) -> int @@ -117,7 +115,7 @@ def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser.add_argument( '--pragma', default=DEFAULT_PRAGMA, type=_normalize_pragma, help='The encoding pragma to use. Default: {}'.format( - _to_disp(DEFAULT_PRAGMA), + DEFAULT_PRAGMA.decode(), ), ) parser.add_argument( @@ -141,7 +139,7 @@ def main(argv=None): # type: (Optional[Sequence[str]]) -> int retv |= file_ret if file_ret: print(fmt.format( - pragma=_to_disp(args.pragma), filename=filename, + pragma=args.pragma.decode(), filename=filename, )) return retv diff --git a/tests/fix_encoding_pragma_test.py b/tests/fix_encoding_pragma_test.py index 7288bfa1..d94b7256 100644 --- a/tests/fix_encoding_pragma_test.py +++ b/tests/fix_encoding_pragma_test.py @@ -112,7 +112,7 @@ def test_not_ok_inputs(input_str, output): def test_ok_input_alternate_pragma(): input_s = b'# coding: utf-8\nx = 1\n' bytesio = io.BytesIO(input_s) - ret = fix_encoding_pragma(bytesio, expected_pragma=b'# coding: utf-8\n') + ret = fix_encoding_pragma(bytesio, expected_pragma=b'# coding: utf-8') assert ret == 0 bytesio.seek(0) assert bytesio.read() == input_s @@ -120,7 +120,7 @@ def test_ok_input_alternate_pragma(): def test_not_ok_input_alternate_pragma(): bytesio = io.BytesIO(b'x = 1\n') - ret = fix_encoding_pragma(bytesio, expected_pragma=b'# coding: utf-8\n') + ret = fix_encoding_pragma(bytesio, expected_pragma=b'# coding: utf-8') assert ret == 1 bytesio.seek(0) assert bytesio.read() == b'# coding: utf-8\nx = 1\n' @@ -130,11 +130,11 @@ def test_not_ok_input_alternate_pragma(): ('input_s', 'expected'), ( # Python 2 cli parameters are bytes - (b'# coding: utf-8', b'# coding: utf-8\n'), + (b'# coding: utf-8', b'# coding: utf-8'), # Python 3 cli parameters are text - ('# coding: utf-8', b'# coding: utf-8\n'), + ('# coding: utf-8', b'# coding: utf-8'), # trailing whitespace - ('# coding: utf-8\n', b'# coding: utf-8\n'), + ('# coding: utf-8\n', b'# coding: utf-8'), ), ) def test_normalize_pragma(input_s, expected): @@ -150,3 +150,16 @@ def test_integration_alternate_pragma(tmpdir, capsys): assert f.read() == '# coding: utf-8\nx = 1\n' out, _ = capsys.readouterr() assert out == 'Added `# coding: utf-8` to {}\n'.format(f.strpath) + + +def test_crlf_ok(tmpdir): + f = tmpdir.join('f.py') + f.write_binary(b'# -*- coding: utf-8 -*-\r\nx = 1\r\n') + assert not main((f.strpath,)) + + +def test_crfl_adds(tmpdir): + f = tmpdir.join('f.py') + f.write_binary(b'x = 1\r\n') + assert main((f.strpath,)) + assert f.read_binary() == b'# -*- coding: utf-8 -*-\r\nx = 1\r\n'