From b9aaa8977ec3ea1a67ad4305ffceac8af0457884 Mon Sep 17 00:00:00 2001 From: robrui <128964181+robrui@users.noreply.github.com> Date: Fri, 8 May 2026 02:27:31 +0000 Subject: [PATCH] Fix BlockLinearOperator._getitem slicing with non-square block indices The modulo check used num_blocks instead of block_size, so slices aligned to individual blocks (e.g. bd[:12, :12] with block_size=12) were incorrectly handled. Also, cross-block slices (row blocks != col blocks) fell through to block-level indexing instead of the generic path, producing non-square sub-blocks. Fix uses block_size for alignment checks, detects cross-block slices and falls back to the generic super()._getitem(), and selects blocks from the batch dimension rather than passing block indices as matrix row/col indices. --- .../operators/block_linear_operator.py | 23 ++++++++++---- .../test_block_diag_linear_operator.py | 31 +++++++++++++++++++ 2 files changed, 48 insertions(+), 6 deletions(-) diff --git a/linear_operator/operators/block_linear_operator.py b/linear_operator/operators/block_linear_operator.py index 4a8e9e51..6c996050 100644 --- a/linear_operator/operators/block_linear_operator.py +++ b/linear_operator/operators/block_linear_operator.py @@ -88,15 +88,26 @@ def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: I # Let's make sure that the slice dimensions perfectly correspond with the number of # outputs per input that we have # Otherwise - its too complicated. We'll go with the base case - if (row_start % num_blocks) or (col_start % num_blocks) or (row_end % num_blocks) or (col_end % num_blocks): + block_size = num_rows // num_blocks + if (row_start % block_size) or (col_start % block_size) or (row_end % block_size) or (col_end % block_size): return super()._getitem(row_index, col_index, *batch_indices) - # Otherwise - let's divide the slices by the number of outputs per input - row_index = slice(row_start // num_blocks, row_end // num_blocks, None) - col_index = slice(col_start // num_blocks, col_end // num_blocks, None) + # Compute block-level indices + block_row_idx = slice(row_start // block_size, row_end // block_size, None) + block_col_idx = slice(col_start // block_size, col_end // block_size, None) - # Now we can try the super call! - new_base_linear_op = self.base_linear_op._getitem(row_index, col_index, *batch_indices, _noop_index) + # If the row and column block ranges differ, this is a cross-block slice. + # The block-diagonal structure is lost, so fall back to the general case. + if block_row_idx != block_col_idx: + row_index = slice(row_start, row_end, row_step) + col_index = slice(col_start, col_end, col_step) + return super()._getitem(row_index, col_index, *batch_indices) + + # Select blocks from the base operator's batch dimension. + # block_row_idx selects which blocks to keep; row/col are all (keep per-block matrix intact). + new_base_linear_op = self.base_linear_op._getitem( + slice(None), slice(None), *batch_indices, block_row_idx + ) # Now construct a kernel with those indices return self.__class__(new_base_linear_op, block_dim=-3) diff --git a/test/operators/test_block_diag_linear_operator.py b/test/operators/test_block_diag_linear_operator.py index fb6b8f1a..8195c855 100644 --- a/test/operators/test_block_diag_linear_operator.py +++ b/test/operators/test_block_diag_linear_operator.py @@ -115,3 +115,34 @@ def test_metaclass_constructor(self): if __name__ == "__main__": unittest.main() + + +class TestBlockDiagCrossBlockSlicing(unittest.TestCase): + def test_cross_block_getitem(self): + T, n_blocks = 12, 200 + blocks = torch.randn(n_blocks, T, T) + dense = DenseLinearOperator(blocks) + bd = BlockDiagLinearOperator(dense) + + total = n_blocks * T # 2400 + + # Cross-block slices (row blocks != col blocks) should work + n_train = 150 + sliced = bd[n_train * T:, :n_train * T] + self.assertEqual(sliced.shape, (total - n_train * T, n_train * T)) + + # Same-range block-aligned slices (fixed by the block_size fix) + sliced2 = bd[:T, :T] + self.assertEqual(sliced2.shape, (T, T)) + + # Same-range block-aligned, multiple blocks + sliced3 = bd[500:600, 500:600] + self.assertEqual(sliced3.shape, (100, 100)) + + # Non-aligned, non-square + sliced4 = bd[5:100, 50:200] + self.assertEqual(sliced4.shape, (95, 150)) + + # Single element + sliced5 = bd[0:1, 0:1] + self.assertEqual(sliced5.shape, (1, 1))