From 9fe2ca58a024afe7172258d5c7ff0803bd08b57e Mon Sep 17 00:00:00 2001 From: Michael Pu Date: Sat, 2 Mar 2024 16:37:51 -0500 Subject: Switch to sqscreds param for passing in SQS credentials --- client/lib/rendezvous.go | 6 +++--- client/lib/rendezvous_sqs.go | 10 ++++++++-- client/lib/rendezvous_test.go | 2 +- client/lib/snowflake.go | 5 ++--- client/snowflake.go | 13 ++++--------- common/sqscreds/generate_creds.go | 36 ++++++++++++++++++++++++++++++++++++ common/sqscreds/lib/sqs_creds.go | 35 +++++++++++++++++++++++++++++++++++ 7 files changed, 89 insertions(+), 18 deletions(-) create mode 100644 common/sqscreds/generate_creds.go create mode 100644 common/sqscreds/lib/sqs_creds.go diff --git a/client/lib/rendezvous.go b/client/lib/rendezvous.go index 908b2ff..f25212c 100644 --- a/client/lib/rendezvous.go +++ b/client/lib/rendezvous.go @@ -94,11 +94,11 @@ func newBrokerChannelFromConfig(config ClientConfig) (*BrokerChannel, error) { if config.AmpCacheURL != "" || config.BrokerURL != "" { log.Fatalln("Multiple rendezvous methods specified. " + rendezvousErrorMsg) } - if config.SQSAccessKeyID == "" || config.SQSSecretKey == "" { - log.Fatalln("sqsakid and sqsskey must be specified to use SQS rendezvous method.") + if config.SQSCredsStr == "" { + log.Fatalln("sqscreds must be specified to use SQS rendezvous method.") } log.Println("Through SQS queue at:", config.SQSQueueURL) - rendezvous, err = newSQSRendezvous(config.SQSQueueURL, config.SQSAccessKeyID, config.SQSSecretKey, brokerTransport) + rendezvous, err = newSQSRendezvous(config.SQSQueueURL, config.SQSCredsStr, brokerTransport) } else if config.AmpCacheURL != "" && config.BrokerURL != "" { log.Println("Through AMP cache at:", config.AmpCacheURL) rendezvous, err = newAMPCacheRendezvous( diff --git a/client/lib/rendezvous_sqs.go b/client/lib/rendezvous_sqs.go index 423545f..6b1c073 100644 --- a/client/lib/rendezvous_sqs.go +++ b/client/lib/rendezvous_sqs.go @@ -16,6 +16,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/sqs" "github.com/aws/aws-sdk-go-v2/service/sqs/types" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/sqsclient" + sqscreds "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/sqscreds/lib" ) type sqsRendezvous struct { @@ -26,12 +27,17 @@ type sqsRendezvous struct { numRetries int } -func newSQSRendezvous(sqsQueue string, sqsAccessKeyId string, sqsSecretKey string, transport http.RoundTripper) (*sqsRendezvous, error) { +func newSQSRendezvous(sqsQueue string, sqsCredsStr string, transport http.RoundTripper) (*sqsRendezvous, error) { sqsURL, err := url.Parse(sqsQueue) if err != nil { return nil, err } + sqsCreds, err := sqscreds.AwsCredsFromBase64(sqsCredsStr) + if err != nil { + return nil, err + } + queueURL := sqsURL.String() hostName := sqsURL.Hostname() @@ -43,7 +49,7 @@ func newSQSRendezvous(sqsQueue string, sqsAccessKeyId string, sqsSecretKey strin region := res[1] cfg, err := config.LoadDefaultConfig(context.TODO(), config.WithCredentialsProvider( - credentials.NewStaticCredentialsProvider(sqsAccessKeyId, sqsSecretKey, ""), + credentials.NewStaticCredentialsProvider(sqsCreds.AwsAccessKeyId, sqsCreds.AwsSecretKey, ""), ), config.WithRegion(region), ) diff --git a/client/lib/rendezvous_test.go b/client/lib/rendezvous_test.go index 6589593..d2460d3 100644 --- a/client/lib/rendezvous_test.go +++ b/client/lib/rendezvous_test.go @@ -284,7 +284,7 @@ func TestSQSRendezvous(t *testing.T) { Convey("Construct SQS queue rendezvous", func() { transport := &mockTransport{http.StatusOK, []byte{}} - rend, err := newSQSRendezvous("https://sqs.us-east-1.amazonaws.com", "some-access-key-id", "some-secret-key", transport) + rend, err := newSQSRendezvous("https://sqs.us-east-1.amazonaws.com", "eyJhd3MtYWNjZXNzLWtleS1pZCI6InRlc3QtYWNjZXNzLWtleSIsImF3cy1zZWNyZXQta2V5IjoidGVzdC1zZWNyZXQta2V5In0=", transport) So(err, ShouldBeNil) So(rend.sqsClient, ShouldNotBeNil) diff --git a/client/lib/snowflake.go b/client/lib/snowflake.go index cf9bd6e..7d405cd 100644 --- a/client/lib/snowflake.go +++ b/client/lib/snowflake.go @@ -89,9 +89,8 @@ type ClientConfig struct { // SQSQueueURL is the full URL of an AWS SQS Queue. A nonzero value indicates // that SQS queue will be used as the rendezvous method with the broker. SQSQueueURL string - // Access Key ID and Secret Key of the credentials used to access the AWS SQS Qeueue - SQSAccessKeyID string - SQSSecretKey string + // Base64 encoded string of the credentials containing access Key ID and secret key used to access the AWS SQS Qeueue + SQSCredsStr string // FrontDomain is the full URL of an optional front domain that can be used with either // the AMP cache or HTTP domain fronting rendezvous method. FrontDomain string diff --git a/client/snowflake.go b/client/snowflake.go index 7ff0a35..4ebbaf5 100644 --- a/client/snowflake.go +++ b/client/snowflake.go @@ -84,11 +84,8 @@ func socksAcceptLoop(ln *pt.SocksListener, config sf.ClientConfig, shutdown chan if arg, ok := conn.Req.Args.Get("sqsqueue"); ok { config.SQSQueueURL = arg } - if arg, ok := conn.Req.Args.Get("sqsakid"); ok { - config.SQSAccessKeyID = arg - } - if arg, ok := conn.Req.Args.Get("sqsskey"); ok { - config.SQSSecretKey = arg + if arg, ok := conn.Req.Args.Get("sqscreds"); ok { + config.SQSCredsStr = arg } if arg, ok := conn.Req.Args.Get("fronts"); ok { if arg != "" { @@ -169,8 +166,7 @@ func main() { frontDomainsCommas := flag.String("fronts", "", "comma-separated list of front domains") ampCacheURL := flag.String("ampcache", "", "URL of AMP cache to use as a proxy for signaling") sqsQueueURL := flag.String("sqsqueue", "", "URL of SQS Queue to use as a proxy for signaling") - sqsAccessKeyId := flag.String("sqsakid", "", "Access Key ID for credentials to access SQS Queue ") - sqsSecretKey := flag.String("sqsskey", "", "Secret Key for credentials to access SQS Queue") + sqsCredsStr := flag.String("sqscreds", "", "credentials to access SQS Queue") logFilename := flag.String("log", "", "name of log file") logToStateDir := flag.Bool("log-to-state-dir", false, "resolve the log file relative to tor's pt state dir") keepLocalAddresses := flag.Bool("keep-local-addresses", false, "keep local LAN address ICE candidates") @@ -239,8 +235,7 @@ func main() { BrokerURL: *brokerURL, AmpCacheURL: *ampCacheURL, SQSQueueURL: *sqsQueueURL, - SQSAccessKeyID: *sqsAccessKeyId, - SQSSecretKey: *sqsSecretKey, + SQSCredsStr: *sqsCredsStr, FrontDomains: frontDomains, ICEAddresses: iceAddresses, KeepLocalAddresses: *keepLocalAddresses || *oldKeepLocalAddresses, diff --git a/common/sqscreds/generate_creds.go b/common/sqscreds/generate_creds.go new file mode 100644 index 0000000..0f89225 --- /dev/null +++ b/common/sqscreds/generate_creds.go @@ -0,0 +1,36 @@ +package main + +import ( + "fmt" + + sqscreds "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/sqscreds/lib" +) + +// This script can be run to generate the encoded SQS credentials to pass as a CLI param or SOCKS option to the client +func main() { + var accessKey, secretKey string + + fmt.Print("Enter Access Key: ") + _, err := fmt.Scanln(&accessKey) + if err != nil { + fmt.Println("Error reading access key:", err) + return + } + + fmt.Print("Enter Secret Key: ") + _, err = fmt.Scanln(&secretKey) + if err != nil { + fmt.Println("Error reading access key:", err) + return + } + + awsCreds := sqscreds.AwsCreds{AwsAccessKeyId: accessKey, AwsSecretKey: secretKey} + println() + println("Encoded Credentials:") + res, err := awsCreds.Base64() + if err != nil { + fmt.Println("Error encoding credentials:", err) + return + } + println(res) +} diff --git a/common/sqscreds/lib/sqs_creds.go b/common/sqscreds/lib/sqs_creds.go new file mode 100644 index 0000000..dba1828 --- /dev/null +++ b/common/sqscreds/lib/sqs_creds.go @@ -0,0 +1,35 @@ +package sqscreds + +import ( + "encoding/base64" + "encoding/json" +) + +type AwsCreds struct { + AwsAccessKeyId string `json:"aws-access-key-id"` + AwsSecretKey string `json:"aws-secret-key"` +} + +func (awsCreds AwsCreds) Base64() (string, error) { + jsonData, err := json.Marshal(awsCreds) + if err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(jsonData), nil +} + +func AwsCredsFromBase64(base64Str string) (AwsCreds, error) { + var awsCreds AwsCreds + + jsonData, err := base64.StdEncoding.DecodeString(base64Str) + if err != nil { + return awsCreds, err + } + + err = json.Unmarshal(jsonData, &awsCreds) + if err != nil { + return awsCreds, err + } + + return awsCreds, nil +} -- cgit v1.2.3-54-g00ecf