Skip to content
26 changes: 13 additions & 13 deletions src/setfit/exporters/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch
from sentence_transformers import SentenceTransformer, models
from sklearn.linear_model import LogisticRegression
from torch import nn
from transformers.modeling_utils import PreTrainedModel

from setfit.exporters.utils import mean_pooling
Expand All @@ -34,7 +33,7 @@ class OnnxSetFitModel(torch.nn.Module):
def __init__(
self,
model_body: PreTrainedModel,
pooler: Optional[Union[nn.Module, Callable[[torch.Tensor], torch.Tensor]]] = None,
pooler: Optional[Union[torch.nn.Module, Callable[[torch.Tensor], torch.Tensor]]] = None,
model_head: Optional[Union[torch.nn.Module, LogisticRegression]] = None,
):
super().__init__()
Expand Down Expand Up @@ -136,15 +135,20 @@ def export_sklearn_head_to_onnx(model_head: LogisticRegression, opset: int) -> o
"""
raise ImportError(msg)

# Check to see that the head has a coef_
if not hasattr(model_head, "coef_"):
# Determine the initial type and the shape of the output.
input_shape = (None, model_head.n_features_in_)
if hasattr(model_head, "coef_"):
dtype = guess_data_type(model_head.coef_, shape=input_shape)[0][1]
elif not hasattr(model_head, "coef_") and hasattr(model_head, "estimators_"):
if any([not hasattr(e, "coef_") for e in model_head.estimators_]):
raise ValueError(
"The model_head is a meta-estimator but not all of the estimators have a coef_ attribute."
)
dtype = guess_data_type(model_head.estimators_[0].coef_, shape=input_shape)[0][1]
else:
raise ValueError(
"Head must have coef_ attribute check that this is supported by your model and the model has been fit."
"The model_head either does not have a coef_ attribute or some estimators in model_head.estimators_ do not have a coef_ attribute. Conversion to ONNX only supports these cases."
)

# Determine the initial type and the shape of the output.
input_shape = (None, *model_head.coef_.shape[1:])
dtype = guess_data_type(model_head.coef_, shape=input_shape)[0][1]
dtype.shape = input_shape

# If the datatype of the model is double we need to cast the outputs
Expand Down Expand Up @@ -235,10 +239,6 @@ def export_onnx(
meta.value = str(value)

else:
# TODO:: Make this work for other sklearn models without coef_.
if not hasattr(model_head, "coef_"):
raise ValueError("Model head must have coef_ attribute for weights.")

# Export the sklearn head first to get the minimum opset. sklearn is behind
# in supported opsets.
# Hummingbird-ML can be used as an option to export to standard opset
Expand Down
14 changes: 11 additions & 3 deletions tests/exporters/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,18 @@
from setfit.trainer import SetFitTrainer


def test_export_onnx_sklearn_head():
@pytest.mark.parametrize(
"model_path, input_text",
[
("lewtun/my-awesome-setfit-model", ["i loved the spiderman movie!", "pineapple on pizza is the worst 🤮"]),
(
"lewtun/setfit-ethos-multilabel-example",
["I'm a really hateful guy!", "I hate this one person in particular!"],
),
],
)
def test_export_onnx_sklearn_head(model_path, input_text):
"""Test that the exported `ONNX` model returns the same predictions as the original model."""
model_path = "lewtun/my-awesome-setfit-model"
model = SetFitModel.from_pretrained(model_path)

# Export the sklearn based model
Expand All @@ -25,7 +34,6 @@ def test_export_onnx_sklearn_head():
assert output_path in os.listdir(), "Model not saved to output_path"

# Run inference using the original model.
input_text = ["i loved the spiderman movie!", "pineapple on pizza is the worst 🤮"]
pytorch_preds = model(input_text)

# Run inference using the exported onnx model.
Expand Down