Serialisation of (online) state for online detectors#604
Serialisation of (online) state for online detectors#604ascillitoe merged 48 commits intoSeldonIO:masterfrom ascillitoe:feature/save_state
Conversation
|
Nice idea with the name change, and yes I agree, "state" does not refer to any attributes set in init, as that would be "config" (with our definitions). My only concern is that users might expect a detector to give the same predictions as the original when loaded from a "checkpoint" via Maybe the answer is just to make it clear in the docstrings though... as in any case statistically the detectors behaviour should be the same after the checkpoint even if the exact predictions are not the same? |
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## master #604 +/- ##
==========================================
+ Coverage 80.15% 80.32% +0.17%
==========================================
Files 133 137 +4
Lines 9177 9292 +115
==========================================
+ Hits 7356 7464 +108
- Misses 1821 1828 +7
Flags with carried forward coverage won't be shown. Click here to find out more.
|
|
Edit: Resolved. |
|
Regarding the codecov report, the |
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
|
@arnaudvl @ojcobb (and @jklaise/@mauicv) could do with your thoughts on this. In the latest implementation, I have removed the new Question: Shall we keep a As Example 2 shows, determinism in the case of saving/loading of a detector is not affected by this decision anyway... Difference between
|
|
Additional side-note, this issue with setting random seeds not giving deterministic behaviour for a given operation (in this case the The only solution I can think of for this is a scikit-learn style approach, where we accept |
|
@arnaudvl @ojcobb a possible alternative strategy to make def _configure_ref_subset(self):
"""
Configure reference subset. If already configured, the stateful attributes `test_window` and `k_xtc` are
reset without re-configuring a new reference subset.
"""
etw_size = 2 * self.window_size - 1 # etw = extended test window
nkc_size = self.n - self.n_kernel_centers # nkc = non-kernel-centers
rw_size = nkc_size - etw_size # rw = ref-window
# Check if already configured, we will re-initialise stateful attributes w/o searching for new ref split if so
configure_ref = self.init_test_inds is None
if configure_ref:
# Make split and ensure it doesn't cause an initial detection
lsdd_init = None
while lsdd_init is None or lsdd_init >= self.get_threshold(0):
# Make split
perm = torch.randperm(nkc_size)
self.ref_inds, self.init_test_inds = perm[:rw_size], perm[-self.window_size:]
self.test_window = self.x_ref_eff[self.init_test_inds]
# Compute initial lsdd to check for initial detection
self.c2s = self.k_xc[self.ref_inds].mean(0) # (below Eqn 21)
self.k_xtc = self.kernel(self.test_window, self.kernel_centers)
h_init = self.c2s - self.k_xtc.mean(0) # (Eqn 21)
lsdd_init = h_init[None, :] @ self.H_lam_inv @ h_init[:, None] # (Eqn 11)
else:
# Reset stateful attributes using existing split
self.test_window = self.x_ref_eff[self.init_test_inds]
self.k_xtc = self.kernel(self.test_window, self.kernel_centers)This seems like a reasonable compromise to me? However, the additional duplication/complexity is unnecessary if we truly don't care about repeatable predictions post- |
alibi_detect/base.py
Outdated
| def save_state(self, filepath): ... | ||
|
|
||
| def load_state(self, filepath): ... |
There was a problem hiding this comment.
Should parameters have type hints?
There was a problem hiding this comment.
Added in 15d1dfa. Note: I also updated the pre-existing get_config and set_config methods here.
Mmn good point, thinking again it doesn't seem ideal to have it outside of Re it becoming a public module I suspect you're mostly right. We do Weirdly though, with our |
alibi_detect/cd/base_online.py
Outdated
| filepath | ||
| The directory to load state from. | ||
| """ | ||
| self._set_state_dir(filepath) |
There was a problem hiding this comment.
Why is this necessary when loading?
There was a problem hiding this comment.
Just so that self.state_dir is set (and converted from str to pathlib.Path) when load_stateis called as well as whensave_state` is called.
I thought it might be helpful to have state_dir as a public attribute so that a user could see interrogate the detector to see where state was loaded from. Although thinking about it more, for the backend detectors one would have to do detector._detector.state_dir (access a private attribute) anyway. I guess we'd probably want to define a @property if we actually want to support this functionality properly...
Happy to just make it private if you think its better though...
alibi_detect/cd/base_online.py
Outdated
| dirpath | ||
| The directory to save state file inside. | ||
| """ | ||
| self.state_dir = Path(dirpath) |
There was a problem hiding this comment.
Should this be a private attribute?
alibi_detect/cd/base_online.py
Outdated
| def _set_state_dir(self, dirpath: Union[str, os.PathLike]): | ||
| """ | ||
| Set the directory path to store state in, and create an empty directory if it doesn't already exist. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| dirpath | ||
| The directory to save state file inside. | ||
| """ | ||
| self.state_dir = Path(dirpath) | ||
| self.state_dir.mkdir(parents=True, exist_ok=True) | ||
|
|
||
| def save_state(self, filepath: Union[str, os.PathLike]): | ||
| """ | ||
| Save a detector's state to disk in order to generate a checkpoint. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| filepath | ||
| The directory to save state to. | ||
| """ | ||
| self._set_state_dir(filepath) | ||
| self._save_state() | ||
| logger.info('Saved state for t={} to {}'.format(self.t, self.state_dir)) | ||
|
|
||
| def load_state(self, filepath: Union[str, os.PathLike]): | ||
| """ | ||
| Load the detector's state from disk, in order to restart from a checkpoint previously generated with | ||
| `save_state`. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| filepath | ||
| The directory to load state from. | ||
| """ | ||
| self._set_state_dir(filepath) | ||
| self._load_state() | ||
| logger.info('State loaded for t={} from {}'.format(self.t, self.state_dir)) | ||
|
|
||
| def _save_state(self): | ||
| """ | ||
| Private method to save a detector's state to disk. | ||
|
|
||
| TODO - Method slightly verbose as designed to facilitate saving of "offline" state in follow-up PR. | ||
| """ | ||
| filename = 'state' | ||
| keys = self.online_state_keys | ||
| save_state_dict(self, keys, self.state_dir.joinpath(filename + '.npz')) | ||
|
|
||
| def _load_state(self, offline: bool = False): | ||
| """ | ||
| Private method to load a detector's state from disk. | ||
|
|
||
| TODO - Method slightly verbose as designed to facilitate loading of "offline" state in follow-up PR. | ||
| """ | ||
| filename = 'state' | ||
| load_state_dict(self, self.state_dir.joinpath(filename + '.npz'), raise_error=True) | ||
|
|
There was a problem hiding this comment.
Seems like a lot of duplicated code that's exactly the same as for BaseMultiDriftOnline which suggests we may want to refactor using functions instead of methods or a mixin class? Or perhaps the class hierarchy needs to be updated.
There was a problem hiding this comment.
Note that this seems to apply to other methods too, so perhaps is a more widespread problem requiring a refactoring later...
There was a problem hiding this comment.
Fair point. Having a BaseDriftOnline class for generic methods, or a mix-in both seem much nicer than this current pattern. I'll have a rethink 👍🏻
There was a problem hiding this comment.
To reduce duplication, 5daf1b1 adds a BaseDriftOnline class. @jklaise @mauicv could I get your thoughts on the design of BaseDriftOnline please? I've gone with a parent class rather than mix-in since it seems strange to define a mix-in in alibi_detect/base.py when it is only to be used in two classes (BaseMultiDriftOnline and BaseUniDriftOnline). I also decided to put it in alibi_detect/cd/base_online.py rather than alibi_detect/base.py since at the moment the concept of "online" detectors is specific to drift (this may change if we decide stateful outlier detectors are in fact "online").
There was a problem hiding this comment.
LGTM however noting that there's quite a few abstract methods, some of which (not all?) are implemented in the Multi/Uni abstract child classes, which come with their own set of abstract methods... Worried that this may become a bit tricky to keep track of. As a minimum, would group all abstract methods to come after each other and add docstrings on expected implementation and also, where valid, which of the Multi/Uni classes implement these methods (+ type hints as always).
There was a problem hiding this comment.
See #604 (comment) wrt to type hints, not sure on best approach here.
Wrt to the abstract methods, if they are missing from the Multi/Uni child classes that will be because they are instead defined in the next subclass down i.e. LSDDDriftOnlineTorch._initialise_state or CVMDriftOnline._configure_thresholds...
We could move the abstract methods such as _configure_thresholds back to their respective Multi/Uni abstract class, at the cost of more duplication (but maybe less complexity?)
There was a problem hiding this comment.
Removed the new base class, and moved state methods to StateMixin. See #604 (comment).
| # Skip if backend not `tensorflow` or `pytorch` | ||
| if backend not in ('tensorflow', 'pytorch'): | ||
| pytest.skip("Detector doesn't have this backend") |
There was a problem hiding this comment.
Is this due to some keops behaviour? Basically asking why skip here.
There was a problem hiding this comment.
The test_saving file cycles through all possible backends:
backends = ['tensorflow', 'pytorch', 'sklearn']
if has_keops: # pykeops only installed in Linux CI
backends.append('keops')
backend = param_fixture("backend", backends)We have to skip tests if the associated detector doesn't have that backend. In this case, online detectors do not have a keops backend.
| @pytest.mark.parametrize('batch_size', batch_size) | ||
| @pytest.mark.parametrize('n_feat', n_features) | ||
| def test_cvmdriftonline(window_sizes, batch_size, n_feat, seed): | ||
| with fixed_seed(seed): |
There was a problem hiding this comment.
Noting that the previous version of tests didn't have a fixed seed, presumably it wasn't needed in this setting as test suite has been passing. Is there a need to introduce a fixed seed here as it seems detrimental to the testing for this particular set of tests?
There was a problem hiding this comment.
P.S. same comment applies to tests below and in other modules.
There was a problem hiding this comment.
Mmn yeh I figured it would be good to add since there is randomness in the initialization of these detectors (in _configure_thresholds, and in the generation of x_ref/x_h0/x_h1). Although the tests do currently pass without fixing the seed, this doesn't actually mean they pass for any random seed. I seem to recall that when I looked into this before, np.random.seed's set in one test leaked into others. Presumably, this means we have been implicitly fixing the seed in these tests anyway.
Ideally (IMO), we'd get to a point where any random operations in tests are done inside with fixed_seed(seed)'s, then if a new bug is introduced, we can go back and reproduce it with the same random seed.
There was a problem hiding this comment.
I agree that in case a bug happens then it's valuable to be able to reproduce with the same seed. But, on the other hand, "any random operations in tests done inside with fixed_seed(seed)" sounds like the opposite to what we want to do (unless for tests where we compare outputs of the same seed) - as stuff should pass most tests with any seed?
There was a problem hiding this comment.
Yeh fair point, tests should generally pass with any seed, especially if they are unit type tests. The problem at the moment is we have lots of functional tests where we are testing a detector's predictions and checking things like Expected Runtime (ERT) for online detectors. We probably want more granular unit tests in lots of places...
Edit: by "any random operations in tests done inside with fixed_seed(seed)", I more meant any random operations that might for some reason affect the outcome of the test.
jklaise
left a comment
There was a problem hiding this comment.
Overall LGTM, main question about default behaviour wrt state saving when save_detector is called on online detectors. Regardless of choice, I believe this should be prominent in saving and method docs.
alibi_detect/cd/base_online.py
Outdated
| @abstractmethod | ||
| def _update_state(self, x_t): | ||
| pass |
There was a problem hiding this comment.
Type hints of parameters and return types required.
There was a problem hiding this comment.
This is unfortunately necessary, since univariate online detectors have def _update_state(self, x_t: np.ndarray): whilst multivariate have def _update_state(self, x_t: torch.Tensor): (or tf.Tensor). We violate Liskov's substitution principle slightly.
I sort of think this is OK to not add type hints in the abstract method since we only have it there to signal that sub-classes must have an _update_method which takes an instance and updates online state, but we don't specify the exact type. However, we could also do def _update_state(self, x_t: Union[np.ndarray, 'torch.Tensor', 'tf.Tensor'): and then add # type: ignore[override] in the sub-class? This is actually what we did previously...
P.s. its a similar story for the get_threshold method...
|
d190589 removes |
jklaise
left a comment
There was a problem hiding this comment.
LGTM! Should we add some documentation somewhere that save will save the state by default, and if that's not desired one should call reset_state first? Or is it going to be too confusing for now?
Thanks! Will add this documentation now, just realised I added it in #628 instead of here. Doh! |
doc/source/overview/saving.md
Outdated
| [Online drift detectors](../cd/methods.md#online) are stateful, with their state updated upon each `predict` call. | ||
| When saving an online detector, the `save_state` option controls whether to include the detector's state: | ||
| [Online drift detectors](../cd/methods.md#online) are stateful, with their state updated each timestep `t` (each time | ||
| `.predict()` or `.state()` is called). {func}`~alibi_detect.saving.save_detector` will save the state of online |
There was a problem hiding this comment.
Do you mean score() instead of state()? On that note, do we even document the usage/use cases of score()? If not, perhaps should leave it out.
There was a problem hiding this comment.
Good spot thanks. Also fair point about not really documenting it. I'll remove.
This PR implements the functionality to save and load state for online detectors. At a given time step, the
save_statemethod can be called to create a "checkpoint". This can later be loaded via theload_statemethod. At any time, thereset_stateresetmethod can be used to reset the state back to thet=0timestep.Scope
This PR deals with online state only. See #604 (comment) for a discussion on online versus offline state.
Example(s)
Saving and loading state
The state of online detectors can now be saved and loaded via
save_stateandload_state:The detector's state may be saved with the
save_statemethod:The previously saved state may then be loaded via the
load_statemethod:At any point, the state may be reset with the
resetmethod. Also see colab notebook.Saving and loading detector with state
Calling
save_detectorwithsave_state=Truewill save an online detectors state tostate/within the detector save directory.load_detectorwill simply attempt to load state if astate/directory exists.TODO's:
save_detectorandload_detectorfunctions, to allow state to be saved and loaded when the detector itself is serialized/unserialized.test_saving.pystate test.## Outstanding considerations (specific to LSDD for now but maybe more widely applicible)There might be an open question to resolve regarding what we define "state" to be. This PR currently considers it to be only the attributes that are updated in_update_state(self.t,self.test_windowandself.k_xtc). In other words, "state" is defined as any attribute that is dependent on time (updated when a new instancex_tis given viascoreorpredict).However, there is already a notion of "state" introduced when weinitialisea detector (or reinitialise it via theresetmethod). Here, in addition to the attributes already mentioned, we setself.ref_inds,self.c2s, andself.init_test_inds. This leads to considerations:1. Will there be confusion between theresetandreset_statemethods, and do we need to change the docstrings or names?2. There is randomness involved in the initialisation of
LSDDDrift(in_configure_ref_subset). It is likely that if the detector is instantiated later on, andload_stateis used to restart from a checkpoint, predictions will still be different compared to those that were observed aftersave_statewas called with the original detector. This would only be avoided if random seeds were set both times. With this in mind, do we want to change our definition of "state" to includeself.ref_inds,self.c2s, andself.init_test_inds?