Skip to content

Commit a0b69b4

Browse files
bogedytomaarsen
andauthored
Remove coef_ requirement from ONNX exporter. (#361)
* Use ONNX exporter with classes without coef_ * Expanded ONNX exporter tests for multilabel * debugging different probs * Added exception to exporter and ran make style. * Switched out hate speech example * fixed quality check by updating black * Wild guess to make github workflows pass. * Run make style Black 23.1.0 & isort 5.12.0 --------- Co-authored-by: Tom Aarsen <[email protected]>
1 parent 1b076bd commit a0b69b4

File tree

2 files changed

+24
-16
lines changed

2 files changed

+24
-16
lines changed

src/setfit/exporters/onnx.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import torch
88
from sentence_transformers import SentenceTransformer, models
99
from sklearn.linear_model import LogisticRegression
10-
from torch import nn
1110
from transformers.modeling_utils import PreTrainedModel
1211

1312
from setfit.exporters.utils import mean_pooling
@@ -34,7 +33,7 @@ class OnnxSetFitModel(torch.nn.Module):
3433
def __init__(
3534
self,
3635
model_body: PreTrainedModel,
37-
pooler: Optional[Union[nn.Module, Callable[[torch.Tensor], torch.Tensor]]] = None,
36+
pooler: Optional[Union[torch.nn.Module, Callable[[torch.Tensor], torch.Tensor]]] = None,
3837
model_head: Optional[Union[torch.nn.Module, LogisticRegression]] = None,
3938
):
4039
super().__init__()
@@ -136,15 +135,20 @@ def export_sklearn_head_to_onnx(model_head: LogisticRegression, opset: int) -> o
136135
"""
137136
raise ImportError(msg)
138137

139-
# Check to see that the head has a coef_
140-
if not hasattr(model_head, "coef_"):
138+
# Determine the initial type and the shape of the output.
139+
input_shape = (None, model_head.n_features_in_)
140+
if hasattr(model_head, "coef_"):
141+
dtype = guess_data_type(model_head.coef_, shape=input_shape)[0][1]
142+
elif not hasattr(model_head, "coef_") and hasattr(model_head, "estimators_"):
143+
if any([not hasattr(e, "coef_") for e in model_head.estimators_]):
144+
raise ValueError(
145+
"The model_head is a meta-estimator but not all of the estimators have a coef_ attribute."
146+
)
147+
dtype = guess_data_type(model_head.estimators_[0].coef_, shape=input_shape)[0][1]
148+
else:
141149
raise ValueError(
142-
"Head must have coef_ attribute check that this is supported by your model and the model has been fit."
150+
"The model_head either does not have a coef_ attribute or some estimators in model_head.estimators_ do not have a coef_ attribute. Conversion to ONNX only supports these cases."
143151
)
144-
145-
# Determine the initial type and the shape of the output.
146-
input_shape = (None, *model_head.coef_.shape[1:])
147-
dtype = guess_data_type(model_head.coef_, shape=input_shape)[0][1]
148152
dtype.shape = input_shape
149153

150154
# If the datatype of the model is double we need to cast the outputs
@@ -235,10 +239,6 @@ def export_onnx(
235239
meta.value = str(value)
236240

237241
else:
238-
# TODO:: Make this work for other sklearn models without coef_.
239-
if not hasattr(model_head, "coef_"):
240-
raise ValueError("Model head must have coef_ attribute for weights.")
241-
242242
# Export the sklearn head first to get the minimum opset. sklearn is behind
243243
# in supported opsets.
244244
# Hummingbird-ML can be used as an option to export to standard opset

tests/exporters/test_onnx.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,18 @@
1111
from setfit.trainer import SetFitTrainer
1212

1313

14-
def test_export_onnx_sklearn_head():
14+
@pytest.mark.parametrize(
15+
"model_path, input_text",
16+
[
17+
("lewtun/my-awesome-setfit-model", ["i loved the spiderman movie!", "pineapple on pizza is the worst 🤮"]),
18+
(
19+
"lewtun/setfit-ethos-multilabel-example",
20+
["I'm a really hateful guy!", "I hate this one person in particular!"],
21+
),
22+
],
23+
)
24+
def test_export_onnx_sklearn_head(model_path, input_text):
1525
"""Test that the exported `ONNX` model returns the same predictions as the original model."""
16-
model_path = "lewtun/my-awesome-setfit-model"
1726
model = SetFitModel.from_pretrained(model_path)
1827

1928
# Export the sklearn based model
@@ -25,7 +34,6 @@ def test_export_onnx_sklearn_head():
2534
assert output_path in os.listdir(), "Model not saved to output_path"
2635

2736
# Run inference using the original model.
28-
input_text = ["i loved the spiderman movie!", "pineapple on pizza is the worst 🤮"]
2937
pytorch_preds = model(input_text)
3038

3139
# Run inference using the exported onnx model.

0 commit comments

Comments
 (0)