aboutsummaryrefslogtreecommitdiff
path: root/broker
diff options
context:
space:
mode:
authorMichael Pu <michael.pu@uwaterloo.ca>2023-11-18 20:43:28 -0500
committerCecylia Bocovich <cohosh@torproject.org>2024-01-22 13:06:42 -0500
commit8fb17de1529281d30f1eb3c9d746de70673337fa (patch)
treeba08f5d26210de1633e890b0f7b4d187cdb6bacd /broker
parentd0529141acb706f64e4defebd22a7d8604d831db (diff)
downloadsnowflake-8fb17de1529281d30f1eb3c9d746de70673337fa.tar.gz
snowflake-8fb17de1529281d30f1eb3c9d746de70673337fa.zip
Implement SQS rendezvous in client and broker
This features adds an additional rendezvous method to send client offers and receive proxy answers through the use of Amazon SQS queues. https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/-/issues/26151
Diffstat (limited to 'broker')
-rw-r--r--broker/broker.go17
-rw-r--r--broker/sqs.go195
2 files changed, 211 insertions, 1 deletions
diff --git a/broker/broker.go b/broker/broker.go
index 33f45ab..06b530a 100644
--- a/broker/broker.go
+++ b/broker/broker.go
@@ -8,9 +8,9 @@ package main
import (
"bytes"
"container/heap"
+ "context"
"crypto/tls"
"flag"
- "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/bridgefingerprint"
"io"
"log"
"net/http"
@@ -21,6 +21,8 @@ import (
"syscall"
"time"
+ "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/bridgefingerprint"
+
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/namematcher"
@@ -190,6 +192,7 @@ func main() {
var geoipDatabase string
var geoip6Database string
var bridgeListFilePath, allowedRelayPattern, presumedPatternForLegacyClient string
+ var brokerSQSQueueName, brokerSQSQueueRegion string
var disableTLS bool
var certFilename, keyFilename string
var disableGeoip bool
@@ -207,6 +210,8 @@ func main() {
flag.StringVar(&bridgeListFilePath, "bridge-list-path", "", "file path for bridgeListFile")
flag.StringVar(&allowedRelayPattern, "allowed-relay-pattern", "", "allowed pattern for relay host name")
flag.StringVar(&presumedPatternForLegacyClient, "default-relay-pattern", "", "presumed pattern for legacy client")
+ flag.StringVar(&brokerSQSQueueName, "broker-sqs-name", "", "name of broker SQS queue to listen for incoming messages on")
+ flag.StringVar(&brokerSQSQueueRegion, "broker-sqs-region", "", "name of AWS region of broker SQS queue")
flag.BoolVar(&disableTLS, "disable-tls", false, "don't use HTTPS")
flag.BoolVar(&disableGeoip, "disable-geoip", false, "don't use geoip for stats collection")
flag.StringVar(&metricsFilename, "metrics-log", "", "path to metrics logging output")
@@ -276,6 +281,16 @@ func main() {
Addr: addr,
}
+ // Run SQS Handler to continuously poll and process messages from SQS
+ if brokerSQSQueueName != "" && brokerSQSQueueRegion != "" {
+ sqsHandlerContext := context.Background()
+ sqsHandler, err := newSQSHandler(sqsHandlerContext, brokerSQSQueueName, brokerSQSQueueRegion, i)
+ if err != nil {
+ log.Fatal(err)
+ }
+ go sqsHandler.PollAndHandleMessages(sqsHandlerContext)
+ }
+
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGHUP)
diff --git a/broker/sqs.go b/broker/sqs.go
new file mode 100644
index 0000000..f3d3e79
--- /dev/null
+++ b/broker/sqs.go
@@ -0,0 +1,195 @@
+package main
+
+import (
+ "context"
+ "log"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/config"
+ "github.com/aws/aws-sdk-go-v2/service/sqs"
+ "github.com/aws/aws-sdk-go-v2/service/sqs/types"
+ "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages"
+)
+
+const (
+ cleanupInterval = time.Second * 30
+ cleanupThreshold = -2 * time.Minute
+)
+
+type sqsHandler struct {
+ SQSClient *sqs.Client
+ SQSQueueURL *string
+ IPC *IPC
+}
+
+func (r *sqsHandler) pollMessages(context context.Context, chn chan<- *types.Message) {
+ for {
+ res, err := r.SQSClient.ReceiveMessage(context, &sqs.ReceiveMessageInput{
+ QueueUrl: r.SQSQueueURL,
+ MaxNumberOfMessages: 10,
+ WaitTimeSeconds: 15,
+ MessageAttributeNames: []string{
+ string(types.QueueAttributeNameAll),
+ },
+ })
+
+ if err != nil {
+ log.Printf("SQSHandler: encountered error while polling for messages: %v\n", err)
+ }
+
+ for _, message := range res.Messages {
+ chn <- &message
+ }
+ }
+}
+
+func (r *sqsHandler) cleanupClientQueues(context context.Context) {
+ for range time.Tick(cleanupInterval) {
+ // Runs at fixed intervals to clean up any client queues that were last changed more than 2 minutes ago
+ queueURLsList := []string{}
+ var nextToken *string
+ for {
+ res, err := r.SQSClient.ListQueues(context, &sqs.ListQueuesInput{
+ QueueNamePrefix: aws.String("snowflake-client-"),
+ MaxResults: aws.Int32(1000),
+ NextToken: nextToken,
+ })
+ if err != nil {
+ log.Printf("SQSHandler: encountered error while retrieving client queues to clean up: %v\n", err)
+ }
+ queueURLsList = append(queueURLsList, res.QueueUrls...)
+ if res.NextToken == nil {
+ break
+ } else {
+ nextToken = res.NextToken
+ }
+ }
+
+ numDeleted := 0
+ cleanupCutoff := time.Now().Add(cleanupThreshold)
+ for _, queueURL := range queueURLsList {
+ if !strings.Contains(queueURL, "snowflake-client-") {
+ continue
+ }
+ res, err := r.SQSClient.GetQueueAttributes(context, &sqs.GetQueueAttributesInput{
+ QueueUrl: aws.String(queueURL),
+ AttributeNames: []types.QueueAttributeName{types.QueueAttributeNameLastModifiedTimestamp},
+ })
+ if err != nil {
+ // According to the AWS SQS docs, the deletion process for a queue can take up to 60 seconds. So the queue
+ // can be in the process of being deleted, but will still be returned by the ListQueues operation, but
+ // fail when we try to GetQueueAttributes for the queue
+ log.Printf("SQSHandler: encountered error while getting attribute of client queue %s. queue may already be deleted.\n", queueURL)
+ continue
+ }
+ lastModifiedInt64, err := strconv.ParseInt(res.Attributes[string(types.QueueAttributeNameLastModifiedTimestamp)], 10, 64)
+ if err != nil {
+ log.Printf("SQSHandler: encountered invalid lastModifiedTimetamp value from client queue %s: %v\n", queueURL, err)
+ continue
+ }
+ lastModified := time.Unix(lastModifiedInt64, 0)
+ if lastModified.Before(cleanupCutoff) {
+ _, err := r.SQSClient.DeleteQueue(context, &sqs.DeleteQueueInput{
+ QueueUrl: aws.String(queueURL),
+ })
+ if err != nil {
+ log.Printf("SQSHandler: encountered error when deleting client queue %s: %v\n", queueURL, err)
+ continue
+ } else {
+ numDeleted += 1
+ }
+
+ }
+ }
+ log.Printf("SQSHandler: finished running iteration of client queue cleanup. found and deleted %d client queues.\n", numDeleted)
+ }
+}
+
+func (r *sqsHandler) handleMessage(context context.Context, message *types.Message) {
+ var encPollReq []byte
+ var response []byte
+ var err error
+
+ clientID := message.MessageAttributes["ClientID"].StringValue
+ if clientID == nil {
+ log.Println("SQSHandler: got SDP offer in SQS message with no client ID. ignoring this message.")
+ return
+ }
+
+ res, err := r.SQSClient.CreateQueue(context, &sqs.CreateQueueInput{
+ QueueName: aws.String("snowflake-client-" + *clientID),
+ })
+ answerSQSURL := res.QueueUrl
+ if err != nil {
+ log.Printf("SQSHandler: error encountered when creating answer queue for client %s: %v\n", *clientID, err)
+ }
+
+ encPollReq = []byte(*message.Body)
+ arg := messages.Arg{
+ Body: encPollReq,
+ RemoteAddr: "",
+ }
+ err = r.IPC.ClientOffers(arg, &response)
+
+ if err != nil {
+ log.Printf("SQSHandler: error encountered when handling message: %v\n", err)
+ return
+ }
+
+ r.SQSClient.SendMessage(context, &sqs.SendMessageInput{
+ QueueUrl: answerSQSURL,
+ MessageBody: aws.String(string(response)),
+ })
+}
+
+func (r *sqsHandler) deleteMessage(context context.Context, message *types.Message) {
+ r.SQSClient.DeleteMessage(context, &sqs.DeleteMessageInput{
+ QueueUrl: r.SQSQueueURL,
+ ReceiptHandle: message.ReceiptHandle,
+ })
+}
+
+func newSQSHandler(context context.Context, sqsQueueName string, region string, i *IPC) (*sqsHandler, error) {
+ log.Printf("Loading SQSHandler using SQS Queue %s in region %s\n", sqsQueueName, region)
+ cfg, err := config.LoadDefaultConfig(context, config.WithRegion(region))
+ if err != nil {
+ return nil, err
+ }
+
+ client := sqs.NewFromConfig(cfg)
+
+ // Creates the queue if a queue with the same name doesn't exist. If a queue with the same name and attributes
+ // already exists, then nothing will happen. If a queue with the same name, but different attributes exists, then
+ // an error will be returned
+ res, err := client.CreateQueue(context, &sqs.CreateQueueInput{
+ QueueName: aws.String(sqsQueueName),
+ Attributes: map[string]string{
+ "MessageRetentionPeriod": strconv.FormatInt(int64((5 * time.Minute).Seconds()), 10),
+ },
+ })
+
+ if err != nil {
+ return nil, err
+ }
+
+ return &sqsHandler{
+ SQSClient: client,
+ SQSQueueURL: res.QueueUrl,
+ IPC: i,
+ }, nil
+}
+
+func (r *sqsHandler) PollAndHandleMessages(context context.Context) {
+ log.Println("SQSHandler: Starting to poll for messages at: " + *r.SQSQueueURL)
+ messagesChn := make(chan *types.Message, 2)
+ go r.pollMessages(context, messagesChn)
+ go r.cleanupClientQueues(context)
+
+ for message := range messagesChn {
+ r.handleMessage(context, message)
+ r.deleteMessage(context, message)
+ }
+}