Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions direct/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,11 @@ def load_from_path(
self._load_model(self.model, checkpoint["model"])

for key in checkpointable_objects:
if key not in checkpoint:
self.logger.warning(f"Requested to load {key}, but this was not stored.")
if only_models and not re.match(self.model_regex, key):
continue

if only_models and not re.match(self.model_regex, key):
if key not in checkpoint:
self.logger.warning(f"Requested to load {key}, but this was not stored.")
continue

self.logger.info(f"Loading {key}...")
Expand Down Expand Up @@ -209,7 +209,7 @@ def _load_checkpoint(self, checkpoint_path: PathOrString) -> Dict:
# Check if the path is an URL
if check_is_valid_url(str(checkpoint_path)):
self.logger.info(f"Initializing from remote checkpoint {checkpoint_path}...")
checkpoint_path = _download_or_load_from_cache(checkpoint_path)
checkpoint_path = self._download_or_load_from_cache(checkpoint_path)
self.logger.info(f"Loading downloaded checkpoint {checkpoint_path}.")

checkpoint_path = pathlib.Path(checkpoint_path)
Expand All @@ -231,13 +231,13 @@ def _load_checkpoint(self, checkpoint_path: PathOrString) -> Dict:

return checkpoint

@staticmethod
def _download_or_load_from_cache(url: str) -> pathlib.Path:
# Get final part of url.
file_path = urllib.parse.urlparse(url).path
filename = pathlib.Path(file_path).name

def _download_or_load_from_cache(url: str) -> pathlib.Path:
# Get final part of url.
file_path = urllib.parse.urlparse(url).path
filename = pathlib.Path(file_path).name

cache_path = DIRECT_MODEL_DOWNLOAD_DIR / filename
download_url(url, DIRECT_MODEL_DOWNLOAD_DIR, max_redirect_hops=3)
cache_path = DIRECT_MODEL_DOWNLOAD_DIR / filename
download_url(url, DIRECT_MODEL_DOWNLOAD_DIR, max_redirect_hops=3)

return cache_path
return cache_path
3 changes: 2 additions & 1 deletion direct/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def register_parser(parser: argparse._SubParsersAction):
"When another checkpoint would be available and the --resume flag is used, "
"this flag is ignored. This can be a path to a file or an URL. "
"If a URL is given the checkpoint will first be downloaded to the environmental variable "
"`DIRECT_MODEL_DOWNLOAD_DIR` (default=current directory).",
"`DIRECT_MODEL_DOWNLOAD_DIR` (default=current directory). Be aware that if `model_checkpoint` is "
"set in the configuration that this flag will overwrite the configuration value, also in the dumped config.",
)
train_parser.add_argument("--resume", help="Resume training if possible.", action="store_true")
train_parser.add_argument(
Expand Down
5 changes: 4 additions & 1 deletion direct/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Copyright (c) DIRECT Contributors

from dataclasses import dataclass, field
from typing import Any, List, Optional
from typing import Any, List, Optional, Union

from omegaconf import MISSING

Expand Down Expand Up @@ -43,6 +43,9 @@ class TrainingConfig(BaseConfig):
# Dataset
datasets: List[Any] = field(default_factory=lambda: [DatasetConfig()])

# model_checkpoint gives the checkpoint from which we can load the *model* weights.
model_checkpoint: Optional[str] = None

# Optimizer
optimizer: str = "Adam"
lr: float = 5e-4
Expand Down
16 changes: 13 additions & 3 deletions direct/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,19 @@ def setup_train(
# Just to make sure.
torch.cuda.empty_cache()

# Check the initialization checkpoint
if env.cfg.training.model_checkpoint:
if initialization_checkpoint:
logger.warning(
f"`--initialization-checkpoint is set, and config has a set `training.model_checkpoint`: "
f"{env.cfg.model_checkpoint}. Will overwrite config variable with the command line: "
f"{initialization_checkpoint}."
)
# Now overwrite this in the configuration, so the correct value is dumped.
env.cfg.training.model_checkpoint = str(initialization_checkpoint)
else:
initialization_checkpoint = env.cfg.training.model_checkpoint

env.engine.train(
optimizer,
lr_scheduler,
Expand All @@ -265,9 +278,6 @@ def train_from_argparse(args: argparse.Namespace):
torch.set_num_threads(1)
os.environ["OMP_NUM_THREADS"] = "1"

# Remove warnings from named tensors being experimental
os.environ["PYTHONWARNINGS"] = "ignore"

if args.initialization_images is not None and args.initialization_kspace is not None:
sys.exit("--initialization-images and --initialization-kspace are mutually exclusive.")
check_train_val(args.initialization_images, "initialization-images")
Expand Down