From b12c9756598073d5aa9f75dc732fcf9e0ab6b0be Mon Sep 17 00:00:00 2001 From: sendqueery <47192407+sendqueery@users.noreply.github.com> Date: Thu, 16 Jul 2020 12:18:27 -0400 Subject: [PATCH 01/12] Add method to create SSM targets from kv strings --- util/common.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/util/common.go b/util/common.go index 9732e3a..7a55229 100644 --- a/util/common.go +++ b/util/common.go @@ -2,6 +2,9 @@ package util import ( "strings" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/ssm" ) // CommaSplit is a function used to split a comma-delimited list of strings into a slice of strings @@ -24,3 +27,17 @@ func SliceToMap(kvslice []string, filterMap *map[string]string) { (*filterMap)[elements[0]] = elements[1] } } + +func SliceToTargets(kvslice []string) (targets []*ssm.Target) { + var elements []string + + for i := 0; i < len(kvslice); i++ { + elements = strings.Split(kvslice[i], "=") + targets = append(targets, &ssm.Target{ + Key: aws.String(elements[0]), + Values: aws.StringSlice([]string{elements[1]}), + }) + } + + return targets +} From 599a551164e38f43ec74f09d71b38a784da15cdd Mon Sep 17 00:00:00 2001 From: sendqueery <47192407+sendqueery@users.noreply.github.com> Date: Thu, 16 Jul 2020 12:20:13 -0400 Subject: [PATCH 02/12] Update ssm-run to rely on SSM for targeting --- cmd/run.go | 97 +++++++++++--------------------- ssm/helpers.go | 103 +++++++++++++++++++--------------- ssm/invocation/runner.go | 118 ++++++++++++--------------------------- 3 files changed, 127 insertions(+), 191 deletions(-) diff --git a/cmd/run.go b/cmd/run.go index 19ae99b..802896e 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -5,7 +5,6 @@ import ( "os" "runtime" "strings" - "sync" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ssm" @@ -43,11 +42,19 @@ func runCommand(cmd *cobra.Command, args []string) { profileList := cmdutil.GetFlagStringSlice(cmd.Parent(), "profile") regionList := cmdutil.GetFlagStringSlice(cmd.Parent(), "region") filterList := cmdutil.GetFlagStringSlice(cmd.Parent(), "filter") - limitFlag := cmdutil.GetFlagInt(cmd, "limit") instanceList := cmdutil.GetFlagStringSlice(cmd, "instance") // Get the number of cores available for parallelization runtime.GOMAXPROCS(runtime.NumCPU()) + if len(instanceList) > 0 && len(filterList) > 0 { + cmdutil.UsageError(cmd, "The --filter and --instance flags cannot be used simultaneously.") + os.Exit(1) + } + + if len(instanceList) > 50 { + cmdutil.UsageError(cmd, "The --instance flag can only be used to specify a maximum of 50 instances.") + } + // If the --commands and --file options are specified, we append the script contents to the specified commands if inputFile := cmdutil.GetFlagString(cmd, "file"); inputFile != "" { // Open our file for reading @@ -75,17 +82,6 @@ func runCommand(cmd *cobra.Command, args []string) { os.Exit(1) } - // ssm.SendCommandInput objects require parameters for the DocumentName chosen - params := &invocation.RunShellScriptParameters{ - /* - For AWS-RunShellScript, the only required parameter is "commands", - which is the shell command to be executed on the target. To emulate - the original script, we also set "executionTimeout" to 10 minutes. - */ - "commands": aws.StringSlice(commandList), - "executionTimeout": aws.StringSlice([]string{"600"}), - } - log.Info("Command(s) to be executed: ", strings.Join(commandList, ",")) if len(profileList) == 0 { @@ -116,71 +112,46 @@ func runCommand(cmd *cobra.Command, args []string) { // Set up our AWS session for each permutation of profile + region sessionPool := session.NewPoolSafe(profileList, regionList) - // Set up our filters - var filterMaps []map[string]string - // Convert the filter slice to a map - filterMap := make(map[string]string) + targets := []*ssm.Target{} if len(filterList) > 0 { - util.SliceToMap(filterList, &filterMap) - filterMaps = append(filterMaps, filterMap) + targets = util.SliceToTargets(filterList) } - var completedInvocations invocation.ResultSafe - var wg sync.WaitGroup + // ssm.SendCommandInput objects require parameters for the DocumentName chosen + params := &invocation.RunShellScriptParameters{ + /* + For AWS-RunShellScript, the only required parameter is "commands", + which is the shell command to be executed on the target. To emulate + the original script, we also set "executionTimeout" to 10 minutes. + */ + "commands": aws.StringSlice(commandList), + "executionTimeout": aws.StringSlice([]string{"600"}), + } - for _, sess := range sessionPool.Sessions { - wg.Add(1) - go func(sess *session.Pool, completedInvocations *invocation.ResultSafe) { - defer wg.Done() - instanceChan := make(chan []*ssm.InstanceInformation) - errChan := make(chan error) - svc := ssm.New(sess.Session) - - go ssmx.GetInstanceList(svc, filterMaps, instanceList, false, instanceChan, errChan) - info, err := <-instanceChan, <-errChan - - if err != nil { - log.Debugf("AWS Session Parameters: %s, %s", *sess.Session.Config.Region, sess.ProfileName) - log.Error(err) - } - - if len(info) == 0 { - return - } - - if len(info) > 0 { - log.Infof("Fetched %d instances for account [%s] in [%s].", len(info), sess.ProfileName, *sess.Session.Config.Region) - if dryRunFlag { - log.Info("Targeted instances:") - for _, instance := range info { - log.Infof("%s", *instance.InstanceId) - } - } - } - - if limitFlag == 0 || limitFlag > len(info) { - limitFlag = len(info) - } - - if err = ssmx.RunInvocations(sess, svc, info[:limitFlag], params, dryRunFlag, completedInvocations); err != nil { - log.Error(err) - } - }(sess, &completedInvocations) + sciInput := &ssm.SendCommandInput{ + InstanceIds: aws.StringSlice(instanceList), + Targets: targets, + DocumentName: aws.String("AWS-RunShellScript"), + Parameters: *params, } - wg.Wait() + ec := make(chan error) + var output invocation.ResultSafe + + for _, sess := range sessionPool.Sessions { + ssmx.RunInvocations(sess, sciInput, &output, ec) + } // Hide results if --verbose is set to quiet or terse if !dryRunFlag { log.Infof("%-24s %-15s %-15s %s\n", "Instance ID", "Region", "Profile", "Status") } - var successCounter int - var failedCounter int + var successCounter, failedCounter int - for _, v := range completedInvocations.InvocationResults { + for _, v := range output.InvocationResults { // Hide results if --verbose is set to quiet or terse if v.Status != "Success" { diff --git a/ssm/helpers.go b/ssm/helpers.go index ff7cdbc..5c24d46 100644 --- a/ssm/helpers.go +++ b/ssm/helpers.go @@ -1,9 +1,12 @@ package ssm import ( + "fmt" + "time" + "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/ssm" - "github.com/hashicorp/go-multierror" + "github.com/aws/aws-sdk-go/service/ssm/ssmiface" log "github.com/sirupsen/logrus" "github.com/disneystreaming/ssm-helpers/aws/session" @@ -53,64 +56,74 @@ func addInstanceInfo(instanceID *string, tags []ec2helpers.InstanceTags, instanc } } -// RunInvocations invokes an SSM document with given parameters on the provided slice of instances -func RunInvocations(sp *session.Pool, sess *ssm.SSM, instances []*ssm.InstanceInformation, params *invocation.RunShellScriptParameters, dryRun bool, resultsPool *invocation.ResultSafe) (err error) { - var commandOutput invocation.CommandOutputSafe - var invError error - - scoChan := make(chan *ssm.SendCommandOutput) - errChan := make(chan error) +func checkInvocationStatus(ctx ssmiface.SSMAPI, commandID *string) (done bool, err error) { + var invocation *ssm.ListCommandsOutput + if invocation, err = ctx.ListCommands(&ssm.ListCommandsInput{ + CommandId: commandID, + }); err != nil { + return true, err + } - /* - In a standard deployment, SSM allows us run commands on a maximum of - up to 50 instances simultaneously. + if len(invocation.Commands) != 1 { + return true, fmt.Errorf("Incorrect number of invocations returned for given command ID; expected 1, got %d", len(invocation.Commands)) + } - (Technically, it does an exponential deployment, where it deploys to n^2 - instances at a time (up to 50), where n is the last number of instances - on which the command completed.) + switch *invocation.Commands[0].Status { + case "Pending", "InProgress": + return false, nil + default: + return true, nil + } +} - To speed up execution, we can split the instances into arbitrarily-sized - batches and run the command on every batch concurrently. In the current - implementation, we are effectively using a batch size of 1 for maximum - concurrency. - */ +// RunInvocations invokes an SSM document with given parameters on the provided slice of instances +func RunInvocations(sess *session.Pool, input *ssm.SendCommandInput, results *invocation.ResultSafe, ec chan error) { + oc := make(chan *ssm.GetCommandInvocationOutput) + svc := ssm.New(sess.Session) - for _, instance := range instances { - go invocation.RunSSMCommand(sess, params, dryRun, scoChan, errChan, *instance.InstanceId) - output, err := <-scoChan, <-errChan + if scOutput, err := svc.SendCommand(input); err == nil { - if err != nil { - invError = multierror.Append(invError, err) + // Watch status of invocation to see when it's done and we can get the output + for done := false; !done; time.Sleep(2 * time.Second) { + if done, err = checkInvocationStatus(svc, scOutput.Command.CommandId); err != nil { + ec <- err + break + } } - if output != nil { - addInvocationInfo(output, &commandOutput) + lciInput := &ssm.ListCommandInvocationsInput{ + CommandId: scOutput.Command.CommandId, } - } - // Fetch the results of our invocation for all provided instances - invocationStatus, err := invocation.GetCommandInvocationResult(sess, commandOutput.Output...) - if err != nil { - // If we somehow throw an error here, something has gone screwy with our invocation or the target instance - // See the docs on ssm.GetCommandInvocation() for error details - invError = multierror.Append(invError, err) - } + if err := svc.ListCommandInvocationsPages(lciInput, func(page *ssm.ListCommandInvocationsOutput, lastPage bool) bool { + for _, entry := range page.CommandInvocations { + // Fetch the results of our invocation for all provided instances + go invocation.GetResult(svc, scOutput.Command.CommandId, entry.InstanceId, oc, ec) - // Iterate through all retrieved invocation results to add some extra context - addInvocationResults(invocationStatus, resultsPool, sp) - return invError -} + // Wait for results to return until the combined total of results and errors + select { + case result := <-oc: + addInvocationResults(results, sess, result) + } + } -func addInvocationInfo(info *ssm.SendCommandOutput, infoPool *invocation.CommandOutputSafe) { - if info != nil { - infoPool.Lock() - infoPool.Output = append(infoPool.Output, info) - infoPool.Unlock() - } + // Last page, break out + if page.NextToken == nil { + return false + } + lciInput.SetNextToken(*page.NextToken) + return true + }); err != nil { + ec <- err + } + + } else { + ec <- err + } } -func addInvocationResults(info []*ssm.GetCommandInvocationOutput, results *invocation.ResultSafe, session *session.Pool) { +func addInvocationResults(results *invocation.ResultSafe, session *session.Pool, info ...*ssm.GetCommandInvocationOutput) { for _, v := range info { var result = &invocation.Result{ InvocationResult: v, diff --git a/ssm/invocation/runner.go b/ssm/invocation/runner.go index 83619ce..c98756b 100644 --- a/ssm/invocation/runner.go +++ b/ssm/invocation/runner.go @@ -1,111 +1,63 @@ package invocation import ( - "sync" + "fmt" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/service/ssm" "github.com/aws/aws-sdk-go/service/ssm/ssmiface" - log "github.com/sirupsen/logrus" ) // RunSSMCommand uses an SSM session, pre-defined SSM document parameters, the dry run flag, and any number of instance IDs and executes the given command // using the AWS-RunShellScript SSM document. It returns an *ssm.SendCommandOutput object, which contains the execution ID of the command, which we use to // check the progress/status of the invocation. -func RunSSMCommand(session ssmiface.SSMAPI, params *RunShellScriptParameters, dryRunFlag bool, resultChan chan *ssm.SendCommandOutput, errChan chan error, instanceID ...string) { - var err error - var output *ssm.SendCommandOutput - - ssmCommandInput := &ssm.SendCommandInput{ - DocumentName: aws.String("AWS-RunShellScript"), - InstanceIds: aws.StringSlice(instanceID), - Parameters: *params} +func RunSSMCommand(session ssmiface.SSMAPI, input *ssm.SendCommandInput, dryRunFlag bool) (scOutput *ssm.SendCommandOutput, err error) { if !dryRunFlag { - output, err = session.SendCommand(ssmCommandInput) + return session.SendCommand(input) } - resultChan <- output - errChan <- err + return } -// GetCommandInvocationResult takes an SSM context and any number of *ssm.SendCommandOutput objects and iterates through them until the invocation is complete. -// Each invocation is checked concurrently, but the method as a whole is blocking until all invocations have returned a finishing result, whether successful or not. -func GetCommandInvocationResult(context ssmiface.SSMAPI, jobs ...*ssm.SendCommandOutput) (invocationStatus []*ssm.GetCommandInvocationOutput, err error) { - // We're creating this here as well as in main() because otherwise we don't have the appropriate logging context - errLog := log.New() - errLog.SetFormatter(&log.TextFormatter{ - // Disable level truncation, timestamp, and pad out the level text to even it up - DisableLevelTruncation: true, - DisableTimestamp: true, - }) +func GetTargets(ctx ssmiface.SSMAPI, commandID *string) (targets []*string, err error) { + var out *ssm.ListCommandInvocationsOutput - wg := sync.WaitGroup{} + // Try a few times to get the invocation data, because it takes a little bit to have any information + for i := 0; i < 3; i++ { + time.Sleep(1 * time.Second) + if out, err = ctx.ListCommandInvocations(&ssm.ListCommandInvocationsInput{ + CommandId: commandID, + }); err != nil { + return nil, err + } - type resultsSafe struct { - sync.Mutex - results []*ssm.GetCommandInvocationOutput + if len(out.CommandInvocations) > 0 { + break + } } - var results resultsSafe - - // Concurrently iterate through all items in []instanceIDs and get the invocation status - for _, v := range jobs { - if v.Command != nil { - for _, i := range v.Command.InstanceIds { - wg.Add(1) - go func(v *ssm.SendCommandOutput, i *string, context ssmiface.SSMAPI) { - defer wg.Done() - /* - GetCommandInvocation() requires a GetCommandInvocationInput object, which - has required parameters CommandId and InstanceId. It is important to note - that unlike the execution of the command, you can only retrieve the invocation - results for one instance+command at a time. - */ - gciInput := &ssm.GetCommandInvocationInput{ - CommandId: v.Command.CommandId, - InstanceId: i, - } - - // Retrieve the status of the command invocation - status, err := context.GetCommandInvocation(gciInput) - - // If we get "InvocationDoesNotExist", it just means we tried to check the results too quickly - for awsErr, ok := err.(awserr.Error); ok && err != nil && awsErr.Code() == "InvocationDoesNotExist"; { - time.Sleep(1000 * time.Millisecond) - status, err = context.GetCommandInvocation(gciInput) - } + if len(out.CommandInvocations) == 0 { + return nil, fmt.Errorf("API response contained no invocations") + } - // If we somehow throw a real error here, something has gone screwy with our invocation or the target instance - // See the docs on ssm.GetCommandInvocation() for error details - if err != nil { - errLog.Errorln(err) - return - } + for _, inv := range out.CommandInvocations { + targets = append(targets, inv.InstanceId) + } - // If the invocation is in a pending state, we sleep for a couple seconds before retrying the query - // NOTE: This may need to change based on API limits, but as there is no documentation, we'll have to wait and see. - for *status.StatusDetails == "InProgress" || *status.StatusDetails == "Pending" { - status, err = context.GetCommandInvocation(gciInput) - time.Sleep(2000 * time.Millisecond) - } + return targets, nil +} - if err != nil { - errLog.Errorln(err) - return - } +func GetResult(ctx ssmiface.SSMAPI, commandID *string, instanceID *string, gci chan *ssm.GetCommandInvocationOutput, ec chan error) { + status, err := ctx.GetCommandInvocation(&ssm.GetCommandInvocationInput{ + CommandId: commandID, + InstanceId: instanceID, + }) - // Append the result to our slice of results - results.Lock() - results.results = append(results.results, status) - results.Unlock() - }(v, i, context) - } - } + switch { + case err != nil: + ec <- err + case status != nil: + gci <- status } - wg.Wait() - // Return - return results.results, err } From 057ca642be00261e3b9aa085176eefbec02ddb8f Mon Sep 17 00:00:00 2001 From: sendqueery <47192407+sendqueery@users.noreply.github.com> Date: Thu, 16 Jul 2020 12:21:06 -0400 Subject: [PATCH 03/12] Update gomod --- go.mod | 11 ++++++----- go.sum | 12 ++++++++++++ 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index 8320d61..9bb1882 100644 --- a/go.mod +++ b/go.mod @@ -3,15 +3,16 @@ module github.com/disneystreaming/ssm-helpers go 1.13 require ( - github.com/AlecAivazis/survey/v2 v2.0.7 - github.com/aws/aws-sdk-go v1.32.3 + github.com/AlecAivazis/survey/v2 v2.0.8 + github.com/aws/aws-sdk-go v1.33.6 github.com/disneystreaming/gomux v0.0.0-20200305000114-de122d6df124 - github.com/hashicorp/go-multierror v1.1.0 - github.com/mattn/go-colorable v0.1.6 // indirect + github.com/mattn/go-colorable v0.1.7 // indirect + github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d // indirect github.com/mitchellh/go-homedir v1.1.0 github.com/sirupsen/logrus v1.6.0 github.com/spf13/cobra v1.0.0 github.com/stretchr/testify v1.6.1 - golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1 // indirect + golang.org/x/crypto v0.0.0-20200709230013-948cd5f35899 // indirect + golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae // indirect golang.org/x/text v0.3.3 // indirect ) diff --git a/go.sum b/go.sum index 3448ef1..9956f13 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,8 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/AlecAivazis/survey/v2 v2.0.7 h1:+f825XHLse/hWd2tE/V5df04WFGimk34Eyg/z35w/rc= github.com/AlecAivazis/survey/v2 v2.0.7/go.mod h1:mlizQTaPjnR4jcpwRSaSlkbsRfYFEyKgLQvYTzxxiHA= +github.com/AlecAivazis/survey/v2 v2.0.8 h1:zVjWKN+JIAfmrq6nGWG3DfLS8ypEBhxYy0p7FM+riFk= +github.com/AlecAivazis/survey/v2 v2.0.8/go.mod h1:9FJRdMdDm8rnT+zHVbvQT2RTSTLq0Ttd6q3Vl2fahjk= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/Netflix/go-expect v0.0.0-20180615182759-c93bf25de8e8 h1:xzYJEypr/85nBpB11F9br+3HUrpgb+fcm5iADzXXYEw= github.com/Netflix/go-expect v0.0.0-20180615182759-c93bf25de8e8/go.mod h1:oX5x61PbNXchhh0oikYAH+4Pcfw5LKv21+Jnpr6r6Pc= @@ -10,6 +12,8 @@ github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRF github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= github.com/aws/aws-sdk-go v1.32.3 h1:E3OciOGVlJrv1gQ2T7/Oou+I9nGPB2j978THQjvZBf0= github.com/aws/aws-sdk-go v1.32.3/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= +github.com/aws/aws-sdk-go v1.33.6 h1:YLoUeMSx05kHwhS+HLDSpdYYpPzJMyp6hn1cWsJ6a+U= +github.com/aws/aws-sdk-go v1.33.6/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= @@ -80,6 +84,8 @@ github.com/mattn/go-colorable v0.1.2 h1:/bC9yWikZXAL9uJdulbSfyVNIR3n3trXl+v8+1sx github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.6 h1:6Su7aK7lXmJ/U79bYtBjLNaha4Fs1Rg9plHpcH+vvnE= github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-colorable v0.1.7 h1:bQGKb3vps/j0E9GfJQ03JyhRuxsvdAanXlT9BTw3mdw= +github.com/mattn/go-colorable v0.1.7/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE= github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= @@ -87,6 +93,8 @@ github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Ky github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b h1:j7+1HpAFS1zy5+Q4qx1fWh90gTKwiN4QCGoY9TWyyO4= github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= +github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d h1:5PJl274Y63IEHC+7izoQE9x6ikvDFZS2mDVS3drnohI= +github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= @@ -141,6 +149,8 @@ golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnf golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190530122614-20be4c3c3ed5 h1:8dUaAV7K4uHsF56JQWkprecIQKdPHtR9jCHF5nB8uzc= golang.org/x/crypto v0.0.0-20190530122614-20be4c3c3ed5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200709230013-948cd5f35899 h1:DZhuSZLsGlFL4CmhA8BcRA0mnthyA/nZ00AqCUo7vHg= +golang.org/x/crypto v0.0.0-20200709230013-948cd5f35899/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -169,6 +179,8 @@ golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1 h1:ogLJMz+qpzav7lGMh10LMvAkM/fAoGlaiiHYiFYdm80= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae h1:Ih9Yo4hSPImZOpfGuA4bR/ORKTAbhZo2AbWNRCnevdo= +golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= From 834a2a266cc8b8e97bbe054520d0121198f95f07 Mon Sep 17 00:00:00 2001 From: sendqueery <47192407+sendqueery@users.noreply.github.com> Date: Thu, 16 Jul 2020 13:43:08 -0400 Subject: [PATCH 04/12] Call RunInvocations() as a goroutine and add waitgroup --- cmd/run.go | 13 ++++++++++++- ssm/helpers.go | 4 +++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/cmd/run.go b/cmd/run.go index 802896e..66b0da5 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -5,6 +5,7 @@ import ( "os" "runtime" "strings" + "sync" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ssm" @@ -139,11 +140,21 @@ func runCommand(cmd *cobra.Command, args []string) { ec := make(chan error) var output invocation.ResultSafe + var wg sync.WaitGroup for _, sess := range sessionPool.Sessions { - ssmx.RunInvocations(sess, sciInput, &output, ec) + wg.Add(1) + go ssmx.RunInvocations(sess, &wg, sciInput, &output, ec) } + select { + case err := <-ec: + log.Error(err) + default: + } + + wg.Wait() + // Hide results if --verbose is set to quiet or terse if !dryRunFlag { log.Infof("%-24s %-15s %-15s %s\n", "Instance ID", "Region", "Profile", "Status") diff --git a/ssm/helpers.go b/ssm/helpers.go index 5c24d46..e5e38de 100644 --- a/ssm/helpers.go +++ b/ssm/helpers.go @@ -2,6 +2,7 @@ package ssm import ( "fmt" + "sync" "time" "github.com/aws/aws-sdk-go/service/ec2" @@ -77,7 +78,8 @@ func checkInvocationStatus(ctx ssmiface.SSMAPI, commandID *string) (done bool, e } // RunInvocations invokes an SSM document with given parameters on the provided slice of instances -func RunInvocations(sess *session.Pool, input *ssm.SendCommandInput, results *invocation.ResultSafe, ec chan error) { +func RunInvocations(sess *session.Pool, wg *sync.WaitGroup, input *ssm.SendCommandInput, results *invocation.ResultSafe, ec chan error) { + defer wg.Done() oc := make(chan *ssm.GetCommandInvocationOutput) svc := ssm.New(sess.Session) From c0f8f946053353e08eb36f497b0e7b2909a20186 Mon Sep 17 00:00:00 2001 From: sendqueery <47192407+sendqueery@users.noreply.github.com> Date: Thu, 16 Jul 2020 15:50:58 -0400 Subject: [PATCH 05/12] Update RunInvocations() to use mockable SSM client --- cmd/run.go | 4 +++- ssm/helpers.go | 45 +++++++++++++++++++++++++++------------------ 2 files changed, 30 insertions(+), 19 deletions(-) diff --git a/cmd/run.go b/cmd/run.go index 66b0da5..bdaaaa8 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -144,7 +144,9 @@ func runCommand(cmd *cobra.Command, args []string) { for _, sess := range sessionPool.Sessions { wg.Add(1) - go ssmx.RunInvocations(sess, &wg, sciInput, &output, ec) + ssmClient := ssm.New(sess.Session) + log.Debugf("Starting invocation targeting account %s in %s", sess.ProfileName, *sess.Session.Config.Region) + go ssmx.RunInvocations(sess, ssmClient, &wg, sciInput, &output, ec) } select { diff --git a/ssm/helpers.go b/ssm/helpers.go index e5e38de..9465f9c 100644 --- a/ssm/helpers.go +++ b/ssm/helpers.go @@ -78,29 +78,41 @@ func checkInvocationStatus(ctx ssmiface.SSMAPI, commandID *string) (done bool, e } // RunInvocations invokes an SSM document with given parameters on the provided slice of instances -func RunInvocations(sess *session.Pool, wg *sync.WaitGroup, input *ssm.SendCommandInput, results *invocation.ResultSafe, ec chan error) { +func RunInvocations(sess *session.Pool, ctx ssmiface.SSMAPI, wg *sync.WaitGroup, input *ssm.SendCommandInput, results *invocation.ResultSafe, ec chan error) { defer wg.Done() + oc := make(chan *ssm.GetCommandInvocationOutput) - svc := ssm.New(sess.Session) + var scOutput *ssm.SendCommandOutput + var err error - if scOutput, err := svc.SendCommand(input); err == nil { + // Send our command input to SSM + if scOutput, err = ctx.SendCommand(input); err != nil { + ec <- err + return + } - // Watch status of invocation to see when it's done and we can get the output - for done := false; !done; time.Sleep(2 * time.Second) { - if done, err = checkInvocationStatus(svc, scOutput.Command.CommandId); err != nil { - ec <- err - break - } - } + commandID := scOutput.Command.CommandId - lciInput := &ssm.ListCommandInvocationsInput{ - CommandId: scOutput.Command.CommandId, + // Watch status of invocation to see when it's done and we can get the output + for done := false; !done; time.Sleep(2 * time.Second) { + if done, err = checkInvocationStatus(ctx, commandID); err != nil { + ec <- err + break } + } - if err := svc.ListCommandInvocationsPages(lciInput, func(page *ssm.ListCommandInvocationsOutput, lastPage bool) bool { + // Set up our LCI input object + lciInput := &ssm.ListCommandInvocationsInput{ + CommandId: commandID, + } + + // Iterate through the details of the invocations returned + if err = ctx.ListCommandInvocationsPages( + lciInput, + func(page *ssm.ListCommandInvocationsOutput, lastPage bool) bool { for _, entry := range page.CommandInvocations { // Fetch the results of our invocation for all provided instances - go invocation.GetResult(svc, scOutput.Command.CommandId, entry.InstanceId, oc, ec) + go invocation.GetResult(ctx, commandID, entry.InstanceId, oc, ec) // Wait for results to return until the combined total of results and errors select { @@ -117,12 +129,9 @@ func RunInvocations(sess *session.Pool, wg *sync.WaitGroup, input *ssm.SendComma lciInput.SetNextToken(*page.NextToken) return true }); err != nil { - ec <- err - } - - } else { ec <- err } + } func addInvocationResults(results *invocation.ResultSafe, session *session.Pool, info ...*ssm.GetCommandInvocationOutput) { From b9619a2f5fc40d95cff8c13e2220c8aa62471ffb Mon Sep 17 00:00:00 2001 From: sendqueery <47192407+sendqueery@users.noreply.github.com> Date: Thu, 16 Jul 2020 15:52:41 -0400 Subject: [PATCH 06/12] Clean up output logging --- cmd/run.go | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/cmd/run.go b/cmd/run.go index bdaaaa8..45d3466 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -164,23 +164,22 @@ func runCommand(cmd *cobra.Command, args []string) { var successCounter, failedCounter int + var successCounter, failedCounter int for _, v := range output.InvocationResults { - - // Hide results if --verbose is set to quiet or terse - if v.Status != "Success" { - // Always output error info to stderr - log.Errorf("%-24s %-15s %-15s %s", *v.InvocationResult.InstanceId, v.Region, v.ProfileName, *v.InvocationResult.StatusDetails) - log.Error(*v.InvocationResult.StandardErrorContent) - - failedCounter++ - } else { - // Output stdout from invocations to stdout + switch v.Status { + case "Success": log.Infof("%-24s %-15s %-15s %s", *v.InvocationResult.InstanceId, v.Region, v.ProfileName, *v.InvocationResult.StatusDetails) log.Info(*v.InvocationResult.StandardOutputContent) - successCounter++ + case "Failed": + log.Errorf("%-24s %-15s %-15s %s", *v.InvocationResult.InstanceId, v.Region, v.ProfileName, *v.InvocationResult.StatusDetails) + log.Error(*v.InvocationResult.StandardErrorContent) + failedCounter++ + default: + // Non-"Failed" statuses are failures, but don't have any output + log.Errorf("%-24s %-15s %-15s %s", *v.InvocationResult.InstanceId, v.Region, v.ProfileName, *v.InvocationResult.StatusDetails) + failedCounter++ } - } if !dryRunFlag { From 8b4d5860828f2c13a9656e1ed0b47cc9081520ce Mon Sep 17 00:00:00 2001 From: sendqueery <47192407+sendqueery@users.noreply.github.com> Date: Thu, 16 Jul 2020 15:53:10 -0400 Subject: [PATCH 07/12] Remove --dry-run flag from run command --- cmd/run.go | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/cmd/run.go b/cmd/run.go index 45d3466..e14868f 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -39,7 +39,6 @@ func runCommand(cmd *cobra.Command, args []string) { cmdutil.ValidateArgs(cmd, args) commandList := cmdutil.GetCommandFlagStringSlice(cmd) - dryRunFlag := cmdutil.GetFlagBool(cmd.Parent(), "dry-run") profileList := cmdutil.GetFlagStringSlice(cmd.Parent(), "profile") regionList := cmdutil.GetFlagStringSlice(cmd.Parent(), "region") filterList := cmdutil.GetFlagStringSlice(cmd.Parent(), "filter") @@ -158,11 +157,8 @@ func runCommand(cmd *cobra.Command, args []string) { wg.Wait() // Hide results if --verbose is set to quiet or terse - if !dryRunFlag { - log.Infof("%-24s %-15s %-15s %s\n", "Instance ID", "Region", "Profile", "Status") - } - var successCounter, failedCounter int + log.Infof("%-24s %-15s %-15s %s\n", "Instance ID", "Region", "Profile", "Status") var successCounter, failedCounter int for _, v := range output.InvocationResults { @@ -182,12 +178,10 @@ func runCommand(cmd *cobra.Command, args []string) { } } - if !dryRunFlag { - log.Infof("Execution results: %d SUCCESS, %d FAILED", successCounter, failedCounter) - if failedCounter > 0 { - // Exit code 1 to indicate that there was some sort of error returned from invocation - os.Exit(1) - } + log.Infof("Execution results: %d SUCCESS, %d FAILED", successCounter, failedCounter) + if failedCounter > 0 { + // Exit code 1 to indicate that there was some sort of error returned from invocation + os.Exit(1) } return From 0ece678d49927fb68ba8227100cfc1ed69f4b31a Mon Sep 17 00:00:00 2001 From: sendqueery <47192407+sendqueery@users.noreply.github.com> Date: Mon, 20 Jul 2020 13:48:03 -0400 Subject: [PATCH 08/12] Clean up invocation logging --- aws/session/session.go | 13 +++---- aws/session/types.go | 2 ++ cmd/run.go | 74 +++++++++++++++++++++------------------- cmd/session.go | 2 +- ssm/helpers.go | 15 ++++---- ssm/invocation/runner.go | 6 +++- 6 files changed, 62 insertions(+), 50 deletions(-) diff --git a/aws/session/session.go b/aws/session/session.go index 6d75720..e90cfc4 100644 --- a/aws/session/session.go +++ b/aws/session/session.go @@ -5,19 +5,20 @@ import ( "sync" "github.com/aws/aws-sdk-go/aws/session" + "github.com/sirupsen/logrus" "github.com/disneystreaming/ssm-helpers/aws/config" ) // NewPoolSafe is used to create a pool of AWS sessions with different profile/region permutations -func NewPoolSafe(profiles []string, regions []string) (allSessions *PoolSafe) { +func NewPoolSafe(profiles []string, regions []string, logger *logrus.Logger) (allSessions *PoolSafe) { wg := sync.WaitGroup{} sp := &PoolSafe{ Sessions: make(map[string]*Pool), } - if regions != nil { + if len(regions) == 0 { wg.Add(len(profiles) * len(regions)) for _, p := range profiles { for _, r := range regions { @@ -29,8 +30,9 @@ func NewPoolSafe(profiles []string, regions []string) (allSessions *PoolSafe) { sp.Lock() session := Pool{ - Session: newSession, + Logger: logger, ProfileName: p, + Session: newSession, } sp.Sessions[fmt.Sprintf("%s-%s", p, r)] = &session //sp.Sessions = append(sp.Sessions, &session) @@ -49,12 +51,11 @@ func NewPoolSafe(profiles []string, regions []string) (allSessions *PoolSafe) { sp.Lock() session := Pool{ - Session: newSession, + Logger: logger, ProfileName: p, + Session: newSession, } sp.Sessions[fmt.Sprintf("%s", p)] = &session - - //sp.Sessions = append(sp.Sessions, &session) defer sp.Unlock() }(p) } diff --git a/aws/session/types.go b/aws/session/types.go index 5172408..feb67fc 100644 --- a/aws/session/types.go +++ b/aws/session/types.go @@ -4,10 +4,12 @@ import ( "sync" "github.com/aws/aws-sdk-go/aws/session" + "github.com/sirupsen/logrus" ) // Pool is a type that holds an instance of an AWS session as well as the profile name used to initialize it type Pool struct { + Logger *logrus.Logger Session *session.Session ProfileName string } diff --git a/cmd/run.go b/cmd/run.go index e14868f..a9183fe 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -119,68 +119,70 @@ func runCommand(cmd *cobra.Command, args []string) { targets = util.SliceToTargets(filterList) } - // ssm.SendCommandInput objects require parameters for the DocumentName chosen - params := &invocation.RunShellScriptParameters{ - /* - For AWS-RunShellScript, the only required parameter is "commands", - which is the shell command to be executed on the target. To emulate - the original script, we also set "executionTimeout" to 10 minutes. - */ - "commands": aws.StringSlice(commandList), - "executionTimeout": aws.StringSlice([]string{"600"}), - } + log.Info("Command(s) to be executed:\n", strings.Join(commandList, "\n")) sciInput := &ssm.SendCommandInput{ InstanceIds: aws.StringSlice(instanceList), Targets: targets, DocumentName: aws.String("AWS-RunShellScript"), - Parameters: *params, + Parameters: map[string][]*string{ + /* + ssm.SendCommandInput objects require parameters for the DocumentName chosen + + For AWS-RunShellScript, the only required parameter is "commands", + which is the shell command to be executed on the target. To emulate + the original script, we also set "executionTimeout" to 10 minutes. + */ + "commands": aws.StringSlice(commandList), + "executionTimeout": aws.StringSlice([]string{"600"}), + }, } - ec := make(chan error) - var output invocation.ResultSafe - var wg sync.WaitGroup + // Set up our AWS session for each permutation of profile + region + sessionPool := session.NewPoolSafe(profileList, regionList, log) + wg, output := sync.WaitGroup{}, invocation.ResultSafe{} for _, sess := range sessionPool.Sessions { wg.Add(1) ssmClient := ssm.New(sess.Session) log.Debugf("Starting invocation targeting account %s in %s", sess.ProfileName, *sess.Session.Config.Region) - go ssmx.RunInvocations(sess, ssmClient, &wg, sciInput, &output, ec) - } - - select { - case err := <-ec: - log.Error(err) - default: + go ssmx.RunInvocations(sess, ssmClient, &wg, sciInput, &output) } - wg.Wait() - - // Hide results if --verbose is set to quiet or terse - - log.Infof("%-24s %-15s %-15s %s\n", "Instance ID", "Region", "Profile", "Status") + wg.Wait() // Wait for each account/region combo to finish + resultFormat := "%-24s %-15s %-15s %s" var successCounter, failedCounter int + + // Output our results + log.Infof(resultFormat, "Instance ID", "Region", "Profile", "Status") for _, v := range output.InvocationResults { switch v.Status { case "Success": - log.Infof("%-24s %-15s %-15s %s", *v.InvocationResult.InstanceId, v.Region, v.ProfileName, *v.InvocationResult.StatusDetails) - log.Info(*v.InvocationResult.StandardOutputContent) + log.Infof(resultFormat, *v.InvocationResult.InstanceId, v.Region, v.ProfileName, *v.InvocationResult.StatusDetails) successCounter++ - case "Failed": - log.Errorf("%-24s %-15s %-15s %s", *v.InvocationResult.InstanceId, v.Region, v.ProfileName, *v.InvocationResult.StatusDetails) - log.Error(*v.InvocationResult.StandardErrorContent) - failedCounter++ default: - // Non-"Failed" statuses are failures, but don't have any output - log.Errorf("%-24s %-15s %-15s %s", *v.InvocationResult.InstanceId, v.Region, v.ProfileName, *v.InvocationResult.StatusDetails) + log.Errorf(resultFormat, *v.InvocationResult.InstanceId, v.Region, v.ProfileName, *v.InvocationResult.StatusDetails) failedCounter++ } + + // stdout is always written back at info level + if *v.InvocationResult.StandardOutputContent != "" { + log.Info(*v.InvocationResult.StandardOutputContent) + } + + // stderr is written back at warn if the invocation was successful, and error if not + if *v.InvocationResult.StandardErrorContent != "" { + if v.Status == "Success" { + log.Warn(*v.InvocationResult.StandardErrorContent) + } else { + log.Error(*v.InvocationResult.StandardErrorContent) + } + } } log.Infof("Execution results: %d SUCCESS, %d FAILED", successCounter, failedCounter) - if failedCounter > 0 { - // Exit code 1 to indicate that there was some sort of error returned from invocation + if failedCounter > 0 { // Exit code 1 to indicate that there was some sort of error returned from invocation os.Exit(1) } diff --git a/cmd/session.go b/cmd/session.go index c1449d4..cd263e1 100644 --- a/cmd/session.go +++ b/cmd/session.go @@ -82,7 +82,7 @@ func startSessionCommand(cmd *cobra.Command, args []string) { } // Set up our AWS session for each permutation of profile + region - sessionPool := session.NewPoolSafe(profileList, regionList) + sessionPool := session.NewPoolSafe(profileList, regionList, log) // Set up our filters var filterMaps []map[string]string diff --git a/ssm/helpers.go b/ssm/helpers.go index 9465f9c..b37e983 100644 --- a/ssm/helpers.go +++ b/ssm/helpers.go @@ -62,7 +62,7 @@ func checkInvocationStatus(ctx ssmiface.SSMAPI, commandID *string) (done bool, e if invocation, err = ctx.ListCommands(&ssm.ListCommandsInput{ CommandId: commandID, }); err != nil { - return true, err + return true, fmt.Errorf("Encountered an error when trying to call the ListCommands API with CommandId: %v\n%v", *commandID, err) } if len(invocation.Commands) != 1 { @@ -78,16 +78,17 @@ func checkInvocationStatus(ctx ssmiface.SSMAPI, commandID *string) (done bool, e } // RunInvocations invokes an SSM document with given parameters on the provided slice of instances -func RunInvocations(sess *session.Pool, ctx ssmiface.SSMAPI, wg *sync.WaitGroup, input *ssm.SendCommandInput, results *invocation.ResultSafe, ec chan error) { +func RunInvocations(sess *session.Pool, ctx ssmiface.SSMAPI, wg *sync.WaitGroup, input *ssm.SendCommandInput, results *invocation.ResultSafe) { defer wg.Done() oc := make(chan *ssm.GetCommandInvocationOutput) + ec := make(chan error) var scOutput *ssm.SendCommandOutput var err error // Send our command input to SSM if scOutput, err = ctx.SendCommand(input); err != nil { - ec <- err + sess.Logger.Errorf("Error when calling the SendCommand API for account %v in %v\n%v", sess.ProfileName, *sess.Session.Config.Region, err) return } @@ -96,8 +97,8 @@ func RunInvocations(sess *session.Pool, ctx ssmiface.SSMAPI, wg *sync.WaitGroup, // Watch status of invocation to see when it's done and we can get the output for done := false; !done; time.Sleep(2 * time.Second) { if done, err = checkInvocationStatus(ctx, commandID); err != nil { - ec <- err - break + sess.Logger.Error(err) + return } } @@ -118,6 +119,8 @@ func RunInvocations(sess *session.Pool, ctx ssmiface.SSMAPI, wg *sync.WaitGroup, select { case result := <-oc: addInvocationResults(results, sess, result) + case err := <-ec: + sess.Logger.Error(err) } } @@ -129,7 +132,7 @@ func RunInvocations(sess *session.Pool, ctx ssmiface.SSMAPI, wg *sync.WaitGroup, lciInput.SetNextToken(*page.NextToken) return true }); err != nil { - ec <- err + sess.Logger.Error(fmt.Errorf("Error when calling ListCommandInvocations API\n%v", err)) } } diff --git a/ssm/invocation/runner.go b/ssm/invocation/runner.go index c98756b..d341ec6 100644 --- a/ssm/invocation/runner.go +++ b/ssm/invocation/runner.go @@ -55,7 +55,11 @@ func GetResult(ctx ssmiface.SSMAPI, commandID *string, instanceID *string, gci c switch { case err != nil: - ec <- err + ec <- fmt.Errorf( + `Error when calling GetCommandInvocation API with args:\n + CommandId: %v\n + InstanceId: %v\n%v`, + *commandID, *instanceID, err) case status != nil: gci <- status } From 06bbf06666f401d799b625cb69bc4a98d342da1d Mon Sep 17 00:00:00 2001 From: sendqueery <47192407+sendqueery@users.noreply.github.com> Date: Mon, 20 Jul 2020 13:48:15 -0400 Subject: [PATCH 09/12] Clean up args and validation --- cmd/run.go | 98 ++++++++++++++++++-------------------------------- util/common.go | 26 ++++++++++++++ 2 files changed, 61 insertions(+), 63 deletions(-) diff --git a/cmd/run.go b/cmd/run.go index a9183fe..0dc69d5 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -1,7 +1,6 @@ package cmd import ( - "bufio" "os" "runtime" "strings" @@ -36,87 +35,60 @@ func newCommandSSMRun() *cobra.Command { } func runCommand(cmd *cobra.Command, args []string) { + // Get all of our CLI flag values cmdutil.ValidateArgs(cmd, args) - commandList := cmdutil.GetCommandFlagStringSlice(cmd) profileList := cmdutil.GetFlagStringSlice(cmd.Parent(), "profile") regionList := cmdutil.GetFlagStringSlice(cmd.Parent(), "region") filterList := cmdutil.GetFlagStringSlice(cmd.Parent(), "filter") instanceList := cmdutil.GetFlagStringSlice(cmd, "instance") + allProfilesFlag := cmdutil.GetFlagBool(cmd, "all-profiles") + // Get the number of cores available for parallelization runtime.GOMAXPROCS(runtime.NumCPU()) - if len(instanceList) > 0 && len(filterList) > 0 { - cmdutil.UsageError(cmd, "The --filter and --instance flags cannot be used simultaneously.") - os.Exit(1) - } - - if len(instanceList) > 50 { - cmdutil.UsageError(cmd, "The --instance flag can only be used to specify a maximum of 50 instances.") - } - // If the --commands and --file options are specified, we append the script contents to the specified commands if inputFile := cmdutil.GetFlagString(cmd, "file"); inputFile != "" { - // Open our file for reading - file, err := os.Open(inputFile) - if err != nil { - log.Fatalf("Could not open file at %s\n%s", inputFile, err) - } - - defer file.Close() - - // Grab each line of the script and append it to the command slice - // Scripts using a line continuation character (\) will work fine here too! - scanner := bufio.NewScanner(file) - for scanner.Scan() { - commandList = append(commandList, scanner.Text()) - } - - if err := scanner.Err(); err != nil { - log.Fatalf("Issue when trying to read input file\n%s", err) + if err := util.ReadScriptFile(inputFile, &commandList); err != nil { + log.Fatal(err) } } - if commandList == nil || len(commandList) == 0 { + switch { // These cases are all fatal for our invocations or result in undefined behavior + case len(instanceList) > 0 && len(filterList) > 0: + cmdutil.UsageError(cmd, "The --filter and --instance flags cannot be used simultaneously.") + case len(instanceList) > 50: + cmdutil.UsageError(cmd, "The --instance flag can only be used to specify a maximum of 50 instances.") + case len(commandList) == 0: cmdutil.UsageError(cmd, "Please specify a command to be run on your instances.") - os.Exit(1) + case len(profileList) > 0 && allProfilesFlag: + cmdutil.UsageError(cmd, "The --profile and --all-profiles flags cannot be used simultaneously.") } - log.Info("Command(s) to be executed: ", strings.Join(commandList, ",")) - - if len(profileList) == 0 { - env, exists := os.LookupEnv("AWS_PROFILE") - if exists { - profileList = []string{env} - } else { - profileList = []string{"default"} - } - } - - if len(regionList) == 0 { - env, exists := os.LookupEnv("AWS_REGION") - if exists == false { - regionList = []string{env} - } - } - - // If --all-profiles is set, we call getAWSProfiles() and iterate through the user's ~/.aws/config - if allProfilesFlag := cmdutil.GetFlagBool(cmd, "all-profiles"); allProfilesFlag { - profileList, err := awsx.GetAWSProfiles() - if profileList == nil || err != nil { - log.Error("Could not load profiles.", err) - os.Exit(1) - } - } - - // Set up our AWS session for each permutation of profile + region - sessionPool := session.NewPoolSafe(profileList, regionList) - - // Convert the filter slice to a map targets := []*ssm.Target{} - if len(filterList) > 0 { - targets = util.SliceToTargets(filterList) +args: + for { + switch { + case allProfilesFlag: // If --all-profiles is set, we call getAWSProfiles() and iterate through the user's ~/.aws/config + if profileList, err := awsx.GetAWSProfiles(); profileList == nil || err != nil { + log.Fatalf("Could not load profiles.\n%v", err) + } + case len(filterList) > 0 && len(targets) == 0: // Convert the filter slice to a map + targets = util.SliceToTargets(filterList) + case len(profileList) == 0: // If no profile is specified, look it up or fall back to "default" + if env, exists := os.LookupEnv("AWS_PROFILE"); !exists { + profileList = []string{env} + } else { + profileList = []string{"default"} + } + case len(regionList) == 0: // If no region is specified, attempt to look it up + if env, exists := os.LookupEnv("AWS_REGION"); !exists { + regionList = []string{env} + } + default: + break args + } } log.Info("Command(s) to be executed:\n", strings.Join(commandList, "\n")) diff --git a/util/common.go b/util/common.go index 7a55229..0edea37 100644 --- a/util/common.go +++ b/util/common.go @@ -1,6 +1,9 @@ package util import ( + "bufio" + "fmt" + "os" "strings" "github.com/aws/aws-sdk-go/aws" @@ -41,3 +44,26 @@ func SliceToTargets(kvslice []string) (targets []*ssm.Target) { return targets } + +func ReadScriptFile(inputFile string, commandList *[]string) error { + // Open our file for reading + file, err := os.Open(inputFile) + if err != nil { + return fmt.Errorf("Could not open file at %s\n%s", inputFile, err) + } + + defer file.Close() + + // Grab each line of the script and append it to the command slice + // Scripts using a line continuation character (\) will work fine here too! + scanner := bufio.NewScanner(file) + for scanner.Scan() { + *commandList = append(*commandList, scanner.Text()) + } + + if err = scanner.Err(); err != nil { + return fmt.Errorf("Issue when trying to read input file\n%s", err) + } + + return nil +} From bc3b3f1031018fcae6135d5d2967d080fb665a88 Mon Sep 17 00:00:00 2001 From: sendqueery <47192407+sendqueery@users.noreply.github.com> Date: Tue, 4 Aug 2020 10:02:20 -0400 Subject: [PATCH 10/12] Clean up flag validation (session needs to be fixed) --- cmd/cmdutil/helpers.go | 127 +++++++++++++++--------------------- cmd/flags.go | 131 +++++++++++++++++++++++++++++++++++++ cmd/root.go | 18 ++---- cmd/run.go | 82 +++++++++-------------- cmd/session.go | 144 ++++++++++++++++++----------------------- 5 files changed, 281 insertions(+), 221 deletions(-) create mode 100644 cmd/flags.go diff --git a/cmd/cmdutil/helpers.go b/cmd/cmdutil/helpers.go index 414dfbe..e08e513 100644 --- a/cmd/cmdutil/helpers.go +++ b/cmd/cmdutil/helpers.go @@ -2,21 +2,19 @@ package cmdutil import ( "fmt" - "os" "strings" - log "github.com/sirupsen/logrus" "github.com/spf13/cobra" ) // AddProfileFlag adds --profile to command func AddProfileFlag(cmd *cobra.Command) { - cmd.PersistentFlags().StringSliceP("profile", "p", nil, "Specify a specific profile to use with your API calls.\nMultiple allowed, delimited by commas (e.g. --profile profile1,profile2)") + cmd.Flags().StringSliceP("profile", "p", nil, "Specify a specific profile to use with your API calls.\nMultiple allowed, delimited by commas (e.g. --profile profile1,profile2)") } // AddRegionFlag adds --region to command func AddRegionFlag(cmd *cobra.Command) { - cmd.PersistentFlags().StringSliceP("region", "r", nil, "Specify a specific region to use with your API calls.\n"+ + cmd.Flags().StringSliceP("region", "r", nil, "Specify a specific region to use with your API calls.\n"+ "This option will override any profile settings in your config file.\n"+ "Multiple allowed, delimited by commas (e.g. --region us-east-1,us-west-2)\n\n"+ "[NOTE] Mixing --profile and --region will result in your command targeting every matching instance in the selected profiles and regions.\n"+ @@ -29,33 +27,32 @@ func AddRegionFlag(cmd *cobra.Command) { // AddFilterFlag adds --filter to command func AddFilterFlag(cmd *cobra.Command) { - cmd.PersistentFlags().StringSliceP("filter", "f", nil, "Filter instances based on tag value. Tags are evaluated with logical AND (instances must match all tags).\nMultiple allowed, delimited by commas (e.g. env=dev,foo=bar)") + cmd.Flags().StringSliceP("filter", "f", nil, "Filter instances based on tag value. Tags are evaluated with logical AND (instances must match all tags).\nMultiple allowed, delimited by commas (e.g. env=dev,foo=bar)") } // AddDryRunFlag adds --dry-run to command func AddDryRunFlag(cmd *cobra.Command) { - cmd.PersistentFlags().Bool("dry-run", false, "Retrieve the list of profiles, regions, and instances your command(s) would target") + cmd.Flags().Bool("dry-run", false, "Retrieve the list of profiles, regions, and instances your command(s) would target") } // AddVerboseFlag adds --verbose to command func AddVerboseFlag(cmd *cobra.Command) { - cmd.PersistentFlags().IntP("verbose", "v", 2, "Sets verbosity of output:\n0 = quiet, 1 = terse, 2 = standard, 3 = debug") + cmd.Flags().IntP("verbose", "v", 2, "Sets verbosity of output:\n0 = quiet, 1 = terse, 2 = standard, 3 = debug") } // AddLimitFlag adds --limit to command func AddLimitFlag(cmd *cobra.Command, limit int, desc string) { - cmd.PersistentFlags().IntP("limit", "l", limit, desc) + cmd.Flags().IntP("limit", "l", limit, desc) } // AddInstanceFlag adds --instance to command func AddInstanceFlag(cmd *cobra.Command) { - cmd.PersistentFlags().StringSliceP("instance", "i", nil, "Specify what instance IDs you want to target.\nMultiple allowed, delimited by commas (e.g. --instance i-12345,i-23456)") + cmd.Flags().StringSliceP("instance", "i", nil, "Specify what instance IDs you want to target.\nMultiple allowed, delimited by commas (e.g. --instance i-12345,i-23456)") } // AddAllProfilesFlag adds --all-profiles to command func AddAllProfilesFlag(cmd *cobra.Command) { - cmd.PersistentFlags().Bool("all-profiles", false, "[USE WITH CAUTION] Parse through ~/.aws/config to target all profiles.") - + cmd.Flags().Bool("all-profiles", false, "[USE WITH CAUTION] Parse through ~/.aws/config to target all profiles.") } // AddCommandFlag adds --command to command @@ -79,99 +76,81 @@ func AddSessionNameFlag(cmd *cobra.Command, defaultName string) { } // ValidateArgs makes sure nothing extra was passed on CLI -func ValidateArgs(cmd *cobra.Command, args []string) { +func ValidateArgs(cmd *cobra.Command, args []string) error { if len(args) != 0 { - UsageError(cmd, "Unexpected args: %v", strings.Join(args, " ")) + return UsageError(cmd, "Unexpected args: %v", strings.Join(args, " ")) } + return nil } // UsageError Prints error and tells users to use -h -func UsageError(cmd *cobra.Command, format string, args ...interface{}) { +func UsageError(cmd *cobra.Command, format string, args ...interface{}) error { msg := fmt.Sprintf(format, args...) - fmt.Printf("%s\nSee '%s -h' for help and examples.\n", msg, cmd.CommandPath()) - os.Exit(1) + return fmt.Errorf("%s\nSee '%s -h' for help and examples", msg, cmd.CommandPath()) } // GetCommandFlagStringSlice returns the []string value of a String() flag, delimited by semicolons -func GetCommandFlagStringSlice(cmd *cobra.Command) []string { - commandString, err := cmd.Flags().GetString("command") - if err != nil { - log.WithError(err). - WithFields(log.Fields{ - "flag": "command", - "command": cmd.Name(), - }). - Error("could not fetch flag") +func GetCommandFlagStringSlice(cmd *cobra.Command) (cs []string, err error) { + var s string + if s, err = cmd.Flags().GetString("command"); err != nil { + return nil, fmt.Errorf("Could not fetch flag %v for command %v\n%v", "command", cmd.Name(), err) } - return readAsSSV(commandString) + return readAsSSV(s), nil } // GetFlagStringSlice returns the []string value of a StringSlice() flag -func GetFlagStringSlice(cmd *cobra.Command, flag string) []string { - s, err := cmd.Flags().GetStringSlice(flag) - if err != nil { - log.WithError(err). - WithFields(log.Fields{ - "flag": flag, - "command": cmd.Name(), - }). - Error("could not fetch flag") +func GetFlagStringSlice(cmd *cobra.Command, flag string) (s []string, err error) { + if s, err = cmd.Flags().GetStringSlice(flag); err != nil { + return nil, fmt.Errorf("Could not fetch flag %v for command %v\n%v", flag, cmd.Name(), err) } - return s + return s, nil } // GetFlagString returns the string value of a String() flag -func GetFlagString(cmd *cobra.Command, flag string) string { - s, err := cmd.Flags().GetString(flag) - if err != nil { - log.WithError(err). - WithFields(log.Fields{ - "flag": flag, - "command": cmd.Name(), - }). - Error("could not fetch flag") +func GetFlagString(cmd *cobra.Command, flag string) (s string, err error) { + if s, err = cmd.Flags().GetString(flag); err != nil { + return "", fmt.Errorf("Could not fetch flag %v for command %v\n%v", flag, cmd.Name(), err) } - return s + + return s, nil } // GetFlagBool returns the bool value from a Bool() flag -func GetFlagBool(cmd *cobra.Command, flag string) bool { - s, err := cmd.Flags().GetBool(flag) - if err != nil { - log.WithError(err). - WithFields(log.Fields{ - "flag": flag, - "command": cmd.Name(), - }). - Error("could not fetch flag") +func GetFlagBool(cmd *cobra.Command, flag string) (b bool, err error) { + if b, err = cmd.Flags().GetBool(flag); err != nil { + return b, fmt.Errorf("Could not fetch flag %v for command %v\n%v", flag, cmd.Name(), err) } - return s + + return b, nil } // GetFlagInt returns the integer value from an Int() flag -func GetFlagInt(cmd *cobra.Command, flag string) int { - s, err := cmd.Flags().GetInt(flag) - if err != nil { - log.WithError(err). - WithFields(log.Fields{ - "flag": flag, - "command": cmd.Name(), - }). - Error("could not fetch flag") +func GetFlagInt(cmd *cobra.Command, flag string) (i int, err error) { + if i, err = cmd.Flags().GetInt(flag); err != nil { + return i, fmt.Errorf("Could not fetch flag %v for command %v\n%v", flag, cmd.Name(), err) } - return s + + return i, nil } // GetMapFromStringSlice returns a k,v map from a StringSlice() flag -func GetMapFromStringSlice(cmd *cobra.Command, flag string) map[string]string { +func GetMapFromStringSlice(cmd *cobra.Command, flag string) (map[string]string, error) { m := make(map[string]string) - slice := GetFlagStringSlice(cmd, flag) - squashSlice := squashParamsSlice(slice, cmd) + + slice, err := GetFlagStringSlice(cmd, flag) + if err != nil { + return nil, err + } + + squashSlice, err := squashParamsSlice(slice, cmd) + if err != nil { + return nil, err + } for _, v := range squashSlice { if !strings.Contains(v, "=") { - UsageError(cmd, "Invalid Parameter format: %s\n", v) + return nil, UsageError(cmd, "Invalid Parameter format: %s\n", v) } // Only split to retun a max of 2 values. This will take string // key=value= and return ["key", "value="] @@ -179,20 +158,20 @@ func GetMapFromStringSlice(cmd *cobra.Command, flag string) map[string]string { m[kv[0]] = kv[1] } - return m + return m, nil } // Reformats CSV params passed via CLI // e.g. Input: ["Env=dev", "ElbSecurityGroups=sg-1234", "sg-5678", "App=grafana"] // Output: ["Env=dev", "ElbSecurityGroups=sg-1234,sg-5678", "App=grafana"] -func squashParamsSlice(slice []string, cmd *cobra.Command) []string { +func squashParamsSlice(slice []string, cmd *cobra.Command) ([]string, error) { sqS := make([]string, 0, len(slice)) index := -1 if len(slice) != 0 { for _, v := range slice { if !strings.Contains(v, "=") { if index < 0 { - UsageError(cmd, "Invalid Parameter format:%s\n", v) + return nil, UsageError(cmd, "Invalid Parameter format:%s\n", v) } sqS[index] = sqS[index] + "," + v } else { @@ -202,5 +181,5 @@ func squashParamsSlice(slice []string, cmd *cobra.Command) []string { } } - return sqS + return sqS, nil } diff --git a/cmd/flags.go b/cmd/flags.go new file mode 100644 index 0000000..37708d0 --- /dev/null +++ b/cmd/flags.go @@ -0,0 +1,131 @@ +package cmd + +import ( + "fmt" + "os" + + "github.com/aws/aws-sdk-go/service/ssm" + awsx "github.com/disneystreaming/ssm-helpers/aws" + "github.com/disneystreaming/ssm-helpers/cmd/cmdutil" + "github.com/disneystreaming/ssm-helpers/cmd/logutil" + "github.com/disneystreaming/ssm-helpers/util" + "github.com/sirupsen/logrus" + "github.com/spf13/cobra" +) + +func addBaseFlags(cmd *cobra.Command) { + cmdutil.AddAllProfilesFlag(cmd) + cmdutil.AddDryRunFlag(cmd) + cmdutil.AddFilterFlag(cmd) + cmdutil.AddInstanceFlag(cmd) + cmdutil.AddProfileFlag(cmd) + cmdutil.AddRegionFlag(cmd) +} + +func addRunFlags(cmd *cobra.Command) { + cmdutil.AddCommandFlag(cmd) + cmdutil.AddFileFlag(cmd, "Specify the path to a shell script to use as input for the AWS-RunShellScript document.\nThis can be used in combination with the --commands/-c flag, and will be run after the specified commands.") +} + +func getCommandList(cmd *cobra.Command) (commandList []string, err error) { + if commandList, err = cmdutil.GetCommandFlagStringSlice(cmd); err != nil { + return nil, err + } + + // If the --commands and --file options are specified, we append the script contents to the specified commands + if inputFile, err := cmdutil.GetFlagString(cmd, "file"); inputFile != "" && err == nil { + if err = util.ReadScriptFile(inputFile, &commandList); err != nil { + return nil, err + } + } else if err != nil { + return nil, err + } + + return commandList, nil +} + +func getRegionList(cmd *cobra.Command) (regionList []string, err error) { + if regionList, err = cmdutil.GetFlagStringSlice(cmd, "region"); err != nil { + return nil, err + } + + if len(regionList) == 0 { // If no region is specified, attempt to look it up + if env, exists := os.LookupEnv("AWS_REGION"); exists { + return []string{env}, nil + } + } + + return regionList, nil +} + +func getFilterList(cmd *cobra.Command) (targets []*ssm.Target, err error) { + var filterList []string + if filterList, err = cmdutil.GetFlagStringSlice(cmd, "filter"); err != nil { + return nil, err + } + + return util.SliceToTargets(filterList), nil + +} + +func getProfileList(cmd *cobra.Command) (profileList []string, err error) { + if profileList, err = cmdutil.GetFlagStringSlice(cmd, "profile"); err != nil { + return nil, err + } + + var allProfilesFlag bool + if allProfilesFlag, err = cmdutil.GetFlagBool(cmd, "all-profiles"); err != nil { + return nil, err + } + + if len(profileList) > 0 && allProfilesFlag { + return nil, cmdutil.UsageError(cmd, "The --profile and --all-profiles flags cannot be used simultaneously.") + } + + if allProfilesFlag { // If --all-profiles is set, we call getAWSProfiles() and iterate through the user's ~/.aws/config + if profileList, err = awsx.GetAWSProfiles(); profileList == nil || err != nil { + return nil, fmt.Errorf("Could not load profiles.\n%v", err) + } + } + + if len(profileList) == 0 { + if env, exists := os.LookupEnv("AWS_PROFILE"); exists { + profileList = []string{env} + } else { + profileList = []string{"default"} + } + } + + return profileList, nil +} + +// validateRunFlags validates the usage of certain flags required by the run subcommand +func validateRunFlags(cmd *cobra.Command, instanceList []string, commandList []string, filterList []*ssm.Target) error { + if len(instanceList) > 0 && len(filterList) > 0 { + return cmdutil.UsageError(cmd, "The --filter and --instance flags cannot be used simultaneously.") + } + + if len(instanceList) == 0 && len(filterList) == 0 { + return cmdutil.UsageError(cmd, "You must supply target arguments using either the --filter or --instance flags.") + } + + if len(instanceList) > 50 { + return cmdutil.UsageError(cmd, "The --instance flag can only be used to specify a maximum of 50 instances.") + } + + if len(commandList) == 0 { + return cmdutil.UsageError(cmd, "Please specify a command to be run on your instances.") + } + + return nil +} + +func setLogLevel(cmd *cobra.Command, log *logrus.Logger) (err error) { + v, err := cmdutil.GetFlagInt(cmd, "verbose") + if err != nil { + return err + } + + log.SetLevel(logutil.IntToLogLevel(v)) + return nil +} diff --git a/cmd/root.go b/cmd/root.go index afbe903..a63cd0a 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -2,7 +2,6 @@ package cmd import ( "fmt" - "os" "github.com/sirupsen/logrus" "github.com/spf13/cobra" @@ -32,19 +31,15 @@ func newRootCmd() *cobra.Command { love by spf13 and friends in Go. Complete documentation is available at http://hugo.spf13.com`, PersistentPreRun: func(cmd *cobra.Command, args []string) { - setLogLevel(cmd, args) + if err := setLogLevel(cmd, log); err != nil { + log.Fatal(err) + } logutil.SetLogSplitOutput(log) }, Version: fmt.Sprintf("%s\ngit commit hash %s", version, commit), } - cmdutil.AddProfileFlag(cmd) - cmdutil.AddRegionFlag(cmd) - cmdutil.AddInstanceFlag(cmd) - cmdutil.AddDryRunFlag(cmd) cmdutil.AddVerboseFlag(cmd) - cmdutil.AddAllProfilesFlag(cmd) - cmdutil.AddFilterFlag(cmd) cmdgroup := &builder.SubCommandGroup{ Commands: []*cobra.Command{ @@ -57,14 +52,9 @@ func newRootCmd() *cobra.Command { return cmd } -func setLogLevel(cmd *cobra.Command, args []string) { - log.Level = logutil.IntToLogLevel(cmdutil.GetFlagInt(cmd, "verbose")) -} - // Execute provides an entrypoint into the commands from main() func Execute() { if err := rootCmd.Execute(); err != nil { - fmt.Println(err) - os.Exit(1) + log.Fatal(err) } } diff --git a/cmd/run.go b/cmd/run.go index 0dc69d5..da08551 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -10,12 +10,10 @@ import ( "github.com/aws/aws-sdk-go/service/ssm" "github.com/spf13/cobra" - awsx "github.com/disneystreaming/ssm-helpers/aws" "github.com/disneystreaming/ssm-helpers/aws/session" "github.com/disneystreaming/ssm-helpers/cmd/cmdutil" ssmx "github.com/disneystreaming/ssm-helpers/ssm" "github.com/disneystreaming/ssm-helpers/ssm/invocation" - "github.com/disneystreaming/ssm-helpers/util" ) func newCommandSSMRun() *cobra.Command { @@ -28,69 +26,46 @@ func newCommandSSMRun() *cobra.Command { }, } - cmdutil.AddCommandFlag(cmd) - cmdutil.AddFileFlag(cmd, "Specify the path to a shell script to use as input for the AWS-RunShellScript document.\nThis can be used in combination with the --commands/-c flag, and will be run after the specified commands.") - cmdutil.AddLimitFlag(cmd, 0, "Set a limit for the number of instance results returned per profile/region combination (0 = no limit)") + addBaseFlags(cmd) + addRunFlags(cmd) + return cmd } func runCommand(cmd *cobra.Command, args []string) { - // Get all of our CLI flag values - cmdutil.ValidateArgs(cmd, args) - commandList := cmdutil.GetCommandFlagStringSlice(cmd) - profileList := cmdutil.GetFlagStringSlice(cmd.Parent(), "profile") - regionList := cmdutil.GetFlagStringSlice(cmd.Parent(), "region") - filterList := cmdutil.GetFlagStringSlice(cmd.Parent(), "filter") - instanceList := cmdutil.GetFlagStringSlice(cmd, "instance") - allProfilesFlag := cmdutil.GetFlagBool(cmd, "all-profiles") - - // Get the number of cores available for parallelization - runtime.GOMAXPROCS(runtime.NumCPU()) + var err error + var instanceList, commandList, profileList, regionList []string + var targets []*ssm.Target - // If the --commands and --file options are specified, we append the script contents to the specified commands - if inputFile := cmdutil.GetFlagString(cmd, "file"); inputFile != "" { - if err := util.ReadScriptFile(inputFile, &commandList); err != nil { - log.Fatal(err) - } + // Get all of our CLI flag values + if err = cmdutil.ValidateArgs(cmd, args); err != nil { + log.Fatal(err) } - switch { // These cases are all fatal for our invocations or result in undefined behavior - case len(instanceList) > 0 && len(filterList) > 0: - cmdutil.UsageError(cmd, "The --filter and --instance flags cannot be used simultaneously.") - case len(instanceList) > 50: - cmdutil.UsageError(cmd, "The --instance flag can only be used to specify a maximum of 50 instances.") - case len(commandList) == 0: - cmdutil.UsageError(cmd, "Please specify a command to be run on your instances.") - case len(profileList) > 0 && allProfilesFlag: - cmdutil.UsageError(cmd, "The --profile and --all-profiles flags cannot be used simultaneously.") + if instanceList, err = cmdutil.GetFlagStringSlice(cmd, "instance"); err != nil { + log.Fatal(err) + } + if commandList, err = getCommandList(cmd); err != nil { + log.Fatal(err) + } + if targets, err = getFilterList(cmd); err != nil { + log.Fatal(err) } - targets := []*ssm.Target{} + if err := validateRunFlags(cmd, instanceList, commandList, targets); err != nil { + log.Fatal(err) + } -args: - for { - switch { - case allProfilesFlag: // If --all-profiles is set, we call getAWSProfiles() and iterate through the user's ~/.aws/config - if profileList, err := awsx.GetAWSProfiles(); profileList == nil || err != nil { - log.Fatalf("Could not load profiles.\n%v", err) - } - case len(filterList) > 0 && len(targets) == 0: // Convert the filter slice to a map - targets = util.SliceToTargets(filterList) - case len(profileList) == 0: // If no profile is specified, look it up or fall back to "default" - if env, exists := os.LookupEnv("AWS_PROFILE"); !exists { - profileList = []string{env} - } else { - profileList = []string{"default"} - } - case len(regionList) == 0: // If no region is specified, attempt to look it up - if env, exists := os.LookupEnv("AWS_REGION"); !exists { - regionList = []string{env} - } - default: - break args - } + if profileList, err = getProfileList(cmd); err != nil { + log.Fatal(err) + } + if regionList, err = getRegionList(cmd); err != nil { + log.Fatal(err) } + // Get the number of cores available for parallelization + runtime.GOMAXPROCS(runtime.NumCPU()) + log.Info("Command(s) to be executed:\n", strings.Join(commandList, "\n")) sciInput := &ssm.SendCommandInput{ @@ -112,6 +87,7 @@ args: // Set up our AWS session for each permutation of profile + region sessionPool := session.NewPoolSafe(profileList, regionList, log) + wg, output := sync.WaitGroup{}, invocation.ResultSafe{} for _, sess := range sessionPool.Sessions { diff --git a/cmd/session.go b/cmd/session.go index cd263e1..54c0b2a 100644 --- a/cmd/session.go +++ b/cmd/session.go @@ -17,7 +17,6 @@ import ( "github.com/disneystreaming/gomux" - awsx "github.com/disneystreaming/ssm-helpers/aws" "github.com/disneystreaming/ssm-helpers/aws/session" "github.com/disneystreaming/ssm-helpers/cmd/cmdutil" ssmx "github.com/disneystreaming/ssm-helpers/ssm" @@ -42,44 +41,28 @@ func newCommandSSMSession() *cobra.Command { } func startSessionCommand(cmd *cobra.Command, args []string) { - cmdutil.ValidateArgs(cmd, args) + var err error + var instanceList, profileList, regionList, filterList, tagList []string - dryRunFlag := cmdutil.GetFlagBool(cmd.Parent(), "dry-run") - profileList := cmdutil.GetFlagStringSlice(cmd.Parent(), "profile") - regionList := cmdutil.GetFlagStringSlice(cmd.Parent(), "region") - filterList := cmdutil.GetFlagStringSlice(cmd.Parent(), "filter") - tagList := cmdutil.GetFlagStringSlice(cmd, "tag") - limitFlag := cmdutil.GetFlagInt(cmd, "limit") - instanceList := cmdutil.GetFlagStringSlice(cmd, "instance") - sessionName := cmdutil.GetFlagString(cmd, "session-name") - - // Get the number of cores available for parallelization - runtime.GOMAXPROCS(runtime.NumCPU()) - - if len(profileList) == 0 { - env, exists := os.LookupEnv("AWS_PROFILE") - if exists { - profileList = []string{env} - } else { - profileList = []string{"default"} - } + // Get all of our CLI flag values + if err = cmdutil.ValidateArgs(cmd, args); err != nil { + log.Fatal(err) } - if len(regionList) == 0 { - env, exists := os.LookupEnv("AWS_REGION") - if exists == false { - regionList = []string{env} - } + if profileList, err = getProfileList(cmd); err != nil { + log.Fatal(err) } - - // If --all-profiles is set, we call getAWSProfiles() and iterate through the user's ~/.aws/config - if allProfilesFlag := cmdutil.GetFlagBool(cmd, "all-profiles"); allProfilesFlag { - profileList, err := awsx.GetAWSProfiles() - if profileList == nil || err != nil { - log.Error("Could not load profiles.", err) - os.Exit(1) - } + if regionList, err = getRegionList(cmd); err != nil { + log.Fatal(err) } + dryRunFlag, err := cmdutil.GetFlagBool(cmd.Parent(), "dry-run") + filterList, err = cmdutil.GetFlagStringSlice(cmd.Parent(), "filter") + tagList, err = cmdutil.GetFlagStringSlice(cmd, "tag") + limitFlag, err := cmdutil.GetFlagInt(cmd, "limit") + sessionName, err := cmdutil.GetFlagString(cmd, "session-name") + + // Get the number of cores available for parallelization + runtime.GOMAXPROCS(runtime.NumCPU()) // Set up our AWS session for each permutation of profile + region sessionPool := session.NewPoolSafe(profileList, regionList, log) @@ -136,64 +119,65 @@ func startSessionCommand(cmd *cobra.Command, args []string) { } // If -i flag is set, don't prompt for instance selection - if !dryRunFlag { - // Single instance specified or found, starting session in current terminal (non-multiplexed) - if len(instanceList) == 1 { - for _, v := range instancePool.AllInstances { - if err := startSSMSession(v.Profile, v.Region, v.InstanceID); err != nil { - log.Errorf("Failed to start ssm-session for instance %s\n%s", v.InstanceID, err) - } + if dryRunFlag { + return + } + // Single instance specified or found, starting session in current terminal (non-multiplexed) + if len(instanceList) == 1 { + for _, v := range instancePool.AllInstances { + if err := startSSMSession(v.Profile, v.Region, v.InstanceID); err != nil { + log.Errorf("Failed to start ssm-session for instance %s\n%s", v.InstanceID, err) } - return } + return + } - // Multiple instances specified or found, check to see if we're in a tmux session to avoid nesting - if len(instanceList) > 1 && len(instancePool.AllInstances) > 1 { - var instances []instance.InstanceInfo - for _, v := range instancePool.AllInstances { - instances = append(instances, v) - } + // Multiple instances specified or found, check to see if we're in a tmux session to avoid nesting + if len(instanceList) > 1 && len(instancePool.AllInstances) > 1 { + var instances []instance.InstanceInfo + for _, v := range instancePool.AllInstances { + instances = append(instances, v) + } - if err := configTmuxSession(sessionName, instances); err != nil { - log.Fatal(err) - } - } else { - // If -i was not specified, go to a selection prompt before starting sessions - selectedInstances, err := startSelectionPrompt(&instancePool, totalInstances, tagList) - if err != nil { - if err == terminal.InterruptErr { - log.Info("Instance selection interrupted.") - os.Exit(0) - } - log.Errorf("Error during instance selection\n%s", err) - os.Exit(1) + if err := configTmuxSession(sessionName, instances); err != nil { + log.Fatal(err) + } + } else { + // If -i was not specified, go to a selection prompt before starting sessions + selectedInstances, err := startSelectionPrompt(&instancePool, totalInstances, tagList) + if err != nil { + if err == terminal.InterruptErr { + log.Info("Instance selection interrupted.") + os.Exit(0) } + log.Errorf("Error during instance selection\n%s", err) + os.Exit(1) + } - // If only one instance was selected, don't bother with a tmux session - if len(selectedInstances) == 1 { - for _, v := range selectedInstances { - if err := startSSMSession(v.Profile, v.Region, v.InstanceID); err != nil { - log.Errorf("Failed to start ssm-session for instance %s\n%s", v.InstanceID, err) - } + // If only one instance was selected, don't bother with a tmux session + if len(selectedInstances) == 1 { + for _, v := range selectedInstances { + if err := startSSMSession(v.Profile, v.Region, v.InstanceID); err != nil { + log.Errorf("Failed to start ssm-session for instance %s\n%s", v.InstanceID, err) } - return } + return + } - if err = configTmuxSession(sessionName, selectedInstances); err != nil { - log.Fatal(err) - } + if err = configTmuxSession(sessionName, selectedInstances); err != nil { + log.Fatal(err) } + } - // Make sure we aren't going to nest tmux sessions - currentTmuxSocket := os.Getenv("TMUX") - if len(currentTmuxSocket) == 0 { - if err := attachTmuxSession(sessionName); err != nil { - log.Errorf("Could not attach to tmux session '%s'\n%s", sessionName, err) - } - } else { - log.Info("To force nested Tmux sessions unset $TMUX") - log.Infof("Attach to the session with `tmux attach -t %s`", sessionName) + // Make sure we aren't going to nest tmux sessions + currentTmuxSocket := os.Getenv("TMUX") + if len(currentTmuxSocket) == 0 { + if err := attachTmuxSession(sessionName); err != nil { + log.Errorf("Could not attach to tmux session '%s'\n%s", sessionName, err) } + } else { + log.Info("To force nested Tmux sessions unset $TMUX") + log.Infof("Attach to the session with `tmux attach -t %s`", sessionName) } } From 5cca56895011368754ec823e9ce94587af604ebd Mon Sep 17 00:00:00 2001 From: sendqueery <47192407+sendqueery@users.noreply.github.com> Date: Tue, 4 Aug 2020 10:03:11 -0400 Subject: [PATCH 11/12] Change instances of "ctx" to "client" --- ssm/helpers.go | 15 ++++++++------- ssm/invocation/runner.go | 8 ++++---- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/ssm/helpers.go b/ssm/helpers.go index b37e983..a73ea84 100644 --- a/ssm/helpers.go +++ b/ssm/helpers.go @@ -57,9 +57,9 @@ func addInstanceInfo(instanceID *string, tags []ec2helpers.InstanceTags, instanc } } -func checkInvocationStatus(ctx ssmiface.SSMAPI, commandID *string) (done bool, err error) { +func checkInvocationStatus(client ssmiface.SSMAPI, commandID *string) (done bool, err error) { var invocation *ssm.ListCommandsOutput - if invocation, err = ctx.ListCommands(&ssm.ListCommandsInput{ + if invocation, err = client.ListCommands(&ssm.ListCommandsInput{ CommandId: commandID, }); err != nil { return true, fmt.Errorf("Encountered an error when trying to call the ListCommands API with CommandId: %v\n%v", *commandID, err) @@ -78,7 +78,7 @@ func checkInvocationStatus(ctx ssmiface.SSMAPI, commandID *string) (done bool, e } // RunInvocations invokes an SSM document with given parameters on the provided slice of instances -func RunInvocations(sess *session.Pool, ctx ssmiface.SSMAPI, wg *sync.WaitGroup, input *ssm.SendCommandInput, results *invocation.ResultSafe) { +func RunInvocations(sess *session.Pool, client ssmiface.SSMAPI, wg *sync.WaitGroup, input *ssm.SendCommandInput, results *invocation.ResultSafe) { defer wg.Done() oc := make(chan *ssm.GetCommandInvocationOutput) @@ -87,16 +87,17 @@ func RunInvocations(sess *session.Pool, ctx ssmiface.SSMAPI, wg *sync.WaitGroup, var err error // Send our command input to SSM - if scOutput, err = ctx.SendCommand(input); err != nil { + if scOutput, err = client.SendCommand(input); err != nil { sess.Logger.Errorf("Error when calling the SendCommand API for account %v in %v\n%v", sess.ProfileName, *sess.Session.Config.Region, err) return } commandID := scOutput.Command.CommandId + sess.Logger.Infof("Started invocation %v for %v in %v", *commandID, sess.ProfileName, *sess.Session.Config.Region) // Watch status of invocation to see when it's done and we can get the output for done := false; !done; time.Sleep(2 * time.Second) { - if done, err = checkInvocationStatus(ctx, commandID); err != nil { + if done, err = checkInvocationStatus(client, commandID); err != nil { sess.Logger.Error(err) return } @@ -108,12 +109,12 @@ func RunInvocations(sess *session.Pool, ctx ssmiface.SSMAPI, wg *sync.WaitGroup, } // Iterate through the details of the invocations returned - if err = ctx.ListCommandInvocationsPages( + if err = client.ListCommandInvocationsPages( lciInput, func(page *ssm.ListCommandInvocationsOutput, lastPage bool) bool { for _, entry := range page.CommandInvocations { // Fetch the results of our invocation for all provided instances - go invocation.GetResult(ctx, commandID, entry.InstanceId, oc, ec) + go invocation.GetResult(client, commandID, entry.InstanceId, oc, ec) // Wait for results to return until the combined total of results and errors select { diff --git a/ssm/invocation/runner.go b/ssm/invocation/runner.go index d341ec6..12f084b 100644 --- a/ssm/invocation/runner.go +++ b/ssm/invocation/runner.go @@ -19,13 +19,13 @@ func RunSSMCommand(session ssmiface.SSMAPI, input *ssm.SendCommandInput, dryRunF return } -func GetTargets(ctx ssmiface.SSMAPI, commandID *string) (targets []*string, err error) { +func GetTargets(client ssmiface.SSMAPI, commandID *string) (targets []*string, err error) { var out *ssm.ListCommandInvocationsOutput // Try a few times to get the invocation data, because it takes a little bit to have any information for i := 0; i < 3; i++ { time.Sleep(1 * time.Second) - if out, err = ctx.ListCommandInvocations(&ssm.ListCommandInvocationsInput{ + if out, err = client.ListCommandInvocations(&ssm.ListCommandInvocationsInput{ CommandId: commandID, }); err != nil { return nil, err @@ -47,8 +47,8 @@ func GetTargets(ctx ssmiface.SSMAPI, commandID *string) (targets []*string, err return targets, nil } -func GetResult(ctx ssmiface.SSMAPI, commandID *string, instanceID *string, gci chan *ssm.GetCommandInvocationOutput, ec chan error) { - status, err := ctx.GetCommandInvocation(&ssm.GetCommandInvocationInput{ +func GetResult(client ssmiface.SSMAPI, commandID *string, instanceID *string, gci chan *ssm.GetCommandInvocationOutput, ec chan error) { + status, err := client.GetCommandInvocation(&ssm.GetCommandInvocationInput{ CommandId: commandID, InstanceId: instanceID, }) From f51e92d59b1eae0b737c445cf6c4ac020f105c32 Mon Sep 17 00:00:00 2001 From: sendqueery <47192407+sendqueery@users.noreply.github.com> Date: Tue, 4 Aug 2020 10:04:11 -0400 Subject: [PATCH 12/12] Fix session validation --- aws/config/config.go | 20 +++--------- aws/session/session.go | 72 ++++++++++++++++++++++-------------------- util/httpx/httpx.go | 6 ++++ 3 files changed, 48 insertions(+), 50 deletions(-) diff --git a/aws/config/config.go b/aws/config/config.go index 6020904..734ed12 100644 --- a/aws/config/config.go +++ b/aws/config/config.go @@ -15,23 +15,13 @@ import ( // // This means we can change our sessions to be `session.New(, session.NewDefaultConfig()) // If we need to override it we can swap order (last config's value wins) -func NewDefaultConfig() *aws.Config { +func NewDefaultConfig(region string) *aws.Config { return &aws.Config{ - HTTPClient: httpx.NewDefaultClient(), + CredentialsChainVerboseErrors: aws.Bool(true), + Region: aws.String(region), + HTTPClient: httpx.NewDefaultClient(), Retryer: &client.DefaultRetryer{ - NumMaxRetries: 10, - MaxThrottleDelay: 1500 * time.Millisecond, - }, - } -} - -// NewDefaultConfigWithRegion returns the same config as NewDefaultConfig, but allows a user to specify a region as well -func NewDefaultConfigWithRegion(region string) *aws.Config { - return &aws.Config{ - Region: aws.String(region), - HTTPClient: httpx.NewDefaultClient(), - Retryer: &client.DefaultRetryer{ - NumMaxRetries: 10, + NumMaxRetries: 3, MaxThrottleDelay: 1500 * time.Millisecond, }, } diff --git a/aws/session/session.go b/aws/session/session.go index e90cfc4..3c07cb5 100644 --- a/aws/session/session.go +++ b/aws/session/session.go @@ -5,13 +5,13 @@ import ( "sync" "github.com/aws/aws-sdk-go/aws/session" - "github.com/sirupsen/logrus" + log "github.com/sirupsen/logrus" "github.com/disneystreaming/ssm-helpers/aws/config" ) // NewPoolSafe is used to create a pool of AWS sessions with different profile/region permutations -func NewPoolSafe(profiles []string, regions []string, logger *logrus.Logger) (allSessions *PoolSafe) { +func NewPoolSafe(profiles []string, regions []string, logger *log.Logger) (allSessions *PoolSafe) { wg := sync.WaitGroup{} sp := &PoolSafe{ @@ -19,76 +19,78 @@ func NewPoolSafe(profiles []string, regions []string, logger *logrus.Logger) (al } if len(regions) == 0 { - wg.Add(len(profiles) * len(regions)) for _, p := range profiles { for _, r := range regions { // Wait until we have the session for each permutation of profiles and regions go func(p string, r string) { defer wg.Done() - newSession := newSession(p, r) - sp.Lock() + s, err := newSession(p, r) + if err != nil { + logger.Fatalf("Error when trying to create session:\n%v", err) + } + + if err := validateSessionCreds(s); s != nil { + logger.Fatal(err) + } session := Pool{ Logger: logger, ProfileName: p, - Session: newSession, + Session: s, } sp.Sessions[fmt.Sprintf("%s-%s", p, r)] = &session - //sp.Sessions = append(sp.Sessions, &session) - defer sp.Unlock() }(p, r) } } } else { - wg.Add(len(profiles)) for _, p := range profiles { // Wait until we have the session for each profile go func(p string) { defer wg.Done() - newSession := newSession(p, "") - sp.Lock() + s, err := newSession(p, "") + if err != nil { + logger.Fatalf("Error when trying to create session:\n%v", err) + } + + if err := validateSessionCreds(s); s != nil { + logger.Fatal(err) + } session := Pool{ Logger: logger, ProfileName: p, - Session: newSession, + Session: s, } sp.Sessions[fmt.Sprintf("%s", p)] = &session - defer sp.Unlock() }(p) } } - // Wait until all sessions have been initialized wg.Wait() return sp } -// createSession uses a given profile and region to call NewSessionWithOptions() to initialize an instance of the AWS client with the given settings. +func validateSessionCreds(session *session.Session) (err error) { + creds := session.Config.Credentials + if _, err := creds.Get(); err != nil { + return fmt.Errorf("Error when validating credentials:\n%v", err) + } + + return nil +} + +// newSession uses a given profile and region to call NewSessionWithOptions() to initialize an instance of the AWS client with the given settings. // If the region is nil, it defaults to the default region in the ~/.aws/config file or the AWS_REGION environment variable. -func newSession(profile string, region string) (newSession *session.Session) { +func newSession(profile string, region string) (newSession *session.Session, err error) { // Create AWS session from shared config // This will import the AWS_PROFILE envvar from your console, if set - if region != "" { - newSession = session.Must( - session.NewSessionWithOptions( - session.Options{ - Config: *config.NewDefaultConfigWithRegion(region), - Profile: profile, - SharedConfigState: session.SharedConfigEnable, - })) - } else { - newSession = session.Must( - session.NewSessionWithOptions( - session.Options{ - Config: *config.NewDefaultConfig(), - Profile: profile, - SharedConfigState: session.SharedConfigEnable, - })) - } - - return newSession + return session.NewSessionWithOptions( + session.Options{ + Config: *config.NewDefaultConfig(region), + Profile: profile, + SharedConfigState: session.SharedConfigEnable, + }) } diff --git a/util/httpx/httpx.go b/util/httpx/httpx.go index da2b204..ee471f7 100644 --- a/util/httpx/httpx.go +++ b/util/httpx/httpx.go @@ -1,6 +1,7 @@ package httpx import ( + "net" "net/http" "time" ) @@ -10,5 +11,10 @@ import ( func NewDefaultClient() *http.Client { return &http.Client{ Timeout: 10 * time.Second, + Transport: &http.Transport{ + Dial: (&net.Dialer{ + Timeout: 2 * time.Second, + }).Dial, + }, } }