mirror of
https://github.com/TecharoHQ/anubis.git
synced 2026-04-08 09:38:45 +00:00
Compare commits
9 Commits
fix/nilpoi
...
Xe/checks-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9f3eb71ef6 | ||
|
|
a494d26708 | ||
|
|
e98d749bf2 | ||
|
|
590d8303ad | ||
|
|
88c30c70fc | ||
|
|
1c43349c4a | ||
|
|
178c60cf72 | ||
|
|
ecbbf77498 | ||
|
|
5307388841 |
@@ -31,6 +31,7 @@ import (
|
|||||||
"github.com/TecharoHQ/anubis/data"
|
"github.com/TecharoHQ/anubis/data"
|
||||||
"github.com/TecharoHQ/anubis/internal"
|
"github.com/TecharoHQ/anubis/internal"
|
||||||
libanubis "github.com/TecharoHQ/anubis/lib"
|
libanubis "github.com/TecharoHQ/anubis/lib"
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker/headerexists"
|
||||||
botPolicy "github.com/TecharoHQ/anubis/lib/policy"
|
botPolicy "github.com/TecharoHQ/anubis/lib/policy"
|
||||||
"github.com/TecharoHQ/anubis/lib/policy/config"
|
"github.com/TecharoHQ/anubis/lib/policy/config"
|
||||||
"github.com/TecharoHQ/anubis/lib/thoth"
|
"github.com/TecharoHQ/anubis/lib/thoth"
|
||||||
@@ -323,7 +324,7 @@ func main() {
|
|||||||
if *debugBenchmarkJS {
|
if *debugBenchmarkJS {
|
||||||
policy.Bots = []botPolicy.Bot{{
|
policy.Bots = []botPolicy.Bot{{
|
||||||
Name: "",
|
Name: "",
|
||||||
Rules: botPolicy.NewHeaderExistsChecker("User-Agent"),
|
Rules: headerexists.New("User-Agent"),
|
||||||
Action: config.RuleBenchmark,
|
Action: config.RuleBenchmark,
|
||||||
}}
|
}}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker/expression"
|
||||||
"github.com/TecharoHQ/anubis/lib/policy/config"
|
"github.com/TecharoHQ/anubis/lib/policy/config"
|
||||||
|
|
||||||
"sigs.k8s.io/yaml"
|
"sigs.k8s.io/yaml"
|
||||||
@@ -37,11 +38,11 @@ type RobotsRule struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type AnubisRule struct {
|
type AnubisRule struct {
|
||||||
Expression *config.ExpressionOrList `yaml:"expression,omitempty" json:"expression,omitempty"`
|
Expression *expression.Config `yaml:"expression,omitempty" json:"expression,omitempty"`
|
||||||
Challenge *config.ChallengeRules `yaml:"challenge,omitempty" json:"challenge,omitempty"`
|
Challenge *config.ChallengeRules `yaml:"challenge,omitempty" json:"challenge,omitempty"`
|
||||||
Weight *config.Weight `yaml:"weight,omitempty" json:"weight,omitempty"`
|
Weight *config.Weight `yaml:"weight,omitempty" json:"weight,omitempty"`
|
||||||
Name string `yaml:"name" json:"name"`
|
Name string `yaml:"name" json:"name"`
|
||||||
Action string `yaml:"action" json:"action"`
|
Action string `yaml:"action" json:"action"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@@ -224,11 +225,11 @@ func convertToAnubisRules(robotsRules []RobotsRule) []AnubisRule {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if userAgent == "*" {
|
if userAgent == "*" {
|
||||||
rule.Expression = &config.ExpressionOrList{
|
rule.Expression = &expression.Config{
|
||||||
All: []string{"true"}, // Always applies
|
All: []string{"true"}, // Always applies
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
rule.Expression = &config.ExpressionOrList{
|
rule.Expression = &expression.Config{
|
||||||
All: []string{fmt.Sprintf("userAgent.contains(%q)", userAgent)},
|
All: []string{fmt.Sprintf("userAgent.contains(%q)", userAgent)},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -249,11 +250,11 @@ func convertToAnubisRules(robotsRules []RobotsRule) []AnubisRule {
|
|||||||
rule.Name = fmt.Sprintf("%s-global-restriction-%d", *policyName, ruleCounter)
|
rule.Name = fmt.Sprintf("%s-global-restriction-%d", *policyName, ruleCounter)
|
||||||
rule.Action = "WEIGH"
|
rule.Action = "WEIGH"
|
||||||
rule.Weight = &config.Weight{Adjust: 20} // Increase difficulty significantly
|
rule.Weight = &config.Weight{Adjust: 20} // Increase difficulty significantly
|
||||||
rule.Expression = &config.ExpressionOrList{
|
rule.Expression = &expression.Config{
|
||||||
All: []string{"true"}, // Always applies
|
All: []string{"true"}, // Always applies
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
rule.Expression = &config.ExpressionOrList{
|
rule.Expression = &expression.Config{
|
||||||
All: []string{fmt.Sprintf("userAgent.contains(%q)", userAgent)},
|
All: []string{fmt.Sprintf("userAgent.contains(%q)", userAgent)},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -285,7 +286,7 @@ func convertToAnubisRules(robotsRules []RobotsRule) []AnubisRule {
|
|||||||
pathCondition := buildPathCondition(disallow)
|
pathCondition := buildPathCondition(disallow)
|
||||||
conditions = append(conditions, pathCondition)
|
conditions = append(conditions, pathCondition)
|
||||||
|
|
||||||
rule.Expression = &config.ExpressionOrList{
|
rule.Expression = &expression.Config{
|
||||||
All: conditions,
|
All: conditions,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
7
errors.go
Normal file
7
errors.go
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
package anubis
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrMisconfiguration = errors.New("[unexpected] policy: administrator misconfiguration")
|
||||||
|
)
|
||||||
@@ -28,15 +28,17 @@ import (
|
|||||||
"github.com/TecharoHQ/anubis/internal/dnsbl"
|
"github.com/TecharoHQ/anubis/internal/dnsbl"
|
||||||
"github.com/TecharoHQ/anubis/internal/ogtags"
|
"github.com/TecharoHQ/anubis/internal/ogtags"
|
||||||
"github.com/TecharoHQ/anubis/lib/challenge"
|
"github.com/TecharoHQ/anubis/lib/challenge"
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker"
|
||||||
"github.com/TecharoHQ/anubis/lib/localization"
|
"github.com/TecharoHQ/anubis/lib/localization"
|
||||||
"github.com/TecharoHQ/anubis/lib/policy"
|
"github.com/TecharoHQ/anubis/lib/policy"
|
||||||
"github.com/TecharoHQ/anubis/lib/policy/checker"
|
|
||||||
"github.com/TecharoHQ/anubis/lib/policy/config"
|
"github.com/TecharoHQ/anubis/lib/policy/config"
|
||||||
"github.com/TecharoHQ/anubis/lib/store"
|
"github.com/TecharoHQ/anubis/lib/store"
|
||||||
|
|
||||||
|
// checker implementations
|
||||||
|
_ "github.com/TecharoHQ/anubis/lib/checker/all"
|
||||||
|
|
||||||
// challenge implementations
|
// challenge implementations
|
||||||
_ "github.com/TecharoHQ/anubis/lib/challenge/metarefresh"
|
_ "github.com/TecharoHQ/anubis/lib/challenge/all"
|
||||||
_ "github.com/TecharoHQ/anubis/lib/challenge/proofofwork"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -549,7 +551,7 @@ func (s *Server) check(r *http.Request) (policy.CheckResult, *policy.Bot, error)
|
|||||||
if matches {
|
if matches {
|
||||||
return cr("threshold/"+t.Name, t.Action, weight), &policy.Bot{
|
return cr("threshold/"+t.Name, t.Action, weight), &policy.Bot{
|
||||||
Challenge: t.Challenge,
|
Challenge: t.Challenge,
|
||||||
Rules: &checker.List{},
|
Rules: &checker.Any{},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -560,6 +562,6 @@ func (s *Server) check(r *http.Request) (policy.CheckResult, *policy.Bot, error)
|
|||||||
ReportAs: s.policy.DefaultDifficulty,
|
ReportAs: s.policy.DefaultDifficulty,
|
||||||
Algorithm: config.DefaultAlgorithm,
|
Algorithm: config.DefaultAlgorithm,
|
||||||
},
|
},
|
||||||
Rules: &checker.List{},
|
Rules: &checker.Any{},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
6
lib/challenge/all/all.go
Normal file
6
lib/challenge/all/all.go
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
package all
|
||||||
|
|
||||||
|
import (
|
||||||
|
_ "github.com/TecharoHQ/anubis/lib/challenge/metarefresh"
|
||||||
|
_ "github.com/TecharoHQ/anubis/lib/challenge/proofofwork"
|
||||||
|
)
|
||||||
35
lib/checker/all.go
Normal file
35
lib/checker/all.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package checker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/TecharoHQ/anubis/internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
type All []Interface
|
||||||
|
|
||||||
|
func (a All) Check(r *http.Request) (bool, error) {
|
||||||
|
for _, c := range a {
|
||||||
|
match, err := c.Check(r)
|
||||||
|
if err != nil {
|
||||||
|
return match, err
|
||||||
|
}
|
||||||
|
if !match {
|
||||||
|
return false, err // no match
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil // match
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a All) Hash() string {
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
for _, c := range a {
|
||||||
|
fmt.Fprintln(&sb, c.Hash())
|
||||||
|
}
|
||||||
|
|
||||||
|
return internal.FastHash(sb.String())
|
||||||
|
}
|
||||||
10
lib/checker/all/all.go
Normal file
10
lib/checker/all/all.go
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
// Package all imports all of the standard checker types.
|
||||||
|
package all
|
||||||
|
|
||||||
|
import (
|
||||||
|
_ "github.com/TecharoHQ/anubis/lib/checker/expression"
|
||||||
|
_ "github.com/TecharoHQ/anubis/lib/checker/headerexists"
|
||||||
|
_ "github.com/TecharoHQ/anubis/lib/checker/headermatches"
|
||||||
|
_ "github.com/TecharoHQ/anubis/lib/checker/path"
|
||||||
|
_ "github.com/TecharoHQ/anubis/lib/checker/remoteaddress"
|
||||||
|
)
|
||||||
70
lib/checker/all_test.go
Normal file
70
lib/checker/all_test.go
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
package checker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAll_Check(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
checkers []MockChecker
|
||||||
|
want bool
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "All match",
|
||||||
|
checkers: []MockChecker{
|
||||||
|
{Result: true, Err: nil},
|
||||||
|
{Result: true, Err: nil},
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "One not match",
|
||||||
|
checkers: []MockChecker{
|
||||||
|
{Result: true, Err: nil},
|
||||||
|
{Result: false, Err: nil},
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No match",
|
||||||
|
checkers: []MockChecker{
|
||||||
|
{Result: false, Err: nil},
|
||||||
|
{Result: false, Err: nil},
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Error encountered",
|
||||||
|
checkers: []MockChecker{
|
||||||
|
{Result: true, Err: nil},
|
||||||
|
{Result: false, Err: http.ErrNotSupported},
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var all All
|
||||||
|
for _, mc := range tt.checkers {
|
||||||
|
all = append(all, mc)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := all.Check(nil)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("All.Check() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("All.Check() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
35
lib/checker/any.go
Normal file
35
lib/checker/any.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package checker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/TecharoHQ/anubis/internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Any []Interface
|
||||||
|
|
||||||
|
func (a Any) Check(r *http.Request) (bool, error) {
|
||||||
|
for _, c := range a {
|
||||||
|
match, err := c.Check(r)
|
||||||
|
if err != nil {
|
||||||
|
return match, err
|
||||||
|
}
|
||||||
|
if match {
|
||||||
|
return true, err // match
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil // no match
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a Any) Hash() string {
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
for _, c := range a {
|
||||||
|
fmt.Fprintln(&sb, c.Hash())
|
||||||
|
}
|
||||||
|
|
||||||
|
return internal.FastHash(sb.String())
|
||||||
|
}
|
||||||
83
lib/checker/any_test.go
Normal file
83
lib/checker/any_test.go
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
package checker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MockChecker struct {
|
||||||
|
Result bool
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m MockChecker) Check(r *http.Request) (bool, error) {
|
||||||
|
return m.Result, m.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m MockChecker) Hash() string {
|
||||||
|
return "mock-hash"
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAny_Check(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
checkers []MockChecker
|
||||||
|
want bool
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "All match",
|
||||||
|
checkers: []MockChecker{
|
||||||
|
{Result: true, Err: nil},
|
||||||
|
{Result: true, Err: nil},
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "One match",
|
||||||
|
checkers: []MockChecker{
|
||||||
|
{Result: false, Err: nil},
|
||||||
|
{Result: true, Err: nil},
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No match",
|
||||||
|
checkers: []MockChecker{
|
||||||
|
{Result: false, Err: nil},
|
||||||
|
{Result: false, Err: nil},
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Error encountered",
|
||||||
|
checkers: []MockChecker{
|
||||||
|
{Result: false, Err: nil},
|
||||||
|
{Result: false, Err: http.ErrNotSupported},
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var any Any
|
||||||
|
for _, mc := range tt.checkers {
|
||||||
|
any = append(any, mc)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := any.Check(nil)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Any.Check() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("Any.Check() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
17
lib/checker/checker.go
Normal file
17
lib/checker/checker.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
// Package checker defines the Checker interface and a helper utility to avoid import cycles.
|
||||||
|
package checker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrUnparseableConfig = errors.New("checker: config is unparseable")
|
||||||
|
ErrInvalidConfig = errors.New("checker: config is invalid")
|
||||||
|
)
|
||||||
|
|
||||||
|
type Interface interface {
|
||||||
|
Check(*http.Request) (matches bool, err error)
|
||||||
|
Hash() string
|
||||||
|
}
|
||||||
@@ -1,43 +1,44 @@
|
|||||||
package policy
|
package expression
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/TecharoHQ/anubis/internal"
|
"github.com/TecharoHQ/anubis/internal"
|
||||||
"github.com/TecharoHQ/anubis/lib/policy/config"
|
"github.com/TecharoHQ/anubis/lib/checker/expression/environment"
|
||||||
"github.com/TecharoHQ/anubis/lib/policy/expressions"
|
|
||||||
"github.com/google/cel-go/cel"
|
"github.com/google/cel-go/cel"
|
||||||
"github.com/google/cel-go/common/types"
|
"github.com/google/cel-go/common/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
type CELChecker struct {
|
type Checker struct {
|
||||||
program cel.Program
|
program cel.Program
|
||||||
src string
|
src string
|
||||||
|
hash string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewCELChecker(cfg *config.ExpressionOrList) (*CELChecker, error) {
|
func New(cfg *Config) (*Checker, error) {
|
||||||
env, err := expressions.BotEnvironment()
|
env, err := environment.Bot()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
program, err := expressions.Compile(env, cfg.String())
|
program, err := environment.Compile(env, cfg.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("can't compile CEL program: %w", err)
|
return nil, fmt.Errorf("can't compile CEL program: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &CELChecker{
|
return &Checker{
|
||||||
src: cfg.String(),
|
src: cfg.String(),
|
||||||
|
hash: internal.FastHash(cfg.String()),
|
||||||
program: program,
|
program: program,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cc *CELChecker) Hash() string {
|
func (cc *Checker) Hash() string {
|
||||||
return internal.FastHash(cc.src)
|
return cc.hash
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cc *CELChecker) Check(r *http.Request) (bool, error) {
|
func (cc *Checker) Check(r *http.Request) (bool, error) {
|
||||||
result, _, err := cc.program.ContextEval(r.Context(), &CELRequest{r})
|
result, _, err := cc.program.ContextEval(r.Context(), &CELRequest{r})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -70,15 +71,15 @@ func (cr *CELRequest) ResolveName(name string) (any, bool) {
|
|||||||
case "path":
|
case "path":
|
||||||
return cr.URL.Path, true
|
return cr.URL.Path, true
|
||||||
case "query":
|
case "query":
|
||||||
return expressions.URLValues{Values: cr.URL.Query()}, true
|
return URLValues{Values: cr.URL.Query()}, true
|
||||||
case "headers":
|
case "headers":
|
||||||
return expressions.HTTPHeaders{Header: cr.Header}, true
|
return HTTPHeaders{Header: cr.Header}, true
|
||||||
case "load_1m":
|
case "load_1m":
|
||||||
return expressions.Load1(), true
|
return Load1(), true
|
||||||
case "load_5m":
|
case "load_5m":
|
||||||
return expressions.Load5(), true
|
return Load5(), true
|
||||||
case "load_15m":
|
case "load_15m":
|
||||||
return expressions.Load15(), true
|
return Load15(), true
|
||||||
default:
|
default:
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package expression
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -9,18 +9,18 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrExpressionOrListMustBeStringOrObject = errors.New("config: this must be a string or an object")
|
ErrExpressionOrListMustBeStringOrObject = errors.New("expression: this must be a string or an object")
|
||||||
ErrExpressionEmpty = errors.New("config: this expression is empty")
|
ErrExpressionEmpty = errors.New("expression: this expression is empty")
|
||||||
ErrExpressionCantHaveBoth = errors.New("config: expression block can't contain multiple expression types")
|
ErrExpressionCantHaveBoth = errors.New("expression: expression block can't contain multiple expression types")
|
||||||
)
|
)
|
||||||
|
|
||||||
type ExpressionOrList struct {
|
type Config struct {
|
||||||
Expression string `json:"-" yaml:"-"`
|
Expression string `json:"-" yaml:"-"`
|
||||||
All []string `json:"all,omitempty" yaml:"all,omitempty"`
|
All []string `json:"all,omitempty" yaml:"all,omitempty"`
|
||||||
Any []string `json:"any,omitempty" yaml:"any,omitempty"`
|
Any []string `json:"any,omitempty" yaml:"any,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (eol ExpressionOrList) String() string {
|
func (eol Config) String() string {
|
||||||
switch {
|
switch {
|
||||||
case len(eol.Expression) != 0:
|
case len(eol.Expression) != 0:
|
||||||
return eol.Expression
|
return eol.Expression
|
||||||
@@ -46,7 +46,7 @@ func (eol ExpressionOrList) String() string {
|
|||||||
panic("this should not happen")
|
panic("this should not happen")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (eol ExpressionOrList) Equal(rhs *ExpressionOrList) bool {
|
func (eol Config) Equal(rhs *Config) bool {
|
||||||
if eol.Expression != rhs.Expression {
|
if eol.Expression != rhs.Expression {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -62,7 +62,7 @@ func (eol ExpressionOrList) Equal(rhs *ExpressionOrList) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (eol *ExpressionOrList) MarshalYAML() (any, error) {
|
func (eol *Config) MarshalYAML() (any, error) {
|
||||||
switch {
|
switch {
|
||||||
case len(eol.All) == 1 && len(eol.Any) == 0:
|
case len(eol.All) == 1 && len(eol.Any) == 0:
|
||||||
eol.Expression = eol.All[0]
|
eol.Expression = eol.All[0]
|
||||||
@@ -76,11 +76,11 @@ func (eol *ExpressionOrList) MarshalYAML() (any, error) {
|
|||||||
return eol.Expression, nil
|
return eol.Expression, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type RawExpressionOrList ExpressionOrList
|
type RawExpressionOrList Config
|
||||||
return RawExpressionOrList(*eol), nil
|
return RawExpressionOrList(*eol), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (eol *ExpressionOrList) MarshalJSON() ([]byte, error) {
|
func (eol *Config) MarshalJSON() ([]byte, error) {
|
||||||
switch {
|
switch {
|
||||||
case len(eol.All) == 1 && len(eol.Any) == 0:
|
case len(eol.All) == 1 && len(eol.Any) == 0:
|
||||||
eol.Expression = eol.All[0]
|
eol.Expression = eol.All[0]
|
||||||
@@ -94,17 +94,17 @@ func (eol *ExpressionOrList) MarshalJSON() ([]byte, error) {
|
|||||||
return json.Marshal(string(eol.Expression))
|
return json.Marshal(string(eol.Expression))
|
||||||
}
|
}
|
||||||
|
|
||||||
type RawExpressionOrList ExpressionOrList
|
type RawExpressionOrList Config
|
||||||
val := RawExpressionOrList(*eol)
|
val := RawExpressionOrList(*eol)
|
||||||
return json.Marshal(val)
|
return json.Marshal(val)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (eol *ExpressionOrList) UnmarshalJSON(data []byte) error {
|
func (eol *Config) UnmarshalJSON(data []byte) error {
|
||||||
switch string(data[0]) {
|
switch string(data[0]) {
|
||||||
case `"`: // string
|
case `"`: // string
|
||||||
return json.Unmarshal(data, &eol.Expression)
|
return json.Unmarshal(data, &eol.Expression)
|
||||||
case "{": // object
|
case "{": // object
|
||||||
type RawExpressionOrList ExpressionOrList
|
type RawExpressionOrList Config
|
||||||
var val RawExpressionOrList
|
var val RawExpressionOrList
|
||||||
if err := json.Unmarshal(data, &val); err != nil {
|
if err := json.Unmarshal(data, &val); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -118,7 +118,7 @@ func (eol *ExpressionOrList) UnmarshalJSON(data []byte) error {
|
|||||||
return ErrExpressionOrListMustBeStringOrObject
|
return ErrExpressionOrListMustBeStringOrObject
|
||||||
}
|
}
|
||||||
|
|
||||||
func (eol *ExpressionOrList) Valid() error {
|
func (eol *Config) Valid() error {
|
||||||
if eol.Expression == "" && len(eol.All) == 0 && len(eol.Any) == 0 {
|
if eol.Expression == "" && len(eol.All) == 0 && len(eol.Any) == 0 {
|
||||||
return ErrExpressionEmpty
|
return ErrExpressionEmpty
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package expression
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@@ -12,13 +12,13 @@ import (
|
|||||||
func TestExpressionOrListMarshalJSON(t *testing.T) {
|
func TestExpressionOrListMarshalJSON(t *testing.T) {
|
||||||
for _, tt := range []struct {
|
for _, tt := range []struct {
|
||||||
name string
|
name string
|
||||||
input *ExpressionOrList
|
input *Config
|
||||||
output []byte
|
output []byte
|
||||||
err error
|
err error
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "single expression",
|
name: "single expression",
|
||||||
input: &ExpressionOrList{
|
input: &Config{
|
||||||
Expression: "true",
|
Expression: "true",
|
||||||
},
|
},
|
||||||
output: []byte(`"true"`),
|
output: []byte(`"true"`),
|
||||||
@@ -26,7 +26,7 @@ func TestExpressionOrListMarshalJSON(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "all",
|
name: "all",
|
||||||
input: &ExpressionOrList{
|
input: &Config{
|
||||||
All: []string{"true", "true"},
|
All: []string{"true", "true"},
|
||||||
},
|
},
|
||||||
output: []byte(`{"all":["true","true"]}`),
|
output: []byte(`{"all":["true","true"]}`),
|
||||||
@@ -34,7 +34,7 @@ func TestExpressionOrListMarshalJSON(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "all one",
|
name: "all one",
|
||||||
input: &ExpressionOrList{
|
input: &Config{
|
||||||
All: []string{"true"},
|
All: []string{"true"},
|
||||||
},
|
},
|
||||||
output: []byte(`"true"`),
|
output: []byte(`"true"`),
|
||||||
@@ -42,7 +42,7 @@ func TestExpressionOrListMarshalJSON(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "any",
|
name: "any",
|
||||||
input: &ExpressionOrList{
|
input: &Config{
|
||||||
Any: []string{"true", "false"},
|
Any: []string{"true", "false"},
|
||||||
},
|
},
|
||||||
output: []byte(`{"any":["true","false"]}`),
|
output: []byte(`{"any":["true","false"]}`),
|
||||||
@@ -50,7 +50,7 @@ func TestExpressionOrListMarshalJSON(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "any one",
|
name: "any one",
|
||||||
input: &ExpressionOrList{
|
input: &Config{
|
||||||
Any: []string{"true"},
|
Any: []string{"true"},
|
||||||
},
|
},
|
||||||
output: []byte(`"true"`),
|
output: []byte(`"true"`),
|
||||||
@@ -75,13 +75,13 @@ func TestExpressionOrListMarshalJSON(t *testing.T) {
|
|||||||
func TestExpressionOrListMarshalYAML(t *testing.T) {
|
func TestExpressionOrListMarshalYAML(t *testing.T) {
|
||||||
for _, tt := range []struct {
|
for _, tt := range []struct {
|
||||||
name string
|
name string
|
||||||
input *ExpressionOrList
|
input *Config
|
||||||
output []byte
|
output []byte
|
||||||
err error
|
err error
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "single expression",
|
name: "single expression",
|
||||||
input: &ExpressionOrList{
|
input: &Config{
|
||||||
Expression: "true",
|
Expression: "true",
|
||||||
},
|
},
|
||||||
output: []byte(`"true"`),
|
output: []byte(`"true"`),
|
||||||
@@ -89,7 +89,7 @@ func TestExpressionOrListMarshalYAML(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "all",
|
name: "all",
|
||||||
input: &ExpressionOrList{
|
input: &Config{
|
||||||
All: []string{"true", "true"},
|
All: []string{"true", "true"},
|
||||||
},
|
},
|
||||||
output: []byte(`all:
|
output: []byte(`all:
|
||||||
@@ -99,7 +99,7 @@ func TestExpressionOrListMarshalYAML(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "all one",
|
name: "all one",
|
||||||
input: &ExpressionOrList{
|
input: &Config{
|
||||||
All: []string{"true"},
|
All: []string{"true"},
|
||||||
},
|
},
|
||||||
output: []byte(`"true"`),
|
output: []byte(`"true"`),
|
||||||
@@ -107,7 +107,7 @@ func TestExpressionOrListMarshalYAML(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "any",
|
name: "any",
|
||||||
input: &ExpressionOrList{
|
input: &Config{
|
||||||
Any: []string{"true", "false"},
|
Any: []string{"true", "false"},
|
||||||
},
|
},
|
||||||
output: []byte(`any:
|
output: []byte(`any:
|
||||||
@@ -117,7 +117,7 @@ func TestExpressionOrListMarshalYAML(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "any one",
|
name: "any one",
|
||||||
input: &ExpressionOrList{
|
input: &Config{
|
||||||
Any: []string{"true"},
|
Any: []string{"true"},
|
||||||
},
|
},
|
||||||
output: []byte(`"true"`),
|
output: []byte(`"true"`),
|
||||||
@@ -145,14 +145,14 @@ func TestExpressionOrListUnmarshalJSON(t *testing.T) {
|
|||||||
for _, tt := range []struct {
|
for _, tt := range []struct {
|
||||||
err error
|
err error
|
||||||
validErr error
|
validErr error
|
||||||
result *ExpressionOrList
|
result *Config
|
||||||
name string
|
name string
|
||||||
inp string
|
inp string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "simple",
|
name: "simple",
|
||||||
inp: `"\"User-Agent\" in headers"`,
|
inp: `"\"User-Agent\" in headers"`,
|
||||||
result: &ExpressionOrList{
|
result: &Config{
|
||||||
Expression: `"User-Agent" in headers`,
|
Expression: `"User-Agent" in headers`,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -161,7 +161,7 @@ func TestExpressionOrListUnmarshalJSON(t *testing.T) {
|
|||||||
inp: `{
|
inp: `{
|
||||||
"all": ["\"User-Agent\" in headers"]
|
"all": ["\"User-Agent\" in headers"]
|
||||||
}`,
|
}`,
|
||||||
result: &ExpressionOrList{
|
result: &Config{
|
||||||
All: []string{
|
All: []string{
|
||||||
`"User-Agent" in headers`,
|
`"User-Agent" in headers`,
|
||||||
},
|
},
|
||||||
@@ -172,7 +172,7 @@ func TestExpressionOrListUnmarshalJSON(t *testing.T) {
|
|||||||
inp: `{
|
inp: `{
|
||||||
"any": ["\"User-Agent\" in headers"]
|
"any": ["\"User-Agent\" in headers"]
|
||||||
}`,
|
}`,
|
||||||
result: &ExpressionOrList{
|
result: &Config{
|
||||||
Any: []string{
|
Any: []string{
|
||||||
`"User-Agent" in headers`,
|
`"User-Agent" in headers`,
|
||||||
},
|
},
|
||||||
@@ -195,7 +195,7 @@ func TestExpressionOrListUnmarshalJSON(t *testing.T) {
|
|||||||
},
|
},
|
||||||
} {
|
} {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
var eol ExpressionOrList
|
var eol Config
|
||||||
|
|
||||||
if err := json.Unmarshal([]byte(tt.inp), &eol); !errors.Is(err, tt.err) {
|
if err := json.Unmarshal([]byte(tt.inp), &eol); !errors.Is(err, tt.err) {
|
||||||
t.Errorf("wanted unmarshal error: %v but got: %v", tt.err, err)
|
t.Errorf("wanted unmarshal error: %v but got: %v", tt.err, err)
|
||||||
@@ -217,40 +217,40 @@ func TestExpressionOrListUnmarshalJSON(t *testing.T) {
|
|||||||
func TestExpressionOrListString(t *testing.T) {
|
func TestExpressionOrListString(t *testing.T) {
|
||||||
for _, tt := range []struct {
|
for _, tt := range []struct {
|
||||||
name string
|
name string
|
||||||
in ExpressionOrList
|
in Config
|
||||||
out string
|
out string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "single expression",
|
name: "single expression",
|
||||||
in: ExpressionOrList{
|
in: Config{
|
||||||
Expression: "true",
|
Expression: "true",
|
||||||
},
|
},
|
||||||
out: "true",
|
out: "true",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "all",
|
name: "all",
|
||||||
in: ExpressionOrList{
|
in: Config{
|
||||||
All: []string{"true"},
|
All: []string{"true"},
|
||||||
},
|
},
|
||||||
out: "( true )",
|
out: "( true )",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "all with &&",
|
name: "all with &&",
|
||||||
in: ExpressionOrList{
|
in: Config{
|
||||||
All: []string{"true", "true"},
|
All: []string{"true", "true"},
|
||||||
},
|
},
|
||||||
out: "( true ) && ( true )",
|
out: "( true ) && ( true )",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "any",
|
name: "any",
|
||||||
in: ExpressionOrList{
|
in: Config{
|
||||||
All: []string{"true"},
|
All: []string{"true"},
|
||||||
},
|
},
|
||||||
out: "( true )",
|
out: "( true )",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "any with ||",
|
name: "any with ||",
|
||||||
in: ExpressionOrList{
|
in: Config{
|
||||||
Any: []string{"true", "true"},
|
Any: []string{"true", "true"},
|
||||||
},
|
},
|
||||||
out: "( true ) || ( true )",
|
out: "( true ) || ( true )",
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package expressions
|
package environment
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"math/rand/v2"
|
"math/rand/v2"
|
||||||
@@ -10,11 +10,11 @@ import (
|
|||||||
"github.com/google/cel-go/ext"
|
"github.com/google/cel-go/ext"
|
||||||
)
|
)
|
||||||
|
|
||||||
// BotEnvironment creates a new CEL environment, this is the set of
|
// Bot creates a new CEL environment, this is the set of variables and
|
||||||
// variables and functions that are passed into the CEL scope so that
|
// functions that are passed into the CEL scope so that Anubis can fail
|
||||||
// Anubis can fail loudly and early when something is invalid instead
|
// loudly and early when something is invalid instead of blowing up at
|
||||||
// of blowing up at runtime.
|
// runtime.
|
||||||
func BotEnvironment() (*cel.Env, error) {
|
func Bot() (*cel.Env, error) {
|
||||||
return New(
|
return New(
|
||||||
// Variables exposed to CEL programs:
|
// Variables exposed to CEL programs:
|
||||||
cel.Variable("remoteAddress", cel.StringType),
|
cel.Variable("remoteAddress", cel.StringType),
|
||||||
@@ -57,13 +57,14 @@ func BotEnvironment() (*cel.Env, error) {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewThreshold creates a new CEL environment for threshold checking.
|
// Threshold creates a new CEL environment for threshold checking.
|
||||||
func ThresholdEnvironment() (*cel.Env, error) {
|
func Threshold() (*cel.Env, error) {
|
||||||
return New(
|
return New(
|
||||||
cel.Variable("weight", cel.IntType),
|
cel.Variable("weight", cel.IntType),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// New creates a new base CEL environment.
|
||||||
func New(opts ...cel.EnvOption) (*cel.Env, error) {
|
func New(opts ...cel.EnvOption) (*cel.Env, error) {
|
||||||
args := []cel.EnvOption{
|
args := []cel.EnvOption{
|
||||||
ext.Strings(
|
ext.Strings(
|
||||||
@@ -95,7 +96,7 @@ func New(opts ...cel.EnvOption) (*cel.Env, error) {
|
|||||||
return cel.NewEnv(args...)
|
return cel.NewEnv(args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compile takes CEL environment and syntax tree then emits an optimized
|
// Compile takes a CEL environment and syntax tree then emits an optimized
|
||||||
// Program for execution.
|
// Program for execution.
|
||||||
func Compile(env *cel.Env, src string) (cel.Program, error) {
|
func Compile(env *cel.Env, src string) (cel.Program, error) {
|
||||||
intermediate, iss := env.Compile(src)
|
intermediate, iss := env.Compile(src)
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package expressions
|
package environment
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
@@ -6,8 +6,8 @@ import (
|
|||||||
"github.com/google/cel-go/common/types"
|
"github.com/google/cel-go/common/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestBotEnvironment(t *testing.T) {
|
func TestBot(t *testing.T) {
|
||||||
env, err := BotEnvironment()
|
env, err := Bot()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create bot environment: %v", err)
|
t.Fatalf("failed to create bot environment: %v", err)
|
||||||
}
|
}
|
||||||
@@ -108,8 +108,8 @@ func TestBotEnvironment(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestThresholdEnvironment(t *testing.T) {
|
func TestThreshold(t *testing.T) {
|
||||||
env, err := ThresholdEnvironment()
|
env, err := Threshold()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create threshold environment: %v", err)
|
t.Fatalf("failed to create threshold environment: %v", err)
|
||||||
}
|
}
|
||||||
43
lib/checker/expression/factory.go
Normal file
43
lib/checker/expression/factory.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package expression
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
checker.Register("expression", Factory{})
|
||||||
|
}
|
||||||
|
|
||||||
|
type Factory struct{}
|
||||||
|
|
||||||
|
func (f Factory) Build(ctx context.Context, data json.RawMessage) (checker.Interface, error) {
|
||||||
|
var fc = &Config{}
|
||||||
|
|
||||||
|
if err := json.Unmarshal([]byte(data), fc); err != nil {
|
||||||
|
return nil, errors.Join(checker.ErrUnparseableConfig, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := fc.Valid(); err != nil {
|
||||||
|
return nil, errors.Join(checker.ErrInvalidConfig, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return New(fc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Factory) Valid(ctx context.Context, data json.RawMessage) error {
|
||||||
|
var fc = &Config{}
|
||||||
|
|
||||||
|
if err := json.Unmarshal([]byte(data), fc); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := fc.Valid(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package expressions
|
package expression
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package expressions
|
package expression
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package expressions
|
package expression
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package expressions
|
package expression
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package expressions
|
package expression
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/url"
|
"net/url"
|
||||||
32
lib/checker/headerexists/checker.go
Normal file
32
lib/checker/headerexists/checker.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package headerexists
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/TecharoHQ/anubis/internal"
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker"
|
||||||
|
)
|
||||||
|
|
||||||
|
func New(key string) checker.Interface {
|
||||||
|
return headerExistsChecker{
|
||||||
|
header: strings.TrimSpace(http.CanonicalHeaderKey(key)),
|
||||||
|
hash: internal.FastHash(key),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type headerExistsChecker struct {
|
||||||
|
header, hash string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hec headerExistsChecker) Check(r *http.Request) (bool, error) {
|
||||||
|
if r.Header.Get(hec.header) != "" {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hec headerExistsChecker) Hash() string {
|
||||||
|
return hec.hash
|
||||||
|
}
|
||||||
57
lib/checker/headerexists/checker_test.go
Normal file
57
lib/checker/headerexists/checker_test.go
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
package headerexists
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestChecker(t *testing.T) {
|
||||||
|
fac := Factory{}
|
||||||
|
|
||||||
|
for _, tt := range []struct {
|
||||||
|
name string
|
||||||
|
header string
|
||||||
|
reqHeader string
|
||||||
|
ok bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "match",
|
||||||
|
header: "Authorization",
|
||||||
|
reqHeader: "Authorization",
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not_match",
|
||||||
|
header: "Authorization",
|
||||||
|
reqHeader: "Authentication",
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
hec, err := fac.Build(t.Context(), json.RawMessage(fmt.Sprintf("%q", tt.header)))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log(hec.Hash())
|
||||||
|
|
||||||
|
r, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("can't make request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Header.Set(tt.reqHeader, "hunter2")
|
||||||
|
|
||||||
|
ok, err := hec.Check(r)
|
||||||
|
|
||||||
|
if tt.ok != ok {
|
||||||
|
t.Errorf("ok: %v, wanted: %v", ok, tt.ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("err: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
40
lib/checker/headerexists/factory.go
Normal file
40
lib/checker/headerexists/factory.go
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
package headerexists
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Factory struct{}
|
||||||
|
|
||||||
|
func (f Factory) Build(ctx context.Context, data json.RawMessage) (checker.Interface, error) {
|
||||||
|
var headerName string
|
||||||
|
|
||||||
|
if err := json.Unmarshal([]byte(data), &headerName); err != nil {
|
||||||
|
return nil, fmt.Errorf("%w: want string", checker.ErrUnparseableConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := f.Valid(ctx, data); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return New(http.CanonicalHeaderKey(headerName)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (Factory) Valid(ctx context.Context, data json.RawMessage) error {
|
||||||
|
var headerName string
|
||||||
|
|
||||||
|
if err := json.Unmarshal([]byte(data), &headerName); err != nil {
|
||||||
|
return fmt.Errorf("%w: want string", checker.ErrUnparseableConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
if headerName == "" {
|
||||||
|
return fmt.Errorf("%w: string must not be empty", checker.ErrInvalidConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
60
lib/checker/headerexists/factory_test.go
Normal file
60
lib/checker/headerexists/factory_test.go
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
package headerexists
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFactoryGood(t *testing.T) {
|
||||||
|
files, err := os.ReadDir("./testdata/good")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fac := Factory{}
|
||||||
|
|
||||||
|
for _, fname := range files {
|
||||||
|
t.Run(fname.Name(), func(t *testing.T) {
|
||||||
|
data, err := os.ReadFile(filepath.Join("testdata", "good", fname.Name()))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := fac.Valid(t.Context(), json.RawMessage(data)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFactoryBad(t *testing.T) {
|
||||||
|
files, err := os.ReadDir("./testdata/bad")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fac := Factory{}
|
||||||
|
|
||||||
|
for _, fname := range files {
|
||||||
|
t.Run(fname.Name(), func(t *testing.T) {
|
||||||
|
data, err := os.ReadFile(filepath.Join("testdata", "bad", fname.Name()))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("Build", func(t *testing.T) {
|
||||||
|
if _, err := fac.Build(t.Context(), json.RawMessage(data)); err == nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Valid", func(t *testing.T) {
|
||||||
|
if err := fac.Valid(t.Context(), json.RawMessage(data)); err == nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
1
lib/checker/headerexists/testdata/bad/empty.json
vendored
Normal file
1
lib/checker/headerexists/testdata/bad/empty.json
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
""
|
||||||
1
lib/checker/headerexists/testdata/bad/object.json
vendored
Normal file
1
lib/checker/headerexists/testdata/bad/object.json
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
{}
|
||||||
1
lib/checker/headerexists/testdata/good/authorization.json
vendored
Normal file
1
lib/checker/headerexists/testdata/good/authorization.json
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"Authorization"
|
||||||
46
lib/checker/headermatches/checker.go
Normal file
46
lib/checker/headermatches/checker.go
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
package headermatches
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"regexp"
|
||||||
|
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Checker struct {
|
||||||
|
header string
|
||||||
|
regexp *regexp.Regexp
|
||||||
|
hash string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Checker) Check(r *http.Request) (bool, error) {
|
||||||
|
if c.regexp.MatchString(r.Header.Get(c.header)) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Checker) Hash() string {
|
||||||
|
return c.hash
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(key, valueRex string) (checker.Interface, error) {
|
||||||
|
fc := fileConfig{
|
||||||
|
Header: key,
|
||||||
|
ValueRegex: valueRex,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := fc.Valid(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(fc)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return Factory{}.Build(context.Background(), json.RawMessage(data))
|
||||||
|
}
|
||||||
98
lib/checker/headermatches/checker_test.go
Normal file
98
lib/checker/headermatches/checker_test.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
package headermatches
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestChecker(t *testing.T) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHeaderMatchesChecker(t *testing.T) {
|
||||||
|
fac := Factory{}
|
||||||
|
|
||||||
|
for _, tt := range []struct {
|
||||||
|
err error
|
||||||
|
name string
|
||||||
|
header string
|
||||||
|
rexStr string
|
||||||
|
reqHeaderKey string
|
||||||
|
reqHeaderValue string
|
||||||
|
ok bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "match",
|
||||||
|
header: "Cf-Worker",
|
||||||
|
rexStr: ".*",
|
||||||
|
reqHeaderKey: "Cf-Worker",
|
||||||
|
reqHeaderValue: "true",
|
||||||
|
ok: true,
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not_match",
|
||||||
|
header: "Cf-Worker",
|
||||||
|
rexStr: "false",
|
||||||
|
reqHeaderKey: "Cf-Worker",
|
||||||
|
reqHeaderValue: "true",
|
||||||
|
ok: false,
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not_present",
|
||||||
|
header: "Cf-Worker",
|
||||||
|
rexStr: "foobar",
|
||||||
|
reqHeaderKey: "Something-Else",
|
||||||
|
reqHeaderValue: "true",
|
||||||
|
ok: false,
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid_regex",
|
||||||
|
rexStr: "a(b",
|
||||||
|
err: ErrInvalidRegex,
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
fc := fileConfig{
|
||||||
|
Header: tt.header,
|
||||||
|
ValueRegex: tt.rexStr,
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(fc)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
hmc, err := fac.Build(t.Context(), json.RawMessage(data))
|
||||||
|
if err != nil && !errors.Is(err, tt.err) {
|
||||||
|
t.Fatalf("creating HeaderMatchesChecker failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.err != nil && hmc == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log(hmc.Hash())
|
||||||
|
|
||||||
|
r, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("can't make request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Header.Set(tt.reqHeaderKey, tt.reqHeaderValue)
|
||||||
|
|
||||||
|
ok, err := hmc.Check(r)
|
||||||
|
|
||||||
|
if tt.ok != ok {
|
||||||
|
t.Errorf("ok: %v, wanted: %v", ok, tt.ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil && tt.err != nil && !errors.Is(err, tt.err) {
|
||||||
|
t.Errorf("err: %v, wanted: %v", err, tt.err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
44
lib/checker/headermatches/config.go
Normal file
44
lib/checker/headermatches/config.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
package headermatches
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrNoHeader = errors.New("headermatches: no header is configured")
|
||||||
|
ErrNoValueRegex = errors.New("headermatches: no value regex is configured")
|
||||||
|
ErrInvalidRegex = errors.New("headermatches: value regex is invalid")
|
||||||
|
)
|
||||||
|
|
||||||
|
type fileConfig struct {
|
||||||
|
Header string `json:"header" yaml:"header"`
|
||||||
|
ValueRegex string `json:"value_regex" yaml:"value_regex"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fc fileConfig) String() string {
|
||||||
|
return fmt.Sprintf("header=%q value_regex=%q", fc.Header, fc.ValueRegex)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fc fileConfig) Valid() error {
|
||||||
|
var errs []error
|
||||||
|
|
||||||
|
if fc.Header == "" {
|
||||||
|
errs = append(errs, ErrNoHeader)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fc.ValueRegex == "" {
|
||||||
|
errs = append(errs, ErrNoValueRegex)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := regexp.Compile(fc.ValueRegex); err != nil {
|
||||||
|
errs = append(errs, ErrInvalidRegex, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errs) != 0 {
|
||||||
|
return errors.Join(errs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
55
lib/checker/headermatches/config_test.go
Normal file
55
lib/checker/headermatches/config_test.go
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
package headermatches
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFileConfigValid(t *testing.T) {
|
||||||
|
for _, tt := range []struct {
|
||||||
|
name, description string
|
||||||
|
in fileConfig
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple happy",
|
||||||
|
description: "the most common usecase",
|
||||||
|
in: fileConfig{
|
||||||
|
Header: "User-Agent",
|
||||||
|
ValueRegex: ".*",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no header",
|
||||||
|
description: "Header must be set, it is not",
|
||||||
|
in: fileConfig{
|
||||||
|
ValueRegex: ".*",
|
||||||
|
},
|
||||||
|
err: ErrNoHeader,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no value regex",
|
||||||
|
description: "ValueRegex must be set, it is not",
|
||||||
|
in: fileConfig{
|
||||||
|
Header: "User-Agent",
|
||||||
|
},
|
||||||
|
err: ErrNoValueRegex,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid regex",
|
||||||
|
description: "the user wrote an invalid value regular expression",
|
||||||
|
in: fileConfig{
|
||||||
|
Header: "User-Agent",
|
||||||
|
ValueRegex: "[a-z",
|
||||||
|
},
|
||||||
|
err: ErrInvalidRegex,
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if err := tt.in.Valid(); !errors.Is(err, tt.err) {
|
||||||
|
t.Log(tt.description)
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
66
lib/checker/headermatches/factory.go
Normal file
66
lib/checker/headermatches/factory.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
package headermatches
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"regexp"
|
||||||
|
|
||||||
|
"github.com/TecharoHQ/anubis/internal"
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
checker.Register("header_matches", Factory{})
|
||||||
|
checker.Register("user_agent", Factory{defaultHeader: "User-Agent"})
|
||||||
|
}
|
||||||
|
|
||||||
|
type Factory struct {
|
||||||
|
defaultHeader string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Factory) Build(ctx context.Context, data json.RawMessage) (checker.Interface, error) {
|
||||||
|
var fc fileConfig
|
||||||
|
|
||||||
|
if f.defaultHeader != "" {
|
||||||
|
fc.Header = f.defaultHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal([]byte(data), &fc); err != nil {
|
||||||
|
return nil, errors.Join(checker.ErrUnparseableConfig, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := fc.Valid(); err != nil {
|
||||||
|
return nil, errors.Join(checker.ErrInvalidConfig, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
valueRex, err := regexp.Compile(fc.ValueRegex)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Join(ErrInvalidRegex, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Checker{
|
||||||
|
header: http.CanonicalHeaderKey(fc.Header),
|
||||||
|
regexp: valueRex,
|
||||||
|
hash: internal.FastHash(fc.String()),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Factory) Valid(ctx context.Context, data json.RawMessage) error {
|
||||||
|
var fc fileConfig
|
||||||
|
|
||||||
|
if f.defaultHeader != "" {
|
||||||
|
fc.Header = f.defaultHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal([]byte(data), &fc); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := fc.Valid(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
52
lib/checker/headermatches/factory_test.go
Normal file
52
lib/checker/headermatches/factory_test.go
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
package headermatches
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFactoryGood(t *testing.T) {
|
||||||
|
files, err := os.ReadDir("./testdata/good")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fac := Factory{}
|
||||||
|
|
||||||
|
for _, fname := range files {
|
||||||
|
t.Run(fname.Name(), func(t *testing.T) {
|
||||||
|
data, err := os.ReadFile(filepath.Join("testdata", "good", fname.Name()))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := fac.Valid(t.Context(), json.RawMessage(data)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFactoryBad(t *testing.T) {
|
||||||
|
files, err := os.ReadDir("./testdata/bad")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fac := Factory{}
|
||||||
|
|
||||||
|
for _, fname := range files {
|
||||||
|
t.Run(fname.Name(), func(t *testing.T) {
|
||||||
|
data, err := os.ReadFile(filepath.Join("testdata", "bad", fname.Name()))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := fac.Valid(t.Context(), json.RawMessage(data)); err == nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
1
lib/checker/headermatches/testdata/bad/invalid_config.json
vendored
Normal file
1
lib/checker/headermatches/testdata/bad/invalid_config.json
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
}
|
||||||
4
lib/checker/headermatches/testdata/bad/invalid_value_regex.json
vendored
Normal file
4
lib/checker/headermatches/testdata/bad/invalid_value_regex.json
vendored
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"header": "User-Agent",
|
||||||
|
"value_regex": "a(b"
|
||||||
|
}
|
||||||
3
lib/checker/headermatches/testdata/bad/no_header.json
vendored
Normal file
3
lib/checker/headermatches/testdata/bad/no_header.json
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
{
|
||||||
|
"value_regex": "PaleMoon"
|
||||||
|
}
|
||||||
3
lib/checker/headermatches/testdata/bad/no_value_regex.json
vendored
Normal file
3
lib/checker/headermatches/testdata/bad/no_value_regex.json
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
{
|
||||||
|
"header": "User-Agent"
|
||||||
|
}
|
||||||
1
lib/checker/headermatches/testdata/bad/nothing.json
vendored
Normal file
1
lib/checker/headermatches/testdata/bad/nothing.json
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
{}
|
||||||
4
lib/checker/headermatches/testdata/good/simple.json
vendored
Normal file
4
lib/checker/headermatches/testdata/good/simple.json
vendored
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"header": "User-Agent",
|
||||||
|
"value_regex": "PaleMoon"
|
||||||
|
}
|
||||||
35
lib/checker/headermatches/useragent.go
Normal file
35
lib/checker/headermatches/useragent.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package headermatches
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ValidUserAgent(valueRex string) error {
|
||||||
|
fc := fileConfig{
|
||||||
|
Header: "User-Agent",
|
||||||
|
ValueRegex: valueRex,
|
||||||
|
}
|
||||||
|
|
||||||
|
return fc.Valid()
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUserAgent(valueRex string) (checker.Interface, error) {
|
||||||
|
fc := fileConfig{
|
||||||
|
Header: "User-Agent",
|
||||||
|
ValueRegex: valueRex,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := fc.Valid(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(fc)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return Factory{}.Build(context.Background(), json.RawMessage(data))
|
||||||
|
}
|
||||||
37
lib/checker/path/checker.go
Normal file
37
lib/checker/path/checker.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package path
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/TecharoHQ/anubis"
|
||||||
|
"github.com/TecharoHQ/anubis/internal"
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker"
|
||||||
|
)
|
||||||
|
|
||||||
|
func New(rexStr string) (checker.Interface, error) {
|
||||||
|
rex, err := regexp.Compile(strings.TrimSpace(rexStr))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("%w: regex %s failed parse: %w", anubis.ErrMisconfiguration, rexStr, err)
|
||||||
|
}
|
||||||
|
return &Checker{rex, internal.FastHash(rexStr)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Checker struct {
|
||||||
|
regexp *regexp.Regexp
|
||||||
|
hash string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Checker) Check(r *http.Request) (bool, error) {
|
||||||
|
if c.regexp.MatchString(r.URL.Path) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Checker) Hash() string {
|
||||||
|
return c.hash
|
||||||
|
}
|
||||||
90
lib/checker/path/checker_test.go
Normal file
90
lib/checker/path/checker_test.go
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
package path
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestChecker(t *testing.T) {
|
||||||
|
fac := Factory{}
|
||||||
|
|
||||||
|
for _, tt := range []struct {
|
||||||
|
err error
|
||||||
|
name string
|
||||||
|
rexStr string
|
||||||
|
reqPath string
|
||||||
|
ok bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "match",
|
||||||
|
rexStr: "^/api/.*",
|
||||||
|
reqPath: "/api/v1/users",
|
||||||
|
ok: true,
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not_match",
|
||||||
|
rexStr: "^/api/.*",
|
||||||
|
reqPath: "/static/index.html",
|
||||||
|
ok: false,
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard_match",
|
||||||
|
rexStr: ".*\\.json$",
|
||||||
|
reqPath: "/data/config.json",
|
||||||
|
ok: true,
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard_not_match",
|
||||||
|
rexStr: ".*\\.json$",
|
||||||
|
reqPath: "/data/config.yaml",
|
||||||
|
ok: false,
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid_regex",
|
||||||
|
rexStr: "a(b",
|
||||||
|
err: ErrInvalidRegex,
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
fc := fileConfig{
|
||||||
|
Regex: tt.rexStr,
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(fc)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pc, err := fac.Build(t.Context(), json.RawMessage(data))
|
||||||
|
if err != nil && !errors.Is(err, tt.err) {
|
||||||
|
t.Fatalf("creating PathChecker failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.err != nil && pc == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log(pc.Hash())
|
||||||
|
|
||||||
|
r, err := http.NewRequest(http.MethodGet, tt.reqPath, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("can't make request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := pc.Check(r)
|
||||||
|
|
||||||
|
if tt.ok != ok {
|
||||||
|
t.Errorf("ok: %v, wanted: %v", ok, tt.ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil && tt.err != nil && !errors.Is(err, tt.err) {
|
||||||
|
t.Errorf("err: %v, wanted: %v", err, tt.err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
38
lib/checker/path/config.go
Normal file
38
lib/checker/path/config.go
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
package path
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrNoRegex = errors.New("path: no regex is configured")
|
||||||
|
ErrInvalidRegex = errors.New("path: regex is invalid")
|
||||||
|
)
|
||||||
|
|
||||||
|
type fileConfig struct {
|
||||||
|
Regex string `json:"regex" yaml:"regex"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fc fileConfig) String() string {
|
||||||
|
return fmt.Sprintf("regex=%q", fc.Regex)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fc fileConfig) Valid() error {
|
||||||
|
var errs []error
|
||||||
|
|
||||||
|
if fc.Regex == "" {
|
||||||
|
errs = append(errs, ErrNoRegex)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := regexp.Compile(fc.Regex); err != nil {
|
||||||
|
errs = append(errs, ErrInvalidRegex, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errs) != 0 {
|
||||||
|
return errors.Join(errs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
50
lib/checker/path/config_test.go
Normal file
50
lib/checker/path/config_test.go
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
package path
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFileConfigValid(t *testing.T) {
|
||||||
|
for _, tt := range []struct {
|
||||||
|
name, description string
|
||||||
|
in fileConfig
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple happy",
|
||||||
|
description: "the most common usecase",
|
||||||
|
in: fileConfig{
|
||||||
|
Regex: "^/api/.*",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard match",
|
||||||
|
description: "match files with specific extension",
|
||||||
|
in: fileConfig{
|
||||||
|
Regex: ".*[.]json$",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no regex",
|
||||||
|
description: "Regex must be set, it is not",
|
||||||
|
in: fileConfig{},
|
||||||
|
err: ErrNoRegex,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid regex",
|
||||||
|
description: "the user wrote an invalid regular expression",
|
||||||
|
in: fileConfig{
|
||||||
|
Regex: "[a-z",
|
||||||
|
},
|
||||||
|
err: ErrInvalidRegex,
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if err := tt.in.Valid(); !errors.Is(err, tt.err) {
|
||||||
|
t.Log(tt.description)
|
||||||
|
t.Fatalf("got %v, wanted %v", err, tt.err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
58
lib/checker/path/factory.go
Normal file
58
lib/checker/path/factory.go
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
package path
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/TecharoHQ/anubis/internal"
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
checker.Register("path", Factory{})
|
||||||
|
}
|
||||||
|
|
||||||
|
type Factory struct{}
|
||||||
|
|
||||||
|
func (f Factory) Build(ctx context.Context, data json.RawMessage) (checker.Interface, error) {
|
||||||
|
var fc fileConfig
|
||||||
|
|
||||||
|
if err := json.Unmarshal([]byte(data), &fc); err != nil {
|
||||||
|
return nil, errors.Join(checker.ErrUnparseableConfig, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := fc.Valid(); err != nil {
|
||||||
|
return nil, errors.Join(checker.ErrInvalidConfig, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pathRex, err := regexp.Compile(strings.TrimSpace(fc.Regex))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Join(ErrInvalidRegex, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Checker{
|
||||||
|
regexp: pathRex,
|
||||||
|
hash: internal.FastHash(fc.String()),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Factory) Valid(ctx context.Context, data json.RawMessage) error {
|
||||||
|
var fc fileConfig
|
||||||
|
|
||||||
|
if err := json.Unmarshal([]byte(data), &fc); err != nil {
|
||||||
|
return errors.Join(checker.ErrUnparseableConfig, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fc.Valid()
|
||||||
|
}
|
||||||
|
|
||||||
|
func Valid(pathRex string) error {
|
||||||
|
fc := fileConfig{
|
||||||
|
Regex: pathRex,
|
||||||
|
}
|
||||||
|
|
||||||
|
return fc.Valid()
|
||||||
|
}
|
||||||
52
lib/checker/path/factory_test.go
Normal file
52
lib/checker/path/factory_test.go
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
package path
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFactoryGood(t *testing.T) {
|
||||||
|
files, err := os.ReadDir("./testdata/good")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fac := Factory{}
|
||||||
|
|
||||||
|
for _, fname := range files {
|
||||||
|
t.Run(fname.Name(), func(t *testing.T) {
|
||||||
|
data, err := os.ReadFile(filepath.Join("testdata", "good", fname.Name()))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := fac.Valid(t.Context(), json.RawMessage(data)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFactoryBad(t *testing.T) {
|
||||||
|
files, err := os.ReadDir("./testdata/bad")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fac := Factory{}
|
||||||
|
|
||||||
|
for _, fname := range files {
|
||||||
|
t.Run(fname.Name(), func(t *testing.T) {
|
||||||
|
data, err := os.ReadFile(filepath.Join("testdata", "bad", fname.Name()))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := fac.Valid(t.Context(), json.RawMessage(data)); err == nil {
|
||||||
|
t.Fatal("expected validation to fail")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
3
lib/checker/path/testdata/bad/invalid_regex.json
vendored
Normal file
3
lib/checker/path/testdata/bad/invalid_regex.json
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
{
|
||||||
|
"regex": "a(b"
|
||||||
|
}
|
||||||
1
lib/checker/path/testdata/bad/nothing.json
vendored
Normal file
1
lib/checker/path/testdata/bad/nothing.json
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
{}
|
||||||
3
lib/checker/path/testdata/good/simple.json
vendored
Normal file
3
lib/checker/path/testdata/good/simple.json
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
{
|
||||||
|
"regex": "^/api/.*"
|
||||||
|
}
|
||||||
3
lib/checker/path/testdata/good/wildcard.json
vendored
Normal file
3
lib/checker/path/testdata/good/wildcard.json
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
{
|
||||||
|
"regex": ".*\\.json$"
|
||||||
|
}
|
||||||
43
lib/checker/registry.go
Normal file
43
lib/checker/registry.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package checker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"sort"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Factory interface {
|
||||||
|
Build(context.Context, json.RawMessage) (Interface, error)
|
||||||
|
Valid(context.Context, json.RawMessage) error
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
registry map[string]Factory = map[string]Factory{}
|
||||||
|
regLock sync.RWMutex
|
||||||
|
)
|
||||||
|
|
||||||
|
func Register(name string, factory Factory) {
|
||||||
|
regLock.Lock()
|
||||||
|
defer regLock.Unlock()
|
||||||
|
|
||||||
|
registry[name] = factory
|
||||||
|
}
|
||||||
|
|
||||||
|
func Get(name string) (Factory, bool) {
|
||||||
|
regLock.RLock()
|
||||||
|
defer regLock.RUnlock()
|
||||||
|
result, ok := registry[name]
|
||||||
|
return result, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func Methods() []string {
|
||||||
|
regLock.RLock()
|
||||||
|
defer regLock.RUnlock()
|
||||||
|
var result []string
|
||||||
|
for method := range registry {
|
||||||
|
result = append(result, method)
|
||||||
|
}
|
||||||
|
sort.Strings(result)
|
||||||
|
return result
|
||||||
|
}
|
||||||
127
lib/checker/remoteaddress/remoteaddress.go
Normal file
127
lib/checker/remoteaddress/remoteaddress.go
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
package remoteaddress
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/TecharoHQ/anubis"
|
||||||
|
"github.com/TecharoHQ/anubis/internal"
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker"
|
||||||
|
"github.com/gaissmai/bart"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrNoRemoteAddresses = errors.New("remoteaddress: no remote addresses defined")
|
||||||
|
ErrInvalidCIDR = errors.New("remoteaddress: invalid CIDR")
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
checker.Register("remote_address", Factory{})
|
||||||
|
}
|
||||||
|
|
||||||
|
type Factory struct{}
|
||||||
|
|
||||||
|
func (Factory) Valid(_ context.Context, inp json.RawMessage) error {
|
||||||
|
var fc fileConfig
|
||||||
|
if err := json.Unmarshal([]byte(inp), &fc); err != nil {
|
||||||
|
return fmt.Errorf("%w: %w", checker.ErrUnparseableConfig, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := fc.Valid(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (Factory) Build(_ context.Context, inp json.RawMessage) (checker.Interface, error) {
|
||||||
|
c := struct {
|
||||||
|
RemoteAddr []netip.Prefix `json:"remote_addresses,omitempty" yaml:"remote_addresses,omitempty"`
|
||||||
|
}{}
|
||||||
|
|
||||||
|
if err := json.Unmarshal([]byte(inp), &c); err != nil {
|
||||||
|
return nil, fmt.Errorf("%w: %w", checker.ErrUnparseableConfig, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
table := new(bart.Lite)
|
||||||
|
|
||||||
|
for _, cidr := range c.RemoteAddr {
|
||||||
|
table.Insert(cidr)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &RemoteAddrChecker{
|
||||||
|
prefixTable: table,
|
||||||
|
hash: internal.FastHash(string(inp)),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type fileConfig struct {
|
||||||
|
RemoteAddr []string `json:"remote_addresses,omitempty" yaml:"remote_addresses,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fc fileConfig) Valid() error {
|
||||||
|
var errs []error
|
||||||
|
|
||||||
|
if len(fc.RemoteAddr) == 0 {
|
||||||
|
errs = append(errs, ErrNoRemoteAddresses)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, cidr := range fc.RemoteAddr {
|
||||||
|
if _, err := netip.ParsePrefix(cidr); err != nil {
|
||||||
|
errs = append(errs, fmt.Errorf("%w: cidr %q is invalid: %w", ErrInvalidCIDR, cidr, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errs) != 0 {
|
||||||
|
return fmt.Errorf("%w: %w", checker.ErrInvalidConfig, errors.Join(errs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Valid(cidrs []string) error {
|
||||||
|
fc := fileConfig{
|
||||||
|
RemoteAddr: cidrs,
|
||||||
|
}
|
||||||
|
|
||||||
|
return fc.Valid()
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(cidrs []string) (checker.Interface, error) {
|
||||||
|
fc := fileConfig{
|
||||||
|
RemoteAddr: cidrs,
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(fc)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return Factory{}.Build(context.Background(), json.RawMessage(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
type RemoteAddrChecker struct {
|
||||||
|
prefixTable *bart.Lite
|
||||||
|
hash string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rac *RemoteAddrChecker) Check(r *http.Request) (bool, error) {
|
||||||
|
host := r.Header.Get("X-Real-Ip")
|
||||||
|
if host == "" {
|
||||||
|
return false, fmt.Errorf("%w: header X-Real-Ip is not set", anubis.ErrMisconfiguration)
|
||||||
|
}
|
||||||
|
|
||||||
|
addr, err := netip.ParseAddr(host)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("%w: %s is not an IP address: %w", anubis.ErrMisconfiguration, host, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rac.prefixTable.Contains(addr), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rac *RemoteAddrChecker) Hash() string {
|
||||||
|
return rac.hash
|
||||||
|
}
|
||||||
138
lib/checker/remoteaddress/remoteaddress_test.go
Normal file
138
lib/checker/remoteaddress/remoteaddress_test.go
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
package remoteaddress_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
_ "embed"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker"
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker/remoteaddress"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFactoryIsCheckerFactory(t *testing.T) {
|
||||||
|
if _, ok := (any(remoteaddress.Factory{})).(checker.Factory); !ok {
|
||||||
|
t.Fatal("Factory is not an instance of checker.Factory")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFactoryValidateConfig(t *testing.T) {
|
||||||
|
f := remoteaddress.Factory{}
|
||||||
|
|
||||||
|
for _, tt := range []struct {
|
||||||
|
name string
|
||||||
|
data []byte
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "basic valid",
|
||||||
|
data: []byte(`{
|
||||||
|
"remote_addresses": [
|
||||||
|
"1.1.1.1/32"
|
||||||
|
]
|
||||||
|
}`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not json",
|
||||||
|
data: []byte(`]`),
|
||||||
|
err: checker.ErrUnparseableConfig,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no cidr",
|
||||||
|
data: []byte(`{
|
||||||
|
"remote_addresses": []
|
||||||
|
}`),
|
||||||
|
err: remoteaddress.ErrNoRemoteAddresses,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bad cidr",
|
||||||
|
data: []byte(`{
|
||||||
|
"remote_addresses": [
|
||||||
|
"according to all laws of aviation"
|
||||||
|
]
|
||||||
|
}`),
|
||||||
|
err: remoteaddress.ErrInvalidCIDR,
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
data := json.RawMessage(tt.data)
|
||||||
|
|
||||||
|
if err := f.Valid(t.Context(), data); !errors.Is(err, tt.err) {
|
||||||
|
t.Logf("want: %v", tt.err)
|
||||||
|
t.Logf("got: %v", err)
|
||||||
|
t.Fatal("validation didn't do what was expected")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFactoryCreate(t *testing.T) {
|
||||||
|
f := remoteaddress.Factory{}
|
||||||
|
|
||||||
|
for _, tt := range []struct {
|
||||||
|
name string
|
||||||
|
data []byte
|
||||||
|
err error
|
||||||
|
ip string
|
||||||
|
match bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "basic valid",
|
||||||
|
data: []byte(`{
|
||||||
|
"remote_addresses": [
|
||||||
|
"1.1.1.1/32"
|
||||||
|
]
|
||||||
|
}`),
|
||||||
|
ip: "1.1.1.1",
|
||||||
|
match: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bad cidr",
|
||||||
|
data: []byte(`{
|
||||||
|
"remote_addresses": [
|
||||||
|
"according to all laws of aviation"
|
||||||
|
]
|
||||||
|
}`),
|
||||||
|
err: checker.ErrUnparseableConfig,
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
data := json.RawMessage(tt.data)
|
||||||
|
|
||||||
|
impl, err := f.Build(t.Context(), data)
|
||||||
|
if !errors.Is(err, tt.err) {
|
||||||
|
t.Logf("want: %v", tt.err)
|
||||||
|
t.Logf("got: %v", err)
|
||||||
|
t.Fatal("creation didn't do what was expected")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("can't make request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.ip != "" {
|
||||||
|
r.Header.Add("X-Real-Ip", tt.ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
match, err := impl.Check(r)
|
||||||
|
|
||||||
|
if tt.match != match {
|
||||||
|
t.Errorf("match: %v, wanted: %v", match, tt.match)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil && tt.err != nil && !errors.Is(err, tt.err) {
|
||||||
|
t.Errorf("err: %v, wanted: %v", err, tt.err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if impl.Hash() == "" {
|
||||||
|
t.Error("hash method returns empty string")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
5
lib/checker/remoteaddress/testdata/invalid_bad_cidr.json
vendored
Normal file
5
lib/checker/remoteaddress/testdata/invalid_bad_cidr.json
vendored
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
{
|
||||||
|
"remote_addresses": [
|
||||||
|
"according to all laws of aviation"
|
||||||
|
]
|
||||||
|
}
|
||||||
3
lib/checker/remoteaddress/testdata/invalid_no_cidr.json
vendored
Normal file
3
lib/checker/remoteaddress/testdata/invalid_no_cidr.json
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
{
|
||||||
|
"remote_addresses": []
|
||||||
|
}
|
||||||
1
lib/checker/remoteaddress/testdata/invalid_not_json.json
vendored
Normal file
1
lib/checker/remoteaddress/testdata/invalid_not_json.json
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
]
|
||||||
5
lib/checker/remoteaddress/testdata/valid_addresses.json
vendored
Normal file
5
lib/checker/remoteaddress/testdata/valid_addresses.json
vendored
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
{
|
||||||
|
"remote_addresses": [
|
||||||
|
"1.1.1.1/32"
|
||||||
|
]
|
||||||
|
}
|
||||||
@@ -4,12 +4,12 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/TecharoHQ/anubis/internal"
|
"github.com/TecharoHQ/anubis/internal"
|
||||||
"github.com/TecharoHQ/anubis/lib/policy/checker"
|
"github.com/TecharoHQ/anubis/lib/checker"
|
||||||
"github.com/TecharoHQ/anubis/lib/policy/config"
|
"github.com/TecharoHQ/anubis/lib/policy/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Bot struct {
|
type Bot struct {
|
||||||
Rules checker.Impl
|
Rules checker.Interface
|
||||||
Challenge *config.ChallengeRules
|
Challenge *config.ChallengeRules
|
||||||
Weight *config.Weight
|
Weight *config.Weight
|
||||||
Name string
|
Name string
|
||||||
|
|||||||
@@ -3,153 +3,39 @@ package policy
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"sort"
|
||||||
"net/netip"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/TecharoHQ/anubis/internal"
|
"github.com/TecharoHQ/anubis/lib/checker"
|
||||||
"github.com/TecharoHQ/anubis/lib/policy/checker"
|
"github.com/TecharoHQ/anubis/lib/checker/headerexists"
|
||||||
"github.com/gaissmai/bart"
|
"github.com/TecharoHQ/anubis/lib/checker/headermatches"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
func NewHeadersChecker(headermap map[string]string) (checker.Interface, error) {
|
||||||
ErrMisconfiguration = errors.New("[unexpected] policy: administrator misconfiguration")
|
var result checker.All
|
||||||
)
|
|
||||||
|
|
||||||
type RemoteAddrChecker struct {
|
|
||||||
prefixTable *bart.Lite
|
|
||||||
hash string
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewRemoteAddrChecker(cidrs []string) (checker.Impl, error) {
|
|
||||||
table := new(bart.Lite)
|
|
||||||
|
|
||||||
for _, cidr := range cidrs {
|
|
||||||
prefix, err := netip.ParsePrefix(cidr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("%w: range %s not parsing: %w", ErrMisconfiguration, cidr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
table.Insert(prefix)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &RemoteAddrChecker{
|
|
||||||
prefixTable: table,
|
|
||||||
hash: internal.FastHash(strings.Join(cidrs, ",")),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rac *RemoteAddrChecker) Check(r *http.Request) (bool, error) {
|
|
||||||
host := r.Header.Get("X-Real-Ip")
|
|
||||||
if host == "" {
|
|
||||||
return false, fmt.Errorf("%w: header X-Real-Ip is not set", ErrMisconfiguration)
|
|
||||||
}
|
|
||||||
|
|
||||||
addr, err := netip.ParseAddr(host)
|
|
||||||
if err != nil {
|
|
||||||
return false, fmt.Errorf("%w: %s is not an IP address: %w", ErrMisconfiguration, host, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return rac.prefixTable.Contains(addr), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rac *RemoteAddrChecker) Hash() string {
|
|
||||||
return rac.hash
|
|
||||||
}
|
|
||||||
|
|
||||||
type HeaderMatchesChecker struct {
|
|
||||||
header string
|
|
||||||
regexp *regexp.Regexp
|
|
||||||
hash string
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewUserAgentChecker(rexStr string) (checker.Impl, error) {
|
|
||||||
return NewHeaderMatchesChecker("User-Agent", rexStr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewHeaderMatchesChecker(header, rexStr string) (checker.Impl, error) {
|
|
||||||
rex, err := regexp.Compile(strings.TrimSpace(rexStr))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("%w: regex %s failed parse: %w", ErrMisconfiguration, rexStr, err)
|
|
||||||
}
|
|
||||||
return &HeaderMatchesChecker{strings.TrimSpace(header), rex, internal.FastHash(header + ": " + rexStr)}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hmc *HeaderMatchesChecker) Check(r *http.Request) (bool, error) {
|
|
||||||
if hmc.regexp.MatchString(r.Header.Get(hmc.header)) {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hmc *HeaderMatchesChecker) Hash() string {
|
|
||||||
return hmc.hash
|
|
||||||
}
|
|
||||||
|
|
||||||
type PathChecker struct {
|
|
||||||
regexp *regexp.Regexp
|
|
||||||
hash string
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewPathChecker(rexStr string) (checker.Impl, error) {
|
|
||||||
rex, err := regexp.Compile(strings.TrimSpace(rexStr))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("%w: regex %s failed parse: %w", ErrMisconfiguration, rexStr, err)
|
|
||||||
}
|
|
||||||
return &PathChecker{rex, internal.FastHash(rexStr)}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pc *PathChecker) Check(r *http.Request) (bool, error) {
|
|
||||||
if pc.regexp.MatchString(r.URL.Path) {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pc *PathChecker) Hash() string {
|
|
||||||
return pc.hash
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewHeaderExistsChecker(key string) checker.Impl {
|
|
||||||
return headerExistsChecker{strings.TrimSpace(key)}
|
|
||||||
}
|
|
||||||
|
|
||||||
type headerExistsChecker struct {
|
|
||||||
header string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hec headerExistsChecker) Check(r *http.Request) (bool, error) {
|
|
||||||
if r.Header.Get(hec.header) != "" {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hec headerExistsChecker) Hash() string {
|
|
||||||
return internal.FastHash(hec.header)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewHeadersChecker(headermap map[string]string) (checker.Impl, error) {
|
|
||||||
var result checker.List
|
|
||||||
var errs []error
|
var errs []error
|
||||||
|
|
||||||
for key, rexStr := range headermap {
|
var keys []string
|
||||||
|
for key := range headermap {
|
||||||
|
keys = append(keys, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Strings(keys)
|
||||||
|
|
||||||
|
for _, key := range keys {
|
||||||
|
rexStr := headermap[key]
|
||||||
if rexStr == ".*" {
|
if rexStr == ".*" {
|
||||||
result = append(result, headerExistsChecker{strings.TrimSpace(key)})
|
result = append(result, headerexists.New(strings.TrimSpace(key)))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
rex, err := regexp.Compile(strings.TrimSpace(rexStr))
|
c, err := headermatches.New(key, rexStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errs = append(errs, fmt.Errorf("while compiling header %s regex %s: %w", key, rexStr, err))
|
errs = append(errs, fmt.Errorf("while parsing header %s regex %s: %w", key, rexStr, err))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
result = append(result, &HeaderMatchesChecker{key, rex, internal.FastHash(key + ": " + rexStr)})
|
result = append(result, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(errs) != 0 {
|
if len(errs) != 0 {
|
||||||
|
|||||||
@@ -1,41 +0,0 @@
|
|||||||
// Package checker defines the Checker interface and a helper utility to avoid import cycles.
|
|
||||||
package checker
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/TecharoHQ/anubis/internal"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Impl interface {
|
|
||||||
Check(*http.Request) (bool, error)
|
|
||||||
Hash() string
|
|
||||||
}
|
|
||||||
|
|
||||||
type List []Impl
|
|
||||||
|
|
||||||
func (l List) Check(r *http.Request) (bool, error) {
|
|
||||||
for _, c := range l {
|
|
||||||
ok, err := c.Check(r)
|
|
||||||
if err != nil {
|
|
||||||
return ok, err
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
return ok, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l List) Hash() string {
|
|
||||||
var sb strings.Builder
|
|
||||||
|
|
||||||
for _, c := range l {
|
|
||||||
fmt.Fprintln(&sb, c.Hash())
|
|
||||||
}
|
|
||||||
|
|
||||||
return internal.FastHash(sb.String())
|
|
||||||
}
|
|
||||||
@@ -1,200 +0,0 @@
|
|||||||
package policy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"net/http"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestRemoteAddrChecker(t *testing.T) {
|
|
||||||
for _, tt := range []struct {
|
|
||||||
err error
|
|
||||||
name string
|
|
||||||
ip string
|
|
||||||
cidrs []string
|
|
||||||
ok bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "match_ipv4",
|
|
||||||
cidrs: []string{"0.0.0.0/0"},
|
|
||||||
ip: "1.1.1.1",
|
|
||||||
ok: true,
|
|
||||||
err: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "match_ipv6",
|
|
||||||
cidrs: []string{"::/0"},
|
|
||||||
ip: "cafe:babe::",
|
|
||||||
ok: true,
|
|
||||||
err: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "not_match_ipv4",
|
|
||||||
cidrs: []string{"1.1.1.1/32"},
|
|
||||||
ip: "1.1.1.2",
|
|
||||||
ok: false,
|
|
||||||
err: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "not_match_ipv6",
|
|
||||||
cidrs: []string{"cafe:babe::/128"},
|
|
||||||
ip: "cafe:babe:4::/128",
|
|
||||||
ok: false,
|
|
||||||
err: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "no_ip_set",
|
|
||||||
cidrs: []string{"::/0"},
|
|
||||||
ok: false,
|
|
||||||
err: ErrMisconfiguration,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid_ip",
|
|
||||||
cidrs: []string{"::/0"},
|
|
||||||
ip: "According to all natural laws of aviation",
|
|
||||||
ok: false,
|
|
||||||
err: ErrMisconfiguration,
|
|
||||||
},
|
|
||||||
} {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
rac, err := NewRemoteAddrChecker(tt.cidrs)
|
|
||||||
if err != nil && !errors.Is(err, tt.err) {
|
|
||||||
t.Fatalf("creating RemoteAddrChecker failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r, err := http.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("can't make request: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.ip != "" {
|
|
||||||
r.Header.Add("X-Real-Ip", tt.ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, err := rac.Check(r)
|
|
||||||
|
|
||||||
if tt.ok != ok {
|
|
||||||
t.Errorf("ok: %v, wanted: %v", ok, tt.ok)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil && tt.err != nil && !errors.Is(err, tt.err) {
|
|
||||||
t.Errorf("err: %v, wanted: %v", err, tt.err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHeaderMatchesChecker(t *testing.T) {
|
|
||||||
for _, tt := range []struct {
|
|
||||||
err error
|
|
||||||
name string
|
|
||||||
header string
|
|
||||||
rexStr string
|
|
||||||
reqHeaderKey string
|
|
||||||
reqHeaderValue string
|
|
||||||
ok bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "match",
|
|
||||||
header: "Cf-Worker",
|
|
||||||
rexStr: ".*",
|
|
||||||
reqHeaderKey: "Cf-Worker",
|
|
||||||
reqHeaderValue: "true",
|
|
||||||
ok: true,
|
|
||||||
err: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "not_match",
|
|
||||||
header: "Cf-Worker",
|
|
||||||
rexStr: "false",
|
|
||||||
reqHeaderKey: "Cf-Worker",
|
|
||||||
reqHeaderValue: "true",
|
|
||||||
ok: false,
|
|
||||||
err: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "not_present",
|
|
||||||
header: "Cf-Worker",
|
|
||||||
rexStr: "foobar",
|
|
||||||
reqHeaderKey: "Something-Else",
|
|
||||||
reqHeaderValue: "true",
|
|
||||||
ok: false,
|
|
||||||
err: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid_regex",
|
|
||||||
rexStr: "a(b",
|
|
||||||
err: ErrMisconfiguration,
|
|
||||||
},
|
|
||||||
} {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
hmc, err := NewHeaderMatchesChecker(tt.header, tt.rexStr)
|
|
||||||
if err != nil && !errors.Is(err, tt.err) {
|
|
||||||
t.Fatalf("creating HeaderMatchesChecker failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.err != nil && hmc == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
r, err := http.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("can't make request: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.Header.Set(tt.reqHeaderKey, tt.reqHeaderValue)
|
|
||||||
|
|
||||||
ok, err := hmc.Check(r)
|
|
||||||
|
|
||||||
if tt.ok != ok {
|
|
||||||
t.Errorf("ok: %v, wanted: %v", ok, tt.ok)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil && tt.err != nil && !errors.Is(err, tt.err) {
|
|
||||||
t.Errorf("err: %v, wanted: %v", err, tt.err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHeaderExistsChecker(t *testing.T) {
|
|
||||||
for _, tt := range []struct {
|
|
||||||
name string
|
|
||||||
header string
|
|
||||||
reqHeader string
|
|
||||||
ok bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "match",
|
|
||||||
header: "Authorization",
|
|
||||||
reqHeader: "Authorization",
|
|
||||||
ok: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "not_match",
|
|
||||||
header: "Authorization",
|
|
||||||
reqHeader: "Authentication",
|
|
||||||
},
|
|
||||||
} {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
hec := headerExistsChecker{tt.header}
|
|
||||||
|
|
||||||
r, err := http.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("can't make request: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.Header.Set(tt.reqHeader, "hunter2")
|
|
||||||
|
|
||||||
ok, err := hec.Check(r)
|
|
||||||
|
|
||||||
if tt.ok != ok {
|
|
||||||
t.Errorf("ok: %v, wanted: %v", ok, tt.ok)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("err: %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"regexp"
|
"regexp"
|
||||||
@@ -13,6 +12,10 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/TecharoHQ/anubis/data"
|
"github.com/TecharoHQ/anubis/data"
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker/expression"
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker/headermatches"
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker/path"
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker/remoteaddress"
|
||||||
"k8s.io/apimachinery/pkg/util/yaml"
|
"k8s.io/apimachinery/pkg/util/yaml"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -25,12 +28,12 @@ var (
|
|||||||
ErrInvalidUserAgentRegex = errors.New("config.Bot: invalid user agent regex")
|
ErrInvalidUserAgentRegex = errors.New("config.Bot: invalid user agent regex")
|
||||||
ErrInvalidPathRegex = errors.New("config.Bot: invalid path regex")
|
ErrInvalidPathRegex = errors.New("config.Bot: invalid path regex")
|
||||||
ErrInvalidHeadersRegex = errors.New("config.Bot: invalid headers regex")
|
ErrInvalidHeadersRegex = errors.New("config.Bot: invalid headers regex")
|
||||||
ErrInvalidCIDR = errors.New("config.Bot: invalid CIDR")
|
|
||||||
ErrRegexEndsWithNewline = errors.New("config.Bot: regular expression ends with newline (try >- instead of > in yaml)")
|
ErrRegexEndsWithNewline = errors.New("config.Bot: regular expression ends with newline (try >- instead of > in yaml)")
|
||||||
ErrInvalidImportStatement = errors.New("config.ImportStatement: invalid source file")
|
ErrInvalidImportStatement = errors.New("config.ImportStatement: invalid source file")
|
||||||
ErrCantSetBotAndImportValuesAtOnce = errors.New("config.BotOrImport: can't set bot rules and import values at the same time")
|
ErrCantSetBotAndImportValuesAtOnce = errors.New("config.BotOrImport: can't set bot rules and import values at the same time")
|
||||||
ErrMustSetBotOrImportRules = errors.New("config.BotOrImport: rule definition is invalid, you must set either bot rules or an import statement, not both")
|
ErrMustSetBotOrImportRules = errors.New("config.BotOrImport: rule definition is invalid, you must set either bot rules or an import statement, not both")
|
||||||
ErrStatusCodeNotValid = errors.New("config.StatusCode: status code not valid, must be between 100 and 599")
|
ErrStatusCodeNotValid = errors.New("config.StatusCode: status code not valid, must be between 100 and 599")
|
||||||
|
ErrUnparseableConfig = errors.New("config: can't parse configuration file")
|
||||||
)
|
)
|
||||||
|
|
||||||
type Rule string
|
type Rule string
|
||||||
@@ -56,15 +59,15 @@ func (r Rule) Valid() error {
|
|||||||
const DefaultAlgorithm = "fast"
|
const DefaultAlgorithm = "fast"
|
||||||
|
|
||||||
type BotConfig struct {
|
type BotConfig struct {
|
||||||
UserAgentRegex *string `json:"user_agent_regex,omitempty" yaml:"user_agent_regex,omitempty"`
|
UserAgentRegex *string `json:"user_agent_regex,omitempty" yaml:"user_agent_regex,omitempty"`
|
||||||
PathRegex *string `json:"path_regex,omitempty" yaml:"path_regex,omitempty"`
|
PathRegex *string `json:"path_regex,omitempty" yaml:"path_regex,omitempty"`
|
||||||
HeadersRegex map[string]string `json:"headers_regex,omitempty" yaml:"headers_regex,omitempty"`
|
HeadersRegex map[string]string `json:"headers_regex,omitempty" yaml:"headers_regex,omitempty"`
|
||||||
Expression *ExpressionOrList `json:"expression,omitempty" yaml:"expression,omitempty"`
|
Expression *expression.Config `json:"expression,omitempty" yaml:"expression,omitempty"`
|
||||||
Challenge *ChallengeRules `json:"challenge,omitempty" yaml:"challenge,omitempty"`
|
Challenge *ChallengeRules `json:"challenge,omitempty" yaml:"challenge,omitempty"`
|
||||||
Weight *Weight `json:"weight,omitempty" yaml:"weight,omitempty"`
|
Weight *Weight `json:"weight,omitempty" yaml:"weight,omitempty"`
|
||||||
Name string `json:"name" yaml:"name"`
|
Name string `json:"name" yaml:"name"`
|
||||||
Action Rule `json:"action" yaml:"action"`
|
Action Rule `json:"action" yaml:"action"`
|
||||||
RemoteAddr []string `json:"remote_addresses,omitempty" yaml:"remote_addresses,omitempty"`
|
RemoteAddr []string `json:"remote_addresses,omitempty" yaml:"remote_addresses,omitempty"`
|
||||||
|
|
||||||
// Thoth features
|
// Thoth features
|
||||||
GeoIP *GeoIP `json:"geoip,omitempty"`
|
GeoIP *GeoIP `json:"geoip,omitempty"`
|
||||||
@@ -118,7 +121,7 @@ func (b *BotConfig) Valid() error {
|
|||||||
errs = append(errs, fmt.Errorf("%w: user agent regex: %q", ErrRegexEndsWithNewline, *b.UserAgentRegex))
|
errs = append(errs, fmt.Errorf("%w: user agent regex: %q", ErrRegexEndsWithNewline, *b.UserAgentRegex))
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := regexp.Compile(*b.UserAgentRegex); err != nil {
|
if err := headermatches.ValidUserAgent(*b.UserAgentRegex); err != nil {
|
||||||
errs = append(errs, ErrInvalidUserAgentRegex, err)
|
errs = append(errs, ErrInvalidUserAgentRegex, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -128,7 +131,7 @@ func (b *BotConfig) Valid() error {
|
|||||||
errs = append(errs, fmt.Errorf("%w: path regex: %q", ErrRegexEndsWithNewline, *b.PathRegex))
|
errs = append(errs, fmt.Errorf("%w: path regex: %q", ErrRegexEndsWithNewline, *b.PathRegex))
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := regexp.Compile(*b.PathRegex); err != nil {
|
if err := path.Valid(*b.PathRegex); err != nil {
|
||||||
errs = append(errs, ErrInvalidPathRegex, err)
|
errs = append(errs, ErrInvalidPathRegex, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -150,10 +153,8 @@ func (b *BotConfig) Valid() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(b.RemoteAddr) > 0 {
|
if len(b.RemoteAddr) > 0 {
|
||||||
for _, cidr := range b.RemoteAddr {
|
if err := remoteaddress.Valid(b.RemoteAddr); err != nil {
|
||||||
if _, _, err := net.ParseCIDR(cidr); err != nil {
|
errs = append(errs, err)
|
||||||
errs = append(errs, ErrInvalidCIDR, err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/TecharoHQ/anubis/data"
|
"github.com/TecharoHQ/anubis/data"
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker/remoteaddress"
|
||||||
. "github.com/TecharoHQ/anubis/lib/policy/config"
|
. "github.com/TecharoHQ/anubis/lib/policy/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -137,7 +138,7 @@ func TestBotValid(t *testing.T) {
|
|||||||
Action: RuleAllow,
|
Action: RuleAllow,
|
||||||
RemoteAddr: []string{"0.0.0.0/33"},
|
RemoteAddr: []string{"0.0.0.0/33"},
|
||||||
},
|
},
|
||||||
err: ErrInvalidCIDR,
|
err: remoteaddress.ErrInvalidCIDR,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "only filter by IP range",
|
name: "only filter by IP range",
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/TecharoHQ/anubis"
|
"github.com/TecharoHQ/anubis"
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker/expression"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -17,7 +18,7 @@ var (
|
|||||||
DefaultThresholds = []Threshold{
|
DefaultThresholds = []Threshold{
|
||||||
{
|
{
|
||||||
Name: "legacy-anubis-behaviour",
|
Name: "legacy-anubis-behaviour",
|
||||||
Expression: &ExpressionOrList{
|
Expression: &expression.Config{
|
||||||
Expression: "weight > 0",
|
Expression: "weight > 0",
|
||||||
},
|
},
|
||||||
Action: RuleChallenge,
|
Action: RuleChallenge,
|
||||||
@@ -31,10 +32,10 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Threshold struct {
|
type Threshold struct {
|
||||||
Name string `json:"name" yaml:"name"`
|
Name string `json:"name" yaml:"name"`
|
||||||
Expression *ExpressionOrList `json:"expression" yaml:"expression"`
|
Expression *expression.Config `json:"expression" yaml:"expression"`
|
||||||
Action Rule `json:"action" yaml:"action"`
|
Action Rule `json:"action" yaml:"action"`
|
||||||
Challenge *ChallengeRules `json:"challenge" yaml:"challenge"`
|
Challenge *ChallengeRules `json:"challenge" yaml:"challenge"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Threshold) Valid() error {
|
func (t Threshold) Valid() error {
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker/expression"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestThresholdValid(t *testing.T) {
|
func TestThresholdValid(t *testing.T) {
|
||||||
@@ -18,7 +20,7 @@ func TestThresholdValid(t *testing.T) {
|
|||||||
name: "basic allow",
|
name: "basic allow",
|
||||||
input: &Threshold{
|
input: &Threshold{
|
||||||
Name: "basic-allow",
|
Name: "basic-allow",
|
||||||
Expression: &ExpressionOrList{Expression: "true"},
|
Expression: &expression.Config{Expression: "true"},
|
||||||
Action: RuleAllow,
|
Action: RuleAllow,
|
||||||
},
|
},
|
||||||
err: nil,
|
err: nil,
|
||||||
@@ -27,7 +29,7 @@ func TestThresholdValid(t *testing.T) {
|
|||||||
name: "basic challenge",
|
name: "basic challenge",
|
||||||
input: &Threshold{
|
input: &Threshold{
|
||||||
Name: "basic-challenge",
|
Name: "basic-challenge",
|
||||||
Expression: &ExpressionOrList{Expression: "true"},
|
Expression: &expression.Config{Expression: "true"},
|
||||||
Action: RuleChallenge,
|
Action: RuleChallenge,
|
||||||
Challenge: &ChallengeRules{
|
Challenge: &ChallengeRules{
|
||||||
Algorithm: "fast",
|
Algorithm: "fast",
|
||||||
@@ -50,9 +52,9 @@ func TestThresholdValid(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "invalid expression",
|
name: "invalid expression",
|
||||||
input: &Threshold{
|
input: &Threshold{
|
||||||
Expression: &ExpressionOrList{},
|
Expression: &expression.Config{},
|
||||||
},
|
},
|
||||||
err: ErrExpressionEmpty,
|
err: expression.ErrExpressionEmpty,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "invalid action",
|
name: "invalid action",
|
||||||
|
|||||||
@@ -8,7 +8,11 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/TecharoHQ/anubis/lib/policy/checker"
|
"github.com/TecharoHQ/anubis/lib/checker"
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker/expression"
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker/headermatches"
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker/path"
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker/remoteaddress"
|
||||||
"github.com/TecharoHQ/anubis/lib/policy/config"
|
"github.com/TecharoHQ/anubis/lib/policy/config"
|
||||||
"github.com/TecharoHQ/anubis/lib/store"
|
"github.com/TecharoHQ/anubis/lib/store"
|
||||||
"github.com/TecharoHQ/anubis/lib/thoth"
|
"github.com/TecharoHQ/anubis/lib/thoth"
|
||||||
@@ -73,10 +77,10 @@ func ParseConfig(ctx context.Context, fin io.Reader, fname string, defaultDiffic
|
|||||||
Action: b.Action,
|
Action: b.Action,
|
||||||
}
|
}
|
||||||
|
|
||||||
cl := checker.List{}
|
cl := checker.Any{}
|
||||||
|
|
||||||
if len(b.RemoteAddr) > 0 {
|
if len(b.RemoteAddr) > 0 {
|
||||||
c, err := NewRemoteAddrChecker(b.RemoteAddr)
|
c, err := remoteaddress.New(b.RemoteAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s remote addr set: %w", b.Name, err))
|
validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s remote addr set: %w", b.Name, err))
|
||||||
} else {
|
} else {
|
||||||
@@ -85,7 +89,7 @@ func ParseConfig(ctx context.Context, fin io.Reader, fname string, defaultDiffic
|
|||||||
}
|
}
|
||||||
|
|
||||||
if b.UserAgentRegex != nil {
|
if b.UserAgentRegex != nil {
|
||||||
c, err := NewUserAgentChecker(*b.UserAgentRegex)
|
c, err := headermatches.NewUserAgent(*b.UserAgentRegex)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s user agent regex: %w", b.Name, err))
|
validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s user agent regex: %w", b.Name, err))
|
||||||
} else {
|
} else {
|
||||||
@@ -94,7 +98,7 @@ func ParseConfig(ctx context.Context, fin io.Reader, fname string, defaultDiffic
|
|||||||
}
|
}
|
||||||
|
|
||||||
if b.PathRegex != nil {
|
if b.PathRegex != nil {
|
||||||
c, err := NewPathChecker(*b.PathRegex)
|
c, err := path.New(*b.PathRegex)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s path regex: %w", b.Name, err))
|
validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s path regex: %w", b.Name, err))
|
||||||
} else {
|
} else {
|
||||||
@@ -112,7 +116,7 @@ func ParseConfig(ctx context.Context, fin io.Reader, fname string, defaultDiffic
|
|||||||
}
|
}
|
||||||
|
|
||||||
if b.Expression != nil {
|
if b.Expression != nil {
|
||||||
c, err := NewCELChecker(b.Expression)
|
c, err := expression.New(b.Expression)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s expressions: %w", b.Name, err))
|
validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s expressions: %w", b.Name, err))
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
package policy
|
package policy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/TecharoHQ/anubis/lib/checker/expression/environment"
|
||||||
"github.com/TecharoHQ/anubis/lib/policy/config"
|
"github.com/TecharoHQ/anubis/lib/policy/config"
|
||||||
"github.com/TecharoHQ/anubis/lib/policy/expressions"
|
|
||||||
"github.com/google/cel-go/cel"
|
"github.com/google/cel-go/cel"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -16,12 +16,12 @@ func ParsedThresholdFromConfig(t config.Threshold) (*Threshold, error) {
|
|||||||
Threshold: t,
|
Threshold: t,
|
||||||
}
|
}
|
||||||
|
|
||||||
env, err := expressions.ThresholdEnvironment()
|
env, err := environment.Threshold()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
program, err := expressions.Compile(env, t.Expression.String())
|
program, err := environment.Compile(env, t.Expression.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,11 +10,11 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/TecharoHQ/anubis/internal"
|
"github.com/TecharoHQ/anubis/internal"
|
||||||
"github.com/TecharoHQ/anubis/lib/policy/checker"
|
"github.com/TecharoHQ/anubis/lib/checker"
|
||||||
iptoasnv1 "github.com/TecharoHQ/thoth-proto/gen/techaro/thoth/iptoasn/v1"
|
iptoasnv1 "github.com/TecharoHQ/thoth-proto/gen/techaro/thoth/iptoasn/v1"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (c *Client) ASNCheckerFor(asns []uint32) checker.Impl {
|
func (c *Client) ASNCheckerFor(asns []uint32) checker.Interface {
|
||||||
asnMap := map[uint32]struct{}{}
|
asnMap := map[uint32]struct{}{}
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
fmt.Fprintln(&sb, "ASNChecker")
|
fmt.Fprintln(&sb, "ASNChecker")
|
||||||
|
|||||||
@@ -5,12 +5,12 @@ import (
|
|||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/TecharoHQ/anubis/lib/policy/checker"
|
"github.com/TecharoHQ/anubis/lib/checker"
|
||||||
"github.com/TecharoHQ/anubis/lib/thoth"
|
"github.com/TecharoHQ/anubis/lib/thoth"
|
||||||
iptoasnv1 "github.com/TecharoHQ/thoth-proto/gen/techaro/thoth/iptoasn/v1"
|
iptoasnv1 "github.com/TecharoHQ/thoth-proto/gen/techaro/thoth/iptoasn/v1"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ checker.Impl = &thoth.ASNChecker{}
|
var _ checker.Interface = &thoth.ASNChecker{}
|
||||||
|
|
||||||
func TestASNChecker(t *testing.T) {
|
func TestASNChecker(t *testing.T) {
|
||||||
cli := loadSecrets(t)
|
cli := loadSecrets(t)
|
||||||
|
|||||||
@@ -9,11 +9,11 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/TecharoHQ/anubis/lib/policy/checker"
|
"github.com/TecharoHQ/anubis/lib/checker"
|
||||||
iptoasnv1 "github.com/TecharoHQ/thoth-proto/gen/techaro/thoth/iptoasn/v1"
|
iptoasnv1 "github.com/TecharoHQ/thoth-proto/gen/techaro/thoth/iptoasn/v1"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (c *Client) GeoIPCheckerFor(countries []string) checker.Impl {
|
func (c *Client) GeoIPCheckerFor(countries []string) checker.Interface {
|
||||||
countryMap := map[string]struct{}{}
|
countryMap := map[string]struct{}{}
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
fmt.Fprintln(&sb, "GeoIPChecker")
|
fmt.Fprintln(&sb, "GeoIPChecker")
|
||||||
|
|||||||
@@ -5,11 +5,11 @@ import (
|
|||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/TecharoHQ/anubis/lib/policy/checker"
|
"github.com/TecharoHQ/anubis/lib/checker"
|
||||||
"github.com/TecharoHQ/anubis/lib/thoth"
|
"github.com/TecharoHQ/anubis/lib/thoth"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ checker.Impl = &thoth.GeoIPChecker{}
|
var _ checker.Interface = &thoth.GeoIPChecker{}
|
||||||
|
|
||||||
func TestGeoIPChecker(t *testing.T) {
|
func TestGeoIPChecker(t *testing.T) {
|
||||||
cli := loadSecrets(t)
|
cli := loadSecrets(t)
|
||||||
|
|||||||
Reference in New Issue
Block a user