From fc8c98adeb3d77f70baddd4a796d24b69e8685da Mon Sep 17 00:00:00 2001 From: Emmanuel Badmus <41684809+emmanuelbadmus@users.noreply.github.com> Date: Wed, 1 Apr 2026 16:39:04 -0400 Subject: [PATCH] fix(normalizers): cast baseMVA to float32 to support MPS backend --- gridfm_graphkit/datasets/normalizers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gridfm_graphkit/datasets/normalizers.py b/gridfm_graphkit/datasets/normalizers.py index 11601a66..2276a2f9 100644 --- a/gridfm_graphkit/datasets/normalizers.py +++ b/gridfm_graphkit/datasets/normalizers.py @@ -228,7 +228,7 @@ def transform(self, data: HeteroData): data.edge_attr_dict[("bus", "connects", "bus")][:, ANG_MIN] *= torch.pi / 180.0 data.edge_attr_dict[("bus", "connects", "bus")][:, ANG_MAX] *= torch.pi / 180.0 data.edge_attr_dict[("bus", "connects", "bus")][:, RATE_A] /= self.baseMVA - data.baseMVA = self.baseMVA + data.baseMVA = torch.tensor(self.baseMVA, dtype=torch.float32) data.is_normalized = True def inverse_transform(self, data: HeteroData):