aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnthony Chang <anthony20093@gmail.com>2024-01-12 23:23:33 -0500
committerCecylia Bocovich <cohosh@torproject.org>2024-01-22 13:11:03 -0500
commit32e864b71d19145096ed93dd0af4b8b900a67081 (patch)
tree515a7a784c9570234005c8830a663abe6aa76147
parentf3b062ddb2f1bc152702f562b0c4b2ed6db4d1aa (diff)
downloadsnowflake-32e864b71d19145096ed93dd0af4b8b900a67081.tar.gz
snowflake-32e864b71d19145096ed93dd0af4b8b900a67081.zip
Add unit tests for SQS rendezvous in client
Co-authored-by: Michael Pu <michael.pu@uwaterloo.ca>
-rw-r--r--client/lib/rendezvous_sqs.go19
-rw-r--r--client/lib/rendezvous_test.go126
-rw-r--r--client/lib/sqs_test.go30
3 files changed, 137 insertions, 38 deletions
diff --git a/client/lib/rendezvous_sqs.go b/client/lib/rendezvous_sqs.go
index 89f5694..00a69b3 100644
--- a/client/lib/rendezvous_sqs.go
+++ b/client/lib/rendezvous_sqs.go
@@ -23,6 +23,8 @@ type sqsRendezvous struct {
sqsClientID string
sqsClient sqsclient.SQSClient
sqsURL *url.URL
+ timeout time.Duration
+ numRetries int
}
func newSQSRendezvous(sqsQueue string, sqsAccessKeyId string, sqsSecretKey string, transport http.RoundTripper) (*sqsRendezvous, error) {
@@ -66,6 +68,8 @@ func newSQSRendezvous(sqsQueue string, sqsAccessKeyId string, sqsSecretKey strin
sqsClientID: clientID,
sqsClient: client,
sqsURL: sqsURL,
+ timeout: time.Second,
+ numRetries: 5,
}, nil
}
@@ -86,11 +90,10 @@ func (r *sqsRendezvous) Exchange(encPollReq []byte) ([]byte, error) {
return nil, err
}
- time.Sleep(time.Second) // wait for client queue to be created by the broker
+ time.Sleep(r.timeout) // wait for client queue to be created by the broker
- numRetries := 5
var responseQueueURL *string
- for i := 0; i < numRetries; i++ {
+ for i := 0; i < r.numRetries; i++ {
// The SQS queue corresponding to the client where the SDP Answer will be placed
// may not be created yet. We will retry up to 5 times before we error out.
var res *sqs.GetQueueUrlOutput
@@ -99,8 +102,8 @@ func (r *sqsRendezvous) Exchange(encPollReq []byte) ([]byte, error) {
})
if err != nil {
log.Println(err)
- log.Printf("Attempt %d of %d to retrieve URL of response SQS queue failed.\n", i+1, numRetries)
- time.Sleep(time.Second)
+ log.Printf("Attempt %d of %d to retrieve URL of response SQS queue failed.\n", i+1, r.numRetries)
+ time.Sleep(r.timeout)
} else {
responseQueueURL = res.QueueUrl
break
@@ -111,7 +114,7 @@ func (r *sqsRendezvous) Exchange(encPollReq []byte) ([]byte, error) {
}
var answer string
- for i := 0; i < numRetries; i++ {
+ for i := 0; i < r.numRetries; i++ {
// Waiting for SDP Answer from proxy to be placed in SQS queue.
// We will retry upt to 5 times before we error out.
res, err := r.sqsClient.ReceiveMessage(context.TODO(), &sqs.ReceiveMessageInput{
@@ -123,9 +126,9 @@ func (r *sqsRendezvous) Exchange(encPollReq []byte) ([]byte, error) {
return nil, err
}
if len(res.Messages) == 0 {
- log.Printf("Attempt %d of %d to receive message from response SQS queue failed. No message found in queue.\n", i+1, numRetries)
+ log.Printf("Attempt %d of %d to receive message from response SQS queue failed. No message found in queue.\n", i+1, r.numRetries)
delay := float64(i)/2.0 + 1
- time.Sleep(time.Duration(delay*1000) * time.Millisecond)
+ time.Sleep(time.Duration(delay*1000) * (r.timeout / 1000))
} else {
answer = *res.Messages[0].Body
break
diff --git a/client/lib/rendezvous_test.go b/client/lib/rendezvous_test.go
index 227d43c..5ff8ae9 100644
--- a/client/lib/rendezvous_test.go
+++ b/client/lib/rendezvous_test.go
@@ -7,12 +7,18 @@ import (
"io"
"io/ioutil"
"net/http"
+ "net/url"
"testing"
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/service/sqs"
+ "github.com/aws/aws-sdk-go-v2/service/sqs/types"
+ "github.com/golang/mock/gomock"
. "github.com/smartystreets/goconvey/convey"
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/amp"
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages"
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/nat"
+ "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/sqsclient"
)
// mockTransport's RoundTrip method returns a response with a fake status and
@@ -271,3 +277,123 @@ func TestAMPCacheRendezvous(t *testing.T) {
})
})
}
+
+func TestSQSRendezvous(t *testing.T) {
+ Convey("SQS Rendezvous", t, func() {
+
+ Convey("Construct SQS queue rendezvous", func() {
+ transport := &mockTransport{http.StatusOK, []byte{}}
+ rend, err := newSQSRendezvous("https://sqs.us-east-1.amazonaws.com", "some-access-key-id", "some-secret-key", transport)
+
+ So(err, ShouldBeNil)
+ So(rend.sqsClientID, ShouldNotBeNil)
+ So(rend.sqsClient, ShouldNotBeNil)
+ So(rend.sqsURL, ShouldNotBeNil)
+ So(rend.sqsURL.String(), ShouldResemble, "https://sqs.us-east-1.amazonaws.com")
+ })
+
+ ctrl := gomock.NewController(t)
+ mockSqsClient := sqsclient.NewMockSQSClient(ctrl)
+ responseQueueURL := "https://sqs.us-east-1.amazonaws.com/testing"
+ sqsClientID := "test123"
+ sqsUrl, _ := url.Parse("https://sqs.us-east-1.amazonaws.com/broker")
+ fakeEncPollResp := makeEncPollResp(
+ `{"answer": "{\"type\":\"answer\",\"sdp\":\"fake\"}" }`,
+ "",
+ )
+ sqsRendezvous := sqsRendezvous{
+ transport: &mockTransport{http.StatusOK, []byte{}},
+ sqsClientID: sqsClientID,
+ sqsClient: mockSqsClient,
+ sqsURL: sqsUrl,
+ timeout: 0,
+ numRetries: 5,
+ }
+
+ Convey("sqsRendezvous.Exchange responds with answer", func() {
+ mockSqsClient.EXPECT().SendMessage(gomock.Any(), &sqs.SendMessageInput{
+ MessageAttributes: map[string]types.MessageAttributeValue{
+ "ClientID": {
+ DataType: aws.String("String"),
+ StringValue: aws.String(sqsClientID),
+ },
+ },
+ MessageBody: aws.String(string(fakeEncPollResp)),
+ QueueUrl: aws.String(sqsUrl.String()),
+ })
+ mockSqsClient.EXPECT().GetQueueUrl(gomock.Any(), &sqs.GetQueueUrlInput{
+ QueueName: aws.String("snowflake-client-" + sqsClientID),
+ }).Return(&sqs.GetQueueUrlOutput{
+ QueueUrl: aws.String(responseQueueURL),
+ }, nil)
+ mockSqsClient.EXPECT().ReceiveMessage(gomock.Any(), gomock.Eq(&sqs.ReceiveMessageInput{
+ QueueUrl: &responseQueueURL,
+ MaxNumberOfMessages: 1,
+ WaitTimeSeconds: 20,
+ })).Return(&sqs.ReceiveMessageOutput{
+ Messages: []types.Message{{Body: aws.String("answer")}},
+ }, nil)
+
+ answer, err := sqsRendezvous.Exchange(fakeEncPollResp)
+
+ So(answer, ShouldEqual, []byte("answer"))
+ So(err, ShouldBeNil)
+ })
+
+ Convey("sqsRendezvous.Exchange cannot get queue url", func() {
+ mockSqsClient.EXPECT().SendMessage(gomock.Any(), &sqs.SendMessageInput{
+ MessageAttributes: map[string]types.MessageAttributeValue{
+ "ClientID": {
+ DataType: aws.String("String"),
+ StringValue: aws.String(sqsClientID),
+ },
+ },
+ MessageBody: aws.String(string(fakeEncPollResp)),
+ QueueUrl: aws.String(sqsUrl.String()),
+ })
+ for i := 0; i < sqsRendezvous.numRetries; i++ {
+ mockSqsClient.EXPECT().GetQueueUrl(gomock.Any(), &sqs.GetQueueUrlInput{
+ QueueName: aws.String("snowflake-client-" + sqsClientID),
+ }).Return(nil, errors.New("test error"))
+ }
+
+ answer, err := sqsRendezvous.Exchange(fakeEncPollResp)
+
+ So(answer, ShouldBeNil)
+ So(err, ShouldNotBeNil)
+ So(err, ShouldEqual, errors.New("test error"))
+ })
+
+ Convey("sqsRendezvous.Exchange does not receive answer", func() {
+ mockSqsClient.EXPECT().SendMessage(gomock.Any(), &sqs.SendMessageInput{
+ MessageAttributes: map[string]types.MessageAttributeValue{
+ "ClientID": {
+ DataType: aws.String("String"),
+ StringValue: aws.String(sqsClientID),
+ },
+ },
+ MessageBody: aws.String(string(fakeEncPollResp)),
+ QueueUrl: aws.String(sqsUrl.String()),
+ })
+ mockSqsClient.EXPECT().GetQueueUrl(gomock.Any(), &sqs.GetQueueUrlInput{
+ QueueName: aws.String("snowflake-client-" + sqsClientID),
+ }).Return(&sqs.GetQueueUrlOutput{
+ QueueUrl: aws.String(responseQueueURL),
+ }, nil)
+ for i := 0; i < sqsRendezvous.numRetries; i++ {
+ mockSqsClient.EXPECT().ReceiveMessage(gomock.Any(), gomock.Eq(&sqs.ReceiveMessageInput{
+ QueueUrl: &responseQueueURL,
+ MaxNumberOfMessages: 1,
+ WaitTimeSeconds: 20,
+ })).Return(&sqs.ReceiveMessageOutput{
+ Messages: []types.Message{},
+ }, nil)
+ }
+
+ answer, err := sqsRendezvous.Exchange(fakeEncPollResp)
+
+ So(answer, ShouldEqual, []byte{})
+ So(err, ShouldBeNil)
+ })
+ })
+}
diff --git a/client/lib/sqs_test.go b/client/lib/sqs_test.go
deleted file mode 100644
index 02da33f..0000000
--- a/client/lib/sqs_test.go
+++ /dev/null
@@ -1,30 +0,0 @@
-package snowflake_client
-
-import (
- "context"
- "testing"
-
- "github.com/aws/aws-sdk-go-v2/aws"
- "github.com/aws/aws-sdk-go-v2/service/sqs"
- "github.com/golang/mock/gomock"
- . "github.com/smartystreets/goconvey/convey"
- "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/sqsclient"
-)
-
-func TestExample(t *testing.T) {
- Convey("Test Example 1", t, func() {
- ctrl := gomock.NewController(t)
- mockSqsClient := sqsclient.NewMockSQSClient(ctrl)
- mockSqsClient.EXPECT().GetQueueUrl(gomock.Any(), gomock.Any()).Return(&sqs.GetQueueUrlOutput{
- QueueUrl: aws.String("https://wwww.google.com"),
- }, nil)
-
- output, err := mockSqsClient.GetQueueUrl(context.TODO(), &sqs.GetQueueUrlInput{
- QueueName: aws.String("testing"),
- })
- ShouldBeNil(err)
- ShouldEqual(output, sqs.GetQueueUrlOutput{
- QueueUrl: aws.String("https://wwww.google.com"),
- })
- })
-}