diff --git a/sagemaker-core/src/sagemaker/core/image_retriever/image_retriever.py b/sagemaker-core/src/sagemaker/core/image_retriever/image_retriever.py index ca6295dd7f..c08857f01b 100644 --- a/sagemaker-core/src/sagemaker/core/image_retriever/image_retriever.py +++ b/sagemaker-core/src/sagemaker/core/image_retriever/image_retriever.py @@ -25,9 +25,15 @@ config_for_framework, ) from sagemaker.core.workflow.utilities import override_pipeline_parameter_var -from sagemaker.core.config.config_schema import IMAGE_RETRIEVER, MODULES, SAGEMAKER, _simple_path +from sagemaker.core.config.config_schema import IMAGE_RETRIEVER, MODULES, PYTHON_SDK, SAGEMAKER, _simple_path from sagemaker.core.config.config_manager import SageMakerConfig + +def _to_pascal_case(name): + """Convert snake_case to PascalCase.""" + camel = to_camel_case(name) + return camel[0].upper() + camel[1:] if camel else camel + ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}" HUGGING_FACE_FRAMEWORK = "huggingface" PYTORCH_FRAMEWORK = "pytorch" @@ -114,11 +120,25 @@ def retrieve_hugging_face_uri( if name in CONFIGURABLE_ATTRIBUTES and not val: default_value = ImageRetriever._config.resolve_value_from_config( config_path=_simple_path( - SAGEMAKER, MODULES, IMAGE_RETRIEVER, to_camel_case(name) + SAGEMAKER, PYTHON_SDK, MODULES, IMAGE_RETRIEVER, _to_pascal_case(name) ) ) if default_value is not None: - locals()[name] = default_value + args[name] = default_value + + # Apply resolved defaults back to local variables + version = args.get("version", version) + py_version = args.get("py_version", py_version) + instance_type = args.get("instance_type", instance_type) + accelerator_type = args.get("accelerator_type", accelerator_type) + image_scope = args.get("image_scope", image_scope) + container_version = args.get("container_version", container_version) + distributed = args.get("distributed", distributed) + base_framework_version = args.get("base_framework_version", base_framework_version) + training_compiler_config = args.get("training_compiler_config", training_compiler_config) + sdk_version = args.get("sdk_version", sdk_version) + inference_tool = args.get("inference_tool", inference_tool) + serverless_inference_config = args.get("serverless_inference_config", serverless_inference_config) if training_compiler_config: final_image_scope = image_scope @@ -503,11 +523,28 @@ def retrieve( if name in CONFIGURABLE_ATTRIBUTES and not val: default_value = ImageRetriever._config.resolve_value_from_config( config_path=_simple_path( - SAGEMAKER, MODULES, IMAGE_RETRIEVER, to_camel_case(name) + SAGEMAKER, PYTHON_SDK, MODULES, IMAGE_RETRIEVER, _to_pascal_case(name) ) ) if default_value is not None: - locals()[name] = default_value + args[name] = default_value + + # Apply resolved defaults back to local variables + version = args.get("version", version) + py_version = args.get("py_version", py_version) + instance_type = args.get("instance_type", instance_type) + accelerator_type = args.get("accelerator_type", accelerator_type) + image_scope = args.get("image_scope", image_scope) + container_version = args.get("container_version", container_version) + distributed = args.get("distributed", distributed) + smp = args.get("smp", smp) + base_framework_version = args.get("base_framework_version", base_framework_version) + training_compiler_config = args.get("training_compiler_config", training_compiler_config) + model_id = args.get("model_id", model_id) + model_version = args.get("model_version", model_version) + sdk_version = args.get("sdk_version", sdk_version) + inference_tool = args.get("inference_tool", inference_tool) + serverless_inference_config = args.get("serverless_inference_config", serverless_inference_config) for name, val in args.items(): if is_pipeline_variable(val): diff --git a/sagemaker-core/tests/integ/image_retriever/test_image_retriever.py b/sagemaker-core/tests/integ/image_retriever/test_image_retriever.py index 89127c1b48..9d07f5ceda 100644 --- a/sagemaker-core/tests/integ/image_retriever/test_image_retriever.py +++ b/sagemaker-core/tests/integ/image_retriever/test_image_retriever.py @@ -94,7 +94,7 @@ def test_retrieve_base_python_image_uri(): assert image_uri == "236514542706.dkr.ecr.us-west-2.amazonaws.com/sagemaker-base-python-310:1.0" -@pytest.mark.skip(reason="Test is failing due to locals()[name] = default_value in Image Retriever") +# @pytest.mark.skip(reason="Test is failing due to locals()[name] = default_value in Image Retriever") @patch.object(SageMakerConfig, "resolve_value_from_config") def test_retrieve_image_uri_intelligent_default(mock_load_config): def custom_return(config_path=None, **kwargs): @@ -116,5 +116,5 @@ def custom_return(config_path=None, **kwargs): ) assert ( image_uri - == "053634841547.dkr.ecr.us-west-1.amazonaws.com/sagemaker-distribution-prod:3.0.0-gpu" + == "053634841547.dkr.ecr.us-west-1.amazonaws.com/sagemaker-distribution-prod:3.2.0-gpu" ) diff --git a/sagemaker-core/tests/integ/remote_function/test_decorator.py b/sagemaker-core/tests/integ/remote_function/test_decorator.py index f4e45ae222..b9aae44729 100644 --- a/sagemaker-core/tests/integ/remote_function/test_decorator.py +++ b/sagemaker-core/tests/integ/remote_function/test_decorator.py @@ -122,7 +122,7 @@ def divide(x, y): # TODO: add VPC settings, update SageMakerRole with KMS permissions -@pytest.mark.skip +# @pytest.mark.skip def test_advanced_job_setting( sagemaker_session, dummy_container_without_error, cpu_instance_type, s3_kms_key ): @@ -573,7 +573,7 @@ def my_func(): assert client_error_message in str(error) -@pytest.mark.skip +# @pytest.mark.skip def test_decorator_with_spark_job(sagemaker_session, cpu_instance_type): @remote( role=ROLE, @@ -599,7 +599,7 @@ def test_spark_transform(): test_spark_transform() -@pytest.mark.skip +# @pytest.mark.skip def test_decorator_auto_capture(sagemaker_session, auto_capture_test_container): """ This test runs a docker container. The Container invocation will execute a python script