Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/Instruction/支持的模型和数据集.md
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,7 @@
|[AIDC-AI/Ovis2-34B](https://modelscope.cn/models/AIDC-AI/Ovis2-34B)|ovis2|ovis2|transformers>=4.46.2, moviepy<2|&#x2718;|vision|[AIDC-AI/Ovis2-34B](https://huggingface.co/AIDC-AI/Ovis2-34B)|
|[XiaomiMiMo/MiMo-VL-7B-SFT](https://modelscope.cn/models/XiaomiMiMo/MiMo-VL-7B-SFT)|mimo_vl|mimo_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|&#x2718;|vision, video|[XiaomiMiMo/MiMo-VL-7B-SFT](https://huggingface.co/XiaomiMiMo/MiMo-VL-7B-SFT)|
|[XiaomiMiMo/MiMo-VL-7B-RL](https://modelscope.cn/models/XiaomiMiMo/MiMo-VL-7B-RL)|mimo_vl|mimo_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|&#x2718;|vision, video|[XiaomiMiMo/MiMo-VL-7B-RL](https://huggingface.co/XiaomiMiMo/MiMo-VL-7B-RL)|
|[mispeech/midashenglm-7b](https://modelscope.cn/models/mispeech/midashenglm-7b)|midashenglm|midashenglm|transformers>=4.52, soundfile|&#x2718;|audio|[mispeech/midashenglm-7b](https://huggingface.co/mispeech/midashenglm-7b)|
|[ZhipuAI/glm-4v-9b](https://modelscope.cn/models/ZhipuAI/glm-4v-9b)|glm4v|glm4v|transformers>=4.42,<4.45|&#x2718;|-|[THUDM/glm-4v-9b](https://huggingface.co/THUDM/glm-4v-9b)|
|[ZhipuAI/cogagent-9b-20241220](https://modelscope.cn/models/ZhipuAI/cogagent-9b-20241220)|glm4v|glm4v|transformers>=4.42|&#x2718;|-|[THUDM/cogagent-9b-20241220](https://huggingface.co/THUDM/cogagent-9b-20241220)|
|[ZhipuAI/GLM-4.1V-9B-Base](https://modelscope.cn/models/ZhipuAI/GLM-4.1V-9B-Base)|glm4_1v|glm4_1v|transformers>=4.53|&#x2718;|-|[THUDM/GLM-4.1V-9B-Base](https://huggingface.co/THUDM/GLM-4.1V-9B-Base)|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,7 @@ The table below introduces the models integrated with ms-swift:
|[AIDC-AI/Ovis2-34B](https://modelscope.cn/models/AIDC-AI/Ovis2-34B)|ovis2|ovis2|transformers>=4.46.2, moviepy<2|&#x2718;|vision|[AIDC-AI/Ovis2-34B](https://huggingface.co/AIDC-AI/Ovis2-34B)|
|[XiaomiMiMo/MiMo-VL-7B-SFT](https://modelscope.cn/models/XiaomiMiMo/MiMo-VL-7B-SFT)|mimo_vl|mimo_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|&#x2718;|vision, video|[XiaomiMiMo/MiMo-VL-7B-SFT](https://huggingface.co/XiaomiMiMo/MiMo-VL-7B-SFT)|
|[XiaomiMiMo/MiMo-VL-7B-RL](https://modelscope.cn/models/XiaomiMiMo/MiMo-VL-7B-RL)|mimo_vl|mimo_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|&#x2718;|vision, video|[XiaomiMiMo/MiMo-VL-7B-RL](https://huggingface.co/XiaomiMiMo/MiMo-VL-7B-RL)|
|[mispeech/midashenglm-7b](https://modelscope.cn/models/mispeech/midashenglm-7b)|midashenglm|midashenglm|transformers>=4.52, soundfile|&#x2718;|audio|[mispeech/midashenglm-7b](https://huggingface.co/mispeech/midashenglm-7b)|
|[ZhipuAI/glm-4v-9b](https://modelscope.cn/models/ZhipuAI/glm-4v-9b)|glm4v|glm4v|transformers>=4.42,<4.45|&#x2718;|-|[THUDM/glm-4v-9b](https://huggingface.co/THUDM/glm-4v-9b)|
|[ZhipuAI/cogagent-9b-20241220](https://modelscope.cn/models/ZhipuAI/cogagent-9b-20241220)|glm4v|glm4v|transformers>=4.42|&#x2718;|-|[THUDM/cogagent-9b-20241220](https://huggingface.co/THUDM/cogagent-9b-20241220)|
|[ZhipuAI/GLM-4.1V-9B-Base](https://modelscope.cn/models/ZhipuAI/GLM-4.1V-9B-Base)|glm4_1v|glm4_1v|transformers>=4.53|&#x2718;|-|[THUDM/GLM-4.1V-9B-Base](https://huggingface.co/THUDM/GLM-4.1V-9B-Base)|
Expand Down
1 change: 1 addition & 0 deletions swift/llm/model/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ class MLLMModelType:
ovis1_6_llama3 = 'ovis1_6_llama3'
ovis2 = 'ovis2'
mimo_vl = 'mimo_vl'
midashenglm = 'midashenglm'

glm4v = 'glm4v'
glm4_1v = 'glm4_1v'
Expand Down
23 changes: 23 additions & 0 deletions swift/llm/model/model/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,29 @@ def get_model_tokenizer_qwen2_5_omni(model_dir, *args, **kwargs):
))


def get_model_tokenizer_midashenglm(*args, **kwargs):
model, tokenizer = get_model_tokenizer_multimodal(*args, **kwargs)
if model is not None:
model.audio_encoder.float()
patch_output_clone(model.decoder.model.embed_tokens)
return model, tokenizer


register_model(
ModelMeta(
MLLMModelType.midashenglm,
[ModelGroup([
Model('mispeech/midashenglm-7b', 'mispeech/midashenglm-7b'),
])],
TemplateType.midashenglm,
get_model_tokenizer_midashenglm,
model_arch=ModelArch.midashenglm,
architectures=['MiDashengLMModel'],
requires=['transformers>=4.52', 'soundfile'],
tags=['audio'],
))


def get_model_tokenizer_qwen2_audio(*args, **kwargs):
from transformers import Qwen2AudioForConditionalGeneration
kwargs['automodel_class'] = kwargs['automodel_class'] or Qwen2AudioForConditionalGeneration
Expand Down
10 changes: 10 additions & 0 deletions swift/llm/model/model_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ class MLLMModelArch:
mistral_2503 = 'mistral_2503'
keye_vl = 'keye_vl'

midashenglm = 'midashenglm'


class ModelArch(LLMModelArch, MLLMModelArch):
pass
Expand Down Expand Up @@ -517,6 +519,14 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
generator=['talker', 'token2wav'],
))

register_model_arch(
MultiModelKeys(
MLLMModelArch.midashenglm,
language_model='decoder',
aligner=['audio_projector'],
vision_tower=['audio_encoder'],
))

register_model_arch(
MultiModelKeys(
MLLMModelArch.glm4v,
Expand Down
22 changes: 7 additions & 15 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,31 +298,23 @@ def _replace_start_image_tags(inputs: StdTemplateInputs):
inputs.generate_mode = generate_mode

@staticmethod
def _extend_tokens(input_ids: List[int], labels: Optional[List[int]], replace_idx_list: List[int],
get_new_tokens: Callable[[int], List[int]]) -> Tuple[List[int], Optional[List[int]]]:
def _extend_tokens(
input_ids: List[int], labels: Optional[List[int]], loss_scale: Optional[List[float]],
replace_idx_list: List[int],
get_new_tokens: Callable[[int], List[int]]) -> Tuple[List[int], Optional[List[int]], Optional[List[float]]]:
added_tokens_len = 0
for i, idx in enumerate(replace_idx_list):
new_tokens = get_new_tokens(i)
token_len = len(new_tokens)
input_ids = input_ids[:idx + added_tokens_len] + new_tokens + input_ids[added_tokens_len + idx + 1:]
if labels:
labels = labels[:idx + added_tokens_len] + [-100] * token_len + labels[added_tokens_len + idx + 1:]
added_tokens_len += token_len - 1
return input_ids, labels

@staticmethod
def _extend_loss_scale(loss_scale: Optional[List[float]], replace_idx_list: List[int],
get_new_tokens: Callable[[int], List[int]]) -> Optional[List[float]]:
if loss_scale:
added_tokens_len = 0
for i, idx in enumerate(replace_idx_list):
new_tokens = get_new_tokens(i)
token_len = len(new_tokens)
if loss_scale:
scale_idx = loss_scale[idx + added_tokens_len]
loss_scale = loss_scale[:idx + added_tokens_len] + [scale_idx] * token_len + loss_scale[added_tokens_len
+ idx + 1:]
added_tokens_len += token_len - 1
return loss_scale
added_tokens_len += token_len - 1
return input_ids, labels, loss_scale
Comment on lines +301 to +317
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of _extend_tokens uses list slicing and concatenation inside a loop (list = list[:idx] + ...). This can be inefficient for large lists as it creates a new list in each iteration, leading to quadratic complexity in the worst case (O(M*N) where M is the number of replacements and N is the list length). A more performant approach would be to build the new lists by appending segments, which would be closer to linear time complexity.

    def _extend_tokens(
            input_ids: List[int], labels: Optional[List[int]], loss_scale: Optional[List[float]],
            replace_idx_list: List[int],
            get_new_tokens: Callable[[int], List[int]]) -> Tuple[List[int], Optional[List[int]], Optional[List[float]]]:
        if not replace_idx_list:
            return input_ids, labels, loss_scale

        new_input_ids = []
        new_labels = [] if labels is not None else None
        new_loss_scale = [] if loss_scale is not None else None

        last_idx = 0
        for i, idx in enumerate(replace_idx_list):
            new_tokens = get_new_tokens(i)

            new_input_ids.extend(input_ids[last_idx:idx])
            if labels is not None:
                new_labels.extend(labels[last_idx:idx])
            if loss_scale is not None:
                new_loss_scale.extend(loss_scale[last_idx:idx])

            new_input_ids.extend(new_tokens)
            if labels is not None:
                new_labels.extend([-100] * len(new_tokens))
            if loss_scale is not None:
                scale_val = loss_scale[idx]
                new_loss_scale.extend([scale_val] * len(new_tokens))

            last_idx = idx + 1

        new_input_ids.extend(input_ids[last_idx:])
        if labels is not None:
            new_labels.extend(labels[last_idx:])
        if loss_scale is not None:
            new_loss_scale.extend(loss_scale[last_idx:])

        return new_input_ids, new_labels, new_loss_scale


def forward_context(self, model, inputs):
return nullcontext()
Expand Down
1 change: 1 addition & 0 deletions swift/llm/template/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class MLLMTemplateType:
ovis1_6_llama3 = 'ovis1_6_llama3'
ovis2 = 'ovis2'
mimo_vl = 'mimo_vl'
midashenglm = 'midashenglm'

llama3_1_omni = 'llama3_1_omni'
llama3_2_vision = 'llama3_2_vision'
Expand Down
4 changes: 2 additions & 2 deletions swift/llm/template/template/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from . import (baidu, bert, deepseek, emu3, gemma, glm, idefics3, internlm, internvl, kwai, llama, llava, llm, megrez,
microsoft, minicpm, minimax, mistral, molmo, moonshot, mplug, openbuddy, pixtral, qwen, stepfun, valley,
yi)
microsoft, midashenglm, minicpm, minimax, mistral, molmo, moonshot, mplug, openbuddy, pixtral, qwen,
stepfun, valley, yi)
4 changes: 2 additions & 2 deletions swift/llm/template/template/emu3.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
image_prompts.append(self.tokenizer.encode(image_prompt))

# Insert image tokens into input_ids
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, lambda i: image_prompts[i])
loss_scale = self._extend_loss_scale(loss_scale, idx_list, lambda i: image_prompts[i])
input_ids, labels, loss_scale = self._extend_tokens(input_ids, labels, loss_scale, idx_list,
lambda i: image_prompts[i])
return {'input_ids': input_ids, 'labels': labels, 'loss_scale': loss_scale}


Expand Down
12 changes: 6 additions & 6 deletions swift/llm/template/template/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
loss_scale = encoded.get('loss_scale', None)
idx_list = findall(input_ids, self.boi_token_id)
img_tokens = self._tokenize(self.processor.full_image_sequence)
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, lambda _: img_tokens)
loss_scale = self._extend_loss_scale(loss_scale, idx_list, lambda _: img_tokens)
input_ids, labels, loss_scale = self._extend_tokens(input_ids, labels, loss_scale, idx_list,
lambda _: img_tokens)

# TODO: customize
processor_kwargs = Gemma3ProcessorKwargs._defaults['images_kwargs']
Expand Down Expand Up @@ -171,8 +171,8 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
if inputs.images:
idx_list = findall(input_ids, self.boi_token_id)
img_tokens = self._tokenize(processor.full_image_sequence)
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, lambda _: img_tokens)
loss_scale = self._extend_loss_scale(loss_scale, idx_list, lambda _: img_tokens)
input_ids, labels, loss_scale = self._extend_tokens(input_ids, labels, loss_scale, idx_list,
lambda _: img_tokens)

# Process images
processor_kwargs = Gemma3nProcessorKwargs._defaults.get('images_kwargs', {})
Expand All @@ -188,8 +188,8 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
if audio_idx_list:
# Get audio token sequence from processor
audio_tokens = self._tokenize(processor.full_audio_sequence)
input_ids, labels = self._extend_tokens(input_ids, labels, audio_idx_list, lambda _: audio_tokens)
loss_scale = self._extend_loss_scale(loss_scale, audio_idx_list, lambda _: audio_tokens)
input_ids, labels, loss_scale = self._extend_tokens(input_ids, labels, loss_scale, audio_idx_list,
lambda _: audio_tokens)

# Process audios
processor_kwargs = Gemma3nProcessorKwargs._defaults.get('audio_kwargs', {})
Expand Down
4 changes: 2 additions & 2 deletions swift/llm/template/template/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ def _get_new_tokens(i):
'<IMG_CONTEXT>', add_special_tokens=False) * self.num_image_token * num_patches[i]
return img_tokens

encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
encoded['loss_scale'] = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens)
encoded['input_ids'], encoded['labels'], encoded['loss_scale'] = self._extend_tokens(
input_ids, labels, loss_scale, idx_list, _get_new_tokens)
encoded['pixel_values'] = pixel_values
return encoded

Expand Down
4 changes: 2 additions & 2 deletions swift/llm/template/template/kwai.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def _get_new_tokens(i):
token_len = (media_grid_thw[i].prod() // merge_length)
return [media_token] * token_len

input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
loss_scale = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens)
input_ids, labels, loss_scale = self._extend_tokens(input_ids, labels, loss_scale, idx_list,
_get_new_tokens)
encoded.update(media_inputs)

encoded['input_ids'] = input_ids
Expand Down
5 changes: 2 additions & 3 deletions swift/llm/template/template/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,8 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
return_tensors='pt')
splited_tokens = self._split_list(media_inputs['input_ids'][0].tolist(), split_token)

encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, idx_list,
lambda i: splited_tokens[i])
encoded['loss_scale'] = self._extend_loss_scale(loss_scale, idx_list, lambda i: splited_tokens[i])
encoded['input_ids'], encoded['labels'], encoded['loss_scale'] = self._extend_tokens(
input_ids, labels, loss_scale, idx_list, lambda i: splited_tokens[i])
encoded['pixel_values'] = media_inputs['pixel_values']
return encoded

Expand Down
4 changes: 2 additions & 2 deletions swift/llm/template/template/megrez.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
def _get_new_tokens(i):
return self._tokenize(padding[i])

input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
loss_scale = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens)
input_ids, labels, loss_scale = self._extend_tokens(input_ids, labels, loss_scale, idx_list,
_get_new_tokens)

encoded['input_ids'] = input_ids
encoded['labels'] = labels
Expand Down
5 changes: 2 additions & 3 deletions swift/llm/template/template/microsoft.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,8 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
def _get_new_tokens(i):
return placeholders[i]

encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, images_idx + audios_idx,
_get_new_tokens)
encoded['loss_scale'] = self._extend_loss_scale(loss_scale, images_idx + audios_idx, _get_new_tokens)
encoded['input_ids'], encoded['labels'], encoded['loss_scale'] = self._extend_tokens(
input_ids, labels, loss_scale, images_idx + audios_idx, _get_new_tokens)
new_encoded.pop('attention_mask')
encoded.update(new_encoded)
return encoded
Expand Down
63 changes: 63 additions & 0 deletions swift/llm/template/template/midashenglm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Dict, List, Literal, Optional

import torch
import torch.nn.functional as F

from swift.utils import get_env_args
from ..base import Template
from ..constant import MLLMTemplateType
from ..register import register_template
from ..template_inputs import StdTemplateInputs
from ..utils import Context, Word, findall
from ..vision_utils import load_batch
from .qwen import QwenTemplateMeta


class MiDashengLMTemplate(Template):
placeholder_tokens = ['<|AUDIO|>']
skip_prompt = False

def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
inputs: StdTemplateInputs) -> List[Context]:
assert media_type == 'audio'
return ['<|AUDIO|>']

def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
from transformers.audio_utils import load_audio
encoded = super()._encode(inputs)
input_ids = encoded['input_ids']
sampling_rate = get_env_args('sampling_rate', int, 16000)
inputs.audios = load_batch(inputs.audios, partial(load_audio, sampling_rate=sampling_rate))
audio_token = self._tokenize('<|AUDIO|>')[0]
idx_list = findall(input_ids, audio_token)
if idx_list:
split_token = self._tokenize('\n')[0]
audio_inputs = self.processor(text='\n'.join(['<|AUDIO|>'] * len(inputs.audios)), audio=inputs.audios)
splited_tokens = self._split_list(audio_inputs['input_ids'][0].tolist(), split_token)

encoded['input_ids'], encoded['labels'], encoded['loss_scale'] = self._extend_tokens(
input_ids, encoded['labels'], encoded['loss_scale'], idx_list, lambda i: splited_tokens[i])
encoded['input_values'] = audio_inputs['input_values']
encoded['audio_length'] = audio_inputs['audio_length']
return encoded

def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
res = super()._data_collator(batch, padding_to=padding_to)

input_values = [b['input_values'] for b in batch if b.get('input_values') is not None]
audio_lengths = [b['audio_length'] for b in batch if b.get('audio_length') is not None]
Comment on lines +50 to +51
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This part of the code iterates over the batch list twice, once for input_values and once for audio_lengths. For large batches, this could be slightly inefficient. You can combine these into a single loop to improve performance.

Suggested change
input_values = [b['input_values'] for b in batch if b.get('input_values') is not None]
audio_lengths = [b['audio_length'] for b in batch if b.get('audio_length') is not None]
input_values = []
audio_lengths = []
for b in batch:
iv = b.get('input_values')
if iv is not None:
input_values.append(iv)
al = b.get('audio_length')
if al is not None:
audio_lengths.append(al)


if input_values:
res['audio_length'] = torch.concat(audio_lengths)
for i in range(len(input_values)):
pad_len = (res['audio_length'].max() - input_values[i].shape[1]).item()
input_values[i] = F.pad(input_values[i], (0, pad_len), 'constant', 0)
res['input_values'] = torch.concat(input_values)

return res


register_template(QwenTemplateMeta(MLLMTemplateType.midashenglm, template_cls=MiDashengLMTemplate))
3 changes: 1 addition & 2 deletions swift/llm/template/template/minicpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,7 @@ def _get_new_tokens(i):
placeholder += '\n'
return self.processor.encode(placeholder, add_special_tokens=False)

input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
loss_scale = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens)
input_ids, labels, loss_scale = self._extend_tokens(input_ids, labels, loss_scale, idx_list, _get_new_tokens)

if inputs.images:
input_tensor_ids = torch.tensor(input_ids)
Expand Down
4 changes: 2 additions & 2 deletions swift/llm/template/template/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def _get_new_tokens(i):
replace_str = ''.join(replace_tokens)
return processor.encode(replace_str, add_special_tokens=False)

encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
encoded['loss_scale'] = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens)
encoded['input_ids'], encoded['labels'], encoded['loss_scale'] = self._extend_tokens(
input_ids, labels, loss_scale, idx_list, _get_new_tokens)

return encoded

Expand Down
4 changes: 2 additions & 2 deletions swift/llm/template/template/moonshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def _get_new_tokens(i):
token_len = (image_grid_hws[i].prod() // merge_length)
return [media_token] * token_len

input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
loss_scale = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens)
input_ids, labels, loss_scale = self._extend_tokens(input_ids, labels, loss_scale, idx_list,
_get_new_tokens)

encoded['loss_scale'] = loss_scale
encoded['input_ids'] = input_ids
Expand Down
4 changes: 2 additions & 2 deletions swift/llm/template/template/mplug.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ def _get_new_tokens(i):
token_list = image_token_list
return token_list

input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
loss_scale = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens)
input_ids, labels, loss_scale = self._extend_tokens(input_ids, labels, loss_scale, idx_list,
_get_new_tokens)

image_token_idx = torch.tensor(findall(input_ids, image_token_list))
if self.version == '241101':
Expand Down
4 changes: 2 additions & 2 deletions swift/llm/template/template/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def _get_new_tokens(i):
img_tokens: List[int] = self.processor.encode(replace_str, add_special_tokens=False)
return img_tokens

encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
encoded['loss_scale'] = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens)
encoded['input_ids'], encoded['labels'], encoded['loss_scale'] = self._extend_tokens(
input_ids, labels, loss_scale, idx_list, _get_new_tokens)

return encoded

Expand Down
Loading
Loading