Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions doc/concept.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,15 @@ Therefore, we need to let the model generator know how many arguments / symbolic

### Constraining viable ranks of different input tensors

Operators like `Concat` require input tensors to have the same ranks. Therefore, we need to somehow constraint the input ranks for `self.inp_ranks` as by default ranks in `self.inp_ranks` are *independent*.
Operators like `Concat` require input tensors to have the same shape (and thus ranks).

To do so, we set `self.same_inp_dims = True` in initializer:
To do so, we overload the `same_inp_dims` class variable to `True`:

```python
def __init__(...):
super().__init__()
class Concat(AbsOpBase)
...
self.same_inp_dims = True # But this is not True for Pool2D and many binary operators.
same_inp_dims = True
...
def __init__(...):
...
```
9 changes: 3 additions & 6 deletions nnsmith/abstract/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ def broadcast_to_cons(*shapes: List[Union[z3.ExprRef, int]]) -> List[z3.BoolRef]
class AbsOpBase(ABC):
# number of parameters; None means it's fixed that can be inferred through `signature`.
num_var_param = None
# Require the input dimension sizes to be equivalent.
same_inp_dims = False
# whether this op is broadcastable or not
bcastable = False
# input dtypes: enumerates all possible input dtype combinations. Size of the list is the number of combinations.
Expand All @@ -281,8 +283,6 @@ def __init__(self):
# NOTE: the concrete values of out_ranks are not useful. Just make sure the length is correct.
# NOTE: the output shape of input dimensions should be concretized during the execution.
self.out_ranks = []
# Require the input dimension sizes to be equivalent.
self.same_inp_dims = False
# NOTE: the input of operator constructors are all Union[int, z3.ExprRef].
self.extra_attrs = {}

Expand Down Expand Up @@ -374,7 +374,6 @@ def concretize_op(op: AbsOpBase, model: Optional[z3.ModelRef]) -> AbsOpBase:
concrete_op = type(op)(*values)
concrete_op.inp_ranks = op.inp_ranks
concrete_op.out_ranks = op.out_ranks
concrete_op.same_inp_dims = op.same_inp_dims
concrete_op.extra_attrs = op.extra_attrs

return concrete_op
Expand Down Expand Up @@ -439,7 +438,6 @@ class BcastBinaryOp(BinaryOpBase):
def __init__(self):
super().__init__()
self.inp_ranks = [int_all(), int_all()]
self.same_inp_dims = False
self.bcastable = True

def type_transfer(self, input_shapes: List[AbsTensor]) -> List[AbsTensor]:
Expand Down Expand Up @@ -501,7 +499,6 @@ class Where(TernaryOpBase):
def __init__(self):
super().__init__()
self.inp_ranks = [int_all(), int_all(), int_all()]
self.same_inp_dims = False
self.same_inp_dtypes = True
self.bcastable = True

Expand Down Expand Up @@ -1859,6 +1856,7 @@ class Concat(AbsOpBase):
MAX_ARITY = 5
MAX_RANK = 5
out_dtypes = [(i,) for i in DTYPE_ALL]
same_inp_dims = True

def __str__(self) -> str:
return "Concat " + str(self.extra_attrs).replace(":", "=")
Expand All @@ -1869,7 +1867,6 @@ def __init__(self, arity):
self.arity = arity
self.inp_ranks = [(int_from(1))] * arity
self.out_ranks = [(int_from(1))]
self.same_inp_dims = True

def _init_concat_axis(self, input_shapes: List[AbsTensor]) -> int:
if "axis" not in self.extra_attrs:
Expand Down