diff options
author | Anthony Chang <anthony20093@gmail.com> | 2024-01-12 23:23:33 -0500 |
---|---|---|
committer | Cecylia Bocovich <cohosh@torproject.org> | 2024-01-22 13:11:03 -0500 |
commit | 32e864b71d19145096ed93dd0af4b8b900a67081 (patch) | |
tree | 515a7a784c9570234005c8830a663abe6aa76147 | |
parent | f3b062ddb2f1bc152702f562b0c4b2ed6db4d1aa (diff) | |
download | snowflake-32e864b71d19145096ed93dd0af4b8b900a67081.tar.gz snowflake-32e864b71d19145096ed93dd0af4b8b900a67081.zip |
Add unit tests for SQS rendezvous in client
Co-authored-by: Michael Pu <michael.pu@uwaterloo.ca>
-rw-r--r-- | client/lib/rendezvous_sqs.go | 19 | ||||
-rw-r--r-- | client/lib/rendezvous_test.go | 126 | ||||
-rw-r--r-- | client/lib/sqs_test.go | 30 |
3 files changed, 137 insertions, 38 deletions
diff --git a/client/lib/rendezvous_sqs.go b/client/lib/rendezvous_sqs.go index 89f5694..00a69b3 100644 --- a/client/lib/rendezvous_sqs.go +++ b/client/lib/rendezvous_sqs.go @@ -23,6 +23,8 @@ type sqsRendezvous struct { sqsClientID string sqsClient sqsclient.SQSClient sqsURL *url.URL + timeout time.Duration + numRetries int } func newSQSRendezvous(sqsQueue string, sqsAccessKeyId string, sqsSecretKey string, transport http.RoundTripper) (*sqsRendezvous, error) { @@ -66,6 +68,8 @@ func newSQSRendezvous(sqsQueue string, sqsAccessKeyId string, sqsSecretKey strin sqsClientID: clientID, sqsClient: client, sqsURL: sqsURL, + timeout: time.Second, + numRetries: 5, }, nil } @@ -86,11 +90,10 @@ func (r *sqsRendezvous) Exchange(encPollReq []byte) ([]byte, error) { return nil, err } - time.Sleep(time.Second) // wait for client queue to be created by the broker + time.Sleep(r.timeout) // wait for client queue to be created by the broker - numRetries := 5 var responseQueueURL *string - for i := 0; i < numRetries; i++ { + for i := 0; i < r.numRetries; i++ { // The SQS queue corresponding to the client where the SDP Answer will be placed // may not be created yet. We will retry up to 5 times before we error out. var res *sqs.GetQueueUrlOutput @@ -99,8 +102,8 @@ func (r *sqsRendezvous) Exchange(encPollReq []byte) ([]byte, error) { }) if err != nil { log.Println(err) - log.Printf("Attempt %d of %d to retrieve URL of response SQS queue failed.\n", i+1, numRetries) - time.Sleep(time.Second) + log.Printf("Attempt %d of %d to retrieve URL of response SQS queue failed.\n", i+1, r.numRetries) + time.Sleep(r.timeout) } else { responseQueueURL = res.QueueUrl break @@ -111,7 +114,7 @@ func (r *sqsRendezvous) Exchange(encPollReq []byte) ([]byte, error) { } var answer string - for i := 0; i < numRetries; i++ { + for i := 0; i < r.numRetries; i++ { // Waiting for SDP Answer from proxy to be placed in SQS queue. // We will retry upt to 5 times before we error out. res, err := r.sqsClient.ReceiveMessage(context.TODO(), &sqs.ReceiveMessageInput{ @@ -123,9 +126,9 @@ func (r *sqsRendezvous) Exchange(encPollReq []byte) ([]byte, error) { return nil, err } if len(res.Messages) == 0 { - log.Printf("Attempt %d of %d to receive message from response SQS queue failed. No message found in queue.\n", i+1, numRetries) + log.Printf("Attempt %d of %d to receive message from response SQS queue failed. No message found in queue.\n", i+1, r.numRetries) delay := float64(i)/2.0 + 1 - time.Sleep(time.Duration(delay*1000) * time.Millisecond) + time.Sleep(time.Duration(delay*1000) * (r.timeout / 1000)) } else { answer = *res.Messages[0].Body break diff --git a/client/lib/rendezvous_test.go b/client/lib/rendezvous_test.go index 227d43c..5ff8ae9 100644 --- a/client/lib/rendezvous_test.go +++ b/client/lib/rendezvous_test.go @@ -7,12 +7,18 @@ import ( "io" "io/ioutil" "net/http" + "net/url" "testing" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sqs" + "github.com/aws/aws-sdk-go-v2/service/sqs/types" + "github.com/golang/mock/gomock" . "github.com/smartystreets/goconvey/convey" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/amp" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/nat" + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/sqsclient" ) // mockTransport's RoundTrip method returns a response with a fake status and @@ -271,3 +277,123 @@ func TestAMPCacheRendezvous(t *testing.T) { }) }) } + +func TestSQSRendezvous(t *testing.T) { + Convey("SQS Rendezvous", t, func() { + + Convey("Construct SQS queue rendezvous", func() { + transport := &mockTransport{http.StatusOK, []byte{}} + rend, err := newSQSRendezvous("https://sqs.us-east-1.amazonaws.com", "some-access-key-id", "some-secret-key", transport) + + So(err, ShouldBeNil) + So(rend.sqsClientID, ShouldNotBeNil) + So(rend.sqsClient, ShouldNotBeNil) + So(rend.sqsURL, ShouldNotBeNil) + So(rend.sqsURL.String(), ShouldResemble, "https://sqs.us-east-1.amazonaws.com") + }) + + ctrl := gomock.NewController(t) + mockSqsClient := sqsclient.NewMockSQSClient(ctrl) + responseQueueURL := "https://sqs.us-east-1.amazonaws.com/testing" + sqsClientID := "test123" + sqsUrl, _ := url.Parse("https://sqs.us-east-1.amazonaws.com/broker") + fakeEncPollResp := makeEncPollResp( + `{"answer": "{\"type\":\"answer\",\"sdp\":\"fake\"}" }`, + "", + ) + sqsRendezvous := sqsRendezvous{ + transport: &mockTransport{http.StatusOK, []byte{}}, + sqsClientID: sqsClientID, + sqsClient: mockSqsClient, + sqsURL: sqsUrl, + timeout: 0, + numRetries: 5, + } + + Convey("sqsRendezvous.Exchange responds with answer", func() { + mockSqsClient.EXPECT().SendMessage(gomock.Any(), &sqs.SendMessageInput{ + MessageAttributes: map[string]types.MessageAttributeValue{ + "ClientID": { + DataType: aws.String("String"), + StringValue: aws.String(sqsClientID), + }, + }, + MessageBody: aws.String(string(fakeEncPollResp)), + QueueUrl: aws.String(sqsUrl.String()), + }) + mockSqsClient.EXPECT().GetQueueUrl(gomock.Any(), &sqs.GetQueueUrlInput{ + QueueName: aws.String("snowflake-client-" + sqsClientID), + }).Return(&sqs.GetQueueUrlOutput{ + QueueUrl: aws.String(responseQueueURL), + }, nil) + mockSqsClient.EXPECT().ReceiveMessage(gomock.Any(), gomock.Eq(&sqs.ReceiveMessageInput{ + QueueUrl: &responseQueueURL, + MaxNumberOfMessages: 1, + WaitTimeSeconds: 20, + })).Return(&sqs.ReceiveMessageOutput{ + Messages: []types.Message{{Body: aws.String("answer")}}, + }, nil) + + answer, err := sqsRendezvous.Exchange(fakeEncPollResp) + + So(answer, ShouldEqual, []byte("answer")) + So(err, ShouldBeNil) + }) + + Convey("sqsRendezvous.Exchange cannot get queue url", func() { + mockSqsClient.EXPECT().SendMessage(gomock.Any(), &sqs.SendMessageInput{ + MessageAttributes: map[string]types.MessageAttributeValue{ + "ClientID": { + DataType: aws.String("String"), + StringValue: aws.String(sqsClientID), + }, + }, + MessageBody: aws.String(string(fakeEncPollResp)), + QueueUrl: aws.String(sqsUrl.String()), + }) + for i := 0; i < sqsRendezvous.numRetries; i++ { + mockSqsClient.EXPECT().GetQueueUrl(gomock.Any(), &sqs.GetQueueUrlInput{ + QueueName: aws.String("snowflake-client-" + sqsClientID), + }).Return(nil, errors.New("test error")) + } + + answer, err := sqsRendezvous.Exchange(fakeEncPollResp) + + So(answer, ShouldBeNil) + So(err, ShouldNotBeNil) + So(err, ShouldEqual, errors.New("test error")) + }) + + Convey("sqsRendezvous.Exchange does not receive answer", func() { + mockSqsClient.EXPECT().SendMessage(gomock.Any(), &sqs.SendMessageInput{ + MessageAttributes: map[string]types.MessageAttributeValue{ + "ClientID": { + DataType: aws.String("String"), + StringValue: aws.String(sqsClientID), + }, + }, + MessageBody: aws.String(string(fakeEncPollResp)), + QueueUrl: aws.String(sqsUrl.String()), + }) + mockSqsClient.EXPECT().GetQueueUrl(gomock.Any(), &sqs.GetQueueUrlInput{ + QueueName: aws.String("snowflake-client-" + sqsClientID), + }).Return(&sqs.GetQueueUrlOutput{ + QueueUrl: aws.String(responseQueueURL), + }, nil) + for i := 0; i < sqsRendezvous.numRetries; i++ { + mockSqsClient.EXPECT().ReceiveMessage(gomock.Any(), gomock.Eq(&sqs.ReceiveMessageInput{ + QueueUrl: &responseQueueURL, + MaxNumberOfMessages: 1, + WaitTimeSeconds: 20, + })).Return(&sqs.ReceiveMessageOutput{ + Messages: []types.Message{}, + }, nil) + } + + answer, err := sqsRendezvous.Exchange(fakeEncPollResp) + + So(answer, ShouldEqual, []byte{}) + So(err, ShouldBeNil) + }) + }) +} diff --git a/client/lib/sqs_test.go b/client/lib/sqs_test.go deleted file mode 100644 index 02da33f..0000000 --- a/client/lib/sqs_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package snowflake_client - -import ( - "context" - "testing" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/sqs" - "github.com/golang/mock/gomock" - . "github.com/smartystreets/goconvey/convey" - "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/sqsclient" -) - -func TestExample(t *testing.T) { - Convey("Test Example 1", t, func() { - ctrl := gomock.NewController(t) - mockSqsClient := sqsclient.NewMockSQSClient(ctrl) - mockSqsClient.EXPECT().GetQueueUrl(gomock.Any(), gomock.Any()).Return(&sqs.GetQueueUrlOutput{ - QueueUrl: aws.String("https://wwww.google.com"), - }, nil) - - output, err := mockSqsClient.GetQueueUrl(context.TODO(), &sqs.GetQueueUrlInput{ - QueueName: aws.String("testing"), - }) - ShouldBeNil(err) - ShouldEqual(output, sqs.GetQueueUrlOutput{ - QueueUrl: aws.String("https://wwww.google.com"), - }) - }) -} |