diff --git a/sagemaker-train/src/sagemaker/train/dpo_trainer.py b/sagemaker-train/src/sagemaker/train/dpo_trainer.py index 7e8b4747c5..bd5d9a11bd 100644 --- a/sagemaker-train/src/sagemaker/train/dpo_trainer.py +++ b/sagemaker-train/src/sagemaker/train/dpo_trainer.py @@ -181,7 +181,8 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True, - wait_timeout: Optional[int] = None): + wait_timeout: Optional[int] = None, + poll: int = 5): """Execute the DPO training job. Parameters: @@ -196,6 +197,8 @@ def train(self, wait_timeout (Optional[int]): Maximum time in seconds to wait for the training job to complete. Only used when wait=True. If None, uses the default timeout from the wait utility. + poll (int): + Polling interval in seconds for checking training job status. Defaults to 5. Returns: TrainingJob: The SageMaker training job object. @@ -283,6 +286,7 @@ def train(self, wait_kwargs = {} if wait_timeout is not None: wait_kwargs['timeout'] = wait_timeout + wait_kwargs['poll'] = poll _wait(training_job, **wait_kwargs) except TimeoutExceededError as e: logger.error("Error: %s", e) diff --git a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py index a9136e2742..f2d8460989 100644 --- a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py +++ b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py @@ -197,7 +197,7 @@ def _validate_reward_model_id(self, reward_model_id): @_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="RLAIFTrainer.train") - def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True, wait_timeout: Optional[int] = None): + def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True, wait_timeout: Optional[int] = None, poll: int = 5): """Execute the RLAIF training job. Parameters: @@ -212,6 +212,8 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati wait_timeout (Optional[int]): Maximum time in seconds to wait for the training job to complete. Only used when wait=True. If None, uses the default timeout from the wait utility. + poll (int): + Polling interval in seconds for checking training job status. Defaults to 5. Returns: TrainingJob: The SageMaker training job object. @@ -301,6 +303,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati wait_kwargs = {} if wait_timeout is not None: wait_kwargs['timeout'] = wait_timeout + wait_kwargs['poll'] = poll _wait(training_job, **wait_kwargs) except TimeoutExceededError as e: logger.error("Error: %s", e) diff --git a/sagemaker-train/src/sagemaker/train/rlvr_trainer.py b/sagemaker-train/src/sagemaker/train/rlvr_trainer.py index c496222bf4..49b35f124e 100644 --- a/sagemaker-train/src/sagemaker/train/rlvr_trainer.py +++ b/sagemaker-train/src/sagemaker/train/rlvr_trainer.py @@ -183,7 +183,7 @@ def _process_hyperparameters(self): @_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="RLVRTrainer.train") def train(self, training_dataset: Optional[Union[str, DataSet]] = None, - validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True, wait_timeout: Optional[int] = None): + validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True, wait_timeout: Optional[int] = None, poll: int = 5): """Execute the RLVR training job. Parameters: @@ -198,6 +198,8 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, wait_timeout (Optional[int]): Maximum time in seconds to wait for the training job to complete. Only used when wait=True. If None, uses the default timeout from the wait utility. + poll (int): + Polling interval in seconds for checking training job status. Defaults to 5. Returns: TrainingJob: The SageMaker training job object. @@ -289,6 +291,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, wait_kwargs = {} if wait_timeout is not None: wait_kwargs['timeout'] = wait_timeout + wait_kwargs['poll'] = poll _wait(training_job, **wait_kwargs) except TimeoutExceededError as e: logger.error("Error: %s", e) diff --git a/sagemaker-train/src/sagemaker/train/sft_trainer.py b/sagemaker-train/src/sagemaker/train/sft_trainer.py index 136231bd6f..233f169d0f 100644 --- a/sagemaker-train/src/sagemaker/train/sft_trainer.py +++ b/sagemaker-train/src/sagemaker/train/sft_trainer.py @@ -180,7 +180,7 @@ def _process_hyperparameters(self): self.hyperparameters._specs.pop('validation_data_path', None) @_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="SFTTrainer.train") - def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True, wait_timeout: Optional[int] = None): + def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True, wait_timeout: Optional[int] = None, poll: int = 5): """Execute the SFT training job. Parameters: @@ -195,6 +195,8 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati wait_timeout (Optional[int]): Maximum time in seconds to wait for the training job to complete. Only used when wait=True. If None, uses the default timeout from the wait utility. + poll (int): + Polling interval in seconds for checking training job status. Defaults to 5. Returns: TrainingJob: The SageMaker training job object. @@ -283,6 +285,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati wait_kwargs = {} if wait_timeout is not None: wait_kwargs['timeout'] = wait_timeout + wait_kwargs['poll'] = poll _wait(training_job, **wait_kwargs) except TimeoutExceededError as e: logger.error("Error: %s", e) diff --git a/sagemaker-train/tests/unit/train/test_dpo_trainer.py b/sagemaker-train/tests/unit/train/test_dpo_trainer.py index 2d5cf2246a..1b70e0bf89 100644 --- a/sagemaker-train/tests/unit/train/test_dpo_trainer.py +++ b/sagemaker-train/tests/unit/train/test_dpo_trainer.py @@ -420,7 +420,7 @@ def test_train_passes_wait_timeout(self, mock_training_job_create, mock_model_pa trainer = DPOTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") trainer.train(wait=True, wait_timeout=600) - mock_wait.assert_called_once_with(mock_training_job, timeout=600) + mock_wait.assert_called_once_with(mock_training_job, timeout=600, poll=5) @patch('sagemaker.train.common_utils.trainer_wait.wait') @patch('sagemaker.train.dpo_trainer._resolve_model_and_name') @@ -463,7 +463,7 @@ def test_train_without_wait_timeout_uses_default(self, mock_training_job_create, trainer = DPOTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") trainer.train(wait=True) - mock_wait.assert_called_once_with(mock_training_job) + mock_wait.assert_called_once_with(mock_training_job, poll=5) @patch('sagemaker.train.common_utils.trainer_wait.wait') @patch('sagemaker.train.dpo_trainer._resolve_model_and_name') diff --git a/sagemaker-train/tests/unit/train/test_rlaif_trainer.py b/sagemaker-train/tests/unit/train/test_rlaif_trainer.py index 24448ebbe6..e5666883e8 100644 --- a/sagemaker-train/tests/unit/train/test_rlaif_trainer.py +++ b/sagemaker-train/tests/unit/train/test_rlaif_trainer.py @@ -596,7 +596,7 @@ def test_train_passes_wait_timeout(self, mock_training_job_create, mock_model_pa trainer = RLAIFTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") trainer.train(wait=True, wait_timeout=600) - mock_wait.assert_called_once_with(mock_training_job, timeout=600) + mock_wait.assert_called_once_with(mock_training_job, timeout=600, poll=5) @patch('sagemaker.train.common_utils.trainer_wait.wait') @patch('sagemaker.train.common_utils.finetune_utils._get_beta_session') @@ -639,7 +639,7 @@ def test_train_without_wait_timeout_uses_default(self, mock_training_job_create, trainer = RLAIFTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") trainer.train(wait=True) - mock_wait.assert_called_once_with(mock_training_job) + mock_wait.assert_called_once_with(mock_training_job, poll=5) @patch('sagemaker.train.common_utils.trainer_wait.wait') @patch('sagemaker.train.common_utils.finetune_utils._get_beta_session') diff --git a/sagemaker-train/tests/unit/train/test_rlvr_trainer.py b/sagemaker-train/tests/unit/train/test_rlvr_trainer.py index e16c5c1c69..320b81555d 100644 --- a/sagemaker-train/tests/unit/train/test_rlvr_trainer.py +++ b/sagemaker-train/tests/unit/train/test_rlvr_trainer.py @@ -423,7 +423,7 @@ def test_train_passes_wait_timeout(self, mock_training_job_create, mock_model_pa trainer = RLVRTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") trainer.train(wait=True, wait_timeout=600) - mock_wait.assert_called_once_with(mock_training_job, timeout=600) + mock_wait.assert_called_once_with(mock_training_job, timeout=600, poll=5) @patch('sagemaker.train.common_utils.trainer_wait.wait') @patch('sagemaker.train.common_utils.finetune_utils._get_beta_session') @@ -466,7 +466,7 @@ def test_train_without_wait_timeout_uses_default(self, mock_training_job_create, trainer = RLVRTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") trainer.train(wait=True) - mock_wait.assert_called_once_with(mock_training_job) + mock_wait.assert_called_once_with(mock_training_job, poll=5) @patch('sagemaker.train.common_utils.trainer_wait.wait') @patch('sagemaker.train.common_utils.finetune_utils._get_beta_session') diff --git a/sagemaker-train/tests/unit/train/test_sft_trainer.py b/sagemaker-train/tests/unit/train/test_sft_trainer.py index a2473ebfd0..108990f839 100644 --- a/sagemaker-train/tests/unit/train/test_sft_trainer.py +++ b/sagemaker-train/tests/unit/train/test_sft_trainer.py @@ -434,7 +434,7 @@ def test_train_passes_wait_timeout(self, mock_training_job_create, mock_model_pa trainer = SFTTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") trainer.train(wait=True, wait_timeout=600) - mock_wait.assert_called_once_with(mock_training_job, timeout=600) + mock_wait.assert_called_once_with(mock_training_job, timeout=600, poll=5) @patch('sagemaker.train.common_utils.trainer_wait.wait') @patch('sagemaker.train.common_utils.finetune_utils._get_beta_session') @@ -477,7 +477,7 @@ def test_train_without_wait_timeout_uses_default(self, mock_training_job_create, trainer = SFTTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") trainer.train(wait=True) - mock_wait.assert_called_once_with(mock_training_job) + mock_wait.assert_called_once_with(mock_training_job, poll=5) @patch('sagemaker.train.common_utils.trainer_wait.wait') @patch('sagemaker.train.common_utils.finetune_utils._get_beta_session')