Skip to content

Commit 1bf2ccd

Browse files
committed
follow up of #1992
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
1 parent b3eb73a commit 1bf2ccd

File tree

3 files changed

+23
-3
lines changed

3 files changed

+23
-3
lines changed

monai/data/test_time_augmentation.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ class TestTimeAugmentation:
6262
device: device on which to perform inference.
6363
image_key: key used to extract image from input dictionary.
6464
label_key: key used to extract label from input dictionary.
65+
meta_key_postfix: use `key_{postfix}` to to fetch the meta data according to the key data,
66+
default is `meta_dict`, the meta data is a dictionary object.
67+
For example, to handle key `image`, read/write affine matrices from the
68+
metadata `image_meta_dict` dictionary's `affine` field.
6569
return_full_data: normally, metrics are returned (mode, mean, std, vvc). Setting this flag to `True` will return the
6670
full data. Dimensions will be same size as when passing a single image through `inferrer_fn`, with a dimension appended
6771
equal in size to `num_examples` (N), i.e., `[N,C,H,W,[D]]`.
@@ -88,6 +92,7 @@ def __init__(
8892
device: Optional[Union[str, torch.device]] = "cuda" if torch.cuda.is_available() else "cpu",
8993
image_key=CommonKeys.IMAGE,
9094
label_key=CommonKeys.LABEL,
95+
meta_key_postfix="meta_dict",
9196
return_full_data: bool = False,
9297
progress: bool = True,
9398
) -> None:
@@ -98,6 +103,7 @@ def __init__(
98103
self.device = device
99104
self.image_key = image_key
100105
self.label_key = label_key
106+
self.meta_key_postfix = meta_key_postfix
101107
self.return_full_data = return_full_data
102108
self.progress = progress
103109

@@ -168,7 +174,7 @@ def __call__(
168174
# create a dictionary containing the inferred batch and their transforms
169175
inferred_dict = {self.label_key: batch_output, label_transform_key: batch_data[label_transform_key]}
170176
# if meta dict is present, add that too (required for some inverse transforms)
171-
label_meta_dict_key = self.label_key + "_meta_dict"
177+
label_meta_dict_key = f"{self.label_key}_{self.meta_key_postfix}"
172178
if label_meta_dict_key in batch_data:
173179
inferred_dict[label_meta_dict_key] = batch_data[label_meta_dict_key]
174180

monai/handlers/transform_inverter.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(
4040
loader: TorchDataLoader,
4141
output_keys: Union[str, Sequence[str]] = CommonKeys.PRED,
4242
batch_keys: Union[str, Sequence[str]] = CommonKeys.IMAGE,
43+
meta_key_postfix: str = "meta_dict",
4344
collate_fn: Optional[Callable] = no_collation,
4445
postfix: str = "_inverted",
4546
nearest_interp: Union[bool, Sequence[bool]] = True,
@@ -56,6 +57,10 @@ def __init__(
5657
batch_keys: the key of input data in `ignite.engine.batch`. will get the applied transforms
5758
for this input data, then invert them for the expected data with `output_keys`.
5859
It can also be a list of keys, each matches to the `output_keys` data. default to "image".
60+
meta_key_postfix: use `{batch_key}_{postfix}` to to fetch the meta data according to the key data,
61+
default is `meta_dict`, the meta data is a dictionary object.
62+
For example, to handle key `image`, read/write affine matrices from the
63+
metadata `image_meta_dict` dictionary's `affine` field.
5964
postfix: will save the inverted result into `ignite.engine.output` with key `{output_key}{postfix}`.
6065
nearest_interp: whether to use `nearest` interpolation mode when inverting the spatial transforms,
6166
default to `True`. If `False`, use the same interpolation mode as the original transform.
@@ -74,6 +79,7 @@ def __init__(
7479
)
7580
self.output_keys = ensure_tuple(output_keys)
7681
self.batch_keys = ensure_tuple_rep(batch_keys, len(self.output_keys))
82+
self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(output_keys))
7783
self.postfix = postfix
7884
self.nearest_interp = ensure_tuple_rep(nearest_interp, len(self.output_keys))
7985
self._totensor = ToTensor()
@@ -90,7 +96,9 @@ def __call__(self, engine: Engine) -> None:
9096
Args:
9197
engine: Ignite Engine, it can be a trainer, validator or evaluator.
9298
"""
93-
for output_key, batch_key, nearest_interp in zip(self.output_keys, self.batch_keys, self.nearest_interp):
99+
for output_key, batch_key, nearest_interp, meta_key in zip(
100+
self.output_keys, self.batch_keys, self.nearest_interp, self.meta_key_postfix
101+
):
94102
transform_key = batch_key + InverseKeys.KEY_SUFFIX
95103
if transform_key not in engine.state.batch:
96104
warnings.warn(f"all the transforms on `{batch_key}` are not InvertibleTransform.")
@@ -104,6 +112,9 @@ def __call__(self, engine: Engine) -> None:
104112
batch_key: engine.state.output[output_key].detach().cpu(),
105113
transform_key: transform_info,
106114
}
115+
meta_dict_key = f"{batch_key}_{meta_key}"
116+
if meta_dict_key in engine.state.batch:
117+
segs_dict[meta_dict_key] = engine.state.batch[meta_dict_key]
107118

108119
with allow_missing_keys_mode(self.transform): # type: ignore
109120
inverted_key = f"{output_key}{self.postfix}"

tests/test_handler_transform_inverter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
RandZoomd,
3333
ResizeWithPadOrCropd,
3434
ScaleIntensityd,
35+
Spacingd,
3536
ToTensord,
3637
)
3738
from monai.utils.misc import set_determinism
@@ -48,6 +49,7 @@ def test_invert(self):
4849
[
4950
LoadImaged(KEYS),
5051
AddChanneld(KEYS),
52+
Spacingd(KEYS, pixdim=(1.1, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32),
5153
ScaleIntensityd("image", minv=1, maxv=10),
5254
RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]),
5355
RandAxisFlipd(KEYS, prob=0.5),
@@ -102,7 +104,8 @@ def _train_func(engine, batch):
102104
reverted_name = engine.state.output["label_meta_dict"]["filename_or_obj"][-1]
103105
original_name = data[-1]["label"]
104106
self.assertEqual(reverted_name, original_name)
105-
self.assertTrue((reverted.size - n_good) in (0, 23641), "diff. in two possible values")
107+
print("invert diff", reverted.size - n_good)
108+
self.assertTrue((reverted.size - n_good) in (0, 981), "diff. in two possible values")
106109

107110

108111
if __name__ == "__main__":

0 commit comments

Comments
 (0)