diff --git a/src/orion/core/worker/consumer.py b/src/orion/core/worker/consumer.py index 165227c3b..52085eb23 100644 --- a/src/orion/core/worker/consumer.py +++ b/src/orion/core/worker/consumer.py @@ -118,9 +118,6 @@ def __call__(self, trial, **kwargs): True if the trial was successfully executed. False if the trial is broken. """ - log.debug("Consumer context: %s", trial.working_dir) - os.makedirs(trial.working_dir, exist_ok=True) - results_file = self._consume(trial, trial.working_dir) log.debug("Parsing results from file and fill corresponding Trial object.") @@ -197,9 +194,16 @@ def get_execution_environment(self, trial, results_file="results.log"): return env - def _consume(self, trial, workdirname): + def _prepare_config(self, trial, workdirname): + log.debug("Consumer context: %s", trial.working_dir) + os.makedirs(trial.working_dir, exist_ok=True) + + if self.template_builder.file_config_path: + _, suffix = os.path.splitext(self.template_builder.file_config_path) + else: + suffix = ".conf" config_file = tempfile.NamedTemporaryFile( - mode="w", prefix="trial_", suffix=".conf", dir=workdirname, delete=False + mode="w", prefix="trial_", suffix=suffix, dir=workdirname, delete=False ) config_file.close() log.debug("New temp config file: %s", config_file.name) @@ -209,6 +213,10 @@ def _consume(self, trial, workdirname): results_file.close() log.debug("New temp results file: %s", results_file.name) + return config_file, results_file + + def _consume(self, trial, workdirname): + config_file, results_file = self._prepare_config(trial, workdirname) log.debug("Building command line argument and configuration for trial.") env = self.get_execution_environment(trial, results_file.name) cmd_args = self.template_builder.format( diff --git a/tests/unittests/core/worker/test_consumer.py b/tests/unittests/core/worker/test_consumer.py index 1a491a178..0bd1ef730 100644 --- a/tests/unittests/core/worker/test_consumer.py +++ b/tests/unittests/core/worker/test_consumer.py @@ -79,6 +79,40 @@ def test_trial_working_dir_is_created(config): shutil.rmtree(trial.working_dir) +@pytest.mark.usefixtures("storage") +@pytest.mark.parametrize( + "config_path_name", + ["yaml_sample_path", "json_sample_path", "unknown_type_template_path"], +) +def test_trial_config_file_is_created_with_correct_ext( + config, config_path_name, request +): + """Check that trial config file is created with correct extension.""" + config_path = request.getfixturevalue(config_path_name) + config["metadata"]["user_args"].insert(1, f"--config={config_path}") + backward.populate_space(config) + exp = experiment_builder.build(**config) + + trial = exp.space.sample()[0] + + exp.register_trial(trial, status="reserved") + + assert not os.path.exists(trial.working_dir) + + con = Consumer(exp) + config_file, results_file = con._prepare_config(trial, trial.working_dir) + + assert os.path.exists(trial.working_dir) + assert os.path.exists(config_file.name) + + _, original_ext = os.path.splitext(config_path) + _, tmp_ext = os.path.splitext(config_file.name) + + assert original_ext == tmp_ext + + shutil.rmtree(trial.working_dir) + + def setup_code_change_mock(config, monkeypatch, ignore_code_changes): """Mock create experiment and trials, and infer_versioning_metadata""" exp = experiment_builder.build(**config)