aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJakob Borg <jakob@kastelo.net>2023-07-10 08:27:12 +0200
committerJakob Borg <jakob@kastelo.net>2023-07-10 08:27:16 +0200
commitbf61e485a6e3b51b87b97c183325302696508241 (patch)
tree7416cb8b71d2dab213fbd846176cfac97b759f46
parentb2886f11b1682bf5f45439c185651b2ff9417002 (diff)
downloadsyncthing-bf61e485a6e3b51b87b97c183325302696508241.tar.gz
syncthing-bf61e485a6e3b51b87b97c183325302696508241.zip
cmd/ursrv: Refactor to use CLI options, fewer global vars
-rw-r--r--cmd/ursrv/main.go149
1 files changed, 82 insertions, 67 deletions
diff --git a/cmd/ursrv/main.go b/cmd/ursrv/main.go
index c69559464..76014d268 100644
--- a/cmd/ursrv/main.go
+++ b/cmd/ursrv/main.go
@@ -25,6 +25,7 @@ import (
"time"
"unicode"
+ "github.com/alecthomas/kong"
_ "github.com/lib/pq" // PostgreSQL driver
"github.com/oschwald/geoip2-golang"
"golang.org/x/text/cases"
@@ -34,14 +35,17 @@ import (
"github.com/syncthing/syncthing/lib/ur/contract"
)
+type CLI struct {
+ UseHTTP bool `env:"UR_USE_HTTP"`
+ Debug bool `env:"UR_DEBUG"`
+ KeyFile string `env:"UR_KEY_FILE" default:"key.pem"`
+ CertFile string `env:"UR_CRT_FILE" default:"crt.pem"`
+ DBConn string `env:"UR_DB_URL" default:"postgres://user:password@localhost/ur?sslmode=disable"`
+ Listen string `env:"UR_LISTEN" default:"0.0.0.0:8443"`
+ GeoIPPath string `env:"UR_GEOIP" default:"GeoLite2-City.mmdb"`
+}
+
var (
- useHTTP = os.Getenv("UR_USE_HTTP") != ""
- debug = os.Getenv("UR_DEBUG") != ""
- keyFile = getEnvDefault("UR_KEY_FILE", "key.pem")
- certFile = getEnvDefault("UR_CRT_FILE", "crt.pem")
- dbConn = getEnvDefault("UR_DB_URL", "postgres://user:password@localhost/ur?sslmode=disable")
- listenAddr = getEnvDefault("UR_LISTEN", "0.0.0.0:8443")
- geoIPPath = getEnvDefault("UR_GEOIP", "GeoLite2-City.mmdb")
tpl *template.Template
compilerRe = regexp.MustCompile(`\(([A-Za-z0-9()., -]+) \w+-\w+(?:| android| default)\) ([\w@.-]+)`)
progressBarClass = []string{"", "progress-bar-success", "progress-bar-info", "progress-bar-warning", "progress-bar-danger"}
@@ -159,6 +163,9 @@ func main() {
log.SetFlags(log.Ltime | log.Ldate | log.Lshortfile)
log.SetOutput(os.Stdout)
+ var cli CLI
+ kong.Parse(&cli)
+
// Template
fd, err := os.Open("static/index.html")
@@ -174,7 +181,7 @@ func main() {
// DB
- db, err := sql.Open("postgres", dbConn)
+ db, err := sql.Open("postgres", cli.DBConn)
if err != nil {
log.Fatalln("database:", err)
}
@@ -186,11 +193,11 @@ func main() {
// TLS & Listening
var listener net.Listener
- if useHTTP {
- listener, err = net.Listen("tcp", listenAddr)
+ if cli.UseHTTP {
+ listener, err = net.Listen("tcp", cli.Listen)
} else {
var cert tls.Certificate
- cert, err = tls.LoadX509KeyPair(certFile, keyFile)
+ cert, err = tls.LoadX509KeyPair(cli.CertFile, cli.KeyFile)
if err != nil {
log.Fatalln("tls:", err)
}
@@ -199,81 +206,89 @@ func main() {
Certificates: []tls.Certificate{cert},
SessionTicketsDisabled: true,
}
- listener, err = tls.Listen("tcp", listenAddr, cfg)
+ listener, err = tls.Listen("tcp", cli.Listen, cfg)
}
if err != nil {
log.Fatalln("listen:", err)
}
- srv := http.Server{
- ReadTimeout: 5 * time.Second,
- WriteTimeout: 15 * time.Second,
+ srv := &server{
+ db: db,
+ debug: cli.Debug,
+ geoIPPath: cli.GeoIPPath,
}
-
- http.HandleFunc("/", withDB(db, rootHandler))
- http.HandleFunc("/newdata", withDB(db, newDataHandler))
- http.HandleFunc("/summary.json", withDB(db, summaryHandler))
- http.HandleFunc("/movement.json", withDB(db, movementHandler))
- http.HandleFunc("/performance.json", withDB(db, performanceHandler))
- http.HandleFunc("/blockstats.json", withDB(db, blockStatsHandler))
- http.HandleFunc("/locations.json", withDB(db, locationsHandler))
+ http.HandleFunc("/", srv.rootHandler)
+ http.HandleFunc("/newdata", srv.newDataHandler)
+ http.HandleFunc("/summary.json", srv.summaryHandler)
+ http.HandleFunc("/movement.json", srv.movementHandler)
+ http.HandleFunc("/performance.json", srv.performanceHandler)
+ http.HandleFunc("/blockstats.json", srv.blockStatsHandler)
+ http.HandleFunc("/locations.json", srv.locationsHandler)
http.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("static"))))
- go cacheRefresher(db)
+ go srv.cacheRefresher()
- err = srv.Serve(listener)
+ httpSrv := http.Server{
+ ReadTimeout: 5 * time.Second,
+ WriteTimeout: 15 * time.Second,
+ }
+ err = httpSrv.Serve(listener)
if err != nil {
log.Fatalln("https:", err)
}
}
-var (
+type server struct {
+ debug bool
+ db *sql.DB
+ geoIPPath string
+
+ cacheMut sync.Mutex
cachedIndex []byte
cachedLocations []byte
cacheTime time.Time
- cacheMut sync.Mutex
-)
+}
const maxCacheTime = 15 * time.Minute
-func cacheRefresher(db *sql.DB) {
+func (s *server) cacheRefresher() {
ticker := time.NewTicker(maxCacheTime - time.Minute)
defer ticker.Stop()
for ; true; <-ticker.C {
- cacheMut.Lock()
- if err := refreshCacheLocked(db); err != nil {
+ s.cacheMut.Lock()
+ if err := s.refreshCacheLocked(); err != nil {
log.Println(err)
}
- cacheMut.Unlock()
+ s.cacheMut.Unlock()
}
}
-func refreshCacheLocked(db *sql.DB) error {
- rep := getReport(db)
+func (s *server) refreshCacheLocked() error {
+ rep := getReport(s.db, s.geoIPPath)
buf := new(bytes.Buffer)
err := tpl.Execute(buf, rep)
if err != nil {
return err
}
- cachedIndex = buf.Bytes()
- cacheTime = time.Now()
+ s.cachedIndex = buf.Bytes()
+ s.cacheTime = time.Now()
locs := rep["locations"].(map[location]int)
wlocs := make([]weightedLocation, 0, len(locs))
for loc, w := range locs {
wlocs = append(wlocs, weightedLocation{loc, w})
}
- cachedLocations, _ = json.Marshal(wlocs)
+ s.cachedLocations, _ = json.Marshal(wlocs)
return nil
}
-func rootHandler(db *sql.DB, w http.ResponseWriter, r *http.Request) {
+func (s *server) rootHandler(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" || r.URL.Path == "/index.html" {
- cacheMut.Lock()
- defer cacheMut.Unlock()
+ s.cacheMut.Lock()
+ defer s.cacheMut.Unlock()
- if time.Since(cacheTime) > maxCacheTime {
- if err := refreshCacheLocked(db); err != nil {
+ if time.Since(s.cacheTime) > maxCacheTime {
+ if err := s.refreshCacheLocked(); err != nil {
log.Println(err)
http.Error(w, "Template Error", http.StatusInternalServerError)
return
@@ -281,19 +296,19 @@ func rootHandler(db *sql.DB, w http.ResponseWriter, r *http.Request) {
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
- w.Write(cachedIndex)
+ w.Write(s.cachedIndex)
} else {
http.Error(w, "Not found", 404)
return
}
}
-func locationsHandler(db *sql.DB, w http.ResponseWriter, _ *http.Request) {
- cacheMut.Lock()
- defer cacheMut.Unlock()
+func (s *server) locationsHandler(w http.ResponseWriter, _ *http.Request) {
+ s.cacheMut.Lock()
+ defer s.cacheMut.Unlock()
- if time.Since(cacheTime) > maxCacheTime {
- if err := refreshCacheLocked(db); err != nil {
+ if time.Since(s.cacheTime) > maxCacheTime {
+ if err := s.refreshCacheLocked(); err != nil {
log.Println(err)
http.Error(w, "Template Error", http.StatusInternalServerError)
return
@@ -301,10 +316,10 @@ func locationsHandler(db *sql.DB, w http.ResponseWriter, _ *http.Request) {
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
- w.Write(cachedLocations)
+ w.Write(s.cachedLocations)
}
-func newDataHandler(db *sql.DB, w http.ResponseWriter, r *http.Request) {
+func (s *server) newDataHandler(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
addr := r.Header.Get("X-Forwarded-For")
@@ -330,7 +345,7 @@ func newDataHandler(db *sql.DB, w http.ResponseWriter, r *http.Request) {
bs, _ := io.ReadAll(lr)
if err := json.Unmarshal(bs, &rep); err != nil {
log.Println("decode:", err)
- if debug {
+ if s.debug {
log.Printf("%s", bs)
}
http.Error(w, "JSON Decode Error", http.StatusInternalServerError)
@@ -339,21 +354,21 @@ func newDataHandler(db *sql.DB, w http.ResponseWriter, r *http.Request) {
if err := rep.Validate(); err != nil {
log.Println("validate:", err)
- if debug {
+ if s.debug {
log.Printf("%#v", rep)
}
http.Error(w, "Validation Error", http.StatusInternalServerError)
return
}
- if err := insertReport(db, rep); err != nil {
+ if err := insertReport(s.db, rep); err != nil {
if err.Error() == `pq: duplicate key value violates unique constraint "uniqueidjsonindex"` {
// We already have a report today for the same unique ID; drop
// this one without complaining.
return
}
log.Println("insert:", err)
- if debug {
+ if s.debug {
log.Printf("%#v", rep)
}
http.Error(w, "Database Error", http.StatusInternalServerError)
@@ -361,16 +376,16 @@ func newDataHandler(db *sql.DB, w http.ResponseWriter, r *http.Request) {
}
}
-func summaryHandler(db *sql.DB, w http.ResponseWriter, r *http.Request) {
+func (s *server) summaryHandler(w http.ResponseWriter, r *http.Request) {
min, _ := strconv.Atoi(r.URL.Query().Get("min"))
- s, err := getSummary(db, min)
+ sum, err := getSummary(s.db, min)
if err != nil {
log.Println("summaryHandler:", err)
http.Error(w, "Database Error", http.StatusInternalServerError)
return
}
- bs, err := s.MarshalJSON()
+ bs, err := sum.MarshalJSON()
if err != nil {
log.Println("summaryHandler:", err)
http.Error(w, "JSON Encode Error", http.StatusInternalServerError)
@@ -381,15 +396,15 @@ func summaryHandler(db *sql.DB, w http.ResponseWriter, r *http.Request) {
w.Write(bs)
}
-func movementHandler(db *sql.DB, w http.ResponseWriter, _ *http.Request) {
- s, err := getMovement(db)
+func (s *server) movementHandler(w http.ResponseWriter, _ *http.Request) {
+ mov, err := getMovement(s.db)
if err != nil {
log.Println("movementHandler:", err)
http.Error(w, "Database Error", http.StatusInternalServerError)
return
}
- bs, err := json.Marshal(s)
+ bs, err := json.Marshal(mov)
if err != nil {
log.Println("movementHandler:", err)
http.Error(w, "JSON Encode Error", http.StatusInternalServerError)
@@ -400,15 +415,15 @@ func movementHandler(db *sql.DB, w http.ResponseWriter, _ *http.Request) {
w.Write(bs)
}
-func performanceHandler(db *sql.DB, w http.ResponseWriter, _ *http.Request) {
- s, err := getPerformance(db)
+func (s *server) performanceHandler(w http.ResponseWriter, _ *http.Request) {
+ perf, err := getPerformance(s.db)
if err != nil {
log.Println("performanceHandler:", err)
http.Error(w, "Database Error", http.StatusInternalServerError)
return
}
- bs, err := json.Marshal(s)
+ bs, err := json.Marshal(perf)
if err != nil {
log.Println("performanceHandler:", err)
http.Error(w, "JSON Encode Error", http.StatusInternalServerError)
@@ -419,15 +434,15 @@ func performanceHandler(db *sql.DB, w http.ResponseWriter, _ *http.Request) {
w.Write(bs)
}
-func blockStatsHandler(db *sql.DB, w http.ResponseWriter, _ *http.Request) {
- s, err := getBlockStats(db)
+func (s *server) blockStatsHandler(w http.ResponseWriter, _ *http.Request) {
+ blocks, err := getBlockStats(s.db)
if err != nil {
log.Println("blockStatsHandler:", err)
http.Error(w, "Database Error", http.StatusInternalServerError)
return
}
- bs, err := json.Marshal(s)
+ bs, err := json.Marshal(blocks)
if err != nil {
log.Println("blockStatsHandler:", err)
http.Error(w, "JSON Encode Error", http.StatusInternalServerError)
@@ -513,7 +528,7 @@ type weightedLocation struct {
Weight int `json:"weight"`
}
-func getReport(db *sql.DB) map[string]interface{} {
+func getReport(db *sql.DB, geoIPPath string) map[string]interface{} {
geoip, err := geoip2.Open(geoIPPath)
if err != nil {
log.Println("opening geoip db", err)