-
Notifications
You must be signed in to change notification settings - Fork 19
Open
Description
Hi, thanks for sharing the worke.
In VSA encoder, code is below:
x = self.pre_mlp(inp)
encoder
attn = torch_scatter.scatter_softmax(self.score(x), inverse, dim=0)
dot = (attn[:, :, None] * x.view(-1, 1, self.dim)).view(-1, self.dim*self.k)
x_ = torch_scatter.scatter_sum(dot, inverse, dim=0)
I just wonder why you write like that " dot = (attn[:, :, None] * x.view(-1, 1, self.dim)).view(-1, self.dim*self.k)",according to paper, A's shape is [n,k], and v's shape is [n,d], so A^T * V should be [k, d], why you can get [n,k,d]? just like code result。
I don't know the meaning of the calculation of dot, which is different with set transformer's ISB.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels