Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions direct/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors
"""DIRECT Command-line interface. This is the file which builds the main parser. Currently just a placeholder"""
"""DIRECT Command-line interface. This is the file which builds the main parser."""
import argparse
import sys


def main():
"""
Console script for dlup.
Console script for direct.
"""
# From https://stackoverflow.com/questions/17073688/how-to-use-argparse-subparsers-correctly
root_parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

root_subparsers = root_parser.add_subparsers(help="DIRECT utilities.")
root_subparsers = root_parser.add_subparsers(help="Direct CLI utils to run.")
root_subparsers.required = True
root_subparsers.dest = "subcommand"

# Prevent circular import
from direct.cli.train import register_parser as register_train_subcommand

# Training images related commands.
register_train_subcommand(root_subparsers)

args = root_parser.parse_args()
args.subcommand(args)
return 0


if __name__ == "__main__":
sys.exit(main()) # pragma: no cover
main()
67 changes: 67 additions & 0 deletions direct/cli/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors
import argparse
import pathlib

from direct.cli.utils import file_or_url
from direct.environment import Args
from direct.train import train_from_argparse


def register_parser(parser: argparse._SubParsersAction):
"""Register wsi commands to a root parser."""

epilog = f"""
Examples:
---------
Run on single machine:
$ direct train training_set validation_set experiment_dir --num-gpus 8 --cfg cfg.yaml
Run on multiple machines:
(machine0)$ direct train training_set validation_set experiment_dir --machine-rank 0 --num-machines 2 --dist-url <URL> [--other-flags]
(machine1)$ direct train training_set validation_set experiment_dir --machine-rank 1 --num-machines 2 --dist-url <URL> [--other-flags]
"""
common_parser = Args(add_help=False)
train_parser = parser.add_parser(
"train",
help="Train models using direct.",
parents=[common_parser],
epilog=epilog,
formatter_class=argparse.RawDescriptionHelpFormatter,
)

train_parser.add_argument("training_root", type=pathlib.Path, help="Path to the training data.")
train_parser.add_argument("validation_root", type=pathlib.Path, help="Path to the validation data.")
train_parser.add_argument(
"experiment_dir",
type=pathlib.Path,
help="Path to the experiment directory.",
)
train_parser.add_argument(
"--cfg",
dest="cfg_file",
help="Config file for training. Can be either a local file or a remote URL.",
required=True,
type=file_or_url,
)
train_parser.add_argument(
"--initialization-checkpoint",
type=file_or_url,
help="If this value is set to a proper checkpoint when training starts, "
"the model will be initialized with the weights given. "
"No other keys in the checkpoint will be loaded. "
"When another checkpoint would be available and the --resume flag is used, "
"this flag is ignored. This can be a path to a file or an URL. "
"If a URL is given the checkpoint will first be downloaded to the environmental variable "
"`DIRECT_MODEL_DOWNLOAD_DIR` (default=current directory).",
)
train_parser.add_argument("--resume", help="Resume training if possible.", action="store_true")
train_parser.add_argument(
"--force-validation",
help="Start with a validation round, when recovering from a crash. "
"If you use this option, be aware that when combined with --resume, "
"each new run will start with a validation round.",
action="store_true",
)
train_parser.add_argument("--name", help="Run name.", required=False, type=str)

train_parser.set_defaults(subcommand=train_from_argparse)
6 changes: 6 additions & 0 deletions direct/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) DIRECT Contributors
import argparse
import pathlib
import sys

from direct.utils.io import check_is_valid_url

Expand All @@ -13,3 +14,8 @@ def file_or_url(path):
if path.is_file():
return path
raise argparse.ArgumentTypeError(f"{path} is not a valid file or url.")


def check_train_val(key, name):
if key is not None and len(key) != 2:
sys.exit(f"--{name} has to be of the form `train_folder, validation_folder` if a validation folder is set.")
4 changes: 2 additions & 2 deletions direct/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,12 +373,12 @@ class Args(argparse.ArgumentParser):
Defines global default arguments.
"""

def __init__(self, epilog=None, **overrides):
def __init__(self, epilog=None, add_help=True, **overrides):
"""
Args:
**overrides (dict, optional): Keyword arguments used to override default argument values
"""
super().__init__(epilog=epilog, formatter_class=argparse.RawDescriptionHelpFormatter)
super().__init__(epilog=epilog, formatter_class=argparse.RawDescriptionHelpFormatter, add_help=add_help)

self.add_argument(
"--device",
Expand Down
59 changes: 4 additions & 55 deletions tools/train_model.py → direct/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors
import argparse
import functools
import logging
import os
Expand All @@ -12,12 +13,12 @@
import numpy as np
import torch

from direct.cli.utils import file_or_url
from direct.cli.utils import check_train_val
from direct.common.subsample import build_masking_function
from direct.data.datasets import build_dataset_from_input
from direct.data.lr_scheduler import WarmupMultiStepLR
from direct.data.mri_transforms import build_mri_transforms
from direct.environment import Args, setup_training_environment
from direct.environment import setup_training_environment
from direct.launch import launch
from direct.utils import remove_keys, set_all_seeds, str_to_class
from direct.utils.dataset import get_filenames_for_datasets
Expand Down Expand Up @@ -258,12 +259,7 @@ def setup_train(
)


def check_train_val(key, name):
if key is not None and len(key) != 2:
sys.exit(f"--{name} has to be of the form `train_folder, validation_folder` if a validation folder is set.")


if __name__ == "__main__":
def train_from_argparse(args: argparse.Namespace):
# This sets MKL threads to 1.
# DataLoader can otherwise bring a l ot of difficulties when computing CPU FFTs in the transforms.
torch.set_num_threads(1)
Expand All @@ -272,53 +268,6 @@ def check_train_val(key, name):
# Remove warnings from named tensors being experimental
os.environ["PYTHONWARNINGS"] = "ignore"

epilog = f"""
Examples:
Run on single machine:
$ {sys.argv[0]} training_set validation_set experiment_dir --num-gpus 8 --cfg cfg.yaml
Run on multiple machines:
(machine0)$ {sys.argv[0]} training_set validation_set experiment_dir --machine-rank 0 --num-machines 2 --dist-url <URL> [--other-flags]
(machine1)$ {sys.argv[0]} training_set validation_set experiment_dir --machine-rank 1 --num-machines 2 --dist-url <URL> [--other-flags]
"""

parser = Args(epilog=epilog)
parser.add_argument("training_root", type=pathlib.Path, help="Path to the training data.")
parser.add_argument("validation_root", type=pathlib.Path, help="Path to the validation data.")
parser.add_argument(
"experiment_dir",
type=pathlib.Path,
help="Path to the experiment directory.",
)
parser.add_argument(
"--cfg",
dest="cfg_file",
help="Config file for training. Can be either a local file or a remote URL.",
required=True,
type=file_or_url,
)
parser.add_argument(
"--initialization-checkpoint",
type=file_or_url,
help="If this value is set to a proper checkpoint when training starts, "
"the model will be initialized with the weights given. "
"No other keys in the checkpoint will be loaded. "
"When another checkpoint would be available and the --resume flag is used, "
"this flag is ignored. This can be a path to a file or an URL. "
"If a URL is given the checkpoint will first be downloaded to the environmental variable "
"`DIRECT_MODEL_DOWNLOAD_DIR` (default=current directory).",
)
parser.add_argument("--resume", help="Resume training if possible.", action="store_true")
parser.add_argument(
"--force-validation",
help="Start with a validation round, when recovering from a crash. "
"If you use this option, be aware that when combined with --resume, "
"each new run will start with a validation round.",
action="store_true",
)
parser.add_argument("--name", help="Run name.", required=False, type=str)

args = parser.parse_args()

if args.initialization_images is not None and args.initialization_kspace is not None:
sys.exit("--initialization-images and --initialization-kspace are mutually exclusive.")
check_train_val(args.initialization_images, "initialization-images")
Expand Down
7 changes: 0 additions & 7 deletions tools/README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,4 @@
# Direct tools

Scripts are provided:
- To train a model use `train_model.py`.
- To extract the best checkpoint based on `metrics.json`, use `parse_metrics_log.py`.


## Tips and tricks

- We are using a lot of experimental features in pytorch, to reduce such warnings you can use
`export PYTHONWARNINGS="ignore"` in the shell before execution.