Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
| TensorRT | ✅ | |
| TFLite | | ⚠️ |
| XLA | | ⚠️ |
| IREE | | ⚠️ |

✅: Supported; ⚠️: Beta support; Others are not supported yet -- Contributions are welcome!

Expand Down
2 changes: 1 addition & 1 deletion doc/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ nnsmith.model_exec model.type=onnx backend.type=onnxruntime model.path=nnsmith_o
nnsmith.model_exec model.type=onnx \
backend.type=onnxruntime \
model.path=nnsmith_output/model.onnx \
cmp.with='{type:tvm, optmax:true, device:cpu}'
cmp.with='{type:tvm, optmax:true, target:cpu}'
```

## Data type testing
Expand Down
219 changes: 82 additions & 137 deletions nnsmith/abstract/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def broadcast_shapes(
return max_shape


def broadcast_cons(*shapes: List[Union[z3.ExprRef, int]]) -> List[z3.ExprRef]:
def broadcast_cons(*shapes: List[Union[z3.ExprRef, int]]) -> List[z3.BoolRef]:
tgt_shape = broadcast_shapes(*shapes)
cons = []
max_dim = len(tgt_shape)
Expand All @@ -194,7 +194,7 @@ def broadcast_cons(*shapes: List[Union[z3.ExprRef, int]]) -> List[z3.ExprRef]:
return cons


def broadcast_cons_binary(*shapes: List[Union[z3.ExprRef, int]]) -> List[z3.ExprRef]:
def broadcast_cons_binary(*shapes: List[Union[z3.ExprRef, int]]) -> List[z3.BoolRef]:
SanityCheck.eq(len(shapes), 2)
tgt_shape = broadcast_shapes(*shapes)
cons = []
Expand Down Expand Up @@ -226,7 +226,7 @@ def broadcast_cons_binary(*shapes: List[Union[z3.ExprRef, int]]) -> List[z3.Expr
return cons


def broadcast_to_cons(*shapes: List[Union[z3.ExprRef, int]]) -> List[z3.ExprRef]:
def broadcast_to_cons(*shapes: List[Union[z3.ExprRef, int]]) -> List[z3.BoolRef]:
"""Unidirectional broadcast. Last input is the target shape.

Examples of valid unidirectional broadcast:
Expand Down Expand Up @@ -603,11 +603,6 @@ def deduct_inp_ranks_and_dtype(
)
)

# TODO: support exactly what onnx spec says (e.g., int support in the rhs)
# lhs_dtypes = (DType.int32, DType.int64, DType.float32, DType.float64)
# rhs_dtypes = (DType.int32, DType.int64, DType.float32, DType.float64)
# Pow.in_dtypes = itertools.product(lhs_dtypes, rhs_dtypes)


class Input(AbsOpBase):
in_dtypes = [()]
Expand All @@ -623,7 +618,7 @@ def type_transfer(self, input_shapes: List[AbsTensor]) -> List[AbsTensor]:
SanityCheck.eq(len(input_shapes), 0)
return [self.abs_tensor]

def requires(self, input_shapes: List[AbsTensor]) -> List[z3.ExprRef]:
def requires(self, input_shapes: List[AbsTensor]) -> List[Union[z3.BoolRef, bool]]:
SanityCheck.eq(len(input_shapes), 0)
return []

Expand All @@ -649,7 +644,7 @@ def type_transfer(self, input_shapes: List[AbsTensor]) -> List[AbsTensor]:
SanityCheck.eq(len(input_shapes), 0)
return [self.abs_tensor]

def requires(self, input_shapes: List[AbsTensor]) -> List[z3.ExprRef]:
def requires(self, input_shapes: List[AbsTensor]) -> List[Union[z3.BoolRef, bool]]:
SanityCheck.eq(len(input_shapes), 0)
return []

Expand Down Expand Up @@ -825,7 +820,7 @@ def __init__(self, dim: Union[int, z3.ExprRef]):
self.inp_ranks = [int_from(1)]
self.out_ranks = [int_from(1)]

def requires(self, input_shapes: List[AbsTensor]) -> List[z3.ExprRef]:
def requires(self, input_shapes: List[AbsTensor]) -> List[Union[z3.BoolRef, bool]]:
return [nnsmith_lt(self.dim, input_shapes[0].ndims), nnsmith_ge(self.dim, 0)]


Expand Down Expand Up @@ -1069,7 +1064,7 @@ def __init__(self, padding_list, pad_t):
len(self.padding_list) % 2 == 0
), f"padding_list must be even, got {self.padding_list}"

def requires(self, input_shapes: List[AbsTensor]) -> List[z3.ExprRef]:
def requires(self, input_shapes: List[AbsTensor]) -> List[Union[z3.BoolRef, bool]]:
pad = self.padding_list
isv = input_shapes[0].shape
cons = []
Expand Down Expand Up @@ -1125,7 +1120,7 @@ def __init__(self, *padding_list):
self.inp_ranks = [int_range(len(padding_list) // 2 + 1, 4)]
self.out_ranks = [int_range(len(padding_list) // 2 + 1, 4)]

def requires(self, input_shapes: List[AbsTensor]) -> List[z3.ExprRef]:
def requires(self, input_shapes: List[AbsTensor]) -> List[Union[z3.BoolRef, bool]]:
cons = super().requires(input_shapes)
pad = self.padding_list
isv = input_shapes[0].shape
Expand Down Expand Up @@ -1244,7 +1239,7 @@ def deduct_inp_ranks_and_dtype(
) -> List[Tuple[int, DType]]:
return [(4, DType.float32)]

def requires(self, input_shapes: List[AbsTensor]) -> List[z3.ExprRef]:
def requires(self, input_shapes: List[AbsTensor]) -> List[Union[z3.BoolRef, bool]]:
return [
nnsmith_eq(self.nfeat, input_shapes[0].shape[1]),
nnsmith_ge(input_shapes[0].shape[0], 2),
Expand Down Expand Up @@ -1736,12 +1731,10 @@ def type_transfer(self, input_shapes: List[AbsTensor]) -> List[AbsTensor]:
]

def requires(self, input_shapes: List[AbsTensor]):
reduce_dim = self._init_reduce_dim(input_shapes[0].shape)
self._init_reduce_dim(input_shapes[0].shape)
return []

def _get_irank(self, orank):
# if orank == 0: # TVM bug ~ crash on scalar.min()
# return random.randint(0, 1)
return orank + 1

def deduct_inp_ranks_and_dtype(
Expand All @@ -1763,9 +1756,8 @@ def requires(self, input_shapes):

@mark_materialize("core")
class ReduceSum(ReduceBase):
# pytorch exporter doesn't support int32
in_dtypes = [(i,) for i in DTYPE_NON_BOOLS if i != DType.int32]
out_dtypes = [(i,) for i in DTYPE_NON_BOOLS if i != DType.int32]
in_dtypes = [(i,) for i in DTYPE_NON_BOOLS]
out_dtypes = [(i,) for i in DTYPE_NON_BOOLS]


@mark_materialize("core")
Expand All @@ -1782,15 +1774,13 @@ class ReduceMax(ReduceBase):

@mark_materialize("core")
class ReduceMean(ReduceBase):
in_dtypes = [(i,) for i in DTYPE_FLOATS]
out_dtypes = [(i,) for i in DTYPE_FLOATS]
in_dtypes = [(i,) for i in DTYPE_NON_BOOLS]
out_dtypes = [(i,) for i in DTYPE_NON_BOOLS]


@mark_materialize("core")
class ArgMin(ReduceBase):
# FIXME(JK): ints are somehow not supported in onnxruntime, which we use to gen inputs.
# Make it include ints once we use other backends other than onnxruntime.
in_dtypes = [(i,) for i in DTYPE_FLOATS]
in_dtypes = [(i,) for i in DTYPE_NON_BOOLS]
out_dtypes = [(DType.int64,)]
_reduce_out_dtype = DType.int64

Expand All @@ -1804,9 +1794,7 @@ def deduct_inp_ranks_and_dtype(

@mark_materialize("core")
class ArgMax(ReduceBase):
# FIXME(JK): ints are somehow not supported in onnxruntime, which we use to gen inputs.
# Make it include ints once we use other backends other than onnxruntime.
in_dtypes = [(i,) for i in DTYPE_FLOATS]
in_dtypes = [(i,) for i in DTYPE_NON_BOOLS]
out_dtypes = [(DType.int64,)]
_reduce_out_dtype = DType.int64

Expand Down Expand Up @@ -1841,7 +1829,7 @@ def deduct_inp_ranks_and_dtype(

@mark_materialize("core")
class Tril(TriBase):
def requires(self, input_shapes: List[AbsTensor]) -> List[z3.ExprRef]:
def requires(self, input_shapes: List[AbsTensor]) -> List[Union[z3.BoolRef, bool]]:
ConstraintCheck.true(input_shapes[0].ndims == 2)
nrow = input_shapes[0].shape[0]
ncol = input_shapes[0].shape[1]
Expand All @@ -1850,7 +1838,7 @@ def requires(self, input_shapes: List[AbsTensor]) -> List[z3.ExprRef]:

@mark_materialize("core")
class Triu(TriBase):
def requires(self, input_shapes: List[AbsTensor]) -> List[z3.ExprRef]:
def requires(self, input_shapes: List[AbsTensor]) -> List[Union[z3.BoolRef, bool]]:
ConstraintCheck.true(input_shapes[0].ndims == 2)
nrow = input_shapes[0].shape[0]
ncol = input_shapes[0].shape[1]
Expand Down Expand Up @@ -1878,7 +1866,7 @@ def _init_concat_axis(self, input_shapes: List[AbsTensor]) -> int:
self.extra_attrs["axis"] = random.randint(0, input_shapes[0].ndims - 1)
return self.extra_attrs["axis"]

def requires(self, input_shapes: List[AbsTensor]) -> List[z3.ExprRef]:
def requires(self, input_shapes: List[AbsTensor]) -> List[Union[z3.BoolRef, bool]]:
ndims = input_shapes[0].ndims
SanityCheck.gt(ndims, self._init_concat_axis(input_shapes))
for s in input_shapes:
Expand Down Expand Up @@ -1961,7 +1949,7 @@ def __init__(self, dtype):
def __str__(self) -> str:
return "Cast " + str(self.extra_attrs).replace(":", "=")

def requires(self, input_shapes: List[AbsTensor]) -> List[z3.ExprRef]:
def requires(self, input_shapes: List[AbsTensor]) -> List[Union[z3.BoolRef, bool]]:
return []

def type_transfer(self, input_shapes: List[AbsTensor]) -> List[AbsTensor]:
Expand Down Expand Up @@ -2015,124 +2003,81 @@ def __init__(self):


@mark_materialize("core")
class Gemm(TernaryOpBase):
# https://pytorch.org/docs/stable/generated/torch.addmm.html?highlight=addmm#torch.addmm
in_dtypes = [(i, i, i) for i in DTYPE_NON_BOOLS]
class MatMul(BinaryOpBase):
in_dtypes = [(i, i) for i in DTYPE_NON_BOOLS]
out_dtypes = [(i,) for i in DTYPE_NON_BOOLS]

def __init__(self):
super().__init__()
self.inp_ranks = [int_until(2), (2,), (2,)]
self.out_ranks = [(2,)]
# Consider at most 3D tensors (batched mm)
self.inp_ranks = [int_range(1, 3), int_range(1, 3)]
self.out_ranks = [int_until(3)]

def _set_or_get_extra_attrs(self, dtype=None):
if "alpha" not in self.extra_attrs:
assert (
dtype is not None
), "dtype must be specified at the first time of this call"
alpha = random.uniform(-2, 2)
beta = random.uniform(-2, 2)
if dtype in DTYPE_INTS:
beta, alpha = int(beta), int(alpha)
self.extra_attrs["alpha"] = alpha
self.extra_attrs["beta"] = beta
return self.extra_attrs
def type_transfer(self, input_shapes: List[AbsTensor]) -> List[AbsTensor]:
# https://pytorch.org/docs/stable/generated/torch.matmul.html#torch.matmul
# shape: [*batches(?), *rc (row and col)]
lhs = input_shapes[0].shape
rhs = input_shapes[1].shape

lrc = lhs[-2:]
rrc = rhs[-2:]
orc = [*lrc[:-1], *rrc[1:]]

lbatch = lhs[: -len(lrc)]
rbatch = rhs[: -len(rrc)]
batches = []
if len(lbatch) > len(rbatch):
batches = lbatch[: len(lbatch) - len(rbatch)]
for x, y in zip(lbatch[len(batches) :], rbatch):
batches.append(nnsmith_max(x, y))
elif len(lbatch) < len(rbatch):
batches = rbatch[: len(rbatch) - len(lbatch)]
for x, y in zip(lbatch, rbatch[len(batches) :]):
batches.append(nnsmith_max(x, y))

return [AbsTensor([*batches, *orc], input_shapes[0].dtype)]

def requires(self, input_shapes: List[AbsTensor]) -> List[Union[z3.BoolRef, bool]]:
cons = []

def requires(self, input_shapes: List[AbsTensor]):
ConstraintCheck.true(input_shapes[0].ndims <= 2)
out_shape = self.checked_type_transfer(input_shapes)[0]
cons = broadcast_to_cons(input_shapes[0].shape, out_shape.shape)

# matmul constraint
mat1, mat2 = input_shapes[1], input_shapes[2]
cons.append(mat1.shape[1] == mat2.shape[0])
self._set_or_get_extra_attrs(input_shapes[0].dtype.torch())
if Z3_CONS_FLOPS:
cons.append(nnsmith_le(self.flops(input_shapes), FLOPS_LIM))
return cons
lhs = input_shapes[0].shape
rhs = input_shapes[1].shape

def type_transfer(self, input_shapes: List[AbsTensor]) -> List[AbsTensor]:
mat1, mat2 = input_shapes[1], input_shapes[2]
return [AbsTensor([mat1.shape[0], mat2.shape[1]], input_shapes[0].dtype)]
lrc = lhs[-2:]
rrc = rhs[-2:]

def flops(self, input_shapes):
mat1, mat2 = input_shapes[1], input_shapes[2]
return mat1.shape[0] * mat1.shape[1] * mat2.shape[1]
# CHECK: l.cols = r.rows
cons.append(lrc[-1] == rrc[0])

# CHECK: batch dim broadcastable
lbatch = lhs[: -len(lrc)]
rbatch = rhs[: -len(rrc)]
common_tail = min(len(lbatch), len(rbatch))
for x, y in zip(lbatch[-common_tail:], rbatch[-common_tail:]):
cons.append(nnsmith_or(x == y, nnsmith_or(x == 1, y == 1)))

return cons

def deduct_inp_ranks_and_dtype(
self, out_abs_tensor: List[AbsTensor]
) -> List[Tuple[int, DType]]:
return [
(random.randint(0, 2), out_abs_tensor[0].dtype),
(2, out_abs_tensor[0].dtype),
(2, out_abs_tensor[0].dtype),
]


ALL_OP_STR2TYPE = {c.__name__: c for c in FULL_OPERATOR_SETS["core"]}
EXPANDED_OP_V0 = [Cast, Expand, TrigonometricOp, Comparator, Logical, InterpBase]
# may also consider Concat, BcastBinaryOp1
EXPANDED_OP = EXPANDED_OP_V0 # points to latest version


def config_skip_op(skip_config):
SKIP_FOR_BKEND = {
"trt": [
# unsupported
"Xor",
"Equal:bool,bool",
"Gemm:int32,int32,int32",
# 'Acos:float64', 'Asin:float64', 'Atan:float64', 'Ceil:float64',
# 'Cos:float64', 'Sin:float64', 'Tan:float64', 'GELU:float64', 'LeakyReLU:float64',
# 'Abs:int64', 'Abs:int32',
# # buggy, see https://github.com/NVIDIA/TensorRT/issues/1781
# 'Less', 'Greater', 'Equal',
# buggy
],
"tvm": [],
"tvm-cuda": [],
"ort": [],
"ort-cpu": [],
"xla": [],
"tch": [],
"dummy": [],
}
print("skip config:", skip_config)
skip_config = skip_config.split(",")
skip = []
for op in skip_config:
if op.startswith("backend:"):
skip.extend(SKIP_FOR_BKEND[op[len("backend:") :]])
if out_abs_tensor[0].ndims == 1:
return [(1, out_abs_tensor[0].dtype), (1, out_abs_tensor[0].dtype)]
elif out_abs_tensor[0].ndims == 2:
# assume no batch.
return [(2, out_abs_tensor[0].dtype), (2, out_abs_tensor[0].dtype)]
else:
skip.append(op)
for op_name_pattern in skip:
skip_comb = None
if op_name_pattern.find(":") != -1:
op_name_pattern, skip_comb = op_name_pattern.split(":")
skip_comb = skip_comb.split(",")
op_name_pattern = op_name_pattern.lower()
for op_name in fnmatch.filter(
map(lambda x: x.__name__.lower(), FULL_OPERATOR_SETS["core"]),
op_name_pattern,
):
op_id = [i.__name__.lower() for i in FULL_OPERATOR_SETS["core"]].index(
op_name
)
op = FULL_OPERATOR_SETS["core"][op_id]
msg = ["skip op:", op_name]
if skip_comb is not None: # only skip some dtype combinations
skip_comb = tuple(map(DType.from_str, skip_comb))
msg += ["skip dtype combination:", skip_comb]
assert (
skip_comb in op.in_dtypes
), "combination {} not found in op({}).in_dtypes: {}".format(
skip_comb, op_name, op.in_dtypes
)
op.in_dtypes.remove(skip_comb)
else: # skip entire op
msg += ["skip entire"]
op._skip = True
print(*msg)
# assume no rank-1 tensor.
lranks = random.randint(2, out_abs_tensor[0].ndims)
rranks = random.randint(2, out_abs_tensor[0].ndims)
if lranks > rranks:
lranks = out_abs_tensor[0].ndims
else:
rranks = out_abs_tensor[0].ndims
return [
(lranks, out_abs_tensor[0].dtype),
(rranks, out_abs_tensor[0].dtype),
]


_PRAGMA_ONCE_CORE_OP = False
Loading