Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 255 additions & 0 deletions src/lmflow/pipeline/inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@
import datetime
import json
import time
import logging
from typing import Dict, List
from concurrent.futures import ThreadPoolExecutor

from transformers import AutoConfig
import torch.distributed as dist
import torch.nn.functional as F

from lmflow.args import DatasetArguments
from lmflow.datasets.dataset import Dataset
Expand All @@ -31,6 +35,8 @@ def rstrip_partial_utf8(string):
"image_text",
]

logger = logging.getLogger(__name__)

class Inferencer(BasePipeline):
"""
Initializes the `Inferencer` class with given arguments.
Expand Down Expand Up @@ -267,3 +273,252 @@ def stream_inference(
response = response[:index]

yield response, flag_break


class SpeculativeInferencer(Inferencer):
"""
Ref: [arXiv:2211.17192v2](https://arxiv.org/abs/2211.17192)

Parameters
------------
target_model_args : ModelArguments object.
Contains the arguments required to load the target model.

draft_model_args : ModelArguments object.
Contains the arguments required to load the draft model.

data_args : DatasetArguments object.
Contains the arguments required to load the dataset.

inferencer_args : InferencerArguments object.
Contains the arguments required to perform inference.


"""
def __init__(self, model_args, draft_model_args, data_args, inferencer_args):
super().__init__(model_args, data_args, inferencer_args)
self.draft_model_args = draft_model_args

self.draft_config = AutoConfig.from_pretrained(draft_model_args.model_name_or_path, trust_remote_code=True)
try:
self.draft_model_hidden_size = self.draft_config.hidden_size
except:
print("Error in setting hidden size for draft model, use the default size 1024")
self.draft_model_hidden_size = 768


@staticmethod
def score_to_prob(scores: torch.Tensor,
temperature: float = 1.,
top_p: float = 1.,) -> torch.Tensor:
"""Convert scores (NOT softmaxed tensor) to probabilities with support for temperature, top-p sampling, and argmax.

Parameters
----------
scores : torch.Tensor
Input scores.
temperature : float, optional
Temperature parameter for controlling randomness. Higher values make the distribution more uniform,
lower values make it peakier. When temperature <= 1e-6, argmax is used. by default 1.0
top_p : float, optional
Top-p sampling parameter for controlling the cumulative probability threshold, by default 1.0 (no threshold)

Returns
-------
torch.Tensor
Probability distribution after adjustments.
"""
assert temperature >= 0.0
assert 0.0 < top_p <= 1.0

if temperature <= 1e-6:
final_prob = F.one_hot(scores.argmax(dim=1), num_classes=scores.size(1)).float()
else:
scores /= temperature
if top_p < 1.0:
sorted_scores, _ = torch.sort(scores, descending=True)
probs = sorted_scores.softmax(dim=1)
cumulative_probs = torch.cumsum(probs, dim=1)
mask = cumulative_probs <= top_p
if mask.any():
thresholded_probs = probs * mask
thresholded_probs = thresholded_probs / thresholded_probs.sum(dim=1, keepdim=True)
final_prob = torch.zeros_like(scores)
final_prob.scatter_add_(1, sorted_scores.argsort(dim=1), thresholded_probs)
else:
final_prob = scores.softmax(dim=1)

else:
final_prob = scores.softmax(dim=1)

return final_prob


@staticmethod
def sample(prob: torch.Tensor, num_samples: int = 1) -> Dict:
"""Sample from a tensor of probabilities
"""
sampled_indices = torch.multinomial(prob, num_samples=num_samples, replacement=True)
return {'sampled_token': sampled_indices, 'sampled_prob': prob.gather(dim=1, index=sampled_indices), 'all_prob': prob}


@staticmethod
def predict_next_token(model: HFDecoderModel, input_ids: torch.Tensor, num_new_tokens: int = 1):
"""Predict the next token given the input_ids.
"""
output = model.inference(input_ids,
use_accelerator=True,
max_new_tokens=num_new_tokens,
return_dict_in_generate=True,
output_scores=True,
do_sample=True,
num_beams=1)
return output


def autoregressive_sampling(self, input_ids: torch.Tensor, model, num_new_tokens: int = 5) -> Dict:
"""Ref: [arXiv:2211.17192v2](https://arxiv.org/abs/2211.17192) Section 2.2
"""
sequence = input_ids
new_tokens = []

for _ in range(num_new_tokens):
pred = self.predict_next_token(model=model, input_ids=sequence, num_new_tokens=1) # predict next one token
prob = self.score_to_prob(pred.scores[0])
sampled = self.sample(prob=prob, num_samples=1)
new_tokens.append(sampled)
sequence = torch.cat([sequence, sampled['sampled_token']], dim=1)

return {"sequence": sequence, "new_tokens": new_tokens}


def inference(
self,
model: HFDecoderModel,
draft_model: HFDecoderModel,
input: str,
gamma: int = 5,
max_new_tokens: int = 100,
):
"""
Perform inference for a model

Parameters
------------
model : HFDecoderModel object.
TunableModel to verify tokens generated by the draft model.

draft_model : HFDecoderModel object.
TunableModel that provides approximations of the target model.

input : str.
The input text (i.e., the prompt) for the model.

gamma : int.
The number of tokens to be generated by the draft model within each iter.

max_new_tokens : int.
The maximum number of tokens to be generated by the target model.


Returns
-------
output: str.
The output text generated by the model.
"""
assert gamma > 0

if self.inferencer_args.device == "gpu":
inputs = model.encode(input, return_tensors="pt").to(device=self.local_rank)
elif self.inferencer_args.device == "cpu":
inputs = model.encode(input, return_tensors="pt").to(device='cpu')
else:
raise NotImplementedError(
f"device \"{self.inferencer_args.device}\" is not supported"
)


def speculative_sampling(input_ids: torch.Tensor,
model: HFDecoderModel,
draft_model: HFDecoderModel) -> torch.Tensor:
"""Ref: [arXiv:2211.17192v2](https://arxiv.org/abs/2211.17192)

Parameters
----------
input_ids : torch.Tensor
draft_model : TunableModel object
model_list : List[TunableModel object]

Returns
-------
torch.Tensor
"""
len_input_ids= input_ids.shape[1]
logger.debug(f"len of input_ids: {len_input_ids}")

# STEP 1: Sample γ guesses x1, ..., xγ from Mq (draft model) autoregressively
output_draft = self.autoregressive_sampling(input_ids=input_ids, model=draft_model, num_new_tokens=gamma)
logger.debug(f"draft result: {output_draft['sequence']}")
logger.debug(f"draft result decoded: {draft_model.decode(output_draft['sequence'][0])}")


# STEP 2: Run Mp (target model) in parallel
# generate sequences [prefix, x1, x2, ..., xγ]
output = model.get_backend_model()(input_ids=output_draft['sequence'], return_dict=True)
logger.debug(f'shape of output: {output.logits.shape}')


# STEP 3: Determine the number of accepted guesses n
accepted = [False] * gamma
for i in range(gamma):
draft_sampled_token_id = output_draft['new_tokens'][i]['sampled_token']
draft_sampled_token_prob = output_draft['new_tokens'][i]['sampled_prob']
token_prob = self.score_to_prob(output.logits[:,len_input_ids+i-1,:])[0, draft_sampled_token_id]

# reject the sample with probability 1 - p(x)/q(x)
if torch.rand_like(token_prob) > token_prob/draft_sampled_token_prob:
break
else:
accepted[i] = True

logger.debug(f"Speculative Sampling: Accepted: {sum(accepted)}/{gamma}")


# STEP 4: Adjust the distribution from Mp if needed
if not all(accepted):
all_prob = self.score_to_prob(output.logits[:,len_input_ids+i-1,:])
draft_all_prob = output_draft['new_tokens'][i]['all_prob']
adjusted_prob = torch.max(torch.zeros_like(all_prob), all_prob - draft_all_prob)
prob = adjusted_prob / adjusted_prob.sum(dim=1, keepdim=True)
else:
prob = self.score_to_prob(output.logits[:,-1,:])


# STEP 5: Return n tokens from Mq, and one token from Mp
token_from_target_model = self.sample(prob)['sampled_token']
final_sequence = torch.concat([output_draft['sequence'][:,:len_input_ids+sum(accepted)], token_from_target_model], dim=1)

return final_sequence


num_generated_new_tokens = 0
len_raw_input = len(inputs[0])
while num_generated_new_tokens < max_new_tokens:
logger.debug(f'===== New iter =====')
logger.debug(f"input_ids: {inputs}")
sampling_result = speculative_sampling(input_ids=inputs,
model=model,
draft_model=draft_model)
logger.debug(f'sampling result: {sampling_result}')
logger.debug(f'sampling result decoded: {model.decode(sampling_result[0])}')
num_generated_new_tokens += len(sampling_result[0]) - len(inputs[0])
inputs = sampling_result


# if, say, num_generated_new_tokens = 19, and the model accept 3
# tokens, the actual generated tokens would be 22.
return model.decode(inputs[0,:len_raw_input+max_new_tokens])


def stream_inference(self):
raise NotImplementedError("Streaming output for SpeculativeInferencer is not supported yet")
20 changes: 20 additions & 0 deletions tests/pipeline/test_spec_inf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from lmflow.args import InferencerArguments
from lmflow.args import ModelArguments
from lmflow.args import DatasetArguments
from lmflow.models import hf_decoder_model
from src.lmflow.pipeline.inferencer import SpeculativeInferencer
import logging

logging.basicConfig(level=logging.DEBUG)

model_args = ModelArguments(model_name_or_path='gpt2-large')
model = hf_decoder_model.HFDecoderModel(model_args)
draft_model_args = ModelArguments(model_name_or_path='gpt2')
draft_model = hf_decoder_model.HFDecoderModel(draft_model_args)

inferencer_args = InferencerArguments()
data_args = DatasetArguments()

specinf = SpeculativeInferencer(model_args, draft_model_args, data_args, inferencer_args)

specinf.inference(model, draft_model, 'Hello, how are you', gamma=3, max_new_tokens=10)