Prevent TypeError on model.predict when using string labels
#331
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Hello!
Pull Request overview
model.predictwhen using string labels.Details
When training with string labels (which is not strictly recommended, but possible), then
model.predictbroke as of the latest version. See the following script to reproduce:Reproduction
This resulted in
See also #329, which shows this same issue, but for
evaluate(which callspredictbehind the scenes).Why do we get this error?
Consider the following lines in the
predictmethod:setfit/src/setfit/modeling.py
Lines 414 to 421 in 0420165
And consider the scenario with the (default) non-differentiable head and
as_numpy=False. In this case, we reach line 419 and calltorch.from_numpy. However,outputshas dtype<U8, where theUindicates that the type is a unicode string. There is no Torch tensor equivalent of this type, and thus we get the error shown above.The fix
The fix is simply to prevent calling
torch.from_numpyif the head outputs a numpy array with strings.Note
The issue from #329 isn't exactly fixed, calling
evaluateusing string labels still fails, as the evaluate library does not support string labels in itsaccuracymetric. This can be counteracted by supplying a differentmetric, e.g. a function, that computes some metric with support of strings.