mirror of
https://github.com/TecharoHQ/anubis.git
synced 2026-04-17 13:54:59 +00:00
@@ -552,7 +552,7 @@ func (s *Server) check(r *http.Request) (policy.CheckResult, *policy.Bot, error)
|
|||||||
if matches {
|
if matches {
|
||||||
return cr("threshold/"+t.Name, t.Action, weight), &policy.Bot{
|
return cr("threshold/"+t.Name, t.Action, weight), &policy.Bot{
|
||||||
Challenge: t.Challenge,
|
Challenge: t.Challenge,
|
||||||
Rules: &checker.List{},
|
Rules: &checker.Any{},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -563,6 +563,6 @@ func (s *Server) check(r *http.Request) (policy.CheckResult, *policy.Bot, error)
|
|||||||
ReportAs: s.policy.DefaultDifficulty,
|
ReportAs: s.policy.DefaultDifficulty,
|
||||||
Algorithm: config.DefaultAlgorithm,
|
Algorithm: config.DefaultAlgorithm,
|
||||||
},
|
},
|
||||||
Rules: &checker.List{},
|
Rules: &checker.Any{},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
35
lib/checker/all.go
Normal file
35
lib/checker/all.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package checker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/TecharoHQ/anubis/internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
type All []Interface
|
||||||
|
|
||||||
|
func (a All) Check(r *http.Request) (bool, error) {
|
||||||
|
for _, c := range a {
|
||||||
|
match, err := c.Check(r)
|
||||||
|
if err != nil {
|
||||||
|
return match, err
|
||||||
|
}
|
||||||
|
if !match {
|
||||||
|
return false, err // no match
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil // match
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a All) Hash() string {
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
for _, c := range a {
|
||||||
|
fmt.Fprintln(&sb, c.Hash())
|
||||||
|
}
|
||||||
|
|
||||||
|
return internal.FastHash(sb.String())
|
||||||
|
}
|
||||||
70
lib/checker/all_test.go
Normal file
70
lib/checker/all_test.go
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
package checker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAll_Check(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
checkers []MockChecker
|
||||||
|
want bool
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "All match",
|
||||||
|
checkers: []MockChecker{
|
||||||
|
{Result: true, Err: nil},
|
||||||
|
{Result: true, Err: nil},
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "One not match",
|
||||||
|
checkers: []MockChecker{
|
||||||
|
{Result: true, Err: nil},
|
||||||
|
{Result: false, Err: nil},
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No match",
|
||||||
|
checkers: []MockChecker{
|
||||||
|
{Result: false, Err: nil},
|
||||||
|
{Result: false, Err: nil},
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Error encountered",
|
||||||
|
checkers: []MockChecker{
|
||||||
|
{Result: true, Err: nil},
|
||||||
|
{Result: false, Err: http.ErrNotSupported},
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var all All
|
||||||
|
for _, mc := range tt.checkers {
|
||||||
|
all = append(all, mc)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := all.Check(nil)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("All.Check() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("All.Check() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
35
lib/checker/any.go
Normal file
35
lib/checker/any.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package checker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/TecharoHQ/anubis/internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Any []Interface
|
||||||
|
|
||||||
|
func (a Any) Check(r *http.Request) (bool, error) {
|
||||||
|
for _, c := range a {
|
||||||
|
match, err := c.Check(r)
|
||||||
|
if err != nil {
|
||||||
|
return match, err
|
||||||
|
}
|
||||||
|
if match {
|
||||||
|
return true, err // match
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil // no match
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a Any) Hash() string {
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
for _, c := range a {
|
||||||
|
fmt.Fprintln(&sb, c.Hash())
|
||||||
|
}
|
||||||
|
|
||||||
|
return internal.FastHash(sb.String())
|
||||||
|
}
|
||||||
83
lib/checker/any_test.go
Normal file
83
lib/checker/any_test.go
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
package checker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MockChecker struct {
|
||||||
|
Result bool
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m MockChecker) Check(r *http.Request) (bool, error) {
|
||||||
|
return m.Result, m.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m MockChecker) Hash() string {
|
||||||
|
return "mock-hash"
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAny_Check(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
checkers []MockChecker
|
||||||
|
want bool
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "All match",
|
||||||
|
checkers: []MockChecker{
|
||||||
|
{Result: true, Err: nil},
|
||||||
|
{Result: true, Err: nil},
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "One match",
|
||||||
|
checkers: []MockChecker{
|
||||||
|
{Result: false, Err: nil},
|
||||||
|
{Result: true, Err: nil},
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No match",
|
||||||
|
checkers: []MockChecker{
|
||||||
|
{Result: false, Err: nil},
|
||||||
|
{Result: false, Err: nil},
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Error encountered",
|
||||||
|
checkers: []MockChecker{
|
||||||
|
{Result: false, Err: nil},
|
||||||
|
{Result: false, Err: http.ErrNotSupported},
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var any Any
|
||||||
|
for _, mc := range tt.checkers {
|
||||||
|
any = append(any, mc)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := any.Check(nil)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Any.Check() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("Any.Check() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,11 +3,7 @@ package checker
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/TecharoHQ/anubis/internal"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -19,29 +15,3 @@ type Interface interface {
|
|||||||
Check(*http.Request) (matches bool, err error)
|
Check(*http.Request) (matches bool, err error)
|
||||||
Hash() string
|
Hash() string
|
||||||
}
|
}
|
||||||
|
|
||||||
type List []Interface
|
|
||||||
|
|
||||||
func (l List) Check(r *http.Request) (bool, error) {
|
|
||||||
for _, c := range l {
|
|
||||||
ok, err := c.Check(r)
|
|
||||||
if err != nil {
|
|
||||||
return ok, err
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
return ok, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l List) Hash() string {
|
|
||||||
var sb strings.Builder
|
|
||||||
|
|
||||||
for _, c := range l {
|
|
||||||
fmt.Fprintln(&sb, c.Hash())
|
|
||||||
}
|
|
||||||
|
|
||||||
return internal.FastHash(sb.String())
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,8 +1,12 @@
|
|||||||
package headermatches
|
package headermatches
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Checker struct {
|
type Checker struct {
|
||||||
@@ -22,3 +26,21 @@ func (c *Checker) Check(r *http.Request) (bool, error) {
|
|||||||
func (c *Checker) Hash() string {
|
func (c *Checker) Hash() string {
|
||||||
return c.hash
|
return c.hash
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func New(key, valueRex string) (checker.Interface, error) {
|
||||||
|
fc := fileConfig{
|
||||||
|
Header: key,
|
||||||
|
ValueRegex: valueRex,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := fc.Valid(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(fc)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return Factory{}.Build(context.Background(), json.RawMessage(data))
|
||||||
|
}
|
||||||
|
|||||||
@@ -1 +1,35 @@
|
|||||||
package headermatches
|
package headermatches
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ValidUserAgent(valueRex string) error {
|
||||||
|
fc := fileConfig{
|
||||||
|
Header: "User-Agent",
|
||||||
|
ValueRegex: valueRex,
|
||||||
|
}
|
||||||
|
|
||||||
|
return fc.Valid()
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUserAgent(valueRex string) (checker.Interface, error) {
|
||||||
|
fc := fileConfig{
|
||||||
|
Header: "User-Agent",
|
||||||
|
ValueRegex: valueRex,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := fc.Valid(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(fc)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return Factory{}.Build(context.Background(), json.RawMessage(data))
|
||||||
|
}
|
||||||
|
|||||||
@@ -83,6 +83,14 @@ func (fc fileConfig) Valid() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Valid(cidrs []string) error {
|
||||||
|
fc := fileConfig{
|
||||||
|
RemoteAddr: cidrs,
|
||||||
|
}
|
||||||
|
|
||||||
|
return fc.Valid()
|
||||||
|
}
|
||||||
|
|
||||||
func New(cidrs []string) (checker.Interface, error) {
|
func New(cidrs []string) (checker.Interface, error) {
|
||||||
fc := fileConfig{
|
fc := fileConfig{
|
||||||
RemoteAddr: cidrs,
|
RemoteAddr: cidrs,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package remoteaddress
|
package remoteaddress_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
_ "embed"
|
_ "embed"
|
||||||
@@ -8,17 +8,17 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/TecharoHQ/anubis/lib/checker"
|
"github.com/TecharoHQ/anubis/lib/checker"
|
||||||
"github.com/TecharoHQ/anubis/lib/policy/config"
|
"github.com/TecharoHQ/anubis/lib/checker/remoteaddress"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestFactoryIsCheckerFactory(t *testing.T) {
|
func TestFactoryIsCheckerFactory(t *testing.T) {
|
||||||
if _, ok := (any(Factory{})).(checker.Factory); !ok {
|
if _, ok := (any(remoteaddress.Factory{})).(checker.Factory); !ok {
|
||||||
t.Fatal("Factory is not an instance of checker.Factory")
|
t.Fatal("Factory is not an instance of checker.Factory")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFactoryValidateConfig(t *testing.T) {
|
func TestFactoryValidateConfig(t *testing.T) {
|
||||||
f := Factory{}
|
f := remoteaddress.Factory{}
|
||||||
|
|
||||||
for _, tt := range []struct {
|
for _, tt := range []struct {
|
||||||
name string
|
name string
|
||||||
@@ -36,14 +36,14 @@ func TestFactoryValidateConfig(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "not json",
|
name: "not json",
|
||||||
data: []byte(`]`),
|
data: []byte(`]`),
|
||||||
err: config.ErrUnparseableConfig,
|
err: checker.ErrUnparseableConfig,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "no cidr",
|
name: "no cidr",
|
||||||
data: []byte(`{
|
data: []byte(`{
|
||||||
"remote_addresses": []
|
"remote_addresses": []
|
||||||
}`),
|
}`),
|
||||||
err: ErrNoRemoteAddresses,
|
err: remoteaddress.ErrNoRemoteAddresses,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "bad cidr",
|
name: "bad cidr",
|
||||||
@@ -52,7 +52,7 @@ func TestFactoryValidateConfig(t *testing.T) {
|
|||||||
"according to all laws of aviation"
|
"according to all laws of aviation"
|
||||||
]
|
]
|
||||||
}`),
|
}`),
|
||||||
err: config.ErrInvalidCIDR,
|
err: remoteaddress.ErrInvalidCIDR,
|
||||||
},
|
},
|
||||||
} {
|
} {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
@@ -68,7 +68,7 @@ func TestFactoryValidateConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFactoryCreate(t *testing.T) {
|
func TestFactoryCreate(t *testing.T) {
|
||||||
f := Factory{}
|
f := remoteaddress.Factory{}
|
||||||
|
|
||||||
for _, tt := range []struct {
|
for _, tt := range []struct {
|
||||||
name string
|
name string
|
||||||
@@ -94,7 +94,7 @@ func TestFactoryCreate(t *testing.T) {
|
|||||||
"according to all laws of aviation"
|
"according to all laws of aviation"
|
||||||
]
|
]
|
||||||
}`),
|
}`),
|
||||||
err: config.ErrUnparseableConfig,
|
err: checker.ErrUnparseableConfig,
|
||||||
},
|
},
|
||||||
} {
|
} {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
|||||||
@@ -3,150 +3,39 @@ package policy
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"sort"
|
||||||
"net/netip"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/TecharoHQ/anubis"
|
|
||||||
"github.com/TecharoHQ/anubis/internal"
|
|
||||||
"github.com/TecharoHQ/anubis/lib/checker"
|
"github.com/TecharoHQ/anubis/lib/checker"
|
||||||
"github.com/gaissmai/bart"
|
"github.com/TecharoHQ/anubis/lib/checker/headerexists"
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker/headermatches"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RemoteAddrChecker struct {
|
|
||||||
prefixTable *bart.Lite
|
|
||||||
hash string
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewRemoteAddrChecker(cidrs []string) (checker.Interface, error) {
|
|
||||||
table := new(bart.Lite)
|
|
||||||
|
|
||||||
for _, cidr := range cidrs {
|
|
||||||
prefix, err := netip.ParsePrefix(cidr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("%w: range %s not parsing: %w", anubis.ErrMisconfiguration, cidr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
table.Insert(prefix)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &RemoteAddrChecker{
|
|
||||||
prefixTable: table,
|
|
||||||
hash: internal.FastHash(strings.Join(cidrs, ",")),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rac *RemoteAddrChecker) Check(r *http.Request) (bool, error) {
|
|
||||||
host := r.Header.Get("X-Real-Ip")
|
|
||||||
if host == "" {
|
|
||||||
return false, fmt.Errorf("%w: header X-Real-Ip is not set", anubis.ErrMisconfiguration)
|
|
||||||
}
|
|
||||||
|
|
||||||
addr, err := netip.ParseAddr(host)
|
|
||||||
if err != nil {
|
|
||||||
return false, fmt.Errorf("%w: %s is not an IP address: %w", anubis.ErrMisconfiguration, host, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return rac.prefixTable.Contains(addr), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rac *RemoteAddrChecker) Hash() string {
|
|
||||||
return rac.hash
|
|
||||||
}
|
|
||||||
|
|
||||||
type HeaderMatchesChecker struct {
|
|
||||||
header string
|
|
||||||
regexp *regexp.Regexp
|
|
||||||
hash string
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewUserAgentChecker(rexStr string) (checker.Interface, error) {
|
|
||||||
return NewHeaderMatchesChecker("User-Agent", rexStr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewHeaderMatchesChecker(header, rexStr string) (checker.Interface, error) {
|
|
||||||
rex, err := regexp.Compile(strings.TrimSpace(rexStr))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("%w: regex %s failed parse: %w", anubis.ErrMisconfiguration, rexStr, err)
|
|
||||||
}
|
|
||||||
return &HeaderMatchesChecker{strings.TrimSpace(header), rex, internal.FastHash(header + ": " + rexStr)}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hmc *HeaderMatchesChecker) Check(r *http.Request) (bool, error) {
|
|
||||||
if hmc.regexp.MatchString(r.Header.Get(hmc.header)) {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hmc *HeaderMatchesChecker) Hash() string {
|
|
||||||
return hmc.hash
|
|
||||||
}
|
|
||||||
|
|
||||||
type PathChecker struct {
|
|
||||||
regexp *regexp.Regexp
|
|
||||||
hash string
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewPathChecker(rexStr string) (checker.Interface, error) {
|
|
||||||
rex, err := regexp.Compile(strings.TrimSpace(rexStr))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("%w: regex %s failed parse: %w", anubis.ErrMisconfiguration, rexStr, err)
|
|
||||||
}
|
|
||||||
return &PathChecker{rex, internal.FastHash(rexStr)}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pc *PathChecker) Check(r *http.Request) (bool, error) {
|
|
||||||
if pc.regexp.MatchString(r.URL.Path) {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pc *PathChecker) Hash() string {
|
|
||||||
return pc.hash
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewHeaderExistsChecker(key string) checker.Interface {
|
|
||||||
return headerExistsChecker{strings.TrimSpace(key)}
|
|
||||||
}
|
|
||||||
|
|
||||||
type headerExistsChecker struct {
|
|
||||||
header string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hec headerExistsChecker) Check(r *http.Request) (bool, error) {
|
|
||||||
if r.Header.Get(hec.header) != "" {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hec headerExistsChecker) Hash() string {
|
|
||||||
return internal.FastHash(hec.header)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewHeadersChecker(headermap map[string]string) (checker.Interface, error) {
|
func NewHeadersChecker(headermap map[string]string) (checker.Interface, error) {
|
||||||
var result checker.List
|
var result checker.All
|
||||||
var errs []error
|
var errs []error
|
||||||
|
|
||||||
for key, rexStr := range headermap {
|
var keys []string
|
||||||
|
for key := range headermap {
|
||||||
|
keys = append(keys, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Strings(keys)
|
||||||
|
|
||||||
|
for _, key := range keys {
|
||||||
|
rexStr := headermap[key]
|
||||||
if rexStr == ".*" {
|
if rexStr == ".*" {
|
||||||
result = append(result, headerExistsChecker{strings.TrimSpace(key)})
|
result = append(result, headerexists.New(strings.TrimSpace(key)))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
rex, err := regexp.Compile(strings.TrimSpace(rexStr))
|
c, err := headermatches.New(key, rexStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errs = append(errs, fmt.Errorf("while compiling header %s regex %s: %w", key, rexStr, err))
|
errs = append(errs, fmt.Errorf("while parsing header %s regex %s: %w", key, rexStr, err))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
result = append(result, &HeaderMatchesChecker{key, rex, internal.FastHash(key + ": " + rexStr)})
|
result = append(result, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(errs) != 0 {
|
if len(errs) != 0 {
|
||||||
|
|||||||
@@ -1,202 +0,0 @@
|
|||||||
package policy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"net/http"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/TecharoHQ/anubis"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestRemoteAddrChecker(t *testing.T) {
|
|
||||||
for _, tt := range []struct {
|
|
||||||
err error
|
|
||||||
name string
|
|
||||||
ip string
|
|
||||||
cidrs []string
|
|
||||||
ok bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "match_ipv4",
|
|
||||||
cidrs: []string{"0.0.0.0/0"},
|
|
||||||
ip: "1.1.1.1",
|
|
||||||
ok: true,
|
|
||||||
err: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "match_ipv6",
|
|
||||||
cidrs: []string{"::/0"},
|
|
||||||
ip: "cafe:babe::",
|
|
||||||
ok: true,
|
|
||||||
err: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "not_match_ipv4",
|
|
||||||
cidrs: []string{"1.1.1.1/32"},
|
|
||||||
ip: "1.1.1.2",
|
|
||||||
ok: false,
|
|
||||||
err: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "not_match_ipv6",
|
|
||||||
cidrs: []string{"cafe:babe::/128"},
|
|
||||||
ip: "cafe:babe:4::/128",
|
|
||||||
ok: false,
|
|
||||||
err: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "no_ip_set",
|
|
||||||
cidrs: []string{"::/0"},
|
|
||||||
ok: false,
|
|
||||||
err: anubis.ErrMisconfiguration,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid_ip",
|
|
||||||
cidrs: []string{"::/0"},
|
|
||||||
ip: "According to all natural laws of aviation",
|
|
||||||
ok: false,
|
|
||||||
err: anubis.ErrMisconfiguration,
|
|
||||||
},
|
|
||||||
} {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
rac, err := NewRemoteAddrChecker(tt.cidrs)
|
|
||||||
if err != nil && !errors.Is(err, tt.err) {
|
|
||||||
t.Fatalf("creating RemoteAddrChecker failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r, err := http.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("can't make request: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.ip != "" {
|
|
||||||
r.Header.Add("X-Real-Ip", tt.ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, err := rac.Check(r)
|
|
||||||
|
|
||||||
if tt.ok != ok {
|
|
||||||
t.Errorf("ok: %v, wanted: %v", ok, tt.ok)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil && tt.err != nil && !errors.Is(err, tt.err) {
|
|
||||||
t.Errorf("err: %v, wanted: %v", err, tt.err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHeaderMatchesChecker(t *testing.T) {
|
|
||||||
for _, tt := range []struct {
|
|
||||||
err error
|
|
||||||
name string
|
|
||||||
header string
|
|
||||||
rexStr string
|
|
||||||
reqHeaderKey string
|
|
||||||
reqHeaderValue string
|
|
||||||
ok bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "match",
|
|
||||||
header: "Cf-Worker",
|
|
||||||
rexStr: ".*",
|
|
||||||
reqHeaderKey: "Cf-Worker",
|
|
||||||
reqHeaderValue: "true",
|
|
||||||
ok: true,
|
|
||||||
err: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "not_match",
|
|
||||||
header: "Cf-Worker",
|
|
||||||
rexStr: "false",
|
|
||||||
reqHeaderKey: "Cf-Worker",
|
|
||||||
reqHeaderValue: "true",
|
|
||||||
ok: false,
|
|
||||||
err: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "not_present",
|
|
||||||
header: "Cf-Worker",
|
|
||||||
rexStr: "foobar",
|
|
||||||
reqHeaderKey: "Something-Else",
|
|
||||||
reqHeaderValue: "true",
|
|
||||||
ok: false,
|
|
||||||
err: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid_regex",
|
|
||||||
rexStr: "a(b",
|
|
||||||
err: anubis.ErrMisconfiguration,
|
|
||||||
},
|
|
||||||
} {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
hmc, err := NewHeaderMatchesChecker(tt.header, tt.rexStr)
|
|
||||||
if err != nil && !errors.Is(err, tt.err) {
|
|
||||||
t.Fatalf("creating HeaderMatchesChecker failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.err != nil && hmc == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
r, err := http.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("can't make request: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.Header.Set(tt.reqHeaderKey, tt.reqHeaderValue)
|
|
||||||
|
|
||||||
ok, err := hmc.Check(r)
|
|
||||||
|
|
||||||
if tt.ok != ok {
|
|
||||||
t.Errorf("ok: %v, wanted: %v", ok, tt.ok)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil && tt.err != nil && !errors.Is(err, tt.err) {
|
|
||||||
t.Errorf("err: %v, wanted: %v", err, tt.err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHeaderExistsChecker(t *testing.T) {
|
|
||||||
for _, tt := range []struct {
|
|
||||||
name string
|
|
||||||
header string
|
|
||||||
reqHeader string
|
|
||||||
ok bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "match",
|
|
||||||
header: "Authorization",
|
|
||||||
reqHeader: "Authorization",
|
|
||||||
ok: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "not_match",
|
|
||||||
header: "Authorization",
|
|
||||||
reqHeader: "Authentication",
|
|
||||||
},
|
|
||||||
} {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
hec := headerExistsChecker{tt.header}
|
|
||||||
|
|
||||||
r, err := http.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("can't make request: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.Header.Set(tt.reqHeader, "hunter2")
|
|
||||||
|
|
||||||
ok, err := hec.Check(r)
|
|
||||||
|
|
||||||
if tt.ok != ok {
|
|
||||||
t.Errorf("ok: %v, wanted: %v", ok, tt.ok)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("err: %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"regexp"
|
"regexp"
|
||||||
@@ -13,6 +12,9 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/TecharoHQ/anubis/data"
|
"github.com/TecharoHQ/anubis/data"
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker/headermatches"
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker/path"
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker/remoteaddress"
|
||||||
"k8s.io/apimachinery/pkg/util/yaml"
|
"k8s.io/apimachinery/pkg/util/yaml"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -25,7 +27,6 @@ var (
|
|||||||
ErrInvalidUserAgentRegex = errors.New("config.Bot: invalid user agent regex")
|
ErrInvalidUserAgentRegex = errors.New("config.Bot: invalid user agent regex")
|
||||||
ErrInvalidPathRegex = errors.New("config.Bot: invalid path regex")
|
ErrInvalidPathRegex = errors.New("config.Bot: invalid path regex")
|
||||||
ErrInvalidHeadersRegex = errors.New("config.Bot: invalid headers regex")
|
ErrInvalidHeadersRegex = errors.New("config.Bot: invalid headers regex")
|
||||||
ErrInvalidCIDR = errors.New("config.Bot: invalid CIDR")
|
|
||||||
ErrRegexEndsWithNewline = errors.New("config.Bot: regular expression ends with newline (try >- instead of > in yaml)")
|
ErrRegexEndsWithNewline = errors.New("config.Bot: regular expression ends with newline (try >- instead of > in yaml)")
|
||||||
ErrInvalidImportStatement = errors.New("config.ImportStatement: invalid source file")
|
ErrInvalidImportStatement = errors.New("config.ImportStatement: invalid source file")
|
||||||
ErrCantSetBotAndImportValuesAtOnce = errors.New("config.BotOrImport: can't set bot rules and import values at the same time")
|
ErrCantSetBotAndImportValuesAtOnce = errors.New("config.BotOrImport: can't set bot rules and import values at the same time")
|
||||||
@@ -119,7 +120,7 @@ func (b *BotConfig) Valid() error {
|
|||||||
errs = append(errs, fmt.Errorf("%w: user agent regex: %q", ErrRegexEndsWithNewline, *b.UserAgentRegex))
|
errs = append(errs, fmt.Errorf("%w: user agent regex: %q", ErrRegexEndsWithNewline, *b.UserAgentRegex))
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := regexp.Compile(*b.UserAgentRegex); err != nil {
|
if err := headermatches.ValidUserAgent(*b.UserAgentRegex); err != nil {
|
||||||
errs = append(errs, ErrInvalidUserAgentRegex, err)
|
errs = append(errs, ErrInvalidUserAgentRegex, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -129,7 +130,7 @@ func (b *BotConfig) Valid() error {
|
|||||||
errs = append(errs, fmt.Errorf("%w: path regex: %q", ErrRegexEndsWithNewline, *b.PathRegex))
|
errs = append(errs, fmt.Errorf("%w: path regex: %q", ErrRegexEndsWithNewline, *b.PathRegex))
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := regexp.Compile(*b.PathRegex); err != nil {
|
if err := path.Valid(*b.PathRegex); err != nil {
|
||||||
errs = append(errs, ErrInvalidPathRegex, err)
|
errs = append(errs, ErrInvalidPathRegex, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -151,10 +152,8 @@ func (b *BotConfig) Valid() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(b.RemoteAddr) > 0 {
|
if len(b.RemoteAddr) > 0 {
|
||||||
for _, cidr := range b.RemoteAddr {
|
if err := remoteaddress.Valid(b.RemoteAddr); err != nil {
|
||||||
if _, _, err := net.ParseCIDR(cidr); err != nil {
|
errs = append(errs, err)
|
||||||
errs = append(errs, ErrInvalidCIDR, err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/TecharoHQ/anubis/data"
|
"github.com/TecharoHQ/anubis/data"
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker/remoteaddress"
|
||||||
. "github.com/TecharoHQ/anubis/lib/policy/config"
|
. "github.com/TecharoHQ/anubis/lib/policy/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -137,7 +138,7 @@ func TestBotValid(t *testing.T) {
|
|||||||
Action: RuleAllow,
|
Action: RuleAllow,
|
||||||
RemoteAddr: []string{"0.0.0.0/33"},
|
RemoteAddr: []string{"0.0.0.0/33"},
|
||||||
},
|
},
|
||||||
err: ErrInvalidCIDR,
|
err: remoteaddress.ErrInvalidCIDR,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "only filter by IP range",
|
name: "only filter by IP range",
|
||||||
|
|||||||
@@ -9,6 +9,9 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/TecharoHQ/anubis/lib/checker"
|
"github.com/TecharoHQ/anubis/lib/checker"
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker/headermatches"
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker/path"
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker/remoteaddress"
|
||||||
"github.com/TecharoHQ/anubis/lib/policy/config"
|
"github.com/TecharoHQ/anubis/lib/policy/config"
|
||||||
"github.com/TecharoHQ/anubis/lib/store"
|
"github.com/TecharoHQ/anubis/lib/store"
|
||||||
"github.com/TecharoHQ/anubis/lib/thoth"
|
"github.com/TecharoHQ/anubis/lib/thoth"
|
||||||
@@ -73,10 +76,10 @@ func ParseConfig(ctx context.Context, fin io.Reader, fname string, defaultDiffic
|
|||||||
Action: b.Action,
|
Action: b.Action,
|
||||||
}
|
}
|
||||||
|
|
||||||
cl := checker.List{}
|
cl := checker.Any{}
|
||||||
|
|
||||||
if len(b.RemoteAddr) > 0 {
|
if len(b.RemoteAddr) > 0 {
|
||||||
c, err := NewRemoteAddrChecker(b.RemoteAddr)
|
c, err := remoteaddress.New(b.RemoteAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s remote addr set: %w", b.Name, err))
|
validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s remote addr set: %w", b.Name, err))
|
||||||
} else {
|
} else {
|
||||||
@@ -85,7 +88,7 @@ func ParseConfig(ctx context.Context, fin io.Reader, fname string, defaultDiffic
|
|||||||
}
|
}
|
||||||
|
|
||||||
if b.UserAgentRegex != nil {
|
if b.UserAgentRegex != nil {
|
||||||
c, err := NewUserAgentChecker(*b.UserAgentRegex)
|
c, err := headermatches.NewUserAgent(*b.UserAgentRegex)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s user agent regex: %w", b.Name, err))
|
validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s user agent regex: %w", b.Name, err))
|
||||||
} else {
|
} else {
|
||||||
@@ -94,7 +97,7 @@ func ParseConfig(ctx context.Context, fin io.Reader, fname string, defaultDiffic
|
|||||||
}
|
}
|
||||||
|
|
||||||
if b.PathRegex != nil {
|
if b.PathRegex != nil {
|
||||||
c, err := NewPathChecker(*b.PathRegex)
|
c, err := path.New(*b.PathRegex)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s path regex: %w", b.Name, err))
|
validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s path regex: %w", b.Name, err))
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
Reference in New Issue
Block a user