From 55df18ce4838d1614693e01686d3f2896eb6d94f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yifan=20Feng=28=E4=B8=B0=E4=B8=80=E5=B8=86=29?= Date: Fri, 15 Sep 2023 17:20:07 +0800 Subject: [PATCH] fix hypergraph device bugs --- dhg/structure/base.py | 8 ++++---- dhg/structure/hypergraphs/hypergraph.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/dhg/structure/base.py b/dhg/structure/base.py index cf91bea..cc89e8b 100644 --- a/dhg/structure/base.py +++ b/dhg/structure/base.py @@ -562,7 +562,7 @@ def _fetch_H_of_group(self, direction: str, group_name: str): torch.ones(len(v_idx)), torch.Size([self.num_v, num_e]), device=self.device, - ).coalesce() + ).coalesce().to(self.device) return H def _fetch_R_of_group(self, direction: str, group_name: str): @@ -587,7 +587,7 @@ def _fetch_R_of_group(self, direction: str, group_name: str): w_list.extend(self._raw_groups[group_name][e][f"w_{direction}"]) R = torch.sparse_coo_tensor( torch.vstack([v_idx, e_idx]), torch.tensor(w_list), torch.Size([self.num_v, num_e]), device=self.device, - ).coalesce() + ).coalesce().to(self.device) return R def _fetch_W_of_group(self, group_name: str): @@ -598,7 +598,7 @@ def _fetch_W_of_group(self, group_name: str): """ assert group_name in self.group_names, f"The specified {group_name} is not in existing hyperedge groups." w_list = [content["w_e"] for content in self._raw_groups[group_name].values()] - W = torch.tensor(w_list, device=self.device).view((-1, 1)) + W = torch.tensor(w_list, device=self.device).view((-1, 1)).to(self.device) return W # some structure modification functions @@ -798,7 +798,7 @@ def W_v(self) -> torch.Tensor: r"""Return the vertex weight matrix of the hypergraph. """ if self.cache["W_v"] is None: - self.cache["W_v"] = torch.tensor(self.v_weight, dtype=torch.float, device=self.device).view(-1, 1) + self.cache["W_v"] = torch.tensor(self.v_weight, dtype=torch.float, device=self.device).view(-1, 1).to(self.device) return self.cache["W_v"] @property diff --git a/dhg/structure/hypergraphs/hypergraph.py b/dhg/structure/hypergraphs/hypergraph.py index 851a625..c8171ad 100644 --- a/dhg/structure/hypergraphs/hypergraph.py +++ b/dhg/structure/hypergraphs/hypergraph.py @@ -773,7 +773,7 @@ def H(self) -> torch.Tensor: r"""Return the hypergraph incidence matrix :math:`\mathbf{H}` with ``torch.sparse_coo_tensor`` format. """ if self.cache.get("H") is None: - self.cache["H"] = self.H_v2e + self.cache["H"] = self.H_v2e.to(self.device) return self.cache["H"] def H_of_group(self, group_name: str) -> torch.Tensor: @@ -792,7 +792,7 @@ def H_T(self) -> torch.Tensor: r"""Return the transpose of the hypergraph incidence matrix :math:`\mathbf{H}^\top` with ``torch.sparse_coo_tensor`` format. """ if self.cache.get("H_T") is None: - self.cache["H_T"] = self.H.t() + self.cache["H_T"] = self.H.t().to(self.device) return self.cache["H_T"] def H_T_of_group(self, group_name: str) -> torch.Tensor: