Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions sagemaker-core/src/sagemaker/core/utils/install_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import re
import subprocess
import sys
import tempfile

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -111,6 +112,20 @@ def _login_awscli(region, account, domain, repo):
)


def _write_pip_config(index):
"""Write the authenticated package index to a temporary pip config file."""
config = tempfile.NamedTemporaryFile(
mode="w", prefix="sagemaker-pip-", suffix=".conf", delete=False
)
try:
config.write("[global]\n")
config.write(f"index-url = {index}\n")
finally:
config.close()
os.chmod(config.name, 0o600)
return config.name


def configure_pip(auth_method=CodeArtifactAuthMethod.AUTO):
"""Configure pip for CodeArtifact if ``CA_REPOSITORY_ARN`` is set.

Expand Down Expand Up @@ -183,10 +198,26 @@ def install_requirements(
python_executable = python_executable or sys.executable
pip_cmd = [python_executable, "-m", "pip", "install", "-r", requirements_file]
index = configure_pip(auth_method=auth_method)
env = None
pip_config = None
if index:
pip_cmd.extend(["-i", index])
pip_config = _write_pip_config(index)
env = os.environ.copy()
env["PIP_CONFIG_FILE"] = pip_config
logger.info("Running: %s", " ".join(pip_cmd))
subprocess.check_call(pip_cmd)
try:
if env is not None:
subprocess.check_call(pip_cmd, env=env)
else:
subprocess.check_call(pip_cmd)
finally:
if pip_config:
try:
os.remove(pip_config)
except FileNotFoundError:
pass
except OSError as e:
logger.warning("Failed to remove temporary pip config file %s: %s", pip_config, e)


def main():
Expand Down
28 changes: 27 additions & 1 deletion sagemaker-core/tests/unit/test_install_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import logging
import os
import subprocess
import sys
from unittest import mock
Expand Down Expand Up @@ -187,10 +189,34 @@ def test_without_codeartifact(self):
mock_call.assert_called_once_with(_pip_cmd())

def test_with_codeartifact_index(self):
captured = {}

def fake_check_call(cmd, env=None):
captured["cmd"] = cmd
captured["env"] = env
with open(env["PIP_CONFIG_FILE"], "r") as config:
captured["config"] = config.read()

with mock.patch(f"{_MODULE}.configure_pip", return_value=EXPECTED_INDEX):
with mock.patch("subprocess.check_call", side_effect=fake_check_call):
install_requirements("reqs.txt")

assert captured["cmd"] == _pip_cmd()
assert captured["env"]["PIP_CONFIG_FILE"]
assert EXPECTED_INDEX in captured["config"]
assert not os.path.exists(captured["env"]["PIP_CONFIG_FILE"])

def test_codeartifact_index_not_logged_or_passed_in_argv(self, caplog):
caplog.set_level(logging.INFO, logger=_MODULE)
with mock.patch(f"{_MODULE}.configure_pip", return_value=EXPECTED_INDEX):
with mock.patch("subprocess.check_call") as mock_call:
install_requirements("reqs.txt")
mock_call.assert_called_once_with(_pip_cmd("-i", EXPECTED_INDEX))

assert FAKE_TOKEN not in caplog.text
assert EXPECTED_INDEX not in caplog.text
argv = mock_call.call_args[0][0]
assert FAKE_TOKEN not in " ".join(argv)
assert EXPECTED_INDEX not in " ".join(argv)

def test_with_cli_fallback_no_index_flag(self):
with mock.patch(f"{_MODULE}.configure_pip", return_value=None):
Expand Down
Loading