Skip to content

Commit 5f00bf1

Browse files
ryancinsightpre-commit-ci[bot]monai-bot
authored
Update dice.py (#4234)
* Update dice.py reduce redundant operations in DiceFocalLoss, initially caused oom Signed-off-by: Ryan Clanton <55164720+ryancinsight@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Ryan Clanton <55164720+ryancinsight@users.noreply.github.com> * [MONAI] python code formatting Signed-off-by: monai-bot <monai.miccai2019@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: monai-bot <monai.miccai2019@gmail.com>
1 parent 66f3280 commit 5f00bf1

File tree

1 file changed

+19
-9
lines changed

1 file changed

+19
-9
lines changed

monai/losses/dice.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -796,8 +796,6 @@ def __init__(
796796
"""
797797
super().__init__()
798798
self.dice = DiceLoss(
799-
include_background=include_background,
800-
to_onehot_y=to_onehot_y,
801799
sigmoid=sigmoid,
802800
softmax=softmax,
803801
other_act=other_act,
@@ -808,19 +806,15 @@ def __init__(
808806
smooth_dr=smooth_dr,
809807
batch=batch,
810808
)
811-
self.focal = FocalLoss(
812-
include_background=include_background,
813-
to_onehot_y=to_onehot_y,
814-
gamma=gamma,
815-
weight=focal_weight,
816-
reduction=reduction,
817-
)
809+
self.focal = FocalLoss(gamma=gamma, weight=focal_weight, reduction=reduction)
818810
if lambda_dice < 0.0:
819811
raise ValueError("lambda_dice should be no less than 0.0.")
820812
if lambda_focal < 0.0:
821813
raise ValueError("lambda_focal should be no less than 0.0.")
822814
self.lambda_dice = lambda_dice
823815
self.lambda_focal = lambda_focal
816+
self.to_onehot_y = to_onehot_y
817+
self.include_background = include_background
824818

825819
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
826820
"""
@@ -837,6 +831,22 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
837831
if len(input.shape) != len(target.shape):
838832
raise ValueError("the number of dimensions for input and target should be the same.")
839833

834+
n_pred_ch = input.shape[1]
835+
836+
if self.to_onehot_y:
837+
if n_pred_ch == 1:
838+
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
839+
else:
840+
target = one_hot(target, num_classes=n_pred_ch)
841+
842+
if not self.include_background:
843+
if n_pred_ch == 1:
844+
warnings.warn("single channel prediction, `include_background=False` ignored.")
845+
else:
846+
# if skipping background, removing first channel
847+
target = target[:, 1:]
848+
input = input[:, 1:]
849+
840850
dice_loss = self.dice(input, target)
841851
focal_loss = self.focal(input, target)
842852
total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_focal * focal_loss

0 commit comments

Comments
 (0)