Skip to content

Commit e4b313d

Browse files
authored
Auto3DSeg continue training (skip trained algos) (#6310)
Second PR for issue #6291 Since the previous PR #6290 was reverted #6295 Allows to skip the already trained algos, and continue training only for the non-trained ones. after this PR, the default option AutoRunner(train=None) will have this behavior, whereas manually setting AutoRunner(train=True/False) will always train all or skip all training. Previously we can only train all or skip all (without any option to resume) I changed import_bundle_algo_history() to return a better algo_dict previously it returned "list[dict(name: algo)]" - a list of dict, but each dict must have a single key name "name => algo". Not it returns a list of dicts, each with several keys dict(AlgoEnsembleKeys.ID: name, AlgoEnsembleKeys.ALGO, algo, "is_trained": bool, etc). this allows to put additional metadata inside of each algo_dict, and it's easier to read it back. previously, to get a name we had to use "name = history[0].keys()[0]", now it's more elegant "name = history[0][AlgoEnsembleKeys.ID]". this however required to change many files, everywhere where import_bundle_algo_history and export_bundle_algo_history was used. All the tests have passed, except for "integration GPU utilization tests" , but those errors seems unrelated After this PR, tutorials need to be updated too Project-MONAI/tutorials#1288 --------- Signed-off-by: myron <amyronenko@nvidia.com>
1 parent 06defb7 commit e4b313d

File tree

11 files changed

+162
-106
lines changed

11 files changed

+162
-106
lines changed

.github/workflows/integration.yml

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,27 +41,33 @@ jobs:
4141
python -m pip install --upgrade torch torchvision torchaudio
4242
python -m pip install -r requirements-dev.txt
4343
rm -rf /github/home/.cache/torch/hub/mmars/
44-
- name: Run integration tests
44+
- name: Clean directory
4545
run: |
4646
python -m pip list
4747
git config --global --add safe.directory /__w/MONAI/MONAI
4848
git clean -ffdx
4949
nvidia-smi
5050
export CUDA_VISIBLE_DEVICES=$(python -m tests.utils -c 1 | tail -n 1)
5151
echo $CUDA_VISIBLE_DEVICES
52-
trap 'if pgrep python; then pkill python; fi;' ERR
53-
python -c $'import torch\na=[torch.zeros(1,device=f"cuda:{i}") for i in range(torch.cuda.device_count())];\nwhile True:print(a)' > /dev/null &
5452
python -c "import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))"
5553
python -c 'import torch; print(torch.rand(5,3, device=torch.device("cuda:0")))'
5654
57-
# test auto3dseg
58-
echo "test tag algo"
55+
- name: Auto3dseg tag algo
56+
shell: bash
57+
env:
58+
BUILD_MONAI: 0
59+
run: |
5960
BUILD_MONAI=0 ./runtests.sh --build
6061
python -m tests.test_auto3dseg_ensemble
6162
python -m tests.test_auto3dseg_hpo
6263
python -m tests.test_integration_autorunner
6364
python -m tests.test_integration_gpu_customization
6465
66+
- name: Auto3dseg latest algo
67+
shell: bash
68+
env:
69+
BUILD_MONAI: 0
70+
run: |
6571
# test latest template
6672
echo "test latest algo"
6773
cd ../
@@ -81,14 +87,24 @@ jobs:
8187
python -m tests.test_integration_autorunner
8288
python -m tests.test_integration_gpu_customization
8389
84-
# the other tests
85-
echo "the other tests"
90+
- name: Integration tests
91+
shell: bash
92+
env:
93+
BUILD_MONAI: 1
94+
run: |
8695
pwd
8796
ls -ll
88-
BUILD_MONAI=1 ./runtests.sh --build --net
89-
BUILD_MONAI=1 ./runtests.sh --build --unittests
90-
if pgrep python; then pkill python; fi
97+
./runtests.sh --build --net
98+
99+
- name: Unit tests
91100
shell: bash
101+
env:
102+
BUILD_MONAI: 1
103+
run: |
104+
pwd
105+
ls -ll
106+
./runtests.sh --unittests
107+
92108
- name: Add reaction
93109
uses: peter-evans/create-or-update-comment@v2
94110
if: github.event.pull_request.number != ''

monai/apps/auto3dseg/auto_runner.py

Lines changed: 71 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def __init__(
281281
# determine if we need to analyze, algo_gen or train from cache, unless manually provided
282282
self.analyze = not self.cache["analyze"] if analyze is None else analyze
283283
self.algo_gen = not self.cache["algo_gen"] if algo_gen is None else algo_gen
284-
self.train = not self.cache["train"] if train is None else train
284+
self.train = train
285285
self.ensemble = ensemble # last step, no need to check
286286

287287
self.set_training_params()
@@ -635,13 +635,15 @@ def _train_algo_in_sequence(self, history: list[dict[str, Any]]) -> None:
635635
folders under the working directory. The results include the model checkpoints, a
636636
progress.yaml, accuracies in CSV and a pickle file of the Algo object.
637637
"""
638-
for task in history:
639-
for _, algo in task.items():
640-
algo.train(self.train_params)
641-
acc = algo.get_score()
642-
algo_to_pickle(algo, template_path=algo.template_path, best_metrics=acc)
638+
for algo_dict in history:
639+
algo = algo_dict[AlgoEnsembleKeys.ALGO]
640+
algo.train(self.train_params)
641+
acc = algo.get_score()
643642

644-
def _train_algo_in_nni(self, history):
643+
algo_meta_data = {str(AlgoEnsembleKeys.SCORE): acc}
644+
algo_to_pickle(algo, template_path=algo.template_path, **algo_meta_data)
645+
646+
def _train_algo_in_nni(self, history: list[dict[str, Any]]) -> None:
645647
"""
646648
Train the Algos using HPO.
647649
@@ -672,40 +674,41 @@ def _train_algo_in_nni(self, history):
672674

673675
last_total_tasks = len(import_bundle_algo_history(self.work_dir, only_trained=True))
674676
mode_dry_run = self.hpo_params.pop("nni_dry_run", False)
675-
for task in history:
676-
for name, algo in task.items():
677-
nni_gen = NNIGen(algo=algo, params=self.hpo_params)
678-
obj_filename = nni_gen.get_obj_filename()
679-
nni_config = deepcopy(default_nni_config)
680-
# override the default nni config with the same key in hpo_params
681-
for key in self.hpo_params:
682-
if key in nni_config:
683-
nni_config[key] = self.hpo_params[key]
684-
nni_config.update({"experimentName": name})
685-
nni_config.update({"search_space": self.search_space})
686-
trial_cmd = "python -m monai.apps.auto3dseg NNIGen run_algo " + obj_filename + " " + self.work_dir
687-
nni_config.update({"trialCommand": trial_cmd})
688-
nni_config_filename = os.path.abspath(os.path.join(self.work_dir, f"{name}_nni_config.yaml"))
689-
ConfigParser.export_config_file(nni_config, nni_config_filename, fmt="yaml", default_flow_style=None)
690-
691-
max_trial = min(self.hpo_tasks, cast(int, default_nni_config["maxTrialNumber"]))
692-
cmd = "nnictl create --config " + nni_config_filename + " --port 8088"
693-
694-
if mode_dry_run:
695-
logger.info(f"AutoRunner HPO is in dry-run mode. Please manually launch: {cmd}")
696-
continue
697-
698-
subprocess.run(cmd.split(), check=True)
699-
677+
for algo_dict in history:
678+
name = algo_dict[AlgoEnsembleKeys.ID]
679+
algo = algo_dict[AlgoEnsembleKeys.ALGO]
680+
nni_gen = NNIGen(algo=algo, params=self.hpo_params)
681+
obj_filename = nni_gen.get_obj_filename()
682+
nni_config = deepcopy(default_nni_config)
683+
# override the default nni config with the same key in hpo_params
684+
for key in self.hpo_params:
685+
if key in nni_config:
686+
nni_config[key] = self.hpo_params[key]
687+
nni_config.update({"experimentName": name})
688+
nni_config.update({"search_space": self.search_space})
689+
trial_cmd = "python -m monai.apps.auto3dseg NNIGen run_algo " + obj_filename + " " + self.work_dir
690+
nni_config.update({"trialCommand": trial_cmd})
691+
nni_config_filename = os.path.abspath(os.path.join(self.work_dir, f"{name}_nni_config.yaml"))
692+
ConfigParser.export_config_file(nni_config, nni_config_filename, fmt="yaml", default_flow_style=None)
693+
694+
max_trial = min(self.hpo_tasks, cast(int, default_nni_config["maxTrialNumber"]))
695+
cmd = "nnictl create --config " + nni_config_filename + " --port 8088"
696+
697+
if mode_dry_run:
698+
logger.info(f"AutoRunner HPO is in dry-run mode. Please manually launch: {cmd}")
699+
continue
700+
701+
subprocess.run(cmd.split(), check=True)
702+
703+
n_trainings = len(import_bundle_algo_history(self.work_dir, only_trained=True))
704+
while n_trainings - last_total_tasks < max_trial:
705+
sleep(1)
700706
n_trainings = len(import_bundle_algo_history(self.work_dir, only_trained=True))
701-
while n_trainings - last_total_tasks < max_trial:
702-
sleep(1)
703-
n_trainings = len(import_bundle_algo_history(self.work_dir, only_trained=True))
704707

705-
cmd = "nnictl stop --all"
706-
subprocess.run(cmd.split(), check=True)
707-
logger.info(f"NNI completes HPO on {name}")
708-
last_total_tasks = n_trainings
708+
cmd = "nnictl stop --all"
709+
subprocess.run(cmd.split(), check=True)
710+
logger.info(f"NNI completes HPO on {name}")
711+
last_total_tasks = n_trainings
709712

710713
def run(self):
711714
"""
@@ -758,7 +761,8 @@ def run(self):
758761
logger.info("Skipping algorithm generation...")
759762

760763
# step 3: algo training
761-
if self.train:
764+
auto_train_choice = self.train is None
765+
if self.train or (auto_train_choice and not self.cache["train"]):
762766
history = import_bundle_algo_history(self.work_dir, only_trained=False)
763767

764768
if len(history) == 0:
@@ -767,20 +771,40 @@ def run(self):
767771
"Possibly the required algorithms generation step was not completed."
768772
)
769773

770-
if not self.hpo:
771-
self._train_algo_in_sequence(history)
772-
else:
773-
self._train_algo_in_nni(history)
774+
if auto_train_choice:
775+
skip_algos = [h[AlgoEnsembleKeys.ID] for h in history if h["is_trained"]]
776+
if len(skip_algos) > 0:
777+
logger.info(
778+
f"Skipping already trained algos {skip_algos}."
779+
"Set option train=True to always retrain all algos."
780+
)
781+
history = [h for h in history if not h["is_trained"]]
782+
783+
if len(history) > 0:
784+
if not self.hpo:
785+
self._train_algo_in_sequence(history)
786+
else:
787+
self._train_algo_in_nni(history)
788+
774789
self.export_cache(train=True)
775790
else:
776791
logger.info("Skipping algorithm training...")
777792

778793
# step 4: model ensemble and write the prediction to disks.
779794
if self.ensemble:
780-
history = import_bundle_algo_history(self.work_dir, only_trained=True)
795+
history = import_bundle_algo_history(self.work_dir, only_trained=False)
796+
797+
history_untrained = [h for h in history if not h["is_trained"]]
798+
if len(history_untrained) > 0:
799+
warnings.warn(
800+
f"Ensembling step will skip {[h['name'] for h in history_untrained]} untrained algos."
801+
"Generally it means these algos did not complete training."
802+
)
803+
history = [h for h in history if h["is_trained"]]
804+
781805
if len(history) == 0:
782806
raise ValueError(
783-
f"Could not find the trained results in {self.work_dir}. "
807+
f"Could not find any trained algos in {self.work_dir}. "
784808
"Possibly the required training step was not completed."
785809
)
786810

@@ -798,4 +822,4 @@ def run(self):
798822
self.save_image(pred)
799823
logger.info(f"Auto3Dseg ensemble prediction outputs are saved in {self.output_dir}.")
800824

801-
logger.info("Auto3Dseg pipeline is complete successfully.")
825+
logger.info("Auto3Dseg pipeline is completed successfully.")

monai/apps/auto3dseg/bundle_gen.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from monai.auto3dseg.utils import algo_to_pickle
3434
from monai.bundle.config_parser import ConfigParser
3535
from monai.utils import ensure_tuple
36+
from monai.utils.enums import AlgoEnsembleKeys
3637

3738
logger = get_logger(module_name=__name__)
3839
ALGO_HASH = os.environ.get("MONAI_ALGO_HASH", "7758ad1")
@@ -537,4 +538,6 @@ def generate(
537538
gen_algo.export_to_disk(output_folder, name, fold=f_id)
538539

539540
algo_to_pickle(gen_algo, template_path=algo.template_path)
540-
self.history.append({name: gen_algo}) # track the previous, may create a persistent history
541+
self.history.append(
542+
{AlgoEnsembleKeys.ID: name, AlgoEnsembleKeys.ALGO: gen_algo}
543+
) # track the previous, may create a persistent history

monai/apps/auto3dseg/ensemble_builder.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -267,22 +267,20 @@ class AlgoEnsembleBuilder:
267267
268268
"""
269269

270-
def __init__(self, history: Sequence[dict], data_src_cfg_filename: str | None = None):
270+
def __init__(self, history: Sequence[dict[str, Any]], data_src_cfg_filename: str | None = None):
271271
self.infer_algos: list[dict[AlgoEnsembleKeys, Any]] = []
272272
self.ensemble: AlgoEnsemble
273273
self.data_src_cfg = ConfigParser(globals=False)
274274

275275
if data_src_cfg_filename is not None and os.path.exists(str(data_src_cfg_filename)):
276276
self.data_src_cfg.read_config(data_src_cfg_filename)
277277

278-
for h in history:
278+
for algo_dict in history:
279279
# load inference_config_paths
280-
# raise warning/error if not found
281-
if len(h) > 1:
282-
raise ValueError(f"{h} should only contain one set of genAlgo key-value")
283280

284-
name = list(h.keys())[0]
285-
gen_algo = h[name]
281+
name = algo_dict[AlgoEnsembleKeys.ID]
282+
gen_algo = algo_dict[AlgoEnsembleKeys.ALGO]
283+
286284
best_metric = gen_algo.get_score()
287285
algo_path = gen_algo.output_path
288286
infer_path = os.path.join(algo_path, "scripts", "infer.py")

monai/apps/auto3dseg/hpo_gen.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from monai.bundle.config_parser import ConfigParser
2424
from monai.config import PathLike
2525
from monai.utils import optional_import
26+
from monai.utils.enums import AlgoEnsembleKeys
2627

2728
nni, has_nni = optional_import("nni")
2829
optuna, has_optuna = optional_import("optuna")
@@ -98,8 +99,8 @@ class NNIGen(HPOGen):
9899
# Bundle Algorithms are already generated by BundleGen in work_dir
99100
import_bundle_algo_history(work_dir, only_trained=False)
100101
algo_dict = self.history[0] # pick the first algorithm
101-
algo_name = list(algo_dict.keys())[0]
102-
onealgo = algo_dict[algo_name]
102+
algo_name = algo_dict[AlgoEnsembleKeys.ID]
103+
onealgo = algo_dict[AlgoEnsembleKeys.ALGO]
103104
nni_gen = NNIGen(algo=onealgo)
104105
nni_gen.print_bundle_algo_instruction()
105106
@@ -237,10 +238,12 @@ def run_algo(self, obj_filename: str, output_folder: str = ".", template_path: P
237238
self.algo.train(self.params)
238239
# step 4 report validation acc to controller
239240
acc = self.algo.get_score()
241+
algo_meta_data = {str(AlgoEnsembleKeys.SCORE): acc}
242+
240243
if isinstance(self.algo, BundleAlgo):
241-
algo_to_pickle(self.algo, template_path=self.algo.template_path, best_metrics=acc)
244+
algo_to_pickle(self.algo, template_path=self.algo.template_path, **algo_meta_data)
242245
else:
243-
algo_to_pickle(self.algo, best_metrics=acc)
246+
algo_to_pickle(self.algo, **algo_meta_data)
244247
self.set_score(acc)
245248

246249

@@ -408,8 +411,9 @@ def run_algo(self, obj_filename: str, output_folder: str = ".", template_path: P
408411
self.algo.train(self.params)
409412
# step 4 report validation acc to controller
410413
acc = self.algo.get_score()
414+
algo_meta_data = {str(AlgoEnsembleKeys.SCORE): acc}
411415
if isinstance(self.algo, BundleAlgo):
412-
algo_to_pickle(self.algo, template_path=self.algo.template_path, best_metrics=acc)
416+
algo_to_pickle(self.algo, template_path=self.algo.template_path, **algo_meta_data)
413417
else:
414-
algo_to_pickle(self.algo, best_metrics=acc)
418+
algo_to_pickle(self.algo, **algo_meta_data)
415419
self.set_score(acc)

monai/apps/auto3dseg/utils.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515

1616
from monai.apps.auto3dseg.bundle_gen import BundleAlgo
1717
from monai.auto3dseg import algo_from_pickle, algo_to_pickle
18+
from monai.utils.enums import AlgoEnsembleKeys
1819

1920

2021
def import_bundle_algo_history(
2122
output_folder: str = ".", template_path: str | None = None, only_trained: bool = True
2223
) -> list:
2324
"""
24-
import the history of the bundleAlgo object with their names/identifiers
25+
import the history of the bundleAlgo objects as a list of algo dicts.
26+
each algo_dict has keys name (folder name), algo (bundleAlgo), is_trained (bool),
2527
2628
Args:
2729
output_folder: the root path of the algorithms templates.
@@ -47,11 +49,18 @@ def import_bundle_algo_history(
4749
if isinstance(algo, BundleAlgo): # algo's template path needs override
4850
algo.template_path = algo_meta_data["template_path"]
4951

50-
if only_trained:
51-
if "best_metrics" in algo_meta_data:
52-
history.append({name: algo})
53-
else:
54-
history.append({name: algo})
52+
best_metric = algo_meta_data.get(AlgoEnsembleKeys.SCORE, None)
53+
is_trained = best_metric is not None
54+
55+
if (only_trained and is_trained) or not only_trained:
56+
history.append(
57+
{
58+
AlgoEnsembleKeys.ID: name,
59+
AlgoEnsembleKeys.ALGO: algo,
60+
AlgoEnsembleKeys.SCORE: best_metric,
61+
"is_trained": is_trained,
62+
}
63+
)
5564

5665
return history
5766

@@ -63,6 +72,6 @@ def export_bundle_algo_history(history: list[dict[str, BundleAlgo]]) -> None:
6372
Args:
6473
history: a List of Bundle. Typically, the history can be obtained from BundleGen get_history method
6574
"""
66-
for task in history:
67-
for _, algo in task.items():
68-
algo_to_pickle(algo, template_path=algo.template_path)
75+
for algo_dict in history:
76+
algo = algo_dict[AlgoEnsembleKeys.ALGO]
77+
algo_to_pickle(algo, template_path=algo.template_path)

0 commit comments

Comments
 (0)