-
Notifications
You must be signed in to change notification settings - Fork 32.6k
Expand file tree
/
Copy pathutils.py
More file actions
3870 lines (3468 loc) · 205 KB
/
utils.py
File metadata and controls
3870 lines (3468 loc) · 205 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import functools
import inspect
import os
import warnings
from collections.abc import Callable
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional, cast
import torch
import torch.distributed as dist
from torch import nn
from ..cache_utils import (
Cache,
DynamicCache,
EncoderDecoderCache,
QuantizedCache,
StaticCache,
)
from ..dynamic_module_utils import (
check_python_requirements,
get_cached_module_file,
get_class_in_module,
resolve_trust_remote_code,
)
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..integrations.fsdp import is_fsdp_managed_module
from ..masking_utils import create_masks_for_generate
from ..tokenization_python import ExtensionsTrie
from ..utils import (
ModelOutput,
TransformersKwargs,
is_accelerate_available,
logging,
)
from ..utils.generic import is_flash_attention_requested
from .candidate_generator import (
AssistantVocabTranslatorCache,
AssistedCandidateGenerator,
AssistedCandidateGeneratorDifferentTokenizers,
CandidateGenerator,
EarlyExitCandidateGenerator,
PromptLookupCandidateGenerator,
UniversalSpeculativeDecodingGenerator,
_prepare_attention_mask,
_prepare_position_ids,
_prepare_token_type_ids,
)
from .configuration_utils import (
ALL_STATIC_CACHE_IMPLEMENTATIONS,
DEPRECATED_STATIC_CACHE_IMPLEMENTATIONS,
STATIC_CACHE_IMPLEMENTATIONS,
GenerationConfig,
GenerationMode,
)
from .continuous_batching import ContinuousMixin
from .logits_process import (
EncoderNoRepeatNGramLogitsProcessor,
EncoderRepetitionPenaltyLogitsProcessor,
EpsilonLogitsWarper,
EtaLogitsWarper,
ExponentialDecayLengthPenalty,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
InfNanRemoveLogitsProcessor,
LogitNormalization,
LogitsProcessorList,
MinLengthLogitsProcessor,
MinNewTokensLengthLogitsProcessor,
MinPLogitsWarper,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
SequenceBiasLogitsProcessor,
SuppressTokensAtBeginLogitsProcessor,
SuppressTokensLogitsProcessor,
TemperatureLogitsWarper,
TopHLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
UnbatchedClassifierFreeGuidanceLogitsProcessor,
)
from .stopping_criteria import (
ConfidenceCriteria,
EosTokenCriteria,
MaxLengthCriteria,
MaxTimeCriteria,
StoppingCriteria,
StoppingCriteriaList,
StopStringCriteria,
)
if TYPE_CHECKING:
from .._typing import GenerativePreTrainedModel
from ..modeling_utils import PreTrainedModel
from ..tokenization_utils_base import PreTrainedTokenizerBase
from .streamers import BaseStreamer
logger = logging.get_logger(__name__)
if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, add_hook_to_module
# Variable names used to hold the cache at generation time
ALL_CACHE_NAMES = [
"past_key_values", # default
"cache_params", # mamba-based models
"state", # rwkv
"mems", # xlnet
"past_buckets_states", # reformer
]
GENERATION_MODES_MAPPING = {
GenerationMode.SAMPLE: "_sample",
GenerationMode.GREEDY_SEARCH: "_sample",
GenerationMode.BEAM_SEARCH: "_beam_search",
GenerationMode.BEAM_SAMPLE: "_beam_search",
GenerationMode.ASSISTED_GENERATION: "_assisted_decoding",
# Deprecated methods
GenerationMode.DOLA_GENERATION: "transformers-community/dola",
GenerationMode.CONTRASTIVE_SEARCH: "transformers-community/contrastive-search",
GenerationMode.GROUP_BEAM_SEARCH: "transformers-community/group-beam-search",
GenerationMode.CONSTRAINED_BEAM_SEARCH: "transformers-community/constrained-beam-search",
}
@dataclass
class GenerateDecoderOnlyOutput(ModelOutput):
"""
Outputs of decoder-only generation models, when using non-beam methods.
Args:
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
if all batches finished early due to the `eos_token_id`.
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
past_key_values (`Cache`, *optional*, returned when `use_cache=True`):
Returns the model cache, used to speed up decoding. Different models have a different cache format, check
the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
"""
sequences: torch.LongTensor
scores: tuple[torch.FloatTensor] | None = None
logits: tuple[torch.FloatTensor] | None = None
attentions: tuple[tuple[torch.FloatTensor]] | None = None
hidden_states: tuple[tuple[torch.FloatTensor]] | None = None
past_key_values: Cache | None = None
@dataclass
class GenerateEncoderDecoderOutput(ModelOutput):
"""
Outputs of encoder-decoder generation models, when using non-beam methods.
Args:
sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
if all batches finished early due to the `eos_token_id`.
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
sequence_length, sequence_length)`.
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`.
decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
past_key_values (`Cache`, *optional*, returned when `use_cache=True`):
Returns the model cache, used to speed up decoding. Different models have a different cache format, check
the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
"""
sequences: torch.LongTensor
scores: tuple[torch.FloatTensor] | None = None
logits: tuple[torch.FloatTensor] | None = None
encoder_attentions: tuple[torch.FloatTensor] | None = None
encoder_hidden_states: tuple[torch.FloatTensor] | None = None
decoder_attentions: tuple[tuple[torch.FloatTensor]] | None = None
cross_attentions: tuple[tuple[torch.FloatTensor]] | None = None
decoder_hidden_states: tuple[tuple[torch.FloatTensor]] | None = None
past_key_values: Cache | None = None
@dataclass
class GenerateBeamDecoderOnlyOutput(ModelOutput):
"""
Outputs of decoder-only generation models, when using beam methods.
Args:
sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
if all batches finished early due to the `eos_token_id`.
sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True`):
Final beam scores of the generated `sequences`.
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True`):
Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
`(batch_size*num_return_sequences, sequence_length)`.
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
past_key_values (`Cache`, *optional*, returned when `use_cache=True`):
Returns the model cache, used to speed up decoding. Different models have a different cache format, check
the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
"""
sequences: torch.LongTensor
sequences_scores: torch.FloatTensor | None = None
scores: tuple[torch.FloatTensor] | None = None
logits: tuple[torch.FloatTensor] | None = None
beam_indices: torch.LongTensor | None = None
attentions: tuple[tuple[torch.FloatTensor]] | None = None
hidden_states: tuple[tuple[torch.FloatTensor]] | None = None
past_key_values: Cache | None = None
@dataclass
class GenerateBeamEncoderDecoderOutput(ModelOutput):
"""
Outputs of encoder-decoder generation models, when using beam methods.
Args:
sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
if all batches finished early due to the `eos_token_id`.
sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True`):
Final beam scores of the generated `sequences`.
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True`):
Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
`(batch_size*num_return_sequences, sequence_length)`.
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
sequence_length, sequence_length)`.
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size*num_beams*num_return_sequences, sequence_length, hidden_size)`.
decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, num_heads, generated_length,
sequence_length)`.
cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
past_key_values (`Cache`, *optional*, returned when `use_cache=True`):
Returns the model cache, used to speed up decoding. Different models have a different cache format, check
the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
"""
sequences: torch.LongTensor
sequences_scores: torch.FloatTensor | None = None
scores: tuple[torch.FloatTensor] | None = None
logits: tuple[torch.FloatTensor] | None = None
beam_indices: torch.LongTensor | None = None
encoder_attentions: tuple[torch.FloatTensor] | None = None
encoder_hidden_states: tuple[torch.FloatTensor] | None = None
decoder_attentions: tuple[tuple[torch.FloatTensor]] | None = None
cross_attentions: tuple[tuple[torch.FloatTensor]] | None = None
decoder_hidden_states: tuple[tuple[torch.FloatTensor]] | None = None
past_key_values: Cache | None = None
# Typing shortcuts
GenerateNonBeamOutput = GenerateDecoderOnlyOutput | GenerateEncoderDecoderOutput
GenerateBeamOutput = GenerateBeamDecoderOnlyOutput | GenerateBeamEncoderDecoderOutput
GenerateOutput = GenerateNonBeamOutput | GenerateBeamOutput
class GenerationMixin(ContinuousMixin):
"""
A class containing all functions for auto-regressive text generation, to be used as a mixin in model classes.
Inheriting from this class causes the model to have special generation-related behavior, such as loading a
`GenerationConfig` at initialization time or ensuring `generate`-related tests are run in `transformers` CI.
A model class should inherit from `GenerationMixin` to enable calling methods like `generate`, or when it
has defined a custom `generate` method that relies on `GenerationMixin`, directly or indirectly, which
approximately shares the same interface to public methods like `generate`. Three examples:
- `LlamaForCausalLM` should inherit from `GenerationMixin` to enable calling `generate` and other public
methods in the mixin;
- `BlipForQuestionAnswering` has a custom `generate` method that approximately shares the same interface as
`GenerationMixin.generate` (it has a few extra arguments, and the same output). That function also calls
`GenerationMixin.generate` indirectly, through an inner model. As such, `BlipForQuestionAnswering` should
inherit from `GenerationMixin` to benefit from all generation-related automation in our codebase;
- `BarkModel` has a custom `generate` method and one of its inner models calls `GenerationMixin.generate`.
However, its `generate` does not share the same interface as `GenerationMixin.generate`. In this case,
`BarkModel` should NOT inherit from `GenerationMixin`, as it breaks the `generate` interface.
The class exposes [`~generation.GenerationMixin.generate`], which can be used for:
- *greedy decoding* if `num_beams=1` and `do_sample=False`
- *multinomial sampling* if `num_beams=1` and `do_sample=True`
- *beam-search decoding* if `num_beams>1` and `do_sample=False`
- *beam-search multinomial sampling* if `num_beams>1` and `do_sample=True`
- *assisted decoding* if `assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
"""
# Should be overwritten by models that can generate non-text output
output_modalities = ("text",)
def adjust_generation_fn(
self: "GenerativePreTrainedModel",
generation_config,
from_auto_class,
from_pipeline,
pretrained_model_name_or_path,
cache_dir,
force_download,
proxies,
local_files_only,
token,
revision,
subfolder,
trust_remote_code,
**kwargs,
):
if self.can_generate() and generation_config is not None:
self.generation_config = self.generation_config.from_dict(generation_config.to_dict())
elif self.can_generate() and pretrained_model_name_or_path is not None:
repo_loading_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"local_files_only": local_files_only,
"token": token,
"revision": revision,
"subfolder": subfolder,
**kwargs,
}
# Load generation config
try:
self.generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
**repo_loading_kwargs,
)
except OSError:
# `self` already has a generation config created from model config, but model config will
# not contain any generation-specific params. These are popped at config's `__init__`.
# Thus we have to load from `config.json` and create a generation config from it (for BART)
logger.info(
"Generation config file not found, using a generation config created from the model config."
)
self.generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path,
config_file_name="config.json",
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
_from_model_config=True,
**repo_loading_kwargs,
)
# Load custom generate function if `pretrained_model_name_or_path` defines it (and override `generate`)
if hasattr(self, "load_custom_generate") and trust_remote_code:
try:
custom_generate = self.load_custom_generate(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **repo_loading_kwargs
)
self.generate = functools.partial(custom_generate, model=self)
except OSError: # there is no custom generate function
pass
def load_custom_generate(
self,
pretrained_model_name_or_path: str | os.PathLike | None = None,
trust_remote_code: bool | None = None,
**kwargs,
) -> Callable:
"""
Loads and returns a custom generate function, given a model repo.
Args:
pretrained_model_name_or_path (`str` or `os.PathLike`):
Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a *directory* containing model weights saved using
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
trust_remote_code (`bool`, *optional*):
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
should only be set to `True` for repositories you trust and in which you have read the code, as it will
execute code present on the Hub on your local machine.
**kwargs:
Additional keyword arguments for remote code loading.
Raises:
OSError: If `pretrained_model_name_or_path` does not contain a `custom_generate` subdirectory.
Returns:
A callable that can be used to generate text.
"""
# Fetches the generate.py file from the model repo. If it doesn't exist, a file in `.no_exist` cache directory
# is created (preventing future hub requests), and an OSError is raised.
try:
module = get_cached_module_file(
pretrained_model_name_or_path, module_file="custom_generate/generate.py", **kwargs
)
except OSError:
raise OSError(
f"`{pretrained_model_name_or_path}` does not contain a `custom_generate` subdirectory with a "
"`generate.py` file, can't load the custom generate function."
)
# Handle opt-in `trust_remote_code` and related exceptions
is_local_code = os.path.exists(pretrained_model_name_or_path)
error_message = (
f"The repository `{pretrained_model_name_or_path}` contains custom generation code that will override "
"the default `generate` method."
)
resolve_trust_remote_code(
trust_remote_code,
pretrained_model_name_or_path,
has_local_code=is_local_code,
has_remote_code=not is_local_code,
error_message=error_message,
)
# Load the custom generate function
check_python_requirements(
pretrained_model_name_or_path, requirements_file="custom_generate/requirements.txt", **kwargs
)
custom_generate_function = get_class_in_module("generate", module)
return custom_generate_function
def prepare_inputs_for_generation(
self: "GenerativePreTrainedModel",
input_ids: torch.LongTensor,
next_sequence_length: int | None = None,
past_key_values: Cache | None = None,
attention_mask: torch.LongTensor | None = None,
inputs_embeds: torch.FloatTensor | None = None,
is_first_iteration: bool | None = False,
**kwargs,
):
"""
Prepare the model inputs for generation. Notable steps include selecting the correct input key and cloning when appropriate,
creating position_ids from the attention_mask when missing, slicing inputs and converting 2D attention masks to 4D for
compilable caches, and finally forwarding all additional keyword arguments unchanged to the model's forward pass.
See the forward pass in the model documentation for expected arguments (different models might have different
requirements for e.g. `past_key_values`). This function should work as is for most LLMs.
"""
# Instantiate output
model_inputs = {}
# 1. Prepare base model inputs
input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step for every prompt.
if not self.config.is_encoder_decoder and inputs_embeds is not None and is_first_iteration:
model_inputs[input_ids_key] = None
prompt_embeds = (
inputs_embeds[:, -next_sequence_length:, :] if next_sequence_length is not None else inputs_embeds
)
model_inputs["inputs_embeds"] = prompt_embeds.clone(memory_format=torch.contiguous_format)
batch_size, sequence_length = prompt_embeds.shape[:2]
else:
# `clone` calls in this function ensure a consistent stride. See #32227
input_ids = input_ids[:, -next_sequence_length:] if next_sequence_length is not None else input_ids
model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
batch_size, sequence_length = input_ids.shape[:2] # we slice here as some models may have them 3D
# 2. Add important inputs
if past_key_values is not None:
model_inputs["past_key_values"] = past_key_values
position_ids_key = "decoder_position_ids" if self.config.is_encoder_decoder else "position_ids"
if (position_ids := kwargs.pop(position_ids_key, None)) is not None:
model_inputs[position_ids_key] = position_ids
if (token_type_ids := kwargs.pop("token_type_ids", None)) is not None:
model_inputs["token_type_ids"] = token_type_ids
# 3. Slice model inputs if it's an input that should have the same length as `input_ids`
for model_input_name in [position_ids_key, "token_type_ids"]:
model_input = model_inputs.get(model_input_name)
if model_input is not None and model_input.shape[-1] != sequence_length:
# Input can be 2D or 3D, and we always slice on `seq-length` (last dim)
model_input = model_input[..., -sequence_length:].clone(memory_format=torch.contiguous_format)
model_inputs[model_input_name] = model_input
# 4. Create 4D attention mask is we are using a compilable cache (important for performant compiled forward
# pass)
encoder_attention_mask = attention_mask if self.config.is_encoder_decoder else None
attention_mask_key = "decoder_attention_mask" if self.config.is_encoder_decoder else "attention_mask"
attention_mask = (
kwargs.pop("decoder_attention_mask", None) if self.config.is_encoder_decoder else attention_mask
)
if (
isinstance(past_key_values, Cache)
and past_key_values.is_compileable
and attention_mask is not None
and attention_mask.ndim == 2
):
# Some models may overwrite the general one
causal_mask_creation_function = getattr(self, "create_masks_for_generate", create_masks_for_generate)
attention_mask = causal_mask_creation_function(
config=self.config,
# we only need batch size, seq_length, dtype and device here - so we pass a 0-sized tensor with only the metadata
inputs_embeds=torch.empty((batch_size, sequence_length, 0), dtype=self.dtype, device=input_ids.device),
attention_mask=attention_mask,
past_key_values=model_inputs.get("past_key_values"),
position_ids=model_inputs.get(position_ids_key),
token_type_ids=model_inputs.get("token_type_ids"),
is_first_iteration=is_first_iteration,
)
if attention_mask is not None:
model_inputs[attention_mask_key] = attention_mask
if encoder_attention_mask is not None:
model_inputs["attention_mask"] = encoder_attention_mask
# 5. Forward ALL kwargs that are uninitialized, e.g. `use_cache` (except a few exceptions)
kwargs_to_avoid_forwarding = ("labels", "next_sequence_length")
for key, value in kwargs.items():
if key not in model_inputs and key not in kwargs_to_avoid_forwarding:
model_inputs[key] = value
# BC for remote code models only: create `cache_position` on the fly here, as we don't want to maintain them in kwargs
# between `forward`s
if self.is_remote_code() and "cache_position" in set(inspect.signature(self.forward).parameters):
logger.warning_once(
"The remote code model you are currently using seems to expect `cache_position`. This arg has been "
"removed from the Transformers library, and will stop being created in `generate` even for remote code models "
"in a future release. Please open a PR on the remote code hub repo to remove any usage of `cache_position`."
)
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(sequence_length, device=input_ids.device) + past_seen_tokens
model_inputs["cache_position"] = cache_position
return model_inputs
def _prepare_model_inputs(
self: "GenerativePreTrainedModel",
inputs: torch.Tensor | None,
bos_token_id: torch.Tensor | None,
model_kwargs: dict[str, torch.Tensor],
) -> tuple[torch.Tensor, str | None, dict[str, torch.Tensor]]:
"""
This function extracts the model-specific `inputs` for generation.
"""
# 1. retrieve all kwargs that are non-None or non-model input related.
# some encoder-decoder models have different names for model and encoder
if (
self.config.is_encoder_decoder
and hasattr(self, "encoder")
and self.encoder.main_input_name != self.main_input_name
):
input_name = self.encoder.main_input_name
else:
input_name = self.main_input_name
# 2. check whether model_input_name is passed as kwarg
# if yes and `inputs` is None use kwarg inputs
inputs_kwarg = model_kwargs.pop(input_name, None)
if inputs_kwarg is not None and inputs is not None:
raise ValueError(
f"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed. "
f"Make sure to either pass {inputs} or {input_name}=..."
)
elif inputs_kwarg is not None:
inputs = inputs_kwarg
# 3. In the presence of `inputs_embeds` for text models:
# - decoder-only models should complain if the user attempts to pass `inputs_embeds`, but the model
# doesn't have its forwarding implemented. `inputs_embeds` is kept in `model_kwargs` and can coexist with
# input_ids (`inputs_embeds` will be used in the 1st generation step, as opposed to `input_ids`)
# - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and
# pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states.
if input_name == "input_ids" and "inputs_embeds" in model_kwargs:
if model_kwargs["inputs_embeds"] is None:
model_kwargs.pop("inputs_embeds")
elif not self.config.is_encoder_decoder:
has_inputs_embeds_forwarding = "inputs_embeds" in set(
inspect.signature(self.prepare_inputs_for_generation).parameters.keys()
)
if not has_inputs_embeds_forwarding:
raise ValueError(
f"You passed `inputs_embeds` to `.generate()`, but the model class {self.__class__.__name__} "
"doesn't have its forwarding implemented. See the GPT2 implementation for an example "
"(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!"
)
# In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of
# the attention mask) can rely on the actual model input.
model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(
inputs, bos_token_id, model_kwargs=model_kwargs
)
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
else:
if inputs is not None:
raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.")
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
# 4. if `inputs` is still None, try to create `input_ids` from BOS token
inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
return inputs, input_name, model_kwargs
def _maybe_initialize_input_ids_for_generation(
self: "GenerativePreTrainedModel",
inputs: torch.Tensor | None,
bos_token_id: torch.Tensor | None,
model_kwargs: dict[str, torch.Tensor],
) -> torch.LongTensor:
"""Initializes input ids for generation, if necessary."""
if inputs is not None:
return inputs
encoder_outputs = model_kwargs.get("encoder_outputs")
last_hidden_state = getattr(encoder_outputs, "last_hidden_state", None)
if self.config.is_encoder_decoder and last_hidden_state is not None:
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
shape = last_hidden_state.size()[:-1]
return torch.ones(shape, dtype=torch.long, device=self.device) * -100
# If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with
# soft-prompting or in multimodal implementations built on top of decoder-only language models.
batch_size = 1
for value in model_kwargs.values():
if isinstance(value, torch.Tensor):
batch_size = value.shape[0]
break
if "inputs_embeds" in model_kwargs:
return torch.ones((batch_size, 0), dtype=torch.long, device=self.device)
if bos_token_id is None:
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
def _prepare_position_ids_for_generation(self, inputs_tensor, model_kwargs):
"""
Tries to infer position ids given attention mask and past kv cache length. All instances when
`position_ids=None` should call this method.
"""
# `input_ids` may be present in the model kwargs, instead of being the main input (e.g. multimodal model)
if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0:
inputs_tensor = model_kwargs["input_ids"]
seq_length = inputs_tensor.shape[1]
if (attention_mask := model_kwargs.get("attention_mask")) is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
# We need this as otherwise padding tokens appear as -1 in position
position_ids = position_ids.masked_fill(attention_mask == 0, 0)
else:
past_length = 0
if (cache := model_kwargs.get("past_key_values")) is not None:
past_length = cache.get_seq_length()
position_ids = torch.arange(seq_length + past_length, dtype=torch.long, device=inputs_tensor.device)
position_ids = position_ids.unsqueeze(0)
return position_ids
def _prepare_attention_mask_for_generation(
self,
inputs_tensor: torch.Tensor,
generation_config: GenerationConfig,
model_kwargs: dict[str, Any],
) -> torch.LongTensor:
pad_token_id = generation_config._pad_token_tensor
eos_token_id = generation_config._eos_token_tensor
# `input_ids` may be present in the model kwargs, instead of being the main input (e.g. multimodal model)
if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0:
inputs_tensor = model_kwargs["input_ids"]
# No information for attention mask inference -> return default attention mask
default_attention_mask = torch.ones(inputs_tensor.shape[:2], dtype=torch.long, device=inputs_tensor.device)
if pad_token_id is None:
return default_attention_mask
is_input_ids = len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long]
if not is_input_ids:
return default_attention_mask
is_pad_token_in_inputs = (pad_token_id is not None) and (torch.isin(inputs_tensor, pad_token_id).any())
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~(
torch.isin(eos_token_id, pad_token_id).any()
)
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
attention_mask_from_padding = inputs_tensor.ne(pad_token_id).long()
attention_mask = (
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
)
return attention_mask
def _prepare_encoder_decoder_kwargs_for_generation(
self: "GenerativePreTrainedModel",
inputs_tensor: torch.Tensor,
model_kwargs,
model_input_name: str | None,
generation_config: GenerationConfig,
) -> dict[str, Any]:
# 1. get encoder
encoder = self.get_encoder()
# Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device
# as the inputs.
if hasattr(self, "hf_device_map"):
if hasattr(encoder, "_hf_hook"):
encoder._hf_hook.io_same_device = True
else:
add_hook_to_module(encoder, AlignDevicesHook(io_same_device=True))
# 2. Prepare encoder args and encoder kwargs from model kwargs and generation config.
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache", "past_key_values", "cache_params"]
encoder_kwargs = {
argument: value
for argument, value in model_kwargs.items()
if not any(argument.startswith(p) for p in irrelevant_prefix)
}
encoder_signature = set(inspect.signature(encoder.forward).parameters)
encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
if not encoder_accepts_wildcard:
encoder_kwargs = {
argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
}
encoder_kwargs["output_attentions"] = generation_config.output_attentions
encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states
# 3. make sure that encoder returns `ModelOutput`
model_input_name = model_input_name if model_input_name is not None else self.main_input_name
encoder_kwargs["return_dict"] = True
encoder_kwargs[model_input_name] = inputs_tensor
model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)
return model_kwargs
def _prepare_decoder_input_ids_for_generation(
self: "GenerativePreTrainedModel",
batch_size: int,
model_input_name: str,
model_kwargs: dict[str, torch.Tensor],
decoder_start_token_id: torch.Tensor,
device: torch.device | None = None,
) -> tuple[torch.LongTensor, dict[str, torch.Tensor]]:
"""Prepares `decoder_input_ids` for generation with encoder-decoder models"""
# 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming,
# we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input.
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
decoder_input_ids = model_kwargs.pop("decoder_input_ids")
elif "input_ids" in model_kwargs and model_input_name != "input_ids":
decoder_input_ids = model_kwargs.pop("input_ids")
else:
decoder_input_ids = None
# 2. `decoder_start_token_id` must have shape (batch_size, 1)
if device is None:
device = self.device
if decoder_start_token_id.ndim == 1:
if decoder_start_token_id.shape[0] != batch_size:
raise ValueError(
f"`decoder_start_token_id` expected to have length {batch_size} but got {decoder_start_token_id.shape[0]}"
)
decoder_start_token_id = decoder_start_token_id.view(-1, 1)
else:
decoder_start_token_id = (
torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
)
# 3. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that.
# no user input -> use decoder_start_token_id as decoder_input_ids
if decoder_input_ids is None:
decoder_input_ids = decoder_start_token_id
# exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token. Note that the
# original checkpoints can't be detected through `self.__class__.__name__.lower()`, needing custom logic.
# See: https://github.com/huggingface/transformers/pull/31470
elif "donut" in self.__class__.__name__.lower() or (
self.config.model_type == "vision-encoder-decoder" and "donut" in self.config.encoder.model_type.lower()
):
pass
elif self.config.model_type == "whisper":
pass
# user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
# decoder_attention_mask if provided)
elif (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item():
decoder_input_ids = torch.cat([decoder_start_token_id, decoder_input_ids], dim=-1)
if "decoder_attention_mask" in model_kwargs:
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
decoder_attention_mask = torch.cat(
(torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask),
dim=-1,
)
model_kwargs["decoder_attention_mask"] = decoder_attention_mask
return decoder_input_ids, model_kwargs
@staticmethod
def _expand_inputs_for_generation(
expand_size: int = 1,
is_encoder_decoder: bool = False,
input_ids: torch.LongTensor | None = None,
**model_kwargs,
) -> tuple[torch.LongTensor, dict[str, Any]]:
"""Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
# Do not call torch.repeat_interleave if expand_size is 1 because it clones
# the input tensor and thus requires more memory although no change is applied
if expand_size == 1:
return input_ids, model_kwargs
def _expand_dict_for_generation(dict_to_expand):
for key in dict_to_expand:
if dict_to_expand[key] is not None and isinstance(dict_to_expand[key], torch.Tensor):
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
return dict_to_expand
if input_ids is not None:
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
model_kwargs = _expand_dict_for_generation(model_kwargs)
if is_encoder_decoder:
if model_kwargs.get("encoder_outputs") is None:
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
return input_ids, model_kwargs
def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: dict[str, Any],
is_encoder_decoder: bool = False,
num_new_tokens: int = 1,
) -> dict[str, Any]:
"""
Update the model kwargs to account for the `num_new_tokens` new tokens that were just generated.
That is, update the `attention_mask`, `position_ids`, and `token_type_ids` to account for the
new tokens of the total sequence.
Note that this function never slices inputs, this is performed in `prepare_inputs_for_generation`.
"""
# update past_key_values keeping its naming used in model code
for possible_cache_name in ALL_CACHE_NAMES:
if possible_cache_name in outputs:
# TODO (joao): remove output/input mismatch when these old models (xlnet, reformer) are deprecated
if possible_cache_name in ("past_buckets_states", "mems"):
cache_name = "past_key_values"
else:
cache_name = possible_cache_name
model_kwargs[cache_name] = getattr(outputs, possible_cache_name)
break
# update token_type_ids with last value
if (token_type_ids := model_kwargs.get("token_type_ids")) is not None:
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -num_new_tokens:]], dim=-1)
# Position ids (2D or 3D sometimes)
position_ids_key = "position_ids" if not is_encoder_decoder else "decoder_position_ids"
if (position_ids := model_kwargs.get(position_ids_key)) is not None:
# We want to expand to the same number of dims which is not always the same
required_dim = [1] * (position_ids.dim() - 1) + [-1]
next_position_ids = (
torch.arange(num_new_tokens, dtype=position_ids.dtype, device=position_ids.device).view(*required_dim)
+ position_ids[..., -1:]
+ 1
)
next_position_ids = torch.cat([position_ids, next_position_ids], dim=-1)
model_kwargs[position_ids_key] = next_position_ids
# 2D attention mask (always 2D here)
attention_mask_key = "attention_mask" if not is_encoder_decoder else "decoder_attention_mask"
if (attention_mask := model_kwargs.get(attention_mask_key)) is not None:
model_kwargs[attention_mask_key] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], num_new_tokens))], dim=-1
)
return model_kwargs
def _get_candidate_generator(
self: "GenerativePreTrainedModel",
generation_config: GenerationConfig,
input_ids: torch.LongTensor,
inputs_tensor: torch.Tensor,
logits_processor: LogitsProcessorList,
model_kwargs: dict[str, Any],
assistant_model: Optional["PreTrainedModel"] = None,
target_tokenizer: Optional["PreTrainedTokenizerBase"] = None,
assistant_tokenizer: Optional["PreTrainedTokenizerBase"] = None,
) -> CandidateGenerator:
"""
Returns the candidate generator to be used in `assisted_generation`
"""
different_tokenizers = all(v is not None for v in (assistant_model, target_tokenizer, assistant_tokenizer))
if generation_config.assistant_early_exit is not None:
candidate_generator = EarlyExitCandidateGenerator(
input_ids=input_ids,
assistant_model=self,
generation_config=generation_config,
model_kwargs=model_kwargs,
inputs_tensor=inputs_tensor,
logits_processor=logits_processor,
)
elif generation_config.prompt_lookup_num_tokens is not None:
candidate_generator = PromptLookupCandidateGenerator(
eos_token_id=generation_config._eos_token_tensor,
num_output_tokens=generation_config.prompt_lookup_num_tokens,
max_matching_ngram_size=generation_config.max_matching_ngram_size or 2,
max_length=generation_config.max_length,
logits_processor=logits_processor,
vocab_size=self.config.get_text_config().vocab_size,
)
elif different_tokenizers:
assistant_model = cast("PreTrainedModel", assistant_model)
target_tokenizer = cast("PreTrainedTokenizerBase", target_tokenizer)
assistant_tokenizer = cast("PreTrainedTokenizerBase", assistant_tokenizer)
if generation_config.do_sample is True:
atm_translator = AssistantVocabTranslatorCache.get_translator(
target_tokenizer,
assistant_tokenizer,
self.config.get_text_config().vocab_size,
assistant_model=assistant_model,
assistant_prune_lm_head=True, # prune LM head of assistant model
)
# Since we prune the LM head, we cannot use the repetition penalty on the assistant model due to mismatches between token ids and logits index
assistant_model.generation_config.repetition_penalty = None
candidate_generator = UniversalSpeculativeDecodingGenerator(
input_ids=input_ids,
assistant_model=assistant_model,
generation_config=generation_config,
model_kwargs=model_kwargs,
inputs_tensor=inputs_tensor,
logits_processor=logits_processor,
target_tokenizer=target_tokenizer,
assistant_tokenizer=assistant_tokenizer,
atm_translator=atm_translator,
)
elif generation_config.do_sample is False:
candidate_generator = AssistedCandidateGeneratorDifferentTokenizers(
input_ids=input_ids,
assistant_model=assistant_model,
generation_config=generation_config,