From ea53882eeaa7250940402ecffd51b09df2d667c4 Mon Sep 17 00:00:00 2001 From: skidoodle Date: Thu, 11 Jul 2024 08:50:04 +0200 Subject: [PATCH] chore: Refactor IP address retrieval logic in main.go --- main.go | 105 ++++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 80 insertions(+), 25 deletions(-) diff --git a/main.go b/main.go index a87412b..645e472 100644 --- a/main.go +++ b/main.go @@ -154,25 +154,30 @@ func updateDatabase(urlTemplate, dstFilename string, updateFunc func(*maxminddb. var invalidIPBytes = []byte("Please provide a valid IP address.") type dataStruct struct { - IP string `json:"ip"` - Hostname string `json:"hostname"` - ASN string `json:"asn"` - Organization string `json:"organization"` - City string `json:"city"` - Region string `json:"region"` - Country string `json:"country"` - CountryFull string `json:"country_full"` - Continent string `json:"continent"` - ContinentFull string `json:"continent_full"` - Loc string `json:"loc"` + IP *string `json:"ip"` + Hostname *string `json:"hostname"` + ASN *string `json:"asn"` + Organization *string `json:"organization"` + City *string `json:"city"` + Region *string `json:"region"` + Country *string `json:"country"` + CountryFull *string `json:"country_full"` + Continent *string `json:"continent"` + ContinentFull *string `json:"continent_full"` + Loc *string `json:"loc"` } func handler(w http.ResponseWriter, r *http.Request) { requestedThings := strings.Split(r.URL.Path, "/") - var IPAddress string - if len(requestedThings) > 1 { + var IPAddress, field string + if len(requestedThings) > 1 && net.ParseIP(requestedThings[1]) != nil { IPAddress = requestedThings[1] + if len(requestedThings) > 2 { + field = requestedThings[2] + } + } else if len(requestedThings) > 1 { + field = requestedThings[1] } if IPAddress == "" || IPAddress == "self" { @@ -192,6 +197,19 @@ func handler(w http.ResponseWriter, r *http.Request) { return } + if field != "" { + value := getField(data, field) + if value != nil { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + json.NewEncoder(w).Encode(map[string]*string{field: value}) + return + } else { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + json.NewEncoder(w).Encode(map[string]*string{field: nil}) + return + } + } + w.Header().Set("Content-Type", "application/json; charset=utf-8") callback := r.URL.Query().Get("callback") enableJSONP := callback != "" && len(callback) < 2000 && callbackJSONP.MatchString(callback) @@ -209,6 +227,35 @@ func handler(w http.ResponseWriter, r *http.Request) { } } +func getField(data *dataStruct, field string) *string { + switch field { + case "ip": + return data.IP + case "hostname": + return data.Hostname + case "asn": + return data.ASN + case "organization": + return data.Organization + case "city": + return data.City + case "region": + return data.Region + case "country": + return data.Country + case "country_full": + return data.CountryFull + case "continent": + return data.Continent + case "continent_full": + return data.ContinentFull + case "loc": + return data.Loc + default: + return nil + } +} + func getRealIP(r *http.Request) string { if realIP := r.Header.Get("CF-Connecting-IP"); realIP != "" { return realIP @@ -264,26 +311,34 @@ func lookupIPData(ip net.IP) *dataStruct { hostname = []string{""} } - var sd string + var sd *string if len(cityRecord.Subdivisions) > 0 { - sd = cityRecord.Subdivisions[0].Names["en"] + name := cityRecord.Subdivisions[0].Names["en"] + sd = &name } return &dataStruct{ - IP: ip.String(), - Hostname: strings.TrimSuffix(hostname[0], "."), - ASN: fmt.Sprintf("%d", asnRecord.AutonomousSystemNumber), - Organization: asnRecord.AutonomousSystemOrganization, - Country: cityRecord.Country.IsoCode, - CountryFull: cityRecord.Country.Names["en"], - City: cityRecord.City.Names["en"], + IP: toPtr(ip.String()), + Hostname: toPtr(strings.TrimSuffix(hostname[0], ".")), + ASN: toPtr(fmt.Sprintf("%d", asnRecord.AutonomousSystemNumber)), + Organization: toPtr(asnRecord.AutonomousSystemOrganization), + Country: toPtr(cityRecord.Country.IsoCode), + CountryFull: toPtr(cityRecord.Country.Names["en"]), + City: toPtr(cityRecord.City.Names["en"]), Region: sd, - Continent: cityRecord.Continent.Code, - ContinentFull: cityRecord.Continent.Names["en"], - Loc: fmt.Sprintf("%.4f,%.4f", cityRecord.Location.Latitude, cityRecord.Location.Longitude), + Continent: toPtr(cityRecord.Continent.Code), + ContinentFull: toPtr(cityRecord.Continent.Names["en"]), + Loc: toPtr(fmt.Sprintf("%.4f,%.4f", cityRecord.Location.Latitude, cityRecord.Location.Longitude)), } } +func toPtr(s string) *string { + if s == "" { + return nil + } + return &s +} + var callbackJSONP = regexp.MustCompile(`^[a-zA-Z_\$][a-zA-Z0-9_\$]*$`) // Extract the IP address from a string, removing unwanted characters.