-
Notifications
You must be signed in to change notification settings - Fork 48
Implement recon models #123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7202996
bfa201e
7b22fbf
95492c4
1fa094e
1ead985
e0713e0
9efd45a
a4f600b
8396472
8ad63d5
9eac5dc
2cc2401
5688fd8
f04c7bf
c955f84
b341c90
70459b8
5c07761
7e5bf37
a77083a
9c163ea
37bf359
4334d2b
89d265f
9be039b
e31e854
3f68f81
b9b7dc4
a4e3e5e
0a8c07e
59c2137
c28db14
da9642c
f83a47c
c4a250e
521c3b7
05d5dc2
c37ebfe
9279eb8
597cdf9
36f4fa0
dfa0729
4b43bad
3dd2004
bd89f92
028565c
4aaf108
a13f565
694ae41
edb95be
b9c6a3c
1424355
26eef4a
3d6808e
8b612a5
5448450
ed5ddda
7629af6
6411b33
23a40f7
122a3e0
3167544
5d3d058
55cd4e9
ac60ada
44ccfc3
771b01d
281ed60
713cb8d
e7cb662
3fa6edf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,10 +10,9 @@ | |
| import numpy as np | ||
| import torch | ||
| import torch.fft | ||
| from packaging import version | ||
|
|
||
| from direct.data.bbox import crop_to_bbox | ||
| from direct.utils import ensure_list, is_power_of_two | ||
| from direct.utils import ensure_list, is_power_of_two, is_complex_data | ||
| from direct.utils.asserts import assert_complex, assert_same_shape | ||
|
|
||
|
|
||
|
|
@@ -222,7 +221,6 @@ def safe_divide(input_tensor: torch.Tensor, other_tensor: torch.Tensor) -> torch | |
| torch.tensor([0.0], dtype=input_tensor.dtype).to(input_tensor.device), | ||
| input_tensor / other_tensor, | ||
| ) | ||
|
|
||
| return data | ||
|
|
||
|
|
||
|
|
@@ -266,8 +264,7 @@ def align_as(input_tensor: torch.Tensor, other: torch.Tensor) -> torch.Tensor: | |
| input_shape = list(input_tensor.shape) | ||
| other_shape = torch.tensor(other.shape, dtype=int) | ||
| out_shape = torch.ones(len(other.shape), dtype=int) | ||
| # TODO(gy): Fix to ensure complex_last when [2,..., 2] or [..., N,..., N,...] in other.shape, | ||
| # "-input_shape.count(dim):" is a hack and might cause problems. | ||
|
|
||
| for dim in np.sort(np.unique(input_tensor.shape)): | ||
| ind = torch.where(other_shape == dim)[0][-input_shape.count(dim) :] | ||
| out_shape[ind] = dim | ||
|
|
@@ -292,7 +289,6 @@ def modulus(data: torch.Tensor) -> torch.Tensor: | |
| complex_axis = -1 if data.size(-1) == 2 else 1 | ||
|
|
||
| return (data ** 2).sum(complex_axis).sqrt() # noqa | ||
| # return torch.view_as_complex(data).abs() | ||
|
|
||
|
|
||
| def modulus_if_complex(data: torch.Tensor) -> torch.Tensor: | ||
|
|
@@ -307,11 +303,9 @@ def modulus_if_complex(data: torch.Tensor) -> torch.Tensor: | |
| ------- | ||
| torch.Tensor | ||
| """ | ||
| # TODO: This can be merged with modulus if the tensor is real. | ||
| try: | ||
| if is_complex_data(data, complex_last=False): | ||
| return modulus(data) | ||
| except ValueError: | ||
| return data | ||
| return data | ||
|
|
||
|
|
||
| def roll( | ||
|
|
@@ -436,7 +430,7 @@ def _complex_matrix_multiplication(input_tensor, other_tensor, mult_func): | |
|
|
||
| Parameters | ||
| ---------- | ||
| x : torch.Tensor | ||
| input_tensor : torch.Tensor | ||
| other_tensor : torch.Tensor | ||
| mult_func : Callable | ||
| Multiplication function e.g. torch.bmm or torch.mm | ||
|
|
@@ -466,6 +460,7 @@ def complex_mm(input_tensor, other_tensor): | |
| ---------- | ||
| input_tensor : torch.Tensor | ||
| other_tensor : torch.Tensor | ||
|
|
||
| Returns | ||
| ------- | ||
| torch.Tensor | ||
|
|
@@ -501,10 +496,6 @@ def conjugate(data: torch.Tensor) -> torch.Tensor: | |
| ------- | ||
| torch.Tensor | ||
| """ | ||
| # assert_complex(data, complex_last=True) | ||
| # data = torch.view_as_real( | ||
| # torch.view_as_complex(data).conj() | ||
| # ) | ||
| assert_complex(data, complex_last=True) | ||
| data = data.clone() # Clone is required as the data in the next line is changed in-place. | ||
| data[..., 1] = data[..., 1] * -1.0 | ||
|
|
@@ -574,11 +565,12 @@ def tensor_to_complex_numpy(data: torch.Tensor) -> np.ndarray: | |
| return data[..., 0] + 1j * data[..., 1] | ||
|
|
||
|
|
||
| def root_sum_of_squares(data: torch.Tensor, dim: int = 0) -> torch.Tensor: | ||
| r""" | ||
| def root_sum_of_squares(data: torch.Tensor, dim: int = 0, complex_dim: int = -1) -> torch.Tensor: | ||
| """ | ||
| Compute the root sum of squares (RSS) transform along a given dimension of the input tensor. | ||
|
|
||
| $$x_{\textrm{rss}} = \sqrt{\sum_{i \in \textrm{coil}} |x_i|^2}$$ | ||
| .. math:: | ||
| x_{\textrm{rss}} = \sqrt{\sum_{i \in \textrm{coil}} |x_i|^2} | ||
|
|
||
| Parameters | ||
| ---------- | ||
|
|
@@ -588,16 +580,16 @@ def root_sum_of_squares(data: torch.Tensor, dim: int = 0) -> torch.Tensor: | |
| dim : int | ||
| Coil dimension. Default is 0 as the first dimension is always the coil dimension. | ||
|
|
||
| complex_dim : int | ||
| Complex channel dimension. Default is -1. If data not complex this is ignored. | ||
| Returns | ||
| ------- | ||
| torch.Tensor : RSS of the input tensor. | ||
| """ | ||
| try: | ||
| assert_complex(data, complex_last=True) | ||
| complex_index = -1 | ||
| return torch.sqrt((data ** 2).sum(complex_index).sum(dim)) | ||
| except ValueError: | ||
| return torch.sqrt((data ** 2).sum(dim)) | ||
| if is_complex_data(data): | ||
| return torch.sqrt((data ** 2).sum(complex_dim).sum(dim)) | ||
|
|
||
| return torch.sqrt((data ** 2).sum(dim)) | ||
|
|
||
|
|
||
| def center_crop(data: torch.Tensor, shape: Tuple[int, int]) -> torch.Tensor: | ||
|
|
@@ -760,5 +752,71 @@ def complex_random_crop( | |
|
|
||
| if len(output) == 1: | ||
| return output[0] | ||
|
|
||
| return output | ||
|
|
||
|
|
||
| def reduce_operator( | ||
| coil_data: torch.Tensor, | ||
| sensitivity_map: torch.Tensor, | ||
| dim: int = 0, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Given zero-filled reconstructions from multiple coils :math: \{x_i\}_{i=1}^{N_c} and coil sensitivity maps | ||
| :math: \{S_i\}_{i=1}^{N_c} it returns | ||
| .. math:: | ||
| R(x_1, .., x_{N_c}, S_1, .., S_{N_c}) = \sum_{i=1}^{N_c} {S_i}^{*} \times x_i. | ||
|
|
||
| From paper End-to-End Variational Networks for Accelerated MRI Reconstruction. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| coil_data : torch.Tensor | ||
| Zero-filled reconstructions from coils. Should be a complex tensor (with complex dim of size 2). | ||
| sensitivity_map: torch.Tensor | ||
| Coil sensitivity maps. Should be complex tensor (with complex dim of size 2). | ||
| dim: int | ||
| Coil dimension. Default: 0. | ||
|
|
||
| Returns | ||
| ------- | ||
| torch.Tensor: | ||
| Combined individual coil images. | ||
| """ | ||
|
|
||
| assert_complex(coil_data, complex_last=True) | ||
| assert_complex(sensitivity_map, complex_last=True) | ||
|
Comment on lines
+786
to
+787
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will require to change the whole software & models so let's keep this as a future change |
||
|
|
||
| return complex_multiplication(conjugate(sensitivity_map), coil_data).sum(dim) | ||
|
|
||
|
|
||
| def expand_operator( | ||
| data: torch.Tensor, | ||
| sensitivity_map: torch.Tensor, | ||
| dim: int = 0, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Given a reconstructed image x and coil sensitivity maps :math: \{S_i\}_{i=1}^{N_c}, it returns | ||
| .. math:: | ||
| \Epsilon(x) = (S_1 \times x, .., S_{N_c} \times x) = (x_1, .., x_{N_c}). | ||
|
|
||
| From paper End-to-End Variational Networks for Accelerated MRI Reconstruction. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| data : torch.Tensor | ||
| Image data. Should be a complex tensor (with complex dim of size 2). | ||
| sensitivity_map: torch.Tensor | ||
| Coil sensitivity maps. Should be complex tensor (with complex dim of size 2). | ||
| dim: int | ||
| Coil dimension. Default: 0. | ||
|
|
||
| Returns | ||
| ------- | ||
| torch.Tensor: | ||
| Zero-filled reconstructions from each coil. | ||
| """ | ||
|
|
||
| assert_complex(data, complex_last=True) | ||
| assert_complex(sensitivity_map, complex_last=True) | ||
|
|
||
| return complex_multiplication(sensitivity_map, data.unsqueeze(dim)) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,7 +11,7 @@ | |
| import warnings | ||
| from abc import ABC, abstractmethod | ||
| from collections import namedtuple | ||
| from typing import Callable, Dict, List, Optional, TypedDict, Union | ||
| from typing import Callable, Dict, List, Optional, Union | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
@@ -278,13 +278,6 @@ def training_loop( | |
| fail_counter = 0 | ||
| for data, iter_idx in zip(data_loader, range(start_iter, total_iter)): | ||
|
|
||
| # 2D data is batched and contains keys: | ||
| # "filename_slice", "slice_no" | ||
| # "sampling_mask" of shape: (batch, 1, height, width, 1) | ||
| # "sensitivity_map" of shape: (batch, coil, height, width, complex=2) | ||
| # "target" of shape: (batch, height, width) | ||
| # "masked_kspace" of shape: (batch, coil, height, width, complex=2) | ||
|
|
||
|
Comment on lines
-281
to
-287
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It feels we should write this down somewhere. Maybe in the documentation?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Documentation of engine? |
||
| if iter_idx == 0: | ||
| self.log_first_training_example_and_model(data) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| # coding=utf-8 | ||
| # Copyright (c) DIRECT Contributors |
Uh oh!
There was an error while loading. Please reload this page.