mirror of
https://github.com/TecharoHQ/anubis.git
synced 2026-04-11 19:18:46 +00:00
@@ -552,7 +552,7 @@ func (s *Server) check(r *http.Request) (policy.CheckResult, *policy.Bot, error)
|
||||
if matches {
|
||||
return cr("threshold/"+t.Name, t.Action, weight), &policy.Bot{
|
||||
Challenge: t.Challenge,
|
||||
Rules: &checker.List{},
|
||||
Rules: &checker.Any{},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
@@ -563,6 +563,6 @@ func (s *Server) check(r *http.Request) (policy.CheckResult, *policy.Bot, error)
|
||||
ReportAs: s.policy.DefaultDifficulty,
|
||||
Algorithm: config.DefaultAlgorithm,
|
||||
},
|
||||
Rules: &checker.List{},
|
||||
Rules: &checker.Any{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
35
lib/checker/all.go
Normal file
35
lib/checker/all.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package checker
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/TecharoHQ/anubis/internal"
|
||||
)
|
||||
|
||||
type All []Interface
|
||||
|
||||
func (a All) Check(r *http.Request) (bool, error) {
|
||||
for _, c := range a {
|
||||
match, err := c.Check(r)
|
||||
if err != nil {
|
||||
return match, err
|
||||
}
|
||||
if !match {
|
||||
return false, err // no match
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil // match
|
||||
}
|
||||
|
||||
func (a All) Hash() string {
|
||||
var sb strings.Builder
|
||||
|
||||
for _, c := range a {
|
||||
fmt.Fprintln(&sb, c.Hash())
|
||||
}
|
||||
|
||||
return internal.FastHash(sb.String())
|
||||
}
|
||||
70
lib/checker/all_test.go
Normal file
70
lib/checker/all_test.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package checker
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAll_Check(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
checkers []MockChecker
|
||||
want bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "All match",
|
||||
checkers: []MockChecker{
|
||||
{Result: true, Err: nil},
|
||||
{Result: true, Err: nil},
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "One not match",
|
||||
checkers: []MockChecker{
|
||||
{Result: true, Err: nil},
|
||||
{Result: false, Err: nil},
|
||||
},
|
||||
want: false,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "No match",
|
||||
checkers: []MockChecker{
|
||||
{Result: false, Err: nil},
|
||||
{Result: false, Err: nil},
|
||||
},
|
||||
want: false,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Error encountered",
|
||||
checkers: []MockChecker{
|
||||
{Result: true, Err: nil},
|
||||
{Result: false, Err: http.ErrNotSupported},
|
||||
},
|
||||
want: false,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var all All
|
||||
for _, mc := range tt.checkers {
|
||||
all = append(all, mc)
|
||||
}
|
||||
|
||||
got, err := all.Check(nil)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("All.Check() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("All.Check() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
35
lib/checker/any.go
Normal file
35
lib/checker/any.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package checker
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/TecharoHQ/anubis/internal"
|
||||
)
|
||||
|
||||
type Any []Interface
|
||||
|
||||
func (a Any) Check(r *http.Request) (bool, error) {
|
||||
for _, c := range a {
|
||||
match, err := c.Check(r)
|
||||
if err != nil {
|
||||
return match, err
|
||||
}
|
||||
if match {
|
||||
return true, err // match
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil // no match
|
||||
}
|
||||
|
||||
func (a Any) Hash() string {
|
||||
var sb strings.Builder
|
||||
|
||||
for _, c := range a {
|
||||
fmt.Fprintln(&sb, c.Hash())
|
||||
}
|
||||
|
||||
return internal.FastHash(sb.String())
|
||||
}
|
||||
83
lib/checker/any_test.go
Normal file
83
lib/checker/any_test.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package checker
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type MockChecker struct {
|
||||
Result bool
|
||||
Err error
|
||||
}
|
||||
|
||||
func (m MockChecker) Check(r *http.Request) (bool, error) {
|
||||
return m.Result, m.Err
|
||||
}
|
||||
|
||||
func (m MockChecker) Hash() string {
|
||||
return "mock-hash"
|
||||
}
|
||||
|
||||
func TestAny_Check(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
checkers []MockChecker
|
||||
want bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "All match",
|
||||
checkers: []MockChecker{
|
||||
{Result: true, Err: nil},
|
||||
{Result: true, Err: nil},
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "One match",
|
||||
checkers: []MockChecker{
|
||||
{Result: false, Err: nil},
|
||||
{Result: true, Err: nil},
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "No match",
|
||||
checkers: []MockChecker{
|
||||
{Result: false, Err: nil},
|
||||
{Result: false, Err: nil},
|
||||
},
|
||||
want: false,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Error encountered",
|
||||
checkers: []MockChecker{
|
||||
{Result: false, Err: nil},
|
||||
{Result: false, Err: http.ErrNotSupported},
|
||||
},
|
||||
want: false,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var any Any
|
||||
for _, mc := range tt.checkers {
|
||||
any = append(any, mc)
|
||||
}
|
||||
|
||||
got, err := any.Check(nil)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Any.Check() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("Any.Check() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -3,11 +3,7 @@ package checker
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/TecharoHQ/anubis/internal"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -19,29 +15,3 @@ type Interface interface {
|
||||
Check(*http.Request) (matches bool, err error)
|
||||
Hash() string
|
||||
}
|
||||
|
||||
type List []Interface
|
||||
|
||||
func (l List) Check(r *http.Request) (bool, error) {
|
||||
for _, c := range l {
|
||||
ok, err := c.Check(r)
|
||||
if err != nil {
|
||||
return ok, err
|
||||
}
|
||||
if ok {
|
||||
return ok, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (l List) Hash() string {
|
||||
var sb strings.Builder
|
||||
|
||||
for _, c := range l {
|
||||
fmt.Fprintln(&sb, c.Hash())
|
||||
}
|
||||
|
||||
return internal.FastHash(sb.String())
|
||||
}
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
package headermatches
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"regexp"
|
||||
|
||||
"github.com/TecharoHQ/anubis/lib/checker"
|
||||
)
|
||||
|
||||
type Checker struct {
|
||||
@@ -22,3 +26,21 @@ func (c *Checker) Check(r *http.Request) (bool, error) {
|
||||
func (c *Checker) Hash() string {
|
||||
return c.hash
|
||||
}
|
||||
|
||||
func New(key, valueRex string) (checker.Interface, error) {
|
||||
fc := fileConfig{
|
||||
Header: key,
|
||||
ValueRegex: valueRex,
|
||||
}
|
||||
|
||||
if err := fc.Valid(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data, err := json.Marshal(fc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return Factory{}.Build(context.Background(), json.RawMessage(data))
|
||||
}
|
||||
|
||||
@@ -1 +1,35 @@
|
||||
package headermatches
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/TecharoHQ/anubis/lib/checker"
|
||||
)
|
||||
|
||||
func ValidUserAgent(valueRex string) error {
|
||||
fc := fileConfig{
|
||||
Header: "User-Agent",
|
||||
ValueRegex: valueRex,
|
||||
}
|
||||
|
||||
return fc.Valid()
|
||||
}
|
||||
|
||||
func NewUserAgent(valueRex string) (checker.Interface, error) {
|
||||
fc := fileConfig{
|
||||
Header: "User-Agent",
|
||||
ValueRegex: valueRex,
|
||||
}
|
||||
|
||||
if err := fc.Valid(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data, err := json.Marshal(fc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return Factory{}.Build(context.Background(), json.RawMessage(data))
|
||||
}
|
||||
|
||||
@@ -83,6 +83,14 @@ func (fc fileConfig) Valid() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func Valid(cidrs []string) error {
|
||||
fc := fileConfig{
|
||||
RemoteAddr: cidrs,
|
||||
}
|
||||
|
||||
return fc.Valid()
|
||||
}
|
||||
|
||||
func New(cidrs []string) (checker.Interface, error) {
|
||||
fc := fileConfig{
|
||||
RemoteAddr: cidrs,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package remoteaddress
|
||||
package remoteaddress_test
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
@@ -8,17 +8,17 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/TecharoHQ/anubis/lib/checker"
|
||||
"github.com/TecharoHQ/anubis/lib/policy/config"
|
||||
"github.com/TecharoHQ/anubis/lib/checker/remoteaddress"
|
||||
)
|
||||
|
||||
func TestFactoryIsCheckerFactory(t *testing.T) {
|
||||
if _, ok := (any(Factory{})).(checker.Factory); !ok {
|
||||
if _, ok := (any(remoteaddress.Factory{})).(checker.Factory); !ok {
|
||||
t.Fatal("Factory is not an instance of checker.Factory")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFactoryValidateConfig(t *testing.T) {
|
||||
f := Factory{}
|
||||
f := remoteaddress.Factory{}
|
||||
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
@@ -36,14 +36,14 @@ func TestFactoryValidateConfig(t *testing.T) {
|
||||
{
|
||||
name: "not json",
|
||||
data: []byte(`]`),
|
||||
err: config.ErrUnparseableConfig,
|
||||
err: checker.ErrUnparseableConfig,
|
||||
},
|
||||
{
|
||||
name: "no cidr",
|
||||
data: []byte(`{
|
||||
"remote_addresses": []
|
||||
}`),
|
||||
err: ErrNoRemoteAddresses,
|
||||
err: remoteaddress.ErrNoRemoteAddresses,
|
||||
},
|
||||
{
|
||||
name: "bad cidr",
|
||||
@@ -52,7 +52,7 @@ func TestFactoryValidateConfig(t *testing.T) {
|
||||
"according to all laws of aviation"
|
||||
]
|
||||
}`),
|
||||
err: config.ErrInvalidCIDR,
|
||||
err: remoteaddress.ErrInvalidCIDR,
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
@@ -68,7 +68,7 @@ func TestFactoryValidateConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFactoryCreate(t *testing.T) {
|
||||
f := Factory{}
|
||||
f := remoteaddress.Factory{}
|
||||
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
@@ -94,7 +94,7 @@ func TestFactoryCreate(t *testing.T) {
|
||||
"according to all laws of aviation"
|
||||
]
|
||||
}`),
|
||||
err: config.ErrUnparseableConfig,
|
||||
err: checker.ErrUnparseableConfig,
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
@@ -3,150 +3,39 @@ package policy
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/TecharoHQ/anubis"
|
||||
"github.com/TecharoHQ/anubis/internal"
|
||||
"github.com/TecharoHQ/anubis/lib/checker"
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/TecharoHQ/anubis/lib/checker/headerexists"
|
||||
"github.com/TecharoHQ/anubis/lib/checker/headermatches"
|
||||
)
|
||||
|
||||
type RemoteAddrChecker struct {
|
||||
prefixTable *bart.Lite
|
||||
hash string
|
||||
}
|
||||
|
||||
func NewRemoteAddrChecker(cidrs []string) (checker.Interface, error) {
|
||||
table := new(bart.Lite)
|
||||
|
||||
for _, cidr := range cidrs {
|
||||
prefix, err := netip.ParsePrefix(cidr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: range %s not parsing: %w", anubis.ErrMisconfiguration, cidr, err)
|
||||
}
|
||||
|
||||
table.Insert(prefix)
|
||||
}
|
||||
|
||||
return &RemoteAddrChecker{
|
||||
prefixTable: table,
|
||||
hash: internal.FastHash(strings.Join(cidrs, ",")),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (rac *RemoteAddrChecker) Check(r *http.Request) (bool, error) {
|
||||
host := r.Header.Get("X-Real-Ip")
|
||||
if host == "" {
|
||||
return false, fmt.Errorf("%w: header X-Real-Ip is not set", anubis.ErrMisconfiguration)
|
||||
}
|
||||
|
||||
addr, err := netip.ParseAddr(host)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("%w: %s is not an IP address: %w", anubis.ErrMisconfiguration, host, err)
|
||||
}
|
||||
|
||||
return rac.prefixTable.Contains(addr), nil
|
||||
}
|
||||
|
||||
func (rac *RemoteAddrChecker) Hash() string {
|
||||
return rac.hash
|
||||
}
|
||||
|
||||
type HeaderMatchesChecker struct {
|
||||
header string
|
||||
regexp *regexp.Regexp
|
||||
hash string
|
||||
}
|
||||
|
||||
func NewUserAgentChecker(rexStr string) (checker.Interface, error) {
|
||||
return NewHeaderMatchesChecker("User-Agent", rexStr)
|
||||
}
|
||||
|
||||
func NewHeaderMatchesChecker(header, rexStr string) (checker.Interface, error) {
|
||||
rex, err := regexp.Compile(strings.TrimSpace(rexStr))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: regex %s failed parse: %w", anubis.ErrMisconfiguration, rexStr, err)
|
||||
}
|
||||
return &HeaderMatchesChecker{strings.TrimSpace(header), rex, internal.FastHash(header + ": " + rexStr)}, nil
|
||||
}
|
||||
|
||||
func (hmc *HeaderMatchesChecker) Check(r *http.Request) (bool, error) {
|
||||
if hmc.regexp.MatchString(r.Header.Get(hmc.header)) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (hmc *HeaderMatchesChecker) Hash() string {
|
||||
return hmc.hash
|
||||
}
|
||||
|
||||
type PathChecker struct {
|
||||
regexp *regexp.Regexp
|
||||
hash string
|
||||
}
|
||||
|
||||
func NewPathChecker(rexStr string) (checker.Interface, error) {
|
||||
rex, err := regexp.Compile(strings.TrimSpace(rexStr))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: regex %s failed parse: %w", anubis.ErrMisconfiguration, rexStr, err)
|
||||
}
|
||||
return &PathChecker{rex, internal.FastHash(rexStr)}, nil
|
||||
}
|
||||
|
||||
func (pc *PathChecker) Check(r *http.Request) (bool, error) {
|
||||
if pc.regexp.MatchString(r.URL.Path) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (pc *PathChecker) Hash() string {
|
||||
return pc.hash
|
||||
}
|
||||
|
||||
func NewHeaderExistsChecker(key string) checker.Interface {
|
||||
return headerExistsChecker{strings.TrimSpace(key)}
|
||||
}
|
||||
|
||||
type headerExistsChecker struct {
|
||||
header string
|
||||
}
|
||||
|
||||
func (hec headerExistsChecker) Check(r *http.Request) (bool, error) {
|
||||
if r.Header.Get(hec.header) != "" {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (hec headerExistsChecker) Hash() string {
|
||||
return internal.FastHash(hec.header)
|
||||
}
|
||||
|
||||
func NewHeadersChecker(headermap map[string]string) (checker.Interface, error) {
|
||||
var result checker.List
|
||||
var result checker.All
|
||||
var errs []error
|
||||
|
||||
for key, rexStr := range headermap {
|
||||
var keys []string
|
||||
for key := range headermap {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
sort.Strings(keys)
|
||||
|
||||
for _, key := range keys {
|
||||
rexStr := headermap[key]
|
||||
if rexStr == ".*" {
|
||||
result = append(result, headerExistsChecker{strings.TrimSpace(key)})
|
||||
result = append(result, headerexists.New(strings.TrimSpace(key)))
|
||||
continue
|
||||
}
|
||||
|
||||
rex, err := regexp.Compile(strings.TrimSpace(rexStr))
|
||||
c, err := headermatches.New(key, rexStr)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("while compiling header %s regex %s: %w", key, rexStr, err))
|
||||
errs = append(errs, fmt.Errorf("while parsing header %s regex %s: %w", key, rexStr, err))
|
||||
continue
|
||||
}
|
||||
|
||||
result = append(result, &HeaderMatchesChecker{key, rex, internal.FastHash(key + ": " + rexStr)})
|
||||
result = append(result, c)
|
||||
}
|
||||
|
||||
if len(errs) != 0 {
|
||||
|
||||
@@ -1,202 +0,0 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/TecharoHQ/anubis"
|
||||
)
|
||||
|
||||
func TestRemoteAddrChecker(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
err error
|
||||
name string
|
||||
ip string
|
||||
cidrs []string
|
||||
ok bool
|
||||
}{
|
||||
{
|
||||
name: "match_ipv4",
|
||||
cidrs: []string{"0.0.0.0/0"},
|
||||
ip: "1.1.1.1",
|
||||
ok: true,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "match_ipv6",
|
||||
cidrs: []string{"::/0"},
|
||||
ip: "cafe:babe::",
|
||||
ok: true,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "not_match_ipv4",
|
||||
cidrs: []string{"1.1.1.1/32"},
|
||||
ip: "1.1.1.2",
|
||||
ok: false,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "not_match_ipv6",
|
||||
cidrs: []string{"cafe:babe::/128"},
|
||||
ip: "cafe:babe:4::/128",
|
||||
ok: false,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "no_ip_set",
|
||||
cidrs: []string{"::/0"},
|
||||
ok: false,
|
||||
err: anubis.ErrMisconfiguration,
|
||||
},
|
||||
{
|
||||
name: "invalid_ip",
|
||||
cidrs: []string{"::/0"},
|
||||
ip: "According to all natural laws of aviation",
|
||||
ok: false,
|
||||
err: anubis.ErrMisconfiguration,
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rac, err := NewRemoteAddrChecker(tt.cidrs)
|
||||
if err != nil && !errors.Is(err, tt.err) {
|
||||
t.Fatalf("creating RemoteAddrChecker failed: %v", err)
|
||||
}
|
||||
|
||||
r, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("can't make request: %v", err)
|
||||
}
|
||||
|
||||
if tt.ip != "" {
|
||||
r.Header.Add("X-Real-Ip", tt.ip)
|
||||
}
|
||||
|
||||
ok, err := rac.Check(r)
|
||||
|
||||
if tt.ok != ok {
|
||||
t.Errorf("ok: %v, wanted: %v", ok, tt.ok)
|
||||
}
|
||||
|
||||
if err != nil && tt.err != nil && !errors.Is(err, tt.err) {
|
||||
t.Errorf("err: %v, wanted: %v", err, tt.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeaderMatchesChecker(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
err error
|
||||
name string
|
||||
header string
|
||||
rexStr string
|
||||
reqHeaderKey string
|
||||
reqHeaderValue string
|
||||
ok bool
|
||||
}{
|
||||
{
|
||||
name: "match",
|
||||
header: "Cf-Worker",
|
||||
rexStr: ".*",
|
||||
reqHeaderKey: "Cf-Worker",
|
||||
reqHeaderValue: "true",
|
||||
ok: true,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "not_match",
|
||||
header: "Cf-Worker",
|
||||
rexStr: "false",
|
||||
reqHeaderKey: "Cf-Worker",
|
||||
reqHeaderValue: "true",
|
||||
ok: false,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "not_present",
|
||||
header: "Cf-Worker",
|
||||
rexStr: "foobar",
|
||||
reqHeaderKey: "Something-Else",
|
||||
reqHeaderValue: "true",
|
||||
ok: false,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid_regex",
|
||||
rexStr: "a(b",
|
||||
err: anubis.ErrMisconfiguration,
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hmc, err := NewHeaderMatchesChecker(tt.header, tt.rexStr)
|
||||
if err != nil && !errors.Is(err, tt.err) {
|
||||
t.Fatalf("creating HeaderMatchesChecker failed")
|
||||
}
|
||||
|
||||
if tt.err != nil && hmc == nil {
|
||||
return
|
||||
}
|
||||
|
||||
r, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("can't make request: %v", err)
|
||||
}
|
||||
|
||||
r.Header.Set(tt.reqHeaderKey, tt.reqHeaderValue)
|
||||
|
||||
ok, err := hmc.Check(r)
|
||||
|
||||
if tt.ok != ok {
|
||||
t.Errorf("ok: %v, wanted: %v", ok, tt.ok)
|
||||
}
|
||||
|
||||
if err != nil && tt.err != nil && !errors.Is(err, tt.err) {
|
||||
t.Errorf("err: %v, wanted: %v", err, tt.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeaderExistsChecker(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
header string
|
||||
reqHeader string
|
||||
ok bool
|
||||
}{
|
||||
{
|
||||
name: "match",
|
||||
header: "Authorization",
|
||||
reqHeader: "Authorization",
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
name: "not_match",
|
||||
header: "Authorization",
|
||||
reqHeader: "Authentication",
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hec := headerExistsChecker{tt.header}
|
||||
|
||||
r, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("can't make request: %v", err)
|
||||
}
|
||||
|
||||
r.Header.Set(tt.reqHeader, "hunter2")
|
||||
|
||||
ok, err := hec.Check(r)
|
||||
|
||||
if tt.ok != ok {
|
||||
t.Errorf("ok: %v, wanted: %v", ok, tt.ok)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("err: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
@@ -13,6 +12,9 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/TecharoHQ/anubis/data"
|
||||
"github.com/TecharoHQ/anubis/lib/checker/headermatches"
|
||||
"github.com/TecharoHQ/anubis/lib/checker/path"
|
||||
"github.com/TecharoHQ/anubis/lib/checker/remoteaddress"
|
||||
"k8s.io/apimachinery/pkg/util/yaml"
|
||||
)
|
||||
|
||||
@@ -25,7 +27,6 @@ var (
|
||||
ErrInvalidUserAgentRegex = errors.New("config.Bot: invalid user agent regex")
|
||||
ErrInvalidPathRegex = errors.New("config.Bot: invalid path regex")
|
||||
ErrInvalidHeadersRegex = errors.New("config.Bot: invalid headers regex")
|
||||
ErrInvalidCIDR = errors.New("config.Bot: invalid CIDR")
|
||||
ErrRegexEndsWithNewline = errors.New("config.Bot: regular expression ends with newline (try >- instead of > in yaml)")
|
||||
ErrInvalidImportStatement = errors.New("config.ImportStatement: invalid source file")
|
||||
ErrCantSetBotAndImportValuesAtOnce = errors.New("config.BotOrImport: can't set bot rules and import values at the same time")
|
||||
@@ -119,7 +120,7 @@ func (b *BotConfig) Valid() error {
|
||||
errs = append(errs, fmt.Errorf("%w: user agent regex: %q", ErrRegexEndsWithNewline, *b.UserAgentRegex))
|
||||
}
|
||||
|
||||
if _, err := regexp.Compile(*b.UserAgentRegex); err != nil {
|
||||
if err := headermatches.ValidUserAgent(*b.UserAgentRegex); err != nil {
|
||||
errs = append(errs, ErrInvalidUserAgentRegex, err)
|
||||
}
|
||||
}
|
||||
@@ -129,7 +130,7 @@ func (b *BotConfig) Valid() error {
|
||||
errs = append(errs, fmt.Errorf("%w: path regex: %q", ErrRegexEndsWithNewline, *b.PathRegex))
|
||||
}
|
||||
|
||||
if _, err := regexp.Compile(*b.PathRegex); err != nil {
|
||||
if err := path.Valid(*b.PathRegex); err != nil {
|
||||
errs = append(errs, ErrInvalidPathRegex, err)
|
||||
}
|
||||
}
|
||||
@@ -151,10 +152,8 @@ func (b *BotConfig) Valid() error {
|
||||
}
|
||||
|
||||
if len(b.RemoteAddr) > 0 {
|
||||
for _, cidr := range b.RemoteAddr {
|
||||
if _, _, err := net.ParseCIDR(cidr); err != nil {
|
||||
errs = append(errs, ErrInvalidCIDR, err)
|
||||
}
|
||||
if err := remoteaddress.Valid(b.RemoteAddr); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/TecharoHQ/anubis/data"
|
||||
"github.com/TecharoHQ/anubis/lib/checker/remoteaddress"
|
||||
. "github.com/TecharoHQ/anubis/lib/policy/config"
|
||||
)
|
||||
|
||||
@@ -137,7 +138,7 @@ func TestBotValid(t *testing.T) {
|
||||
Action: RuleAllow,
|
||||
RemoteAddr: []string{"0.0.0.0/33"},
|
||||
},
|
||||
err: ErrInvalidCIDR,
|
||||
err: remoteaddress.ErrInvalidCIDR,
|
||||
},
|
||||
{
|
||||
name: "only filter by IP range",
|
||||
|
||||
@@ -9,6 +9,9 @@ import (
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/TecharoHQ/anubis/lib/checker"
|
||||
"github.com/TecharoHQ/anubis/lib/checker/headermatches"
|
||||
"github.com/TecharoHQ/anubis/lib/checker/path"
|
||||
"github.com/TecharoHQ/anubis/lib/checker/remoteaddress"
|
||||
"github.com/TecharoHQ/anubis/lib/policy/config"
|
||||
"github.com/TecharoHQ/anubis/lib/store"
|
||||
"github.com/TecharoHQ/anubis/lib/thoth"
|
||||
@@ -73,10 +76,10 @@ func ParseConfig(ctx context.Context, fin io.Reader, fname string, defaultDiffic
|
||||
Action: b.Action,
|
||||
}
|
||||
|
||||
cl := checker.List{}
|
||||
cl := checker.Any{}
|
||||
|
||||
if len(b.RemoteAddr) > 0 {
|
||||
c, err := NewRemoteAddrChecker(b.RemoteAddr)
|
||||
c, err := remoteaddress.New(b.RemoteAddr)
|
||||
if err != nil {
|
||||
validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s remote addr set: %w", b.Name, err))
|
||||
} else {
|
||||
@@ -85,7 +88,7 @@ func ParseConfig(ctx context.Context, fin io.Reader, fname string, defaultDiffic
|
||||
}
|
||||
|
||||
if b.UserAgentRegex != nil {
|
||||
c, err := NewUserAgentChecker(*b.UserAgentRegex)
|
||||
c, err := headermatches.NewUserAgent(*b.UserAgentRegex)
|
||||
if err != nil {
|
||||
validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s user agent regex: %w", b.Name, err))
|
||||
} else {
|
||||
@@ -94,7 +97,7 @@ func ParseConfig(ctx context.Context, fin io.Reader, fname string, defaultDiffic
|
||||
}
|
||||
|
||||
if b.PathRegex != nil {
|
||||
c, err := NewPathChecker(*b.PathRegex)
|
||||
c, err := path.New(*b.PathRegex)
|
||||
if err != nil {
|
||||
validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s path regex: %w", b.Name, err))
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user