-
Notifications
You must be signed in to change notification settings - Fork 243
Expand file tree
/
Copy pathevaluate_deepct.py
More file actions
157 lines (128 loc) · 6.94 KB
/
evaluate_deepct.py
File metadata and controls
157 lines (128 loc) · 6.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""
This example shows how to evaluate DeepCT (using Anserini) in BEIR.
For more details on DeepCT, refer here: https://arxiv.org/abs/1910.10687
The original DeepCT repository is not modularised and only works with Tensorflow 1.x (1.15).
We modified the DeepCT repository to work with Tensorflow latest (2.x).
We do not change the core-prediction code, only few input/output file format and structure to adapt to BEIR formats.
For more details on changes, check: https://github.com/NThakur20/DeepCT and compare it with original repo!
Please follow the steps below to install DeepCT:
1. git clone https://github.com/NThakur20/DeepCT.git
Since Anserini uses Java-11, we would advise you to use docker for running Pyserini.
To be able to run the code below you must have docker locally installed in your machine.
To install docker on your local machine, please refer here: https://docs.docker.com/get-docker/
After docker installation, please follow the steps below to get docker container up and running:
1. docker pull docker pull beir/pyserini-fastapi
2. docker build -t pyserini-fastapi .
3. docker run -p 8000:8000 -it --rm pyserini-fastapi
Usage: python evaluate_deepct.py
"""
import json
import logging
import os
import pathlib
import random
import requests
from DeepCT.deepct import (
run_deepct,
)
# git clone https://github.com/NThakur20/DeepCT.git
from beir import LoggingHandler, util
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
#### Just some code to print debug information to stdout
logging.basicConfig(
format="%(asctime)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO,
handlers=[LoggingHandler()],
)
#### /print debug information to stdout
#### Download scifact.zip dataset and unzip the dataset
dataset = "scifact"
url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset}.zip"
out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
data_path = util.download_and_unzip(url, out_dir)
corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")
#### 1. Download Google BERT-BASE, Uncased model ####
# Ref: https://github.com/google-research/bert
base_model_url = "https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip"
out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "models")
bert_base_dir = util.download_and_unzip(base_model_url, out_dir)
#### 2. Download DeepCT MSMARCO Trained BERT checkpoint ####
# Credits to DeepCT authors: Zhuyun Dai, Jamie Callan, (https://github.com/AdeDZY/DeepCT)
model_url = "http://boston.lti.cs.cmu.edu/appendices/arXiv2019-DeepCT-Zhuyun-Dai/outputs/marco.zip"
out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "models")
checkpoint_dir = util.download_and_unzip(model_url, out_dir)
##################################################
#### 3. Configure Params for DeepCT inference ####
##################################################
# We cannot use the original Repo (https://github.com/AdeDZY/DeepCT) as it only runs with TF 1.15.
# We reformatted the code (https://github.com/NThakur20/DeepCT) and made it working with latest TF 2.X!
if not os.path.isfile(os.path.join(data_path, "deepct.jsonl")):
################################
#### Command-Line Arugments ####
################################
run_deepct.FLAGS.task_name = "beir" # Defined a seperate BEIR task in DeepCT. Check out run_deepct.
run_deepct.FLAGS.do_train = False # We only want to use the code for inference.
run_deepct.FLAGS.do_eval = False # No evaluation.
run_deepct.FLAGS.do_predict = True # True, as we would use DeepCT model for only prediction.
run_deepct.FLAGS.data_dir = os.path.join(
data_path, "corpus.jsonl"
) # Provide original path to corpus data, follow beir format.
run_deepct.FLAGS.vocab_file = os.path.join(
bert_base_dir, "vocab.txt"
) # Provide bert-base-uncased model vocabulary.
run_deepct.FLAGS.bert_config_file = os.path.join(
bert_base_dir, "bert_config.json"
) # Provide bert-base-uncased config.json file.
run_deepct.FLAGS.init_checkpoint = os.path.join(
checkpoint_dir, "model.ckpt-65816"
) # Provide DeepCT MSMARCO model (bert-base-uncased) checkpoint file.
run_deepct.FLAGS.max_seq_length = 350 # Provide Max Sequence Length used for consideration. (Max: 512)
run_deepct.FLAGS.train_batch_size = 128 # Inference batch size, Larger more Memory but faster!
run_deepct.FLAGS.output_dir = (
data_path # Output directory, this will contain two files: deepct.jsonl (output-file) and predict.tf_record
)
run_deepct.FLAGS.output_file = "deepct.jsonl" # Output file for storing final DeepCT produced corpus.
run_deepct.FLAGS.m = 100 # Scaling parameter for DeepCT weights: scaling parameter > 0, recommend 100
run_deepct.FLAGS.smoothing = "sqrt" # Use sqrt to smooth weights. DeepCT Paper uses None.
run_deepct.FLAGS.keep_all_terms = True # Do not allow DeepCT to delete terms.
# Runs DeepCT model on the corpus.jsonl
run_deepct.main()
#### Download Docker Image beir/pyserini-fastapi ####
#### Locally run the docker Image + FastAPI ####
docker_beir_pyserini = "http://127.0.0.1:8000"
#### Upload Multipart-encoded files ####
with open(os.path.join(data_path, "deepct.jsonl"), "rb") as fIn:
r = requests.post(docker_beir_pyserini + "/upload/", files={"file": fIn}, verify=False)
#### Index documents to Pyserini #####
index_name = "beir/you-index-name" # beir/scifact
r = requests.get(docker_beir_pyserini + "/index/", params={"index_name": index_name})
######################################
#### 2. Pyserini-Retrieval (BM25) ####
######################################
#### Retrieve documents from Pyserini #####
retriever = EvaluateRetrieval()
qids = list(queries)
query_texts = [queries[qid] for qid in qids]
payload = {
"queries": query_texts,
"qids": qids,
"k": max(retriever.k_values),
"fields": {"contents": 1.0},
"bm25": {"k1": 18, "b": 0.7},
}
#### Retrieve pyserini results (format of results is identical to qrels)
results = json.loads(requests.post(docker_beir_pyserini + "/lexical/batch_search/", json=payload).text)["results"]
#### Retrieve RM3 expanded pyserini results (format of results is identical to qrels)
# results = json.loads(requests.post(docker_beir_pyserini + "/lexical/rm3/batch_search/", json=payload).text)["results"]
#### Evaluate your retrieval using NDCG@k, MAP@K ...
logging.info(f"Retriever evaluation for k in: {retriever.k_values}")
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
#### Retrieval Example ####
query_id, scores_dict = random.choice(list(results.items()))
logging.info(f"Query : {queries[query_id]}\n")
scores = sorted(scores_dict.items(), key=lambda item: item[1], reverse=True)
for rank in range(10):
doc_id = scores[rank][0]
logging.info(f"Rank {rank + 1}: {doc_id} [{corpus[doc_id].get('title')}] - {corpus[doc_id].get('text')}\n")