Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion sagemaker-core/src/sagemaker/core/utils/install_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import re
import subprocess
import sys
from urllib.parse import urlsplit, urlunsplit

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -111,6 +112,27 @@ def _login_awscli(region, account, domain, repo):
)


def _redact_url_credentials(value):
"""Redact credentials embedded in a URL for safe logging."""
parts = urlsplit(value)
if not parts.scheme or not parts.netloc:
return value
if parts.username is None and parts.password is None:
return value

host = parts.hostname or ""
if parts.port is not None:
host = f"{host}:{parts.port}"

netloc = f"****@{host}"

return urlunsplit((parts.scheme, netloc, parts.path, parts.query, parts.fragment))


def _format_pip_cmd_for_log(pip_cmd):
return " ".join(_redact_url_credentials(arg) for arg in pip_cmd)


def configure_pip(auth_method=CodeArtifactAuthMethod.AUTO):
"""Configure pip for CodeArtifact if ``CA_REPOSITORY_ARN`` is set.

Expand Down Expand Up @@ -185,7 +207,7 @@ def install_requirements(
index = configure_pip(auth_method=auth_method)
if index:
pip_cmd.extend(["-i", index])
logger.info("Running: %s", " ".join(pip_cmd))
logger.info("Running: %s", _format_pip_cmd_for_log(pip_cmd))
subprocess.check_call(pip_cmd)


Expand Down
27 changes: 27 additions & 0 deletions sagemaker-core/tests/unit/test_install_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import logging
import subprocess
import sys
from unittest import mock
Expand All @@ -21,6 +22,7 @@
from sagemaker.core.utils.install_requirements import (
CA_REPOSITORY_ARN_ENV,
CodeArtifactAuthMethod,
_format_pip_cmd_for_log,
_parse_arn,
configure_pip,
install_requirements,
Expand Down Expand Up @@ -192,6 +194,31 @@ def test_with_codeartifact_index(self):
install_requirements("reqs.txt")
mock_call.assert_called_once_with(_pip_cmd("-i", EXPECTED_INDEX))

def test_codeartifact_index_token_is_redacted_from_logs(self, caplog):
caplog.set_level(logging.INFO, logger=_MODULE)
with mock.patch(f"{_MODULE}.configure_pip", return_value=EXPECTED_INDEX):
with mock.patch("subprocess.check_call") as mock_call:
install_requirements("reqs.txt")

mock_call.assert_called_once_with(_pip_cmd("-i", EXPECTED_INDEX))
assert FAKE_TOKEN not in caplog.text
assert "https://****@" in caplog.text

def test_url_username_is_redacted_from_logged_command(self):
logged_cmd = _format_pip_cmd_for_log(
[
sys.executable,
"-m",
"pip",
"install",
"-i",
"https://username-only-token@example.com/simple/",
]
)

assert "username-only-token" not in logged_cmd
assert "https://****@example.com/simple/" in logged_cmd

def test_with_cli_fallback_no_index_flag(self):
with mock.patch(f"{_MODULE}.configure_pip", return_value=None):
with mock.patch("subprocess.check_call") as mock_call:
Expand Down
Loading