1515import os
1616import re
1717import shutil
18- from dataclasses import dataclass
1918from datetime import timedelta
2019from pathlib import Path
2120from typing import Any , Dict , Iterable , Optional , Union
2726from pytorch_lightning .callbacks .model_checkpoint import _is_local_file_protocol
2827from pytorch_lightning .utilities import rank_zero_info
2928
30- from nemo .collections .common .callbacks import EMA
3129from nemo .utils import logging
3230from nemo .utils .app_state import AppState
33- from nemo .utils .exp_manager import get_git_diff , get_git_hash
34- from nemo .utils .get_rank import is_global_rank_zero
35- from nemo .utils .lightning_logger_patch import add_filehandlers_to_pl_logger
3631from nemo .utils .model_utils import ckpt_to_dir
3732
3833
@@ -74,6 +69,10 @@ def __init__(
7469 )
7570
7671 def on_train_start (self , trainer , pl_module ):
72+ from nemo .utils .exp_manager import get_git_diff , get_git_hash
73+ from nemo .utils .get_rank import is_global_rank_zero
74+ from nemo .utils .lightning_logger_patch import add_filehandlers_to_pl_logger
75+
7776 app_state = AppState ()
7877 if self .save_top_k != - 1 and app_state .restore :
7978 logging .debug ("Checking previous runs" )
@@ -205,6 +204,8 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
205204 self ._remove_invalid_entries_from_topk ()
206205
207206 def setup (self , * args , ** kwargs ) -> None :
207+ from nemo .utils .get_rank import is_global_rank_zero
208+
208209 if is_global_rank_zero ():
209210 logging .debug ("Removing unfinished checkpoints if any..." )
210211 ModelCheckpoint ._remove_unfinished_checkpoints (self .dirpath )
@@ -260,6 +261,7 @@ def on_train_end(self, trainer, pl_module):
260261 trainer ._checkpoint_connector .restore (self .best_model_path )
261262
262263 def _del_model_without_trainer (self , filepath : str ) -> None :
264+ from nemo .utils .get_rank import is_global_rank_zero
263265
264266 filepath = Path (filepath )
265267
@@ -273,7 +275,9 @@ def _del_model_without_trainer(self, filepath: str) -> None:
273275 if torch .distributed .is_initialized ():
274276 torch .distributed .barrier ()
275277
276- def _ema_callback (self , trainer : 'pytorch_lightning.Trainer' ) -> Optional [EMA ]:
278+ def _ema_callback (self , trainer : 'pytorch_lightning.Trainer' ):
279+ from nemo .collections .common .callbacks import EMA
280+
277281 ema_callback = None
278282 for callback in trainer .callbacks :
279283 if isinstance (callback , EMA ):
@@ -321,6 +325,8 @@ def set_checkpoint_unfinished_marker(checkpoint_path: Union[Path, str], barrier_
321325 barrier_after: Synchronize ranks after writing the marker file.
322326 Defaults to False.
323327 """
328+ from nemo .utils .get_rank import is_global_rank_zero
329+
324330 if is_global_rank_zero ():
325331 marker_path = ModelCheckpoint .format_checkpoint_unfinished_marker_path (checkpoint_path )
326332 marker_path .parent .mkdir (parents = True , exist_ok = True )
@@ -338,6 +344,8 @@ def remove_checkpoint_unfinished_marker(checkpoint_path: Union[Path, str], barri
338344 barrier_before: Synchronize ranks before removing the marker file.
339345 Defaults to False.
340346 """
347+ from nemo .utils .get_rank import is_global_rank_zero
348+
341349 try :
342350 if barrier_before and torch .distributed .is_initialized ():
343351 torch .distributed .barrier ()
@@ -434,6 +442,7 @@ def _saved_checkpoint_paths(self) -> Iterable[Path]:
434442
435443 @staticmethod
436444 def _remove_unfinished_checkpoints (checkpoint_dir : Union [Path , str ]) -> None :
445+ from nemo .utils .get_rank import is_global_rank_zero
437446
438447 # Delete unfinished checkpoints from the filesystems.
439448 # "Unfinished marker" files are removed as well.
0 commit comments