Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 48 additions & 2 deletions aws_signing_helper/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@ import (
"log"
"net/http"
"runtime"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/arn"
awscredentials "github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/rolesanywhere-credential-helper/rolesanywhere"
)

Expand All @@ -20,9 +23,10 @@ type CredentialsOpts struct {
CertificateId string
CertificateBundleId string
CertIdentifier CertIdentifier
RoleArn string
RoleArn []string
ProfileArnStr string
TrustAnchorArnStr string
RoleSessionName []string
SessionDuration int
Region string
Endpoint string
Expand Down Expand Up @@ -98,13 +102,16 @@ func GenerateCredentials(opts *CredentialsOpts, signer Signer, signatureAlgorith

certificateStr := base64.StdEncoding.EncodeToString(certificate.Raw)
durationSeconds := int64(opts.SessionDuration)
var firstRoleArn string
var remainingRoleArns []string
firstRoleArn, remainingRoleArns = opts.RoleArn[0], opts.RoleArn[1:]
createSessionRequest := rolesanywhere.CreateSessionInput{
Cert: &certificateStr,
ProfileArn: &opts.ProfileArnStr,
TrustAnchorArn: &opts.TrustAnchorArnStr,
DurationSeconds: &(durationSeconds),
InstanceProperties: nil,
RoleArn: &opts.RoleArn,
RoleArn: &firstRoleArn,
SessionName: nil,
}
output, err := rolesAnywhereClient.CreateSession(&createSessionRequest)
Expand All @@ -117,6 +124,45 @@ func GenerateCredentials(opts *CredentialsOpts, signer Signer, signatureAlgorith
return CredentialProcessOutput{}, errors.New(msg)
}
credentials := output.CredentialSet[0].Credentials
var currentRoleArn = firstRoleArn
for i := 0; i < len(remainingRoleArns); i++ {
if Debug {
log.Printf("using %s to assume %s\n",currentRoleArn,remainingRoleArns[i])
}

sess, err := session.NewSession(&aws.Config{
Region: &opts.Region,
Credentials: awscredentials.NewStaticCredentials(
*credentials.AccessKeyId,
*credentials.SecretAccessKey,
*credentials.SessionToken,
),
})
stsClient := sts.New(sess)
rsn := "my-session"
if len(opts.RoleSessionName) > i {
rsn = opts.RoleSessionName[i]
}
stsRequest := sts.AssumeRoleInput{
RoleArn: aws.String(remainingRoleArns[i]),
RoleSessionName: aws.String(rsn),
DurationSeconds: aws.Int64(durationSeconds), //min allowed
}

stsResponse, err := stsClient.AssumeRole(&stsRequest)
if err != nil {
return CredentialProcessOutput{}, err
}
xp := stsResponse.Credentials.Expiration.Format(time.RFC3339)
credentials = &rolesanywhere.Credentials{
AccessKeyId: stsResponse.Credentials.AccessKeyId,
SecretAccessKey: stsResponse.Credentials.SecretAccessKey,
SessionToken: stsResponse.Credentials.SessionToken,
Expiration: &xp,
}

currentRoleArn = remainingRoleArns[i]
}
credentialProcessOutput := CredentialProcessOutput{
Version: 1,
AccessKeyId: *credentials.AccessKeyId,
Expand Down
2 changes: 1 addition & 1 deletion aws_signing_helper/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ func AllIssuesHandlers(cred *RefreshableCred, roleName string, opts *Credentials
func Serve(port int, credentialsOptions CredentialsOpts) {
var refreshableCred = RefreshableCred{}

roleArn, err := arn.Parse(credentialsOptions.RoleArn)
roleArn, err := arn.Parse(credentialsOptions.RoleArn[0])
if err != nil {
log.Println("invalid role ARN")
os.Exit(1)
Expand Down
7 changes: 5 additions & 2 deletions cmd/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ import (
)

var (
roleArnStr string
roleArnStr []string
profileArnStr string
trustAnchorArnStr string
roleSessionName []string
sessionDuration int
region string
endpoint string
Expand Down Expand Up @@ -52,9 +53,10 @@ type MapEntry struct {
// Parses common flags for commands that vend credentials
func initCredentialsSubCommand(subCmd *cobra.Command) {
rootCmd.AddCommand(subCmd)
subCmd.PersistentFlags().StringVar(&roleArnStr, "role-arn", "", "Target role to assume")
subCmd.PersistentFlags().StringArrayVarP(&roleArnStr, "role-arn", "", []string{}, "Target role(s) to assume one-by-one, in order specified")
subCmd.PersistentFlags().StringVar(&profileArnStr, "profile-arn", "", "Profile to pull policies from")
subCmd.PersistentFlags().StringVar(&trustAnchorArnStr, "trust-anchor-arn", "", "Trust anchor to use for authentication")
subCmd.PersistentFlags().StringArrayVarP(&roleSessionName, "role-session-name", "", []string{}, "Session names for additional roles specified in --role-arn arguments")
subCmd.PersistentFlags().IntVar(&sessionDuration, "session-duration", 3600, "Duration, in seconds, for the resulting session")
subCmd.PersistentFlags().StringVar(&region, "region", "", "Signing region")
subCmd.PersistentFlags().StringVar(&endpoint, "endpoint", "", "Endpoint used to call CreateSession")
Expand Down Expand Up @@ -233,6 +235,7 @@ func PopulateCredentialsOptions() error {
RoleArn: roleArnStr,
ProfileArnStr: profileArnStr,
TrustAnchorArnStr: trustAnchorArnStr,
RoleSessionName: roleSessionName,
SessionDuration: sessionDuration,
Region: region,
Endpoint: endpoint,
Expand Down