|
1 | 1 | import copy |
2 | 2 | import os |
| 3 | +import re |
3 | 4 | import tempfile |
4 | 5 | import types |
| 6 | +from collections import defaultdict |
5 | 7 | from dataclasses import dataclass, field |
6 | 8 | from pathlib import Path |
7 | 9 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union |
8 | 10 |
|
9 | 11 | import torch |
| 12 | +from datasets import Dataset |
10 | 13 | from huggingface_hub.utils import SoftTemporaryDirectory |
11 | 14 |
|
12 | 15 | from setfit.utils import set_docstring |
@@ -148,7 +151,99 @@ class AbsaModel: |
148 | 151 | aspect_model: AspectModel |
149 | 152 | polarity_model: PolarityModel |
150 | 153 |
|
151 | | - def predict(self, inputs: Union[str, List[str]]) -> List[Dict[str, Any]]: |
| 154 | + def gold_aspect_spans_to_aspects_list(self, inputs: Dataset) -> List[List[slice]]: |
| 155 | + # First group inputs by text |
| 156 | + grouped_data = defaultdict(list) |
| 157 | + for sample in inputs: |
| 158 | + text = sample.pop("text") |
| 159 | + grouped_data[text].append(sample) |
| 160 | + |
| 161 | + # Get the spaCy docs |
| 162 | + docs, _ = self.aspect_extractor(grouped_data.keys()) |
| 163 | + |
| 164 | + # Get the aspect spans for each doc by matching gold spans to the spaCy tokens |
| 165 | + aspects_list = [] |
| 166 | + index = -1 |
| 167 | + skipped_indices = [] |
| 168 | + for doc, samples in zip(docs, grouped_data.values()): |
| 169 | + aspects_list.append([]) |
| 170 | + for sample in samples: |
| 171 | + index += 1 |
| 172 | + match_objects = re.finditer(re.escape(sample["span"]), doc.text) |
| 173 | + for i, match in enumerate(match_objects): |
| 174 | + if i == sample["ordinal"]: |
| 175 | + char_idx_start = match.start() |
| 176 | + char_idx_end = match.end() |
| 177 | + span = doc.char_span(char_idx_start, char_idx_end) |
| 178 | + if span is None: |
| 179 | + logger.warning( |
| 180 | + f"Aspect term {sample['span']!r} with ordinal {sample['ordinal']}, isn't a token in {doc.text!r} according to spaCy. " |
| 181 | + "Skipping this sample." |
| 182 | + ) |
| 183 | + skipped_indices.append(index) |
| 184 | + continue |
| 185 | + aspects_list[-1].append(slice(span.start, span.end)) |
| 186 | + return docs, aspects_list, skipped_indices |
| 187 | + |
| 188 | + def predict_dataset(self, inputs: Dataset) -> Dataset: |
| 189 | + if set(inputs.column_names) >= {"text", "span", "ordinal"}: |
| 190 | + pass |
| 191 | + elif set(inputs.column_names) >= {"text", "span"}: |
| 192 | + inputs = inputs.add_column("ordinal", [0] * len(inputs)) |
| 193 | + else: |
| 194 | + raise ValueError( |
| 195 | + "`inputs` must be either a `str`, a `List[str]`, or a `datasets.Dataset` with columns `text` and `span` and optionally `ordinal`. " |
| 196 | + f"Found a dataset with these columns: {inputs.column_names}." |
| 197 | + ) |
| 198 | + if "pred_polarity" in inputs.column_names: |
| 199 | + raise ValueError( |
| 200 | + "`predict_dataset` wants to add a `pred_polarity` column, but the input dataset already contains that column." |
| 201 | + ) |
| 202 | + docs, aspects_list, skipped_indices = self.gold_aspect_spans_to_aspects_list(inputs) |
| 203 | + polarity_list = sum(self.polarity_model(docs, aspects_list), []) |
| 204 | + for index in skipped_indices: |
| 205 | + polarity_list.insert(index, None) |
| 206 | + return inputs.add_column("pred_polarity", polarity_list) |
| 207 | + |
| 208 | + def predict(self, inputs: Union[str, List[str], Dataset]) -> Union[List[Dict[str, Any]], Dataset]: |
| 209 | + """Predicts aspects & their polarities of the given inputs. |
| 210 | +
|
| 211 | + Example:: |
| 212 | +
|
| 213 | + >>> from setfit import AbsaModel |
| 214 | + >>> model = AbsaModel.from_pretrained( |
| 215 | + ... "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-aspect", |
| 216 | + ... "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-polarity", |
| 217 | + ... ) |
| 218 | + >>> model.predict("The food and wine are just exquisite.") |
| 219 | + [{'span': 'food', 'polarity': 'positive'}, {'span': 'wine', 'polarity': 'positive'}] |
| 220 | +
|
| 221 | + >>> from setfit import AbsaModel |
| 222 | + >>> from datasets import load_dataset |
| 223 | + >>> model = AbsaModel.from_pretrained( |
| 224 | + ... "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-aspect", |
| 225 | + ... "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-polarity", |
| 226 | + ... ) |
| 227 | + >>> dataset = load_dataset("tomaarsen/setfit-absa-semeval-restaurants", split="train") |
| 228 | + >>> model.predict(dataset) |
| 229 | + Dataset({ |
| 230 | + features: ['text', 'span', 'label', 'ordinal', 'pred_polarity'], |
| 231 | + num_rows: 3693 |
| 232 | + }) |
| 233 | +
|
| 234 | + Args: |
| 235 | + inputs (Union[str, List[str], Dataset]): Either a sentence, a list of sentences, |
| 236 | + or a dataset with columns `text` and `span` and optionally `ordinal`. This dataset |
| 237 | + contains gold aspects, and we only predict the polarities for them. |
| 238 | +
|
| 239 | + Returns: |
| 240 | + Union[List[Dict[str, Any]], Dataset]: Either a list of dictionaries with keys `span` |
| 241 | + and `polarity` if the input was a sentence or a list of sentences, or a dataset with |
| 242 | + columns `text`, `span`, `ordinal`, and `pred_polarity`. |
| 243 | + """ |
| 244 | + if isinstance(inputs, Dataset): |
| 245 | + return self.predict_dataset(inputs) |
| 246 | + |
152 | 247 | is_str = isinstance(inputs, str) |
153 | 248 | inputs_list = [inputs] if is_str else inputs |
154 | 249 | docs, aspects_list = self.aspect_extractor(inputs_list) |
|
0 commit comments