-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathmodel.py
More file actions
107 lines (99 loc) · 3.71 KB
/
model.py
File metadata and controls
107 lines (99 loc) · 3.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch.nn as nn
import torch
import torch.nn.functional as F
import torch_sparse
from torch import FloatTensor
class H2GCN(nn.Module):
def __init__(
self,
feat_dim: int,
hidden_dim: int,
class_dim: int,
k: int = 2,
dropout: float = 0.5,
use_relu: bool = True
):
super(H2GCN, self).__init__()
self.dropout = dropout
self.k = k
self.act = F.relu if use_relu else lambda x: x
self.use_relu = use_relu
self.w_embed = nn.Parameter(
torch.zeros(size=(feat_dim, hidden_dim)),
requires_grad=True
)
self.w_classify = nn.Parameter(
torch.zeros(size=((2 ** (self.k + 1) - 1) * hidden_dim, class_dim)),
requires_grad=True
)
self.params = [self.w_embed, self.w_classify]
self.initialized = False
self.a1 = None
self.a2 = None
self.reset_parameter()
def reset_parameter(self):
nn.init.xavier_uniform_(self.w_embed)
nn.init.xavier_uniform_(self.w_classify)
@staticmethod
def _indicator(sp_tensor: torch.sparse.Tensor) -> torch.sparse.Tensor:
csp = sp_tensor.coalesce()
return torch.sparse_coo_tensor(
indices=csp.indices(),
values=torch.where(csp.values() > 0, 1, 0),
size=csp.size(),
dtype=torch.float
)
@staticmethod
def _spspmm(sp1: torch.sparse.Tensor, sp2: torch.sparse.Tensor) -> torch.sparse.Tensor:
assert sp1.shape[1] == sp2.shape[0], 'Cannot multiply size %s with %s' % (sp1.shape, sp2.shape)
sp1, sp2 = sp1.coalesce(), sp2.coalesce()
index1, value1 = sp1.indices(), sp1.values()
index2, value2 = sp2.indices(), sp2.values()
m, n, k = sp1.shape[0], sp1.shape[1], sp2.shape[1]
indices, values = torch_sparse.spspmm(index1, value1, index2, value2, m, n, k)
return torch.sparse_coo_tensor(
indices=indices,
values=values,
size=(m, k),
dtype=torch.float
)
@classmethod
def _adj_norm(cls, adj: torch.sparse.Tensor) -> torch.sparse.Tensor:
n = adj.size(0)
d_diag = torch.pow(torch.sparse.sum(adj, dim=1).values(), -0.5)
d_diag = torch.where(torch.isinf(d_diag), torch.full_like(d_diag, 0), d_diag)
d_tiled = torch.sparse_coo_tensor(
indices=[list(range(n)), list(range(n))],
values=d_diag,
size=(n, n)
)
return cls._spspmm(cls._spspmm(d_tiled, adj), d_tiled)
def _prepare_prop(self, adj):
n = adj.size(0)
device = adj.device
self.initialized = True
sp_eye = torch.sparse_coo_tensor(
indices=[list(range(n)), list(range(n))],
values=[1.0] * n,
size=(n, n),
dtype=torch.float
).to(device)
# initialize A1, A2
a1 = self._indicator(adj - sp_eye)
a2 = self._indicator(self._spspmm(adj, adj) - adj - sp_eye)
# norm A1 A2
self.a1 = self._adj_norm(a1)
self.a2 = self._adj_norm(a2)
def forward(self, adj: torch.sparse.Tensor, x: FloatTensor) -> FloatTensor:
if not self.initialized:
self._prepare_prop(adj)
# H2GCN propagation
rs = [self.act(torch.mm(x, self.w_embed))]
for i in range(self.k):
r_last = rs[-1]
r1 = torch.spmm(self.a1, r_last)
r2 = torch.spmm(self.a2, r_last)
rs.append(self.act(torch.cat([r1, r2], dim=1)))
r_final = torch.cat(rs, dim=1)
r_final = F.dropout(r_final, self.dropout, training=self.training)
return torch.softmax(torch.mm(r_final, self.w_classify), dim=1)