Skip to content

Commit b618909

Browse files
committed
Preserve block order before _JIT_ENTRY
1 parent a20783f commit b618909

1 file changed

Lines changed: 35 additions & 20 deletions

File tree

Tools/jit/_optimizers.py

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,11 @@ def _ensure_hot_fallthrough(self, block: _Block) -> None:
505505
self._insert_fallthrough_bridge(block, fallthrough)
506506

507507
def _layout_units(self) -> list[tuple[bool, list[_Block]]]:
508+
return self._layout_units_from(list(self._layout_blocks()))
509+
510+
def _layout_units_from(
511+
self, blocks: list[_Block]
512+
) -> list[tuple[bool, list[_Block]]]:
508513
continuation = self._continuation()
509514
cold_start = self._cold_start_block()
510515
units: list[tuple[bool, list[_Block]]] = []
@@ -519,7 +524,7 @@ def finish_unit() -> None:
519524
units.append((layout_hot, unit))
520525
unit = []
521526

522-
for block in self._layout_blocks():
527+
for block in blocks:
523528
if block is continuation or block is cold_start:
524529
finish_unit()
525530
continue
@@ -538,16 +543,23 @@ def _relink_blocks(self, blocks: list[_Block]) -> None:
538543
if blocks:
539544
blocks[-1].link = None
540545

541-
def _partition_hot_cold_blocks(self) -> None:
542-
# The entry point must remain in the hot layout, even when it can't
543-
# reach _JIT_CONTINUE. The stencil parser expects _JIT_ENTRY at code
544-
# offset 0.
545-
entry_label = f"{self.symbol_prefix}_JIT_ENTRY"
546-
for block in self._layout_blocks():
547-
if block.label == entry_label:
548-
block.layout_hot = True
546+
def _split_at_entry(self) -> tuple[list[_Block], list[_Block]]:
547+
entry = self._lookup_label(f"{self.symbol_prefix}_JIT_ENTRY")
548+
layout_blocks = list(self._layout_blocks())
549+
index = layout_blocks.index(entry)
550+
return layout_blocks[:index], layout_blocks[index:]
549551

550-
for block in list(self._layout_blocks()):
552+
def _partition_hot_cold_blocks(self) -> None:
553+
# The entry point must remain first in the partitioned JIT body, even
554+
# when it can't reach _JIT_CONTINUE. Blocks before _JIT_ENTRY are
555+
# assembler prefix material; keep their original order before
556+
# partitioning the JIT body.
557+
prefix_blocks, body_blocks = self._split_at_entry()
558+
for block in prefix_blocks:
559+
block.layout_hot = True
560+
body_blocks[0].layout_hot = True
561+
562+
for block in list(body_blocks):
551563
self._ensure_hot_fallthrough(block)
552564

553565
continuation = self._continuation()
@@ -557,27 +569,24 @@ def _partition_hot_cold_blocks(self) -> None:
557569
cold_start.layout_hot = False
558570
cold_start.fallthrough = True
559571

560-
units = self._layout_units()
572+
prefix_blocks, body_blocks = self._split_at_entry()
573+
for block in prefix_blocks:
574+
block.layout_hot = True
575+
units = self._layout_units_from(body_blocks)
561576
hot_blocks = [
562577
block for layout_hot, unit in units if layout_hot for block in unit
563578
]
564579
cold_blocks = [
565580
block for layout_hot, unit in units if not layout_hot for block in unit
566581
]
567582
blocks = [
583+
*prefix_blocks,
568584
*hot_blocks,
569585
continuation,
570586
cold_start,
571587
*cold_blocks,
572588
*self._metadata_blocks(),
573589
]
574-
if self._root in blocks and blocks[0] is not self._root:
575-
assert self._root.label is None
576-
assert not self._root.instructions
577-
assert self._root.target is None
578-
assert self._root.fallthrough
579-
blocks.remove(self._root)
580-
blocks.insert(0, self._root)
581590
self._relink_blocks(blocks)
582591
linked_blocks = list(self._blocks())
583592
assert linked_blocks[0] is self._root
@@ -762,25 +771,31 @@ def _find_live_blocks(self) -> set[_Block]:
762771
def _remove_unreachable(self) -> None:
763772
live = self._find_live_blocks()
764773
continuation = self._continuation()
774+
entry = self._lookup_label(f"{self.symbol_prefix}_JIT_ENTRY")
765775
cont_or_cold_blocks = {continuation}
766776
if self._cold_start is not None:
767777
cont_or_cold_blocks.add(self._cold_start)
768778
# Keep only the original assembler tail. Cold code after _JIT_CONTINUE
769779
# is ordinary code and can be removed when unreachable.
770780
prev: _Block | None = None
771-
block = self._root
781+
block: _Block | None = self._root
782+
seen_entry = False
772783
# We now walk the whole list, so keep explicit sentinel checks in place
773784
# of the old "stop at _JIT_CONTINUE" loop invariant.
774785
seen_continuation = False
775786
seen_cold_start = self._cold_start is None
776787
while block is not None:
788+
preserve_prefix = not seen_entry
789+
if block is entry:
790+
seen_entry = True
777791
if block is continuation:
778792
seen_continuation = True
779793
if block is self._cold_start:
780794
seen_cold_start = True
781795
next = block.link
782796
if (
783-
block not in live
797+
not preserve_prefix
798+
and block not in live
784799
and block not in cont_or_cold_blocks
785800
and not block.is_metadata
786801
and prev is not None

0 commit comments

Comments
 (0)