|
31 | 31 | from superduper.base.metadata import Job |
32 | 32 |
|
33 | 33 |
|
| 34 | +def ensure_setup(func): |
| 35 | + """Decorator to ensure that the model is initialized before calling the function. |
| 36 | +
|
| 37 | + :param func: Decorator function. |
| 38 | + """ |
| 39 | + |
| 40 | + @wraps(func) |
| 41 | + def wrapper(self, *args, **kwargs): |
| 42 | + if not getattr(self, "_is_setup", False): |
| 43 | + model_message = f"{self.__class__.__name__} : {self.identifier}" |
| 44 | + logging.debug(f"Initializing {model_message}") |
| 45 | + self.setup() |
| 46 | + self._is_setup = True |
| 47 | + logging.debug(f"Initialized {model_message} successfully") |
| 48 | + return func(self, *args, **kwargs) |
| 49 | + |
| 50 | + return wrapper |
| 51 | + |
| 52 | + |
34 | 53 | def propagate_failure(f): |
35 | 54 | """Propagate failure decorator. |
36 | 55 |
|
@@ -180,6 +199,7 @@ def __post_init__(self, db: t.Optional['Datalayer'] = None): |
180 | 199 | self.postinit() |
181 | 200 |
|
182 | 201 | @property |
| 202 | + @ensure_setup |
183 | 203 | def metadata(self): |
184 | 204 | """Get metadata of the component.""" |
185 | 205 | return {k: getattr(self, k) for k in self.metadata_fields} |
@@ -308,10 +328,14 @@ def get_children(self, deep: bool = False) -> t.List["Component"]: |
308 | 328 |
|
309 | 329 | :param deep: If set `True` get all recursively. |
310 | 330 | """ |
311 | | - from superduper.base.datatype import Saveable |
| 331 | + from superduper.base.datatype import ComponentRef, Saveable |
312 | 332 |
|
313 | 333 | r = self.dict().encode(leaves_to_keep=(Component, Saveable)) |
314 | | - out = [v for v in r['_builds'].values() if isinstance(v, Component)] |
| 334 | + out = [ |
| 335 | + v.setup() |
| 336 | + for v in r['_builds'].values() |
| 337 | + if isinstance(v, (Component, ComponentRef)) |
| 338 | + ] |
315 | 339 | lookup = {} |
316 | 340 | for v in out: |
317 | 341 | lookup[id(v)] = v |
@@ -725,22 +749,3 @@ def hash(self): |
725 | 749 | breaking = hash_item(breaking_hashes) |
726 | 750 | non_breaking = hash_item(non_breaking_hashes) |
727 | 751 | return breaking[:32] + non_breaking[:32] |
728 | | - |
729 | | - |
730 | | -def ensure_setup(func): |
731 | | - """Decorator to ensure that the model is initialized before calling the function. |
732 | | -
|
733 | | - :param func: Decorator function. |
734 | | - """ |
735 | | - |
736 | | - @wraps(func) |
737 | | - def wrapper(self, *args, **kwargs): |
738 | | - if not getattr(self, "_is_setup", False): |
739 | | - model_message = f"{self.__class__.__name__} : {self.identifier}" |
740 | | - logging.debug(f"Initializing {model_message}") |
741 | | - self.setup() |
742 | | - self._is_setup = True |
743 | | - logging.debug(f"Initialized {model_message} successfully") |
744 | | - return func(self, *args, **kwargs) |
745 | | - |
746 | | - return wrapper |
|
0 commit comments