Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions nnsmith/abstract/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,18 @@ def __call__(self, op_type: Type["AbsOpBase"]) -> Type["AbsOpBase"]:


class mark_materialize:
def __init__(self, dialect: str):
def __init__(self, dialect: str, limit_domain=False):
self.limit_domain = limit_domain
self.dialect = dialect

def __call__(self, op_type: Type["AbsOpBase"]) -> Type["AbsOpBase"]:
op_list = FULL_OPERATOR_SETS.setdefault(self.dialect, [])

if op_type not in op_list:
op_type = mark_abstract(self.dialect)(op_type)
op_type.limit_domain = self.limit_domain
op_list.append(op_type)
op_list.sort(key=lambda x: x.__name__)
op_type = mark_abstract(self.dialect)(op_type)

return op_type

Expand Down Expand Up @@ -265,8 +267,8 @@ class AbsOpBase(ABC):
# this op can accept one of float32xfloat32, float64xfloat64, and int32xint32 as input dtypes.
in_dtypes: List[Tuple[DType, ...]] = None # Overwrite me!
out_dtypes: List[Tuple[DType, ...]] = None
# whether to disable the op during graph generation
_skip = False

limit_domain = False

dialect = None

Expand Down Expand Up @@ -686,7 +688,7 @@ class Div(BcastBinaryOp):
out_dtypes = [(i,) for i in DTYPE_FLOATS]


@mark_materialize("core")
@mark_materialize("core", limit_domain=True)
class Pow(BcastBinaryOp):
in_dtypes = [(i, i) for i in DTYPE_FLOATS]
out_dtypes = [(i,) for i in DTYPE_FLOATS]
Expand Down Expand Up @@ -737,13 +739,13 @@ class Cos(TrigonometricOp):
out_dtypes = [(i,) for i in DTYPE_FLOATS]


@mark_materialize("core")
@mark_materialize("core", limit_domain=True)
class Asin(TrigonometricOp):
in_dtypes = [(i,) for i in DTYPE_FLOATS]
out_dtypes = [(i,) for i in DTYPE_FLOATS]


@mark_materialize("core")
@mark_materialize("core", limit_domain=True)
class Acos(TrigonometricOp):
in_dtypes = [(i,) for i in DTYPE_FLOATS]
out_dtypes = [(i,) for i in DTYPE_FLOATS]
Expand Down Expand Up @@ -795,13 +797,13 @@ class Round(ElementWiseUnaryOp):
out_dtypes = [(i,) for i in DTYPE_FLOATS]


@mark_materialize("core")
@mark_materialize("core", limit_domain=True)
class Sqrt(ElementWiseUnaryOp):
in_dtypes = [(i,) for i in DTYPE_FLOATS]
out_dtypes = [(i,) for i in DTYPE_FLOATS]


@mark_materialize("core")
@mark_materialize("core", limit_domain=True)
class Log2(ElementWiseUnaryOp):
in_dtypes = [(i,) for i in DTYPE_FLOATS]
out_dtypes = [(i,) for i in DTYPE_FLOATS]
Expand Down
4 changes: 3 additions & 1 deletion nnsmith/cli/fuzz.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ def __init__(
model_cfg["type"], backend_target=cfg["backend"]["target"]
)
self.ModelType.add_seed_setter()
self.opset = auto_opset(self.ModelType, self.factory)
self.opset = auto_opset(
self.ModelType, self.factory, vulops=cfg["mgen"]["vulops"]
)

seed = cfg["fuzz"]["seed"] or random.getrandbits(32)
set_seed(seed)
Expand Down
2 changes: 1 addition & 1 deletion nnsmith/cli/model_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def main(cfg: DictConfig):
factory = None

gen = random_model_gen(
opset=auto_opset(ModelType, factory),
opset=auto_opset(ModelType, factory, vulops=mgen_cfg["vulops"]),
init_rank=mgen_cfg["init_rank"],
seed=seed,
max_nodes=mgen_cfg["max_nodes"],
Expand Down
1 change: 1 addition & 0 deletions nnsmith/config/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ mgen: # model gen.
init_rank: 4
max_nodes: 5
timeout_ms: 50000
vulops: False
save: "nnsmith_output"
seed: null

Expand Down
6 changes: 5 additions & 1 deletion nnsmith/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ def filter_nan(report: BugReport) -> bool: # True means filter;

# numpy.assert_allclose style.
# TODO(ganler): can we use more well-formed checking? say directly checking the results?
return "nan location mismatch" in report.log
return (
"nan location mismatch" in report.log
or "-9223372036854775808" in report.log # tf.cast(nan, int) is UB.
or "-2147483648" in report.log
)


@filter("inf")
Expand Down
8 changes: 6 additions & 2 deletions nnsmith/narrow_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,12 +252,16 @@ def auto_opconfig(
return opset


def auto_opset(model_cls: Type[Model], factory: Optional[BackendFactory] = None):
def auto_opset(
model_cls: Type[Model],
factory: Optional[BackendFactory] = None,
vulops: bool = False,
) -> List[Type[AbsOpBase]]:
# None means only test model exportation.
topset_config = auto_opconfig(model_cls, factory)
opset = []
for op in model_cls.operators():
if op.name() not in topset_config:
if op.name() not in topset_config or (vulops == False and op.limit_domain):
continue
op.in_dtypes = topset_config[op.name()].in_dtypes
op.out_dtypes = topset_config[op.name()].out_dtypes
Expand Down