Skip to content

Commit

Permalink
fix last layer bn bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
yifanfeng97 committed Aug 29, 2023
1 parent 6278f0d commit 27f944f
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 37 deletions.
6 changes: 4 additions & 2 deletions dhg/nn/convs/graphs/gat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ def forward(self, X: torch.Tensor, g: Graph) -> torch.Tensor:
g (``dhg.Graph``): The graph structure that contains :math:`N_v` vertices.
"""
X = self.theta(X)
if self.bn is not None:
X = self.bn(X)
x_for_src = self.atten_src(X)
x_for_dst = self.atten_dst(X)
e_atten_score = x_for_src[g.e_src] + x_for_dst[g.e_dst]
Expand All @@ -56,6 +54,10 @@ def forward(self, X: torch.Tensor, g: Graph) -> torch.Tensor:
e_atten_score = torch.clamp(e_atten_score, min=0.001, max=5)
# ================================================================================
X = g.v2v(X, aggr="softmax_then_sum", e_weight=e_atten_score)

if not self.is_last:
X = self.act(X)
if self.bn is not None:
X = self.bn(X)
X = self.drop(X)
return X
7 changes: 4 additions & 3 deletions dhg/nn/convs/graphs/gcn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ def forward(self, X: torch.Tensor, g: Graph) -> torch.Tensor:
g (``dhg.Graph``): The graph structure that contains :math:`N` vertices.
"""
X = self.theta(X)
if self.bn is not None:
X = self.bn(X)
X = g.smoothing_with_GCN(X)
if not self.is_last:
X = self.drop(self.act(X))
X = self.act(X)
if self.bn is not None:
X = self.bn(X)
X = self.drop(X)
return X
7 changes: 4 additions & 3 deletions dhg/nn/convs/graphs/graphsage_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ def forward(self, X: torch.Tensor, g: Graph) -> torch.Tensor:
else:
raise NotImplementedError()
X = self.theta(X)
if self.bn is not None:
X = self.bn(X)
if not self.is_last:
X = self.drop(self.act(X))
X = self.act(X)
if self.bn is not None:
X = self.bn(X)
X = self.drop(X)
return X
11 changes: 6 additions & 5 deletions dhg/nn/convs/hypergraphs/dhcf_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ def forward(self, X: torch.Tensor, hg: Hypergraph) -> torch.Tensor:
hg (``dhg.Hypergraph``): The hypergraph structure that contains :math:`N` vertices.
"""
X_ = self.theta(X)
if self.bn is not None:
X_ = self.bn(X_)
X_ = hg.smoothing_with_HGNN(X_) + X
X = hg.smoothing_with_HGNN(X_) + X
if not self.is_last:
X_ = self.drop(self.act(X_))
return X_
X = self.act(X)
if self.bn is not None:
X = self.bn(X)
X = self.drop(X)
return X
7 changes: 4 additions & 3 deletions dhg/nn/convs/hypergraphs/hgnn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,10 @@ def forward(self, X: torch.Tensor, hg: Hypergraph) -> torch.Tensor:
hg (``dhg.Hypergraph``): The hypergraph structure that contains :math:`N` vertices.
"""
X = self.theta(X)
if self.bn is not None:
X = self.bn(X)
X = hg.smoothing_with_HGNN(X)
if not self.is_last:
X = self.drop(self.act(X))
X = self.act(X)
if self.bn is not None:
X = self.bn(X)
X = self.drop(X)
return X
7 changes: 4 additions & 3 deletions dhg/nn/convs/hypergraphs/hgnnp_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,10 @@ def forward(self, X: torch.Tensor, hg: Hypergraph) -> torch.Tensor:
hg (``dhg.Hypergraph``): The hypergraph structure that contains :math:`|\mathcal{V}|` vertices.
"""
X = self.theta(X)
if self.bn is not None:
X = self.bn(X)
X = hg.v2v(X, aggr="mean")
if not self.is_last:
X = self.drop(self.act(X))
X = self.act(X)
if self.bn is not None:
X = self.bn(X)
X = self.drop(X)
return X
7 changes: 4 additions & 3 deletions dhg/nn/convs/hypergraphs/hnhn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,13 @@ def forward(self, X: torch.Tensor, hg: Hypergraph) -> torch.Tensor:
"""
# v -> e
X = self.theta_v2e(X)
if self.bn is not None:
X = self.bn(X)
Y = self.act(hg.v2e(X, aggr="mean"))
# e -> v
Y = self.theta_e2v(Y)
X = hg.e2v(Y, aggr="mean")
if not self.is_last:
X = self.drop(self.act(X))
X = self.act(X)
if self.bn is not None:
X = self.bn(X)
X = self.drop(X)
return X
7 changes: 4 additions & 3 deletions dhg/nn/convs/hypergraphs/hypergcn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ def forward(
``cached_g`` (``dhg.Graph``): The pre-transformed graph structure from the hypergraph structure that contains :math:`N` vertices. If not provided, the graph structure will be transformed for each forward time. Defaults to ``None``.
"""
X = self.theta(X)
if self.bn is not None:
X = self.bn(X)
if cached_g is None:
g = Graph.from_hypergraph_hypergcn(
hg, X, self.use_mediator, device=X.device
Expand All @@ -59,5 +57,8 @@ def forward(
else:
X = cached_g.smoothing_with_GCN(X)
if not self.is_last:
X = self.drop(self.act(X))
X = self.act(X)
if self.bn is not None:
X = self.bn(X)
X = self.drop(X)
return X
29 changes: 17 additions & 12 deletions dhg/nn/convs/hypergraphs/unignn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ def forward(self, X: torch.Tensor, hg: Hypergraph) -> torch.Tensor:
hg (``dhg.Hypergraph``): The hypergraph structure that contains :math:`|\mathcal{V}|` vertices.
"""
X = self.theta(X)
if self.bn is not None:
X = self.bn(X)
Y = hg.v2e(X, aggr="mean")
# ===============================================
# compute the special degree of hyperedges
Expand All @@ -71,8 +69,12 @@ def forward(self, X: torch.Tensor, hg: Hypergraph) -> torch.Tensor:
# ===============================================
X = hg.e2v(Y, aggr="sum")
X = torch.sparse.mm(hg.D_v_neg_1_2, X)

if not self.is_last:
X = self.drop(self.act(X))
X = self.act(X)
if self.bn is not None:
X = self.bn(X)
X = self.drop(X)
return X


Expand Down Expand Up @@ -128,8 +130,6 @@ def forward(self, X: torch.Tensor, hg: Hypergraph) -> torch.Tensor:
hg (``dhg.Hypergraph``): The hypergraph structure that contains :math:`|\mathcal{V}|` vertices.
"""
X = self.theta(X)
if self.bn is not None:
X = self.bn(X)
Y = hg.v2e(X, aggr="mean")
# ===============================================
alpha_e = self.atten_e(Y)
Expand All @@ -140,8 +140,12 @@ def forward(self, X: torch.Tensor, hg: Hypergraph) -> torch.Tensor:
e_atten_score = torch.clamp(e_atten_score, min=0.001, max=5)
# ================================================================================
X = hg.e2v(Y, aggr="softmax_then_sum", e2v_weight=e_atten_score)

if not self.is_last:
X = self.act(X)
if self.bn is not None:
X = self.bn(X)
X = self.drop(X)
return X


Expand Down Expand Up @@ -196,12 +200,13 @@ def forward(self, X: torch.Tensor, hg: Hypergraph) -> torch.Tensor:
hg (``dhg.Hypergraph``): The hypergraph structure that contains :math:`|\mathcal{V}|` vertices.
"""
X = self.theta(X)
if self.bn is not None:
X = self.bn(X)
Y = hg.v2e(X, aggr="mean")
X = hg.e2v(Y, aggr="sum") + X
if not self.is_last:
X = self.drop(self.act(X))
X = self.act(X)
if self.bn is not None:
X = self.bn(X)
X = self.drop(X)
return X


Expand Down Expand Up @@ -265,11 +270,11 @@ def forward(self, X: torch.Tensor, hg: Hypergraph) -> torch.Tensor:
hg (``dhg.Hypergraph``): The hypergraph structure that contains :math:`|\mathcal{V}|` vertices.
"""
X = self.theta(X)
if self.bn is not None:
X = self.bn(X)
Y = hg.v2e(X, aggr="mean")
X = (1 + self.eps) * hg.e2v(Y, aggr="sum") + X
if not self.is_last:
X = self.drop(self.act(X))
X = self.act(X)
if self.bn is not None:
X = self.bn(X)
X = self.drop(X)
return X

0 comments on commit 27f944f

Please sign in to comment.