chore: Refactor IP address retrieval logic in main.go

This commit is contained in:
skidoodle 2024-07-11 08:50:04 +02:00
parent 1f0f5668fe
commit ea53882eea

105
main.go
View file

@ -154,25 +154,30 @@ func updateDatabase(urlTemplate, dstFilename string, updateFunc func(*maxminddb.
var invalidIPBytes = []byte("Please provide a valid IP address.") var invalidIPBytes = []byte("Please provide a valid IP address.")
type dataStruct struct { type dataStruct struct {
IP string `json:"ip"` IP *string `json:"ip"`
Hostname string `json:"hostname"` Hostname *string `json:"hostname"`
ASN string `json:"asn"` ASN *string `json:"asn"`
Organization string `json:"organization"` Organization *string `json:"organization"`
City string `json:"city"` City *string `json:"city"`
Region string `json:"region"` Region *string `json:"region"`
Country string `json:"country"` Country *string `json:"country"`
CountryFull string `json:"country_full"` CountryFull *string `json:"country_full"`
Continent string `json:"continent"` Continent *string `json:"continent"`
ContinentFull string `json:"continent_full"` ContinentFull *string `json:"continent_full"`
Loc string `json:"loc"` Loc *string `json:"loc"`
} }
func handler(w http.ResponseWriter, r *http.Request) { func handler(w http.ResponseWriter, r *http.Request) {
requestedThings := strings.Split(r.URL.Path, "/") requestedThings := strings.Split(r.URL.Path, "/")
var IPAddress string var IPAddress, field string
if len(requestedThings) > 1 { if len(requestedThings) > 1 && net.ParseIP(requestedThings[1]) != nil {
IPAddress = requestedThings[1] IPAddress = requestedThings[1]
if len(requestedThings) > 2 {
field = requestedThings[2]
}
} else if len(requestedThings) > 1 {
field = requestedThings[1]
} }
if IPAddress == "" || IPAddress == "self" { if IPAddress == "" || IPAddress == "self" {
@ -192,6 +197,19 @@ func handler(w http.ResponseWriter, r *http.Request) {
return return
} }
if field != "" {
value := getField(data, field)
if value != nil {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
json.NewEncoder(w).Encode(map[string]*string{field: value})
return
} else {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
json.NewEncoder(w).Encode(map[string]*string{field: nil})
return
}
}
w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Header().Set("Content-Type", "application/json; charset=utf-8")
callback := r.URL.Query().Get("callback") callback := r.URL.Query().Get("callback")
enableJSONP := callback != "" && len(callback) < 2000 && callbackJSONP.MatchString(callback) enableJSONP := callback != "" && len(callback) < 2000 && callbackJSONP.MatchString(callback)
@ -209,6 +227,35 @@ func handler(w http.ResponseWriter, r *http.Request) {
} }
} }
func getField(data *dataStruct, field string) *string {
switch field {
case "ip":
return data.IP
case "hostname":
return data.Hostname
case "asn":
return data.ASN
case "organization":
return data.Organization
case "city":
return data.City
case "region":
return data.Region
case "country":
return data.Country
case "country_full":
return data.CountryFull
case "continent":
return data.Continent
case "continent_full":
return data.ContinentFull
case "loc":
return data.Loc
default:
return nil
}
}
func getRealIP(r *http.Request) string { func getRealIP(r *http.Request) string {
if realIP := r.Header.Get("CF-Connecting-IP"); realIP != "" { if realIP := r.Header.Get("CF-Connecting-IP"); realIP != "" {
return realIP return realIP
@ -264,26 +311,34 @@ func lookupIPData(ip net.IP) *dataStruct {
hostname = []string{""} hostname = []string{""}
} }
var sd string var sd *string
if len(cityRecord.Subdivisions) > 0 { if len(cityRecord.Subdivisions) > 0 {
sd = cityRecord.Subdivisions[0].Names["en"] name := cityRecord.Subdivisions[0].Names["en"]
sd = &name
} }
return &dataStruct{ return &dataStruct{
IP: ip.String(), IP: toPtr(ip.String()),
Hostname: strings.TrimSuffix(hostname[0], "."), Hostname: toPtr(strings.TrimSuffix(hostname[0], ".")),
ASN: fmt.Sprintf("%d", asnRecord.AutonomousSystemNumber), ASN: toPtr(fmt.Sprintf("%d", asnRecord.AutonomousSystemNumber)),
Organization: asnRecord.AutonomousSystemOrganization, Organization: toPtr(asnRecord.AutonomousSystemOrganization),
Country: cityRecord.Country.IsoCode, Country: toPtr(cityRecord.Country.IsoCode),
CountryFull: cityRecord.Country.Names["en"], CountryFull: toPtr(cityRecord.Country.Names["en"]),
City: cityRecord.City.Names["en"], City: toPtr(cityRecord.City.Names["en"]),
Region: sd, Region: sd,
Continent: cityRecord.Continent.Code, Continent: toPtr(cityRecord.Continent.Code),
ContinentFull: cityRecord.Continent.Names["en"], ContinentFull: toPtr(cityRecord.Continent.Names["en"]),
Loc: fmt.Sprintf("%.4f,%.4f", cityRecord.Location.Latitude, cityRecord.Location.Longitude), Loc: toPtr(fmt.Sprintf("%.4f,%.4f", cityRecord.Location.Latitude, cityRecord.Location.Longitude)),
} }
} }
func toPtr(s string) *string {
if s == "" {
return nil
}
return &s
}
var callbackJSONP = regexp.MustCompile(`^[a-zA-Z_\$][a-zA-Z0-9_\$]*$`) var callbackJSONP = regexp.MustCompile(`^[a-zA-Z_\$][a-zA-Z0-9_\$]*$`)
// Extract the IP address from a string, removing unwanted characters. // Extract the IP address from a string, removing unwanted characters.