diff --git a/pkg/aws/sso.go b/pkg/aws/sso.go index 45f3737..6a1a003 100644 --- a/pkg/aws/sso.go +++ b/pkg/aws/sso.go @@ -4,6 +4,12 @@ import ( "context" "errors" "fmt" + "log" + "os" + "path/filepath" + "strings" + "time" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/sso" @@ -11,10 +17,6 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ssooidc" ssoOidcTypes "github.com/aws/aws-sdk-go-v2/service/ssooidc/types" "github.com/skratchdot/open-golang/open" - "log" - "os" - "strings" - "time" ) func GenerateConfig(ssoURL, region string, overwrite bool) error { @@ -38,10 +40,10 @@ func GenerateConfig(ssoURL, region string, overwrite bool) error { if err != nil { return err } - configPath := homeDir + "/.aws/config" + configPath := filepath.Join(homeDir, ".aws", "config") err = os.Remove(configPath) - if err != nil { - fmt.Println("File not found. Continue...", err) + if err != nil && !os.IsNotExist(err) { + return err } for _, acc := range accounts { roles := GetRolesByAccount(ctx, client, acc, token.AccessToken) diff --git a/pkg/commands/generate.go b/pkg/commands/generate.go index a6820da..e5c43ab 100644 --- a/pkg/commands/generate.go +++ b/pkg/commands/generate.go @@ -3,6 +3,8 @@ package commands import ( "fmt" "log" + "os" + "path/filepath" "github.com/Gympass/aws-vault-scg/pkg/aws" "github.com/urfave/cli/v2" @@ -32,31 +34,66 @@ var Generate = &cli.Command{ region = "us-east-1" } - var o string - fmt.Print("Overwrite current config(~/.aws/config)[y/N]? ") - _, err := fmt.Scanf("%s", &o) - if err != nil { - fmt.Println("Default option") - } - switch { - case o == "y" || o == "Y": - err := aws.GenerateConfig(ssoURL, region, true) - if err != nil { - log.Fatalf("Error to generate config file: %v", err) - } - case o == "n" || o == "N": - err := aws.GenerateConfig(ssoURL, region, false) - fmt.Println("You can use this profile values to update your config file(~/.aws/config)") + configDir := configDirName() + configFile := filepath.Join(configDir, "config") + + if !fileExists(configFile) { + // no config file, make sure path exists + createPath(configDir) + generateConfig(ssoURL, region) + } else { + // config already exists, check if the user wants to overwrite + var selectedOption string + fmt.Print("Overwrite current config(~/.aws/config)[y/N]? ") + _, err := fmt.Scanf("%s", &selectedOption) if err != nil { - log.Fatalf("Error to print config: %v", err) + fmt.Println("Default option") } - default: - err := aws.GenerateConfig(ssoURL, region, false) - fmt.Println("You can use this profile values to update your config file(~/.aws/config)") - if err != nil { - log.Fatalf("Error to print config: %v", err) + + if selectedOption == "y" || selectedOption == "Y" { + generateConfig(ssoURL, region) + } else { + printConfig(ssoURL, region) } } + return nil }, } + +func configDirName() string { + homeDir, err := os.UserHomeDir() + if err != nil { + log.Fatalf("Error identifying the home dir: %v", err) + } + return filepath.Join(homeDir, ".aws") +} + +func fileExists(fileName string) bool { + if _, err := os.Stat(fileName); err == nil { + return true + } else { + return false + } +} + +func createPath(path string) { + if err := os.MkdirAll(path, os.ModePerm); err != nil { + log.Fatalf("Error creating %v: %v", path, err) + } +} + +func generateConfig(ssoURL string, region string) { + err := aws.GenerateConfig(ssoURL, region, true) + if err != nil { + log.Fatalf("Error to generate config file: %v", err) + } +} + +func printConfig(ssoURL string, region string) { + err := aws.GenerateConfig(ssoURL, region, false) + fmt.Println("You can use this profile values to update your config file(~/.aws/config)") + if err != nil { + log.Fatalf("Error to print config: %v", err) + } +}