From b5993e09c3107c6f4ee119e775a141a0a15805c7 Mon Sep 17 00:00:00 2001 From: Tiantian Date: Tue, 10 Jun 2025 16:53:33 +1200 Subject: [PATCH 1/2] fixed bug to change second attention calculation for HyperGat layer to be edge level instead of node level --- topomodelx/nn/hypergraph/hypergat_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/topomodelx/nn/hypergraph/hypergat_layer.py b/topomodelx/nn/hypergraph/hypergat_layer.py index 6f303aa5..a6390318 100644 --- a/topomodelx/nn/hypergraph/hypergat_layer.py +++ b/topomodelx/nn/hypergraph/hypergat_layer.py @@ -200,7 +200,7 @@ def forward(self, x_0, incidence_1): inter_aggregation = incidence_1 @ (messages_on_edges @ self.weight2) attention_values = self.attention( - inter_aggregation, intra_aggregation + inter_aggregation, intra_aggregation, "edge-level" ).squeeze() incidence_with_attention = torch.sparse_coo_tensor( indices=incidence_1.indices(), From aa9ecd20a742154006521ff19a619a26ff0ff10e Mon Sep 17 00:00:00 2001 From: Tiantian Date: Wed, 11 Jun 2025 13:42:11 +1200 Subject: [PATCH 2/2] fixed error with source and target index in attention function --- topomodelx/nn/hypergraph/hypergat_layer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/topomodelx/nn/hypergraph/hypergat_layer.py b/topomodelx/nn/hypergraph/hypergat_layer.py index a6390318..1f2d769f 100644 --- a/topomodelx/nn/hypergraph/hypergat_layer.py +++ b/topomodelx/nn/hypergraph/hypergat_layer.py @@ -113,7 +113,7 @@ def attention( Attention weights: one scalar per message between a source and a target cell. """ if mechanism == "node-level": - x_source_per_message = x_source[self.target_index_i] + x_source_per_message = x_source[self.source_index_j] return torch.nn.functional.softmax( torch.matmul( torch.nn.functional.leaky_relu(x_source_per_message), @@ -122,11 +122,11 @@ def attention( dim=1, ) - x_source_per_message = x_source[self.source_index_j] + x_source_per_message = x_source[self.target_index_i] x_target_per_message = ( x_source[self.target_index_i] if x_target is None - else x_target[self.target_index_i] + else x_target[self.source_index_j] ) x_source_target_per_message = torch.nn.functional.leaky_relu( @@ -183,7 +183,7 @@ def forward(self, x_0, incidence_1): Output hyperedge features. """ intra_aggregation = incidence_1.t() @ (x_0 @ self.weight1) - + self.target_index_i, self.source_index_j = incidence_1.indices() attention_values = self.attention(intra_aggregation).squeeze()