diff options
author | Michael Pu <michael.pu@uwaterloo.ca> | 2023-11-18 20:43:28 -0500 |
---|---|---|
committer | Cecylia Bocovich <cohosh@torproject.org> | 2024-01-22 13:06:42 -0500 |
commit | 8fb17de1529281d30f1eb3c9d746de70673337fa (patch) | |
tree | ba08f5d26210de1633e890b0f7b4d187cdb6bacd /broker | |
parent | d0529141acb706f64e4defebd22a7d8604d831db (diff) | |
download | snowflake-8fb17de1529281d30f1eb3c9d746de70673337fa.tar.gz snowflake-8fb17de1529281d30f1eb3c9d746de70673337fa.zip |
Implement SQS rendezvous in client and broker
This features adds an additional rendezvous method to send client offers
and receive proxy answers through the use of Amazon SQS queues.
https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/-/issues/26151
Diffstat (limited to 'broker')
-rw-r--r-- | broker/broker.go | 17 | ||||
-rw-r--r-- | broker/sqs.go | 195 |
2 files changed, 211 insertions, 1 deletions
diff --git a/broker/broker.go b/broker/broker.go index 33f45ab..06b530a 100644 --- a/broker/broker.go +++ b/broker/broker.go @@ -8,9 +8,9 @@ package main import ( "bytes" "container/heap" + "context" "crypto/tls" "flag" - "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/bridgefingerprint" "io" "log" "net/http" @@ -21,6 +21,8 @@ import ( "syscall" "time" + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/bridgefingerprint" + "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/namematcher" @@ -190,6 +192,7 @@ func main() { var geoipDatabase string var geoip6Database string var bridgeListFilePath, allowedRelayPattern, presumedPatternForLegacyClient string + var brokerSQSQueueName, brokerSQSQueueRegion string var disableTLS bool var certFilename, keyFilename string var disableGeoip bool @@ -207,6 +210,8 @@ func main() { flag.StringVar(&bridgeListFilePath, "bridge-list-path", "", "file path for bridgeListFile") flag.StringVar(&allowedRelayPattern, "allowed-relay-pattern", "", "allowed pattern for relay host name") flag.StringVar(&presumedPatternForLegacyClient, "default-relay-pattern", "", "presumed pattern for legacy client") + flag.StringVar(&brokerSQSQueueName, "broker-sqs-name", "", "name of broker SQS queue to listen for incoming messages on") + flag.StringVar(&brokerSQSQueueRegion, "broker-sqs-region", "", "name of AWS region of broker SQS queue") flag.BoolVar(&disableTLS, "disable-tls", false, "don't use HTTPS") flag.BoolVar(&disableGeoip, "disable-geoip", false, "don't use geoip for stats collection") flag.StringVar(&metricsFilename, "metrics-log", "", "path to metrics logging output") @@ -276,6 +281,16 @@ func main() { Addr: addr, } + // Run SQS Handler to continuously poll and process messages from SQS + if brokerSQSQueueName != "" && brokerSQSQueueRegion != "" { + sqsHandlerContext := context.Background() + sqsHandler, err := newSQSHandler(sqsHandlerContext, brokerSQSQueueName, brokerSQSQueueRegion, i) + if err != nil { + log.Fatal(err) + } + go sqsHandler.PollAndHandleMessages(sqsHandlerContext) + } + sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGHUP) diff --git a/broker/sqs.go b/broker/sqs.go new file mode 100644 index 0000000..f3d3e79 --- /dev/null +++ b/broker/sqs.go @@ -0,0 +1,195 @@ +package main + +import ( + "context" + "log" + "strconv" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/sqs" + "github.com/aws/aws-sdk-go-v2/service/sqs/types" + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages" +) + +const ( + cleanupInterval = time.Second * 30 + cleanupThreshold = -2 * time.Minute +) + +type sqsHandler struct { + SQSClient *sqs.Client + SQSQueueURL *string + IPC *IPC +} + +func (r *sqsHandler) pollMessages(context context.Context, chn chan<- *types.Message) { + for { + res, err := r.SQSClient.ReceiveMessage(context, &sqs.ReceiveMessageInput{ + QueueUrl: r.SQSQueueURL, + MaxNumberOfMessages: 10, + WaitTimeSeconds: 15, + MessageAttributeNames: []string{ + string(types.QueueAttributeNameAll), + }, + }) + + if err != nil { + log.Printf("SQSHandler: encountered error while polling for messages: %v\n", err) + } + + for _, message := range res.Messages { + chn <- &message + } + } +} + +func (r *sqsHandler) cleanupClientQueues(context context.Context) { + for range time.Tick(cleanupInterval) { + // Runs at fixed intervals to clean up any client queues that were last changed more than 2 minutes ago + queueURLsList := []string{} + var nextToken *string + for { + res, err := r.SQSClient.ListQueues(context, &sqs.ListQueuesInput{ + QueueNamePrefix: aws.String("snowflake-client-"), + MaxResults: aws.Int32(1000), + NextToken: nextToken, + }) + if err != nil { + log.Printf("SQSHandler: encountered error while retrieving client queues to clean up: %v\n", err) + } + queueURLsList = append(queueURLsList, res.QueueUrls...) + if res.NextToken == nil { + break + } else { + nextToken = res.NextToken + } + } + + numDeleted := 0 + cleanupCutoff := time.Now().Add(cleanupThreshold) + for _, queueURL := range queueURLsList { + if !strings.Contains(queueURL, "snowflake-client-") { + continue + } + res, err := r.SQSClient.GetQueueAttributes(context, &sqs.GetQueueAttributesInput{ + QueueUrl: aws.String(queueURL), + AttributeNames: []types.QueueAttributeName{types.QueueAttributeNameLastModifiedTimestamp}, + }) + if err != nil { + // According to the AWS SQS docs, the deletion process for a queue can take up to 60 seconds. So the queue + // can be in the process of being deleted, but will still be returned by the ListQueues operation, but + // fail when we try to GetQueueAttributes for the queue + log.Printf("SQSHandler: encountered error while getting attribute of client queue %s. queue may already be deleted.\n", queueURL) + continue + } + lastModifiedInt64, err := strconv.ParseInt(res.Attributes[string(types.QueueAttributeNameLastModifiedTimestamp)], 10, 64) + if err != nil { + log.Printf("SQSHandler: encountered invalid lastModifiedTimetamp value from client queue %s: %v\n", queueURL, err) + continue + } + lastModified := time.Unix(lastModifiedInt64, 0) + if lastModified.Before(cleanupCutoff) { + _, err := r.SQSClient.DeleteQueue(context, &sqs.DeleteQueueInput{ + QueueUrl: aws.String(queueURL), + }) + if err != nil { + log.Printf("SQSHandler: encountered error when deleting client queue %s: %v\n", queueURL, err) + continue + } else { + numDeleted += 1 + } + + } + } + log.Printf("SQSHandler: finished running iteration of client queue cleanup. found and deleted %d client queues.\n", numDeleted) + } +} + +func (r *sqsHandler) handleMessage(context context.Context, message *types.Message) { + var encPollReq []byte + var response []byte + var err error + + clientID := message.MessageAttributes["ClientID"].StringValue + if clientID == nil { + log.Println("SQSHandler: got SDP offer in SQS message with no client ID. ignoring this message.") + return + } + + res, err := r.SQSClient.CreateQueue(context, &sqs.CreateQueueInput{ + QueueName: aws.String("snowflake-client-" + *clientID), + }) + answerSQSURL := res.QueueUrl + if err != nil { + log.Printf("SQSHandler: error encountered when creating answer queue for client %s: %v\n", *clientID, err) + } + + encPollReq = []byte(*message.Body) + arg := messages.Arg{ + Body: encPollReq, + RemoteAddr: "", + } + err = r.IPC.ClientOffers(arg, &response) + + if err != nil { + log.Printf("SQSHandler: error encountered when handling message: %v\n", err) + return + } + + r.SQSClient.SendMessage(context, &sqs.SendMessageInput{ + QueueUrl: answerSQSURL, + MessageBody: aws.String(string(response)), + }) +} + +func (r *sqsHandler) deleteMessage(context context.Context, message *types.Message) { + r.SQSClient.DeleteMessage(context, &sqs.DeleteMessageInput{ + QueueUrl: r.SQSQueueURL, + ReceiptHandle: message.ReceiptHandle, + }) +} + +func newSQSHandler(context context.Context, sqsQueueName string, region string, i *IPC) (*sqsHandler, error) { + log.Printf("Loading SQSHandler using SQS Queue %s in region %s\n", sqsQueueName, region) + cfg, err := config.LoadDefaultConfig(context, config.WithRegion(region)) + if err != nil { + return nil, err + } + + client := sqs.NewFromConfig(cfg) + + // Creates the queue if a queue with the same name doesn't exist. If a queue with the same name and attributes + // already exists, then nothing will happen. If a queue with the same name, but different attributes exists, then + // an error will be returned + res, err := client.CreateQueue(context, &sqs.CreateQueueInput{ + QueueName: aws.String(sqsQueueName), + Attributes: map[string]string{ + "MessageRetentionPeriod": strconv.FormatInt(int64((5 * time.Minute).Seconds()), 10), + }, + }) + + if err != nil { + return nil, err + } + + return &sqsHandler{ + SQSClient: client, + SQSQueueURL: res.QueueUrl, + IPC: i, + }, nil +} + +func (r *sqsHandler) PollAndHandleMessages(context context.Context) { + log.Println("SQSHandler: Starting to poll for messages at: " + *r.SQSQueueURL) + messagesChn := make(chan *types.Message, 2) + go r.pollMessages(context, messagesChn) + go r.cleanupClientQueues(context) + + for message := range messagesChn { + r.handleMessage(context, message) + r.deleteMessage(context, message) + } +} |