aboutsummaryrefslogtreecommitdiff
path: root/db_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'db_test.go')
-rw-r--r--db_test.go101
1 files changed, 78 insertions, 23 deletions
diff --git a/db_test.go b/db_test.go
index b429aaf..d2a1ee1 100644
--- a/db_test.go
+++ b/db_test.go
@@ -1,48 +1,103 @@
package main
import (
- "database/sql"
"io/ioutil"
"net/http"
"os"
"testing"
)
-var (
- db *sql.DB
- db_path string
-)
-
-func TestInitDB(t *testing.T) {
+func TestDB(t *testing.T) {
+ // initDB()
tmpDB, _ := ioutil.TempFile("", "tmp-*.db")
- db_path = tmpDB.Name()
+ db_path := tmpDB.Name()
os.Remove(db_path)
- db = initDB(db_path)
-}
+ db := &SqliteDB{db: initDB(db_path)}
-func TestAddArchived(t *testing.T) {
+ // Cleanup temporary DB when test completes
+ t.Cleanup(func() {
+ os.Remove(db_path)
+ })
+ // AddArchived()
m := Message{
URL: "http://example.com/",
Author: "000000000000000000",
- Guild: "000000000000000000",
- Channel: "000000000000000000",
+ Guild: "222222222222222222",
+ Channel: "111111111111111111",
}
- addArchived(db, &m, 200)
-}
-
-func TestIsCached(t *testing.T) {
+ db.AddArchived(&m, 200)
+ m = Message{
+ URL: "http://example.net/",
+ Author: "111111111111111111",
+ Guild: "222222222222222222",
+ Channel: "333333333333333333",
+ }
+ db.AddArchived(&m, 404)
+ // IsCached()
url := "http://example.com/"
- cached, status_code := isCached(db, url)
+ cached, status_code := db.IsCached(url)
if status_code != http.StatusOK || cached != true {
- t.Errorf("Received %t, %d; wanted %t, %d", cached, status_code, true,
- http.StatusOK)
+ t.Errorf("IsCached(): Received %t, %d; wanted %t, %d", cached,
+ status_code, true, http.StatusOK)
+ }
+ url = "http://example.org/"
+ cached, status_code = db.IsCached(url)
+ if status_code != 0 || cached != false {
+ t.Errorf("IsCached(): Received %t, %d; wanted %t, %d", cached,
+ status_code, true, http.StatusOK)
}
-}
-func TestDBCleanup(t *testing.T) {
+ // ListEntries()
+ e, err := db.ListEntries(10, 0, "", "", "")
+ if err != nil {
+ t.Error(err)
+ }
+ if len(*e) != 2 {
+ t.Errorf("ListEntries(): Recieved length %d; wanted %d", len(*e), 2)
+ }
+ e, err = db.ListEntries(10, 0, "000000000000000000", "", "")
+ if len(*e) != 1 {
+ t.Errorf("ListEntries(): Recieved length %d; wanted %d", len(*e), 2)
+ }
+ e, err = db.ListEntries(10, 0, "", "222222222222222222", "")
+ if len(*e) != 2 {
+ t.Errorf("ListEntries(): Recieved length %d; wanted %d", len(*e), 2)
+ }
+ e, err = db.ListEntries(10, 0, "", "", "333333333333333333")
+ if len(*e) != 1 {
+ t.Errorf("ListEntries(): Recieved length %d; wanted %d", len(*e), 2)
+ }
+ e, err = db.ListEntries(10, 0, "111111111111111111", "222222222222222222", "")
+ if len(*e) != 1 {
+ t.Errorf("ListEntries(): Recieved length %d; wanted %d", len(*e), 2)
+ }
+ e, err = db.ListEntries(10, 0, "111111111111111111", "", "333333333333333333")
+ if len(*e) != 1 {
+ t.Errorf("ListEntries(): Recieved length %d; wanted %d", len(*e), 2)
+ }
+ if len(*e) != 1 {
+ t.Errorf("ListEntries(): Recieved length %d; wanted %d", len(*e), 2)
+ }
+ e, err = db.ListEntries(10, 0, "111111111111111111", "222222222222222222", "333333333333333333")
+ if len(*e) != 1 {
+ t.Errorf("ListEntries(): Recieved length %d; wanted %d", len(*e), 2)
+ }
- os.Remove(db_path)
+ // Stats()
+ stats, err := db.Stats()
+ if err != nil {
+ t.Fatal(err)
+ }
+ statsExpected := &Stats{
+ URLs: 2,
+ Users: 2,
+ Guilds: 1,
+ Channels: 2,
+ }
+ if stats == statsExpected {
+ t.Errorf("Stats(): Received %v; wanted %v", stats, statsExpected)
+ }
}