Skip to content

Commit 617c1be

Browse files
authored
6873 data analyzer histogram_only=True fix (#6874)
Fixes #6873 ### Description - fixes data analyzer - replace `"image_stats"` with `DataStatsKeys.IMAGE_STATS` ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Wenqi Li <wenqil@nvidia.com>
1 parent e24b969 commit 617c1be

File tree

4 files changed

+32
-10
lines changed

4 files changed

+32
-10
lines changed

monai/apps/auto3dseg/data_analyzer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ def _check_data_uniformity(keys: list[str], result: dict) -> bool:
161161
162162
"""
163163

164+
if DataStatsKeys.SUMMARY not in result or DataStatsKeys.IMAGE_STATS not in result[DataStatsKeys.SUMMARY]:
165+
return True
164166
constant_props = [result[DataStatsKeys.SUMMARY][DataStatsKeys.IMAGE_STATS][key] for key in keys]
165167
for prop in constant_props:
166168
if "stdev" in prop and np.any(prop["stdev"]):
@@ -358,10 +360,11 @@ def _get_all_case_stats(
358360
stats_by_cases = {
359361
DataStatsKeys.BY_CASE_IMAGE_PATH: d[DataStatsKeys.BY_CASE_IMAGE_PATH],
360362
DataStatsKeys.BY_CASE_LABEL_PATH: d[DataStatsKeys.BY_CASE_LABEL_PATH],
361-
DataStatsKeys.IMAGE_STATS: d[DataStatsKeys.IMAGE_STATS],
362363
}
364+
if not self.histogram_only:
365+
stats_by_cases[DataStatsKeys.IMAGE_STATS] = d[DataStatsKeys.IMAGE_STATS]
363366
if self.hist_bins != 0:
364-
stats_by_cases.update({DataStatsKeys.IMAGE_HISTOGRAM: d[DataStatsKeys.IMAGE_HISTOGRAM]})
367+
stats_by_cases[DataStatsKeys.IMAGE_HISTOGRAM] = d[DataStatsKeys.IMAGE_HISTOGRAM]
365368

366369
if self.label_key is not None:
367370
stats_by_cases.update(

monai/auto3dseg/analyzer.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ class ImageStats(Analyzer):
198198
199199
"""
200200

201-
def __init__(self, image_key: str, stats_name: str = "image_stats") -> None:
201+
def __init__(self, image_key: str, stats_name: str = DataStatsKeys.IMAGE_STATS) -> None:
202202
if not isinstance(image_key, str):
203203
raise ValueError("image_key input must be str")
204204

@@ -296,7 +296,7 @@ class FgImageStats(Analyzer):
296296
297297
"""
298298

299-
def __init__(self, image_key: str, label_key: str, stats_name: str = "image_foreground_stats"):
299+
def __init__(self, image_key: str, label_key: str, stats_name: str = DataStatsKeys.FG_IMAGE_STATS):
300300
self.image_key = image_key
301301
self.label_key = label_key
302302

@@ -378,7 +378,9 @@ class LabelStats(Analyzer):
378378
379379
"""
380380

381-
def __init__(self, image_key: str, label_key: str, stats_name: str = "label_stats", do_ccp: bool | None = True):
381+
def __init__(
382+
self, image_key: str, label_key: str, stats_name: str = DataStatsKeys.LABEL_STATS, do_ccp: bool | None = True
383+
):
382384
self.image_key = image_key
383385
self.label_key = label_key
384386
self.do_ccp = do_ccp
@@ -533,7 +535,7 @@ class ImageStatsSumm(Analyzer):
533535
534536
"""
535537

536-
def __init__(self, stats_name: str = "image_stats", average: bool | None = True):
538+
def __init__(self, stats_name: str = DataStatsKeys.IMAGE_STATS, average: bool | None = True):
537539
self.summary_average = average
538540
report_format = {
539541
ImageStatsKeys.SHAPE: None,
@@ -623,7 +625,7 @@ class FgImageStatsSumm(Analyzer):
623625
624626
"""
625627

626-
def __init__(self, stats_name: str = "image_foreground_stats", average: bool | None = True):
628+
def __init__(self, stats_name: str = DataStatsKeys.FG_IMAGE_STATS, average: bool | None = True):
627629
self.summary_average = average
628630

629631
report_format = {ImageStatsKeys.INTENSITY: None}
@@ -687,7 +689,9 @@ class LabelStatsSumm(Analyzer):
687689
688690
"""
689691

690-
def __init__(self, stats_name: str = "label_stats", average: bool | None = True, do_ccp: bool | None = True):
692+
def __init__(
693+
self, stats_name: str = DataStatsKeys.LABEL_STATS, average: bool | None = True, do_ccp: bool | None = True
694+
):
691695
self.summary_average = average
692696
self.do_ccp = do_ccp
693697

monai/auto3dseg/seg_summarizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,9 @@ def __init__(
100100
self.summary_analyzers: list[Any] = []
101101
super().__init__()
102102

103+
self.add_analyzer(FilenameStats(image_key, DataStatsKeys.BY_CASE_IMAGE_PATH), None)
104+
self.add_analyzer(FilenameStats(label_key, DataStatsKeys.BY_CASE_LABEL_PATH), None)
103105
if not self.histogram_only:
104-
self.add_analyzer(FilenameStats(image_key, DataStatsKeys.BY_CASE_IMAGE_PATH), None)
105-
self.add_analyzer(FilenameStats(label_key, DataStatsKeys.BY_CASE_LABEL_PATH), None)
106106
self.add_analyzer(ImageStats(image_key), ImageStatsSumm(average=average))
107107

108108
if label_key is None:

tests/test_auto3dseg.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,21 @@ def test_data_analyzer_cpu(self, input_params):
190190

191191
assert len(datastat["stats_by_cases"]) == len(sim_datalist["training"])
192192

193+
def test_data_analyzer_histogram(self):
194+
create_sim_data(
195+
self.dataroot_dir, sim_datalist, [32] * 3, image_only=True, rad_max=8, rad_min=1, num_seg_classes=1
196+
)
197+
analyser = DataAnalyzer(
198+
self.datalist_file,
199+
self.dataroot_dir,
200+
output_path=self.datastat_file,
201+
label_key=None,
202+
device=device,
203+
histogram_only=True,
204+
)
205+
datastat = analyser.get_all_case_stats()
206+
assert len(datastat["stats_by_cases"]) == len(sim_datalist["training"])
207+
193208
@parameterized.expand(SIM_GPU_TEST_CASES)
194209
@skip_if_no_cuda
195210
def test_data_analyzer_gpu(self, input_params):

0 commit comments

Comments
 (0)