Skip to content

Commit a6a3756

Browse files
author
Ashley Scillitoe
authored
Fixes to ContextMMDDrift mypy errors and examples (#466)
1 parent ffb00b2 commit a6a3756

File tree

3 files changed

+12
-11
lines changed

3 files changed

+12
-11
lines changed

alibi_detect/cd/pytorch/context_aware.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,13 +135,14 @@ def score(self, # type: ignore[override]
135135
inds_held = np.random.choice(n, n_held, replace=False)
136136
inds_test = np.setdiff1d(np.arange(n), inds_held)
137137
c_held = torch.as_tensor(c[inds_held]).to(self.device)
138-
c, x = torch.as_tensor(c[inds_test]).to(self.device), torch.as_tensor(x[inds_test]).to(self.device)
138+
c = torch.as_tensor(c[inds_test]).to(self.device) # type: ignore[assignment]
139+
x = torch.as_tensor(x[inds_test]).to(self.device) # type: ignore[assignment]
139140
n_ref, n_test = len(x_ref), len(x)
140141
bools = torch.cat([torch.zeros(n_ref), torch.ones(n_test)]).to(self.device)
141142

142143
# Compute kernel matrices
143-
x_all = torch.cat([x_ref, x], dim=0)
144-
c_all = torch.cat([c_ref, c], dim=0)
144+
x_all = torch.cat([x_ref, x], dim=0) # type: ignore[list-item]
145+
c_all = torch.cat([c_ref, c], dim=0) # type: ignore[list-item]
145146
K = self.x_kernel(x_all, x_all)
146147
L = self.c_kernel(c_all, c_all)
147148
L_held = self.c_kernel(c_held, c_all)

doc/source/examples/cd_context_20newsgroup.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"source": [
88
"# Context-aware drift detection on news articles\n",
99
"\n",
10-
"### Introduction\n",
10+
"## Introduction\n",
1111
"\n",
1212
"In this notebook we show how to **detect drift on text data given a specific context** using the [context-aware MMD detector](https://docs.seldon.io/projects/alibi-detect/en/latest/cd/methods/contextmmddrift.html) ([Cobb and Van Looveren, 2022](https://arxiv.org/abs/2203.08644)). Consider the following simple example: the upcoming elections result in an increase of political news articles compared to other topics such as sports or science. Given the context (the elections), it is however not surprising that we observe this uptick. Moreover, assume we have a machine learning model which is trained to classify news topics, and this model performs well on political articles. So given that we fully expect this uptick to occur given the context, and that our model performs fine on the political news articles, we do not want to flag this type of drift in the data. **This setting corresponds more closely to many real-life settings than traditional drift detection where we make the assumption that both the reference and test data are i.i.d. samples from their underlying distributions.**\n",
1313
"\n",
@@ -28,11 +28,11 @@
2828
"\n",
2929
"Under setting 1. we want our detector to be **well-calibrated** (a controlled False Positive Rate (FPR) and more generally a p-value which is uniformly distributed between 0 and 1) while under settings 2. and 3. we want our detector to be **powerful** and flag the drift. Lastly, we show how the detector can help you to **understand the connection between the reference and test data distributions** better.\n",
3030
"\n",
31-
"### Data\n",
31+
"## Data\n",
3232
"\n",
3333
"We use the [20 newsgroup dataset](https://scikit-learn.org/0.19/datasets/twenty_newsgroups.html) which contains about 18,000 newsgroups post across 20 topics, including politics, science sports or religion.\n",
3434
"\n",
35-
"### Requirements\n",
35+
"## Requirements\n",
3636
"\n",
3737
"The notebook requires the `umap-learn`, `torch`, `sentence-transformers`, `statsmodels`, `seaborn` and `datasets` packages to be installed, which can be done via `pip`:"
3838
]
@@ -1525,7 +1525,7 @@
15251525
"name": "python",
15261526
"nbconvert_exporter": "python",
15271527
"pygments_lexer": "ipython3",
1528-
"version": "3.8.11"
1528+
"version": "3.7.10"
15291529
}
15301530
},
15311531
"nbformat": 4,

doc/source/examples/cd_context_ecg.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"source": [
88
"# Context-aware drift detection on ECGs\n",
99
"\n",
10-
"### Introduction\n",
10+
"## Introduction\n",
1111
"\n",
1212
"In this notebook we show how to **detect drift on ECG data given a specific context** using the [context-aware MMD detector](https://docs.seldon.io/projects/alibi-detect/en/latest/cd/methods/contextmmddrift.html) ([Cobb and Van Looveren, 2022](https://arxiv.org/abs/2203.08644)). Consider the following simple example: we have a heatbeat monitoring system which is trained on a wide variety of heartbeats sampled from people of all ages across a variety of activities (e.g. rest or running). Then we deploy the system to monitor individual people during certain activities. The distribution of the heartbeats monitored during deployment will then be drifting against the reference data which resembles the full training distribution, simply because only individual people in a specific setting are being tracked. However, this does not mean that the system is not working and requires re-training. We are instead interested in flagging drift given the relevant context such as the person's characteristics (e.g. age or medical history) and the activity. Traditional drift detectors cannot flexibly deal with this setting since they rely on the [i.i.d.]([i.i.d.](https://en.wikipedia.org/wiki/Independent_and_identically_distributed_random_variables)) assumption when sampling the reference and test sets. The context-aware detector however allows us to pass this context to the detector and flag drift appropriately. More generally, **the context-aware drift detector detects changes in the data distribution which cannot be attributed to a permissible change in the context variable**. On top of that, the detector allows you to understand which subpopulations are present in both the reference and test data which provides deeper insights into the distribution underlying the test data.\n",
1313
"\n",
@@ -27,11 +27,11 @@
2727
"\n",
2828
"Under setting 1. we want our detector to be **well-calibrated** (a controlled False Positive Rate (FPR) and more generally a p-value which is uniformly distributed between 0 and 1) while under setting 2. we want our detector to be **powerful** and flag drift. Lastly, we show how the detector can help you to **understand the connection between the reference and test data distributions** better.\n",
2929
"\n",
30-
"### Data\n",
30+
"## Data\n",
3131
"\n",
3232
"The dataset contains 5000 ECG’s, originally obtained from Physionet from the [BIDMC Congestive Heart Failure Database](https://www.physionet.org/content/chfdb/1.0.0/), record chf07. The data has been pre-processed in 2 steps: first each heartbeat is extracted, and then each beat is made equal length via interpolation. The data is labeled and contains 5 classes. The first class $N$ which contains almost 60% of the observations is seen as normal while the others are *supraventricular ectopic beats* ($S$), *ventricular ectopic beats* ($V$), *fusion beats* ($F$) and *unknown beats* ($Q$).\n",
3333
"\n",
34-
"### Requirements\n",
34+
"## Requirements\n",
3535
"\n",
3636
"The notebook requires the `torch` and `statsmodels` packages to be installed, which can be done via `pip`:"
3737
]
@@ -1086,7 +1086,7 @@
10861086
"name": "python",
10871087
"nbconvert_exporter": "python",
10881088
"pygments_lexer": "ipython3",
1089-
"version": "3.8.11"
1089+
"version": "3.7.10"
10901090
}
10911091
},
10921092
"nbformat": 4,

0 commit comments

Comments
 (0)