Skip to content

Commit

Permalink
fix hypergraph device bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
yifanfeng97 committed Sep 15, 2023
1 parent aa25822 commit 55df18c
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions dhg/structure/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def _fetch_H_of_group(self, direction: str, group_name: str):
torch.ones(len(v_idx)),
torch.Size([self.num_v, num_e]),
device=self.device,
).coalesce()
).coalesce().to(self.device)
return H

def _fetch_R_of_group(self, direction: str, group_name: str):
Expand All @@ -587,7 +587,7 @@ def _fetch_R_of_group(self, direction: str, group_name: str):
w_list.extend(self._raw_groups[group_name][e][f"w_{direction}"])
R = torch.sparse_coo_tensor(
torch.vstack([v_idx, e_idx]), torch.tensor(w_list), torch.Size([self.num_v, num_e]), device=self.device,
).coalesce()
).coalesce().to(self.device)
return R

def _fetch_W_of_group(self, group_name: str):
Expand All @@ -598,7 +598,7 @@ def _fetch_W_of_group(self, group_name: str):
"""
assert group_name in self.group_names, f"The specified {group_name} is not in existing hyperedge groups."
w_list = [content["w_e"] for content in self._raw_groups[group_name].values()]
W = torch.tensor(w_list, device=self.device).view((-1, 1))
W = torch.tensor(w_list, device=self.device).view((-1, 1)).to(self.device)
return W

# some structure modification functions
Expand Down Expand Up @@ -798,7 +798,7 @@ def W_v(self) -> torch.Tensor:
r"""Return the vertex weight matrix of the hypergraph.
"""
if self.cache["W_v"] is None:
self.cache["W_v"] = torch.tensor(self.v_weight, dtype=torch.float, device=self.device).view(-1, 1)
self.cache["W_v"] = torch.tensor(self.v_weight, dtype=torch.float, device=self.device).view(-1, 1).to(self.device)
return self.cache["W_v"]

@property
Expand Down
4 changes: 2 additions & 2 deletions dhg/structure/hypergraphs/hypergraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,7 @@ def H(self) -> torch.Tensor:
r"""Return the hypergraph incidence matrix :math:`\mathbf{H}` with ``torch.sparse_coo_tensor`` format.
"""
if self.cache.get("H") is None:
self.cache["H"] = self.H_v2e
self.cache["H"] = self.H_v2e.to(self.device)
return self.cache["H"]

def H_of_group(self, group_name: str) -> torch.Tensor:
Expand All @@ -792,7 +792,7 @@ def H_T(self) -> torch.Tensor:
r"""Return the transpose of the hypergraph incidence matrix :math:`\mathbf{H}^\top` with ``torch.sparse_coo_tensor`` format.
"""
if self.cache.get("H_T") is None:
self.cache["H_T"] = self.H.t()
self.cache["H_T"] = self.H.t().to(self.device)
return self.cache["H_T"]

def H_T_of_group(self, group_name: str) -> torch.Tensor:
Expand Down

0 comments on commit 55df18c

Please sign in to comment.