Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions aws_signing_helper/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ import (
)

const DefaultPort = 9911
const LocalHostAddress = "127.0.0.1"
const DefaultBindAddr = "127.0.0.1"

var AllowedBindAddrs = []string{"127.0.0.1", "169.254.169.254"}

var RefreshTime = time.Minute * time.Duration(5)

Expand All @@ -35,9 +37,10 @@ type RefreshableCred struct {
}

type Endpoint struct {
PortNum int
Server *http.Server
TmpCred RefreshableCred
BindAddr string
PortNum int
Server *http.Server
TmpCred RefreshableCred
}

type SessionToken struct {
Expand Down Expand Up @@ -259,9 +262,14 @@ func AllIssuesHandlers(cred *RefreshableCred, roleName string, opts *Credentials
return putTokenHandler, getRoleNameHandler, getCredentialsHandler
}

func Serve(port int, credentialsOptions CredentialsOpts) {
func Serve(bindAddr string, port int, credentialsOptions CredentialsOpts) {
var refreshableCred = RefreshableCred{}

if !bindAddrAllowed(bindAddr) {
log.Printf("bind address not in %s: ", AllowedBindAddrs)
os.Exit(1)
}

roleArn, err := arn.Parse(credentialsOptions.RoleArn)
if err != nil {
log.Println("invalid role ARN")
Expand All @@ -276,7 +284,7 @@ func Serve(port int, credentialsOptions CredentialsOpts) {
refreshableCred.Code = REFRESHABLE_CRED_CODE
refreshableCred.LastUpdated = time.Now()
refreshableCred.Type = REFRESHABLE_CRED_TYPE
endpoint := &Endpoint{PortNum: port, TmpCred: refreshableCred}
endpoint := &Endpoint{BindAddr: bindAddr, PortNum: port, TmpCred: refreshableCred}
endpoint.Server = &http.Server{}
roleResourceParts := strings.Split(roleArn.Resource, "/")
roleName := roleResourceParts[len(roleResourceParts)-1] // Find role name without path
Expand All @@ -303,17 +311,27 @@ func Serve(port int, credentialsOptions CredentialsOpts) {
}()

// Start the credentials endpoint
listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", LocalHostAddress, endpoint.PortNum))
listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", endpoint.BindAddr, endpoint.PortNum))
if err != nil {
log.Println("failed to create listener")
os.Exit(1)
}
endpoint.PortNum = listener.Addr().(*net.TCPAddr).Port
log.Println("Local server started on port:", endpoint.PortNum)
log.Println("Make it available to the sdk by running:")
log.Printf("export AWS_EC2_METADATA_SERVICE_ENDPOINT=http://%s:%d/", LocalHostAddress, endpoint.PortNum)
log.Printf("export AWS_EC2_METADATA_SERVICE_ENDPOINT=http://%s:%d/", endpoint.BindAddr, endpoint.PortNum)
if err := endpoint.Server.Serve(listener); err != nil {
log.Println("Httpserver: ListenAndServe() error")
os.Exit(1)
}
}

func bindAddrAllowed(addr string) bool {
for _, v := range AllowedBindAddrs {
if addr == v {
return true
}
}

return false
}
25 changes: 25 additions & 0 deletions aws_signing_helper/serve_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package aws_signing_helper

import "testing"

func Test_bindAddrAllowed(t *testing.T) {
type args struct {
addr string
}
tests := []struct {
name string
args args
want bool
}{
{name: "open", args: args{addr: "0.0.0.0"}, want: false},
{name: "localhost", args: args{addr: "127.0.0.1"}, want: true},
{name: "loopback", args: args{addr: "127.0.0.1"}, want: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := bindAddrAllowed(tt.args.addr); got != tt.want {
t.Errorf("bindAddrAllowed() = %v, want %v", got, tt.want)
}
})
}
}
8 changes: 5 additions & 3 deletions cmd/aws_signing_helper/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ var (
profile string
once bool

port int
port int
bindAddr string

credentialProcessCmd = flag.NewFlagSet("credential-process", flag.ExitOnError)
signStringCmd = flag.NewFlagSet("sign-string", flag.ExitOnError)
Expand Down Expand Up @@ -117,7 +118,8 @@ func setupFlags() {
fs.StringVar(&profile, "profile", "default", "The aws profile to use (default 'default')")
fs.BoolVar(&once, "once", false, "Update the credentials once")
} else if command == "serve" {
fs.IntVar(&port, "port", helper.DefaultPort, "The port used to run local server (default: 9911)")
fs.IntVar(&port, "port", helper.DefaultPort, fmt.Sprintf("The port used to run local server (default: '%d')", helper.DefaultPort))
fs.StringVar(&bindAddr, "bind-addr", helper.DefaultBindAddr, fmt.Sprintf("The address used to run local server. Must be in %+v", helper.AllowedBindAddrs))
}
}
}
Expand Down Expand Up @@ -270,7 +272,7 @@ func main() {
log.Println(msg)
os.Exit(1)
}
helper.Serve(port, credentialsOptions)
helper.Serve(bindAddr, port, credentialsOptions)
case "":
log.Println("No command provided")
os.Exit(1)
Expand Down