Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
b453383
Move incorrectly placed comment in _saving.py
Sep 21, 2022
f1ecce5
Change to tf SavedModel format and remove custom_objects config support
Sep 21, 2022
a355503
Fixed legacy save/load
Sep 22, 2022
b70f5b4
Check that dir contains a SavedModel in test_save_model
Sep 22, 2022
853d76c
Add documentation on supported models
Sep 22, 2022
fac6f1b
Call tf models to set input_shape prior to saving, and small docs edit
Sep 26, 2022
129ece9
Update get_input_shape warning and fix handling of input_shape in tests
Sep 26, 2022
2a98085
Reword get_input_shape warning again.
Sep 26, 2022
630478c
Merge master into feature/tf_SavedModel
Jan 17, 2023
11010ff
Fix deepkernel and save_model tests
Jan 17, 2023
4fc8585
Add custom_object support, and subclassed models tests
Jan 17, 2023
46c3ed6
Update saving/loading page
Jan 18, 2023
a47ebf5
Remove online state section from docs (wrong PR)
Jan 18, 2023
5abeda4
Add space in warning
Jan 18, 2023
f61af31
Remove dummy model call and just raise error if not called
Jan 19, 2023
97866c7
Check the tf model for possible problems during detector loading
Jan 26, 2023
1034cf8
Test new error and warning
Jan 26, 2023
733418b
Merge master into feature/tf_SavedModel
Jan 26, 2023
8f70da4
Fix tests
Jan 27, 2023
6df68ef
Tidy tf save_model
Jan 27, 2023
73c32b7
Remove addition of model to LARGE_ARTEFACTS (as in #723)
Jan 27, 2023
51e9bbd
Revert "Remove addition of model to LARGE_ARTEFACTS (as in #723)"
Jan 27, 2023
7f2553b
Edit error message
Jan 27, 2023
5eaafbe
Update saving docs w/ more prominent warning
Jan 27, 2023
d39ff25
Change UserWarning to ValueError
Jan 30, 2023
1dc7b61
Add a test of save/loading with a subclassed model in preprocess_fn
Jan 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion alibi_detect/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ def infer_threshold(self, X: np.ndarray) -> None:
# "Large artefacts" - to save memory these are skipped in _set_config(), but added back in get_config()
# Note: The current implementation assumes the artefact is stored as a class attribute, and as a config field under
# the same name. Refactoring will be required if this assumption is to be broken.
LARGE_ARTEFACTS = ['x_ref', 'c_ref', 'preprocess_fn']
# Note: The above procedure is not followed for `model` in the `UncertaintyDrift` detectors, since these do not store
# the attribute `self.model`.
LARGE_ARTEFACTS = ['x_ref', 'c_ref', 'preprocess_fn', 'model']


class DriftConfigMixin:
Expand Down
7 changes: 6 additions & 1 deletion alibi_detect/saving/_pytorch/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

def load_model(filepath: Union[str, os.PathLike],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extra functionality sneaking into this PR... Worth adding changelog entries to this PR so everything is documented and not missed upon release?

Copy link
Contributor Author

@ascillitoe ascillitoe Jan 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not 100% clear what you mean here? The functionality of passing kwargs to load_detector? It is extra functionality but is interlinked with the PR, as custom_objects needs to be passed to load_detector.

Edit: reading again, I see what you mean. Since we also pass kwarg's to pytorch. I could factor this out to a separate PR if preferred...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need, but would appreciate a changelog as part of the PR.

layer: Optional[int] = None,
**kwargs
) -> nn.Module:
"""
Load PyTorch model.
Expand All @@ -29,13 +30,17 @@ def load_model(filepath: Union[str, os.PathLike],
layer
Optional index of a hidden layer to extract. If not `None`, a
:py:class:`~alibi_detect.cd.pytorch.HiddenOutput` model is returned.
kwargs
Additional keyword arguments to be passed to :func:`torch.load`.

Returns
-------
Loaded model.
"""
filepath = Path(filepath).joinpath('model.pt')
model = torch.load(filepath, pickle_module=dill)
if 'pickle_module' not in kwargs:
kwargs['pickle_module'] = dill
model = torch.load(filepath, **kwargs)
# Optionally extract hidden layer
if isinstance(layer, int):
model = HiddenOutput(model, layer=layer)
Expand Down
79 changes: 58 additions & 21 deletions alibi_detect/saving/_tensorflow/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from alibi_detect.models.tensorflow.autoencoder import (AE, AEGMM, VAE, VAEGMM,
DecoderLSTM,
EncoderLSTM, Seq2Seq)
from alibi_detect.utils.tensorflow.misc import clone_model
from alibi_detect.od import (LLR, IForest, Mahalanobis, OutlierAE,
OutlierAEGMM, OutlierProphet, OutlierSeq2Seq,
OutlierVAE, OutlierVAEGMM, SpectralResidual)
Expand All @@ -34,38 +35,38 @@

logger = logging.getLogger(__name__)

MODEL_ERROR = "The TensorFlow model may have been loaded incorrectly. This could be because `get_config` " \
"and/or `from_config` methods were defined incorrectly. Otherwise, it could be because custom objects " \
"have not been provided. For more guidance see the TensorFlow tab at " \
"https://docs.seldon.io/projects/alibi-detect/en/stable/overview/saving.html#supported-ml-models."


def load_model(filepath: Union[str, os.PathLike],
filename: str = 'model',
custom_objects: dict = None,
layer: Optional[int] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor, but I don't fully agree with omitting filename from these internal save/load functions. What we save for in function signature, we pay at every callsite, having to remember to do .joinpath(filename). (OTOH in the old behaviour, having a default model name is also likely not desirable as forgetting to set it would result in a perhaps unexpected default).

Copy link
Contributor Author

@ascillitoe ascillitoe Jan 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty much the only reason I made this change is that the filename is only there for legacy loading (legacy as in, saving to .h5, and legacy as in loading files with different names such as encoder.h5). For the "modern" loading we simply do:

    if flavour == Framework.TENSORFLOW:
        model = load_model_tf(src, layer=layer, **kwargs)

So I saw it as a trade-off wrt to carrying around more complexity in load_model to facilitate the legacy functionality in load_detector_legacy, or add some complexity to the calls in load_detector_legacy to simplify the load_model function... (which just happens to be used for modern and legacy loading).

Similar story for save_model...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, noting that there's a few subtle changes in the behaviour of internal functions for saving/loading, need to be extra vigilant new bugs haven't been introduced.

p.s. below is very true though... tweaking anything to do with legacy save/load does bring the potential for bugs like the ones v0.10.5 fixed. I've run the same tests of loading old artefacts we ran in #732 and everything passes, but that isn't 100% comprehensive...

**kwargs
) -> tf.keras.Model:
"""
Load TensorFlow model.

Parameters
----------
filepath
Saved model directory.
filename
Name of saved model within the filepath directory.
custom_objects
Optional custom objects when loading the TensorFlow model.
Saved model filepath. This should be a directory if the model is in the `SavedModel` format. Otherwise, it
should be a path to a `.h5` file.
layer
Optional index of a hidden layer to extract. If not `None`, a
:py:class:`~alibi_detect.cd.tensorflow.HiddenOutput` model is returned.

kwargs
Additional keyword arguments to be passed to :func:`tf.keras.models.load_model`.
Returns
-------
Loaded model.
"""
# TODO - update this to accept tf format - later PR.
model_dir = Path(filepath)
model_name = filename + '.h5'
# Check if model exists
if model_name not in [f.name for f in model_dir.glob('[!.]*.h5')]:
raise FileNotFoundError(f'{model_name} not found in {model_dir.resolve()}.')
model = tf.keras.models.load_model(model_dir.joinpath(model_name), custom_objects=custom_objects)
# Load model
model = tf.keras.models.load_model(filepath, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we're throwing away the validation code for the existence of the model? Or is it done from higher up in another caller?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its not actually done from higher up. Rather, I realised that the validation might be superfluous since tf.keras.models.load_model already raises OSError: No file or directory found at test.h5 if a filepath to a .h5 model is passed and one doesn't exist, and OSError: SavedModel file does not exist at: test//{saved_model.pbtxt|saved_model.pb} if an directory is passed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK that makes sense.

# Check the loaded model for problems
check_model(model)

# Optionally extract hidden layer
if isinstance(layer, int):
model = HiddenOutput(model, layer=layer)
Expand Down Expand Up @@ -265,13 +266,13 @@ def load_detector_legacy(filepath: Union[str, os.PathLike], suffix: str, **kwarg
elif detector_name == 'AdversarialAE':
ae = load_tf_ae(filepath)
custom_objects = kwargs['custom_objects'] if 'custom_objects' in k else None
model = load_model(model_dir, custom_objects=custom_objects)
model = load_model(model_dir.joinpath('model.h5'), custom_objects=custom_objects)
model_hl = load_tf_hl(filepath, model, state_dict)
detector = init_ad_ae(state_dict, ae, model, model_hl)
elif detector_name == 'ModelDistillation':
md = load_model(model_dir, filename='distilled_model')
md = load_model(model_dir.joinpath('distilled_model.h5'))
custom_objects = kwargs['custom_objects'] if 'custom_objects' in k else None
model = load_model(model_dir, custom_objects=custom_objects)
model = load_model(model_dir.joinpath('model.h5'), custom_objects=custom_objects)
detector = init_ad_md(state_dict, md, model)
elif detector_name == 'OutlierProphet':
detector = init_od_prophet(state_dict) # type: ignore[assignment]
Expand All @@ -285,8 +286,8 @@ def load_detector_legacy(filepath: Union[str, os.PathLike], suffix: str, **kwarg
if state_dict['other']['load_text_embedding']:
emb, tokenizer = load_text_embed(filepath)
try: # legacy load_model behaviour was to return None if not found. Now it raises error, hence need try-except.
model = load_model(model_dir, filename='encoder')
except FileNotFoundError:
model = load_model(model_dir.joinpath('encoder.h5'))
except OSError:
logger.warning('No model found in {}, setting `model` to `None`.'.format(model_dir))
model = None
if detector_name == 'KSDrift':
Expand All @@ -299,7 +300,7 @@ def load_detector_legacy(filepath: Union[str, os.PathLike], suffix: str, **kwarg
load_fn = init_cd_tabulardrift # type: ignore[assignment]
elif detector_name == 'ClassifierDriftTF':
# Don't need try-except here since model is not optional for ClassifierDrift
clf_drift = load_model(model_dir, filename='clf_drift')
clf_drift = load_model(model_dir.joinpath('clf_drift.h5'))
load_fn = partial(init_cd_classifierdrift, clf_drift) # type: ignore[assignment]
else:
raise NotImplementedError
Expand Down Expand Up @@ -1018,3 +1019,39 @@ def init_od_llr(state_dict: Dict, models: tuple) -> LLR:
od.model_s = models[2]
od.model_b = models[3]
return od


def check_model(model: tf.keras.Model) -> None:
"""
Function to check that a TensorFlow model has been loaded correctly. Specifically, this checks that the model
can be cloned, since this isn't possible if the model is a subclassed `tf.keras.Model` and custom objects
were not provided at load time. Additionally, in some cases (dependent on exact model and tf version) the model can
have problems being called if custom objects were not provided. As a general check for both cases, the model is
examined to see if it is a `keras.saving.saved_model.load.RevivedNetwork`. If it is, an error is also raised.

Parameters
----------
model
The model to be checked.

Raises
------
ValueError
Raised if the model appears to be a `keras.saving.saved_model.load.RevivedNetwork`, indicating that some
custom objects were not provided at load time.
"""
try:
# Check if model is a `RevivedNetwork` rather than the real original model (this occurs when subclassed models
# are loaded without all custom objects being provided
# Note, could also do `if model.__class__.__base__.__name__ == 'RevivedNetwork':`
if model.__class__.__module__ == 'keras.saving.saved_model.load':
# Raise error (this will be caught and re-raised below)
raise ValueError('The model appears to be a `keras.saving.saved_model.load.RevivedNetwork. This suggests '
'a subclassed model has been loaded without all custom objects being provided.')

# Check model cloning doesn't raise error
clone_model(model)

# Capture any errors and display custom error
except Exception as error:
raise ValueError(MODEL_ERROR) from error
44 changes: 19 additions & 25 deletions alibi_detect/saving/_tensorflow/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from alibi_detect.od import (LLR, IForest, Mahalanobis, OutlierAE,
OutlierAEGMM, OutlierProphet, OutlierSeq2Seq,
OutlierVAE, OutlierVAEGMM, SpectralResidual)
from alibi_detect.utils._types import Literal
from alibi_detect.utils.tensorflow.kernels import GaussianRBF
from alibi_detect.utils.missing_optional_dependency import MissingDependency
from alibi_detect.utils.frameworks import Framework
Expand All @@ -29,7 +28,7 @@

def save_model_config(model: Callable,
base_path: Path,
input_shape: Optional[tuple],
input_shape: Optional[tuple] = None,
local_path: Path = Path('.')) -> Tuple[dict, Optional[dict]]:
"""
Save a TensorFlow model to a config dictionary. When a model has a text embedding model contained within it,
Expand Down Expand Up @@ -90,9 +89,7 @@ def save_model_config(model: Callable,


def save_model(model: tf.keras.Model,
filepath: Union[str, os.PathLike],
filename: str = 'model',
save_format: Literal['tf', 'h5'] = 'h5') -> None: # TODO - change to tf, later PR
filepath: Union[str, os.PathLike]) -> None:
"""
Save TensorFlow model.

Expand All @@ -101,24 +98,21 @@ def save_model(model: tf.keras.Model,
model
The tf.keras.Model to save.
filepath
Save directory.
filename
Name of file to save to within the filepath directory.
save_format
The format to save to. 'tf' to save to the newer SavedModel format, 'h5' to save to the lighter-weight
legacy hdf5 format.
File path to save to. If it refers to a `.h5` file, the model is saved in `.h5` format. Otherwise, the model
is saved in `SavedModel` format.
"""
# create folder to save model in
model_path = Path(filepath)
if not model_path.is_dir():
logger.warning('Directory {} does not exist and is now created.'.format(model_path))
model_path.mkdir(parents=True, exist_ok=True)

filepath = Path(filepath)
# Determine file format to save in
save_format = 'h5' if filepath.suffix == '.h5' else 'tf'
# save model
model_path = model_path.joinpath(filename + '.h5') if save_format == 'h5' else model_path

if isinstance(model, tf.keras.Model):
model.save(model_path, save_format=save_format)
try:
model.save(filepath, save_format=save_format)
except ValueError as error:
raise ValueError("Saving of the `tf.keras.Model` failed. If the model is a subclassed tensorflow model, "
"this might be because the model's input shape is not available. To specify an input "
"shape call the model (on actual data) before passing it to the detector, or pass actual "
"data to the detector's `predict` method.") from error
else:
raise ValueError('The extracted model to save is not a `tf.keras.Model`. Cannot save.')

Expand Down Expand Up @@ -261,24 +255,24 @@ def save_detector_legacy(detector, filepath):
save_tf_vae(detector, filepath)
elif isinstance(detector, (ChiSquareDrift, ClassifierDrift, KSDrift, MMDDrift, TabularDrift)):
if model is not None:
save_model(model, model_dir, filename='encoder')
save_model(model, model_dir.joinpath('encoder.h5'))
if embed is not None:
save_embedding_legacy(embed, embed_args, filepath)
if tokenizer is not None:
tokenizer.save_pretrained(filepath.joinpath('model'))
if detector_name == 'ClassifierDriftTF':
save_model(clf_drift, model_dir, filename='clf_drift')
save_model(clf_drift, model_dir.joinpath('clf_drift.h5'))
elif isinstance(detector, OutlierAEGMM):
save_tf_aegmm(detector, filepath)
elif isinstance(detector, OutlierVAEGMM):
save_tf_vaegmm(detector, filepath)
elif isinstance(detector, AdversarialAE):
save_tf_ae(detector, filepath)
save_model(detector.model, model_dir)
save_model(detector.model, model_dir.joinpath('model.h5'))
save_tf_hl(detector.model_hl, filepath)
elif isinstance(detector, ModelDistillation):
save_model(detector.distilled_model, model_dir, filename='distilled_model')
save_model(detector.model, model_dir, filename='model')
save_model(detector.distilled_model, model_dir.joinpath('distilled_model.h5'))
save_model(detector.model, model_dir.joinpath('model.h5'))
elif isinstance(detector, OutlierSeq2Seq):
save_tf_s2s(detector, filepath)
elif isinstance(detector, LLR):
Expand Down
Loading