diff --git a/src/otx/algorithms/visual_prompting/tasks/inference.py b/src/otx/algorithms/visual_prompting/tasks/inference.py index b6da5606d84..6ff23ee9050 100644 --- a/src/otx/algorithms/visual_prompting/tasks/inference.py +++ b/src/otx/algorithms/visual_prompting/tasks/inference.py @@ -106,6 +106,7 @@ def __init__(self, task_environment: TaskEnvironment, output_path: Optional[str] self.optimization_type = ModelOptimizationType.MO self.trainer: Trainer + self._model_ckpt: Optional[str] = None self.timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime()) @@ -134,17 +135,11 @@ def get_config(self) -> Union[DictConfig, ListConfig]: resume_from_checkpoint: Optional[str] = None if self.mode == "train" and self.task_environment.model is not None: # when args.load_weights or args.resume_from is set - resume_from_checkpoint = model_checkpoint = self.task_environment.model.model_adapters.get("path", None) # type: ignore # noqa: E501 + checkpoint_path = str(self.task_environment.model.model_adapters.get("path", None)) if self.task_environment.model.model_adapters.get("resume", False): - if resume_from_checkpoint.endswith(".pth"): # type: ignore - logger.info("[*] Pytorch checkpoint cannot be used for resuming. It will be supported.") - resume_from_checkpoint = None - else: - model_checkpoint = None + resume_from_checkpoint = checkpoint_path else: - # If not resuming, set resume_from_checkpoint to None to avoid training in resume environment - # and saving to configuration. - resume_from_checkpoint = None + model_checkpoint = checkpoint_path config = get_visual_promtping_config( task_name=self.model_name, @@ -191,18 +186,23 @@ def get_model(config: DictConfig, state_dict: Optional[OrderedDict] = None): "No trained model in project yet. Created new model with '%s'", self.model_name, ) - elif ("path" in otx_model.model_adapters) and ( - otx_model.model_adapters.get("path").endswith(".ckpt") # type: ignore[attr-defined] - ): - # pytorch lightning checkpoint - if not otx_model.model_adapters.get("resume"): - # If not resuming, just load weights in LightningModule - logger.info("Load pytorch lightning checkpoint.") + elif otx_model.model_adapters.get("resume", False): + # If resuming, pass this part to load checkpoint in Trainer + logger.info(f"To resume {otx_model.model_adapters.get('path')}, the checkpoint will be loaded in Trainer.") + else: - # pytorch checkpoint saved by otx + # Load state_dict buffer = io.BytesIO(otx_model.get_data("weights.pth")) model_data = torch.load(buffer, map_location=torch.device("cpu")) - if model_data.get("model", None) and model_data.get("config", None): + if model_data.get("state_dict", None) and model_data.get("pytorch-lightning_version", None): + # Load state_dict from pytorch lightning checkpoint or weights.pth saved by visual prompting task + # In pytorch lightning checkpoint, there are metas: epoch, global_step, pytorch-lightning_version, + # state_dict, loops, callbacks, optimizer_states, lr_schedulers, hparams_name, hyper_parameters. + # To confirm if it is from pytorch lightning, check if one or two of them is in model_data. + state_dict = model_data["state_dict"] + + elif model_data.get("model", None) and model_data.get("config", None): + # Load state_dict from checkpoint saved by otx other tasks if model_data["config"]["model"]["backbone"] != self.config["model"]["backbone"]: logger.warning( "Backbone of the model in the Task Environment is different from the one in the template. " @@ -210,10 +210,10 @@ def get_model(config: DictConfig, state_dict: Optional[OrderedDict] = None): ) self.config["model"]["backbone"] = model_data["config"]["model"]["backbone"] state_dict = model_data["model"] - logger.info("Load pytorch checkpoint from weights.pth.") + else: + # Load state_dict from naive pytorch checkpoint state_dict = model_data - logger.info("Load pytorch checkpoint.") try: model = get_model(config=self.config, state_dict=state_dict) @@ -406,11 +406,10 @@ def model_info(self) -> Dict: Returns: Dict: Model info. """ - return { - "model": self.model.state_dict(), - "config": self.get_config(), - "version": self.trainer.logger.version, - } + if not self._model_ckpt: + logger.warn("model checkpoint is not set, return empty dictionary.") + return {} + return torch.load(self._model_ckpt, map_location="cpu") def save_model(self, output_model: ModelEntity) -> None: """Save the model after training is completed. diff --git a/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/config/test_visual_prompting_config.py b/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/config/test_visual_prompting_config.py index c61e6b46589..105047526b8 100644 --- a/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/config/test_visual_prompting_config.py +++ b/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/config/test_visual_prompting_config.py @@ -48,12 +48,12 @@ def test_get_visual_promtping_config( assert config.get("optimizer", False) assert config.get("callback", False) assert config.get("trainer", False) - if model_checkpoint is not None: - if mode == "train": - assert config.get("model").get("checkpoint") == model_checkpoint + if mode == "train": + if model_checkpoint: + assert config.get("model").get("checkpoint", None) == model_checkpoint else: - assert config.get("model").get("checkpoint") != model_checkpoint - assert config.get("trainer").get("resume_from_checkpoint", None) == resume_from_checkpoint + assert config.get("model").get("checkpoint", None) != model_checkpoint + assert config.get("trainer").get("resume_from_checkpoint", None) == resume_from_checkpoint @e2e_pytest_unit diff --git a/tests/unit/algorithms/visual_prompting/tasks/test_inference.py b/tests/unit/algorithms/visual_prompting/tasks/test_inference.py index 3ca7915b2c6..b89d59f86a8 100644 --- a/tests/unit/algorithms/visual_prompting/tasks/test_inference.py +++ b/tests/unit/algorithms/visual_prompting/tasks/test_inference.py @@ -5,7 +5,7 @@ # from collections import OrderedDict -from typing import Optional +from typing import Optional, Dict, Any import pytest from omegaconf import DictConfig @@ -59,17 +59,9 @@ def test_get_config_train(self, mocker, load_inference_task, path: Optional[str] assert isinstance(inference_task.config, DictConfig) assert inference_task.config.dataset.task == "visual_prompting" if path: - if path.endswith(".pth"): - # TODO (sungchul): when applying resume - # pytorch weights - assert inference_task.config.model.checkpoint == path - assert inference_task.config.trainer.resume_from_checkpoint is None - elif path.endswith(".ckpt") and resume: - # resume with pytorch lightning weights - assert inference_task.config.model.checkpoint != path # use default checkpoint + if resume: assert inference_task.config.trainer.resume_from_checkpoint == path else: - # just train with pytorch lightning weights assert inference_task.config.model.checkpoint == path assert inference_task.config.trainer.resume_from_checkpoint is None @@ -84,40 +76,37 @@ def test_get_config_eval(self, mocker, load_inference_task): assert inference_task.config.dataset.task == "visual_prompting" @e2e_pytest_unit - @pytest.mark.parametrize("path", [None, "checkpoint.ckpt"]) - @pytest.mark.parametrize("resume", [True, False]) - def test_load_model_without_otx_model_or_with_lightning_ckpt( - self, mocker, load_inference_task, path: str, resume: bool - ): - """Test load_model to resume.""" - mocker_segment_anything = mocker.patch( - "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.SegmentAnything" - ) - - inference_task = load_inference_task(path=path, resume=resume) - inference_task.load_model(otx_model=inference_task.task_environment.model) - - mocker_segment_anything.assert_called_once() - - @e2e_pytest_unit + @pytest.mark.parametrize("path", [None, "checkpoint.ckpt", "checkpoint.pth"]) @pytest.mark.parametrize("resume", [True, False]) - def test_load_model_with_pytorch_pth(self, mocker, load_inference_task, resume: bool): - """Test load_model with otx_model.""" + @pytest.mark.parametrize( + "load_return_value", + [ + {"state_dict": {"layer": "weights"}, "pytorch-lightning_version": "version"}, + {"model": {"layer": "weights"}, "config": {"model": {"backbone": "sam_vit_b"}}}, + {}, + ], + ) + def test_load_model(self, mocker, load_inference_task, path: str, resume: bool, load_return_value: Dict[str, Any]): + """Test load_model.""" mocker_segment_anything = mocker.patch( "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.SegmentAnything" ) mocker_io_bytes_io = mocker.patch("io.BytesIO") mocker_torch_load = mocker.patch( "torch.load", - return_value=dict(config=dict(model=dict(backbone="sam_vit_b")), model={}), + return_value=load_return_value, ) - inference_task = load_inference_task(path="checkpoint.pth", resume=resume) + inference_task = load_inference_task(path=path, resume=resume) inference_task.load_model(otx_model=inference_task.task_environment.model) mocker_segment_anything.assert_called_once() - mocker_io_bytes_io.assert_called_once() - mocker_torch_load.assert_called_once() + if resume or path is None: + mocker_io_bytes_io.assert_not_called() + mocker_torch_load.assert_not_called() + else: + mocker_io_bytes_io.assert_called_once() + mocker_torch_load.assert_called_once() @e2e_pytest_unit def test_infer(self, mocker, load_inference_task): @@ -162,19 +151,17 @@ def test_model_info(self, mocker, load_inference_task): mocker.patch( "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.load_checkpoint" ) + mocker.patch("torch.load", return_value={"state_dict": {"layer": "weights"}, "epoch": 0}) inference_task = load_inference_task(output_path=None) - inference_task.model = inference_task.load_model(otx_model=inference_task.task_environment.model) - setattr(inference_task, "trainer", None) - mocker.patch.object(inference_task, "trainer") + inference_task._model_ckpt = "checkpoint" model_info = inference_task.model_info() - assert "model" in model_info - assert isinstance(model_info["model"], OrderedDict) - assert "config" in model_info - assert isinstance(model_info["config"], DictConfig) - assert "version" in model_info + assert isinstance(model_info.get("state_dict", None), dict) + assert model_info.get("state_dict", None)["layer"] == "weights" + assert isinstance(model_info.get("epoch", None), int) + assert model_info.get("epoch", None) == 0 @e2e_pytest_unit def test_save_model(self, mocker, load_inference_task): diff --git a/tests/unit/algorithms/visual_prompting/tasks/test_train.py b/tests/unit/algorithms/visual_prompting/tasks/test_train.py index e0fb804ee6b..5d1ea57c6b2 100644 --- a/tests/unit/algorithms/visual_prompting/tasks/test_train.py +++ b/tests/unit/algorithms/visual_prompting/tasks/test_train.py @@ -30,6 +30,7 @@ def test_train(self, mocker): """Test train.""" mocker_trainer = mocker.patch("otx.algorithms.visual_prompting.tasks.train.Trainer") mocker_save = mocker.patch("torch.save") + mocker.patch.object(self.training_task, "model_info") dataset = generate_visual_prompting_dataset() output_model = ModelEntity(