aboutsummaryrefslogtreecommitdiff
path: root/broker/sqs.go
blob: 42e5dd38c58890955c2d15bdfcb7af7a2f70bb55 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
package main

import (
	"context"
	"log"
	"strconv"
	"strings"
	"time"

	"github.com/aws/aws-sdk-go-v2/aws"
	"github.com/aws/aws-sdk-go-v2/service/sqs"
	"github.com/aws/aws-sdk-go-v2/service/sqs/types"
	"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages"
	"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/sqsclient"
)

const (
	cleanupThreshold = -2 * time.Minute
)

type sqsHandler struct {
	SQSClient       sqsclient.SQSClient
	SQSQueueURL     *string
	IPC             *IPC
	cleanupInterval time.Duration
}

func (r *sqsHandler) pollMessages(ctx context.Context, chn chan<- *types.Message) {
	for {
		select {
		case <-ctx.Done():
			// if context is cancelled
			return
		default:
			res, err := r.SQSClient.ReceiveMessage(ctx, &sqs.ReceiveMessageInput{
				QueueUrl:            r.SQSQueueURL,
				MaxNumberOfMessages: 10,
				WaitTimeSeconds:     15,
				MessageAttributeNames: []string{
					string(types.QueueAttributeNameAll),
				},
			})

			if err != nil {
				log.Printf("SQSHandler: encountered error while polling for messages: %v\n", err)
				continue
			}

			for _, message := range res.Messages {
				chn <- &message
			}
		}
	}
}

func (r *sqsHandler) cleanupClientQueues(ctx context.Context) {
	for range time.NewTicker(r.cleanupInterval).C {
		// Runs at fixed intervals to clean up any client queues that were last changed more than 2 minutes ago
		select {
		case <-ctx.Done():
			// if context is cancelled
			return
		default:
			queueURLsList := []string{}
			var nextToken *string
			for {
				res, err := r.SQSClient.ListQueues(ctx, &sqs.ListQueuesInput{
					QueueNamePrefix: aws.String("snowflake-client-"),
					MaxResults:      aws.Int32(1000),
					NextToken:       nextToken,
				})
				if err != nil {
					log.Printf("SQSHandler: encountered error while retrieving client queues to clean up: %v\n", err)
					// client queues will be cleaned up the next time the cleanup operation is triggered automatically
					break
				}
				queueURLsList = append(queueURLsList, res.QueueUrls...)
				if res.NextToken == nil {
					break
				} else {
					nextToken = res.NextToken
				}
			}

			numDeleted := 0
			cleanupCutoff := time.Now().Add(cleanupThreshold)
			for _, queueURL := range queueURLsList {
				if !strings.Contains(queueURL, "snowflake-client-") {
					continue
				}
				res, err := r.SQSClient.GetQueueAttributes(ctx, &sqs.GetQueueAttributesInput{
					QueueUrl:       aws.String(queueURL),
					AttributeNames: []types.QueueAttributeName{types.QueueAttributeNameLastModifiedTimestamp},
				})
				if err != nil {
					// According to the AWS SQS docs, the deletion process for a queue can take up to 60 seconds. So the queue
					// can be in the process of being deleted, but will still be returned by the ListQueues operation, but
					// fail when we try to GetQueueAttributes for the queue
					log.Printf("SQSHandler: encountered error while getting attribute of client queue %s. queue may already be deleted.\n", queueURL)
					continue
				}
				lastModifiedInt64, err := strconv.ParseInt(res.Attributes[string(types.QueueAttributeNameLastModifiedTimestamp)], 10, 64)
				if err != nil {
					log.Printf("SQSHandler: encountered invalid lastModifiedTimetamp value from client queue %s: %v\n", queueURL, err)
					continue
				}
				lastModified := time.Unix(lastModifiedInt64, 0)
				if lastModified.Before(cleanupCutoff) {
					_, err := r.SQSClient.DeleteQueue(ctx, &sqs.DeleteQueueInput{
						QueueUrl: aws.String(queueURL),
					})
					if err != nil {
						log.Printf("SQSHandler: encountered error when deleting client queue %s: %v\n", queueURL, err)
						continue
					} else {
						numDeleted += 1
					}

				}
			}
			log.Printf("SQSHandler: finished running iteration of client queue cleanup. found and deleted %d client queues.\n", numDeleted)
		}
	}
}

func (r *sqsHandler) handleMessage(context context.Context, message *types.Message) {
	var encPollReq []byte
	var response []byte
	var err error

	clientID := message.MessageAttributes["ClientID"].StringValue
	if clientID == nil {
		log.Println("SQSHandler: got SDP offer in SQS message with no client ID. ignoring this message.")
		return
	}

	res, err := r.SQSClient.CreateQueue(context, &sqs.CreateQueueInput{
		QueueName: aws.String("snowflake-client-" + *clientID),
	})
	if err != nil {
		log.Printf("SQSHandler: error encountered when creating answer queue for client %s: %v\n", *clientID, err)
		return
	}
	answerSQSURL := res.QueueUrl

	encPollReq = []byte(*message.Body)
	arg := messages.Arg{
		Body:             encPollReq,
		RemoteAddr:       "",
		RendezvousMethod: messages.RendezvousSqs,
	}
	err = r.IPC.ClientOffers(arg, &response)

	if err != nil {
		log.Printf("SQSHandler: error encountered when handling message: %v\n", err)
		return
	}

	r.SQSClient.SendMessage(context, &sqs.SendMessageInput{
		QueueUrl:    answerSQSURL,
		MessageBody: aws.String(string(response)),
	})
}

func (r *sqsHandler) deleteMessage(context context.Context, message *types.Message) {
	r.SQSClient.DeleteMessage(context, &sqs.DeleteMessageInput{
		QueueUrl:      r.SQSQueueURL,
		ReceiptHandle: message.ReceiptHandle,
	})
}

func newSQSHandler(context context.Context, client sqsclient.SQSClient, sqsQueueName string, region string, i *IPC) (*sqsHandler, error) {
	// Creates the queue if a queue with the same name doesn't exist. If a queue with the same name and attributes
	// already exists, then nothing will happen. If a queue with the same name, but different attributes exists, then
	// an error will be returned
	res, err := client.CreateQueue(context, &sqs.CreateQueueInput{
		QueueName: aws.String(sqsQueueName),
		Attributes: map[string]string{
			"MessageRetentionPeriod": strconv.FormatInt(int64((5 * time.Minute).Seconds()), 10),
		},
	})

	if err != nil {
		return nil, err
	}

	return &sqsHandler{
		SQSClient:       client,
		SQSQueueURL:     res.QueueUrl,
		IPC:             i,
		cleanupInterval: time.Second * 30,
	}, nil
}

func (r *sqsHandler) PollAndHandleMessages(ctx context.Context) {
	log.Println("SQSHandler: Starting to poll for messages at: " + *r.SQSQueueURL)
	messagesChn := make(chan *types.Message, 2)
	go r.pollMessages(ctx, messagesChn)
	go r.cleanupClientQueues(ctx)

	for message := range messagesChn {
		select {
		case <-ctx.Done():
			// if context is cancelled
			return
		default:
			r.handleMessage(ctx, message)
			r.deleteMessage(ctx, message)
		}
	}
}