diff --git a/nnsmith/materialize/torch/numeric.py b/nnsmith/materialize/torch/numeric.py index cf78307..55fe996 100644 --- a/nnsmith/materialize/torch/numeric.py +++ b/nnsmith/materialize/torch/numeric.py @@ -12,6 +12,8 @@ def numeric_valid(outputs) -> bool: # generalized loss fn def smoothed_relu(x): + if x.dtype == torch.float16: + return torch.relu(x.float()).half() return torch.relu(x)