chore(lib/policy): move Checker to its own package to avoid import cycles

Signed-off-by: Xe Iaso <me@xeiaso.net>
This commit is contained in:
Xe Iaso
2025-05-22 11:19:01 -04:00
parent 9bb38d6ad0
commit 86ee5697f3
13 changed files with 218 additions and 58 deletions
+2 -2
View File
@@ -5,10 +5,10 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/TecharoHQ/anubis/lib/policy" "github.com/TecharoHQ/anubis/lib/policy/checker"
) )
var _ policy.Checker = &ASNChecker{} var _ checker.Impl = &ASNChecker{}
func TestGeoIPChecker(t *testing.T) { func TestGeoIPChecker(t *testing.T) {
cli := loadSecrets(t) cli := loadSecrets(t)
+18 -2
View File
@@ -8,7 +8,7 @@ import (
"time" "time"
"github.com/TecharoHQ/anubis/internal" "github.com/TecharoHQ/anubis/internal"
"github.com/TecharoHQ/anubis/lib/policy" "github.com/TecharoHQ/anubis/lib/policy/checker"
iptoasnv1 "github.com/TecharoHQ/thoth-proto/gen/techaro/thoth/iptoasn/v1" iptoasnv1 "github.com/TecharoHQ/thoth-proto/gen/techaro/thoth/iptoasn/v1"
grpcprom "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus" grpcprom "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/timeout" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/timeout"
@@ -74,7 +74,7 @@ func (c *Client) Close() error {
return c.conn.Close() return c.conn.Close()
} }
func (c *Client) ASNCheckerFor(asns []uint32) policy.Checker { func (c *Client) ASNCheckerFor(asns []uint32) checker.Impl {
asnMap := map[uint32]struct{}{} asnMap := map[uint32]struct{}{}
var sb strings.Builder var sb strings.Builder
fmt.Fprintln(&sb, "ASNChecker") fmt.Fprintln(&sb, "ASNChecker")
@@ -89,3 +89,19 @@ func (c *Client) ASNCheckerFor(asns []uint32) policy.Checker {
hash: internal.SHA256sum(sb.String()), hash: internal.SHA256sum(sb.String()),
} }
} }
func (c *Client) GeoIPCheckerFor(countries []string) checker.Impl {
countryMap := map[string]struct{}{}
var sb strings.Builder
fmt.Fprintln(&sb, "GeoIPChecker")
for _, cc := range countries {
countryMap[cc] = struct{}{}
fmt.Fprintln(&sb, cc)
}
return &GeoIPChecker{
iptoasn: c.iptoasn,
countries: countryMap,
hash: sb.String(),
}
}
+3 -3
View File
@@ -21,7 +21,7 @@ import (
func loadPolicies(t *testing.T, fname string) *policy.ParsedConfig { func loadPolicies(t *testing.T, fname string) *policy.ParsedConfig {
t.Helper() t.Helper()
anubisPolicy, err := LoadPoliciesOrDefault(fname, anubis.DefaultDifficulty) anubisPolicy, err := LoadPoliciesOrDefault(t.Context(), fname, anubis.DefaultDifficulty)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -118,7 +118,7 @@ func TestLoadPolicies(t *testing.T) {
} }
defer fin.Close() defer fin.Close()
if _, err := policy.ParseConfig(fin, fname, 4); err != nil { if _, err := policy.ParseConfig(t.Context(), fin, fname, 4); err != nil {
t.Fatal(err) t.Fatal(err)
} }
}) })
@@ -268,7 +268,7 @@ func TestCheckDefaultDifficultyMatchesPolicy(t *testing.T) {
for i := 1; i < 10; i++ { for i := 1; i < 10; i++ {
t.Run(fmt.Sprint(i), func(t *testing.T) { t.Run(fmt.Sprint(i), func(t *testing.T) {
anubisPolicy, err := LoadPoliciesOrDefault("", i) anubisPolicy, err := LoadPoliciesOrDefault(t.Context(), "", i)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
+3 -2
View File
@@ -1,6 +1,7 @@
package lib package lib
import ( import (
"context"
"crypto/ed25519" "crypto/ed25519"
"crypto/rand" "crypto/rand"
"fmt" "fmt"
@@ -40,7 +41,7 @@ type Options struct {
ServeRobotsTXT bool ServeRobotsTXT bool
} }
func LoadPoliciesOrDefault(fname string, defaultDifficulty int) (*policy.ParsedConfig, error) { func LoadPoliciesOrDefault(ctx context.Context, fname string, defaultDifficulty int) (*policy.ParsedConfig, error) {
var fin io.ReadCloser var fin io.ReadCloser
var err error var err error
@@ -64,7 +65,7 @@ func LoadPoliciesOrDefault(fname string, defaultDifficulty int) (*policy.ParsedC
} }
}(fin) }(fin)
anubisPolicy, err := policy.ParseConfig(fin, fname, defaultDifficulty) anubisPolicy, err := policy.ParseConfig(ctx, fin, fname, defaultDifficulty)
return anubisPolicy, err return anubisPolicy, err
} }
+2 -1
View File
@@ -4,11 +4,12 @@ import (
"fmt" "fmt"
"github.com/TecharoHQ/anubis/internal" "github.com/TecharoHQ/anubis/internal"
"github.com/TecharoHQ/anubis/lib/policy/checker"
"github.com/TecharoHQ/anubis/lib/policy/config" "github.com/TecharoHQ/anubis/lib/policy/config"
) )
type Bot struct { type Bot struct {
Rules Checker Rules checker.Impl
Challenge *config.ChallengeRules Challenge *config.ChallengeRules
Name string Name string
Action config.Rule Action config.Rule
+8 -38
View File
@@ -9,6 +9,7 @@ import (
"strings" "strings"
"github.com/TecharoHQ/anubis/internal" "github.com/TecharoHQ/anubis/internal"
"github.com/TecharoHQ/anubis/lib/policy/checker"
"github.com/yl2chen/cidranger" "github.com/yl2chen/cidranger"
) )
@@ -16,43 +17,12 @@ var (
ErrMisconfiguration = errors.New("[unexpected] policy: administrator misconfiguration") ErrMisconfiguration = errors.New("[unexpected] policy: administrator misconfiguration")
) )
type Checker interface {
Check(*http.Request) (bool, error)
Hash() string
}
type CheckerList []Checker
func (cl CheckerList) Check(r *http.Request) (bool, error) {
for _, c := range cl {
ok, err := c.Check(r)
if err != nil {
return ok, err
}
if ok {
return ok, nil
}
}
return false, nil
}
func (cl CheckerList) Hash() string {
var sb strings.Builder
for _, c := range cl {
fmt.Fprintln(&sb, c.Hash())
}
return internal.SHA256sum(sb.String())
}
type RemoteAddrChecker struct { type RemoteAddrChecker struct {
ranger cidranger.Ranger ranger cidranger.Ranger
hash string hash string
} }
func NewRemoteAddrChecker(cidrs []string) (Checker, error) { func NewRemoteAddrChecker(cidrs []string) (checker.Impl, error) {
ranger := cidranger.NewPCTrieRanger() ranger := cidranger.NewPCTrieRanger()
var sb strings.Builder var sb strings.Builder
@@ -105,11 +75,11 @@ type HeaderMatchesChecker struct {
hash string hash string
} }
func NewUserAgentChecker(rexStr string) (Checker, error) { func NewUserAgentChecker(rexStr string) (checker.Impl, error) {
return NewHeaderMatchesChecker("User-Agent", rexStr) return NewHeaderMatchesChecker("User-Agent", rexStr)
} }
func NewHeaderMatchesChecker(header, rexStr string) (Checker, error) { func NewHeaderMatchesChecker(header, rexStr string) (checker.Impl, error) {
rex, err := regexp.Compile(strings.TrimSpace(rexStr)) rex, err := regexp.Compile(strings.TrimSpace(rexStr))
if err != nil { if err != nil {
return nil, fmt.Errorf("%w: regex %s failed parse: %w", ErrMisconfiguration, rexStr, err) return nil, fmt.Errorf("%w: regex %s failed parse: %w", ErrMisconfiguration, rexStr, err)
@@ -134,7 +104,7 @@ type PathChecker struct {
hash string hash string
} }
func NewPathChecker(rexStr string) (Checker, error) { func NewPathChecker(rexStr string) (checker.Impl, error) {
rex, err := regexp.Compile(strings.TrimSpace(rexStr)) rex, err := regexp.Compile(strings.TrimSpace(rexStr))
if err != nil { if err != nil {
return nil, fmt.Errorf("%w: regex %s failed parse: %w", ErrMisconfiguration, rexStr, err) return nil, fmt.Errorf("%w: regex %s failed parse: %w", ErrMisconfiguration, rexStr, err)
@@ -154,7 +124,7 @@ func (pc *PathChecker) Hash() string {
return pc.hash return pc.hash
} }
func NewHeaderExistsChecker(key string) Checker { func NewHeaderExistsChecker(key string) checker.Impl {
return headerExistsChecker{strings.TrimSpace(key)} return headerExistsChecker{strings.TrimSpace(key)}
} }
@@ -174,8 +144,8 @@ func (hec headerExistsChecker) Hash() string {
return internal.SHA256sum(hec.header) return internal.SHA256sum(hec.header)
} }
func NewHeadersChecker(headermap map[string]string) (Checker, error) { func NewHeadersChecker(headermap map[string]string) (checker.Impl, error) {
var result CheckerList var result checker.List
var errs []error var errs []error
for key, rexStr := range headermap { for key, rexStr := range headermap {
+41
View File
@@ -0,0 +1,41 @@
// Package checker defines the Checker interface and a helper utility to avoid import cycles.
package checker
import (
"fmt"
"net/http"
"strings"
"github.com/TecharoHQ/anubis/internal"
)
type Impl interface {
Check(*http.Request) (bool, error)
Hash() string
}
type List []Impl
func (l List) Check(r *http.Request) (bool, error) {
for _, c := range l {
ok, err := c.Check(r)
if err != nil {
return ok, err
}
if ok {
return ok, nil
}
}
return false, nil
}
func (l List) Hash() string {
var sb strings.Builder
for _, c := range l {
fmt.Fprintln(&sb, c.Hash())
}
return internal.SHA256sum(sb.String())
}
+44
View File
@@ -0,0 +1,44 @@
package config
import (
"errors"
"fmt"
)
var (
ErrPrivateASN = errors.New("bot.ASNs: you have specified a private use ASN")
)
type ASNs struct {
Match []uint32 `json:"match"`
}
func (a *ASNs) Valid() error {
var errs []error
for _, asn := range a.Match {
if isPrivateASN(asn) {
errs = append(errs, fmt.Errorf("%w: %d is private (see RFC 6996)", ErrPrivateASN, asn))
}
}
if len(errs) != 0 {
return fmt.Errorf("bot.ASNs: invalid ASN settings: %w", errors.Join(errs...))
}
return nil
}
// isPrivateASN checks if an ASN is in the private use area.
//
// Based on RFC 6996 and IANA allocations.
func isPrivateASN(asn uint32) bool {
switch {
case asn >= 64512 && asn <= 65534:
return true
case asn >= 4200000000 && asn <= 4294967294:
return true
default:
return false
}
}
+7 -5
View File
@@ -51,14 +51,16 @@ const (
) )
type BotConfig struct { type BotConfig struct {
UserAgentRegex *string `json:"user_agent_regex"` UserAgentRegex *string `json:"user_agent_regex,omitempty"`
PathRegex *string `json:"path_regex"` PathRegex *string `json:"path_regex,omitempty"`
HeadersRegex map[string]string `json:"headers_regex"` HeadersRegex map[string]string `json:"headers_regex,omitempty"`
Expression *ExpressionOrList `json:"expression"` Expression *ExpressionOrList `json:"expression,omitempty"`
Challenge *ChallengeRules `json:"challenge,omitempty"` Challenge *ChallengeRules `json:"challenge,omitempty"`
GeoIP *GeoIP `json:"geoip,omitempty"`
ASNs *ASNs `json:"asns,omitempty"`
Name string `json:"name"` Name string `json:"name"`
Action Rule `json:"action"` Action Rule `json:"action"`
RemoteAddr []string `json:"remote_addresses"` RemoteAddr []string `json:"remote_addresses,omitempty"`
} }
func (b BotConfig) Zero() bool { func (b BotConfig) Zero() bool {
+36
View File
@@ -0,0 +1,36 @@
package config
import (
"errors"
"fmt"
"regexp"
"strings"
)
var (
countryCodeRegexp = regexp.MustCompile(`^\w{2}$`)
ErrNotCountryCode = errors.New("config.Bot: invalid country code")
)
type GeoIP struct {
Countries []string `json:"countries"`
}
func (g *GeoIP) Valid() error {
var errs []error
for i, cc := range g.Countries {
if !countryCodeRegexp.MatchString(cc) {
errs = append(errs, fmt.Errorf("%w: %s", ErrNotCountryCode, cc))
}
g.Countries[i] = strings.ToLower(cc)
}
if len(errs) != 0 {
return fmt.Errorf("bot.GeoIP: invalid GeoIP settings: %w", errors.Join(errs...))
}
return nil
}
+33
View File
@@ -0,0 +1,33 @@
package config
import (
"errors"
"testing"
)
func TestGeoIPValid(t *testing.T) {
for _, cs := range []struct {
name string
countries []string
err error
}{
{
name: "basic-working",
countries: []string{"US", "Ca", "mx"},
err: nil,
},
} {
t.Run(cs.name, func(t *testing.T) {
g := &GeoIP{
Countries: cs.countries,
}
err := g.Valid()
if !errors.Is(err, cs.err) {
t.Fatalf("wanted error %v but got: %v", cs.err, err)
}
if err == nil && cs.err != nil {
t.Fatalf("wanted error %v but got none", cs.err)
}
})
}
}
+18 -2
View File
@@ -1,6 +1,7 @@
package policy package policy
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -8,6 +9,8 @@ import (
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/client_golang/prometheus/promauto"
"github.com/TecharoHQ/anubis/internal/thoth"
"github.com/TecharoHQ/anubis/lib/policy/checker"
"github.com/TecharoHQ/anubis/lib/policy/config" "github.com/TecharoHQ/anubis/lib/policy/config"
) )
@@ -16,6 +19,8 @@ var (
Name: "anubis_policy_results", Name: "anubis_policy_results",
Help: "The results of each policy rule", Help: "The results of each policy rule",
}, []string{"rule", "action"}) }, []string{"rule", "action"})
ErrNoThothClient = errors.New("config: you have specified Thoth related checks but have no active Thoth client")
) )
type ParsedConfig struct { type ParsedConfig struct {
@@ -34,7 +39,7 @@ func NewParsedConfig(orig *config.Config) *ParsedConfig {
} }
} }
func ParseConfig(fin io.Reader, fname string, defaultDifficulty int) (*ParsedConfig, error) { func ParseConfig(ctx context.Context, fin io.Reader, fname string, defaultDifficulty int) (*ParsedConfig, error) {
c, err := config.Load(fin, fname) c, err := config.Load(fin, fname)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -42,6 +47,8 @@ func ParseConfig(fin io.Reader, fname string, defaultDifficulty int) (*ParsedCon
var validationErrs []error var validationErrs []error
tc, hasThothClient := thoth.FromContext(ctx)
result := NewParsedConfig(c) result := NewParsedConfig(c)
result.DefaultDifficulty = defaultDifficulty result.DefaultDifficulty = defaultDifficulty
@@ -56,7 +63,7 @@ func ParseConfig(fin io.Reader, fname string, defaultDifficulty int) (*ParsedCon
Action: b.Action, Action: b.Action,
} }
cl := CheckerList{} cl := checker.List{}
if len(b.RemoteAddr) > 0 { if len(b.RemoteAddr) > 0 {
c, err := NewRemoteAddrChecker(b.RemoteAddr) c, err := NewRemoteAddrChecker(b.RemoteAddr)
@@ -103,6 +110,15 @@ func ParseConfig(fin io.Reader, fname string, defaultDifficulty int) (*ParsedCon
} }
} }
if b.ASNs != nil {
if !hasThothClient {
validationErrs = append(validationErrs, fmt.Errorf("%w: %w", ErrMisconfiguration, ErrNoThothClient))
continue
}
cl = append(cl, tc.ASNCheckerFor(b.ASNs.Match))
}
if b.Challenge == nil { if b.Challenge == nil {
parsedBot.Challenge = &config.ChallengeRules{ parsedBot.Challenge = &config.ChallengeRules{
Difficulty: defaultDifficulty, Difficulty: defaultDifficulty,
+3 -3
View File
@@ -16,7 +16,7 @@ func TestDefaultPolicyMustParse(t *testing.T) {
} }
defer fin.Close() defer fin.Close()
if _, err := ParseConfig(fin, "botPolicies.json", anubis.DefaultDifficulty); err != nil { if _, err := ParseConfig(t.Context(), fin, "botPolicies.json", anubis.DefaultDifficulty); err != nil {
t.Fatalf("can't parse config: %v", err) t.Fatalf("can't parse config: %v", err)
} }
} }
@@ -36,7 +36,7 @@ func TestGoodConfigs(t *testing.T) {
} }
defer fin.Close() defer fin.Close()
if _, err := ParseConfig(fin, fin.Name(), anubis.DefaultDifficulty); err != nil { if _, err := ParseConfig(t.Context(), fin, fin.Name(), anubis.DefaultDifficulty); err != nil {
t.Fatal(err) t.Fatal(err)
} }
}) })
@@ -58,7 +58,7 @@ func TestBadConfigs(t *testing.T) {
} }
defer fin.Close() defer fin.Close()
if _, err := ParseConfig(fin, fin.Name(), anubis.DefaultDifficulty); err == nil { if _, err := ParseConfig(t.Context(), fin, fin.Name(), anubis.DefaultDifficulty); err == nil {
t.Fatal(err) t.Fatal(err)
} else { } else {
t.Log(err) t.Log(err)