Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions .github/workflows/linter.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
name: Ruff

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: 3.9
- run: pip install ruff
- run: ruff check .
140 changes: 140 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

*.test
*.onx
*.qonx
*.DS_Store
*.pyc
*.ipynb_checkpoints
*.pickle
*.pkl
*.icloud
cache/
# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/
test.ipynb
3 changes: 1 addition & 2 deletions giga_cherche/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from .ColBERTDataCollator import ColBERTDataCollator
__all__ = ["models", "losses", "util", "evaluation", "ColBERTDataCollator" ]
__all__ = ["models", "losses", "scores", "evaluation", "data_collator"]
3 changes: 3 additions & 0 deletions giga_cherche/__version__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
VERSION = (0, 0, 1)

__version__ = ".".join(map(str, VERSION))
3 changes: 3 additions & 0 deletions giga_cherche/data_collator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .colbert_data_collator import ColBERTDataCollator

__all__ = ["ColBERTDataCollator"]
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,13 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
for column in columns:
# We tokenize the query differently than the documents, TODO: define a parameter "query_column"
if "query" in column or "anchor" in column:
tokenized = self.tokenize_fn([row[column] for row in features], is_query=True)
tokenized = self.tokenize_fn(
[row[column] for row in features], is_query=True
)
else:
tokenized = self.tokenize_fn([row[column] for row in features], is_query=False)
tokenized = self.tokenize_fn(
[row[column] for row in features], is_query=False
)
for key, value in tokenized.items():
batch[f"{column}_{key}"] = value
return batch
return batch
4 changes: 2 additions & 2 deletions giga_cherche/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .ColBERTTripletEvaluator import ColBERTTripletEvaluator
from .colbert_triplet_evaluator import ColBERTTripletEvaluator

__all__ = ["ColBERTTripletEvaluator"]
__all__ = ["ColBERTTripletEvaluator"]
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
from contextlib import nullcontext
from typing import TYPE_CHECKING, Dict, List, Optional, Union

from giga_cherche.util import colbert_pairwise_score, colbert_score
import numpy as np

import torch
from sentence_transformers.evaluation.SentenceEvaluator import SentenceEvaluator
from sentence_transformers.readers import InputExample
from sentence_transformers.similarity_functions import SimilarityFunction

import torch
from giga_cherche.scores.colbert_score import colbert_score

if TYPE_CHECKING:
from sentence_transformers.SentenceTransformer import SentenceTransformer
Expand Down Expand Up @@ -97,17 +96,30 @@ def __init__(
assert len(self.anchors) == len(self.positives)
assert len(self.anchors) == len(self.negatives)

self.main_distance_function = SimilarityFunction(main_distance_function) if main_distance_function else None
self.main_distance_function = (
SimilarityFunction(main_distance_function)
if main_distance_function
else None
)

self.batch_size = batch_size
if show_progress_bar is None:
show_progress_bar = (
logger.getEffectiveLevel() == logging.INFO or logger.getEffectiveLevel() == logging.DEBUG
logger.getEffectiveLevel() == logging.INFO
or logger.getEffectiveLevel() == logging.DEBUG
)
self.show_progress_bar = show_progress_bar

self.csv_file: str = "triplet_evaluation" + ("_" + name if name else "") + "_results.csv"
self.csv_headers = ["epoch", "steps", "accuracy_cosinus", "accuracy_manhattan", "accuracy_euclidean"]
self.csv_file: str = (
"triplet_evaluation" + ("_" + name if name else "") + "_results.csv"
)
self.csv_headers = [
"epoch",
"steps",
"accuracy_cosinus",
"accuracy_manhattan",
"accuracy_euclidean",
]
self.write_csv = write_csv

@classmethod
Expand All @@ -121,9 +133,14 @@ def from_input_examples(cls, examples: List[InputExample], **kwargs):
positives.append(example.texts[1])
negatives.append(example.texts[2])
return cls(anchors, positives, negatives, **kwargs)
#TODO: add mAP and other metrics

# TODO: add mAP and other metrics
def __call__(
self, model: "SentenceTransformer", output_path: str = None, epoch: int = -1, steps: int = -1
self,
model: "SentenceTransformer",
output_path: str = None,
epoch: int = -1,
steps: int = -1,
) -> Dict[str, float]:
if epoch != -1:
if steps == -1:
Expand All @@ -135,14 +152,18 @@ def __call__(
if self.truncate_dim is not None:
out_txt += f" (truncated to {self.truncate_dim})"

logger.info(f"TripletEvaluator: Evaluating the model on the {self.name} dataset{out_txt}:")
logger.info(
f"TripletEvaluator: Evaluating the model on the {self.name} dataset{out_txt}:"
)

num_triplets = 0
(
num_correct_colbert_triplets
) = 0
(num_correct_colbert_triplets) = 0

with nullcontext() if self.truncate_dim is None else model.truncate_sentence_embeddings(self.truncate_dim):
with (
nullcontext()
if self.truncate_dim is None
else model.truncate_sentence_embeddings(self.truncate_dim)
):
embeddings_anchors = model.encode(
self.anchors,
batch_size=self.batch_size,
Expand All @@ -166,18 +187,28 @@ def __call__(
)

# TODO: do the padding in encode()?
embeddings_positives = torch.nn.utils.rnn.pad_sequence(embeddings_positives, batch_first=True, padding_value=0)
embeddings_positives = torch.nn.utils.rnn.pad_sequence(
embeddings_positives, batch_first=True, padding_value=0
)
attention_mask_positives = (embeddings_positives.sum(dim=-1) != 0).float()

embeddings_negatives = torch.nn.utils.rnn.pad_sequence(embeddings_negatives, batch_first=True, padding_value=0)

embeddings_negatives = torch.nn.utils.rnn.pad_sequence(
embeddings_negatives, batch_first=True, padding_value=0
)
attention_mask_negatives = (embeddings_negatives.sum(dim=-1) != 0).float()

# Colbert distance
# pos_colbert_distances = colbert_pairwise_score(embeddings_anchors, embeddings_positives)
# neg_colbert_distances = colbert_pairwise_score(embeddings_anchors, embeddings_negatives)
pos_colbert_distances_full = colbert_score(embeddings_anchors, embeddings_positives, attention_mask_positives)
neg_colbert_distances_full = colbert_score(embeddings_anchors, embeddings_negatives, attention_mask_negatives)
distances_full = torch.cat([pos_colbert_distances_full, neg_colbert_distances_full], dim=1)
pos_colbert_distances_full = colbert_score(
embeddings_anchors, embeddings_positives, attention_mask_positives
)
neg_colbert_distances_full = colbert_score(
embeddings_anchors, embeddings_negatives, attention_mask_negatives
)
distances_full = torch.cat(
[pos_colbert_distances_full, neg_colbert_distances_full], dim=1
)
# print(distances_full.shape)
labels = np.arange(0, len(embeddings_anchors))
indices = np.argsort(-distances_full.cpu().numpy(), axis=1)
Expand All @@ -192,7 +223,7 @@ def __call__(
num_triplets += 1
if pos_colbert_distances[idx] > neg_colbert_distances[idx]:
num_correct_colbert_triplets += 1

accuracy_colbert = num_correct_colbert_triplets / num_triplets

logger.info("Accuracy Colbert: \t{:.2f}".format(accuracy_colbert * 100))
Expand All @@ -218,10 +249,10 @@ def __call__(
}.get(self.main_distance_function, "max_accuracy")
metrics = {
"colbert_accuracy": accuracy_colbert,
"hits@1": np.sum(ranks <= 1)/len(ranks),
"hits@5": np.sum(ranks <= 5)/len(ranks),
"hits@10": np.sum(ranks <= 10)/len(ranks),
"hits@25": np.sum(ranks <= 25)/len(ranks),
"hits@1": np.sum(ranks <= 1) / len(ranks),
"hits@5": np.sum(ranks <= 5) / len(ranks),
"hits@10": np.sum(ranks <= 10) / len(ranks),
"hits@25": np.sum(ranks <= 25) / len(ranks),
# "max_accuracy": max(accuracy_cos, accuracy_manhattan, accuracy_euclidean),
}
metrics = self.prefix_name_to_metrics(metrics, self.name)
Expand Down
4 changes: 2 additions & 2 deletions giga_cherche/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .ColBERTLoss import ColBERTLoss
from .colbert_loss import ColBERTLoss

__all__ = ["ColBERTLoss"]
__all__ = ["ColBERTLoss"]
Loading