-
Notifications
You must be signed in to change notification settings - Fork 244
Change tensorflow model format to SavedModel to support sub-classed models
#628
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b453383
f1ecce5
a355503
b70f5b4
853d76c
fac6f1b
129ece9
2a98085
630478c
11010ff
4fc8585
46c3ed6
a47ebf5
5abeda4
f61af31
97866c7
1034cf8
733418b
8f70da4
6df68ef
73c32b7
51e9bbd
7f2553b
5eaafbe
d39ff25
1dc7b61
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,7 @@ | |
| from alibi_detect.models.tensorflow.autoencoder import (AE, AEGMM, VAE, VAEGMM, | ||
| DecoderLSTM, | ||
| EncoderLSTM, Seq2Seq) | ||
| from alibi_detect.utils.tensorflow.misc import clone_model | ||
| from alibi_detect.od import (LLR, IForest, Mahalanobis, OutlierAE, | ||
| OutlierAEGMM, OutlierProphet, OutlierSeq2Seq, | ||
| OutlierVAE, OutlierVAEGMM, SpectralResidual) | ||
|
|
@@ -34,38 +35,38 @@ | |
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| MODEL_ERROR = "The TensorFlow model may have been loaded incorrectly. This could be because `get_config` " \ | ||
| "and/or `from_config` methods were defined incorrectly. Otherwise, it could be because custom objects " \ | ||
| "have not been provided. For more guidance see the TensorFlow tab at " \ | ||
| "https://docs.seldon.io/projects/alibi-detect/en/stable/overview/saving.html#supported-ml-models." | ||
|
|
||
|
|
||
| def load_model(filepath: Union[str, os.PathLike], | ||
| filename: str = 'model', | ||
| custom_objects: dict = None, | ||
| layer: Optional[int] = None, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor, but I don't fully agree with omitting
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pretty much the only reason I made this change is that the if flavour == Framework.TENSORFLOW:
model = load_model_tf(src, layer=layer, **kwargs)So I saw it as a trade-off wrt to carrying around more complexity in Similar story for
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
p.s. below is very true though... tweaking anything to do with legacy save/load does bring the potential for bugs like the ones |
||
| **kwargs | ||
| ) -> tf.keras.Model: | ||
| """ | ||
| Load TensorFlow model. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| filepath | ||
| Saved model directory. | ||
| filename | ||
| Name of saved model within the filepath directory. | ||
| custom_objects | ||
| Optional custom objects when loading the TensorFlow model. | ||
| Saved model filepath. This should be a directory if the model is in the `SavedModel` format. Otherwise, it | ||
| should be a path to a `.h5` file. | ||
| layer | ||
| Optional index of a hidden layer to extract. If not `None`, a | ||
| :py:class:`~alibi_detect.cd.tensorflow.HiddenOutput` model is returned. | ||
|
|
||
| kwargs | ||
| Additional keyword arguments to be passed to :func:`tf.keras.models.load_model`. | ||
| Returns | ||
| ------- | ||
| Loaded model. | ||
| """ | ||
| # TODO - update this to accept tf format - later PR. | ||
| model_dir = Path(filepath) | ||
| model_name = filename + '.h5' | ||
| # Check if model exists | ||
| if model_name not in [f.name for f in model_dir.glob('[!.]*.h5')]: | ||
| raise FileNotFoundError(f'{model_name} not found in {model_dir.resolve()}.') | ||
| model = tf.keras.models.load_model(model_dir.joinpath(model_name), custom_objects=custom_objects) | ||
| # Load model | ||
| model = tf.keras.models.load_model(filepath, **kwargs) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like we're throwing away the validation code for the existence of the model? Or is it done from higher up in another caller?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Its not actually done from higher up. Rather, I realised that the validation might be superfluous since
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK that makes sense. |
||
| # Check the loaded model for problems | ||
| check_model(model) | ||
|
|
||
| # Optionally extract hidden layer | ||
| if isinstance(layer, int): | ||
| model = HiddenOutput(model, layer=layer) | ||
|
|
@@ -265,13 +266,13 @@ def load_detector_legacy(filepath: Union[str, os.PathLike], suffix: str, **kwarg | |
| elif detector_name == 'AdversarialAE': | ||
| ae = load_tf_ae(filepath) | ||
| custom_objects = kwargs['custom_objects'] if 'custom_objects' in k else None | ||
| model = load_model(model_dir, custom_objects=custom_objects) | ||
| model = load_model(model_dir.joinpath('model.h5'), custom_objects=custom_objects) | ||
| model_hl = load_tf_hl(filepath, model, state_dict) | ||
| detector = init_ad_ae(state_dict, ae, model, model_hl) | ||
| elif detector_name == 'ModelDistillation': | ||
| md = load_model(model_dir, filename='distilled_model') | ||
| md = load_model(model_dir.joinpath('distilled_model.h5')) | ||
| custom_objects = kwargs['custom_objects'] if 'custom_objects' in k else None | ||
| model = load_model(model_dir, custom_objects=custom_objects) | ||
| model = load_model(model_dir.joinpath('model.h5'), custom_objects=custom_objects) | ||
| detector = init_ad_md(state_dict, md, model) | ||
| elif detector_name == 'OutlierProphet': | ||
| detector = init_od_prophet(state_dict) # type: ignore[assignment] | ||
|
|
@@ -285,8 +286,8 @@ def load_detector_legacy(filepath: Union[str, os.PathLike], suffix: str, **kwarg | |
| if state_dict['other']['load_text_embedding']: | ||
| emb, tokenizer = load_text_embed(filepath) | ||
| try: # legacy load_model behaviour was to return None if not found. Now it raises error, hence need try-except. | ||
| model = load_model(model_dir, filename='encoder') | ||
| except FileNotFoundError: | ||
| model = load_model(model_dir.joinpath('encoder.h5')) | ||
| except OSError: | ||
| logger.warning('No model found in {}, setting `model` to `None`.'.format(model_dir)) | ||
| model = None | ||
| if detector_name == 'KSDrift': | ||
|
|
@@ -299,7 +300,7 @@ def load_detector_legacy(filepath: Union[str, os.PathLike], suffix: str, **kwarg | |
| load_fn = init_cd_tabulardrift # type: ignore[assignment] | ||
| elif detector_name == 'ClassifierDriftTF': | ||
| # Don't need try-except here since model is not optional for ClassifierDrift | ||
| clf_drift = load_model(model_dir, filename='clf_drift') | ||
| clf_drift = load_model(model_dir.joinpath('clf_drift.h5')) | ||
| load_fn = partial(init_cd_classifierdrift, clf_drift) # type: ignore[assignment] | ||
| else: | ||
| raise NotImplementedError | ||
|
|
@@ -1018,3 +1019,39 @@ def init_od_llr(state_dict: Dict, models: tuple) -> LLR: | |
| od.model_s = models[2] | ||
| od.model_b = models[3] | ||
| return od | ||
|
|
||
|
|
||
| def check_model(model: tf.keras.Model) -> None: | ||
| """ | ||
| Function to check that a TensorFlow model has been loaded correctly. Specifically, this checks that the model | ||
| can be cloned, since this isn't possible if the model is a subclassed `tf.keras.Model` and custom objects | ||
| were not provided at load time. Additionally, in some cases (dependent on exact model and tf version) the model can | ||
| have problems being called if custom objects were not provided. As a general check for both cases, the model is | ||
| examined to see if it is a `keras.saving.saved_model.load.RevivedNetwork`. If it is, an error is also raised. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| model | ||
| The model to be checked. | ||
|
|
||
| Raises | ||
| ------ | ||
| ValueError | ||
| Raised if the model appears to be a `keras.saving.saved_model.load.RevivedNetwork`, indicating that some | ||
| custom objects were not provided at load time. | ||
| """ | ||
| try: | ||
| # Check if model is a `RevivedNetwork` rather than the real original model (this occurs when subclassed models | ||
| # are loaded without all custom objects being provided | ||
| # Note, could also do `if model.__class__.__base__.__name__ == 'RevivedNetwork':` | ||
| if model.__class__.__module__ == 'keras.saving.saved_model.load': | ||
| # Raise error (this will be caught and re-raised below) | ||
| raise ValueError('The model appears to be a `keras.saving.saved_model.load.RevivedNetwork. This suggests ' | ||
| 'a subclassed model has been loaded without all custom objects being provided.') | ||
|
|
||
| # Check model cloning doesn't raise error | ||
| clone_model(model) | ||
|
|
||
| # Capture any errors and display custom error | ||
| except Exception as error: | ||
| raise ValueError(MODEL_ERROR) from error | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Extra functionality sneaking into this PR... Worth adding changelog entries to this PR so everything is documented and not missed upon release?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not 100% clear what you mean here? The functionality of passing kwargs to
load_detector? It is extra functionality but is interlinked with the PR, ascustom_objectsneeds to be passed toload_detector.Edit: reading again, I see what you mean. Since we also pass kwarg's to
pytorch. I could factor this out to a separate PR if preferred...There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need, but would appreciate a changelog as part of the PR.