diff --git a/main.go b/main.go index 9318b11..a87412b 100644 --- a/main.go +++ b/main.go @@ -32,9 +32,57 @@ const ( asnDBURL = "https://download.db-ip.com/free/dbip-asn-lite-%s.mmdb.gz" ) +func main() { + initDatabases() + go startUpdater() + startServer() +} + +func initDatabases() { + var err error + + cityDB, err = maxminddb.Open(currCityFilename) + if err != nil { + if os.IsNotExist(err) { + currCityFilename = "" + doUpdate() + if cityDB == nil { + log.Fatalf("Failed to initialize city database: %v", err) + } + } else { + log.Fatalf("Error opening city database: %v", err) + } + } + + asnDB, err = maxminddb.Open(currASNFilename) + if err != nil { + if os.IsNotExist(err) { + currASNFilename = "" + doUpdate() + if asnDB == nil { + log.Fatalf("Failed to initialize ASN database: %v", err) + } + } else { + log.Fatalf("Error opening ASN database: %v", err) + } + } +} + +func startUpdater() { + for range time.Tick(time.Hour * 24 * 7) { + doUpdate() + } +} + +func startServer() { + log.Println("Server listening on :3000") + http.HandleFunc("/", handler) + log.Fatal(http.ListenAndServe(":3000", nil)) +} + // Fetch and update the GeoIP databases. func doUpdate() { - fmt.Fprintln(os.Stderr, "Fetching updates...") + log.Println("Fetching updates...") currMonth := time.Now().Format("2006-01") newCityFilename := currMonth + "-city.mmdb" newASNFilename := currMonth + "-asn.mmdb" @@ -47,7 +95,7 @@ func doUpdate() { } cityDB = newDB currCityFilename = newCityFilename - fmt.Fprintf(os.Stderr, "City GeoIP database updated to %s\n", currMonth) + log.Printf("City GeoIP database updated to %s\n", currMonth) }) updateDatabase(asnDBURL, newASNFilename, func(newDB *maxminddb.Reader) { @@ -58,93 +106,51 @@ func doUpdate() { } asnDB = newDB currASNFilename = newASNFilename - fmt.Fprintf(os.Stderr, "ASN GeoIP database updated to %s\n", currMonth) + log.Printf("ASN GeoIP database updated to %s\n", currMonth) }) } -// Download and update the database file. func updateDatabase(urlTemplate, dstFilename string, updateFunc func(*maxminddb.Reader)) { resp, err := http.Get(fmt.Sprintf(urlTemplate, time.Now().Format("2006-01"))) if err != nil { - fmt.Fprintf(os.Stderr, "Error fetching the updated DB: %v\n", err) + log.Printf("Error fetching the updated DB: %v\n", err) return } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - fmt.Fprintf(os.Stderr, "Non-200 status code (%d), retry later...\n", resp.StatusCode) + log.Printf("Non-200 status code (%d), retry later...\n", resp.StatusCode) return } dst, err := os.Create(dstFilename) if err != nil { - fmt.Fprintf(os.Stderr, "Error creating file: %v\n", err) + log.Printf("Error creating file: %v\n", err) return } defer dst.Close() r, err := gzip.NewReader(resp.Body) if err != nil { - fmt.Fprintf(os.Stderr, "Error creating gzip reader: %v\n", err) + log.Printf("Error creating gzip reader: %v\n", err) return } defer r.Close() - fmt.Fprintln(os.Stderr, "Copying new database...") + log.Println("Copying new database...") if _, err = io.Copy(dst, r); err != nil { - fmt.Fprintf(os.Stderr, "Error copying file: %v\n", err) + log.Printf("Error copying file: %v\n", err) return } newDB, err := maxminddb.Open(dstFilename) if err != nil { - fmt.Fprintf(os.Stderr, "Error opening new DB: %v\n", err) + log.Printf("Error opening new DB: %v\n", err) return } updateFunc(newDB) } -// Periodically update the GeoIP databases every week. -func updater() { - for range time.Tick(time.Hour * 24 * 7) { - doUpdate() - } -} - -func main() { - var err error - cityDB, err = maxminddb.Open(currCityFilename) - if err != nil { - if os.IsNotExist(err) { - currCityFilename = "" - doUpdate() - if cityDB == nil { - os.Exit(1) - } - } else { - log.Fatal(err) - } - } - - asnDB, err = maxminddb.Open(currASNFilename) - if err != nil { - if os.IsNotExist(err) { - currASNFilename = "" - doUpdate() - if asnDB == nil { - os.Exit(1) - } - } else { - log.Fatal(err) - } - } - - go updater() - - log.Println("Server listening on :3000") - http.ListenAndServe(":3000", http.HandlerFunc(handler)) -} - var invalidIPBytes = []byte("Please provide a valid IP address.") type dataStruct struct { @@ -170,13 +176,7 @@ func handler(w http.ResponseWriter, r *http.Request) { } if IPAddress == "" || IPAddress == "self" { - if realIP := r.Header.Get("CF-Connecting-IP"); realIP != "" { - IPAddress = realIP - } else if realIP := r.Header.Get("X-Forwarded-For"); realIP != "" { - IPAddress = strings.Split(realIP, ",")[0] - } else { - IPAddress = extractIP(r.RemoteAddr) - } + IPAddress = getRealIP(r) } ip := net.ParseIP(IPAddress) if ip == nil { @@ -185,7 +185,44 @@ func handler(w http.ResponseWriter, r *http.Request) { return } + data := lookupIPData(ip) + if data == nil { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Write(invalidIPBytes) + return + } + + w.Header().Set("Content-Type", "application/json; charset=utf-8") + callback := r.URL.Query().Get("callback") + enableJSONP := callback != "" && len(callback) < 2000 && callbackJSONP.MatchString(callback) + if enableJSONP { + jsonData, _ := json.MarshalIndent(data, "", " ") + response := fmt.Sprintf("/**/ typeof %s === 'function' && %s(%s);", callback, callback, jsonData) + w.Write([]byte(response)) + } else { + enc := json.NewEncoder(w) + enc.SetIndent("", " ") + if r.URL.Query().Get("compact") == "true" { + enc.SetIndent("", "") + } + enc.Encode(data) + } +} + +func getRealIP(r *http.Request) string { + if realIP := r.Header.Get("CF-Connecting-IP"); realIP != "" { + return realIP + } else if realIP := r.Header.Get("X-Forwarded-For"); realIP != "" { + return strings.Split(realIP, ",")[0] + } else { + return extractIP(r.RemoteAddr) + } +} + +func lookupIPData(ip net.IP) *dataStruct { dbMtx.RLock() + defer dbMtx.RUnlock() + var cityRecord struct { Country struct { IsoCode string `maxminddb:"iso_code"` @@ -207,20 +244,19 @@ func handler(w http.ResponseWriter, r *http.Request) { } `maxminddb:"location"` } err := cityDB.Lookup(ip, &cityRecord) - dbMtx.RUnlock() if err != nil { - log.Fatal(err) + log.Printf("Error looking up city data: %v\n", err) + return nil } - dbMtx.RLock() var asnRecord struct { AutonomousSystemNumber uint `maxminddb:"autonomous_system_number"` AutonomousSystemOrganization string `maxminddb:"autonomous_system_organization"` } err = asnDB.Lookup(ip, &asnRecord) - dbMtx.RUnlock() if err != nil { - log.Fatal(err) + log.Printf("Error looking up ASN data: %v\n", err) + return nil } hostname, err := net.LookupAddr(ip.String()) @@ -233,7 +269,7 @@ func handler(w http.ResponseWriter, r *http.Request) { sd = cityRecord.Subdivisions[0].Names["en"] } - d := dataStruct{ + return &dataStruct{ IP: ip.String(), Hostname: strings.TrimSuffix(hostname[0], "."), ASN: fmt.Sprintf("%d", asnRecord.AutonomousSystemNumber), @@ -246,22 +282,6 @@ func handler(w http.ResponseWriter, r *http.Request) { ContinentFull: cityRecord.Continent.Names["en"], Loc: fmt.Sprintf("%.4f,%.4f", cityRecord.Location.Latitude, cityRecord.Location.Longitude), } - - w.Header().Set("Content-Type", "application/json; charset=utf-8") - callback := r.URL.Query().Get("callback") - enableJSONP := callback != "" && len(callback) < 2000 && callbackJSONP.MatchString(callback) - if enableJSONP { - jsonData, _ := json.MarshalIndent(d, "", " ") - response := fmt.Sprintf("/**/ typeof %s === 'function' && %s(%s);", callback, callback, jsonData) - w.Write([]byte(response)) - } else { - enc := json.NewEncoder(w) - enc.SetIndent("", " ") - if r.URL.Query().Get("compact") == "true" { - enc.SetIndent("", "") - } - enc.Encode(d) - } } var callbackJSONP = regexp.MustCompile(`^[a-zA-Z_\$][a-zA-Z0-9_\$]*$`)