Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 50 additions & 3 deletions monai/losses/cldice.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@

from monai.losses.dice import DiceLoss
from monai.networks import one_hot
from monai.utils import LossReduction
from monai.utils import LossReduction, optional_import
from monai.utils.deprecate_utils import deprecated_arg

centerline_extraction_3d, _has_thinning = optional_import("centerline_extraction_3d_cuda")


def soft_erode(img: torch.Tensor) -> torch.Tensor: # type: ignore
"""
Expand Down Expand Up @@ -129,6 +131,8 @@ def __init__(
softmax: bool = False,
other_act: Callable | None = None,
reduction: LossReduction | str = LossReduction.MEAN,
use_hard_target: bool = False,
use_hard_prob: bool = False,
) -> None:
"""
Args:
Expand All @@ -151,6 +155,10 @@ def __init__(
- ``"none"``: no reduction will be applied.
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
- ``"sum"``: the output will be summed.
use_hard_target: if True, use the exact CUDA 3D binary thinning for the target skeleton instead of soft skeletonization.
Requires centerline_extraction_3d_cuda package and a CUDA 3D target. Defaults to False.
use_hard_prob: if True, use the CUDA 3D prob map thinning with backward for the prediction skeleton instead of soft skeletonization.
Requires centerline_extraction_3d_cuda package and a CUDA 3D input. Defaults to False.

Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
Expand Down Expand Up @@ -181,6 +189,8 @@ def __init__(
self.sigmoid = sigmoid
self.softmax = softmax
self.other_act = other_act
self.use_hard_target = use_hard_target
self.use_hard_prob = use_hard_prob

@deprecated_arg("y_pred", since="1.5", removed="1.8", new_name="input", msg_suffix="please use `input` instead.")
@deprecated_arg("y_true", since="1.5", removed="1.8", new_name="target", msg_suffix="please use `target` instead.")
Expand All @@ -193,6 +203,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
Raises:
AssertionError: When input and target (after one hot transform if set)
have different shapes.
ValueError: When `use_hard_prob` or `use_hard_target` is enabled but the tensor is not 5D CUDA
or `centerline_extraction_3d_cuda` is unavailable.

"""
n_pred_ch = input.shape[1]
Expand Down Expand Up @@ -225,8 +237,33 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if target.shape != input.shape:
raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})")

skel_pred = soft_skel(input, self.iter)
skel_true = soft_skel(target, self.iter)
if self.use_hard_prob:
if not (input.dim() == 5 and _has_thinning and input.is_cuda):
raise ValueError(
"use_hard_prob=True but conditions not met. "
"Requires 5D CUDA tensor and centerline_extraction_3d_cuda package."
)
pred_mask = (input >= 0.5).to(torch.uint8).contiguous()
skel_pred = torch.zeros_like(input)
for b in range(input.shape[0]):
for c in range(input.shape[1]):
skel_pred[b, c] = centerline_extraction_3d.extract_centerline(pred_mask[b, c], input[b, c], 0)
else:
skel_pred = soft_skel(input, self.iter)

if self.use_hard_target:
if not (target.dim() == 5 and _has_thinning and target.is_cuda):
raise ValueError(
"use_hard_target=True but conditions not met. "
"Requires 5D CUDA tensor and centerline_extraction_3d_cuda package."
)
Comment thread
sychen52 marked this conversation as resolved.
skel_true = (target > 0).to(torch.uint8).contiguous()
for b in range(target.shape[0]):
for c in range(target.shape[1]):
centerline_extraction_3d.binary_thinning(skel_true[b, c], 0)
skel_true = skel_true.to(target.dtype)
else:
skel_true = soft_skel(target, self.iter)

# Compute per-batch clDice by reducing over channel and spatial dimensions
# reduce_axis includes all dimensions except batch (dim 0)
Expand Down Expand Up @@ -279,6 +316,8 @@ def __init__(
softmax: bool = False,
other_act: Callable | None = None,
reduction: LossReduction | str = LossReduction.MEAN,
use_hard_target: bool = False,
use_hard_prob: bool = False,
) -> None:
"""
Args:
Expand All @@ -304,6 +343,10 @@ def __init__(
- ``"none"``: no reduction will be applied.
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
- ``"sum"``: the output will be summed.
use_hard_target: if True, use the exact CUDA 3D binary thinning for the target skeleton instead of soft skeletonization.
Requires MONAI C++ extensions and a 3D target. Defaults to False.
use_hard_prob: if True, use the CUDA 3D prob map thinning with backward for the prediction skeleton instead of soft skeletonization.
Requires centerline_extraction_3d_cuda package and a CUDA 3D input. Defaults to False.

Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
Expand Down Expand Up @@ -336,6 +379,8 @@ def __init__(
softmax=softmax,
other_act=other_act,
reduction=reduction,
use_hard_target=use_hard_target,
use_hard_prob=use_hard_prob,
Comment thread
sychen52 marked this conversation as resolved.
)
self.alpha = alpha
self.to_onehot_y = to_onehot_y
Expand All @@ -351,6 +396,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
Raises:
ValueError: When number of dimensions for input and target are different.
ValueError: When number of channels for target is neither 1 nor the same as input.
ValueError: When `use_hard_prob` or `use_hard_target` is enabled but the tensor is not 5D CUDA
or `centerline_extraction_3d_cuda` is unavailable.

"""
if input.dim() != target.dim():
Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ all =
nvidia-ml-py
huggingface_hub
pyamg>=5.0.0, <5.3.0
centerline_extraction_3d_cuda
nibabel =
nibabel
ninja =
Expand Down Expand Up @@ -179,6 +180,8 @@ huggingface_hub =
huggingface_hub
pyamg =
pyamg>=5.0.0, <5.3.0
centerline_extraction =
centerline_extraction_3d_cuda
# segment-anything =
# segment-anything @ git+https://github.com/facebookresearch/segment-anything@6fdee8f2727f4506cfbbe553e23b895e27956588#egg=segment-anything

Expand Down
62 changes: 62 additions & 0 deletions tests/losses/test_cldice_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,39 @@ def test_cuda(self):
result = loss(ONES_2D["input"].cuda(), ONES_2D["target"].cuda())
np.testing.assert_allclose(result.detach().cpu().numpy(), 0.0, atol=1e-4)

@skip_if_no_cuda
def test_hard_target(self):
Comment thread
sychen52 marked this conversation as resolved.
"""Test SoftclDiceLoss with use_hard_target=True using binary thinning on 3D CUDA tensors."""
# Skip if thinning not available
from monai.losses.cldice import _has_thinning

if not _has_thinning:
self.skipTest("centerline_extraction_3d_cuda not available")

loss = SoftclDiceLoss(use_hard_target=True)
# MUST BE 3D for hard target logic to trigger! (shape: B, N, H, W, D)
result = loss(ONES_3D["input"].cuda(), ONES_3D["target"].cuda())
np.testing.assert_allclose(result.detach().cpu().numpy(), 0.0, atol=1e-4)
Comment thread
sychen52 marked this conversation as resolved.

@skip_if_no_cuda
def test_hard_prob(self):
"""Test SoftclDiceLoss with use_hard_prob=True using prob thinning on 3D CUDA tensors."""
# Skip if thinning not available
from monai.losses.cldice import _has_thinning

if not _has_thinning:
self.skipTest("centerline_extraction_3d_cuda not available")

loss = SoftclDiceLoss(use_hard_prob=True)
# MUST BE 3D for hard prob logic to trigger! (shape: B, N, H, W, D)
input_tensor = torch.ones_like(ONES_3D["input"]).cuda()
input_tensor.requires_grad = True
target = ONES_3D["target"].cuda()
result = loss(input_tensor, target)
np.testing.assert_allclose(result.detach().cpu().numpy(), 0.0, atol=1e-4)
result.backward()
self.assertIsNotNone(input_tensor.grad)

def test_reduction_shapes(self):
input_tensor = torch.ones((4, 2, 8, 8))
target = torch.ones((4, 2, 8, 8))
Expand Down Expand Up @@ -128,6 +161,35 @@ def test_cuda(self):
result = loss(ONES_2D["input"].cuda(), ONES_2D["target"].cuda())
np.testing.assert_allclose(result.detach().cpu().numpy(), 0.0, atol=1e-4)

@skip_if_no_cuda
def test_hard_target(self):
"""Test SoftDiceclDiceLoss with use_hard_target=True."""
from monai.losses.cldice import _has_thinning

if not _has_thinning:
self.skipTest("centerline_extraction_3d_cuda not available")

loss = SoftDiceclDiceLoss(use_hard_target=True)
result = loss(ONES_3D["input"].cuda(), ONES_3D["target"].cuda())
np.testing.assert_allclose(result.detach().cpu().numpy(), 0.0, atol=1e-4)

@skip_if_no_cuda
def test_hard_prob(self):
"""Test SoftDiceclDiceLoss with use_hard_prob=True."""
from monai.losses.cldice import _has_thinning

if not _has_thinning:
self.skipTest("centerline_extraction_3d_cuda not available")

loss = SoftDiceclDiceLoss(use_hard_prob=True)
input_tensor = torch.ones_like(ONES_3D["input"]).cuda()
input_tensor.requires_grad = True
target = ONES_3D["target"].cuda()
result = loss(input_tensor, target)
np.testing.assert_allclose(result.detach().cpu().numpy(), 0.0, atol=1e-4)
result.backward()
self.assertIsNotNone(input_tensor.grad)

def test_dimension_mismatch(self):
loss = SoftDiceclDiceLoss()
with self.assertRaises(ValueError):
Expand Down
Loading