Skip to content

Commit 3e3d828

Browse files
authored
Merge pull request #469 from tomaarsen/absa_predict_gold_aspects
[`ABSA`] Predict with a gold aspect dataset
2 parents 6ef9482 + 38e9075 commit 3e3d828

File tree

3 files changed

+185
-1
lines changed

3 files changed

+185
-1
lines changed

src/setfit/span/modeling.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import copy
22
import os
3+
import re
34
import tempfile
45
import types
6+
from collections import defaultdict
57
from dataclasses import dataclass, field
68
from pathlib import Path
79
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
810

911
import torch
12+
from datasets import Dataset
1013
from huggingface_hub.utils import SoftTemporaryDirectory
1114

1215
from setfit.utils import set_docstring
@@ -148,7 +151,99 @@ class AbsaModel:
148151
aspect_model: AspectModel
149152
polarity_model: PolarityModel
150153

151-
def predict(self, inputs: Union[str, List[str]]) -> List[Dict[str, Any]]:
154+
def gold_aspect_spans_to_aspects_list(self, inputs: Dataset) -> List[List[slice]]:
155+
# First group inputs by text
156+
grouped_data = defaultdict(list)
157+
for sample in inputs:
158+
text = sample.pop("text")
159+
grouped_data[text].append(sample)
160+
161+
# Get the spaCy docs
162+
docs, _ = self.aspect_extractor(grouped_data.keys())
163+
164+
# Get the aspect spans for each doc by matching gold spans to the spaCy tokens
165+
aspects_list = []
166+
index = -1
167+
skipped_indices = []
168+
for doc, samples in zip(docs, grouped_data.values()):
169+
aspects_list.append([])
170+
for sample in samples:
171+
index += 1
172+
match_objects = re.finditer(re.escape(sample["span"]), doc.text)
173+
for i, match in enumerate(match_objects):
174+
if i == sample["ordinal"]:
175+
char_idx_start = match.start()
176+
char_idx_end = match.end()
177+
span = doc.char_span(char_idx_start, char_idx_end)
178+
if span is None:
179+
logger.warning(
180+
f"Aspect term {sample['span']!r} with ordinal {sample['ordinal']}, isn't a token in {doc.text!r} according to spaCy. "
181+
"Skipping this sample."
182+
)
183+
skipped_indices.append(index)
184+
continue
185+
aspects_list[-1].append(slice(span.start, span.end))
186+
return docs, aspects_list, skipped_indices
187+
188+
def predict_dataset(self, inputs: Dataset) -> Dataset:
189+
if set(inputs.column_names) >= {"text", "span", "ordinal"}:
190+
pass
191+
elif set(inputs.column_names) >= {"text", "span"}:
192+
inputs = inputs.add_column("ordinal", [0] * len(inputs))
193+
else:
194+
raise ValueError(
195+
"`inputs` must be either a `str`, a `List[str]`, or a `datasets.Dataset` with columns `text` and `span` and optionally `ordinal`. "
196+
f"Found a dataset with these columns: {inputs.column_names}."
197+
)
198+
if "pred_polarity" in inputs.column_names:
199+
raise ValueError(
200+
"`predict_dataset` wants to add a `pred_polarity` column, but the input dataset already contains that column."
201+
)
202+
docs, aspects_list, skipped_indices = self.gold_aspect_spans_to_aspects_list(inputs)
203+
polarity_list = sum(self.polarity_model(docs, aspects_list), [])
204+
for index in skipped_indices:
205+
polarity_list.insert(index, None)
206+
return inputs.add_column("pred_polarity", polarity_list)
207+
208+
def predict(self, inputs: Union[str, List[str], Dataset]) -> Union[List[Dict[str, Any]], Dataset]:
209+
"""Predicts aspects & their polarities of the given inputs.
210+
211+
Example::
212+
213+
>>> from setfit import AbsaModel
214+
>>> model = AbsaModel.from_pretrained(
215+
... "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-aspect",
216+
... "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-polarity",
217+
... )
218+
>>> model.predict("The food and wine are just exquisite.")
219+
[{'span': 'food', 'polarity': 'positive'}, {'span': 'wine', 'polarity': 'positive'}]
220+
221+
>>> from setfit import AbsaModel
222+
>>> from datasets import load_dataset
223+
>>> model = AbsaModel.from_pretrained(
224+
... "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-aspect",
225+
... "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-polarity",
226+
... )
227+
>>> dataset = load_dataset("tomaarsen/setfit-absa-semeval-restaurants", split="train")
228+
>>> model.predict(dataset)
229+
Dataset({
230+
features: ['text', 'span', 'label', 'ordinal', 'pred_polarity'],
231+
num_rows: 3693
232+
})
233+
234+
Args:
235+
inputs (Union[str, List[str], Dataset]): Either a sentence, a list of sentences,
236+
or a dataset with columns `text` and `span` and optionally `ordinal`. This dataset
237+
contains gold aspects, and we only predict the polarities for them.
238+
239+
Returns:
240+
Union[List[Dict[str, Any]], Dataset]: Either a list of dictionaries with keys `span`
241+
and `polarity` if the input was a sentence or a list of sentences, or a dataset with
242+
columns `text`, `span`, `ordinal`, and `pred_polarity`.
243+
"""
244+
if isinstance(inputs, Dataset):
245+
return self.predict_dataset(inputs)
246+
152247
is_str = isinstance(inputs, str)
153248
inputs_list = [inputs] if is_str else inputs
154249
docs, aspects_list = self.aspect_extractor(inputs_list)

tests/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@ def absa_model() -> AbsaModel:
1414
return AbsaModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2", spacy_model="en_core_web_sm")
1515

1616

17+
@pytest.fixture()
18+
def trained_absa_model() -> AbsaModel:
19+
return AbsaModel.from_pretrained(
20+
"tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-aspect",
21+
"tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-polarity",
22+
)
23+
24+
1725
@pytest.fixture()
1826
def absa_dataset() -> Dataset:
1927
texts = [

tests/span/test_modeling.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import json
2+
import re
23
from pathlib import Path
34
from tempfile import TemporaryDirectory
45

56
import pytest
67
import torch
8+
from datasets import Dataset
79
from pytest import LogCaptureFixture
810

911
from setfit import AbsaModel
@@ -144,3 +146,82 @@ def test_load_model_on_device(device):
144146
assert model.device.type == device
145147
assert model.polarity_model.device.type == device
146148
assert model.aspect_model.device.type == device
149+
150+
151+
def test_predict_dataset(trained_absa_model: AbsaModel):
152+
inputs = Dataset.from_dict(
153+
{
154+
"text": [
155+
"But the staff was so horrible to us.",
156+
"To be completely fair, the only redeeming factor was the food, which was above average, but couldn't make up for all the other deficiencies of Teodora.",
157+
"The food is uniformly exceptional, with a very capable kitchen which will proudly whip up whatever you feel like eating, whether it's on the menu or not.",
158+
"The food is uniformly exceptional, with a very capable kitchen which will proudly whip up whatever you feel like eating, whether it's on the menu or not.",
159+
"The food is uniformly exceptional, with a very capable kitchen which will proudly whip up whatever you feel like eating, whether it's on the menu or not.",
160+
],
161+
"span": ["staff", "food", "food", "kitchen", "menu"],
162+
"label": ["negative", "positive", "positive", "positive", "neutral"],
163+
"ordinal": [0, 0, 0, 0, 0],
164+
}
165+
)
166+
outputs = trained_absa_model.predict(inputs)
167+
assert isinstance(outputs, Dataset)
168+
assert set(outputs.column_names) == {"pred_polarity", "text", "span", "label", "ordinal"}
169+
170+
inputs = Dataset.from_dict(
171+
{
172+
"text": [
173+
"But the staff was so horrible to us.",
174+
"To be completely fair, the only redeeming factor was the food, which was above average, but couldn't make up for all the other deficiencies of Teodora.",
175+
"The food is uniformly exceptional, with a very capable kitchen which will proudly whip up whatever you feel like eating, whether it's on the menu or not.",
176+
"The food is uniformly exceptional, with a very capable kitchen which will proudly whip up whatever you feel like eating, whether it's on the menu or not.",
177+
"The food is uniformly exceptional, with a very capable kitchen which will proudly whip up whatever you feel like eating, whether it's on the menu or not.",
178+
],
179+
"span": ["staff", "food", "food", "kitchen", "menu"],
180+
}
181+
)
182+
outputs = trained_absa_model.predict(inputs)
183+
assert isinstance(outputs, Dataset)
184+
assert "pred_polarity" in outputs.column_names
185+
186+
187+
def test_predict_dataset_errors(trained_absa_model: AbsaModel):
188+
inputs = Dataset.from_dict(
189+
{
190+
"text": [
191+
"But the staff was so horrible to us.",
192+
"To be completely fair, the only redeeming factor was the food, which was above average, but couldn't make up for all the other deficiencies of Teodora.",
193+
"The food is uniformly exceptional, with a very capable kitchen which will proudly whip up whatever you feel like eating, whether it's on the menu or not.",
194+
"The food is uniformly exceptional, with a very capable kitchen which will proudly whip up whatever you feel like eating, whether it's on the menu or not.",
195+
"The food is uniformly exceptional, with a very capable kitchen which will proudly whip up whatever you feel like eating, whether it's on the menu or not.",
196+
],
197+
}
198+
)
199+
with pytest.raises(
200+
ValueError,
201+
match=re.escape(
202+
"`inputs` must be either a `str`, a `List[str]`, or a `datasets.Dataset` with columns `text` and `span` and optionally `ordinal`. "
203+
"Found a dataset with these columns: ['text']."
204+
),
205+
):
206+
trained_absa_model.predict(inputs)
207+
208+
inputs = Dataset.from_dict(
209+
{
210+
"text": [
211+
"But the staff was so horrible to us.",
212+
"To be completely fair, the only redeeming factor was the food, which was above average, but couldn't make up for all the other deficiencies of Teodora.",
213+
"The food is uniformly exceptional, with a very capable kitchen which will proudly whip up whatever you feel like eating, whether it's on the menu or not.",
214+
"The food is uniformly exceptional, with a very capable kitchen which will proudly whip up whatever you feel like eating, whether it's on the menu or not.",
215+
"The food is uniformly exceptional, with a very capable kitchen which will proudly whip up whatever you feel like eating, whether it's on the menu or not.",
216+
],
217+
"span": ["staff", "food", "food", "kitchen", "menu"],
218+
"pred_polarity": ["negative", "positive", "positive", "positive", "neutral"],
219+
}
220+
)
221+
with pytest.raises(
222+
ValueError,
223+
match=re.escape(
224+
"`predict_dataset` wants to add a `pred_polarity` column, but the input dataset already contains that column."
225+
),
226+
):
227+
trained_absa_model.predict(inputs)

0 commit comments

Comments
 (0)