Skip to content

Commit

Permalink
fix: correct credentials for both IAM and SSO
Browse files Browse the repository at this point in the history
  • Loading branch information
MoHermes committed Dec 16, 2024
1 parent 38d9b5c commit cfd665b
Showing 1 changed file with 60 additions and 45 deletions.
105 changes: 60 additions & 45 deletions mpqp/execution/connection/aws_connection.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import os
from typing import TYPE_CHECKING, Any, Union
from typing import TYPE_CHECKING, Any, Optional, Union

from termcolor import colored
from typeguard import typechecked

if TYPE_CHECKING:
from braket.devices.device import Device as BraketDevice

from configparser import ConfigParser
from getpass import getpass
from pathlib import Path

from mpqp.execution.connection.env_manager import get_env_variable, save_env_variable
from mpqp.execution.devices import AWSDevice
Expand Down Expand Up @@ -63,64 +65,22 @@ def setup_aws_braket_account() -> tuple[str, list[Any]]:
return "", []


def configure_account_iam() -> tuple[str, list[Any]]:
"""Configure IAM authentication for Amazon Braket.
This function guides the user through the Amazon Braket IAM configuration process.
"""
print("Configuring IAM authentication for Amazon Braket...")
os.system("aws configure")

print("IAM authentication configured successfully.")
save_env_variable("BRAKET_AUTH_METHOD", "IAM")
save_env_variable("BRAKET_CONFIGURED", "True")

return "IAM configuration successful.", []


def get_user_sso_credentials() -> Union[dict[str, str], None]:

print("Please enter your AWS SSO credentials (inputs are hidden):")

try:
access_key_id = input("Enter AWS access key ID: ").strip()
secret_access_key = getpass("Enter AWS secret access key: ").strip()
session_token = getpass(
"Enter AWS session token (if applicable, press Enter to skip): "
).strip()
region = input("Enter AWS region: ").strip()

return {
"access_key_id": access_key_id,
"secret_access_key": secret_access_key,
"session_token": session_token,
"region": region,
}

except Exception as e:
print(f"An error occurred while getting credentials: {e}")
return None


def update_aws_credentials_file(
profile_name: str,
access_key_id: str,
secret_access_key: str,
session_token: str,
session_token: Optional[str],
region: str,
):
"""Update .aws/credentials file with the latest SSO credentials."""

from configparser import ConfigParser
from pathlib import Path

credentials_file = Path.home() / ".aws" / "credentials"

credentials_dir = credentials_file.parent
if not credentials_dir.exists():
credentials_dir.mkdir(parents=True, exist_ok=True)

config = ConfigParser()

if credentials_file.exists():
config.read(credentials_file)

Expand All @@ -130,14 +90,68 @@ def update_aws_credentials_file(
config.add_section(profile_name)
config[profile_name]["aws_access_key_id"] = access_key_id
config[profile_name]["aws_secret_access_key"] = secret_access_key

if session_token:
config[profile_name]["aws_session_token"] = session_token

config[profile_name]["region"] = region

with open(credentials_file, "w") as f:
config.write(f)


def configure_account_iam() -> tuple[str, list[Any]]:
"""Configure IAM authentication for Amazon Braket."""

print("Configuring IAM authentication for Amazon Braket...")
os.system("aws configure")

print("IAM authentication configured successfully.")
save_env_variable("BRAKET_AUTH_METHOD", "IAM")
save_env_variable("BRAKET_CONFIGURED", "True")

credentials_file = Path.home() / ".aws" / "credentials"

config = ConfigParser()
config.read(credentials_file)

access_key_id = config.get("default", "aws_access_key_id", fallback="")
secret_access_key = config.get("default", "aws_secret_access_key", fallback="")
region = config.get("default", "region", fallback="us-east-1")

update_aws_credentials_file(
profile_name="default",
access_key_id=access_key_id,
secret_access_key=secret_access_key,
session_token=None,
region=region,
)

return "IAM configuration successful.", []


def get_user_sso_credentials() -> Union[dict[str, str], None]:

print("Please enter your AWS SSO credentials (inputs are hidden):")

try:
access_key_id = input("Enter AWS access key ID: ").strip()
secret_access_key = getpass("Enter AWS secret access key: ").strip()
session_token = getpass("Enter AWS session token: ").strip()
region = input("Enter AWS region: ").strip()

return {
"access_key_id": access_key_id,
"secret_access_key": secret_access_key,
"session_token": session_token,
"region": region,
}

except Exception as e:
print(f"An error occurred while getting credentials: {e}")
return None


def configure_account_sso() -> tuple[str, list[Any]]:
"""Configure SSO authentication for Amazon Braket.
This function guides the user through the Amazon Braket SSO configuration process.
Expand All @@ -154,7 +168,7 @@ def configure_account_sso() -> tuple[str, list[Any]]:
profile_name="default",
access_key_id=sso_credentials["access_key_id"],
secret_access_key=sso_credentials["secret_access_key"],
session_token=sso_credentials.get("session_token", ""),
session_token=sso_credentials["session_token"],
region=sso_credentials["region"],
)

Expand Down Expand Up @@ -199,6 +213,7 @@ def get_aws_braket_account_info() -> str:
access_key_id = credentials.access_key
secret_access_key = credentials.secret_key
obfuscate_key = secret_access_key[:5] + "*" * (len(secret_access_key) - 5)

region_name = session.boto_session.region_name
except Exception as e:
raise AWSBraketRemoteExecutionError(
Expand Down

0 comments on commit cfd665b

Please sign in to comment.