diff --git a/topomodelx/nn/hypergraph/hypergat_layer.py b/topomodelx/nn/hypergraph/hypergat_layer.py index 6f303aa5..1f2d769f 100644 --- a/topomodelx/nn/hypergraph/hypergat_layer.py +++ b/topomodelx/nn/hypergraph/hypergat_layer.py @@ -113,7 +113,7 @@ def attention( Attention weights: one scalar per message between a source and a target cell. """ if mechanism == "node-level": - x_source_per_message = x_source[self.target_index_i] + x_source_per_message = x_source[self.source_index_j] return torch.nn.functional.softmax( torch.matmul( torch.nn.functional.leaky_relu(x_source_per_message), @@ -122,11 +122,11 @@ def attention( dim=1, ) - x_source_per_message = x_source[self.source_index_j] + x_source_per_message = x_source[self.target_index_i] x_target_per_message = ( x_source[self.target_index_i] if x_target is None - else x_target[self.target_index_i] + else x_target[self.source_index_j] ) x_source_target_per_message = torch.nn.functional.leaky_relu( @@ -183,7 +183,7 @@ def forward(self, x_0, incidence_1): Output hyperedge features. """ intra_aggregation = incidence_1.t() @ (x_0 @ self.weight1) - + self.target_index_i, self.source_index_j = incidence_1.indices() attention_values = self.attention(intra_aggregation).squeeze() @@ -200,7 +200,7 @@ def forward(self, x_0, incidence_1): inter_aggregation = incidence_1 @ (messages_on_edges @ self.weight2) attention_values = self.attention( - inter_aggregation, intra_aggregation + inter_aggregation, intra_aggregation, "edge-level" ).squeeze() incidence_with_attention = torch.sparse_coo_tensor( indices=incidence_1.indices(),