Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/awscli_login/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ class ConfigError(AWSCLILogin):
pass


class UserExit(ConfigError):
code = 0

def __init__(self) -> None:
super().__init__("No role selected. Good bye!")


class AlreadyLoggedIn(ConfigError):
code = 2

Expand Down Expand Up @@ -89,12 +96,11 @@ def __init__(self, role: str) -> None:
super().__init__(mesg % role)


class InvalidSelection(ConfigError):
class TooManyInvalidSelections(ConfigError):
code = 11

def __init__(self) -> None:
mesg = "Invalid selection!\a"
super().__init__(mesg)
super().__init__("Too many invalid selections!")


class TooManyHttpTrafficFlags(ConfigError):
Expand Down
38 changes: 26 additions & 12 deletions src/awscli_login/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ class Session: # type: ignore
ConfigError,
CredentialProcessMisconfigured,
CredentialProcessNotSet,
InvalidSelection,
SAML,
TooManyHttpTrafficFlags,
TooManyInvalidSelections,
UserExit,
)
from ._typing import Role

Expand Down Expand Up @@ -75,9 +76,7 @@ def get_selection(role_arns: List[Role], profile_role: Optional[str] = None,
interactive: bool = True, aliases: Dict[str, str] = {}
) -> Role:
""" Interactively prompts the user for a role selection. """
i = 0
n = len(role_arns)
select: Dict[int, int] = {}

# Return profile_role if valid and set
if profile_role is not None:
Expand All @@ -90,9 +89,22 @@ def get_selection(role_arns: List[Role], profile_role: Optional[str] = None,
logger.error(ERROR_INVALID_PROFILE_ROLE % profile_role)

if n > 1:
print("Please choose the role you would like to assume:")
return prompt_for_role_arn(role_arns)
elif n == 1:
return role_arns[0]
else:
raise SAML("No roles returned!")


def prompt_for_role_arn(role_arns: List[Role], aliases: Dict[str, str] = {}):
""" Prompts user to select a role from the given list of roles. """
accounts = sort_roles(role_arns)

accounts = sort_roles(role_arns)
for _ in range(3):
i = 0
select: Dict[int, int] = {}

print("Please choose the role you would like to assume:")
for acct, roles in accounts:
name = f"{aliases[acct]} ({acct})" if acct in aliases else acct
print(' ' * 4, "Account:", name)
Expand All @@ -102,15 +114,17 @@ def get_selection(role_arns: List[Role], profile_role: Optional[str] = None,
select[i] = index
i += 1

print("Selection:\a ", end='')
print("Select a role or enter 'q' to quit:\a ", end='')
try:
return role_arns[select[int(input())]]
user_input = input()
return role_arns[select[int(user_input)]]
except (ValueError, KeyError):
raise InvalidSelection
elif n == 1:
return role_arns[0]
else:
raise SAML("No roles returned!")
if user_input == 'q':
raise UserExit
print("Invalid value. Try again.")
continue

raise TooManyInvalidSelections


def file2bytes(filename: str) -> bytes:
Expand Down
19 changes: 16 additions & 3 deletions src/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
from awscli_login.exceptions import (
CredentialProcessMisconfigured,
CredentialProcessNotSet,
InvalidSelection,
SAML,
TooManyHttpTrafficFlags,
TooManyInvalidSelections,
UserExit,
)
from awscli_login.util import (
config_vcr,
Expand Down Expand Up @@ -128,6 +129,18 @@ def test_get_2of2_selections(self, *args):

self.assertEqual(get_selection(roles), roles[1])

@patch('builtins.input', return_value='q')
@patch('sys.stdout', new=StringIO())
def test_user_exit(self, *args):
""" User exits by typing 'q' """
roles = [
('idp1', 'arn:aws:iam::123577191723:role/KalturaAdmin'),
('idp2', 'arn:aws:iam::271867855970:role/BoxAdmin'),
]

with self.assertRaises(UserExit):
get_selection(roles)

@patch('builtins.input', return_value=3)
@patch('sys.stdout', new=StringIO())
def test_get_bad_numeric_selection(self, *args):
Expand All @@ -137,7 +150,7 @@ def test_get_bad_numeric_selection(self, *args):
('idp2', 'arn:aws:iam::271867855970:role/BoxAdmin'),
]

with self.assertRaises(InvalidSelection):
with self.assertRaises(TooManyInvalidSelections):
get_selection(roles)

@patch('builtins.input', return_value="foo")
Expand All @@ -149,7 +162,7 @@ def test_get_bad_type_selection(self, *args):
('idp2', 'arn:aws:iam::271867855970:role/BoxAdmin'),
]

with self.assertRaises(InvalidSelection):
with self.assertRaises(TooManyInvalidSelections):
get_selection(roles)

@patch('builtins.input', return_value=1)
Expand Down