2424from collections import OrderedDict
2525from copy import deepcopy
2626from functools import cmp_to_key
27- from typing import Any , Dict , Generator , List , Optional , Union
27+ from typing import Any , Dict , Generator , Iterable , List , Optional , Tuple , Union
2828
2929from sparseml .optim .modifier import BaseModifier , BaseObject , ModifierProp
3030from sparseml .sparsification .types import SparsificationTypes
@@ -343,47 +343,15 @@ def min_epochs(self) -> int:
343343 """
344344 :return: the minimum epochs required by any of the modifiers under the manager
345345 """
346- vals = []
347- vals .extend (
348- [
349- math .floor (mod .start_epoch )
350- for mod in self .iter_modifiers ()
351- if mod .start_epoch > - 1
352- ]
353- )
354- vals .extend (
355- [
356- math .floor (mod .end_epoch )
357- for mod in self .iter_modifiers ()
358- if mod .end_epoch > - 1
359- ]
360- )
361-
362- return min (vals ) if len (vals ) > 0 else - 1
346+ return _min_modifier_epoch (self .iter_modifiers ())
363347
364348 @ModifierProp (serializable = False )
365349 def max_epochs (self ) -> int :
366350 """
367351 :return: the maximum number of epochs required by any of the modifiers
368352 under the manager
369353 """
370- vals = []
371- vals .extend (
372- [
373- math .ceil (mod .start_epoch )
374- for mod in self .iter_modifiers ()
375- if mod .start_epoch > - 1
376- ]
377- )
378- vals .extend (
379- [
380- math .ceil (mod .end_epoch )
381- for mod in self .iter_modifiers ()
382- if mod .end_epoch > - 1
383- ]
384- )
385-
386- return max (vals ) if len (vals ) > 0 else - 1
354+ return _max_modifier_epoch (self .iter_modifiers ())
387355
388356 def save (self , file_path : str , include_metadata : bool = True ):
389357 """
@@ -561,6 +529,44 @@ def qat_active(self, epoch: float) -> bool:
561529 else False
562530 )
563531
532+ def get_start_end_epochs (self ) -> Dict [str , Tuple [float , float ]]:
533+ """
534+ Return an OrderedDict mapping each stage to its min and max epoch. If not a
535+ staged manager, map 'all' to the the min and max epochs
536+ """
537+ if isinstance (self .modifiers , List ):
538+ return OrderedDict ({"all" : (self .min_epochs , self .max_epochs )})
539+ else :
540+ stage_max_min = OrderedDict ()
541+ for stage , mod_list in self .modifiers .items ():
542+ epoch_min = _min_modifier_epoch (mod_list )
543+ epoch_max = _max_modifier_epoch (mod_list )
544+ stage_max_min [stage ] = (epoch_min , epoch_max )
545+
546+ # post-process to replace -1's with their real values
547+ epochs_list = list (stage_max_min .values ())
548+ for i , (stage , epochs ) in enumerate (stage_max_min .items ()):
549+ # replace start epochs that are -1 with the last epoch of the previous
550+ # stage, or 0 if it's the first stage
551+ if epochs [0 ] == - 1 :
552+ stage_max_min [stage ][0 ] = epochs_list [i - 1 ][1 ] if i > 0 else 0
553+ # replace end epochs that are -1 with the next stage's start epoch,
554+ # unless it's the last stage
555+ if epochs [1 ] == - 1 and i < len (epochs_list ) - 1 :
556+ stage_max_min [stage ][1 ] = epochs_list [i + 1 ][0 ]
557+
558+ return stage_max_min
559+
560+ def get_last_start_epoch (self ) -> float :
561+ """
562+ Return the start epoch of the last stage in the recipe. Useful for applying
563+ recipes at the correct epoch in a staged run
564+ """
565+ stage_max_min = self .get_start_end_epochs ()
566+ last_stage_epochs = stage_max_min [next (reversed (stage_max_min ))]
567+ last_start_epoch = last_stage_epochs [0 ]
568+ return last_start_epoch if last_start_epoch > - 1 else 0
569+
564570 def _info_log_metadata (self ):
565571 metadata_str = json .dumps (self ._metadata , indent = 1 )
566572 _LOGGER .debug (f"Created recipe manager with metadata: { metadata_str } " )
@@ -586,3 +592,28 @@ def _nested_dict_to_lines(
586592 # reached maximum nesting level.
587593 yaml_str_lines .append (indentation * nesting_depth + f"{ key } : { value } " )
588594 return yaml_str_lines
595+
596+
597+ def _min_modifier_epoch (modifiers : Iterable [BaseModifier ]) -> float :
598+ """
599+ :return: the minimum epochs required by any of the modifiers provided
600+ """
601+ vals = [math .floor (mod .start_epoch ) for mod in modifiers if mod .start_epoch > - 1 ]
602+
603+ return min (vals ) if len (vals ) > 0 else - 1
604+
605+
606+ def _max_modifier_epoch (modifiers : Iterable [BaseModifier ]) -> float :
607+ """
608+ :return: the maximum number of epochs required by any of the modifiers provided
609+ """
610+ # save modifiers as list so it can iterated over multiple times
611+ modifiers = [mod for mod in modifiers ]
612+
613+ vals = []
614+ vals .extend (
615+ [math .ceil (mod .start_epoch ) for mod in modifiers if mod .start_epoch > - 1 ]
616+ )
617+ vals .extend ([math .ceil (mod .end_epoch ) for mod in modifiers if mod .end_epoch > - 1 ])
618+
619+ return max (vals ) if len (vals ) > 0 else - 1
0 commit comments