@@ -40,6 +40,7 @@ def __init__(
4040 loader : TorchDataLoader ,
4141 output_keys : Union [str , Sequence [str ]] = CommonKeys .PRED ,
4242 batch_keys : Union [str , Sequence [str ]] = CommonKeys .IMAGE ,
43+ meta_key_postfix : str = "meta_dict" ,
4344 collate_fn : Optional [Callable ] = no_collation ,
4445 postfix : str = "_inverted" ,
4546 nearest_interp : Union [bool , Sequence [bool ]] = True ,
@@ -56,6 +57,10 @@ def __init__(
5657 batch_keys: the key of input data in `ignite.engine.batch`. will get the applied transforms
5758 for this input data, then invert them for the expected data with `output_keys`.
5859 It can also be a list of keys, each matches to the `output_keys` data. default to "image".
60+ meta_key_postfix: use `{batch_key}_{postfix}` to to fetch the meta data according to the key data,
61+ default is `meta_dict`, the meta data is a dictionary object.
62+ For example, to handle key `image`, read/write affine matrices from the
63+ metadata `image_meta_dict` dictionary's `affine` field.
5964 postfix: will save the inverted result into `ignite.engine.output` with key `{output_key}{postfix}`.
6065 nearest_interp: whether to use `nearest` interpolation mode when inverting the spatial transforms,
6166 default to `True`. If `False`, use the same interpolation mode as the original transform.
@@ -74,6 +79,7 @@ def __init__(
7479 )
7580 self .output_keys = ensure_tuple (output_keys )
7681 self .batch_keys = ensure_tuple_rep (batch_keys , len (self .output_keys ))
82+ self .meta_key_postfix = ensure_tuple_rep (meta_key_postfix , len (output_keys ))
7783 self .postfix = postfix
7884 self .nearest_interp = ensure_tuple_rep (nearest_interp , len (self .output_keys ))
7985 self ._totensor = ToTensor ()
@@ -90,7 +96,9 @@ def __call__(self, engine: Engine) -> None:
9096 Args:
9197 engine: Ignite Engine, it can be a trainer, validator or evaluator.
9298 """
93- for output_key , batch_key , nearest_interp in zip (self .output_keys , self .batch_keys , self .nearest_interp ):
99+ for output_key , batch_key , nearest_interp , meta_key in zip (
100+ self .output_keys , self .batch_keys , self .nearest_interp , self .meta_key_postfix
101+ ):
94102 transform_key = batch_key + InverseKeys .KEY_SUFFIX
95103 if transform_key not in engine .state .batch :
96104 warnings .warn (f"all the transforms on `{ batch_key } ` are not InvertibleTransform." )
@@ -104,6 +112,9 @@ def __call__(self, engine: Engine) -> None:
104112 batch_key : engine .state .output [output_key ].detach ().cpu (),
105113 transform_key : transform_info ,
106114 }
115+ meta_dict_key = f"{ batch_key } _{ meta_key } "
116+ if meta_dict_key in engine .state .batch :
117+ segs_dict [meta_dict_key ] = engine .state .batch [meta_dict_key ]
107118
108119 with allow_missing_keys_mode (self .transform ): # type: ignore
109120 inverted_key = f"{ output_key } { self .postfix } "
0 commit comments