Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ ci:

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
rev: v6.0.0
hooks:
- id: end-of-file-fixer
- id: trailing-whitespace
Expand All @@ -26,18 +26,18 @@ repos:
args: ['--autofix', '--no-sort-keys', '--indent=4']
- id: end-of-file-fixer
- id: mixed-line-ending
- repo: https://github.com/psf/black
rev: "24.10.0"
- repo: https://github.com/psf/black-pre-commit-mirror
rev: "25.12.0"
hooks:
- id: black
- id: black-jupyter
- repo: https://github.com/pycqa/isort
rev: 5.13.2
rev: 7.0.0
hooks:
- id: isort
args: ["--profile", "black"]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.6
rev: v0.14.10
hooks:
- id: ruff
args: ['--fix']
Expand Down
10 changes: 5 additions & 5 deletions vista3d/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ mv model-zoo/models/vista3d vista3dbundle & rm -rf model-zoo
cd vista3dbundle
mkdir models
# minor model weights naming conversion due to monai version change
wget -O models/model.pt https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo/model_vista3d.pt
wget -O models/model.pt https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo/model_vista3d.pt
```
MONAI bundle accepts multiple json config files and input arguments. The latter configs/arguments will overide the previous configs/arguments if they have overlapping keys.
MONAI bundle accepts multiple json config files and input arguments. The latter configs/arguments will overide the previous configs/arguments if they have overlapping keys.
```python
# Automatic Segment everything
python -m monai.bundle run --config_file configs/inference.json --input_dict "{'image':'spleen_03.nii.gz'}
Expand All @@ -108,7 +108,7 @@ python -m monai.bundle run --config_file configs/inference.json --input_dict "{'
python -m monai.bundle run --config_file configs/inference.json --input_dict "{'image':'spleen_03.nii.gz','label_prompt':[3]}
```
```python
# Interactive segmentation
# Interactive segmentation
# Points must be three dimensional (x,y,z) in the shape of [[x,y,z],...,[x,y,z]]. Point labels can only be -1(ignore), 0(negative), 1(positive) and 2(negative for special overlaped class like tumor), 3(positive for special class). Only supporting 1 class per inference. The output 255 represents NaN value which means not processed region.
python -m monai.bundle run --config_file configs/inference.json --input_dict "{'image':'spleen_03.nii.gz','points':[[128,128,16], [100,100,16]],'point_labels':[1, 0]}"
```
Expand Down Expand Up @@ -158,7 +158,7 @@ python -m monai.bundle run --config_file="['configs/inference.json', 'configs/ba
### 1.1 Overlapped classes and postprocessing with [ShapeKit](https://arxiv.org/pdf/2506.24003)
VISTA3D is trained with binary segmentation, and may produce false positives due to weak false positive supervision. ShapeKit solves this problem with sophisticated postprocessing. ShapeKit requires segmentation mask for each class. VISTA3D by default performs argmax and collaps overlapping classes. Change the `monai.apps.vista3d.transforms.VistaPostTransformd` in `inference.json` to save each class segmentation as a separate channel. Then follow [ShapeKit](https://github.com/BodyMaps/ShapeKit) codebase for processing.
```json
{
{
"_target_": "Activationsd",
"sigmoid": true,
"keys": "pred"
Expand All @@ -180,7 +180,7 @@ To segment everything, run
```bash
export CUDA_VISIBLE_DEVICES=0; python -m scripts.infer --config_file 'configs/infer.yaml' - infer_everything --image_file 'example-1.nii.gz'
```
To segment based on point clicks, provide `point` and `point_label`.
To segment based on point clicks, provide `point` and `point_label`.
```bash
export CUDA_VISIBLE_DEVICES=0; python -m scripts.infer --config_file 'configs/infer.yaml' - infer --image_file 'example-1.nii.gz' --point "[[128,128,16],[100,100,6]]" --point_label "[1,0]" --save_mask true
```
Expand Down
1 change: 1 addition & 0 deletions vista3d/cvpr_workshop/infer_cvpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from train_cvpr import ROI_SIZE


def convert_clicks(alldata):
# indexes = list(alldata.keys())
# data = [alldata[i] for i in indexes]
Expand Down
30 changes: 20 additions & 10 deletions vista3d/cvpr_workshop/train_cvpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
import matplotlib.pyplot as plt

NUM_PATCHES_PER_IMAGE = 2
ROI_SIZE= [128, 128, 128]
ROI_SIZE = [128, 128, 128]


def plot_to_tensorboard(writer, epoch, inputs, labels, points, outputs):
"""
Expand Down Expand Up @@ -109,7 +110,7 @@ def __getitem__(self, idx):
keys=["image", "label"],
label_key="label",
num_classes=label.max() + 1,
ratios=tuple(float(i > 0) for i in range(label.max()+1)),
ratios=tuple(float(i > 0) for i in range(label.max() + 1)),
num_samples=NUM_PATCHES_PER_IMAGE,
),
monai.transforms.RandScaleIntensityd(
Expand Down Expand Up @@ -137,17 +138,19 @@ def __getitem__(self, idx):
mode=["constant", "constant"],
keys=["image", "label"],
spatial_size=ROI_SIZE,
)
),
]
)
data = transforms(data)
return data


import re


def get_latest_epoch(directory):
# Pattern to match filenames like 'model_epoch<number>.pth'
pattern = re.compile(r'model_epoch(\d+)\.pth')
pattern = re.compile(r"model_epoch(\d+)\.pth")
max_epoch = -1

for filename in os.listdir(directory):
Expand All @@ -159,6 +162,7 @@ def get_latest_epoch(directory):

return max_epoch if max_epoch != -1 else None


# Training function
def train():
json_file = "allset.json" # Update with your JSON file
Expand All @@ -169,7 +173,6 @@ def train():
start_epoch = get_latest_epoch(checkpoint_dir)
start_checkpoint = "./CPRR25_vista3D_model_final_10percent_data.pth"


os.makedirs(checkpoint_dir, exist_ok=True)
dist.init_process_group(backend="nccl")
world_size = int(os.environ["WORLD_SIZE"])
Expand All @@ -189,11 +192,12 @@ def train():
model.load_state_dict(pretrained_ckpt, strict=True)
else:
print(f"Resuming from epoch {start_epoch}")
pretrained_ckpt = torch.load(os.path.join(checkpoint_dir, f"model_epoch{start_epoch}.pth"))
model.load_state_dict(pretrained_ckpt['model'], strict=True)
pretrained_ckpt = torch.load(
os.path.join(checkpoint_dir, f"model_epoch{start_epoch}.pth")
)
model.load_state_dict(pretrained_ckpt["model"], strict=True)
model = DDP(model, device_ids=[local_rank], find_unused_parameters=True)


optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1.0e-05)
lr_scheduler = monai.optimizers.WarmupCosineSchedule(
optimizer=optimizer,
Expand Down Expand Up @@ -265,10 +269,16 @@ def train():
if local_rank == 0:
writer.add_scalar("loss", loss.item(), step)
if local_rank == 0 and (epoch + 1) % save_interval == 0:
checkpoint_path = os.path.join(checkpoint_dir, f"model_epoch{epoch + 1}.pth")
checkpoint_path = os.path.join(
checkpoint_dir, f"model_epoch{epoch + 1}.pth"
)
if world_size > 1:
torch.save(
{"model": model.module.state_dict(), "epoch": epoch + 1, "step": step},
{
"model": model.module.state_dict(),
"epoch": epoch + 1,
"step": step,
},
checkpoint_path,
)
print(
Expand Down