diff --git a/linear_operator/operators/block_linear_operator.py b/linear_operator/operators/block_linear_operator.py index 4a8e9e51..6c996050 100644 --- a/linear_operator/operators/block_linear_operator.py +++ b/linear_operator/operators/block_linear_operator.py @@ -88,15 +88,26 @@ def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: I # Let's make sure that the slice dimensions perfectly correspond with the number of # outputs per input that we have # Otherwise - its too complicated. We'll go with the base case - if (row_start % num_blocks) or (col_start % num_blocks) or (row_end % num_blocks) or (col_end % num_blocks): + block_size = num_rows // num_blocks + if (row_start % block_size) or (col_start % block_size) or (row_end % block_size) or (col_end % block_size): return super()._getitem(row_index, col_index, *batch_indices) - # Otherwise - let's divide the slices by the number of outputs per input - row_index = slice(row_start // num_blocks, row_end // num_blocks, None) - col_index = slice(col_start // num_blocks, col_end // num_blocks, None) + # Compute block-level indices + block_row_idx = slice(row_start // block_size, row_end // block_size, None) + block_col_idx = slice(col_start // block_size, col_end // block_size, None) - # Now we can try the super call! - new_base_linear_op = self.base_linear_op._getitem(row_index, col_index, *batch_indices, _noop_index) + # If the row and column block ranges differ, this is a cross-block slice. + # The block-diagonal structure is lost, so fall back to the general case. + if block_row_idx != block_col_idx: + row_index = slice(row_start, row_end, row_step) + col_index = slice(col_start, col_end, col_step) + return super()._getitem(row_index, col_index, *batch_indices) + + # Select blocks from the base operator's batch dimension. + # block_row_idx selects which blocks to keep; row/col are all (keep per-block matrix intact). + new_base_linear_op = self.base_linear_op._getitem( + slice(None), slice(None), *batch_indices, block_row_idx + ) # Now construct a kernel with those indices return self.__class__(new_base_linear_op, block_dim=-3) diff --git a/test/operators/test_block_diag_linear_operator.py b/test/operators/test_block_diag_linear_operator.py index fb6b8f1a..8195c855 100644 --- a/test/operators/test_block_diag_linear_operator.py +++ b/test/operators/test_block_diag_linear_operator.py @@ -115,3 +115,34 @@ def test_metaclass_constructor(self): if __name__ == "__main__": unittest.main() + + +class TestBlockDiagCrossBlockSlicing(unittest.TestCase): + def test_cross_block_getitem(self): + T, n_blocks = 12, 200 + blocks = torch.randn(n_blocks, T, T) + dense = DenseLinearOperator(blocks) + bd = BlockDiagLinearOperator(dense) + + total = n_blocks * T # 2400 + + # Cross-block slices (row blocks != col blocks) should work + n_train = 150 + sliced = bd[n_train * T:, :n_train * T] + self.assertEqual(sliced.shape, (total - n_train * T, n_train * T)) + + # Same-range block-aligned slices (fixed by the block_size fix) + sliced2 = bd[:T, :T] + self.assertEqual(sliced2.shape, (T, T)) + + # Same-range block-aligned, multiple blocks + sliced3 = bd[500:600, 500:600] + self.assertEqual(sliced3.shape, (100, 100)) + + # Non-aligned, non-square + sliced4 = bd[5:100, 50:200] + self.assertEqual(sliced4.shape, (95, 150)) + + # Single element + sliced5 = bd[0:1, 0:1] + self.assertEqual(sliced5.shape, (1, 1))