Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/experiment_generator/perturbation_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,27 @@ def _select_from_list(row: list):
if isinstance(inner, list) and len(inner) == 0:
continue

# handle mixed positional inventories, such as [PRESERVE, {input: ...}]
if any(isinstance(d, Mapping) for d in inner):
slots = []
for i in inner:
# keep markers as is
if _is_removed_str(i) or _is_preserved_str(i):
slots.append(i)
continue

# recurse into dicts
if isinstance(i, Mapping):
keep_v, cleaned = _filter_value(
self._extract_run_specific_params(i, indx, total_exps)
)
slots.append(cleaned if keep_v else {})
continue

# scalar slots
slots.append(i)

# inner is all dicts
if all(isinstance(d, Mapping) for d in inner):
# recurse for each dict
items = []
Expand Down
5 changes: 5 additions & 0 deletions src/experiment_generator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,11 @@ def update_config_entries(
if not should_apply:
continue # no change for this key; keep existing value

# if the incoming value is a single value list but the base value is scalar,
# then treat it as a scalar.
if isinstance(v, list) and len(v) == 1 and not isinstance(base.get(k), list):
v = v[0]

key_path = _path_join(path, str(k))

if isinstance(v, Mapping) and isinstance(base.get(k), Mapping):
Expand Down
13 changes: 13 additions & 0 deletions tests/test_perturbation_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,19 @@ def test_manage_perturb_expt_creat_branches_applies_updates_and_commits(
# sequence branch: inner list is [PRESERVE]
({"queue3": [["PRESERVE"]]}, 0, 2, {"queue3": ["PRESERVE"]}),
({"queue3": [["PRESERVE"]]}, 1, 2, {"queue3": ["PRESERVE"]}),
# mixed positional list of lists with markers and mappings and scalars
(
{"submodels": [["PRESERVE", {}, "foo"]]},
0,
2,
{"submodels": ["PRESERVE", None, "foo"]},
),
(
{"submodels": [["REMOVE", {}, "foo"]]},
1,
2,
{"submodels": ["REMOVE", None, "foo"]},
),
],
)
def test_extract_run_specific_params_rules(tmp_repo_dir, indata, param_dict, indx, total, expected):
Expand Down
10 changes: 10 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,16 @@ def test_merge_lists_positional_mapping_branch_uses_current_mapping_slot():
assert out == [{"a": 1, "keep": 0, "b": 2}]


def test_update_config_entries_unwraps_single_item_list_to_scalar_when_base_is_scalar():
base = {"ncpus": 4}

changes = {"ncpus": [8]} # single-item list

update_config_entries(base, changes, pop_key=True)

assert base == {"ncpus": 8}


# def test_empty_mapping_in_change_keeps_existing_slot():
# base = {"lst": [{"a": 1}, {"b": 2}]}
# # empty mapping also means PRESERVE
Expand Down
Loading