From 49ab76c9ddf840458ad12200bcbc4a7da87303dd Mon Sep 17 00:00:00 2001 From: Xe Iaso Date: Thu, 22 May 2025 10:48:15 -0400 Subject: [PATCH] feat(thoth): add GeoIP checker Signed-off-by: Xe Iaso --- internal/thoth/geoipchecker.go | 36 +++++++++++++++++ internal/thoth/geoipchecker_test.go | 63 +++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+) create mode 100644 internal/thoth/geoipchecker.go create mode 100644 internal/thoth/geoipchecker_test.go diff --git a/internal/thoth/geoipchecker.go b/internal/thoth/geoipchecker.go new file mode 100644 index 00000000..3b0aaea2 --- /dev/null +++ b/internal/thoth/geoipchecker.go @@ -0,0 +1,36 @@ +package thoth + +import ( + "context" + "net/http" + "strings" + "time" + + iptoasnv1 "github.com/TecharoHQ/thoth-proto/gen/techaro/thoth/iptoasn/v1" +) + +type GeoIPChecker struct { + iptoasn iptoasnv1.IpToASNServiceClient + countries map[string]struct{} + hash string +} + +func (gipc *GeoIPChecker) Check(r *http.Request) (bool, error) { + ctx, cancel := context.WithTimeout(r.Context(), 50*time.Millisecond) + defer cancel() + + ipInfo, err := gipc.iptoasn.Lookup(ctx, &iptoasnv1.LookupRequest{ + IpAddress: r.Header.Get("X-Real-Ip"), + }) + if err != nil { + return false, err + } + + _, ok := gipc.countries[strings.ToLower(ipInfo.GetCountryCode())] + + return ok, nil +} + +func (gipc *GeoIPChecker) Hash() string { + return gipc.hash +} diff --git a/internal/thoth/geoipchecker_test.go b/internal/thoth/geoipchecker_test.go new file mode 100644 index 00000000..86203b84 --- /dev/null +++ b/internal/thoth/geoipchecker_test.go @@ -0,0 +1,63 @@ +package thoth + +import ( + "fmt" + "net/http/httptest" + "testing" + + "github.com/TecharoHQ/anubis/lib/policy" +) + +var _ policy.Checker = &ASNChecker{} + +func TestGeoIPChecker(t *testing.T) { + cli := loadSecrets(t) + + asnc := &GeoIPChecker{ + iptoasn: cli.iptoasn, + countries: map[string]struct{}{ + "us": {}, + }, + hash: "foobar", + } + + for _, cs := range []struct { + ipAddress string + wantMatch bool + wantError bool + }{ + { + ipAddress: "1.1.1.1", + wantMatch: true, + wantError: false, + }, + { + ipAddress: "70.31.0.1", + wantMatch: false, + wantError: false, + }, + { + ipAddress: "taco", + wantMatch: false, + wantError: true, + }, + } { + t.Run(fmt.Sprintf("%v", cs), func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("X-Real-Ip", cs.ipAddress) + + match, err := asnc.Check(req) + + if match != cs.wantMatch { + t.Errorf("Wanted match: %v, got: %v", cs.wantMatch, match) + } + + switch { + case err != nil && !cs.wantError: + t.Errorf("Did not want error but got: %v", err) + case err == nil && cs.wantError: + t.Error("Wanted error but got none") + } + }) + } +}