Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 61 additions & 1 deletion hydra/_internal/config_loader_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,16 @@
import sys
import warnings
from textwrap import dedent
from typing import Any, List, MutableSequence, Optional, Tuple
from typing import (
Any,
List,
MutableSequence,
Optional,
Tuple,
Union,
get_args,
get_origin,
)

from omegaconf import Container, DictConfig, OmegaConf, flag_override, open_dict
from omegaconf.errors import (
Expand Down Expand Up @@ -548,6 +557,11 @@ def _compose_config_from_defaults_list(
for default in defaults:
loaded = self._load_single_config(default=default, repo=repo)
try:
if isinstance(cfg, DictConfig) and isinstance(
loaded.config, DictConfig
):
self._materialize_structures(cfg, loaded.config)

cfg.merge_with(loaded.config)
except OmegaConfBaseException as e:
raise ConfigCompositionException(
Expand All @@ -572,6 +586,52 @@ def strip_defaults(cfg: Any) -> None:

return cfg

def _materialize_structures(self, dest: DictConfig, src: DictConfig) -> None:
"""
Recursively materialize None-valued DictConfig nodes in dest if src has a corresponding DictConfig override.
Uses _metadata.ref_type (internal API) to inspect the underlying structured type of a None node.
"""
if not isinstance(dest, DictConfig) or not isinstance(src, DictConfig):
return

for key in src:
try:
dest_node = dest._get_node(key)
except (ConfigKeyError, ConfigAttributeError):
continue
except AttributeError:
# AttributeError observed in Hydra tests when dest contains unresolved interpolation nodes;
# skipping materialization avoids raising during compose.
continue

if not isinstance(dest_node, DictConfig):
continue

if dest_node._is_none():
src_node = src._get_node(key)
if not isinstance(src_node, DictConfig):
continue

ref_type = dest_node._metadata.ref_type
if ref_type is not Any:
if get_origin(ref_type) is Union:
args = [a for a in get_args(ref_type) if a is not type(None)]
if len(args) == 1:
ref_type = args[0]

try:
if ref_type is not Any:
dest[key] = OmegaConf.structured(ref_type)
else:
dest[key] = {}
except Exception:
dest[key] = {}

else:
src_node = src._get_node(key)
if isinstance(src_node, DictConfig):
self._materialize_structures(dest_node, src_node)

def get_sources(self) -> List[ConfigSource]:
return self.repository.get_sources()

Expand Down
20 changes: 20 additions & 0 deletions tests/test_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,3 +894,23 @@ class Conf:

cfg = compose("conf")
assert cfg == {"enum_dict": {}, "int_dict": {}, "str_dict": {}}


def test_compose_merge_into_none_structured_node(hydra_restore_singletons: Any) -> None:
@dataclass
class Child:
pass

@dataclass
class Config:
child: Optional[Child] = None

ConfigStore.instance().store(name="config_2502", node=Config)
ConfigStore.instance().store(
group="group_2502", name="option", node={"child": {}}, package="_global_"
)

with initialize(version_base=None):
cfg = compose(config_name="config_2502", overrides=["+group_2502=option"])
assert cfg.child is not None
assert OmegaConf.get_type(cfg.child) is Child