Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 5 additions & 21 deletions direct/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,42 +230,26 @@ def safe_divide(input_tensor: torch.Tensor, other_tensor: torch.Tensor) -> torch
return data


def modulus(data: torch.Tensor) -> torch.Tensor:
def modulus(data: torch.Tensor, complex_axis: int = -1) -> torch.Tensor:
"""Compute modulus of complex input data. Assumes there is a complex axis (of dimension 2) in the data.

Parameters
----------
data: torch.Tensor
complex_axis: int
Complex dimension along which the modulus will be calculated. Default: -1.

Returns
-------
output_data: torch.Tensor
Modulus of data.
"""
# TODO: fix to specify dim of complex axis or make it work with complex_last=True.

assert_complex(data, complex_last=False)
complex_axis = -1 if data.size(-1) == 2 else 1
assert_complex(data, complex_axis)

return (data**2).sum(complex_axis).sqrt() # noqa


def modulus_if_complex(data: torch.Tensor) -> torch.Tensor:
"""Compute modulus if complex tensor (has complex axis).

Parameters
----------
data: torch.Tensor

Returns
-------
torch.Tensor
"""
if is_complex_data(data, complex_last=False):
return modulus(data)
return data


def roll_one_dim(data: torch.Tensor, shift: int, dim: int) -> torch.Tensor:
"""Similar to roll but only for one dim

Expand Down Expand Up @@ -539,7 +523,7 @@ def tensor_to_complex_numpy(data: torch.Tensor) -> np.ndarray:
out: np.array
Complex valued np.ndarray
"""
assert_complex(data)
assert_complex(data, complex_last=True)
data = data.detach().cpu().numpy()
return data[..., 0] + 1j * data[..., 1]

Expand Down
2 changes: 1 addition & 1 deletion direct/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def training_loop(

metrics_dict = evaluate_dict(
metric_fns,
T.modulus_if_complex(output.detach()),
output.detach(),
data["target"].detach().to(self.device),
reduction="mean",
)
Expand Down
12 changes: 3 additions & 9 deletions direct/nn/mri_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,7 @@ def l1_loss(source, reduction="mean", **data):
L1 loss.
"""
resolution = get_resolution(**data)
l1_loss = F.l1_loss(
*_crop_volume(T.modulus_if_complex(source), data["target"], resolution), reduction=reduction
)
l1_loss = F.l1_loss(*_crop_volume(source, data["target"], resolution), reduction=reduction)

return l1_loss

Expand All @@ -155,9 +153,7 @@ def l2_loss(source, reduction="mean", **data):
L2 loss.
"""
resolution = get_resolution(**data)
l2_loss = F.mse_loss(
*_crop_volume(T.modulus_if_complex(source), data["target"], resolution), reduction=reduction
)
l2_loss = F.mse_loss(*_crop_volume(source, data["target"], resolution), reduction=reduction)

return l2_loss

Expand All @@ -182,7 +178,7 @@ def ssim_loss(source, reduction="mean", **data):
f"SSIM loss can only be computed with reduction == 'mean'." f" Got reduction == {reduction}."
)

source_abs, target_abs = _crop_volume(T.modulus_if_complex(source), data["target"], resolution)
source_abs, target_abs = _crop_volume(source, data["target"], resolution)
data_range = torch.tensor([target_abs.max()], device=target_abs.device)

ssim_loss = SSIMLoss().to(source_abs.device).forward(source_abs, target_abs, data_range=data_range)
Expand Down Expand Up @@ -461,8 +457,6 @@ def _process_output(data, scaling_factors=None, resolution=None):
if scaling_factors is not None:
data = data * scaling_factors.view(-1, *((1,) * (len(data.shape) - 1))).to(data.device)

data = T.modulus_if_complex(data)

if len(data.shape) == 3: # (batch, height, width)
data = data.unsqueeze(1) # Added channel dimension.

Expand Down
9 changes: 6 additions & 3 deletions direct/nn/rim/rim_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.cuda.amp import autocast

from direct.config import BaseConfig
from direct.data.transforms import modulus
from direct.engine import DoIterationOutput
from direct.nn.mri_models import MRIModelEngine
from direct.utils import detach_dict, dict_to_device, reduce_list_of_dicts
Expand Down Expand Up @@ -36,6 +37,7 @@ def __init__(
mixed_precision=mixed_precision,
**models,
)
self._complex_dim = 1

def _do_iteration(
self,
Expand Down Expand Up @@ -90,7 +92,8 @@ def _do_iteration(
# reconstruction_iter: list with tensors of shape (batch, complex=2, height, width)
# hidden_state has shape: (batch, num_hidden_channels, height, width, depth)

output_image = reconstruction_iter[-1] # shape (batch, complex=2, height, width)
output_image = reconstruction_iter[-1] # shape (batch, complex=2, height, width)
output_image = modulus(output_image, complex_axis=self._complex_dim) # shape (batch, height, width)

loss_dict = {
k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()
Expand All @@ -103,14 +106,14 @@ def _do_iteration(
for output_image_iter in reconstruction_iter:
for key, value in loss_dict.items():
loss_dict[key] = value + loss_fns[key](
output_image_iter,
modulus(output_image_iter, complex_axis=self._complex_dim),
**data,
reduction="mean",
)

for key, value in regularizer_dict.items():
regularizer_dict[key] = value + regularizer_fns[key](
output_image_iter,
modulus(output_image_iter, complex_axis=self._complex_dim),
**data,
)

Expand Down
46 changes: 11 additions & 35 deletions direct/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,50 +19,26 @@
logger = logging.getLogger(__name__)


def is_complex_data(data: torch.Tensor, complex_last: bool = True) -> bool:
"""Returns True if data is a complex tensor, i.e. has a complex axis of dimension 2, and False otherwise.
COMPLEX_DIM = 2


def is_complex_data(data: torch.Tensor, complex_axis: int = -1) -> bool:
"""Returns True if data is a complex tensor at a specified dimension, i.e. complex_axis of data is of size 2,
corresponding to real and imaginary channels..

Parameters
----------
data: torch.Tensor
For 2D data the shape is assumed ([batch], [coil], height, width, [complex])
or ([batch], [coil], [complex], height, width).
For 3D data the shape is assumed ([batch], [coil], slice, height, width, [complex])
or ([batch], [coil], [complex], slice, height, width).
complex_last: bool
If true, will require complex axis to be at the last axis. Default: True.
complex_axis: int
Complex dimension along which the check will be done. Default: -1 (last).

Returns
-------
bool
True if data is a complex tensor.
"""
if 2 not in data.shape:
return False
if complex_last:
if data.size(-1) != 2:
return False
else:
if data.ndim == 6:
if data.size(2) != 2 and data.size(-1) != 2: # (B, C, 2, S, H, 2) or (B, C, S, H, W, 2)
return False

elif data.ndim == 5:
# (B, 2, S, H, W) or (B, C, 2, H, W) or (B, S, H, W, 2) or (B, C, H, W, 2)
if data.size(1) != 2 and data.size(2) != 2 and data.size(-1) != 2:
return False

elif data.ndim == 4:
if data.size(1) != 2 and data.size(-1) != 2: # (B, 2, H, W) or (B, H, W, 2) or (S, H, W, 2)
return False

elif data.ndim == 3:
if data.size(-1) != 2: # (H, W, 2)
return False

else:
raise ValueError(f"Not compatible number of dimensions for complex data. Got {data.ndim}.")

return True

return data.size(complex_axis) == COMPLEX_DIM


def is_power_of_two(number: int) -> bool:
Expand Down
25 changes: 12 additions & 13 deletions direct/utils/asserts.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors
import inspect
from typing import List
from typing import List, Optional

import torch

Expand Down Expand Up @@ -43,21 +43,20 @@ def assert_same_shape(data_list: List[torch.Tensor]):
raise ValueError(f"All inputs are expected to have the same shape. Got {shape_list}.")


def assert_complex(data: torch.Tensor, complex_last: bool = True) -> None:
"""Assert if a tensor is a complex tensor.
def assert_complex(data: torch.Tensor, complex_axis: int = -1, complex_last: Optional[bool] = None) -> None:
"""Assert if a tensor is complex (has complex dimension of size 2 corresponding to real and imaginary channels).

Parameters
----------
data: torch.Tensor
For 2D data the shape is assumed ([batch], [coil], height, width, [complex])
or ([batch], [coil], [complex], height, width).
For 3D data the shape is assumed ([batch], [coil], slice, height, width, [complex])
or ([batch], [coil], [complex], slice, height, width).
complex_last: bool
If true, will require complex axis to be at the last axis.
Returns
-------
complex_axis: int
Complex dimension along which the assertion will be done. Default: -1 (last).
complex_last: Optional[bool]
If true, will override complex_axis with -1 (last). Default: None.
"""
# TODO: This is because ifft and fft or torch expect the last dimension to represent the complex axis.
if not is_complex_data(data, complex_last):
raise ValueError(f"Complex dimension assumed to be 2 (complex valued), but not found in shape {data.shape}.")
if complex_last:
complex_axis = -1
assert is_complex_data(
data, complex_axis
), f"Complex dimension assumed to be 2 (complex valued), but not found in shape {data.shape}."
18 changes: 0 additions & 18 deletions tests/tests_data/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,24 +86,6 @@ def test_modulus(shape):
assert np.allclose(out_torch, out_numpy)


@pytest.mark.parametrize(
"shape",
[
[3, 3],
[4, 6],
[10, 8, 4],
[3, 4, 3, 5],
],
)
@pytest.mark.parametrize("complex", [True, False])
def test_modulus_if_complex(shape, complex):
if complex:
shape += [
2,
]
test_modulus(shape)


@pytest.mark.parametrize(
"shape, dims",
[
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_nn/test_rim_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,4 @@ def test_lpd_engine(shape, loss_fns, length, depth, scale_log):
)
loss_fns = engine.build_loss()
out = engine._do_iteration(data, loss_fns)
assert out.output_image.shape == (shape[0],) + (2,) + tuple(shape[2:-1])
assert out.output_image.shape == (shape[0],) + tuple(shape[2:-1])
25 changes: 13 additions & 12 deletions tests/tests_utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import pytest
import torch

from direct.utils import is_complex_data, is_power_of_two, normalize_image, remove_keys, set_all_seeds
from direct.utils import is_power_of_two, normalize_image, remove_keys, set_all_seeds
from direct.utils.asserts import assert_complex
from direct.utils.bbox import crop_to_largest
from direct.utils.dataset import get_filenames_for_datasets

Expand All @@ -31,22 +32,22 @@ def __init__(self, *args, **kwargs):


@pytest.mark.parametrize(
"shape",
"shape, complex_axis, complex_last",
[
[3, 3, 2],
[5, 8, 4, 2],
[5, 2, 8, 4],
[3, 5, 8, 4, 2],
[3, 5, 2, 8, 4],
[3, 2, 5, 8, 4],
[3, 3, 5, 8, 4, 2],
[3, 3, 2, 5, 8, 4],
[[3, 3, 2], None, True],
[[5, 8, 4, 2], -1, None],
[[5, 2, 8, 4], 1, None],
[[3, 5, 8, 4, 2], None, True],
[[3, 5, 2, 8, 4], -3, None],
[[3, 2, 5, 8, 4], 1, None],
[[3, 3, 5, 8, 4, 2], None, True],
[[3, 3, 2, 5, 8, 4], 2, None],
],
)
def test_is_complex_data(shape):
def test_is_complex_data(shape, complex_axis, complex_last):
data = create_input(shape)

assert is_complex_data(data, False)
assert_complex(data, complex_axis=complex_axis, complex_last=complex_last)


@pytest.mark.parametrize(
Expand Down