Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions homeassistant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,11 +330,40 @@ def _async_stop_handler(self, *args):
self.exit_code = 0
self.loop.create_task(self.async_stop())

@asyncio.coroutine
def _async_check_config_and_restart(self):
"""Restart Home Assistant if config is valid.

This method is a coroutine.
"""
proc = yield from asyncio.create_subprocess_exec(
sys.argv[0],
'--script',
'check_config',
stdout=asyncio.subprocess.PIPE)
# Wait for the subprocess exit
(stdout_data, dummy) = yield from proc.communicate()
result = yield from proc.wait()
if result:
_LOGGER.error("check_config failed. Not restarting.")
content = re.sub(r'\033\[[^m]*m', '', str(stdout_data, 'utf-8'))
# Put error cleaned from color codes in the error log so it
# will be visible at the UI.
_LOGGER.error(content)
yield from self.services.async_call(
'persistent_notification', 'create', {
'message': 'Config error. See dev-info panel for details.',
'title': 'Restarting',
'notification_id': '{}.restart'.format(DOMAIN)})
return

self.exit_code = RESTART_EXIT_CODE
yield from self.async_stop()

@callback
def _async_restart_handler(self, *args):
"""Restart Home Assistant."""
self.exit_code = RESTART_EXIT_CODE
self.loop.create_task(self.async_stop())
self.loop.create_task(self._async_check_config_and_restart())


class EventOrigin(enum.Enum):
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/scripts/check_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def run(script_args: List) -> int:
domain_info = args.info.split(',')

res = check(config_path)

if args.files:
print(color(C_HEAD, 'yaml files'), '(used /',
color('red', 'not used') + ')')
Expand Down Expand Up @@ -247,6 +246,7 @@ def mock_package_error( # pylint: disable=unused-variable
res['secret_cache'] = dict(yaml.__SECRET_CACHE)
except Exception as err: # pylint: disable=broad-except
print(color('red', 'Fatal error while loading config:'), str(err))
res['except'].setdefault(ERROR_STR, []).append(err)
finally:
# Stop all patches
for pat in PATCHES.values():
Expand Down
5 changes: 5 additions & 0 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,11 @@ def coro():
return coro


def mock_generator(return_value=None):
"""Helper method to return a coro generator that returns a value."""
return mock_coro(return_value)()


@contextmanager
def assert_setup_component(count, domain=None):
"""Collect valid configuration from setup_component.
Expand Down
38 changes: 36 additions & 2 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
import homeassistant.util.dt as dt_util
from homeassistant.util.unit_system import (METRIC_SYSTEM)
from homeassistant.const import (
__version__, EVENT_STATE_CHANGED, ATTR_FRIENDLY_NAME, CONF_UNIT_SYSTEM)
__version__, EVENT_STATE_CHANGED, ATTR_FRIENDLY_NAME, CONF_UNIT_SYSTEM,
SERVICE_HOMEASSISTANT_RESTART, RESTART_EXIT_CODE)

from tests.common import get_test_home_assistant
from tests.common import get_test_home_assistant, mock_generator

PST = pytz.timezone('America/Los_Angeles')

Expand Down Expand Up @@ -220,6 +221,39 @@ def test_add_job_with_none(self):
with pytest.raises(ValueError):
self.hass.add_job(None, 'test_arg')

@patch('asyncio.create_subprocess_exec')
def test_restart(self, mock_create):
"""Check that restart propagates to stop."""
process_mock = MagicMock()
attrs = {
'communicate.return_value': mock_generator(('output', 'error')),
'wait.return_value': mock_generator(0)}
process_mock.configure_mock(**attrs)
mock_create.return_value = mock_generator(process_mock)

self.hass.start()
with patch.object(self.hass, 'async_stop') as mock_stop:
self.hass.services.call(ha.DOMAIN, SERVICE_HOMEASSISTANT_RESTART)
mock_stop.assert_called_once_with()
self.assertEqual(RESTART_EXIT_CODE, self.hass.exit_code)

@patch('asyncio.create_subprocess_exec')
def test_restart_bad_config(self, mock_create):
"""Check that restart with a bad config doesn't propagate to stop."""
process_mock = MagicMock()
attrs = {
'communicate.return_value':
mock_generator((r'\033[hellom'.encode('utf-8'), 'error')),
'wait.return_value': mock_generator(1)}
process_mock.configure_mock(**attrs)
mock_create.return_value = mock_generator(process_mock)

self.hass.start()
with patch.object(self.hass, 'async_stop') as mock_stop:
self.hass.services.call(ha.DOMAIN, SERVICE_HOMEASSISTANT_RESTART)
mock_stop.assert_not_called()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tip for the future: don't use assert_not_called because if you would make a typo like calling mock_stop.assert_not_called_obviously_wrong() it will pass because it will return a mock object. Better to test explicitly:

assert mock_stop.called

self.assertEqual(None, self.hass.exit_code)


class TestEvent(unittest.TestCase):
"""A Test Event class."""
Expand Down