Skip to content

Commit 9e7b3d6

Browse files
committed
changes after code review
Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com>
1 parent ad1d780 commit 9e7b3d6

File tree

6 files changed

+108
-42
lines changed

6 files changed

+108
-42
lines changed

monai/data/meta_tensor.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717

1818
import torch
1919

20+
from monai.config.type_definitions import NdarrayTensor
2021
from monai.data.meta_obj import MetaObj, get_track_meta, get_track_transforms
2122
from monai.data.utils import decollate_batch, list_data_collate
23+
from monai.transforms.utils import remove_extra_metadata
2224
from monai.utils.enums import PostFix
2325

2426
__all__ = ["MetaTensor"]
@@ -232,3 +234,33 @@ def affine(self) -> torch.Tensor:
232234
def affine(self, d: torch.Tensor) -> None:
233235
"""Set the affine."""
234236
self.meta["affine"] = d
237+
238+
@staticmethod
239+
def ensure_torch_and_prune_meta(im: NdarrayTensor, meta: dict):
240+
"""
241+
Convert the image to `torch.Tensor`. If `affine` is in the `meta` dictionary,
242+
convert that to `torch.Tensor`, too. Remove any superfluous metadata.
243+
244+
Args:
245+
im: Input image (`np.ndarray` or `torch.Tensor`)
246+
meta: Metadata dictionary.
247+
248+
Returns:
249+
By default, a `MetaTensor` is returned.
250+
However, if `get_track_meta()` is `False`, a `torch.Tensor` is returned.
251+
"""
252+
img = torch.as_tensor(im)
253+
254+
# if not tracking metadata, return `torch.Tensor`
255+
if not get_track_meta() or meta is None:
256+
return img
257+
258+
# ensure affine is of type `torch.Tensor`
259+
if "affine" in meta:
260+
meta["affine"] = torch.as_tensor(meta["affine"])
261+
262+
# remove any superfluous metadata.
263+
remove_extra_metadata(meta)
264+
265+
# return the `MetaTensor`
266+
return MetaTensor(img, meta=meta)

monai/transforms/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,7 @@
569569
generate_label_classes_crop_centers,
570570
generate_pos_neg_label_crop_centers,
571571
generate_spatial_bounding_box,
572+
get_extra_metadata_keys,
572573
get_extreme_points,
573574
get_largest_connected_component_mask,
574575
get_number_image_type_conversions,
@@ -582,6 +583,8 @@
582583
map_spatial_axes,
583584
print_transform_backends,
584585
rand_choice,
586+
remove_extra_metadata,
587+
remove_keys,
585588
rescale_array,
586589
rescale_array_int_max,
587590
rescale_instance_array,

monai/transforms/io/array.py

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from monai.data import image_writer
3030
from monai.data.folder_layout import FolderLayout
3131
from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader
32-
from monai.data.meta_obj import get_track_meta
3332
from monai.data.meta_tensor import MetaTensor
3433
from monai.transforms.transform import Transform
3534
from monai.transforms.utility.array import EnsureChannelFirst
@@ -247,45 +246,11 @@ def __call__(self, filename: Union[Sequence[PathLike], PathLike], reader: Option
247246
meta_data = switch_endianness(meta_data, "<")
248247

249248
meta_data[Key.FILENAME_OR_OBJ] = f"{ensure_tuple(filename)[0]}" # Path obj should be strings for data loader
250-
img = self.join_im_and_meta(img_array, meta_data)
249+
img = MetaTensor.ensure_torch_and_prune_meta(img_array, meta_data)
251250
if self.ensure_channel_first:
252251
img = EnsureChannelFirst()(img)
253252
return img
254253

255-
@staticmethod
256-
def join_im_and_meta(im, meta: dict):
257-
img = torch.as_tensor(im)
258-
259-
# if not tracking metadata, return torch.Tensor
260-
if not get_track_meta() or meta is None:
261-
return img
262-
263-
if "affine" in meta:
264-
meta["affine"] = torch.as_tensor(meta["affine"])
265-
266-
# TODO: delete extra metadata
267-
for i in range(8):
268-
for k in ("dim", "pixdim"):
269-
if f"{k}[{i}]" in meta:
270-
del meta[f"{k}[{i}]"]
271-
for k in (
272-
# "original_affine",
273-
# "spatial_shape",
274-
# "spacing",
275-
"srow_x",
276-
"srow_y",
277-
"srow_z",
278-
"quatern_b",
279-
"quatern_c",
280-
"quatern_d",
281-
"qoffset_x",
282-
"qoffset_y",
283-
"qoffset_z",
284-
):
285-
if k in meta:
286-
del meta[k]
287-
return MetaTensor(img, meta=meta)
288-
289254

290255
class SaveImage(Transform):
291256
"""

monai/transforms/utility/dictionary.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -427,11 +427,11 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
427427
if dim > 0: # don't update affine if channel dim
428428
affine = d[split_meta_key]["affine"] # type: ignore
429429
ndim = len(affine)
430-
shift = (
431-
torch.eye(ndim, device=affine.device, dtype=affine.dtype)
432-
if isinstance(affine, torch.Tensor)
433-
else np.eye(ndim)
434-
)
430+
shift: NdarrayOrTensor
431+
if isinstance(affine, torch.Tensor):
432+
shift = torch.eye(ndim, device=affine.device, dtype=affine.dtype)
433+
else:
434+
shift = np.eye(ndim)
435435
shift[dim - 1, -1] = i # type: ignore
436436
d[split_meta_key]["affine"] = d[split_meta_key]["affine"] @ shift # type: ignore
437437

monai/transforms/utils.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@
105105
"convert_pad_mode",
106106
"convert_to_contiguous",
107107
"get_unique_labels",
108+
"remove_keys",
109+
"remove_extra_metadata",
110+
"get_extra_metadata_keys",
108111
]
109112

110113

@@ -1573,5 +1576,68 @@ def convert_to_contiguous(data, **kwargs):
15731576
return data
15741577

15751578

1579+
def remove_keys(data: dict, keys: List[str]) -> None:
1580+
"""
1581+
Remove keys from a dictionary. Operates in-place so nothing is returned.
1582+
1583+
Args:
1584+
data: dictionary to be modified.
1585+
keys: keys to be deleted from dictionary.
1586+
1587+
Returns:
1588+
`None`
1589+
"""
1590+
for k in keys:
1591+
_ = data.pop(k, None)
1592+
1593+
1594+
def remove_extra_metadata(meta: dict) -> None:
1595+
"""
1596+
Remove extra metadata from the dictionary. Operates in-place so nothing is returned.
1597+
1598+
Args:
1599+
meta: dictionary containing metadata to be modified.
1600+
1601+
Returns:
1602+
`None`
1603+
"""
1604+
keys = get_extra_metadata_keys()
1605+
remove_keys(data=meta, keys=keys)
1606+
1607+
1608+
def get_extra_metadata_keys() -> List[str]:
1609+
"""
1610+
Get a list of unnecessary keys for metadata that can be removed.
1611+
1612+
Returns:
1613+
List of keys to be removed.
1614+
"""
1615+
keys = [
1616+
"srow_x",
1617+
"srow_y",
1618+
"srow_z",
1619+
"quatern_b",
1620+
"quatern_c",
1621+
"quatern_d",
1622+
"qoffset_x",
1623+
"qoffset_y",
1624+
"qoffset_z",
1625+
"dim",
1626+
"pixdim",
1627+
*[f"dim[{i}]" for i in range(8)],
1628+
*[f"pixdim[{i}]" for i in range(8)],
1629+
]
1630+
1631+
# TODO: it would be good to remove these, but they are currently being used in the
1632+
# codebase.
1633+
# keys += [
1634+
# "original_affine",
1635+
# "spatial_shape",
1636+
# "spacing",
1637+
# ]
1638+
1639+
return keys
1640+
1641+
15761642
if __name__ == "__main__":
15771643
print_transform_backends()

tests/test_load_image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def test_kwargs(self):
242242
reader = ITKReader()
243243
img = reader.read(filename, fallback_only=False)
244244
result_raw = reader.get_data(img)
245-
result_raw = LoadImage.join_im_and_meta(*result_raw)
245+
result_raw = MetaTensor.ensure_torch_and_prune_meta(*result_raw)
246246
self.assertTupleEqual(result.shape, result_raw.shape)
247247

248248
def test_my_reader(self):

0 commit comments

Comments
 (0)