Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 24 additions & 25 deletions src/otx/algorithms/visual_prompting/tasks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def __init__(self, task_environment: TaskEnvironment, output_path: Optional[str]
self.optimization_type = ModelOptimizationType.MO

self.trainer: Trainer
self._model_ckpt: Optional[str] = None

self.timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime())

Expand Down Expand Up @@ -134,17 +135,11 @@ def get_config(self) -> Union[DictConfig, ListConfig]:
resume_from_checkpoint: Optional[str] = None
if self.mode == "train" and self.task_environment.model is not None:
# when args.load_weights or args.resume_from is set
resume_from_checkpoint = model_checkpoint = self.task_environment.model.model_adapters.get("path", None) # type: ignore # noqa: E501
checkpoint_path = str(self.task_environment.model.model_adapters.get("path", None))
if self.task_environment.model.model_adapters.get("resume", False):
if resume_from_checkpoint.endswith(".pth"): # type: ignore
logger.info("[*] Pytorch checkpoint cannot be used for resuming. It will be supported.")
resume_from_checkpoint = None
else:
model_checkpoint = None
resume_from_checkpoint = checkpoint_path
else:
# If not resuming, set resume_from_checkpoint to None to avoid training in resume environment
# and saving to configuration.
resume_from_checkpoint = None
model_checkpoint = checkpoint_path

config = get_visual_promtping_config(
task_name=self.model_name,
Expand Down Expand Up @@ -191,29 +186,34 @@ def get_model(config: DictConfig, state_dict: Optional[OrderedDict] = None):
"No trained model in project yet. Created new model with '%s'",
self.model_name,
)
elif ("path" in otx_model.model_adapters) and (
otx_model.model_adapters.get("path").endswith(".ckpt") # type: ignore[attr-defined]
):
# pytorch lightning checkpoint
if not otx_model.model_adapters.get("resume"):
# If not resuming, just load weights in LightningModule
logger.info("Load pytorch lightning checkpoint.")
elif otx_model.model_adapters.get("resume", False):
# If resuming, pass this part to load checkpoint in Trainer
logger.info(f"To resume {otx_model.model_adapters.get('path')}, the checkpoint will be loaded in Trainer.")

else:
# pytorch checkpoint saved by otx
# Load state_dict
buffer = io.BytesIO(otx_model.get_data("weights.pth"))
model_data = torch.load(buffer, map_location=torch.device("cpu"))
if model_data.get("model", None) and model_data.get("config", None):
if model_data.get("state_dict", None) and model_data.get("pytorch-lightning_version", None):
# Load state_dict from pytorch lightning checkpoint or weights.pth saved by visual prompting task
# In pytorch lightning checkpoint, there are metas: epoch, global_step, pytorch-lightning_version,
# state_dict, loops, callbacks, optimizer_states, lr_schedulers, hparams_name, hyper_parameters.
# To confirm if it is from pytorch lightning, check if one or two of them is in model_data.
state_dict = model_data["state_dict"]

elif model_data.get("model", None) and model_data.get("config", None):
# Load state_dict from checkpoint saved by otx other tasks
if model_data["config"]["model"]["backbone"] != self.config["model"]["backbone"]:
logger.warning(
"Backbone of the model in the Task Environment is different from the one in the template. "
f"creating model with backbone={model_data['config']['model']['backbone']}"
)
self.config["model"]["backbone"] = model_data["config"]["model"]["backbone"]
state_dict = model_data["model"]
logger.info("Load pytorch checkpoint from weights.pth.")

else:
# Load state_dict from naive pytorch checkpoint
state_dict = model_data
logger.info("Load pytorch checkpoint.")

try:
model = get_model(config=self.config, state_dict=state_dict)
Expand Down Expand Up @@ -406,11 +406,10 @@ def model_info(self) -> Dict:
Returns:
Dict: Model info.
"""
return {
"model": self.model.state_dict(),
"config": self.get_config(),
"version": self.trainer.logger.version,
}
if not self._model_ckpt:
logger.warn("model checkpoint is not set, return empty dictionary.")
return {}
return torch.load(self._model_ckpt, map_location="cpu")

def save_model(self, output_model: ModelEntity) -> None:
"""Save the model after training is completed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ def test_get_visual_promtping_config(
assert config.get("optimizer", False)
assert config.get("callback", False)
assert config.get("trainer", False)
if model_checkpoint is not None:
if mode == "train":
assert config.get("model").get("checkpoint") == model_checkpoint
if mode == "train":
if model_checkpoint:
assert config.get("model").get("checkpoint", None) == model_checkpoint
else:
assert config.get("model").get("checkpoint") != model_checkpoint
assert config.get("trainer").get("resume_from_checkpoint", None) == resume_from_checkpoint
assert config.get("model").get("checkpoint", None) != model_checkpoint
assert config.get("trainer").get("resume_from_checkpoint", None) == resume_from_checkpoint


@e2e_pytest_unit
Expand Down
67 changes: 27 additions & 40 deletions tests/unit/algorithms/visual_prompting/tasks/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#

from collections import OrderedDict
from typing import Optional
from typing import Optional, Dict, Any

import pytest
from omegaconf import DictConfig
Expand Down Expand Up @@ -59,17 +59,9 @@ def test_get_config_train(self, mocker, load_inference_task, path: Optional[str]
assert isinstance(inference_task.config, DictConfig)
assert inference_task.config.dataset.task == "visual_prompting"
if path:
if path.endswith(".pth"):
# TODO (sungchul): when applying resume
# pytorch weights
assert inference_task.config.model.checkpoint == path
assert inference_task.config.trainer.resume_from_checkpoint is None
elif path.endswith(".ckpt") and resume:
# resume with pytorch lightning weights
assert inference_task.config.model.checkpoint != path # use default checkpoint
if resume:
assert inference_task.config.trainer.resume_from_checkpoint == path
else:
# just train with pytorch lightning weights
assert inference_task.config.model.checkpoint == path
assert inference_task.config.trainer.resume_from_checkpoint is None

Expand All @@ -84,40 +76,37 @@ def test_get_config_eval(self, mocker, load_inference_task):
assert inference_task.config.dataset.task == "visual_prompting"

@e2e_pytest_unit
@pytest.mark.parametrize("path", [None, "checkpoint.ckpt"])
@pytest.mark.parametrize("resume", [True, False])
def test_load_model_without_otx_model_or_with_lightning_ckpt(
self, mocker, load_inference_task, path: str, resume: bool
):
"""Test load_model to resume."""
mocker_segment_anything = mocker.patch(
"otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.SegmentAnything"
)

inference_task = load_inference_task(path=path, resume=resume)
inference_task.load_model(otx_model=inference_task.task_environment.model)

mocker_segment_anything.assert_called_once()

@e2e_pytest_unit
@pytest.mark.parametrize("path", [None, "checkpoint.ckpt", "checkpoint.pth"])
@pytest.mark.parametrize("resume", [True, False])
def test_load_model_with_pytorch_pth(self, mocker, load_inference_task, resume: bool):
"""Test load_model with otx_model."""
@pytest.mark.parametrize(
"load_return_value",
[
{"state_dict": {"layer": "weights"}, "pytorch-lightning_version": "version"},
{"model": {"layer": "weights"}, "config": {"model": {"backbone": "sam_vit_b"}}},
{},
],
)
def test_load_model(self, mocker, load_inference_task, path: str, resume: bool, load_return_value: Dict[str, Any]):
"""Test load_model."""
mocker_segment_anything = mocker.patch(
"otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.SegmentAnything"
)
mocker_io_bytes_io = mocker.patch("io.BytesIO")
mocker_torch_load = mocker.patch(
"torch.load",
return_value=dict(config=dict(model=dict(backbone="sam_vit_b")), model={}),
return_value=load_return_value,
)

inference_task = load_inference_task(path="checkpoint.pth", resume=resume)
inference_task = load_inference_task(path=path, resume=resume)
inference_task.load_model(otx_model=inference_task.task_environment.model)

mocker_segment_anything.assert_called_once()
mocker_io_bytes_io.assert_called_once()
mocker_torch_load.assert_called_once()
if resume or path is None:
mocker_io_bytes_io.assert_not_called()
mocker_torch_load.assert_not_called()
else:
mocker_io_bytes_io.assert_called_once()
mocker_torch_load.assert_called_once()

@e2e_pytest_unit
def test_infer(self, mocker, load_inference_task):
Expand Down Expand Up @@ -162,19 +151,17 @@ def test_model_info(self, mocker, load_inference_task):
mocker.patch(
"otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.load_checkpoint"
)
mocker.patch("torch.load", return_value={"state_dict": {"layer": "weights"}, "epoch": 0})

inference_task = load_inference_task(output_path=None)
inference_task.model = inference_task.load_model(otx_model=inference_task.task_environment.model)
setattr(inference_task, "trainer", None)
mocker.patch.object(inference_task, "trainer")
inference_task._model_ckpt = "checkpoint"

model_info = inference_task.model_info()

assert "model" in model_info
assert isinstance(model_info["model"], OrderedDict)
assert "config" in model_info
assert isinstance(model_info["config"], DictConfig)
assert "version" in model_info
assert isinstance(model_info.get("state_dict", None), dict)
assert model_info.get("state_dict", None)["layer"] == "weights"
assert isinstance(model_info.get("epoch", None), int)
assert model_info.get("epoch", None) == 0

@e2e_pytest_unit
def test_save_model(self, mocker, load_inference_task):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def test_train(self, mocker):
"""Test train."""
mocker_trainer = mocker.patch("otx.algorithms.visual_prompting.tasks.train.Trainer")
mocker_save = mocker.patch("torch.save")
mocker.patch.object(self.training_task, "model_info")

dataset = generate_visual_prompting_dataset()
output_model = ModelEntity(
Expand Down