From 86ee5697f32c954658285cf4fafe1bd0e4480522 Mon Sep 17 00:00:00 2001 From: Xe Iaso Date: Thu, 22 May 2025 11:19:01 -0400 Subject: [PATCH] chore(lib/policy): move Checker to its own package to avoid import cycles Signed-off-by: Xe Iaso --- internal/thoth/geoipchecker_test.go | 4 +-- internal/thoth/thoth.go | 20 +++++++++++-- lib/anubis_test.go | 6 ++-- lib/config.go | 5 ++-- lib/policy/bot.go | 3 +- lib/policy/checker.go | 46 +++++------------------------ lib/policy/checker/checker.go | 41 +++++++++++++++++++++++++ lib/policy/config/asn.go | 44 +++++++++++++++++++++++++++ lib/policy/config/config.go | 12 ++++---- lib/policy/config/geoip.go | 36 ++++++++++++++++++++++ lib/policy/config/geoip_test.go | 33 +++++++++++++++++++++ lib/policy/policy.go | 20 +++++++++++-- lib/policy/policy_test.go | 6 ++-- 13 files changed, 218 insertions(+), 58 deletions(-) create mode 100644 lib/policy/checker/checker.go create mode 100644 lib/policy/config/asn.go create mode 100644 lib/policy/config/geoip.go create mode 100644 lib/policy/config/geoip_test.go diff --git a/internal/thoth/geoipchecker_test.go b/internal/thoth/geoipchecker_test.go index 86203b84..68aa1935 100644 --- a/internal/thoth/geoipchecker_test.go +++ b/internal/thoth/geoipchecker_test.go @@ -5,10 +5,10 @@ import ( "net/http/httptest" "testing" - "github.com/TecharoHQ/anubis/lib/policy" + "github.com/TecharoHQ/anubis/lib/policy/checker" ) -var _ policy.Checker = &ASNChecker{} +var _ checker.Impl = &ASNChecker{} func TestGeoIPChecker(t *testing.T) { cli := loadSecrets(t) diff --git a/internal/thoth/thoth.go b/internal/thoth/thoth.go index 43f2f27d..1cf76939 100644 --- a/internal/thoth/thoth.go +++ b/internal/thoth/thoth.go @@ -8,7 +8,7 @@ import ( "time" "github.com/TecharoHQ/anubis/internal" - "github.com/TecharoHQ/anubis/lib/policy" + "github.com/TecharoHQ/anubis/lib/policy/checker" iptoasnv1 "github.com/TecharoHQ/thoth-proto/gen/techaro/thoth/iptoasn/v1" grpcprom "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/timeout" @@ -74,7 +74,7 @@ func (c *Client) Close() error { return c.conn.Close() } -func (c *Client) ASNCheckerFor(asns []uint32) policy.Checker { +func (c *Client) ASNCheckerFor(asns []uint32) checker.Impl { asnMap := map[uint32]struct{}{} var sb strings.Builder fmt.Fprintln(&sb, "ASNChecker") @@ -89,3 +89,19 @@ func (c *Client) ASNCheckerFor(asns []uint32) policy.Checker { hash: internal.SHA256sum(sb.String()), } } + +func (c *Client) GeoIPCheckerFor(countries []string) checker.Impl { + countryMap := map[string]struct{}{} + var sb strings.Builder + fmt.Fprintln(&sb, "GeoIPChecker") + for _, cc := range countries { + countryMap[cc] = struct{}{} + fmt.Fprintln(&sb, cc) + } + + return &GeoIPChecker{ + iptoasn: c.iptoasn, + countries: countryMap, + hash: sb.String(), + } +} diff --git a/lib/anubis_test.go b/lib/anubis_test.go index 55aa09a4..c09452f6 100644 --- a/lib/anubis_test.go +++ b/lib/anubis_test.go @@ -21,7 +21,7 @@ import ( func loadPolicies(t *testing.T, fname string) *policy.ParsedConfig { t.Helper() - anubisPolicy, err := LoadPoliciesOrDefault(fname, anubis.DefaultDifficulty) + anubisPolicy, err := LoadPoliciesOrDefault(t.Context(), fname, anubis.DefaultDifficulty) if err != nil { t.Fatal(err) } @@ -118,7 +118,7 @@ func TestLoadPolicies(t *testing.T) { } defer fin.Close() - if _, err := policy.ParseConfig(fin, fname, 4); err != nil { + if _, err := policy.ParseConfig(t.Context(), fin, fname, 4); err != nil { t.Fatal(err) } }) @@ -268,7 +268,7 @@ func TestCheckDefaultDifficultyMatchesPolicy(t *testing.T) { for i := 1; i < 10; i++ { t.Run(fmt.Sprint(i), func(t *testing.T) { - anubisPolicy, err := LoadPoliciesOrDefault("", i) + anubisPolicy, err := LoadPoliciesOrDefault(t.Context(), "", i) if err != nil { t.Fatal(err) } diff --git a/lib/config.go b/lib/config.go index 7d08e20a..f7785094 100644 --- a/lib/config.go +++ b/lib/config.go @@ -1,6 +1,7 @@ package lib import ( + "context" "crypto/ed25519" "crypto/rand" "fmt" @@ -40,7 +41,7 @@ type Options struct { ServeRobotsTXT bool } -func LoadPoliciesOrDefault(fname string, defaultDifficulty int) (*policy.ParsedConfig, error) { +func LoadPoliciesOrDefault(ctx context.Context, fname string, defaultDifficulty int) (*policy.ParsedConfig, error) { var fin io.ReadCloser var err error @@ -64,7 +65,7 @@ func LoadPoliciesOrDefault(fname string, defaultDifficulty int) (*policy.ParsedC } }(fin) - anubisPolicy, err := policy.ParseConfig(fin, fname, defaultDifficulty) + anubisPolicy, err := policy.ParseConfig(ctx, fin, fname, defaultDifficulty) return anubisPolicy, err } diff --git a/lib/policy/bot.go b/lib/policy/bot.go index 3e7a63ad..fdcc8f1c 100644 --- a/lib/policy/bot.go +++ b/lib/policy/bot.go @@ -4,11 +4,12 @@ import ( "fmt" "github.com/TecharoHQ/anubis/internal" + "github.com/TecharoHQ/anubis/lib/policy/checker" "github.com/TecharoHQ/anubis/lib/policy/config" ) type Bot struct { - Rules Checker + Rules checker.Impl Challenge *config.ChallengeRules Name string Action config.Rule diff --git a/lib/policy/checker.go b/lib/policy/checker.go index 51223cdb..2c3828cc 100644 --- a/lib/policy/checker.go +++ b/lib/policy/checker.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/TecharoHQ/anubis/internal" + "github.com/TecharoHQ/anubis/lib/policy/checker" "github.com/yl2chen/cidranger" ) @@ -16,43 +17,12 @@ var ( ErrMisconfiguration = errors.New("[unexpected] policy: administrator misconfiguration") ) -type Checker interface { - Check(*http.Request) (bool, error) - Hash() string -} - -type CheckerList []Checker - -func (cl CheckerList) Check(r *http.Request) (bool, error) { - for _, c := range cl { - ok, err := c.Check(r) - if err != nil { - return ok, err - } - if ok { - return ok, nil - } - } - - return false, nil -} - -func (cl CheckerList) Hash() string { - var sb strings.Builder - - for _, c := range cl { - fmt.Fprintln(&sb, c.Hash()) - } - - return internal.SHA256sum(sb.String()) -} - type RemoteAddrChecker struct { ranger cidranger.Ranger hash string } -func NewRemoteAddrChecker(cidrs []string) (Checker, error) { +func NewRemoteAddrChecker(cidrs []string) (checker.Impl, error) { ranger := cidranger.NewPCTrieRanger() var sb strings.Builder @@ -105,11 +75,11 @@ type HeaderMatchesChecker struct { hash string } -func NewUserAgentChecker(rexStr string) (Checker, error) { +func NewUserAgentChecker(rexStr string) (checker.Impl, error) { return NewHeaderMatchesChecker("User-Agent", rexStr) } -func NewHeaderMatchesChecker(header, rexStr string) (Checker, error) { +func NewHeaderMatchesChecker(header, rexStr string) (checker.Impl, error) { rex, err := regexp.Compile(strings.TrimSpace(rexStr)) if err != nil { return nil, fmt.Errorf("%w: regex %s failed parse: %w", ErrMisconfiguration, rexStr, err) @@ -134,7 +104,7 @@ type PathChecker struct { hash string } -func NewPathChecker(rexStr string) (Checker, error) { +func NewPathChecker(rexStr string) (checker.Impl, error) { rex, err := regexp.Compile(strings.TrimSpace(rexStr)) if err != nil { return nil, fmt.Errorf("%w: regex %s failed parse: %w", ErrMisconfiguration, rexStr, err) @@ -154,7 +124,7 @@ func (pc *PathChecker) Hash() string { return pc.hash } -func NewHeaderExistsChecker(key string) Checker { +func NewHeaderExistsChecker(key string) checker.Impl { return headerExistsChecker{strings.TrimSpace(key)} } @@ -174,8 +144,8 @@ func (hec headerExistsChecker) Hash() string { return internal.SHA256sum(hec.header) } -func NewHeadersChecker(headermap map[string]string) (Checker, error) { - var result CheckerList +func NewHeadersChecker(headermap map[string]string) (checker.Impl, error) { + var result checker.List var errs []error for key, rexStr := range headermap { diff --git a/lib/policy/checker/checker.go b/lib/policy/checker/checker.go new file mode 100644 index 00000000..4d7b5c79 --- /dev/null +++ b/lib/policy/checker/checker.go @@ -0,0 +1,41 @@ +// Package checker defines the Checker interface and a helper utility to avoid import cycles. +package checker + +import ( + "fmt" + "net/http" + "strings" + + "github.com/TecharoHQ/anubis/internal" +) + +type Impl interface { + Check(*http.Request) (bool, error) + Hash() string +} + +type List []Impl + +func (l List) Check(r *http.Request) (bool, error) { + for _, c := range l { + ok, err := c.Check(r) + if err != nil { + return ok, err + } + if ok { + return ok, nil + } + } + + return false, nil +} + +func (l List) Hash() string { + var sb strings.Builder + + for _, c := range l { + fmt.Fprintln(&sb, c.Hash()) + } + + return internal.SHA256sum(sb.String()) +} diff --git a/lib/policy/config/asn.go b/lib/policy/config/asn.go new file mode 100644 index 00000000..2b92ff3a --- /dev/null +++ b/lib/policy/config/asn.go @@ -0,0 +1,44 @@ +package config + +import ( + "errors" + "fmt" +) + +var ( + ErrPrivateASN = errors.New("bot.ASNs: you have specified a private use ASN") +) + +type ASNs struct { + Match []uint32 `json:"match"` +} + +func (a *ASNs) Valid() error { + var errs []error + + for _, asn := range a.Match { + if isPrivateASN(asn) { + errs = append(errs, fmt.Errorf("%w: %d is private (see RFC 6996)", ErrPrivateASN, asn)) + } + } + + if len(errs) != 0 { + return fmt.Errorf("bot.ASNs: invalid ASN settings: %w", errors.Join(errs...)) + } + + return nil +} + +// isPrivateASN checks if an ASN is in the private use area. +// +// Based on RFC 6996 and IANA allocations. +func isPrivateASN(asn uint32) bool { + switch { + case asn >= 64512 && asn <= 65534: + return true + case asn >= 4200000000 && asn <= 4294967294: + return true + default: + return false + } +} diff --git a/lib/policy/config/config.go b/lib/policy/config/config.go index 4b6f6435..8e0f8ca5 100644 --- a/lib/policy/config/config.go +++ b/lib/policy/config/config.go @@ -51,14 +51,16 @@ const ( ) type BotConfig struct { - UserAgentRegex *string `json:"user_agent_regex"` - PathRegex *string `json:"path_regex"` - HeadersRegex map[string]string `json:"headers_regex"` - Expression *ExpressionOrList `json:"expression"` + UserAgentRegex *string `json:"user_agent_regex,omitempty"` + PathRegex *string `json:"path_regex,omitempty"` + HeadersRegex map[string]string `json:"headers_regex,omitempty"` + Expression *ExpressionOrList `json:"expression,omitempty"` Challenge *ChallengeRules `json:"challenge,omitempty"` + GeoIP *GeoIP `json:"geoip,omitempty"` + ASNs *ASNs `json:"asns,omitempty"` Name string `json:"name"` Action Rule `json:"action"` - RemoteAddr []string `json:"remote_addresses"` + RemoteAddr []string `json:"remote_addresses,omitempty"` } func (b BotConfig) Zero() bool { diff --git a/lib/policy/config/geoip.go b/lib/policy/config/geoip.go new file mode 100644 index 00000000..d62d0274 --- /dev/null +++ b/lib/policy/config/geoip.go @@ -0,0 +1,36 @@ +package config + +import ( + "errors" + "fmt" + "regexp" + "strings" +) + +var ( + countryCodeRegexp = regexp.MustCompile(`^\w{2}$`) + + ErrNotCountryCode = errors.New("config.Bot: invalid country code") +) + +type GeoIP struct { + Countries []string `json:"countries"` +} + +func (g *GeoIP) Valid() error { + var errs []error + + for i, cc := range g.Countries { + if !countryCodeRegexp.MatchString(cc) { + errs = append(errs, fmt.Errorf("%w: %s", ErrNotCountryCode, cc)) + } + + g.Countries[i] = strings.ToLower(cc) + } + + if len(errs) != 0 { + return fmt.Errorf("bot.GeoIP: invalid GeoIP settings: %w", errors.Join(errs...)) + } + + return nil +} diff --git a/lib/policy/config/geoip_test.go b/lib/policy/config/geoip_test.go new file mode 100644 index 00000000..e6108577 --- /dev/null +++ b/lib/policy/config/geoip_test.go @@ -0,0 +1,33 @@ +package config + +import ( + "errors" + "testing" +) + +func TestGeoIPValid(t *testing.T) { + for _, cs := range []struct { + name string + countries []string + err error + }{ + { + name: "basic-working", + countries: []string{"US", "Ca", "mx"}, + err: nil, + }, + } { + t.Run(cs.name, func(t *testing.T) { + g := &GeoIP{ + Countries: cs.countries, + } + err := g.Valid() + if !errors.Is(err, cs.err) { + t.Fatalf("wanted error %v but got: %v", cs.err, err) + } + if err == nil && cs.err != nil { + t.Fatalf("wanted error %v but got none", cs.err) + } + }) + } +} diff --git a/lib/policy/policy.go b/lib/policy/policy.go index 1dfeafbb..9d5980a5 100644 --- a/lib/policy/policy.go +++ b/lib/policy/policy.go @@ -1,6 +1,7 @@ package policy import ( + "context" "errors" "fmt" "io" @@ -8,6 +9,8 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" + "github.com/TecharoHQ/anubis/internal/thoth" + "github.com/TecharoHQ/anubis/lib/policy/checker" "github.com/TecharoHQ/anubis/lib/policy/config" ) @@ -16,6 +19,8 @@ var ( Name: "anubis_policy_results", Help: "The results of each policy rule", }, []string{"rule", "action"}) + + ErrNoThothClient = errors.New("config: you have specified Thoth related checks but have no active Thoth client") ) type ParsedConfig struct { @@ -34,7 +39,7 @@ func NewParsedConfig(orig *config.Config) *ParsedConfig { } } -func ParseConfig(fin io.Reader, fname string, defaultDifficulty int) (*ParsedConfig, error) { +func ParseConfig(ctx context.Context, fin io.Reader, fname string, defaultDifficulty int) (*ParsedConfig, error) { c, err := config.Load(fin, fname) if err != nil { return nil, err @@ -42,6 +47,8 @@ func ParseConfig(fin io.Reader, fname string, defaultDifficulty int) (*ParsedCon var validationErrs []error + tc, hasThothClient := thoth.FromContext(ctx) + result := NewParsedConfig(c) result.DefaultDifficulty = defaultDifficulty @@ -56,7 +63,7 @@ func ParseConfig(fin io.Reader, fname string, defaultDifficulty int) (*ParsedCon Action: b.Action, } - cl := CheckerList{} + cl := checker.List{} if len(b.RemoteAddr) > 0 { c, err := NewRemoteAddrChecker(b.RemoteAddr) @@ -103,6 +110,15 @@ func ParseConfig(fin io.Reader, fname string, defaultDifficulty int) (*ParsedCon } } + if b.ASNs != nil { + if !hasThothClient { + validationErrs = append(validationErrs, fmt.Errorf("%w: %w", ErrMisconfiguration, ErrNoThothClient)) + continue + } + + cl = append(cl, tc.ASNCheckerFor(b.ASNs.Match)) + } + if b.Challenge == nil { parsedBot.Challenge = &config.ChallengeRules{ Difficulty: defaultDifficulty, diff --git a/lib/policy/policy_test.go b/lib/policy/policy_test.go index 16ca9c70..e6aec8b5 100644 --- a/lib/policy/policy_test.go +++ b/lib/policy/policy_test.go @@ -16,7 +16,7 @@ func TestDefaultPolicyMustParse(t *testing.T) { } defer fin.Close() - if _, err := ParseConfig(fin, "botPolicies.json", anubis.DefaultDifficulty); err != nil { + if _, err := ParseConfig(t.Context(), fin, "botPolicies.json", anubis.DefaultDifficulty); err != nil { t.Fatalf("can't parse config: %v", err) } } @@ -36,7 +36,7 @@ func TestGoodConfigs(t *testing.T) { } defer fin.Close() - if _, err := ParseConfig(fin, fin.Name(), anubis.DefaultDifficulty); err != nil { + if _, err := ParseConfig(t.Context(), fin, fin.Name(), anubis.DefaultDifficulty); err != nil { t.Fatal(err) } }) @@ -58,7 +58,7 @@ func TestBadConfigs(t *testing.T) { } defer fin.Close() - if _, err := ParseConfig(fin, fin.Name(), anubis.DefaultDifficulty); err == nil { + if _, err := ParseConfig(t.Context(), fin, fin.Name(), anubis.DefaultDifficulty); err == nil { t.Fatal(err) } else { t.Log(err)