Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions deepmd/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"gelu",
"gelu_tf",
"silu",
"silut",
"none",
"linear",
]
Expand Down
31 changes: 31 additions & 0 deletions deepmd/dpmodel/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,37 @@ def fn(x):
# generated by GitHub Copilot
return x / (1 + xp.exp(-x))

return fn
elif activation_function.startswith("silut") or activation_function.startswith(
"custom_silu"
):

def sigmoid(x):
return 1 / (1 + np.exp(-x))

def silu(x):
return x * sigmoid(x)

def silu_grad(x):
sig = sigmoid(x)
return sig + x * sig * (1 - sig)

threshold = (
float(activation_function.split(":")[-1])
if ":" in activation_function
else 3.0
)
slope = float(silu_grad(threshold))
const = float(silu(threshold))

def fn(x):
xp = array_api_compat.array_namespace(x)
return xp.where(
x < threshold,
x * (1 / (1 + xp.exp(-x))),
xp.tanh(slope * (x - threshold)) + const,
)

return fn
elif activation_function.lower() in ("none", "linear"):

Expand Down
42 changes: 42 additions & 0 deletions deepmd/pd/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,47 @@
)


class SiLUT(paddle.nn.Layer):
def __init__(self, threshold=3.0):
super().__init__()

def sigmoid(x):
return 1 / (1 + np.exp(-x))

def silu(x):
return x * sigmoid(x)

def silu_grad(x):
sig = sigmoid(x)
return sig + x * sig * (1 - sig)

self.threshold = threshold
self.slope = float(silu_grad(threshold))
self.const = float(silu(threshold))

def forward(self, x: paddle.Tensor) -> paddle.Tensor:
silu_part = F.silu(x)
mask = x >= self.threshold
if paddle.any(mask):
tanh_part = paddle.tanh(self.slope * (x - self.threshold)) + self.const
return paddle.where(x < self.threshold, silu_part, tanh_part)
else:
return silu_part


class ActivationFn(paddle.nn.Layer):
def __init__(self, activation: str | None):
super().__init__()
self.activation: str = activation if activation is not None else "linear"
if self.activation.lower().startswith(
"silut"
) or self.activation.lower().startswith("custom_silu"):
threshold = (
float(self.activation.split(":")[-1]) if ":" in self.activation else 3.0
)
self.silut = SiLUT(threshold=threshold)
else:
self.silut = None

def forward(self, x: paddle.Tensor) -> paddle.Tensor:
"""Returns the tensor after applying activation function corresponding to `activation`."""
Expand All @@ -53,6 +90,11 @@ def forward(self, x: paddle.Tensor) -> paddle.Tensor:
return F.sigmoid(x)
elif self.activation.lower() == "silu":
return F.silu(x)
elif self.activation.lower().startswith(
"silut"
) or self.activation.lower().startswith("custom_silu"):
assert self.silut is not None
return self.silut(x)
elif self.activation.lower() == "linear" or self.activation.lower() == "none":
return x
else:
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def train(
output: str = "out.json",
) -> None:
log.info("Configuration path: %s", input_file)
env.CUSTOM_OP_USE_JIT = True
if LOCAL_RANK == 0:
SummaryPrinter()()
with open(input_file) as fin:
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
JIT = False
CACHE_PER_SYS = 5 # keep at most so many sets per sys in memory
ENERGY_BIAS_TRAINABLE = True
CUSTOM_OP_USE_JIT = False
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed

PRECISION_DICT = {
"float16": torch.float16,
Expand Down Expand Up @@ -76,6 +77,7 @@

__all__ = [
"CACHE_PER_SYS",
"CUSTOM_OP_USE_JIT",
"DEFAULT_PRECISION",
"DEVICE",
"ENERGY_BIAS_TRAINABLE",
Expand Down
157 changes: 157 additions & 0 deletions deepmd/pt/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,169 @@
import torch.nn.functional as F

from deepmd.dpmodel.common import PRECISION_DICT as NP_PRECISION_DICT
from deepmd.pt.utils import (
env,
)

from .env import (
DEVICE,
)
from .env import PRECISION_DICT as PT_PRECISION_DICT


def silut_forward(
x: torch.Tensor, threshold: float, slope: float, const_val: float
) -> torch.Tensor:
sig = torch.sigmoid(x)
silu = x * sig
tanh_part = torch.tanh(slope * (x - threshold)) + const_val
return torch.where(x >= threshold, tanh_part, silu)


def silut_backward(
x: torch.Tensor, grad_output: torch.Tensor, threshold: float, slope: float
):
sig = torch.sigmoid(x)
grad_silu = sig * (1 + x * (1 - sig))

tanh_term = torch.tanh(slope * (x - threshold))
grad_tanh = slope * (1 - tanh_term.pow(2))

grad = torch.where(x >= threshold, grad_tanh, grad_silu)
return grad * grad_output, grad


def silut_double_backward(
x: torch.Tensor,
grad_grad_output: torch.Tensor,
grad_output: torch.Tensor,
threshold: float,
slope: float,
) -> torch.Tensor:
# Tanh branch
tanh_term = torch.tanh(slope * (x - threshold))
grad_grad = -2 * slope * slope * tanh_term * (1 - tanh_term * tanh_term)

# SiLU branch
sig = 1.0 / (1.0 + torch.exp(-x))
sig_prime = sig * (1 - sig)
silu_term = sig_prime * (2 + x * (1 - 2 * sig))

grad_grad = torch.where(x >= threshold, grad_grad, silu_term)

return grad_output * grad_grad * grad_grad_output


class SiLUTScript(torch.nn.Module):
def __init__(self, threshold: float = 3.0):
super().__init__()
self.threshold = threshold

# Precompute parameters for the tanh replacement
sigmoid_threshold = 1 / (1 + np.exp(-threshold))
self.slope = float(
sigmoid_threshold + threshold * sigmoid_threshold * (1 - sigmoid_threshold)
)
self.const_val = float(threshold * sigmoid_threshold)
self.get_script_code()

def get_script_code(self):
silut_forward_script = torch.jit.script(silut_forward)
silut_backward_script = torch.jit.script(silut_backward)
silut_double_backward_script = torch.jit.script(silut_double_backward)

class SiLUTFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, threshold, slope, const_val):
ctx.save_for_backward(x)
ctx.threshold = threshold
ctx.slope = slope
ctx.const_val = const_val
return silut_forward_script(x, threshold, slope, const_val)

@staticmethod
def backward(ctx, grad_output):
(x,) = ctx.saved_tensors
threshold = ctx.threshold
slope = ctx.slope

grad_input = SiLUTGradFunction.apply(x, grad_output, threshold, slope)
return grad_input, None, None, None

class SiLUTGradFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, grad_output, threshold, slope):
ctx.threshold = threshold
ctx.slope = slope
grad_input, grad = silut_backward_script(
x, grad_output, threshold, slope
)
ctx.save_for_backward(x, grad_output, grad)
return grad_input

@staticmethod
def backward(ctx, grad_grad_output):
(x, grad_output, grad) = ctx.saved_tensors
threshold = ctx.threshold
slope = ctx.slope

grad_input = silut_double_backward_script(
x, grad_grad_output, grad_output, threshold, slope
)
return grad_input, grad * grad_grad_output, None, None

self.SiLUTFunction = SiLUTFunction

def forward(self, x):
return self.SiLUTFunction.apply(x, self.threshold, self.slope, self.const_val)


class SiLUT(torch.nn.Module):
def __init__(self, threshold=3.0):
super().__init__()

def sigmoid(x):
return 1 / (1 + np.exp(-x))

def silu(x):
return x * sigmoid(x)

def silu_grad(x):
sig = sigmoid(x)
return sig + x * sig * (1 - sig)

self.threshold = threshold
self.slope = float(silu_grad(threshold))
self.const = float(silu(threshold))

def forward(self, x: torch.Tensor) -> torch.Tensor:
silu_part = F.silu(x)
mask = x >= self.threshold
if torch.any(mask):
Comment thread
njzjz marked this conversation as resolved.
tanh_part = torch.tanh(self.slope * (x - self.threshold)) + self.const
return torch.where(x < self.threshold, silu_part, tanh_part)
else:
return silu_part


class ActivationFn(torch.nn.Module):
def __init__(self, activation: Optional[str]) -> None:
super().__init__()
self.activation: str = activation if activation is not None else "linear"
if self.activation.lower().startswith(
"silut"
) or self.activation.lower().startswith("custom_silu"):
threshold = (
float(self.activation.split(":")[-1]) if ":" in self.activation else 3.0
)
if env.CUSTOM_OP_USE_JIT:
# for efficient training but can not be jit
self.silut = SiLUTScript(threshold=threshold)
else:
# for jit freeze
self.silut = SiLUT(threshold=threshold)
else:
self.silut = None

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Returns the tensor after applying activation function corresponding to `activation`."""
Expand All @@ -41,6 +193,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.sigmoid(x)
elif self.activation.lower() == "silu":
return F.silu(x)
elif self.activation.lower().startswith(
"silut"
) or self.activation.lower().startswith("custom_silu"):
assert self.silut is not None
return self.silut(x)
elif self.activation.lower() == "linear" or self.activation.lower() == "none":
return x
else:
Expand Down
44 changes: 44 additions & 0 deletions deepmd/tf/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,47 @@ def silu(x: tf.Tensor) -> tf.Tensor:
return x * tf.sigmoid(x)


def get_silut(activation_function: str = "silut"):
import numpy as np

def sigmoid(x):
return 1 / (1 + np.exp(-x))

def silu(x):
return x * sigmoid(x)

def silu_grad(x):
sig = sigmoid(x)
return sig + x * sig * (1 - sig)

threshold = (
float(activation_function.split(":")[-1]) if ":" in activation_function else 3.0
)
slope = float(silu_grad(threshold))
const = float(silu(threshold))

def silut(x: tf.Tensor) -> tf.Tensor:
"""The customized sigmoid-weighted linear unit with tanh.

Parameters
----------
x : tf.Tensor
float Tensor to perform activation

Returns
-------
tf.Tensor
`x` with the SiLUT activation applied
"""
return tf.where(
x < threshold,
x * tf.sigmoid(x),
tf.nn.tanh(slope * (x - threshold)) + const,
)

return silut


ACTIVATION_FN_DICT = {
"relu": tf.nn.relu,
"relu6": tf.nn.relu6,
Expand All @@ -153,6 +194,7 @@ def silu(x: tf.Tensor) -> tf.Tensor:
"gelu": gelu,
"gelu_tf": gelu_tf,
"silu": silu,
"silut": get_silut("silut"),
"linear": lambda x: x,
"none": lambda x: x,
}
Expand Down Expand Up @@ -182,6 +224,8 @@ def get_activation_func(
if activation_fn is None:
activation_fn = "none"
assert activation_fn is not None
if activation_fn.lower().startswith("silut"):
ACTIVATION_FN_DICT[activation_fn.lower()] = get_silut(activation_fn.lower())
if activation_fn.lower() not in ACTIVATION_FN_DICT:
raise RuntimeError(f"{activation_fn} is not a valid activation function")
return ACTIVATION_FN_DICT[activation_fn.lower()]
Expand Down
Loading