From 590d8303ad93d92be267db2d4f1a20462d363415 Mon Sep 17 00:00:00 2001 From: Xe Iaso Date: Fri, 25 Jul 2025 19:39:14 +0000 Subject: [PATCH] refactor: use new checker types Signed-off-by: Xe Iaso --- lib/anubis.go | 4 +- lib/checker/all.go | 35 +++ lib/checker/all_test.go | 70 ++++++ lib/checker/any.go | 35 +++ lib/checker/any_test.go | 83 +++++++ lib/checker/checker.go | 30 --- lib/checker/headermatches/checker.go | 22 ++ lib/checker/headermatches/useragent.go | 34 +++ lib/checker/remoteaddress/remoteaddress.go | 8 + .../remoteaddress/remoteaddress_test.go | 18 +- lib/policy/checker.go | 145 ++----------- lib/policy/checker_test.go | 202 ------------------ lib/policy/config/config.go | 15 +- lib/policy/config/config_test.go | 3 +- lib/policy/policy.go | 11 +- 15 files changed, 331 insertions(+), 384 deletions(-) create mode 100644 lib/checker/all.go create mode 100644 lib/checker/all_test.go create mode 100644 lib/checker/any.go create mode 100644 lib/checker/any_test.go delete mode 100644 lib/policy/checker_test.go diff --git a/lib/anubis.go b/lib/anubis.go index e375a212..b06d0162 100644 --- a/lib/anubis.go +++ b/lib/anubis.go @@ -552,7 +552,7 @@ func (s *Server) check(r *http.Request) (policy.CheckResult, *policy.Bot, error) if matches { return cr("threshold/"+t.Name, t.Action, weight), &policy.Bot{ Challenge: t.Challenge, - Rules: &checker.List{}, + Rules: &checker.Any{}, }, nil } } @@ -563,6 +563,6 @@ func (s *Server) check(r *http.Request) (policy.CheckResult, *policy.Bot, error) ReportAs: s.policy.DefaultDifficulty, Algorithm: config.DefaultAlgorithm, }, - Rules: &checker.List{}, + Rules: &checker.Any{}, }, nil } diff --git a/lib/checker/all.go b/lib/checker/all.go new file mode 100644 index 00000000..c922808e --- /dev/null +++ b/lib/checker/all.go @@ -0,0 +1,35 @@ +package checker + +import ( + "fmt" + "net/http" + "strings" + + "github.com/TecharoHQ/anubis/internal" +) + +type All []Interface + +func (a All) Check(r *http.Request) (bool, error) { + for _, c := range a { + match, err := c.Check(r) + if err != nil { + return match, err + } + if !match { + return false, err // no match + } + } + + return true, nil // match +} + +func (a All) Hash() string { + var sb strings.Builder + + for _, c := range a { + fmt.Fprintln(&sb, c.Hash()) + } + + return internal.FastHash(sb.String()) +} diff --git a/lib/checker/all_test.go b/lib/checker/all_test.go new file mode 100644 index 00000000..b0b9d4cb --- /dev/null +++ b/lib/checker/all_test.go @@ -0,0 +1,70 @@ +package checker + +import ( + "net/http" + "testing" +) + +func TestAll_Check(t *testing.T) { + tests := []struct { + name string + checkers []MockChecker + want bool + wantErr bool + }{ + { + name: "All match", + checkers: []MockChecker{ + {Result: true, Err: nil}, + {Result: true, Err: nil}, + }, + want: true, + wantErr: false, + }, + { + name: "One not match", + checkers: []MockChecker{ + {Result: true, Err: nil}, + {Result: false, Err: nil}, + }, + want: false, + wantErr: false, + }, + { + name: "No match", + checkers: []MockChecker{ + {Result: false, Err: nil}, + {Result: false, Err: nil}, + }, + want: false, + wantErr: false, + }, + { + name: "Error encountered", + checkers: []MockChecker{ + {Result: true, Err: nil}, + {Result: false, Err: http.ErrNotSupported}, + }, + want: false, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var all All + for _, mc := range tt.checkers { + all = append(all, mc) + } + + got, err := all.Check(nil) + if (err != nil) != tt.wantErr { + t.Errorf("All.Check() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("All.Check() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/lib/checker/any.go b/lib/checker/any.go new file mode 100644 index 00000000..eef4336b --- /dev/null +++ b/lib/checker/any.go @@ -0,0 +1,35 @@ +package checker + +import ( + "fmt" + "net/http" + "strings" + + "github.com/TecharoHQ/anubis/internal" +) + +type Any []Interface + +func (a Any) Check(r *http.Request) (bool, error) { + for _, c := range a { + match, err := c.Check(r) + if err != nil { + return match, err + } + if match { + return true, err // match + } + } + + return false, nil // no match +} + +func (a Any) Hash() string { + var sb strings.Builder + + for _, c := range a { + fmt.Fprintln(&sb, c.Hash()) + } + + return internal.FastHash(sb.String()) +} diff --git a/lib/checker/any_test.go b/lib/checker/any_test.go new file mode 100644 index 00000000..34732c4a --- /dev/null +++ b/lib/checker/any_test.go @@ -0,0 +1,83 @@ +package checker + +import ( + "net/http" + "testing" +) + +type MockChecker struct { + Result bool + Err error +} + +func (m MockChecker) Check(r *http.Request) (bool, error) { + return m.Result, m.Err +} + +func (m MockChecker) Hash() string { + return "mock-hash" +} + +func TestAny_Check(t *testing.T) { + tests := []struct { + name string + checkers []MockChecker + want bool + wantErr bool + }{ + { + name: "All match", + checkers: []MockChecker{ + {Result: true, Err: nil}, + {Result: true, Err: nil}, + }, + want: true, + wantErr: false, + }, + { + name: "One match", + checkers: []MockChecker{ + {Result: false, Err: nil}, + {Result: true, Err: nil}, + }, + want: true, + wantErr: false, + }, + { + name: "No match", + checkers: []MockChecker{ + {Result: false, Err: nil}, + {Result: false, Err: nil}, + }, + want: false, + wantErr: false, + }, + { + name: "Error encountered", + checkers: []MockChecker{ + {Result: false, Err: nil}, + {Result: false, Err: http.ErrNotSupported}, + }, + want: false, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var any Any + for _, mc := range tt.checkers { + any = append(any, mc) + } + + got, err := any.Check(nil) + if (err != nil) != tt.wantErr { + t.Errorf("Any.Check() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("Any.Check() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/lib/checker/checker.go b/lib/checker/checker.go index 21922f8f..8d00d42d 100644 --- a/lib/checker/checker.go +++ b/lib/checker/checker.go @@ -3,11 +3,7 @@ package checker import ( "errors" - "fmt" "net/http" - "strings" - - "github.com/TecharoHQ/anubis/internal" ) var ( @@ -19,29 +15,3 @@ type Interface interface { Check(*http.Request) (matches bool, err error) Hash() string } - -type List []Interface - -func (l List) Check(r *http.Request) (bool, error) { - for _, c := range l { - ok, err := c.Check(r) - if err != nil { - return ok, err - } - if ok { - return ok, nil - } - } - - return false, nil -} - -func (l List) Hash() string { - var sb strings.Builder - - for _, c := range l { - fmt.Fprintln(&sb, c.Hash()) - } - - return internal.FastHash(sb.String()) -} diff --git a/lib/checker/headermatches/checker.go b/lib/checker/headermatches/checker.go index 0f52f29e..177b5836 100644 --- a/lib/checker/headermatches/checker.go +++ b/lib/checker/headermatches/checker.go @@ -1,8 +1,12 @@ package headermatches import ( + "context" + "encoding/json" "net/http" "regexp" + + "github.com/TecharoHQ/anubis/lib/checker" ) type Checker struct { @@ -22,3 +26,21 @@ func (c *Checker) Check(r *http.Request) (bool, error) { func (c *Checker) Hash() string { return c.hash } + +func New(key, valueRex string) (checker.Interface, error) { + fc := fileConfig{ + Header: key, + ValueRegex: valueRex, + } + + if err := fc.Valid(); err != nil { + return nil, err + } + + data, err := json.Marshal(fc) + if err != nil { + return nil, err + } + + return Factory{}.Build(context.Background(), json.RawMessage(data)) +} diff --git a/lib/checker/headermatches/useragent.go b/lib/checker/headermatches/useragent.go index 006a91da..c9c8c158 100644 --- a/lib/checker/headermatches/useragent.go +++ b/lib/checker/headermatches/useragent.go @@ -1 +1,35 @@ package headermatches + +import ( + "context" + "encoding/json" + + "github.com/TecharoHQ/anubis/lib/checker" +) + +func ValidUserAgent(valueRex string) error { + fc := fileConfig{ + Header: "User-Agent", + ValueRegex: valueRex, + } + + return fc.Valid() +} + +func NewUserAgent(valueRex string) (checker.Interface, error) { + fc := fileConfig{ + Header: "User-Agent", + ValueRegex: valueRex, + } + + if err := fc.Valid(); err != nil { + return nil, err + } + + data, err := json.Marshal(fc) + if err != nil { + return nil, err + } + + return Factory{}.Build(context.Background(), json.RawMessage(data)) +} diff --git a/lib/checker/remoteaddress/remoteaddress.go b/lib/checker/remoteaddress/remoteaddress.go index 836af27c..70b4ab47 100644 --- a/lib/checker/remoteaddress/remoteaddress.go +++ b/lib/checker/remoteaddress/remoteaddress.go @@ -83,6 +83,14 @@ func (fc fileConfig) Valid() error { return nil } +func Valid(cidrs []string) error { + fc := fileConfig{ + RemoteAddr: cidrs, + } + + return fc.Valid() +} + func New(cidrs []string) (checker.Interface, error) { fc := fileConfig{ RemoteAddr: cidrs, diff --git a/lib/checker/remoteaddress/remoteaddress_test.go b/lib/checker/remoteaddress/remoteaddress_test.go index 7fdc0f37..b217e7a8 100644 --- a/lib/checker/remoteaddress/remoteaddress_test.go +++ b/lib/checker/remoteaddress/remoteaddress_test.go @@ -1,4 +1,4 @@ -package remoteaddress +package remoteaddress_test import ( _ "embed" @@ -8,17 +8,17 @@ import ( "testing" "github.com/TecharoHQ/anubis/lib/checker" - "github.com/TecharoHQ/anubis/lib/policy/config" + "github.com/TecharoHQ/anubis/lib/checker/remoteaddress" ) func TestFactoryIsCheckerFactory(t *testing.T) { - if _, ok := (any(Factory{})).(checker.Factory); !ok { + if _, ok := (any(remoteaddress.Factory{})).(checker.Factory); !ok { t.Fatal("Factory is not an instance of checker.Factory") } } func TestFactoryValidateConfig(t *testing.T) { - f := Factory{} + f := remoteaddress.Factory{} for _, tt := range []struct { name string @@ -36,14 +36,14 @@ func TestFactoryValidateConfig(t *testing.T) { { name: "not json", data: []byte(`]`), - err: config.ErrUnparseableConfig, + err: checker.ErrUnparseableConfig, }, { name: "no cidr", data: []byte(`{ "remote_addresses": [] }`), - err: ErrNoRemoteAddresses, + err: remoteaddress.ErrNoRemoteAddresses, }, { name: "bad cidr", @@ -52,7 +52,7 @@ func TestFactoryValidateConfig(t *testing.T) { "according to all laws of aviation" ] }`), - err: config.ErrInvalidCIDR, + err: remoteaddress.ErrInvalidCIDR, }, } { t.Run(tt.name, func(t *testing.T) { @@ -68,7 +68,7 @@ func TestFactoryValidateConfig(t *testing.T) { } func TestFactoryCreate(t *testing.T) { - f := Factory{} + f := remoteaddress.Factory{} for _, tt := range []struct { name string @@ -94,7 +94,7 @@ func TestFactoryCreate(t *testing.T) { "according to all laws of aviation" ] }`), - err: config.ErrUnparseableConfig, + err: checker.ErrUnparseableConfig, }, } { t.Run(tt.name, func(t *testing.T) { diff --git a/lib/policy/checker.go b/lib/policy/checker.go index 5422872f..a39cb978 100644 --- a/lib/policy/checker.go +++ b/lib/policy/checker.go @@ -3,150 +3,39 @@ package policy import ( "errors" "fmt" - "net/http" - "net/netip" - "regexp" + "sort" "strings" - "github.com/TecharoHQ/anubis" - "github.com/TecharoHQ/anubis/internal" "github.com/TecharoHQ/anubis/lib/checker" - "github.com/gaissmai/bart" + "github.com/TecharoHQ/anubis/lib/checker/headerexists" + "github.com/TecharoHQ/anubis/lib/checker/headermatches" ) -type RemoteAddrChecker struct { - prefixTable *bart.Lite - hash string -} - -func NewRemoteAddrChecker(cidrs []string) (checker.Interface, error) { - table := new(bart.Lite) - - for _, cidr := range cidrs { - prefix, err := netip.ParsePrefix(cidr) - if err != nil { - return nil, fmt.Errorf("%w: range %s not parsing: %w", anubis.ErrMisconfiguration, cidr, err) - } - - table.Insert(prefix) - } - - return &RemoteAddrChecker{ - prefixTable: table, - hash: internal.FastHash(strings.Join(cidrs, ",")), - }, nil -} - -func (rac *RemoteAddrChecker) Check(r *http.Request) (bool, error) { - host := r.Header.Get("X-Real-Ip") - if host == "" { - return false, fmt.Errorf("%w: header X-Real-Ip is not set", anubis.ErrMisconfiguration) - } - - addr, err := netip.ParseAddr(host) - if err != nil { - return false, fmt.Errorf("%w: %s is not an IP address: %w", anubis.ErrMisconfiguration, host, err) - } - - return rac.prefixTable.Contains(addr), nil -} - -func (rac *RemoteAddrChecker) Hash() string { - return rac.hash -} - -type HeaderMatchesChecker struct { - header string - regexp *regexp.Regexp - hash string -} - -func NewUserAgentChecker(rexStr string) (checker.Interface, error) { - return NewHeaderMatchesChecker("User-Agent", rexStr) -} - -func NewHeaderMatchesChecker(header, rexStr string) (checker.Interface, error) { - rex, err := regexp.Compile(strings.TrimSpace(rexStr)) - if err != nil { - return nil, fmt.Errorf("%w: regex %s failed parse: %w", anubis.ErrMisconfiguration, rexStr, err) - } - return &HeaderMatchesChecker{strings.TrimSpace(header), rex, internal.FastHash(header + ": " + rexStr)}, nil -} - -func (hmc *HeaderMatchesChecker) Check(r *http.Request) (bool, error) { - if hmc.regexp.MatchString(r.Header.Get(hmc.header)) { - return true, nil - } - - return false, nil -} - -func (hmc *HeaderMatchesChecker) Hash() string { - return hmc.hash -} - -type PathChecker struct { - regexp *regexp.Regexp - hash string -} - -func NewPathChecker(rexStr string) (checker.Interface, error) { - rex, err := regexp.Compile(strings.TrimSpace(rexStr)) - if err != nil { - return nil, fmt.Errorf("%w: regex %s failed parse: %w", anubis.ErrMisconfiguration, rexStr, err) - } - return &PathChecker{rex, internal.FastHash(rexStr)}, nil -} - -func (pc *PathChecker) Check(r *http.Request) (bool, error) { - if pc.regexp.MatchString(r.URL.Path) { - return true, nil - } - - return false, nil -} - -func (pc *PathChecker) Hash() string { - return pc.hash -} - -func NewHeaderExistsChecker(key string) checker.Interface { - return headerExistsChecker{strings.TrimSpace(key)} -} - -type headerExistsChecker struct { - header string -} - -func (hec headerExistsChecker) Check(r *http.Request) (bool, error) { - if r.Header.Get(hec.header) != "" { - return true, nil - } - - return false, nil -} - -func (hec headerExistsChecker) Hash() string { - return internal.FastHash(hec.header) -} - func NewHeadersChecker(headermap map[string]string) (checker.Interface, error) { - var result checker.List + var result checker.All var errs []error - for key, rexStr := range headermap { + var keys []string + for key := range headermap { + keys = append(keys, key) + } + + sort.Strings(keys) + + for _, key := range keys { + rexStr := headermap[key] if rexStr == ".*" { - result = append(result, headerExistsChecker{strings.TrimSpace(key)}) + result = append(result, headerexists.New(strings.TrimSpace(key))) continue } - rex, err := regexp.Compile(strings.TrimSpace(rexStr)) + c, err := headermatches.New(key, rexStr) if err != nil { - errs = append(errs, fmt.Errorf("while compiling header %s regex %s: %w", key, rexStr, err)) + errs = append(errs, fmt.Errorf("while parsing header %s regex %s: %w", key, rexStr, err)) continue } - result = append(result, &HeaderMatchesChecker{key, rex, internal.FastHash(key + ": " + rexStr)}) + result = append(result, c) } if len(errs) != 0 { diff --git a/lib/policy/checker_test.go b/lib/policy/checker_test.go deleted file mode 100644 index 2414ccf5..00000000 --- a/lib/policy/checker_test.go +++ /dev/null @@ -1,202 +0,0 @@ -package policy - -import ( - "errors" - "net/http" - "testing" - - "github.com/TecharoHQ/anubis" -) - -func TestRemoteAddrChecker(t *testing.T) { - for _, tt := range []struct { - err error - name string - ip string - cidrs []string - ok bool - }{ - { - name: "match_ipv4", - cidrs: []string{"0.0.0.0/0"}, - ip: "1.1.1.1", - ok: true, - err: nil, - }, - { - name: "match_ipv6", - cidrs: []string{"::/0"}, - ip: "cafe:babe::", - ok: true, - err: nil, - }, - { - name: "not_match_ipv4", - cidrs: []string{"1.1.1.1/32"}, - ip: "1.1.1.2", - ok: false, - err: nil, - }, - { - name: "not_match_ipv6", - cidrs: []string{"cafe:babe::/128"}, - ip: "cafe:babe:4::/128", - ok: false, - err: nil, - }, - { - name: "no_ip_set", - cidrs: []string{"::/0"}, - ok: false, - err: anubis.ErrMisconfiguration, - }, - { - name: "invalid_ip", - cidrs: []string{"::/0"}, - ip: "According to all natural laws of aviation", - ok: false, - err: anubis.ErrMisconfiguration, - }, - } { - t.Run(tt.name, func(t *testing.T) { - rac, err := NewRemoteAddrChecker(tt.cidrs) - if err != nil && !errors.Is(err, tt.err) { - t.Fatalf("creating RemoteAddrChecker failed: %v", err) - } - - r, err := http.NewRequest(http.MethodGet, "/", nil) - if err != nil { - t.Fatalf("can't make request: %v", err) - } - - if tt.ip != "" { - r.Header.Add("X-Real-Ip", tt.ip) - } - - ok, err := rac.Check(r) - - if tt.ok != ok { - t.Errorf("ok: %v, wanted: %v", ok, tt.ok) - } - - if err != nil && tt.err != nil && !errors.Is(err, tt.err) { - t.Errorf("err: %v, wanted: %v", err, tt.err) - } - }) - } -} - -func TestHeaderMatchesChecker(t *testing.T) { - for _, tt := range []struct { - err error - name string - header string - rexStr string - reqHeaderKey string - reqHeaderValue string - ok bool - }{ - { - name: "match", - header: "Cf-Worker", - rexStr: ".*", - reqHeaderKey: "Cf-Worker", - reqHeaderValue: "true", - ok: true, - err: nil, - }, - { - name: "not_match", - header: "Cf-Worker", - rexStr: "false", - reqHeaderKey: "Cf-Worker", - reqHeaderValue: "true", - ok: false, - err: nil, - }, - { - name: "not_present", - header: "Cf-Worker", - rexStr: "foobar", - reqHeaderKey: "Something-Else", - reqHeaderValue: "true", - ok: false, - err: nil, - }, - { - name: "invalid_regex", - rexStr: "a(b", - err: anubis.ErrMisconfiguration, - }, - } { - t.Run(tt.name, func(t *testing.T) { - hmc, err := NewHeaderMatchesChecker(tt.header, tt.rexStr) - if err != nil && !errors.Is(err, tt.err) { - t.Fatalf("creating HeaderMatchesChecker failed") - } - - if tt.err != nil && hmc == nil { - return - } - - r, err := http.NewRequest(http.MethodGet, "/", nil) - if err != nil { - t.Fatalf("can't make request: %v", err) - } - - r.Header.Set(tt.reqHeaderKey, tt.reqHeaderValue) - - ok, err := hmc.Check(r) - - if tt.ok != ok { - t.Errorf("ok: %v, wanted: %v", ok, tt.ok) - } - - if err != nil && tt.err != nil && !errors.Is(err, tt.err) { - t.Errorf("err: %v, wanted: %v", err, tt.err) - } - }) - } -} - -func TestHeaderExistsChecker(t *testing.T) { - for _, tt := range []struct { - name string - header string - reqHeader string - ok bool - }{ - { - name: "match", - header: "Authorization", - reqHeader: "Authorization", - ok: true, - }, - { - name: "not_match", - header: "Authorization", - reqHeader: "Authentication", - }, - } { - t.Run(tt.name, func(t *testing.T) { - hec := headerExistsChecker{tt.header} - - r, err := http.NewRequest(http.MethodGet, "/", nil) - if err != nil { - t.Fatalf("can't make request: %v", err) - } - - r.Header.Set(tt.reqHeader, "hunter2") - - ok, err := hec.Check(r) - - if tt.ok != ok { - t.Errorf("ok: %v, wanted: %v", ok, tt.ok) - } - - if err != nil { - t.Errorf("err: %v", err) - } - }) - } -} diff --git a/lib/policy/config/config.go b/lib/policy/config/config.go index 7a71d908..4d67df16 100644 --- a/lib/policy/config/config.go +++ b/lib/policy/config/config.go @@ -5,7 +5,6 @@ import ( "fmt" "io" "io/fs" - "net" "net/http" "os" "regexp" @@ -13,6 +12,9 @@ import ( "time" "github.com/TecharoHQ/anubis/data" + "github.com/TecharoHQ/anubis/lib/checker/headermatches" + "github.com/TecharoHQ/anubis/lib/checker/path" + "github.com/TecharoHQ/anubis/lib/checker/remoteaddress" "k8s.io/apimachinery/pkg/util/yaml" ) @@ -25,7 +27,6 @@ var ( ErrInvalidUserAgentRegex = errors.New("config.Bot: invalid user agent regex") ErrInvalidPathRegex = errors.New("config.Bot: invalid path regex") ErrInvalidHeadersRegex = errors.New("config.Bot: invalid headers regex") - ErrInvalidCIDR = errors.New("config.Bot: invalid CIDR") ErrRegexEndsWithNewline = errors.New("config.Bot: regular expression ends with newline (try >- instead of > in yaml)") ErrInvalidImportStatement = errors.New("config.ImportStatement: invalid source file") ErrCantSetBotAndImportValuesAtOnce = errors.New("config.BotOrImport: can't set bot rules and import values at the same time") @@ -119,7 +120,7 @@ func (b *BotConfig) Valid() error { errs = append(errs, fmt.Errorf("%w: user agent regex: %q", ErrRegexEndsWithNewline, *b.UserAgentRegex)) } - if _, err := regexp.Compile(*b.UserAgentRegex); err != nil { + if err := headermatches.ValidUserAgent(*b.UserAgentRegex); err != nil { errs = append(errs, ErrInvalidUserAgentRegex, err) } } @@ -129,7 +130,7 @@ func (b *BotConfig) Valid() error { errs = append(errs, fmt.Errorf("%w: path regex: %q", ErrRegexEndsWithNewline, *b.PathRegex)) } - if _, err := regexp.Compile(*b.PathRegex); err != nil { + if err := path.Valid(*b.PathRegex); err != nil { errs = append(errs, ErrInvalidPathRegex, err) } } @@ -151,10 +152,8 @@ func (b *BotConfig) Valid() error { } if len(b.RemoteAddr) > 0 { - for _, cidr := range b.RemoteAddr { - if _, _, err := net.ParseCIDR(cidr); err != nil { - errs = append(errs, ErrInvalidCIDR, err) - } + if err := remoteaddress.Valid(b.RemoteAddr); err != nil { + errs = append(errs, err) } } diff --git a/lib/policy/config/config_test.go b/lib/policy/config/config_test.go index 40bb6b43..40efcaef 100644 --- a/lib/policy/config/config_test.go +++ b/lib/policy/config/config_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/TecharoHQ/anubis/data" + "github.com/TecharoHQ/anubis/lib/checker/remoteaddress" . "github.com/TecharoHQ/anubis/lib/policy/config" ) @@ -137,7 +138,7 @@ func TestBotValid(t *testing.T) { Action: RuleAllow, RemoteAddr: []string{"0.0.0.0/33"}, }, - err: ErrInvalidCIDR, + err: remoteaddress.ErrInvalidCIDR, }, { name: "only filter by IP range", diff --git a/lib/policy/policy.go b/lib/policy/policy.go index 0dfd2724..dc2c58fe 100644 --- a/lib/policy/policy.go +++ b/lib/policy/policy.go @@ -9,6 +9,9 @@ import ( "sync/atomic" "github.com/TecharoHQ/anubis/lib/checker" + "github.com/TecharoHQ/anubis/lib/checker/headermatches" + "github.com/TecharoHQ/anubis/lib/checker/path" + "github.com/TecharoHQ/anubis/lib/checker/remoteaddress" "github.com/TecharoHQ/anubis/lib/policy/config" "github.com/TecharoHQ/anubis/lib/store" "github.com/TecharoHQ/anubis/lib/thoth" @@ -73,10 +76,10 @@ func ParseConfig(ctx context.Context, fin io.Reader, fname string, defaultDiffic Action: b.Action, } - cl := checker.List{} + cl := checker.Any{} if len(b.RemoteAddr) > 0 { - c, err := NewRemoteAddrChecker(b.RemoteAddr) + c, err := remoteaddress.New(b.RemoteAddr) if err != nil { validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s remote addr set: %w", b.Name, err)) } else { @@ -85,7 +88,7 @@ func ParseConfig(ctx context.Context, fin io.Reader, fname string, defaultDiffic } if b.UserAgentRegex != nil { - c, err := NewUserAgentChecker(*b.UserAgentRegex) + c, err := headermatches.NewUserAgent(*b.UserAgentRegex) if err != nil { validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s user agent regex: %w", b.Name, err)) } else { @@ -94,7 +97,7 @@ func ParseConfig(ctx context.Context, fin io.Reader, fname string, defaultDiffic } if b.PathRegex != nil { - c, err := NewPathChecker(*b.PathRegex) + c, err := path.New(*b.PathRegex) if err != nil { validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s path regex: %w", b.Name, err)) } else {