@@ -644,3 +644,130 @@ def test_from_netmhciipan_stdout_unsupported_version_rejects(self, tmp_path):
644644 path = self ._write (tmp_path , "ii.out" , _NETMHCPAN_41_STDOUT )
645645 with pytest .raises (ValueError , match = "not supported" ):
646646 CachedPredictor .from_netmhciipan_stdout (path , version = "99" )
647+
648+
649+ # ---------------------------------------------------------------------------
650+ # Sharding: concat + from_directory
651+ # ---------------------------------------------------------------------------
652+
653+
654+ class TestConcat :
655+ def test_concat_merges_disjoint_shards (self ):
656+ a = CachedPredictor .from_dataframe (_df ([_row (peptide = "SIINFEKLA" )]))
657+ b = CachedPredictor .from_dataframe (_df ([_row (peptide = "GILGFVFTL" )]))
658+ merged = CachedPredictor .concat ([a , b ])
659+ assert set (merged ._df ["peptide" ]) == {"SIINFEKLA" , "GILGFVFTL" }
660+ assert merged .prediction_method_name == "random"
661+ assert merged .predictor_version == "1.0"
662+
663+ def test_concat_empty_list_rejects (self ):
664+ with pytest .raises (ValueError , match = "no caches given" ):
665+ CachedPredictor .concat ([])
666+
667+ def test_concat_mixed_versions_rejects (self ):
668+ a = CachedPredictor .from_dataframe (
669+ _df ([_row (peptide = "SIINFEKLA" , predictor_version = "1.0" )]),
670+ )
671+ b = CachedPredictor .from_dataframe (
672+ _df ([_row (peptide = "GILGFVFTL" , predictor_version = "2.0" )]),
673+ )
674+ with pytest .raises (ValueError , match = "multiple" ):
675+ CachedPredictor .concat ([a , b ])
676+
677+ def test_concat_overlap_raises_by_default (self ):
678+ a = CachedPredictor .from_dataframe (
679+ _df ([_row (peptide = "SIINFEKLA" , affinity = 100.0 )]),
680+ )
681+ b = CachedPredictor .from_dataframe (
682+ _df ([_row (peptide = "SIINFEKLA" , affinity = 999.0 )]),
683+ )
684+ with pytest .raises (ValueError , match = "overlapping" ):
685+ CachedPredictor .concat ([a , b ])
686+
687+ def test_concat_overlap_last_wins (self ):
688+ a = CachedPredictor .from_dataframe (
689+ _df ([_row (peptide = "SIINFEKLA" , affinity = 100.0 )]),
690+ )
691+ b = CachedPredictor .from_dataframe (
692+ _df ([_row (peptide = "SIINFEKLA" , affinity = 999.0 )]),
693+ )
694+ merged = CachedPredictor .concat ([a , b ], on_overlap = "last" )
695+ assert merged ._df .iloc [0 ]["affinity" ] == 999.0
696+
697+ def test_concat_overlap_first_wins (self ):
698+ a = CachedPredictor .from_dataframe (
699+ _df ([_row (peptide = "SIINFEKLA" , affinity = 100.0 )]),
700+ )
701+ b = CachedPredictor .from_dataframe (
702+ _df ([_row (peptide = "SIINFEKLA" , affinity = 999.0 )]),
703+ )
704+ merged = CachedPredictor .concat ([a , b ], on_overlap = "first" )
705+ assert merged ._df .iloc [0 ]["affinity" ] == 100.0
706+
707+ def test_concat_overlap_callable_resolver (self ):
708+ a = CachedPredictor .from_dataframe (
709+ _df ([_row (peptide = "SIINFEKLA" , affinity = 100.0 )]),
710+ )
711+ b = CachedPredictor .from_dataframe (
712+ _df ([_row (peptide = "SIINFEKLA" , affinity = 999.0 )]),
713+ )
714+ # Keep the lower affinity (stronger binder) — caller's choice.
715+ def keep_lower (x , y ):
716+ return x if x ["affinity" ] <= y ["affinity" ] else y
717+ merged = CachedPredictor .concat ([a , b ], on_overlap = keep_lower )
718+ assert merged ._df .iloc [0 ]["affinity" ] == 100.0
719+
720+ def test_concat_overlap_invalid_policy_rejects (self ):
721+ a = CachedPredictor .from_dataframe (
722+ _df ([_row (peptide = "SIINFEKLA" , affinity = 100.0 )]),
723+ )
724+ b = CachedPredictor .from_dataframe (
725+ _df ([_row (peptide = "SIINFEKLA" , affinity = 999.0 )]),
726+ )
727+ with pytest .raises (ValueError , match = "on_overlap" ):
728+ CachedPredictor .concat ([a , b ], on_overlap = "random-policy" )
729+
730+
731+ class TestFromDirectory :
732+ def test_from_directory_loads_all_shards (self , tmp_path ):
733+ # Two shards with disjoint peptides
734+ shard_a = CachedPredictor .from_dataframe (
735+ _df ([_row (peptide = "SIINFEKLA" )]),
736+ )
737+ shard_b = CachedPredictor .from_dataframe (
738+ _df ([_row (peptide = "GILGFVFTL" )]),
739+ )
740+ shard_a .save (tmp_path / "shard_a.tsv" )
741+ shard_b .save (tmp_path / "shard_b.tsv" )
742+
743+ merged = CachedPredictor .from_directory (tmp_path , pattern = "*.tsv" )
744+ assert set (merged ._df ["peptide" ]) == {"SIINFEKLA" , "GILGFVFTL" }
745+
746+ def test_from_directory_nonexistent_raises (self , tmp_path ):
747+ missing = tmp_path / "nope"
748+ with pytest .raises (ValueError , match = "not a directory" ):
749+ CachedPredictor .from_directory (missing )
750+
751+ def test_from_directory_no_matching_files_raises (self , tmp_path ):
752+ # Empty dir with a pattern that can't match
753+ with pytest .raises (ValueError , match = "no files matching" ):
754+ CachedPredictor .from_directory (tmp_path , pattern = "*.parquet" )
755+
756+ def test_from_directory_propagates_on_overlap (self , tmp_path ):
757+ shard_a = CachedPredictor .from_dataframe (
758+ _df ([_row (peptide = "SIINFEKLA" , affinity = 100.0 )]),
759+ )
760+ shard_b = CachedPredictor .from_dataframe (
761+ _df ([_row (peptide = "SIINFEKLA" , affinity = 999.0 )]),
762+ )
763+ shard_a .save (tmp_path / "a.tsv" )
764+ shard_b .save (tmp_path / "b.tsv" )
765+ # Default policy raises
766+ with pytest .raises (ValueError , match = "overlapping" ):
767+ CachedPredictor .from_directory (tmp_path , pattern = "*.tsv" )
768+ # last-wins policy succeeds
769+ merged = CachedPredictor .from_directory (
770+ tmp_path , pattern = "*.tsv" , on_overlap = "last" ,
771+ )
772+ # sorted(["a.tsv", "b.tsv"]) → b wins (affinity=999.0)
773+ assert merged ._df .iloc [0 ]["affinity" ] == 999.0
0 commit comments