mirror of
https://github.com/TecharoHQ/anubis.git
synced 2026-04-17 05:44:57 +00:00
36
internal/thoth/geoipchecker.go
Normal file
36
internal/thoth/geoipchecker.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
package thoth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
iptoasnv1 "github.com/TecharoHQ/thoth-proto/gen/techaro/thoth/iptoasn/v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GeoIPChecker struct {
|
||||||
|
iptoasn iptoasnv1.IpToASNServiceClient
|
||||||
|
countries map[string]struct{}
|
||||||
|
hash string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (gipc *GeoIPChecker) Check(r *http.Request) (bool, error) {
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), 50*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
ipInfo, err := gipc.iptoasn.Lookup(ctx, &iptoasnv1.LookupRequest{
|
||||||
|
IpAddress: r.Header.Get("X-Real-Ip"),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, ok := gipc.countries[strings.ToLower(ipInfo.GetCountryCode())]
|
||||||
|
|
||||||
|
return ok, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (gipc *GeoIPChecker) Hash() string {
|
||||||
|
return gipc.hash
|
||||||
|
}
|
||||||
63
internal/thoth/geoipchecker_test.go
Normal file
63
internal/thoth/geoipchecker_test.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
package thoth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/TecharoHQ/anubis/lib/policy"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ policy.Checker = &ASNChecker{}
|
||||||
|
|
||||||
|
func TestGeoIPChecker(t *testing.T) {
|
||||||
|
cli := loadSecrets(t)
|
||||||
|
|
||||||
|
asnc := &GeoIPChecker{
|
||||||
|
iptoasn: cli.iptoasn,
|
||||||
|
countries: map[string]struct{}{
|
||||||
|
"us": {},
|
||||||
|
},
|
||||||
|
hash: "foobar",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, cs := range []struct {
|
||||||
|
ipAddress string
|
||||||
|
wantMatch bool
|
||||||
|
wantError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
ipAddress: "1.1.1.1",
|
||||||
|
wantMatch: true,
|
||||||
|
wantError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ipAddress: "70.31.0.1",
|
||||||
|
wantMatch: false,
|
||||||
|
wantError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ipAddress: "taco",
|
||||||
|
wantMatch: false,
|
||||||
|
wantError: true,
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
t.Run(fmt.Sprintf("%v", cs), func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
req.Header.Set("X-Real-Ip", cs.ipAddress)
|
||||||
|
|
||||||
|
match, err := asnc.Check(req)
|
||||||
|
|
||||||
|
if match != cs.wantMatch {
|
||||||
|
t.Errorf("Wanted match: %v, got: %v", cs.wantMatch, match)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case err != nil && !cs.wantError:
|
||||||
|
t.Errorf("Did not want error but got: %v", err)
|
||||||
|
case err == nil && cs.wantError:
|
||||||
|
t.Error("Wanted error but got none")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user