aboutsummaryrefslogtreecommitdiff
path: root/client/lib/rendezvous_sqs.go
blob: 00a69b3dba4ea4f59b331a5e64ac343794b05ee0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
package snowflake_client

import (
	"context"
	"crypto/rand"
	"encoding/hex"
	"log"
	"net/http"
	"net/url"
	"regexp"
	"time"

	"github.com/aws/aws-sdk-go-v2/aws"
	"github.com/aws/aws-sdk-go-v2/config"
	"github.com/aws/aws-sdk-go-v2/credentials"
	"github.com/aws/aws-sdk-go-v2/service/sqs"
	"github.com/aws/aws-sdk-go-v2/service/sqs/types"
	"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/sqsclient"
)

type sqsRendezvous struct {
	transport   http.RoundTripper
	sqsClientID string
	sqsClient   sqsclient.SQSClient
	sqsURL      *url.URL
	timeout     time.Duration
	numRetries  int
}

func newSQSRendezvous(sqsQueue string, sqsAccessKeyId string, sqsSecretKey string, transport http.RoundTripper) (*sqsRendezvous, error) {
	sqsURL, err := url.Parse(sqsQueue)
	if err != nil {
		return nil, err
	}

	var id [8]byte
	_, err = rand.Read(id[:])
	if err != nil {
		log.Fatal(err)
	}
	clientID := hex.EncodeToString(id[:])

	queueURL := sqsURL.String()
	hostName := sqsURL.Hostname()

	regionRegex, _ := regexp.Compile(`^sqs\.([\w-]+)\.amazonaws\.com$`)
	res := regionRegex.FindStringSubmatch(hostName)
	if len(res) < 2 {
		log.Fatal("Could not extract AWS region from SQS URL. Ensure that the SQS Queue URL provided is valid.")
	}
	region := res[1]
	cfg, err := config.LoadDefaultConfig(context.TODO(),
		config.WithCredentialsProvider(
			credentials.NewStaticCredentialsProvider(sqsAccessKeyId, sqsSecretKey, ""),
		),
		config.WithRegion(region),
	)
	if err != nil {
		log.Fatal(err)
	}
	client := sqs.NewFromConfig(cfg)

	log.Println("Queue URL: ", queueURL)
	log.Println("SQS Client ID: ", clientID)

	return &sqsRendezvous{
		transport:   transport,
		sqsClientID: clientID,
		sqsClient:   client,
		sqsURL:      sqsURL,
		timeout:     time.Second,
		numRetries:  5,
	}, nil
}

func (r *sqsRendezvous) Exchange(encPollReq []byte) ([]byte, error) {
	log.Println("Negotiating via SQS Queue rendezvous...")

	_, err := r.sqsClient.SendMessage(context.TODO(), &sqs.SendMessageInput{
		MessageAttributes: map[string]types.MessageAttributeValue{
			"ClientID": {
				DataType:    aws.String("String"),
				StringValue: aws.String(r.sqsClientID),
			},
		},
		MessageBody: aws.String(string(encPollReq)),
		QueueUrl:    aws.String(r.sqsURL.String()),
	})
	if err != nil {
		return nil, err
	}

	time.Sleep(r.timeout) // wait for client queue to be created by the broker

	var responseQueueURL *string
	for i := 0; i < r.numRetries; i++ {
		// The SQS queue corresponding to the client where the SDP Answer will be placed
		// may not be created yet. We will retry up to 5 times before we error out.
		var res *sqs.GetQueueUrlOutput
		res, err = r.sqsClient.GetQueueUrl(context.TODO(), &sqs.GetQueueUrlInput{
			QueueName: aws.String("snowflake-client-" + r.sqsClientID),
		})
		if err != nil {
			log.Println(err)
			log.Printf("Attempt %d of %d to retrieve URL of response SQS queue failed.\n", i+1, r.numRetries)
			time.Sleep(r.timeout)
		} else {
			responseQueueURL = res.QueueUrl
			break
		}
	}
	if err != nil {
		return nil, err
	}

	var answer string
	for i := 0; i < r.numRetries; i++ {
		// Waiting for SDP Answer from proxy to be placed in SQS queue.
		// We will retry upt to 5 times before we error out.
		res, err := r.sqsClient.ReceiveMessage(context.TODO(), &sqs.ReceiveMessageInput{
			QueueUrl:            responseQueueURL,
			MaxNumberOfMessages: 1,
			WaitTimeSeconds:     20,
		})
		if err != nil {
			return nil, err
		}
		if len(res.Messages) == 0 {
			log.Printf("Attempt %d of %d to receive message from response SQS queue failed. No message found in queue.\n", i+1, r.numRetries)
			delay := float64(i)/2.0 + 1
			time.Sleep(time.Duration(delay*1000) * (r.timeout / 1000))
		} else {
			answer = *res.Messages[0].Body
			break
		}
	}

	return []byte(answer), nil
}