@@ -505,6 +505,11 @@ def _ensure_hot_fallthrough(self, block: _Block) -> None:
505505 self ._insert_fallthrough_bridge (block , fallthrough )
506506
507507 def _layout_units (self ) -> list [tuple [bool , list [_Block ]]]:
508+ return self ._layout_units_from (list (self ._layout_blocks ()))
509+
510+ def _layout_units_from (
511+ self , blocks : list [_Block ]
512+ ) -> list [tuple [bool , list [_Block ]]]:
508513 continuation = self ._continuation ()
509514 cold_start = self ._cold_start_block ()
510515 units : list [tuple [bool , list [_Block ]]] = []
@@ -519,7 +524,7 @@ def finish_unit() -> None:
519524 units .append ((layout_hot , unit ))
520525 unit = []
521526
522- for block in self . _layout_blocks () :
527+ for block in blocks :
523528 if block is continuation or block is cold_start :
524529 finish_unit ()
525530 continue
@@ -538,16 +543,23 @@ def _relink_blocks(self, blocks: list[_Block]) -> None:
538543 if blocks :
539544 blocks [- 1 ].link = None
540545
541- def _partition_hot_cold_blocks (self ) -> None :
542- # The entry point must remain in the hot layout, even when it can't
543- # reach _JIT_CONTINUE. The stencil parser expects _JIT_ENTRY at code
544- # offset 0.
545- entry_label = f"{ self .symbol_prefix } _JIT_ENTRY"
546- for block in self ._layout_blocks ():
547- if block .label == entry_label :
548- block .layout_hot = True
546+ def _split_at_entry (self ) -> tuple [list [_Block ], list [_Block ]]:
547+ entry = self ._lookup_label (f"{ self .symbol_prefix } _JIT_ENTRY" )
548+ layout_blocks = list (self ._layout_blocks ())
549+ index = layout_blocks .index (entry )
550+ return layout_blocks [:index ], layout_blocks [index :]
549551
550- for block in list (self ._layout_blocks ()):
552+ def _partition_hot_cold_blocks (self ) -> None :
553+ # The entry point must remain first in the partitioned JIT body, even
554+ # when it can't reach _JIT_CONTINUE. Blocks before _JIT_ENTRY are
555+ # assembler prefix material; keep their original order before
556+ # partitioning the JIT body.
557+ prefix_blocks , body_blocks = self ._split_at_entry ()
558+ for block in prefix_blocks :
559+ block .layout_hot = True
560+ body_blocks [0 ].layout_hot = True
561+
562+ for block in list (body_blocks ):
551563 self ._ensure_hot_fallthrough (block )
552564
553565 continuation = self ._continuation ()
@@ -557,27 +569,24 @@ def _partition_hot_cold_blocks(self) -> None:
557569 cold_start .layout_hot = False
558570 cold_start .fallthrough = True
559571
560- units = self ._layout_units ()
572+ prefix_blocks , body_blocks = self ._split_at_entry ()
573+ for block in prefix_blocks :
574+ block .layout_hot = True
575+ units = self ._layout_units_from (body_blocks )
561576 hot_blocks = [
562577 block for layout_hot , unit in units if layout_hot for block in unit
563578 ]
564579 cold_blocks = [
565580 block for layout_hot , unit in units if not layout_hot for block in unit
566581 ]
567582 blocks = [
583+ * prefix_blocks ,
568584 * hot_blocks ,
569585 continuation ,
570586 cold_start ,
571587 * cold_blocks ,
572588 * self ._metadata_blocks (),
573589 ]
574- if self ._root in blocks and blocks [0 ] is not self ._root :
575- assert self ._root .label is None
576- assert not self ._root .instructions
577- assert self ._root .target is None
578- assert self ._root .fallthrough
579- blocks .remove (self ._root )
580- blocks .insert (0 , self ._root )
581590 self ._relink_blocks (blocks )
582591 linked_blocks = list (self ._blocks ())
583592 assert linked_blocks [0 ] is self ._root
@@ -762,25 +771,31 @@ def _find_live_blocks(self) -> set[_Block]:
762771 def _remove_unreachable (self ) -> None :
763772 live = self ._find_live_blocks ()
764773 continuation = self ._continuation ()
774+ entry = self ._lookup_label (f"{ self .symbol_prefix } _JIT_ENTRY" )
765775 cont_or_cold_blocks = {continuation }
766776 if self ._cold_start is not None :
767777 cont_or_cold_blocks .add (self ._cold_start )
768778 # Keep only the original assembler tail. Cold code after _JIT_CONTINUE
769779 # is ordinary code and can be removed when unreachable.
770780 prev : _Block | None = None
771- block = self ._root
781+ block : _Block | None = self ._root
782+ seen_entry = False
772783 # We now walk the whole list, so keep explicit sentinel checks in place
773784 # of the old "stop at _JIT_CONTINUE" loop invariant.
774785 seen_continuation = False
775786 seen_cold_start = self ._cold_start is None
776787 while block is not None :
788+ preserve_prefix = not seen_entry
789+ if block is entry :
790+ seen_entry = True
777791 if block is continuation :
778792 seen_continuation = True
779793 if block is self ._cold_start :
780794 seen_cold_start = True
781795 next = block .link
782796 if (
783- block not in live
797+ not preserve_prefix
798+ and block not in live
784799 and block not in cont_or_cold_blocks
785800 and not block .is_metadata
786801 and prev is not None
0 commit comments