Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
name: Run unit tests
strategy:
matrix:
python-version: ['3.7', '3.8', '3.9', '3.10']
python-version: ['3.8', '3.9', '3.10', '3.11']
os: [ubuntu-latest, windows-latest]
requirements: ['.[tests]', '.[compat_tests]']
fail-fast: false
Expand All @@ -32,12 +32,20 @@ jobs:
with:
python-version: ${{ matrix.python-version }}

# Taken from https://github.com/actions/cache?tab=readme-ov-file#creating-a-cache-key
# Use date to invalidate cache every week
- name: Get date
id: get-date
run: |
echo "date=$(/bin/date -u "+%G%V")" >> $GITHUB_OUTPUT
shell: bash

- name: Try to load cached dependencies
uses: actions/cache@v3
id: restore-cache
with:
path: ${{ env.pythonLocation }}
key: python-dependencies-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.requirements }}-${{ hashFiles('setup.py') }}-${{ env.pythonLocation }}
key: python-dependencies-${{ matrix.os }}-${{ steps.get-date.outputs.date }}-${{ matrix.python-version }}-${{ matrix.requirements }}-${{ hashFiles('setup.py') }}-${{ env.pythonLocation }}

- name: Install external dependencies on cache miss
run: |
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/installation.mdx
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

# Installation

Before you start, you'll need to setup your environment and install the appropriate packages. 🤗 SetFit is tested on **Python 3.7+**.
Before you start, you'll need to setup your environment and install the appropriate packages. 🤗 SetFit is tested on **Python 3.8+**.

## pip

Expand Down
2 changes: 1 addition & 1 deletion scripts/perfect/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Follow the steps below to run the baselines based on the `PERFECT` paper: [_PERF
To get started, first create a Python virtual environment, e.g. with `conda`:

```
conda create -n baselines-perfect python=3.7 && conda activate baselines-perfect
conda create -n baselines-perfect python=3.10 && conda activate baselines-perfect
```

Next, clone [our fork](https://github.com/SetFit/perfect) of the [`PERFECT` codebase](https://github.com/facebookresearch/perfect), and install the required dependencies:
Expand Down
2 changes: 1 addition & 1 deletion scripts/tfew/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ These scripts run the baselines based on the `T-Few` paper: [_Few-Shot Parameter
To run the scripts, first create a Python virtual environment, e.g. with `conda`:

```
conda create -n baselines-tfew python=3.7 && conda activate baselines-tfew
conda create -n baselines-tfew python=3.10 && conda activate baselines-tfew
```

Next, clone our `T-Few` fork, and install the required dependencies:
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from setuptools import find_packages, setup


README_TEXT = (Path(__file__).parent / "README.md").read_text(encoding="utf-8")

MAINTAINER = "Lewis Tunstall, Tom Aarsen"
Expand All @@ -14,7 +13,7 @@
"datasets>=2.3.0",
"sentence-transformers>=2.2.1",
"evaluate>=0.3.0",
"huggingface_hub>=0.13.0",
"huggingface_hub>=0.22.1",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0.13 is a quite old version. I'm not sure everything works correctly with it.

"scikit-learn",
"packaging",
]
Expand Down Expand Up @@ -78,6 +77,8 @@ def combine_requirements(base_keys):
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
keywords="nlp, machine learning, fewshot learning, transformers",
Expand Down
13 changes: 2 additions & 11 deletions src/setfit/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,7 @@
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple, Union


# For Python 3.7 compatibility
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
from typing import Dict, List, Literal, Optional, Set, Tuple, Union

import joblib
import numpy as np
Expand All @@ -20,9 +13,8 @@
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
from huggingface_hub.utils import validate_hf_hub_args
from packaging.version import Version, parse
from sentence_transformers import SentenceTransformer
from sentence_transformers import SentenceTransformer, models
from sentence_transformers import __version__ as sentence_transformers_version
from sentence_transformers import models
from sklearn.linear_model import LogisticRegression
from sklearn.multiclass import OneVsRestClassifier
from sklearn.multioutput import ClassifierChain, MultiOutputClassifier
Expand All @@ -36,7 +28,6 @@
from .model_card import SetFitModelCardData, generate_model_card
from .utils import set_docstring


logging.set_verbosity_info()
logger = logging.get_logger(__name__)

Expand Down
31 changes: 20 additions & 11 deletions src/setfit/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,27 @@
import time
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Literal,
Optional,
Tuple,
Union,
)

import evaluate
import torch
from datasets import Dataset, DatasetDict
from sentence_transformers import InputExample, SentenceTransformer, losses
from sentence_transformers.datasets import SentenceLabelDataset
from sentence_transformers.losses.BatchHardTripletLoss import BatchHardTripletLossDistanceFunction
from sentence_transformers.losses.BatchHardTripletLoss import (
BatchHardTripletLossDistanceFunction,
)
from sentence_transformers.util import batch_to_device
from sklearn.preprocessing import LabelEncoder
from torch import nn
Expand Down Expand Up @@ -41,20 +54,16 @@
from setfit.model_card import ModelCardCallback

from . import logging
from .integrations import default_hp_search_backend, is_optuna_available, run_hp_search_optuna
from .integrations import (
default_hp_search_backend,
is_optuna_available,
run_hp_search_optuna,
)
from .losses import SupConLoss
from .sampler import ContrastiveDataset
from .training_args import TrainingArguments
from .utils import BestRun, default_hp_space_optuna


# For Python 3.7 compatibility
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal


if TYPE_CHECKING:
import optuna

Expand Down