Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 11 additions & 13 deletions pre_commit_hooks/fix_encoding_pragma.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
from typing import Sequence
from typing import Union

DEFAULT_PRAGMA = b'# -*- coding: utf-8 -*-\n'
DEFAULT_PRAGMA = b'# -*- coding: utf-8 -*-'


def has_coding(line): # type: (bytes) -> bool
if not line.strip():
return False
return (
line.lstrip()[0:1] == b'#' and (
line.lstrip()[:1] == b'#' and (
b'unicode' in line or
b'encoding' in line or
b'coding:' in line or
Expand All @@ -26,7 +26,7 @@ def has_coding(line): # type: (bytes) -> bool


class ExpectedContents(collections.namedtuple(
'ExpectedContents', ('shebang', 'rest', 'pragma_status'),
'ExpectedContents', ('shebang', 'rest', 'pragma_status', 'ending'),
)):
"""
pragma_status:
Expand All @@ -47,6 +47,8 @@ def is_expected_pragma(self, remove): # type: (bool) -> bool

def _get_expected_contents(first_line, second_line, rest, expected_pragma):
# type: (bytes, bytes, bytes, bytes) -> ExpectedContents
ending = b'\r\n' if first_line.endswith(b'\r\n') else b'\n'

if first_line.startswith(b'#!'):
shebang = first_line
potential_coding = second_line
Expand All @@ -55,7 +57,7 @@ def _get_expected_contents(first_line, second_line, rest, expected_pragma):
potential_coding = first_line
rest = second_line + rest

if potential_coding == expected_pragma:
if potential_coding.rstrip(b'\r\n') == expected_pragma:
pragma_status = True # type: Optional[bool]
elif has_coding(potential_coding):
pragma_status = None
Expand All @@ -64,7 +66,7 @@ def _get_expected_contents(first_line, second_line, rest, expected_pragma):
rest = potential_coding + rest

return ExpectedContents(
shebang=shebang, rest=rest, pragma_status=pragma_status,
shebang=shebang, rest=rest, pragma_status=pragma_status, ending=ending,
)


Expand Down Expand Up @@ -93,7 +95,7 @@ def fix_encoding_pragma(f, remove=False, expected_pragma=DEFAULT_PRAGMA):
f.truncate()
f.write(expected.shebang)
if not remove:
f.write(expected_pragma)
f.write(expected_pragma + expected.ending)
f.write(expected.rest)

return 1
Expand All @@ -102,11 +104,7 @@ def fix_encoding_pragma(f, remove=False, expected_pragma=DEFAULT_PRAGMA):
def _normalize_pragma(pragma): # type: (Union[bytes, str]) -> bytes
if not isinstance(pragma, bytes):
pragma = pragma.encode('UTF-8')
return pragma.rstrip() + b'\n'


def _to_disp(pragma): # type: (bytes) -> str
return pragma.decode().rstrip()
return pragma.rstrip()


def main(argv=None): # type: (Optional[Sequence[str]]) -> int
Expand All @@ -117,7 +115,7 @@ def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser.add_argument(
'--pragma', default=DEFAULT_PRAGMA, type=_normalize_pragma,
help='The encoding pragma to use. Default: {}'.format(
_to_disp(DEFAULT_PRAGMA),
DEFAULT_PRAGMA.decode(),
),
)
parser.add_argument(
Expand All @@ -141,7 +139,7 @@ def main(argv=None): # type: (Optional[Sequence[str]]) -> int
retv |= file_ret
if file_ret:
print(fmt.format(
pragma=_to_disp(args.pragma), filename=filename,
pragma=args.pragma.decode(), filename=filename,
))

return retv
Expand Down
23 changes: 18 additions & 5 deletions tests/fix_encoding_pragma_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,15 @@ def test_not_ok_inputs(input_str, output):
def test_ok_input_alternate_pragma():
input_s = b'# coding: utf-8\nx = 1\n'
bytesio = io.BytesIO(input_s)
ret = fix_encoding_pragma(bytesio, expected_pragma=b'# coding: utf-8\n')
ret = fix_encoding_pragma(bytesio, expected_pragma=b'# coding: utf-8')
assert ret == 0
bytesio.seek(0)
assert bytesio.read() == input_s


def test_not_ok_input_alternate_pragma():
bytesio = io.BytesIO(b'x = 1\n')
ret = fix_encoding_pragma(bytesio, expected_pragma=b'# coding: utf-8\n')
ret = fix_encoding_pragma(bytesio, expected_pragma=b'# coding: utf-8')
assert ret == 1
bytesio.seek(0)
assert bytesio.read() == b'# coding: utf-8\nx = 1\n'
Expand All @@ -130,11 +130,11 @@ def test_not_ok_input_alternate_pragma():
('input_s', 'expected'),
(
# Python 2 cli parameters are bytes
(b'# coding: utf-8', b'# coding: utf-8\n'),
(b'# coding: utf-8', b'# coding: utf-8'),
# Python 3 cli parameters are text
('# coding: utf-8', b'# coding: utf-8\n'),
('# coding: utf-8', b'# coding: utf-8'),
# trailing whitespace
('# coding: utf-8\n', b'# coding: utf-8\n'),
('# coding: utf-8\n', b'# coding: utf-8'),
),
)
def test_normalize_pragma(input_s, expected):
Expand All @@ -150,3 +150,16 @@ def test_integration_alternate_pragma(tmpdir, capsys):
assert f.read() == '# coding: utf-8\nx = 1\n'
out, _ = capsys.readouterr()
assert out == 'Added `# coding: utf-8` to {}\n'.format(f.strpath)


def test_crlf_ok(tmpdir):
f = tmpdir.join('f.py')
f.write_binary(b'# -*- coding: utf-8 -*-\r\nx = 1\r\n')
assert not main((f.strpath,))


def test_crfl_adds(tmpdir):
f = tmpdir.join('f.py')
f.write_binary(b'x = 1\r\n')
assert main((f.strpath,))
assert f.read_binary() == b'# -*- coding: utf-8 -*-\r\nx = 1\r\n'