Skip to content

Commit bf01e57

Browse files
authored
Add CachedPredictor sharding — v5.6.0 (#133) (closes #128)
Final part of #128 — pluggable prediction sources reach feature-complete. - concat([caches]) + from_directory(path) merge multiple CachedPredictors through the core (name, version) invariant. - Overlap resolution: 'raise' (default) / 'last' / 'first' / callable(row_a, row_b) -> row. - v5.6.0 bump accumulates #131 (vaxrank polish), #132 (NetMHC loaders), and this PR (sharding). Closes #128.
1 parent c55e477 commit bf01e57

6 files changed

Lines changed: 358 additions & 2 deletions

File tree

CHANGELOG.md

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,57 @@
11
# Changelog
22

3+
## 5.6.0
4+
5+
**Closes #128`CachedPredictor` reaches feature-complete.**
6+
7+
**New loaders for the DTU NetMHC suite (#132):**
8+
9+
- `CachedPredictor.from_netmhcpan_stdout(path, mode=…)` — auto-detects
10+
NetMHCpan 2.8 / 3 / 4 / 4.1. `mode` selects `"binding_affinity"` or
11+
`"elution_score"` for 4+.
12+
- `CachedPredictor.from_netmhc_stdout(path, version=…)` — classic
13+
NetMHC 3 / 4 / 4.1.
14+
- `CachedPredictor.from_netmhcpan_cons_stdout(path)` — NetMHCcons.
15+
- `CachedPredictor.from_netmhciipan_stdout(path, version=…)`
16+
NetMHCIIpan legacy / 4 / 4.3.
17+
- `CachedPredictor.from_netmhcstabpan_stdout(path)` — NetMHCstabpan
18+
pMHC-stability predictor.
19+
20+
Each loader wraps an existing `mhctools.parsing.*_stdout` function
21+
(zero new parsing code) and parses the tool version out of the
22+
stdout preamble onto `predictor_version`. Parses stdout text, not
23+
the `-xlsfile` tab-delimited variant — flagged in `docs/cached.md`.
24+
25+
**Sharding — `concat` + `from_directory`:**
26+
27+
- `CachedPredictor.concat([caches], on_overlap=…)` — merge several
28+
caches into one. All shards must share `(name, version)` per the
29+
core invariant.
30+
- `CachedPredictor.from_directory(path, pattern="*", on_overlap=…)`
31+
glob a directory and concat every matching file through
32+
`from_topiary_output`.
33+
- Overlap resolution policies (`on_overlap`): `"raise"` (default — fail
34+
if any `(peptide, allele, peptide_length)` appears in more than one
35+
shard), `"last"` (later shard wins), `"first"` (earlier wins), or a
36+
user-supplied `callable(row_a, row_b) -> row` resolver.
37+
38+
**Polish from vaxrank-consumer review on #130 (#131):**
39+
40+
- `_fallback_resolve` filters fallback output to keys not already in
41+
the index before merging, so a partial-allele cache (peptide P
42+
present for allele A, missing for B) doesn't see its `(P, A)` row
43+
silently overwritten by the fallback's all-alleles response.
44+
- Class docstring now flags silent peptide-length lock-in and
45+
non-thread-safety.
46+
- `save()` raises on an empty never-queried cache with no identity,
47+
so users don't write schema-only files that can't be round-tripped.
48+
49+
**Tests:**
50+
51+
- 59 tests in `tests/test_cached_predictor.py` (up from 41): 6 NetMHC
52+
loader tests, 12 sharding tests. Full suite 1111 passed (up from
53+
1093).
54+
355
## 5.5.0
456

557
**New feature — `CachedPredictor`:**

docs/api.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ pre-computed table. Pass as `models=cache` to `TopiaryPredictor`. See
4141
| `CachedPredictor.from_netmhcpan_cons_stdout(path)` | NetMHCcons stdout. |
4242
| `CachedPredictor.from_netmhciipan_stdout(path, version=...)` | NetMHCIIpan stdout (legacy / 4 / 4.3). |
4343
| `CachedPredictor.from_netmhcstabpan_stdout(path)` | NetMHCstabpan stdout (pMHC stability). |
44+
| `CachedPredictor.concat([caches], on_overlap=...)` | Merge shards (all must share name+version). `on_overlap`: `"raise"` / `"last"` / `"first"` / callable. |
45+
| `CachedPredictor.from_directory(path, pattern="*", on_overlap=...)` | Glob a dir and concat every matching file. |
4446
| `CachedPredictor.from_tsv(path, columns=..., prediction_method_name=..., predictor_version=...)` | Generic tab- or comma-delimited. |
4547
| `CachedPredictor.from_dataframe(df, ...)` | In-memory DataFrame. |
4648
| `CachedPredictor(fallback=live_predictor)` | Empty cache, lazy identity discovery — pure read-through over a live model. |

docs/cached.md

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,47 @@ when the file doesn't embed that identity.
116116

117117
Pass `sep=","` for CSV files.
118118

119+
### Sharding: merge multiple caches
120+
121+
Predict per-allele or per-sample in parallel, persist each shard
122+
separately, then merge them:
123+
124+
```python
125+
cache = CachedPredictor.concat([shard_a, shard_b, shard_c])
126+
127+
# Or: load every matching file from a directory
128+
cache = CachedPredictor.from_directory(
129+
"caches/",
130+
pattern="*.parquet",
131+
)
132+
```
133+
134+
Every shard must share the same
135+
`(prediction_method_name, predictor_version)` — the core invariant
136+
applies across shards the same way it applies inside one.
137+
138+
**Overlap resolution** (`on_overlap=`):
139+
140+
- `"raise"` (default) — fail if any `(peptide, allele, peptide_length)`
141+
appears in more than one shard. A sample of conflicting keys is
142+
included in the error. Use this if shards should be disjoint.
143+
- `"last"` — later shard in the input list wins. Useful when the
144+
sort order represents "newer overwrites older."
145+
- `"first"` — earlier shard wins.
146+
- `callable(row_a, row_b) -> row` — custom resolver. Called pairwise
147+
per duplicate group. Pattern for "keep stronger binder":
148+
149+
```python
150+
def keep_lower_affinity(a, b):
151+
return a if a["affinity"] <= b["affinity"] else b
152+
153+
cache = CachedPredictor.concat(shards, on_overlap=keep_lower_affinity)
154+
```
155+
156+
`from_directory` passes `on_overlap` through to `concat`; file order
157+
is sorted lexicographically, so `shard_a.tsv` is always earlier than
158+
`shard_b.tsv`.
159+
119160
### From an in-memory DataFrame
120161

121162
```python

tests/test_cached_predictor.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,3 +644,130 @@ def test_from_netmhciipan_stdout_unsupported_version_rejects(self, tmp_path):
644644
path = self._write(tmp_path, "ii.out", _NETMHCPAN_41_STDOUT)
645645
with pytest.raises(ValueError, match="not supported"):
646646
CachedPredictor.from_netmhciipan_stdout(path, version="99")
647+
648+
649+
# ---------------------------------------------------------------------------
650+
# Sharding: concat + from_directory
651+
# ---------------------------------------------------------------------------
652+
653+
654+
class TestConcat:
655+
def test_concat_merges_disjoint_shards(self):
656+
a = CachedPredictor.from_dataframe(_df([_row(peptide="SIINFEKLA")]))
657+
b = CachedPredictor.from_dataframe(_df([_row(peptide="GILGFVFTL")]))
658+
merged = CachedPredictor.concat([a, b])
659+
assert set(merged._df["peptide"]) == {"SIINFEKLA", "GILGFVFTL"}
660+
assert merged.prediction_method_name == "random"
661+
assert merged.predictor_version == "1.0"
662+
663+
def test_concat_empty_list_rejects(self):
664+
with pytest.raises(ValueError, match="no caches given"):
665+
CachedPredictor.concat([])
666+
667+
def test_concat_mixed_versions_rejects(self):
668+
a = CachedPredictor.from_dataframe(
669+
_df([_row(peptide="SIINFEKLA", predictor_version="1.0")]),
670+
)
671+
b = CachedPredictor.from_dataframe(
672+
_df([_row(peptide="GILGFVFTL", predictor_version="2.0")]),
673+
)
674+
with pytest.raises(ValueError, match="multiple"):
675+
CachedPredictor.concat([a, b])
676+
677+
def test_concat_overlap_raises_by_default(self):
678+
a = CachedPredictor.from_dataframe(
679+
_df([_row(peptide="SIINFEKLA", affinity=100.0)]),
680+
)
681+
b = CachedPredictor.from_dataframe(
682+
_df([_row(peptide="SIINFEKLA", affinity=999.0)]),
683+
)
684+
with pytest.raises(ValueError, match="overlapping"):
685+
CachedPredictor.concat([a, b])
686+
687+
def test_concat_overlap_last_wins(self):
688+
a = CachedPredictor.from_dataframe(
689+
_df([_row(peptide="SIINFEKLA", affinity=100.0)]),
690+
)
691+
b = CachedPredictor.from_dataframe(
692+
_df([_row(peptide="SIINFEKLA", affinity=999.0)]),
693+
)
694+
merged = CachedPredictor.concat([a, b], on_overlap="last")
695+
assert merged._df.iloc[0]["affinity"] == 999.0
696+
697+
def test_concat_overlap_first_wins(self):
698+
a = CachedPredictor.from_dataframe(
699+
_df([_row(peptide="SIINFEKLA", affinity=100.0)]),
700+
)
701+
b = CachedPredictor.from_dataframe(
702+
_df([_row(peptide="SIINFEKLA", affinity=999.0)]),
703+
)
704+
merged = CachedPredictor.concat([a, b], on_overlap="first")
705+
assert merged._df.iloc[0]["affinity"] == 100.0
706+
707+
def test_concat_overlap_callable_resolver(self):
708+
a = CachedPredictor.from_dataframe(
709+
_df([_row(peptide="SIINFEKLA", affinity=100.0)]),
710+
)
711+
b = CachedPredictor.from_dataframe(
712+
_df([_row(peptide="SIINFEKLA", affinity=999.0)]),
713+
)
714+
# Keep the lower affinity (stronger binder) — caller's choice.
715+
def keep_lower(x, y):
716+
return x if x["affinity"] <= y["affinity"] else y
717+
merged = CachedPredictor.concat([a, b], on_overlap=keep_lower)
718+
assert merged._df.iloc[0]["affinity"] == 100.0
719+
720+
def test_concat_overlap_invalid_policy_rejects(self):
721+
a = CachedPredictor.from_dataframe(
722+
_df([_row(peptide="SIINFEKLA", affinity=100.0)]),
723+
)
724+
b = CachedPredictor.from_dataframe(
725+
_df([_row(peptide="SIINFEKLA", affinity=999.0)]),
726+
)
727+
with pytest.raises(ValueError, match="on_overlap"):
728+
CachedPredictor.concat([a, b], on_overlap="random-policy")
729+
730+
731+
class TestFromDirectory:
732+
def test_from_directory_loads_all_shards(self, tmp_path):
733+
# Two shards with disjoint peptides
734+
shard_a = CachedPredictor.from_dataframe(
735+
_df([_row(peptide="SIINFEKLA")]),
736+
)
737+
shard_b = CachedPredictor.from_dataframe(
738+
_df([_row(peptide="GILGFVFTL")]),
739+
)
740+
shard_a.save(tmp_path / "shard_a.tsv")
741+
shard_b.save(tmp_path / "shard_b.tsv")
742+
743+
merged = CachedPredictor.from_directory(tmp_path, pattern="*.tsv")
744+
assert set(merged._df["peptide"]) == {"SIINFEKLA", "GILGFVFTL"}
745+
746+
def test_from_directory_nonexistent_raises(self, tmp_path):
747+
missing = tmp_path / "nope"
748+
with pytest.raises(ValueError, match="not a directory"):
749+
CachedPredictor.from_directory(missing)
750+
751+
def test_from_directory_no_matching_files_raises(self, tmp_path):
752+
# Empty dir with a pattern that can't match
753+
with pytest.raises(ValueError, match="no files matching"):
754+
CachedPredictor.from_directory(tmp_path, pattern="*.parquet")
755+
756+
def test_from_directory_propagates_on_overlap(self, tmp_path):
757+
shard_a = CachedPredictor.from_dataframe(
758+
_df([_row(peptide="SIINFEKLA", affinity=100.0)]),
759+
)
760+
shard_b = CachedPredictor.from_dataframe(
761+
_df([_row(peptide="SIINFEKLA", affinity=999.0)]),
762+
)
763+
shard_a.save(tmp_path / "a.tsv")
764+
shard_b.save(tmp_path / "b.tsv")
765+
# Default policy raises
766+
with pytest.raises(ValueError, match="overlapping"):
767+
CachedPredictor.from_directory(tmp_path, pattern="*.tsv")
768+
# last-wins policy succeeds
769+
merged = CachedPredictor.from_directory(
770+
tmp_path, pattern="*.tsv", on_overlap="last",
771+
)
772+
# sorted(["a.tsv", "b.tsv"]) → b wins (affinity=999.0)
773+
assert merged._df.iloc[0]["affinity"] == 999.0

topiary/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from .result import TopiaryResult, concat
5050
from .wide import detect_form, from_wide, to_wide
5151

52-
__version__ = "5.5.0"
52+
__version__ = "5.6.0"
5353

5454
__all__ = [
5555
"TopiaryPredictor",

0 commit comments

Comments
 (0)