diff --git a/doc/concept.md b/doc/concept.md index f7ec20c..a525a68 100644 --- a/doc/concept.md +++ b/doc/concept.md @@ -85,13 +85,15 @@ Therefore, we need to let the model generator know how many arguments / symbolic ### Constraining viable ranks of different input tensors -Operators like `Concat` require input tensors to have the same ranks. Therefore, we need to somehow constraint the input ranks for `self.inp_ranks` as by default ranks in `self.inp_ranks` are *independent*. +Operators like `Concat` require input tensors to have the same shape (and thus ranks). -To do so, we set `self.same_inp_dims = True` in initializer: +To do so, we overload the `same_inp_dims` class variable to `True`: ```python -def __init__(...): - super().__init__() +class Concat(AbsOpBase) ... - self.same_inp_dims = True # But this is not True for Pool2D and many binary operators. + same_inp_dims = True + ... + def __init__(...): + ... ``` diff --git a/nnsmith/abstract/op.py b/nnsmith/abstract/op.py index a07591e..fe47f60 100644 --- a/nnsmith/abstract/op.py +++ b/nnsmith/abstract/op.py @@ -257,6 +257,8 @@ def broadcast_to_cons(*shapes: List[Union[z3.ExprRef, int]]) -> List[z3.BoolRef] class AbsOpBase(ABC): # number of parameters; None means it's fixed that can be inferred through `signature`. num_var_param = None + # Require the input dimension sizes to be equivalent. + same_inp_dims = False # whether this op is broadcastable or not bcastable = False # input dtypes: enumerates all possible input dtype combinations. Size of the list is the number of combinations. @@ -281,8 +283,6 @@ def __init__(self): # NOTE: the concrete values of out_ranks are not useful. Just make sure the length is correct. # NOTE: the output shape of input dimensions should be concretized during the execution. self.out_ranks = [] - # Require the input dimension sizes to be equivalent. - self.same_inp_dims = False # NOTE: the input of operator constructors are all Union[int, z3.ExprRef]. self.extra_attrs = {} @@ -374,7 +374,6 @@ def concretize_op(op: AbsOpBase, model: Optional[z3.ModelRef]) -> AbsOpBase: concrete_op = type(op)(*values) concrete_op.inp_ranks = op.inp_ranks concrete_op.out_ranks = op.out_ranks - concrete_op.same_inp_dims = op.same_inp_dims concrete_op.extra_attrs = op.extra_attrs return concrete_op @@ -439,7 +438,6 @@ class BcastBinaryOp(BinaryOpBase): def __init__(self): super().__init__() self.inp_ranks = [int_all(), int_all()] - self.same_inp_dims = False self.bcastable = True def type_transfer(self, input_shapes: List[AbsTensor]) -> List[AbsTensor]: @@ -501,7 +499,6 @@ class Where(TernaryOpBase): def __init__(self): super().__init__() self.inp_ranks = [int_all(), int_all(), int_all()] - self.same_inp_dims = False self.same_inp_dtypes = True self.bcastable = True @@ -1859,6 +1856,7 @@ class Concat(AbsOpBase): MAX_ARITY = 5 MAX_RANK = 5 out_dtypes = [(i,) for i in DTYPE_ALL] + same_inp_dims = True def __str__(self) -> str: return "Concat " + str(self.extra_attrs).replace(":", "=") @@ -1869,7 +1867,6 @@ def __init__(self, arity): self.arity = arity self.inp_ranks = [(int_from(1))] * arity self.out_ranks = [(int_from(1))] - self.same_inp_dims = True def _init_concat_axis(self, input_shapes: List[AbsTensor]) -> int: if "axis" not in self.extra_attrs: