Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions deepspeed/moe/sharded_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from deepspeed.utils.timer import SynchronizedWallClockTimer
from deepspeed.utils import logger
from deepspeed.utils.bwc import bwc_tensor_model_parallel_world_size
from deepspeed.utils.torch import jit_script_compat
from typing import Callable, Dict, TYPE_CHECKING, Any, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -157,7 +158,7 @@ def einsum(rule, a, b):
# includes stateful caching logic which is incompatible with ONNX.


@torch.jit.script
@jit_script_compat
def _capacity(gates: Tensor, capacity_factor: Tensor, min_capacity: Tensor) -> Tensor:
# gates has shape of SE
num_tokens = gates.shape[0]
Expand All @@ -170,12 +171,12 @@ def _capacity(gates: Tensor, capacity_factor: Tensor, min_capacity: Tensor) -> T
return capacity


@torch.jit.script
@jit_script_compat
def _top_idx(source, k):
return torch.topk(source, k=k, dim=0)[1]


@torch.jit.script
@jit_script_compat
def _one_hot_to_float(x, num_classes):
return F.one_hot(x, num_classes=num_classes).float()

Expand Down
4 changes: 2 additions & 2 deletions deepspeed/runtime/zero/mics_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,20 @@
from typing import List

import numpy as np
import torch
from torch import Tensor

from deepspeed import comm as dist
from deepspeed.accelerator import get_accelerator
from deepspeed.utils import logger
from deepspeed.utils.torch import jit_script_compat


def _log_rank0(msg):
if dist.get_rank() == 0:
logger.info(msg)


@torch.jit.script
@jit_script_compat
def scale_tensors(tensors: List[Tensor], scale: int):
for t in tensors:
t.div_(scale)
Expand Down
5 changes: 3 additions & 2 deletions deepspeed/sequence/fpdt_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from packaging import version
import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator
from deepspeed.utils.torch import jit_script_compat

try:
import flash_attn
Expand Down Expand Up @@ -1040,12 +1041,12 @@ def forward(self,
return output, self.qkv_dense_bias if self.reture_bias else None


@torch.jit.script
@jit_script_compat
def bias_gelu(x):
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))


@torch.jit.script
@jit_script_compat
def bias_gelu_back(g, x):
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
Expand Down
6 changes: 6 additions & 0 deletions deepspeed/utils/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,9 @@ def register_grad_hook(param, hook):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
return grad_acc.register_hook(hook)


def jit_script_compat(fn):
if required_torch_version(min_version=2.0) and hasattr(torch, "compile"):
return torch.compile(fn)
return torch.jit.script(fn)
Loading