Skip to content

Commit

Permalink
fix hypergraph D_v bugs ( h[v, e] -> w[e]*h[v, e] )
Browse files Browse the repository at this point in the history
  • Loading branch information
yifanfeng97 committed Dec 27, 2022
1 parent 7293527 commit 7786e2d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
6 changes: 5 additions & 1 deletion dhg/structure/hypergraphs/hypergraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,7 +847,11 @@ def D_v_of_group(self, group_name: str) -> torch.Tensor:
"""
assert group_name in self.group_names, f"The specified {group_name} is not in existing hyperedge groups."
if self.group_cache[group_name].get("D_v") is None:
_tmp = torch.sparse.sum(self.H_of_group(group_name), dim=1).to_dense().clone().view(-1)
H = self.H_of_group(group_name).clone()
w_e = self.W_e_of_group(group_name)._values().clone()
val = w_e[H._indices()[1]] * H._values()
H_ = torch.sparse_coo_tensor(H._indices(), val, size=H.shape, device=self.device).coalesce()
_tmp = torch.sparse.sum(H_, dim=1).to_dense().clone().view(-1)
_num_v = _tmp.size(0)
self.group_cache[group_name]["D_v"] = torch.sparse_coo_tensor(
torch.arange(0, _num_v).view(1, -1).repeat(2, 1),
Expand Down
8 changes: 4 additions & 4 deletions tests/structure/test_hypergraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def test_add_and_remove_group(g1):
def test_deg(g1, g2):
assert g1.deg_v == [2, 2, 2, 1, 1, 1]
assert g1.deg_e == [4, 2, 3]
assert g2.deg_v == [2, 3, 3, 4, 1]
assert g2.deg_v == [1.5, 2, 2, 3, 1]
assert g2.deg_e == [3, 3, 2, 3, 2]


Expand Down Expand Up @@ -463,7 +463,7 @@ def test_W_e_group(g2):
def test_D(g1, g2):
assert (g1.D_v.cpu()._values() == torch.tensor([2, 2, 2, 1, 1, 1])).all()
assert (g1.D_e.cpu()._values() == torch.tensor([4, 2, 3])).all()
assert (g2.D_v.cpu()._values() == torch.tensor([2, 3, 3, 4, 1])).all()
assert (g2.D_v.cpu()._values() == torch.tensor([1.5, 2, 2, 3, 1])).all()
assert (g2.D_e.cpu()._values() == torch.tensor([3, 3, 2, 3, 2])).all()


Expand All @@ -483,11 +483,11 @@ def test_D_neg(g1, g2):
# -1
assert (g1.D_v_neg_1.cpu()._values() == torch.tensor([2, 2, 2, 1, 1, 1]) ** (-1.0)).all()
assert (g1.D_e_neg_1.cpu()._values() == torch.tensor([4, 2, 3]) ** (-1.0)).all()
assert (g2.D_v_neg_1.cpu()._values() == torch.tensor([2, 3, 3, 4, 1]) ** (-1.0)).all()
assert (g2.D_v_neg_1.cpu()._values() == torch.tensor([1.5, 2, 2, 3, 1]) ** (-1.0)).all()
assert (g2.D_e_neg_1.cpu()._values() == torch.tensor([3, 3, 2, 3, 2]) ** (-1.0)).all()
# -1/2
assert (g1.D_v_neg_1_2.cpu()._values() == torch.tensor([2, 2, 2, 1, 1, 1]) ** (-0.5)).all()
assert (g2.D_v_neg_1_2.cpu()._values() == torch.tensor([2, 3, 3, 4, 1]) ** (-0.5)).all()
assert (g2.D_v_neg_1_2.cpu()._values() == torch.tensor([1.5, 2, 2, 3, 1]) ** (-0.5)).all()
# isolated vertex
g3 = Hypergraph(3, [0, 1])
assert (g3.D_v_neg_1.cpu()._values() == torch.tensor([1, 1, 0])).all()
Expand Down

0 comments on commit 7786e2d

Please sign in to comment.