Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,15 @@
config_for_framework,
)
from sagemaker.core.workflow.utilities import override_pipeline_parameter_var
from sagemaker.core.config.config_schema import IMAGE_RETRIEVER, MODULES, SAGEMAKER, _simple_path
from sagemaker.core.config.config_schema import IMAGE_RETRIEVER, MODULES, PYTHON_SDK, SAGEMAKER, _simple_path
from sagemaker.core.config.config_manager import SageMakerConfig


def _to_pascal_case(name):
"""Convert snake_case to PascalCase."""
camel = to_camel_case(name)
return camel[0].upper() + camel[1:] if camel else camel

ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}"
HUGGING_FACE_FRAMEWORK = "huggingface"
PYTORCH_FRAMEWORK = "pytorch"
Expand Down Expand Up @@ -114,11 +120,25 @@ def retrieve_hugging_face_uri(
if name in CONFIGURABLE_ATTRIBUTES and not val:
default_value = ImageRetriever._config.resolve_value_from_config(
config_path=_simple_path(
SAGEMAKER, MODULES, IMAGE_RETRIEVER, to_camel_case(name)
SAGEMAKER, PYTHON_SDK, MODULES, IMAGE_RETRIEVER, _to_pascal_case(name)
)
)
if default_value is not None:
locals()[name] = default_value
args[name] = default_value

# Apply resolved defaults back to local variables
version = args.get("version", version)
py_version = args.get("py_version", py_version)
instance_type = args.get("instance_type", instance_type)
accelerator_type = args.get("accelerator_type", accelerator_type)
image_scope = args.get("image_scope", image_scope)
container_version = args.get("container_version", container_version)
distributed = args.get("distributed", distributed)
base_framework_version = args.get("base_framework_version", base_framework_version)
training_compiler_config = args.get("training_compiler_config", training_compiler_config)
sdk_version = args.get("sdk_version", sdk_version)
inference_tool = args.get("inference_tool", inference_tool)
serverless_inference_config = args.get("serverless_inference_config", serverless_inference_config)

if training_compiler_config:
final_image_scope = image_scope
Expand Down Expand Up @@ -503,11 +523,28 @@ def retrieve(
if name in CONFIGURABLE_ATTRIBUTES and not val:
default_value = ImageRetriever._config.resolve_value_from_config(
config_path=_simple_path(
SAGEMAKER, MODULES, IMAGE_RETRIEVER, to_camel_case(name)
SAGEMAKER, PYTHON_SDK, MODULES, IMAGE_RETRIEVER, _to_pascal_case(name)
)
)
if default_value is not None:
locals()[name] = default_value
args[name] = default_value

# Apply resolved defaults back to local variables
version = args.get("version", version)
py_version = args.get("py_version", py_version)
instance_type = args.get("instance_type", instance_type)
accelerator_type = args.get("accelerator_type", accelerator_type)
image_scope = args.get("image_scope", image_scope)
container_version = args.get("container_version", container_version)
distributed = args.get("distributed", distributed)
smp = args.get("smp", smp)
base_framework_version = args.get("base_framework_version", base_framework_version)
training_compiler_config = args.get("training_compiler_config", training_compiler_config)
model_id = args.get("model_id", model_id)
model_version = args.get("model_version", model_version)
sdk_version = args.get("sdk_version", sdk_version)
inference_tool = args.get("inference_tool", inference_tool)
serverless_inference_config = args.get("serverless_inference_config", serverless_inference_config)

for name, val in args.items():
if is_pipeline_variable(val):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_retrieve_base_python_image_uri():
assert image_uri == "236514542706.dkr.ecr.us-west-2.amazonaws.com/sagemaker-base-python-310:1.0"


@pytest.mark.skip(reason="Test is failing due to locals()[name] = default_value in Image Retriever")
# @pytest.mark.skip(reason="Test is failing due to locals()[name] = default_value in Image Retriever")
@patch.object(SageMakerConfig, "resolve_value_from_config")
def test_retrieve_image_uri_intelligent_default(mock_load_config):
def custom_return(config_path=None, **kwargs):
Expand All @@ -116,5 +116,5 @@ def custom_return(config_path=None, **kwargs):
)
assert (
image_uri
== "053634841547.dkr.ecr.us-west-1.amazonaws.com/sagemaker-distribution-prod:3.0.0-gpu"
== "053634841547.dkr.ecr.us-west-1.amazonaws.com/sagemaker-distribution-prod:3.2.0-gpu"
)
6 changes: 3 additions & 3 deletions sagemaker-core/tests/integ/remote_function/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def divide(x, y):


# TODO: add VPC settings, update SageMakerRole with KMS permissions
@pytest.mark.skip
# @pytest.mark.skip
def test_advanced_job_setting(
sagemaker_session, dummy_container_without_error, cpu_instance_type, s3_kms_key
):
Expand Down Expand Up @@ -573,7 +573,7 @@ def my_func():
assert client_error_message in str(error)


@pytest.mark.skip
# @pytest.mark.skip
def test_decorator_with_spark_job(sagemaker_session, cpu_instance_type):
@remote(
role=ROLE,
Expand All @@ -599,7 +599,7 @@ def test_spark_transform():
test_spark_transform()


@pytest.mark.skip
# @pytest.mark.skip
def test_decorator_auto_capture(sagemaker_session, auto_capture_test_container):
"""
This test runs a docker container. The Container invocation will execute a python script
Expand Down
Loading