Skip to content

cbenge509/BERTVision

Repository files navigation

BERTVision

Parameter-efficient fine-tuning using BERT's hidden state activations

License Python 3.7 PyTorch TensorFlow Transformers SQuAD 2.0 GLUE Paper


Authors

William Casey King, PhDSiduo (Stone) JiangCristopher BengeAndrew Fogarty

UC Berkeley
UC Berkeley · Masters in Information & Data Science
Spring 2021 Capstone · W210 with Alberto Todeschini, PhD & Puya H. Vahabi, PhD


Key Results

BERTVision achieves BERT-level performance with a fraction of the trainable parameters:

Model Params SQuAD 2.0 EM SQuAD 2.0 F1
BERT-large (1 epoch) 335M (100%) 72.8 77.7
BERT-large (2 epochs, best) 335M (100%) 74.7 79.2
BERTVision (1 epoch) ~0.4M (0.12%) 74.9 79.0
Ensemble (BERT + BERTVision) - 75.6 79.8

Highlights:

  • ~800x fewer trainable parameters than full BERT fine-tuning
  • Matches BERT 2-epoch performance with only 1 epoch of embeddings
  • Ensemble surpasses BERT by +0.9 EM and +0.6 F1
  • Works for both span annotation (QA) and binary classification

Performance comparison


Architecture

BERTVision Architecture

Left: Standard BERT inference discards hidden layer outputs.
Right: BERTVision captures all 24 layers, feeding them to a lightweight adapter for span prediction.

Binary Classification Pipeline

Binary Classification Pipeline: For tasks like answerability detection, BERT-large extracts CLS token embeddings from all 24 transformer layers. These embeddings are pooled using learned weights and fed into a lightweight dense layer for Answer/No Answer classification.

Span Annotation Pipeline

Span Annotation Pipeline: For question answering, embeddings are extracted for all 386 tokens across all transformer layers. The learned pooling mechanism combines layer outputs, and a dense layer predicts span start and end positions for answer extraction.


Quick Start

Installation

# Clone the repository
git clone https://github.com/cbenge509/BERTVision.git
cd BERTVision

# Option 1: Conda (recommended)
conda create -n bertvision python=3.7 pytorch torchvision torchaudio cudatoolkit=10.1 -c pytorch
conda activate bertvision
conda install -c conda-forge jupyter numpy pandas matplotlib scikit-learn transformers datasets tqdm
pip install loguru hyperopt h5py

# Option 2: Pipenv
pipenv install
pipenv shell

Usage

All models run from code/torch/:

cd code/torch

# Fine-tune BERT on SQuAD 2.0
python -m models.bert_squad --model SQuAD --checkpoint bert-large-uncased \
    --lr 2e-5 --max-seq-length 384 --batch-size 8

# Train BERTVision adapter on pre-extracted embeddings
python -m models.ap_squad --model AP_SQuAD --checkpoint bert-large-uncased \
    --lr 2e-5 --max-seq-length 384 --batch-size 8

See code/torch/README.md for full documentation.


Supported Tasks

SQuAD 2.0

Question answering with span annotation and unanswerable question detection.

GLUE Benchmark

Task Type Labels Description
CoLA Classification 2 Linguistic acceptability
SST-2 Classification 2 Sentiment analysis
MRPC Paraphrase 2 Paraphrase detection
QQP Paraphrase 2 Question pair similarity
STS-B Regression 1 Semantic textual similarity
MNLI NLI 3 Natural language inference
QNLI NLI 2 Question NLI
RTE NLI 2 Textual entailment
WNLI NLI 2 Winograd NLI

Citation

If you use BERTVision in your research, please cite:

@article{jiang2020bertvision,
  title={BERTVision: A Parameter-Efficient Approach for Question Answering},
  author={Jiang, Siduo and Benge, Cristopher and King, William Casey},
  journal={UC Berkeley},
  year={2020},
  url={https://github.com/cbenge509/BERTVision}
}

License

This project is licensed under the MIT License - see LICENSE.txt for details.

About

A parameter-efficient compression model architecture for a variety of NLP tasks at BERT level performance at a fraction of the computational requirements.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors