Compare commits

..

9 Commits

Author SHA1 Message Date
Xe Iaso
9f3eb71ef6 refactor: get rid of package expressions by moving the code into package expression
Signed-off-by: Xe Iaso <me@xeiaso.net>
2025-07-25 19:58:41 +00:00
Xe Iaso
a494d26708 refactor: move cel environment creation to a subpackage
Signed-off-by: Xe Iaso <me@xeiaso.net>
2025-07-25 19:55:56 +00:00
Xe Iaso
e98d749bf2 refactor: move CEL checker to its own package
Signed-off-by: Xe Iaso <me@xeiaso.net>
2025-07-25 19:52:07 +00:00
Xe Iaso
590d8303ad refactor: use new checker types
Signed-off-by: Xe Iaso <me@xeiaso.net>
2025-07-25 19:39:14 +00:00
Xe Iaso
88c30c70fc feat(checker): port path checker
Signed-off-by: Xe Iaso <me@xeiaso.net>
2025-07-25 19:11:16 +00:00
Xe Iaso
1c43349c4a feat(checker): port other checkers over
Signed-off-by: Xe Iaso <me@xeiaso.net>
2025-07-25 18:45:54 +00:00
Xe Iaso
178c60cf72 refactor: raise checker to be a subpackage of lib
Signed-off-by: Xe Iaso <me@xeiaso.net>
2025-07-25 18:45:40 +00:00
Xe Iaso
ecbbf77498 refactor: move ErrMisconfiguration to top level
Signed-off-by: Xe Iaso <me@xeiaso.net>
2025-07-25 18:44:27 +00:00
Xe Iaso
5307388841 chore: start refactor of checkers into separate packages
Signed-off-by: Xe Iaso <me@xeiaso.net>
2025-07-25 17:38:18 +00:00
79 changed files with 2001 additions and 967 deletions

View File

@@ -31,6 +31,7 @@ import (
"github.com/TecharoHQ/anubis/data" "github.com/TecharoHQ/anubis/data"
"github.com/TecharoHQ/anubis/internal" "github.com/TecharoHQ/anubis/internal"
libanubis "github.com/TecharoHQ/anubis/lib" libanubis "github.com/TecharoHQ/anubis/lib"
"github.com/TecharoHQ/anubis/lib/checker/headerexists"
botPolicy "github.com/TecharoHQ/anubis/lib/policy" botPolicy "github.com/TecharoHQ/anubis/lib/policy"
"github.com/TecharoHQ/anubis/lib/policy/config" "github.com/TecharoHQ/anubis/lib/policy/config"
"github.com/TecharoHQ/anubis/lib/thoth" "github.com/TecharoHQ/anubis/lib/thoth"
@@ -323,7 +324,7 @@ func main() {
if *debugBenchmarkJS { if *debugBenchmarkJS {
policy.Bots = []botPolicy.Bot{{ policy.Bots = []botPolicy.Bot{{
Name: "", Name: "",
Rules: botPolicy.NewHeaderExistsChecker("User-Agent"), Rules: headerexists.New("User-Agent"),
Action: config.RuleBenchmark, Action: config.RuleBenchmark,
}} }}
} }

View File

@@ -12,6 +12,7 @@ import (
"regexp" "regexp"
"strings" "strings"
"github.com/TecharoHQ/anubis/lib/checker/expression"
"github.com/TecharoHQ/anubis/lib/policy/config" "github.com/TecharoHQ/anubis/lib/policy/config"
"sigs.k8s.io/yaml" "sigs.k8s.io/yaml"
@@ -37,11 +38,11 @@ type RobotsRule struct {
} }
type AnubisRule struct { type AnubisRule struct {
Expression *config.ExpressionOrList `yaml:"expression,omitempty" json:"expression,omitempty"` Expression *expression.Config `yaml:"expression,omitempty" json:"expression,omitempty"`
Challenge *config.ChallengeRules `yaml:"challenge,omitempty" json:"challenge,omitempty"` Challenge *config.ChallengeRules `yaml:"challenge,omitempty" json:"challenge,omitempty"`
Weight *config.Weight `yaml:"weight,omitempty" json:"weight,omitempty"` Weight *config.Weight `yaml:"weight,omitempty" json:"weight,omitempty"`
Name string `yaml:"name" json:"name"` Name string `yaml:"name" json:"name"`
Action string `yaml:"action" json:"action"` Action string `yaml:"action" json:"action"`
} }
func init() { func init() {
@@ -224,11 +225,11 @@ func convertToAnubisRules(robotsRules []RobotsRule) []AnubisRule {
} }
if userAgent == "*" { if userAgent == "*" {
rule.Expression = &config.ExpressionOrList{ rule.Expression = &expression.Config{
All: []string{"true"}, // Always applies All: []string{"true"}, // Always applies
} }
} else { } else {
rule.Expression = &config.ExpressionOrList{ rule.Expression = &expression.Config{
All: []string{fmt.Sprintf("userAgent.contains(%q)", userAgent)}, All: []string{fmt.Sprintf("userAgent.contains(%q)", userAgent)},
} }
} }
@@ -249,11 +250,11 @@ func convertToAnubisRules(robotsRules []RobotsRule) []AnubisRule {
rule.Name = fmt.Sprintf("%s-global-restriction-%d", *policyName, ruleCounter) rule.Name = fmt.Sprintf("%s-global-restriction-%d", *policyName, ruleCounter)
rule.Action = "WEIGH" rule.Action = "WEIGH"
rule.Weight = &config.Weight{Adjust: 20} // Increase difficulty significantly rule.Weight = &config.Weight{Adjust: 20} // Increase difficulty significantly
rule.Expression = &config.ExpressionOrList{ rule.Expression = &expression.Config{
All: []string{"true"}, // Always applies All: []string{"true"}, // Always applies
} }
} else { } else {
rule.Expression = &config.ExpressionOrList{ rule.Expression = &expression.Config{
All: []string{fmt.Sprintf("userAgent.contains(%q)", userAgent)}, All: []string{fmt.Sprintf("userAgent.contains(%q)", userAgent)},
} }
} }
@@ -285,7 +286,7 @@ func convertToAnubisRules(robotsRules []RobotsRule) []AnubisRule {
pathCondition := buildPathCondition(disallow) pathCondition := buildPathCondition(disallow)
conditions = append(conditions, pathCondition) conditions = append(conditions, pathCondition)
rule.Expression = &config.ExpressionOrList{ rule.Expression = &expression.Config{
All: conditions, All: conditions,
} }

View File

@@ -1,4 +1,3 @@
- import: (data)/bots/cloudflare-workers.yaml - import: (data)/bots/cloudflare-workers.yaml
- import: (data)/bots/headless-browsers.yaml - import: (data)/bots/headless-browsers.yaml
- import: (data)/bots/us-ai-scraper.yaml - import: (data)/bots/us-ai-scraper.yaml
- import: (data)/bots/custom-async-http-client.yaml

View File

@@ -1,5 +0,0 @@
- name: "custom-async-http-client"
user_agent_regex: "Custom-AsyncHttpClient"
action: WEIGH
weight:
adjust: 10

View File

@@ -14,8 +14,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
<!-- This changes the project to: --> <!-- This changes the project to: -->
- The [Thoth client](https://anubis.techaro.lol/docs/admin/thoth) is now public in the repo instead of being an internal package. - The [Thoth client](https://anubis.techaro.lol/docs/admin/thoth) is now public in the repo instead of being an internal package.
- [Custom-AsyncHttpClient](https://github.com/AsyncHttpClient/async-http-client)'s default User-Agent has an increased weight by default ([#852](https://github.com/TecharoHQ/anubis/issues/852)).
- The [`segments`](./admin/configuration/expressions.mdx#segments) function was added for splitting a path into its slash-separated segments.
## v1.21.3: Minfilia Warde - Echo 3 ## v1.21.3: Minfilia Warde - Echo 3

View File

@@ -232,39 +232,6 @@ This is best applied when doing explicit block rules, eg:
It seems counter-intuitive to allow known bad clients through sometimes, but this allows you to confuse attackers by making Anubis' behavior random. Adjust the thresholds and numbers as facts and circumstances demand. It seems counter-intuitive to allow known bad clients through sometimes, but this allows you to confuse attackers by making Anubis' behavior random. Adjust the thresholds and numbers as facts and circumstances demand.
### `segments`
Available in `bot` expressions.
```ts
function segments(path: string): string[];
```
`segments` returns the number of slash-separated path segments, ignoring the leading slash. Here is what it will return with some common paths:
| Input | Output |
| :----------------------- | :--------------------- |
| `segments("/")` | `[""]` |
| `segments("/foo/bar")` | `["foo", "bar"] ` |
| `segments("/users/xe/")` | `["users", "xe", ""] ` |
:::note
If the path ends with a `/`, then the last element of the result will be an empty string. This is because `/users/xe` and `/users/xe/` are semantically different paths.
:::
This is useful if you want to write rules that allow requests that have no query parameters only if they have less than two path segments:
```yaml
- name: two-path-segments-no-query
action: ALLOW
expression:
all:
- size(query) == 0
- size(segments(path)) < 2
```
## Life advice ## Life advice
Expressions are very powerful. This is a benefit and a burden. If you are not careful with your expression targeting, you will be liable to get yourself into trouble. If you are at all in doubt, throw a `CHALLENGE` over a `DENY`. Legitimate users can easily work around a `CHALLENGE` result with a [proof of work challenge](../../design/why-proof-of-work.mdx). Bots are less likely to be able to do this. Expressions are very powerful. This is a benefit and a burden. If you are not careful with your expression targeting, you will be liable to get yourself into trouble. If you are at all in doubt, throw a `CHALLENGE` over a `DENY`. Legitimate users can easily work around a `CHALLENGE` result with a [proof of work challenge](../../design/why-proof-of-work.mdx). Bots are less likely to be able to do this.

7
errors.go Normal file
View File

@@ -0,0 +1,7 @@
package anubis
import "errors"
var (
ErrMisconfiguration = errors.New("[unexpected] policy: administrator misconfiguration")
)

View File

@@ -28,15 +28,17 @@ import (
"github.com/TecharoHQ/anubis/internal/dnsbl" "github.com/TecharoHQ/anubis/internal/dnsbl"
"github.com/TecharoHQ/anubis/internal/ogtags" "github.com/TecharoHQ/anubis/internal/ogtags"
"github.com/TecharoHQ/anubis/lib/challenge" "github.com/TecharoHQ/anubis/lib/challenge"
"github.com/TecharoHQ/anubis/lib/checker"
"github.com/TecharoHQ/anubis/lib/localization" "github.com/TecharoHQ/anubis/lib/localization"
"github.com/TecharoHQ/anubis/lib/policy" "github.com/TecharoHQ/anubis/lib/policy"
"github.com/TecharoHQ/anubis/lib/policy/checker"
"github.com/TecharoHQ/anubis/lib/policy/config" "github.com/TecharoHQ/anubis/lib/policy/config"
"github.com/TecharoHQ/anubis/lib/store" "github.com/TecharoHQ/anubis/lib/store"
// checker implementations
_ "github.com/TecharoHQ/anubis/lib/checker/all"
// challenge implementations // challenge implementations
_ "github.com/TecharoHQ/anubis/lib/challenge/metarefresh" _ "github.com/TecharoHQ/anubis/lib/challenge/all"
_ "github.com/TecharoHQ/anubis/lib/challenge/proofofwork"
) )
var ( var (
@@ -549,7 +551,7 @@ func (s *Server) check(r *http.Request) (policy.CheckResult, *policy.Bot, error)
if matches { if matches {
return cr("threshold/"+t.Name, t.Action, weight), &policy.Bot{ return cr("threshold/"+t.Name, t.Action, weight), &policy.Bot{
Challenge: t.Challenge, Challenge: t.Challenge,
Rules: &checker.List{}, Rules: &checker.Any{},
}, nil }, nil
} }
} }
@@ -560,6 +562,6 @@ func (s *Server) check(r *http.Request) (policy.CheckResult, *policy.Bot, error)
ReportAs: s.policy.DefaultDifficulty, ReportAs: s.policy.DefaultDifficulty,
Algorithm: config.DefaultAlgorithm, Algorithm: config.DefaultAlgorithm,
}, },
Rules: &checker.List{}, Rules: &checker.Any{},
}, nil }, nil
} }

6
lib/challenge/all/all.go Normal file
View File

@@ -0,0 +1,6 @@
package all
import (
_ "github.com/TecharoHQ/anubis/lib/challenge/metarefresh"
_ "github.com/TecharoHQ/anubis/lib/challenge/proofofwork"
)

35
lib/checker/all.go Normal file
View File

@@ -0,0 +1,35 @@
package checker
import (
"fmt"
"net/http"
"strings"
"github.com/TecharoHQ/anubis/internal"
)
type All []Interface
func (a All) Check(r *http.Request) (bool, error) {
for _, c := range a {
match, err := c.Check(r)
if err != nil {
return match, err
}
if !match {
return false, err // no match
}
}
return true, nil // match
}
func (a All) Hash() string {
var sb strings.Builder
for _, c := range a {
fmt.Fprintln(&sb, c.Hash())
}
return internal.FastHash(sb.String())
}

10
lib/checker/all/all.go Normal file
View File

@@ -0,0 +1,10 @@
// Package all imports all of the standard checker types.
package all
import (
_ "github.com/TecharoHQ/anubis/lib/checker/expression"
_ "github.com/TecharoHQ/anubis/lib/checker/headerexists"
_ "github.com/TecharoHQ/anubis/lib/checker/headermatches"
_ "github.com/TecharoHQ/anubis/lib/checker/path"
_ "github.com/TecharoHQ/anubis/lib/checker/remoteaddress"
)

70
lib/checker/all_test.go Normal file
View File

@@ -0,0 +1,70 @@
package checker
import (
"net/http"
"testing"
)
func TestAll_Check(t *testing.T) {
tests := []struct {
name string
checkers []MockChecker
want bool
wantErr bool
}{
{
name: "All match",
checkers: []MockChecker{
{Result: true, Err: nil},
{Result: true, Err: nil},
},
want: true,
wantErr: false,
},
{
name: "One not match",
checkers: []MockChecker{
{Result: true, Err: nil},
{Result: false, Err: nil},
},
want: false,
wantErr: false,
},
{
name: "No match",
checkers: []MockChecker{
{Result: false, Err: nil},
{Result: false, Err: nil},
},
want: false,
wantErr: false,
},
{
name: "Error encountered",
checkers: []MockChecker{
{Result: true, Err: nil},
{Result: false, Err: http.ErrNotSupported},
},
want: false,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var all All
for _, mc := range tt.checkers {
all = append(all, mc)
}
got, err := all.Check(nil)
if (err != nil) != tt.wantErr {
t.Errorf("All.Check() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("All.Check() = %v, want %v", got, tt.want)
}
})
}
}

35
lib/checker/any.go Normal file
View File

@@ -0,0 +1,35 @@
package checker
import (
"fmt"
"net/http"
"strings"
"github.com/TecharoHQ/anubis/internal"
)
type Any []Interface
func (a Any) Check(r *http.Request) (bool, error) {
for _, c := range a {
match, err := c.Check(r)
if err != nil {
return match, err
}
if match {
return true, err // match
}
}
return false, nil // no match
}
func (a Any) Hash() string {
var sb strings.Builder
for _, c := range a {
fmt.Fprintln(&sb, c.Hash())
}
return internal.FastHash(sb.String())
}

83
lib/checker/any_test.go Normal file
View File

@@ -0,0 +1,83 @@
package checker
import (
"net/http"
"testing"
)
type MockChecker struct {
Result bool
Err error
}
func (m MockChecker) Check(r *http.Request) (bool, error) {
return m.Result, m.Err
}
func (m MockChecker) Hash() string {
return "mock-hash"
}
func TestAny_Check(t *testing.T) {
tests := []struct {
name string
checkers []MockChecker
want bool
wantErr bool
}{
{
name: "All match",
checkers: []MockChecker{
{Result: true, Err: nil},
{Result: true, Err: nil},
},
want: true,
wantErr: false,
},
{
name: "One match",
checkers: []MockChecker{
{Result: false, Err: nil},
{Result: true, Err: nil},
},
want: true,
wantErr: false,
},
{
name: "No match",
checkers: []MockChecker{
{Result: false, Err: nil},
{Result: false, Err: nil},
},
want: false,
wantErr: false,
},
{
name: "Error encountered",
checkers: []MockChecker{
{Result: false, Err: nil},
{Result: false, Err: http.ErrNotSupported},
},
want: false,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var any Any
for _, mc := range tt.checkers {
any = append(any, mc)
}
got, err := any.Check(nil)
if (err != nil) != tt.wantErr {
t.Errorf("Any.Check() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("Any.Check() = %v, want %v", got, tt.want)
}
})
}
}

17
lib/checker/checker.go Normal file
View File

@@ -0,0 +1,17 @@
// Package checker defines the Checker interface and a helper utility to avoid import cycles.
package checker
import (
"errors"
"net/http"
)
var (
ErrUnparseableConfig = errors.New("checker: config is unparseable")
ErrInvalidConfig = errors.New("checker: config is invalid")
)
type Interface interface {
Check(*http.Request) (matches bool, err error)
Hash() string
}

View File

@@ -1,43 +1,44 @@
package policy package expression
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"github.com/TecharoHQ/anubis/internal" "github.com/TecharoHQ/anubis/internal"
"github.com/TecharoHQ/anubis/lib/policy/config" "github.com/TecharoHQ/anubis/lib/checker/expression/environment"
"github.com/TecharoHQ/anubis/lib/policy/expressions"
"github.com/google/cel-go/cel" "github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types"
) )
type CELChecker struct { type Checker struct {
program cel.Program program cel.Program
src string src string
hash string
} }
func NewCELChecker(cfg *config.ExpressionOrList) (*CELChecker, error) { func New(cfg *Config) (*Checker, error) {
env, err := expressions.BotEnvironment() env, err := environment.Bot()
if err != nil { if err != nil {
return nil, err return nil, err
} }
program, err := expressions.Compile(env, cfg.String()) program, err := environment.Compile(env, cfg.String())
if err != nil { if err != nil {
return nil, fmt.Errorf("can't compile CEL program: %w", err) return nil, fmt.Errorf("can't compile CEL program: %w", err)
} }
return &CELChecker{ return &Checker{
src: cfg.String(), src: cfg.String(),
hash: internal.FastHash(cfg.String()),
program: program, program: program,
}, nil }, nil
} }
func (cc *CELChecker) Hash() string { func (cc *Checker) Hash() string {
return internal.FastHash(cc.src) return cc.hash
} }
func (cc *CELChecker) Check(r *http.Request) (bool, error) { func (cc *Checker) Check(r *http.Request) (bool, error) {
result, _, err := cc.program.ContextEval(r.Context(), &CELRequest{r}) result, _, err := cc.program.ContextEval(r.Context(), &CELRequest{r})
if err != nil { if err != nil {
@@ -70,15 +71,15 @@ func (cr *CELRequest) ResolveName(name string) (any, bool) {
case "path": case "path":
return cr.URL.Path, true return cr.URL.Path, true
case "query": case "query":
return expressions.URLValues{Values: cr.URL.Query()}, true return URLValues{Values: cr.URL.Query()}, true
case "headers": case "headers":
return expressions.HTTPHeaders{Header: cr.Header}, true return HTTPHeaders{Header: cr.Header}, true
case "load_1m": case "load_1m":
return expressions.Load1(), true return Load1(), true
case "load_5m": case "load_5m":
return expressions.Load5(), true return Load5(), true
case "load_15m": case "load_15m":
return expressions.Load15(), true return Load15(), true
default: default:
return nil, false return nil, false
} }

View File

@@ -1,4 +1,4 @@
package config package expression
import ( import (
"encoding/json" "encoding/json"
@@ -9,18 +9,18 @@ import (
) )
var ( var (
ErrExpressionOrListMustBeStringOrObject = errors.New("config: this must be a string or an object") ErrExpressionOrListMustBeStringOrObject = errors.New("expression: this must be a string or an object")
ErrExpressionEmpty = errors.New("config: this expression is empty") ErrExpressionEmpty = errors.New("expression: this expression is empty")
ErrExpressionCantHaveBoth = errors.New("config: expression block can't contain multiple expression types") ErrExpressionCantHaveBoth = errors.New("expression: expression block can't contain multiple expression types")
) )
type ExpressionOrList struct { type Config struct {
Expression string `json:"-" yaml:"-"` Expression string `json:"-" yaml:"-"`
All []string `json:"all,omitempty" yaml:"all,omitempty"` All []string `json:"all,omitempty" yaml:"all,omitempty"`
Any []string `json:"any,omitempty" yaml:"any,omitempty"` Any []string `json:"any,omitempty" yaml:"any,omitempty"`
} }
func (eol ExpressionOrList) String() string { func (eol Config) String() string {
switch { switch {
case len(eol.Expression) != 0: case len(eol.Expression) != 0:
return eol.Expression return eol.Expression
@@ -46,7 +46,7 @@ func (eol ExpressionOrList) String() string {
panic("this should not happen") panic("this should not happen")
} }
func (eol ExpressionOrList) Equal(rhs *ExpressionOrList) bool { func (eol Config) Equal(rhs *Config) bool {
if eol.Expression != rhs.Expression { if eol.Expression != rhs.Expression {
return false return false
} }
@@ -62,7 +62,7 @@ func (eol ExpressionOrList) Equal(rhs *ExpressionOrList) bool {
return true return true
} }
func (eol *ExpressionOrList) MarshalYAML() (any, error) { func (eol *Config) MarshalYAML() (any, error) {
switch { switch {
case len(eol.All) == 1 && len(eol.Any) == 0: case len(eol.All) == 1 && len(eol.Any) == 0:
eol.Expression = eol.All[0] eol.Expression = eol.All[0]
@@ -76,11 +76,11 @@ func (eol *ExpressionOrList) MarshalYAML() (any, error) {
return eol.Expression, nil return eol.Expression, nil
} }
type RawExpressionOrList ExpressionOrList type RawExpressionOrList Config
return RawExpressionOrList(*eol), nil return RawExpressionOrList(*eol), nil
} }
func (eol *ExpressionOrList) MarshalJSON() ([]byte, error) { func (eol *Config) MarshalJSON() ([]byte, error) {
switch { switch {
case len(eol.All) == 1 && len(eol.Any) == 0: case len(eol.All) == 1 && len(eol.Any) == 0:
eol.Expression = eol.All[0] eol.Expression = eol.All[0]
@@ -94,17 +94,17 @@ func (eol *ExpressionOrList) MarshalJSON() ([]byte, error) {
return json.Marshal(string(eol.Expression)) return json.Marshal(string(eol.Expression))
} }
type RawExpressionOrList ExpressionOrList type RawExpressionOrList Config
val := RawExpressionOrList(*eol) val := RawExpressionOrList(*eol)
return json.Marshal(val) return json.Marshal(val)
} }
func (eol *ExpressionOrList) UnmarshalJSON(data []byte) error { func (eol *Config) UnmarshalJSON(data []byte) error {
switch string(data[0]) { switch string(data[0]) {
case `"`: // string case `"`: // string
return json.Unmarshal(data, &eol.Expression) return json.Unmarshal(data, &eol.Expression)
case "{": // object case "{": // object
type RawExpressionOrList ExpressionOrList type RawExpressionOrList Config
var val RawExpressionOrList var val RawExpressionOrList
if err := json.Unmarshal(data, &val); err != nil { if err := json.Unmarshal(data, &val); err != nil {
return err return err
@@ -118,7 +118,7 @@ func (eol *ExpressionOrList) UnmarshalJSON(data []byte) error {
return ErrExpressionOrListMustBeStringOrObject return ErrExpressionOrListMustBeStringOrObject
} }
func (eol *ExpressionOrList) Valid() error { func (eol *Config) Valid() error {
if eol.Expression == "" && len(eol.All) == 0 && len(eol.Any) == 0 { if eol.Expression == "" && len(eol.All) == 0 && len(eol.Any) == 0 {
return ErrExpressionEmpty return ErrExpressionEmpty
} }

View File

@@ -1,4 +1,4 @@
package config package expression
import ( import (
"bytes" "bytes"
@@ -12,13 +12,13 @@ import (
func TestExpressionOrListMarshalJSON(t *testing.T) { func TestExpressionOrListMarshalJSON(t *testing.T) {
for _, tt := range []struct { for _, tt := range []struct {
name string name string
input *ExpressionOrList input *Config
output []byte output []byte
err error err error
}{ }{
{ {
name: "single expression", name: "single expression",
input: &ExpressionOrList{ input: &Config{
Expression: "true", Expression: "true",
}, },
output: []byte(`"true"`), output: []byte(`"true"`),
@@ -26,7 +26,7 @@ func TestExpressionOrListMarshalJSON(t *testing.T) {
}, },
{ {
name: "all", name: "all",
input: &ExpressionOrList{ input: &Config{
All: []string{"true", "true"}, All: []string{"true", "true"},
}, },
output: []byte(`{"all":["true","true"]}`), output: []byte(`{"all":["true","true"]}`),
@@ -34,7 +34,7 @@ func TestExpressionOrListMarshalJSON(t *testing.T) {
}, },
{ {
name: "all one", name: "all one",
input: &ExpressionOrList{ input: &Config{
All: []string{"true"}, All: []string{"true"},
}, },
output: []byte(`"true"`), output: []byte(`"true"`),
@@ -42,7 +42,7 @@ func TestExpressionOrListMarshalJSON(t *testing.T) {
}, },
{ {
name: "any", name: "any",
input: &ExpressionOrList{ input: &Config{
Any: []string{"true", "false"}, Any: []string{"true", "false"},
}, },
output: []byte(`{"any":["true","false"]}`), output: []byte(`{"any":["true","false"]}`),
@@ -50,7 +50,7 @@ func TestExpressionOrListMarshalJSON(t *testing.T) {
}, },
{ {
name: "any one", name: "any one",
input: &ExpressionOrList{ input: &Config{
Any: []string{"true"}, Any: []string{"true"},
}, },
output: []byte(`"true"`), output: []byte(`"true"`),
@@ -75,13 +75,13 @@ func TestExpressionOrListMarshalJSON(t *testing.T) {
func TestExpressionOrListMarshalYAML(t *testing.T) { func TestExpressionOrListMarshalYAML(t *testing.T) {
for _, tt := range []struct { for _, tt := range []struct {
name string name string
input *ExpressionOrList input *Config
output []byte output []byte
err error err error
}{ }{
{ {
name: "single expression", name: "single expression",
input: &ExpressionOrList{ input: &Config{
Expression: "true", Expression: "true",
}, },
output: []byte(`"true"`), output: []byte(`"true"`),
@@ -89,7 +89,7 @@ func TestExpressionOrListMarshalYAML(t *testing.T) {
}, },
{ {
name: "all", name: "all",
input: &ExpressionOrList{ input: &Config{
All: []string{"true", "true"}, All: []string{"true", "true"},
}, },
output: []byte(`all: output: []byte(`all:
@@ -99,7 +99,7 @@ func TestExpressionOrListMarshalYAML(t *testing.T) {
}, },
{ {
name: "all one", name: "all one",
input: &ExpressionOrList{ input: &Config{
All: []string{"true"}, All: []string{"true"},
}, },
output: []byte(`"true"`), output: []byte(`"true"`),
@@ -107,7 +107,7 @@ func TestExpressionOrListMarshalYAML(t *testing.T) {
}, },
{ {
name: "any", name: "any",
input: &ExpressionOrList{ input: &Config{
Any: []string{"true", "false"}, Any: []string{"true", "false"},
}, },
output: []byte(`any: output: []byte(`any:
@@ -117,7 +117,7 @@ func TestExpressionOrListMarshalYAML(t *testing.T) {
}, },
{ {
name: "any one", name: "any one",
input: &ExpressionOrList{ input: &Config{
Any: []string{"true"}, Any: []string{"true"},
}, },
output: []byte(`"true"`), output: []byte(`"true"`),
@@ -145,14 +145,14 @@ func TestExpressionOrListUnmarshalJSON(t *testing.T) {
for _, tt := range []struct { for _, tt := range []struct {
err error err error
validErr error validErr error
result *ExpressionOrList result *Config
name string name string
inp string inp string
}{ }{
{ {
name: "simple", name: "simple",
inp: `"\"User-Agent\" in headers"`, inp: `"\"User-Agent\" in headers"`,
result: &ExpressionOrList{ result: &Config{
Expression: `"User-Agent" in headers`, Expression: `"User-Agent" in headers`,
}, },
}, },
@@ -161,7 +161,7 @@ func TestExpressionOrListUnmarshalJSON(t *testing.T) {
inp: `{ inp: `{
"all": ["\"User-Agent\" in headers"] "all": ["\"User-Agent\" in headers"]
}`, }`,
result: &ExpressionOrList{ result: &Config{
All: []string{ All: []string{
`"User-Agent" in headers`, `"User-Agent" in headers`,
}, },
@@ -172,7 +172,7 @@ func TestExpressionOrListUnmarshalJSON(t *testing.T) {
inp: `{ inp: `{
"any": ["\"User-Agent\" in headers"] "any": ["\"User-Agent\" in headers"]
}`, }`,
result: &ExpressionOrList{ result: &Config{
Any: []string{ Any: []string{
`"User-Agent" in headers`, `"User-Agent" in headers`,
}, },
@@ -195,7 +195,7 @@ func TestExpressionOrListUnmarshalJSON(t *testing.T) {
}, },
} { } {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
var eol ExpressionOrList var eol Config
if err := json.Unmarshal([]byte(tt.inp), &eol); !errors.Is(err, tt.err) { if err := json.Unmarshal([]byte(tt.inp), &eol); !errors.Is(err, tt.err) {
t.Errorf("wanted unmarshal error: %v but got: %v", tt.err, err) t.Errorf("wanted unmarshal error: %v but got: %v", tt.err, err)
@@ -217,40 +217,40 @@ func TestExpressionOrListUnmarshalJSON(t *testing.T) {
func TestExpressionOrListString(t *testing.T) { func TestExpressionOrListString(t *testing.T) {
for _, tt := range []struct { for _, tt := range []struct {
name string name string
in ExpressionOrList in Config
out string out string
}{ }{
{ {
name: "single expression", name: "single expression",
in: ExpressionOrList{ in: Config{
Expression: "true", Expression: "true",
}, },
out: "true", out: "true",
}, },
{ {
name: "all", name: "all",
in: ExpressionOrList{ in: Config{
All: []string{"true"}, All: []string{"true"},
}, },
out: "( true )", out: "( true )",
}, },
{ {
name: "all with &&", name: "all with &&",
in: ExpressionOrList{ in: Config{
All: []string{"true", "true"}, All: []string{"true", "true"},
}, },
out: "( true ) && ( true )", out: "( true ) && ( true )",
}, },
{ {
name: "any", name: "any",
in: ExpressionOrList{ in: Config{
All: []string{"true"}, All: []string{"true"},
}, },
out: "( true )", out: "( true )",
}, },
{ {
name: "any with ||", name: "any with ||",
in: ExpressionOrList{ in: Config{
Any: []string{"true", "true"}, Any: []string{"true", "true"},
}, },
out: "( true ) || ( true )", out: "( true ) || ( true )",

View File

@@ -1,8 +1,7 @@
package expressions package environment
import ( import (
"math/rand/v2" "math/rand/v2"
"strings"
"github.com/google/cel-go/cel" "github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types"
@@ -11,11 +10,11 @@ import (
"github.com/google/cel-go/ext" "github.com/google/cel-go/ext"
) )
// BotEnvironment creates a new CEL environment, this is the set of // Bot creates a new CEL environment, this is the set of variables and
// variables and functions that are passed into the CEL scope so that // functions that are passed into the CEL scope so that Anubis can fail
// Anubis can fail loudly and early when something is invalid instead // loudly and early when something is invalid instead of blowing up at
// of blowing up at runtime. // runtime.
func BotEnvironment() (*cel.Env, error) { func Bot() (*cel.Env, error) {
return New( return New(
// Variables exposed to CEL programs: // Variables exposed to CEL programs:
cel.Variable("remoteAddress", cel.StringType), cel.Variable("remoteAddress", cel.StringType),
@@ -55,38 +54,17 @@ func BotEnvironment() (*cel.Env, error) {
}), }),
), ),
), ),
cel.Function("segments",
cel.Overload("segments_string_list_string",
[]*cel.Type{cel.StringType},
cel.ListType(cel.StringType),
cel.UnaryBinding(func(path ref.Val) ref.Val {
pathStrType, ok := path.(types.String)
if !ok {
return types.ValOrErr(path, "path is not a string, but is %T", path)
}
pathStr := string(pathStrType)
if !strings.HasPrefix(pathStr, "/") {
return types.ValOrErr(path, "path does not start with /")
}
pathList := strings.Split(string(pathStr), "/")[1:]
return types.NewStringList(types.DefaultTypeAdapter, pathList)
}),
),
),
) )
} }
// NewThreshold creates a new CEL environment for threshold checking. // Threshold creates a new CEL environment for threshold checking.
func ThresholdEnvironment() (*cel.Env, error) { func Threshold() (*cel.Env, error) {
return New( return New(
cel.Variable("weight", cel.IntType), cel.Variable("weight", cel.IntType),
) )
} }
// New creates a new base CEL environment.
func New(opts ...cel.EnvOption) (*cel.Env, error) { func New(opts ...cel.EnvOption) (*cel.Env, error) {
args := []cel.EnvOption{ args := []cel.EnvOption{
ext.Strings( ext.Strings(
@@ -118,7 +96,7 @@ func New(opts ...cel.EnvOption) (*cel.Env, error) {
return cel.NewEnv(args...) return cel.NewEnv(args...)
} }
// Compile takes CEL environment and syntax tree then emits an optimized // Compile takes a CEL environment and syntax tree then emits an optimized
// Program for execution. // Program for execution.
func Compile(env *cel.Env, src string) (cel.Program, error) { func Compile(env *cel.Env, src string) (cel.Program, error) {
intermediate, iss := env.Compile(src) intermediate, iss := env.Compile(src)

View File

@@ -0,0 +1,269 @@
package environment
import (
"testing"
"github.com/google/cel-go/common/types"
)
func TestBot(t *testing.T) {
env, err := Bot()
if err != nil {
t.Fatalf("failed to create bot environment: %v", err)
}
tests := []struct {
name string
expression string
headers map[string]string
expected types.Bool
description string
}{
{
name: "missing-header",
expression: `missingHeader(headers, "Missing-Header")`,
headers: map[string]string{
"User-Agent": "test-agent",
"Content-Type": "application/json",
},
expected: types.Bool(true),
description: "should return true when header is missing",
},
{
name: "existing-header",
expression: `missingHeader(headers, "User-Agent")`,
headers: map[string]string{
"User-Agent": "test-agent",
"Content-Type": "application/json",
},
expected: types.Bool(false),
description: "should return false when header exists",
},
{
name: "case-sensitive",
expression: `missingHeader(headers, "user-agent")`,
headers: map[string]string{
"User-Agent": "test-agent",
},
expected: types.Bool(true),
description: "should be case-sensitive (user-agent != User-Agent)",
},
{
name: "empty-headers",
expression: `missingHeader(headers, "Any-Header")`,
headers: map[string]string{},
expected: types.Bool(true),
description: "should return true for any header when map is empty",
},
{
name: "real-world-sec-ch-ua",
expression: `missingHeader(headers, "Sec-Ch-Ua")`,
headers: map[string]string{
"User-Agent": "curl/7.68.0",
"Accept": "*/*",
"Host": "example.com",
},
expected: types.Bool(true),
description: "should detect missing browser-specific headers from bots",
},
{
name: "browser-with-sec-ch-ua",
expression: `missingHeader(headers, "Sec-Ch-Ua")`,
headers: map[string]string{
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
"Sec-Ch-Ua": `"Chrome"; v="91", "Not A Brand"; v="99"`,
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
},
expected: types.Bool(false),
description: "should return false when browser sends Sec-Ch-Ua header",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
prog, err := Compile(env, tt.expression)
if err != nil {
t.Fatalf("failed to compile expression %q: %v", tt.expression, err)
}
result, _, err := prog.Eval(map[string]interface{}{
"headers": tt.headers,
})
if err != nil {
t.Fatalf("failed to evaluate expression %q: %v", tt.expression, err)
}
if result != tt.expected {
t.Errorf("%s: expected %v, got %v", tt.description, tt.expected, result)
}
})
}
t.Run("function-compilation", func(t *testing.T) {
src := `missingHeader(headers, "Test-Header")`
_, err := Compile(env, src)
if err != nil {
t.Fatalf("failed to compile missingHeader expression: %v", err)
}
})
}
func TestThreshold(t *testing.T) {
env, err := Threshold()
if err != nil {
t.Fatalf("failed to create threshold environment: %v", err)
}
tests := []struct {
name string
expression string
variables map[string]interface{}
expected types.Bool
description string
shouldCompile bool
}{
{
name: "weight-variable-available",
expression: `weight > 100`,
variables: map[string]interface{}{"weight": 150},
expected: types.Bool(true),
description: "should support weight variable in expressions",
shouldCompile: true,
},
{
name: "weight-variable-false-case",
expression: `weight > 100`,
variables: map[string]interface{}{"weight": 50},
expected: types.Bool(false),
description: "should correctly evaluate weight comparisons",
shouldCompile: true,
},
{
name: "missingHeader-not-available",
expression: `missingHeader(headers, "Test")`,
variables: map[string]interface{}{},
expected: types.Bool(false), // not used
description: "should not have missingHeader function available",
shouldCompile: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
prog, err := Compile(env, tt.expression)
if !tt.shouldCompile {
if err == nil {
t.Fatalf("%s: expected compilation to fail but it succeeded", tt.description)
}
return // Test passed - compilation failed as expected
}
if err != nil {
t.Fatalf("failed to compile expression %q: %v", tt.expression, err)
}
result, _, err := prog.Eval(tt.variables)
if err != nil {
t.Fatalf("failed to evaluate expression %q: %v", tt.expression, err)
}
if result != tt.expected {
t.Errorf("%s: expected %v, got %v", tt.description, tt.expected, result)
}
})
}
}
func TestNewEnvironment(t *testing.T) {
env, err := New()
if err != nil {
t.Fatalf("failed to create new environment: %v", err)
}
tests := []struct {
name string
expression string
variables map[string]interface{}
expectBool *bool // nil if we just want to test compilation or non-bool result
description string
shouldCompile bool
}{
{
name: "randInt-function-compilation",
expression: `randInt(10)`,
variables: map[string]interface{}{},
expectBool: nil, // Don't check result, just compilation
description: "should compile randInt function",
shouldCompile: true,
},
{
name: "randInt-range-validation",
expression: `randInt(10) >= 0 && randInt(10) < 10`,
variables: map[string]interface{}{},
expectBool: boolPtr(true),
description: "should return values in correct range",
shouldCompile: true,
},
{
name: "strings-extension-size",
expression: `"hello".size() == 5`,
variables: map[string]interface{}{},
expectBool: boolPtr(true),
description: "should support string extension functions",
shouldCompile: true,
},
{
name: "strings-extension-contains",
expression: `"hello world".contains("world")`,
variables: map[string]interface{}{},
expectBool: boolPtr(true),
description: "should support string contains function",
shouldCompile: true,
},
{
name: "strings-extension-startsWith",
expression: `"hello world".startsWith("hello")`,
variables: map[string]interface{}{},
expectBool: boolPtr(true),
description: "should support string startsWith function",
shouldCompile: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
prog, err := Compile(env, tt.expression)
if !tt.shouldCompile {
if err == nil {
t.Fatalf("%s: expected compilation to fail but it succeeded", tt.description)
}
return // Test passed - compilation failed as expected
}
if err != nil {
t.Fatalf("failed to compile expression %q: %v", tt.expression, err)
}
// If we only want to test compilation, skip evaluation
if tt.expectBool == nil {
return
}
result, _, err := prog.Eval(tt.variables)
if err != nil {
t.Fatalf("failed to evaluate expression %q: %v", tt.expression, err)
}
if result != types.Bool(*tt.expectBool) {
t.Errorf("%s: expected %v, got %v", tt.description, *tt.expectBool, result)
}
})
}
}
// Helper function to create bool pointers
func boolPtr(b bool) *bool {
return &b
}

View File

@@ -0,0 +1,43 @@
package expression
import (
"context"
"encoding/json"
"errors"
"github.com/TecharoHQ/anubis/lib/checker"
)
func init() {
checker.Register("expression", Factory{})
}
type Factory struct{}
func (f Factory) Build(ctx context.Context, data json.RawMessage) (checker.Interface, error) {
var fc = &Config{}
if err := json.Unmarshal([]byte(data), fc); err != nil {
return nil, errors.Join(checker.ErrUnparseableConfig, err)
}
if err := fc.Valid(); err != nil {
return nil, errors.Join(checker.ErrInvalidConfig, err)
}
return New(fc)
}
func (f Factory) Valid(ctx context.Context, data json.RawMessage) error {
var fc = &Config{}
if err := json.Unmarshal([]byte(data), fc); err != nil {
return err
}
if err := fc.Valid(); err != nil {
return err
}
return nil
}

View File

@@ -1,4 +1,4 @@
package expressions package expression
import ( import (
"net/http" "net/http"

View File

@@ -1,4 +1,4 @@
package expressions package expression
import ( import (
"net/http" "net/http"

View File

@@ -1,4 +1,4 @@
package expressions package expression
import ( import (
"context" "context"

View File

@@ -1,4 +1,4 @@
package expressions package expression
import ( import (
"errors" "errors"

View File

@@ -1,4 +1,4 @@
package expressions package expression
import ( import (
"net/url" "net/url"

View File

@@ -0,0 +1,32 @@
package headerexists
import (
"net/http"
"strings"
"github.com/TecharoHQ/anubis/internal"
"github.com/TecharoHQ/anubis/lib/checker"
)
func New(key string) checker.Interface {
return headerExistsChecker{
header: strings.TrimSpace(http.CanonicalHeaderKey(key)),
hash: internal.FastHash(key),
}
}
type headerExistsChecker struct {
header, hash string
}
func (hec headerExistsChecker) Check(r *http.Request) (bool, error) {
if r.Header.Get(hec.header) != "" {
return true, nil
}
return false, nil
}
func (hec headerExistsChecker) Hash() string {
return hec.hash
}

View File

@@ -0,0 +1,57 @@
package headerexists
import (
"encoding/json"
"fmt"
"net/http"
"testing"
)
func TestChecker(t *testing.T) {
fac := Factory{}
for _, tt := range []struct {
name string
header string
reqHeader string
ok bool
}{
{
name: "match",
header: "Authorization",
reqHeader: "Authorization",
ok: true,
},
{
name: "not_match",
header: "Authorization",
reqHeader: "Authentication",
},
} {
t.Run(tt.name, func(t *testing.T) {
hec, err := fac.Build(t.Context(), json.RawMessage(fmt.Sprintf("%q", tt.header)))
if err != nil {
t.Fatal(err)
}
t.Log(hec.Hash())
r, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
t.Fatalf("can't make request: %v", err)
}
r.Header.Set(tt.reqHeader, "hunter2")
ok, err := hec.Check(r)
if tt.ok != ok {
t.Errorf("ok: %v, wanted: %v", ok, tt.ok)
}
if err != nil {
t.Errorf("err: %v", err)
}
})
}
}

View File

@@ -0,0 +1,40 @@
package headerexists
import (
"context"
"encoding/json"
"fmt"
"net/http"
"github.com/TecharoHQ/anubis/lib/checker"
)
type Factory struct{}
func (f Factory) Build(ctx context.Context, data json.RawMessage) (checker.Interface, error) {
var headerName string
if err := json.Unmarshal([]byte(data), &headerName); err != nil {
return nil, fmt.Errorf("%w: want string", checker.ErrUnparseableConfig)
}
if err := f.Valid(ctx, data); err != nil {
return nil, err
}
return New(http.CanonicalHeaderKey(headerName)), nil
}
func (Factory) Valid(ctx context.Context, data json.RawMessage) error {
var headerName string
if err := json.Unmarshal([]byte(data), &headerName); err != nil {
return fmt.Errorf("%w: want string", checker.ErrUnparseableConfig)
}
if headerName == "" {
return fmt.Errorf("%w: string must not be empty", checker.ErrInvalidConfig)
}
return nil
}

View File

@@ -0,0 +1,60 @@
package headerexists
import (
"encoding/json"
"os"
"path/filepath"
"testing"
)
func TestFactoryGood(t *testing.T) {
files, err := os.ReadDir("./testdata/good")
if err != nil {
t.Fatal(err)
}
fac := Factory{}
for _, fname := range files {
t.Run(fname.Name(), func(t *testing.T) {
data, err := os.ReadFile(filepath.Join("testdata", "good", fname.Name()))
if err != nil {
t.Fatal(err)
}
if err := fac.Valid(t.Context(), json.RawMessage(data)); err != nil {
t.Fatal(err)
}
})
}
}
func TestFactoryBad(t *testing.T) {
files, err := os.ReadDir("./testdata/bad")
if err != nil {
t.Fatal(err)
}
fac := Factory{}
for _, fname := range files {
t.Run(fname.Name(), func(t *testing.T) {
data, err := os.ReadFile(filepath.Join("testdata", "bad", fname.Name()))
if err != nil {
t.Fatal(err)
}
t.Run("Build", func(t *testing.T) {
if _, err := fac.Build(t.Context(), json.RawMessage(data)); err == nil {
t.Fatal(err)
}
})
t.Run("Valid", func(t *testing.T) {
if err := fac.Valid(t.Context(), json.RawMessage(data)); err == nil {
t.Fatal(err)
}
})
})
}
}

View File

@@ -0,0 +1 @@
""

View File

@@ -0,0 +1 @@
{}

View File

@@ -0,0 +1 @@
"Authorization"

View File

@@ -0,0 +1,46 @@
package headermatches
import (
"context"
"encoding/json"
"net/http"
"regexp"
"github.com/TecharoHQ/anubis/lib/checker"
)
type Checker struct {
header string
regexp *regexp.Regexp
hash string
}
func (c *Checker) Check(r *http.Request) (bool, error) {
if c.regexp.MatchString(r.Header.Get(c.header)) {
return true, nil
}
return false, nil
}
func (c *Checker) Hash() string {
return c.hash
}
func New(key, valueRex string) (checker.Interface, error) {
fc := fileConfig{
Header: key,
ValueRegex: valueRex,
}
if err := fc.Valid(); err != nil {
return nil, err
}
data, err := json.Marshal(fc)
if err != nil {
return nil, err
}
return Factory{}.Build(context.Background(), json.RawMessage(data))
}

View File

@@ -0,0 +1,98 @@
package headermatches
import (
"encoding/json"
"errors"
"net/http"
"testing"
)
func TestChecker(t *testing.T) {
}
func TestHeaderMatchesChecker(t *testing.T) {
fac := Factory{}
for _, tt := range []struct {
err error
name string
header string
rexStr string
reqHeaderKey string
reqHeaderValue string
ok bool
}{
{
name: "match",
header: "Cf-Worker",
rexStr: ".*",
reqHeaderKey: "Cf-Worker",
reqHeaderValue: "true",
ok: true,
err: nil,
},
{
name: "not_match",
header: "Cf-Worker",
rexStr: "false",
reqHeaderKey: "Cf-Worker",
reqHeaderValue: "true",
ok: false,
err: nil,
},
{
name: "not_present",
header: "Cf-Worker",
rexStr: "foobar",
reqHeaderKey: "Something-Else",
reqHeaderValue: "true",
ok: false,
err: nil,
},
{
name: "invalid_regex",
rexStr: "a(b",
err: ErrInvalidRegex,
},
} {
t.Run(tt.name, func(t *testing.T) {
fc := fileConfig{
Header: tt.header,
ValueRegex: tt.rexStr,
}
data, err := json.Marshal(fc)
if err != nil {
t.Fatal(err)
}
hmc, err := fac.Build(t.Context(), json.RawMessage(data))
if err != nil && !errors.Is(err, tt.err) {
t.Fatalf("creating HeaderMatchesChecker failed")
}
if tt.err != nil && hmc == nil {
return
}
t.Log(hmc.Hash())
r, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
t.Fatalf("can't make request: %v", err)
}
r.Header.Set(tt.reqHeaderKey, tt.reqHeaderValue)
ok, err := hmc.Check(r)
if tt.ok != ok {
t.Errorf("ok: %v, wanted: %v", ok, tt.ok)
}
if err != nil && tt.err != nil && !errors.Is(err, tt.err) {
t.Errorf("err: %v, wanted: %v", err, tt.err)
}
})
}
}

View File

@@ -0,0 +1,44 @@
package headermatches
import (
"errors"
"fmt"
"regexp"
)
var (
ErrNoHeader = errors.New("headermatches: no header is configured")
ErrNoValueRegex = errors.New("headermatches: no value regex is configured")
ErrInvalidRegex = errors.New("headermatches: value regex is invalid")
)
type fileConfig struct {
Header string `json:"header" yaml:"header"`
ValueRegex string `json:"value_regex" yaml:"value_regex"`
}
func (fc fileConfig) String() string {
return fmt.Sprintf("header=%q value_regex=%q", fc.Header, fc.ValueRegex)
}
func (fc fileConfig) Valid() error {
var errs []error
if fc.Header == "" {
errs = append(errs, ErrNoHeader)
}
if fc.ValueRegex == "" {
errs = append(errs, ErrNoValueRegex)
}
if _, err := regexp.Compile(fc.ValueRegex); err != nil {
errs = append(errs, ErrInvalidRegex, err)
}
if len(errs) != 0 {
return errors.Join(errs...)
}
return nil
}

View File

@@ -0,0 +1,55 @@
package headermatches
import (
"errors"
"testing"
)
func TestFileConfigValid(t *testing.T) {
for _, tt := range []struct {
name, description string
in fileConfig
err error
}{
{
name: "simple happy",
description: "the most common usecase",
in: fileConfig{
Header: "User-Agent",
ValueRegex: ".*",
},
},
{
name: "no header",
description: "Header must be set, it is not",
in: fileConfig{
ValueRegex: ".*",
},
err: ErrNoHeader,
},
{
name: "no value regex",
description: "ValueRegex must be set, it is not",
in: fileConfig{
Header: "User-Agent",
},
err: ErrNoValueRegex,
},
{
name: "invalid regex",
description: "the user wrote an invalid value regular expression",
in: fileConfig{
Header: "User-Agent",
ValueRegex: "[a-z",
},
err: ErrInvalidRegex,
},
} {
t.Run(tt.name, func(t *testing.T) {
if err := tt.in.Valid(); !errors.Is(err, tt.err) {
t.Log(tt.description)
t.Fatal(err)
}
})
}
}

View File

@@ -0,0 +1,66 @@
package headermatches
import (
"context"
"encoding/json"
"errors"
"net/http"
"regexp"
"github.com/TecharoHQ/anubis/internal"
"github.com/TecharoHQ/anubis/lib/checker"
)
func init() {
checker.Register("header_matches", Factory{})
checker.Register("user_agent", Factory{defaultHeader: "User-Agent"})
}
type Factory struct {
defaultHeader string
}
func (f Factory) Build(ctx context.Context, data json.RawMessage) (checker.Interface, error) {
var fc fileConfig
if f.defaultHeader != "" {
fc.Header = f.defaultHeader
}
if err := json.Unmarshal([]byte(data), &fc); err != nil {
return nil, errors.Join(checker.ErrUnparseableConfig, err)
}
if err := fc.Valid(); err != nil {
return nil, errors.Join(checker.ErrInvalidConfig, err)
}
valueRex, err := regexp.Compile(fc.ValueRegex)
if err != nil {
return nil, errors.Join(ErrInvalidRegex, err)
}
return &Checker{
header: http.CanonicalHeaderKey(fc.Header),
regexp: valueRex,
hash: internal.FastHash(fc.String()),
}, nil
}
func (f Factory) Valid(ctx context.Context, data json.RawMessage) error {
var fc fileConfig
if f.defaultHeader != "" {
fc.Header = f.defaultHeader
}
if err := json.Unmarshal([]byte(data), &fc); err != nil {
return err
}
if err := fc.Valid(); err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,52 @@
package headermatches
import (
"encoding/json"
"os"
"path/filepath"
"testing"
)
func TestFactoryGood(t *testing.T) {
files, err := os.ReadDir("./testdata/good")
if err != nil {
t.Fatal(err)
}
fac := Factory{}
for _, fname := range files {
t.Run(fname.Name(), func(t *testing.T) {
data, err := os.ReadFile(filepath.Join("testdata", "good", fname.Name()))
if err != nil {
t.Fatal(err)
}
if err := fac.Valid(t.Context(), json.RawMessage(data)); err != nil {
t.Fatal(err)
}
})
}
}
func TestFactoryBad(t *testing.T) {
files, err := os.ReadDir("./testdata/bad")
if err != nil {
t.Fatal(err)
}
fac := Factory{}
for _, fname := range files {
t.Run(fname.Name(), func(t *testing.T) {
data, err := os.ReadFile(filepath.Join("testdata", "bad", fname.Name()))
if err != nil {
t.Fatal(err)
}
if err := fac.Valid(t.Context(), json.RawMessage(data)); err == nil {
t.Fatal(err)
}
})
}
}

View File

@@ -0,0 +1 @@
}

View File

@@ -0,0 +1,4 @@
{
"header": "User-Agent",
"value_regex": "a(b"
}

View File

@@ -0,0 +1,3 @@
{
"value_regex": "PaleMoon"
}

View File

@@ -0,0 +1,3 @@
{
"header": "User-Agent"
}

View File

@@ -0,0 +1 @@
{}

View File

@@ -0,0 +1,4 @@
{
"header": "User-Agent",
"value_regex": "PaleMoon"
}

View File

@@ -0,0 +1,35 @@
package headermatches
import (
"context"
"encoding/json"
"github.com/TecharoHQ/anubis/lib/checker"
)
func ValidUserAgent(valueRex string) error {
fc := fileConfig{
Header: "User-Agent",
ValueRegex: valueRex,
}
return fc.Valid()
}
func NewUserAgent(valueRex string) (checker.Interface, error) {
fc := fileConfig{
Header: "User-Agent",
ValueRegex: valueRex,
}
if err := fc.Valid(); err != nil {
return nil, err
}
data, err := json.Marshal(fc)
if err != nil {
return nil, err
}
return Factory{}.Build(context.Background(), json.RawMessage(data))
}

View File

@@ -0,0 +1,37 @@
package path
import (
"fmt"
"net/http"
"regexp"
"strings"
"github.com/TecharoHQ/anubis"
"github.com/TecharoHQ/anubis/internal"
"github.com/TecharoHQ/anubis/lib/checker"
)
func New(rexStr string) (checker.Interface, error) {
rex, err := regexp.Compile(strings.TrimSpace(rexStr))
if err != nil {
return nil, fmt.Errorf("%w: regex %s failed parse: %w", anubis.ErrMisconfiguration, rexStr, err)
}
return &Checker{rex, internal.FastHash(rexStr)}, nil
}
type Checker struct {
regexp *regexp.Regexp
hash string
}
func (c *Checker) Check(r *http.Request) (bool, error) {
if c.regexp.MatchString(r.URL.Path) {
return true, nil
}
return false, nil
}
func (c *Checker) Hash() string {
return c.hash
}

View File

@@ -0,0 +1,90 @@
package path
import (
"encoding/json"
"errors"
"net/http"
"testing"
)
func TestChecker(t *testing.T) {
fac := Factory{}
for _, tt := range []struct {
err error
name string
rexStr string
reqPath string
ok bool
}{
{
name: "match",
rexStr: "^/api/.*",
reqPath: "/api/v1/users",
ok: true,
err: nil,
},
{
name: "not_match",
rexStr: "^/api/.*",
reqPath: "/static/index.html",
ok: false,
err: nil,
},
{
name: "wildcard_match",
rexStr: ".*\\.json$",
reqPath: "/data/config.json",
ok: true,
err: nil,
},
{
name: "wildcard_not_match",
rexStr: ".*\\.json$",
reqPath: "/data/config.yaml",
ok: false,
err: nil,
},
{
name: "invalid_regex",
rexStr: "a(b",
err: ErrInvalidRegex,
},
} {
t.Run(tt.name, func(t *testing.T) {
fc := fileConfig{
Regex: tt.rexStr,
}
data, err := json.Marshal(fc)
if err != nil {
t.Fatal(err)
}
pc, err := fac.Build(t.Context(), json.RawMessage(data))
if err != nil && !errors.Is(err, tt.err) {
t.Fatalf("creating PathChecker failed")
}
if tt.err != nil && pc == nil {
return
}
t.Log(pc.Hash())
r, err := http.NewRequest(http.MethodGet, tt.reqPath, nil)
if err != nil {
t.Fatalf("can't make request: %v", err)
}
ok, err := pc.Check(r)
if tt.ok != ok {
t.Errorf("ok: %v, wanted: %v", ok, tt.ok)
}
if err != nil && tt.err != nil && !errors.Is(err, tt.err) {
t.Errorf("err: %v, wanted: %v", err, tt.err)
}
})
}
}

View File

@@ -0,0 +1,38 @@
package path
import (
"errors"
"fmt"
"regexp"
)
var (
ErrNoRegex = errors.New("path: no regex is configured")
ErrInvalidRegex = errors.New("path: regex is invalid")
)
type fileConfig struct {
Regex string `json:"regex" yaml:"regex"`
}
func (fc fileConfig) String() string {
return fmt.Sprintf("regex=%q", fc.Regex)
}
func (fc fileConfig) Valid() error {
var errs []error
if fc.Regex == "" {
errs = append(errs, ErrNoRegex)
}
if _, err := regexp.Compile(fc.Regex); err != nil {
errs = append(errs, ErrInvalidRegex, err)
}
if len(errs) != 0 {
return errors.Join(errs...)
}
return nil
}

View File

@@ -0,0 +1,50 @@
package path
import (
"errors"
"testing"
)
func TestFileConfigValid(t *testing.T) {
for _, tt := range []struct {
name, description string
in fileConfig
err error
}{
{
name: "simple happy",
description: "the most common usecase",
in: fileConfig{
Regex: "^/api/.*",
},
},
{
name: "wildcard match",
description: "match files with specific extension",
in: fileConfig{
Regex: ".*[.]json$",
},
},
{
name: "no regex",
description: "Regex must be set, it is not",
in: fileConfig{},
err: ErrNoRegex,
},
{
name: "invalid regex",
description: "the user wrote an invalid regular expression",
in: fileConfig{
Regex: "[a-z",
},
err: ErrInvalidRegex,
},
} {
t.Run(tt.name, func(t *testing.T) {
if err := tt.in.Valid(); !errors.Is(err, tt.err) {
t.Log(tt.description)
t.Fatalf("got %v, wanted %v", err, tt.err)
}
})
}
}

View File

@@ -0,0 +1,58 @@
package path
import (
"context"
"encoding/json"
"errors"
"regexp"
"strings"
"github.com/TecharoHQ/anubis/internal"
"github.com/TecharoHQ/anubis/lib/checker"
)
func init() {
checker.Register("path", Factory{})
}
type Factory struct{}
func (f Factory) Build(ctx context.Context, data json.RawMessage) (checker.Interface, error) {
var fc fileConfig
if err := json.Unmarshal([]byte(data), &fc); err != nil {
return nil, errors.Join(checker.ErrUnparseableConfig, err)
}
if err := fc.Valid(); err != nil {
return nil, errors.Join(checker.ErrInvalidConfig, err)
}
pathRex, err := regexp.Compile(strings.TrimSpace(fc.Regex))
if err != nil {
return nil, errors.Join(ErrInvalidRegex, err)
}
return &Checker{
regexp: pathRex,
hash: internal.FastHash(fc.String()),
}, nil
}
func (f Factory) Valid(ctx context.Context, data json.RawMessage) error {
var fc fileConfig
if err := json.Unmarshal([]byte(data), &fc); err != nil {
return errors.Join(checker.ErrUnparseableConfig, err)
}
return fc.Valid()
}
func Valid(pathRex string) error {
fc := fileConfig{
Regex: pathRex,
}
return fc.Valid()
}

View File

@@ -0,0 +1,52 @@
package path
import (
"encoding/json"
"os"
"path/filepath"
"testing"
)
func TestFactoryGood(t *testing.T) {
files, err := os.ReadDir("./testdata/good")
if err != nil {
t.Fatal(err)
}
fac := Factory{}
for _, fname := range files {
t.Run(fname.Name(), func(t *testing.T) {
data, err := os.ReadFile(filepath.Join("testdata", "good", fname.Name()))
if err != nil {
t.Fatal(err)
}
if err := fac.Valid(t.Context(), json.RawMessage(data)); err != nil {
t.Fatal(err)
}
})
}
}
func TestFactoryBad(t *testing.T) {
files, err := os.ReadDir("./testdata/bad")
if err != nil {
t.Fatal(err)
}
fac := Factory{}
for _, fname := range files {
t.Run(fname.Name(), func(t *testing.T) {
data, err := os.ReadFile(filepath.Join("testdata", "bad", fname.Name()))
if err != nil {
t.Fatal(err)
}
if err := fac.Valid(t.Context(), json.RawMessage(data)); err == nil {
t.Fatal("expected validation to fail")
}
})
}
}

View File

@@ -0,0 +1,3 @@
{
"regex": "a(b"
}

View File

@@ -0,0 +1 @@
{}

View File

@@ -0,0 +1,3 @@
{
"regex": "^/api/.*"
}

View File

@@ -0,0 +1,3 @@
{
"regex": ".*\\.json$"
}

43
lib/checker/registry.go Normal file
View File

@@ -0,0 +1,43 @@
package checker
import (
"context"
"encoding/json"
"sort"
"sync"
)
type Factory interface {
Build(context.Context, json.RawMessage) (Interface, error)
Valid(context.Context, json.RawMessage) error
}
var (
registry map[string]Factory = map[string]Factory{}
regLock sync.RWMutex
)
func Register(name string, factory Factory) {
regLock.Lock()
defer regLock.Unlock()
registry[name] = factory
}
func Get(name string) (Factory, bool) {
regLock.RLock()
defer regLock.RUnlock()
result, ok := registry[name]
return result, ok
}
func Methods() []string {
regLock.RLock()
defer regLock.RUnlock()
var result []string
for method := range registry {
result = append(result, method)
}
sort.Strings(result)
return result
}

View File

@@ -0,0 +1,127 @@
package remoteaddress
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/netip"
"github.com/TecharoHQ/anubis"
"github.com/TecharoHQ/anubis/internal"
"github.com/TecharoHQ/anubis/lib/checker"
"github.com/gaissmai/bart"
)
var (
ErrNoRemoteAddresses = errors.New("remoteaddress: no remote addresses defined")
ErrInvalidCIDR = errors.New("remoteaddress: invalid CIDR")
)
func init() {
checker.Register("remote_address", Factory{})
}
type Factory struct{}
func (Factory) Valid(_ context.Context, inp json.RawMessage) error {
var fc fileConfig
if err := json.Unmarshal([]byte(inp), &fc); err != nil {
return fmt.Errorf("%w: %w", checker.ErrUnparseableConfig, err)
}
if err := fc.Valid(); err != nil {
return err
}
return nil
}
func (Factory) Build(_ context.Context, inp json.RawMessage) (checker.Interface, error) {
c := struct {
RemoteAddr []netip.Prefix `json:"remote_addresses,omitempty" yaml:"remote_addresses,omitempty"`
}{}
if err := json.Unmarshal([]byte(inp), &c); err != nil {
return nil, fmt.Errorf("%w: %w", checker.ErrUnparseableConfig, err)
}
table := new(bart.Lite)
for _, cidr := range c.RemoteAddr {
table.Insert(cidr)
}
return &RemoteAddrChecker{
prefixTable: table,
hash: internal.FastHash(string(inp)),
}, nil
}
type fileConfig struct {
RemoteAddr []string `json:"remote_addresses,omitempty" yaml:"remote_addresses,omitempty"`
}
func (fc fileConfig) Valid() error {
var errs []error
if len(fc.RemoteAddr) == 0 {
errs = append(errs, ErrNoRemoteAddresses)
}
for _, cidr := range fc.RemoteAddr {
if _, err := netip.ParsePrefix(cidr); err != nil {
errs = append(errs, fmt.Errorf("%w: cidr %q is invalid: %w", ErrInvalidCIDR, cidr, err))
}
}
if len(errs) != 0 {
return fmt.Errorf("%w: %w", checker.ErrInvalidConfig, errors.Join(errs...))
}
return nil
}
func Valid(cidrs []string) error {
fc := fileConfig{
RemoteAddr: cidrs,
}
return fc.Valid()
}
func New(cidrs []string) (checker.Interface, error) {
fc := fileConfig{
RemoteAddr: cidrs,
}
data, err := json.Marshal(fc)
if err != nil {
return nil, err
}
return Factory{}.Build(context.Background(), json.RawMessage(data))
}
type RemoteAddrChecker struct {
prefixTable *bart.Lite
hash string
}
func (rac *RemoteAddrChecker) Check(r *http.Request) (bool, error) {
host := r.Header.Get("X-Real-Ip")
if host == "" {
return false, fmt.Errorf("%w: header X-Real-Ip is not set", anubis.ErrMisconfiguration)
}
addr, err := netip.ParseAddr(host)
if err != nil {
return false, fmt.Errorf("%w: %s is not an IP address: %w", anubis.ErrMisconfiguration, host, err)
}
return rac.prefixTable.Contains(addr), nil
}
func (rac *RemoteAddrChecker) Hash() string {
return rac.hash
}

View File

@@ -0,0 +1,138 @@
package remoteaddress_test
import (
_ "embed"
"encoding/json"
"errors"
"net/http"
"testing"
"github.com/TecharoHQ/anubis/lib/checker"
"github.com/TecharoHQ/anubis/lib/checker/remoteaddress"
)
func TestFactoryIsCheckerFactory(t *testing.T) {
if _, ok := (any(remoteaddress.Factory{})).(checker.Factory); !ok {
t.Fatal("Factory is not an instance of checker.Factory")
}
}
func TestFactoryValidateConfig(t *testing.T) {
f := remoteaddress.Factory{}
for _, tt := range []struct {
name string
data []byte
err error
}{
{
name: "basic valid",
data: []byte(`{
"remote_addresses": [
"1.1.1.1/32"
]
}`),
},
{
name: "not json",
data: []byte(`]`),
err: checker.ErrUnparseableConfig,
},
{
name: "no cidr",
data: []byte(`{
"remote_addresses": []
}`),
err: remoteaddress.ErrNoRemoteAddresses,
},
{
name: "bad cidr",
data: []byte(`{
"remote_addresses": [
"according to all laws of aviation"
]
}`),
err: remoteaddress.ErrInvalidCIDR,
},
} {
t.Run(tt.name, func(t *testing.T) {
data := json.RawMessage(tt.data)
if err := f.Valid(t.Context(), data); !errors.Is(err, tt.err) {
t.Logf("want: %v", tt.err)
t.Logf("got: %v", err)
t.Fatal("validation didn't do what was expected")
}
})
}
}
func TestFactoryCreate(t *testing.T) {
f := remoteaddress.Factory{}
for _, tt := range []struct {
name string
data []byte
err error
ip string
match bool
}{
{
name: "basic valid",
data: []byte(`{
"remote_addresses": [
"1.1.1.1/32"
]
}`),
ip: "1.1.1.1",
match: true,
},
{
name: "bad cidr",
data: []byte(`{
"remote_addresses": [
"according to all laws of aviation"
]
}`),
err: checker.ErrUnparseableConfig,
},
} {
t.Run(tt.name, func(t *testing.T) {
data := json.RawMessage(tt.data)
impl, err := f.Build(t.Context(), data)
if !errors.Is(err, tt.err) {
t.Logf("want: %v", tt.err)
t.Logf("got: %v", err)
t.Fatal("creation didn't do what was expected")
}
if tt.err != nil {
return
}
r, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
t.Fatalf("can't make request: %v", err)
}
if tt.ip != "" {
r.Header.Add("X-Real-Ip", tt.ip)
}
match, err := impl.Check(r)
if tt.match != match {
t.Errorf("match: %v, wanted: %v", match, tt.match)
}
if err != nil && tt.err != nil && !errors.Is(err, tt.err) {
t.Errorf("err: %v, wanted: %v", err, tt.err)
}
if impl.Hash() == "" {
t.Error("hash method returns empty string")
}
})
}
}

View File

@@ -0,0 +1,5 @@
{
"remote_addresses": [
"according to all laws of aviation"
]
}

View File

@@ -0,0 +1,3 @@
{
"remote_addresses": []
}

View File

@@ -0,0 +1 @@
]

View File

@@ -0,0 +1,5 @@
{
"remote_addresses": [
"1.1.1.1/32"
]
}

View File

@@ -4,12 +4,12 @@ import (
"fmt" "fmt"
"github.com/TecharoHQ/anubis/internal" "github.com/TecharoHQ/anubis/internal"
"github.com/TecharoHQ/anubis/lib/policy/checker" "github.com/TecharoHQ/anubis/lib/checker"
"github.com/TecharoHQ/anubis/lib/policy/config" "github.com/TecharoHQ/anubis/lib/policy/config"
) )
type Bot struct { type Bot struct {
Rules checker.Impl Rules checker.Interface
Challenge *config.ChallengeRules Challenge *config.ChallengeRules
Weight *config.Weight Weight *config.Weight
Name string Name string

View File

@@ -3,153 +3,39 @@ package policy
import ( import (
"errors" "errors"
"fmt" "fmt"
"net/http" "sort"
"net/netip"
"regexp"
"strings" "strings"
"github.com/TecharoHQ/anubis/internal" "github.com/TecharoHQ/anubis/lib/checker"
"github.com/TecharoHQ/anubis/lib/policy/checker" "github.com/TecharoHQ/anubis/lib/checker/headerexists"
"github.com/gaissmai/bart" "github.com/TecharoHQ/anubis/lib/checker/headermatches"
) )
var ( func NewHeadersChecker(headermap map[string]string) (checker.Interface, error) {
ErrMisconfiguration = errors.New("[unexpected] policy: administrator misconfiguration") var result checker.All
)
type RemoteAddrChecker struct {
prefixTable *bart.Lite
hash string
}
func NewRemoteAddrChecker(cidrs []string) (checker.Impl, error) {
table := new(bart.Lite)
for _, cidr := range cidrs {
prefix, err := netip.ParsePrefix(cidr)
if err != nil {
return nil, fmt.Errorf("%w: range %s not parsing: %w", ErrMisconfiguration, cidr, err)
}
table.Insert(prefix)
}
return &RemoteAddrChecker{
prefixTable: table,
hash: internal.FastHash(strings.Join(cidrs, ",")),
}, nil
}
func (rac *RemoteAddrChecker) Check(r *http.Request) (bool, error) {
host := r.Header.Get("X-Real-Ip")
if host == "" {
return false, fmt.Errorf("%w: header X-Real-Ip is not set", ErrMisconfiguration)
}
addr, err := netip.ParseAddr(host)
if err != nil {
return false, fmt.Errorf("%w: %s is not an IP address: %w", ErrMisconfiguration, host, err)
}
return rac.prefixTable.Contains(addr), nil
}
func (rac *RemoteAddrChecker) Hash() string {
return rac.hash
}
type HeaderMatchesChecker struct {
header string
regexp *regexp.Regexp
hash string
}
func NewUserAgentChecker(rexStr string) (checker.Impl, error) {
return NewHeaderMatchesChecker("User-Agent", rexStr)
}
func NewHeaderMatchesChecker(header, rexStr string) (checker.Impl, error) {
rex, err := regexp.Compile(strings.TrimSpace(rexStr))
if err != nil {
return nil, fmt.Errorf("%w: regex %s failed parse: %w", ErrMisconfiguration, rexStr, err)
}
return &HeaderMatchesChecker{strings.TrimSpace(header), rex, internal.FastHash(header + ": " + rexStr)}, nil
}
func (hmc *HeaderMatchesChecker) Check(r *http.Request) (bool, error) {
if hmc.regexp.MatchString(r.Header.Get(hmc.header)) {
return true, nil
}
return false, nil
}
func (hmc *HeaderMatchesChecker) Hash() string {
return hmc.hash
}
type PathChecker struct {
regexp *regexp.Regexp
hash string
}
func NewPathChecker(rexStr string) (checker.Impl, error) {
rex, err := regexp.Compile(strings.TrimSpace(rexStr))
if err != nil {
return nil, fmt.Errorf("%w: regex %s failed parse: %w", ErrMisconfiguration, rexStr, err)
}
return &PathChecker{rex, internal.FastHash(rexStr)}, nil
}
func (pc *PathChecker) Check(r *http.Request) (bool, error) {
if pc.regexp.MatchString(r.URL.Path) {
return true, nil
}
return false, nil
}
func (pc *PathChecker) Hash() string {
return pc.hash
}
func NewHeaderExistsChecker(key string) checker.Impl {
return headerExistsChecker{strings.TrimSpace(key)}
}
type headerExistsChecker struct {
header string
}
func (hec headerExistsChecker) Check(r *http.Request) (bool, error) {
if r.Header.Get(hec.header) != "" {
return true, nil
}
return false, nil
}
func (hec headerExistsChecker) Hash() string {
return internal.FastHash(hec.header)
}
func NewHeadersChecker(headermap map[string]string) (checker.Impl, error) {
var result checker.List
var errs []error var errs []error
for key, rexStr := range headermap { var keys []string
for key := range headermap {
keys = append(keys, key)
}
sort.Strings(keys)
for _, key := range keys {
rexStr := headermap[key]
if rexStr == ".*" { if rexStr == ".*" {
result = append(result, headerExistsChecker{strings.TrimSpace(key)}) result = append(result, headerexists.New(strings.TrimSpace(key)))
continue continue
} }
rex, err := regexp.Compile(strings.TrimSpace(rexStr)) c, err := headermatches.New(key, rexStr)
if err != nil { if err != nil {
errs = append(errs, fmt.Errorf("while compiling header %s regex %s: %w", key, rexStr, err)) errs = append(errs, fmt.Errorf("while parsing header %s regex %s: %w", key, rexStr, err))
continue continue
} }
result = append(result, &HeaderMatchesChecker{key, rex, internal.FastHash(key + ": " + rexStr)}) result = append(result, c)
} }
if len(errs) != 0 { if len(errs) != 0 {

View File

@@ -1,41 +0,0 @@
// Package checker defines the Checker interface and a helper utility to avoid import cycles.
package checker
import (
"fmt"
"net/http"
"strings"
"github.com/TecharoHQ/anubis/internal"
)
type Impl interface {
Check(*http.Request) (bool, error)
Hash() string
}
type List []Impl
func (l List) Check(r *http.Request) (bool, error) {
for _, c := range l {
ok, err := c.Check(r)
if err != nil {
return ok, err
}
if ok {
return ok, nil
}
}
return false, nil
}
func (l List) Hash() string {
var sb strings.Builder
for _, c := range l {
fmt.Fprintln(&sb, c.Hash())
}
return internal.FastHash(sb.String())
}

View File

@@ -1,200 +0,0 @@
package policy
import (
"errors"
"net/http"
"testing"
)
func TestRemoteAddrChecker(t *testing.T) {
for _, tt := range []struct {
err error
name string
ip string
cidrs []string
ok bool
}{
{
name: "match_ipv4",
cidrs: []string{"0.0.0.0/0"},
ip: "1.1.1.1",
ok: true,
err: nil,
},
{
name: "match_ipv6",
cidrs: []string{"::/0"},
ip: "cafe:babe::",
ok: true,
err: nil,
},
{
name: "not_match_ipv4",
cidrs: []string{"1.1.1.1/32"},
ip: "1.1.1.2",
ok: false,
err: nil,
},
{
name: "not_match_ipv6",
cidrs: []string{"cafe:babe::/128"},
ip: "cafe:babe:4::/128",
ok: false,
err: nil,
},
{
name: "no_ip_set",
cidrs: []string{"::/0"},
ok: false,
err: ErrMisconfiguration,
},
{
name: "invalid_ip",
cidrs: []string{"::/0"},
ip: "According to all natural laws of aviation",
ok: false,
err: ErrMisconfiguration,
},
} {
t.Run(tt.name, func(t *testing.T) {
rac, err := NewRemoteAddrChecker(tt.cidrs)
if err != nil && !errors.Is(err, tt.err) {
t.Fatalf("creating RemoteAddrChecker failed: %v", err)
}
r, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
t.Fatalf("can't make request: %v", err)
}
if tt.ip != "" {
r.Header.Add("X-Real-Ip", tt.ip)
}
ok, err := rac.Check(r)
if tt.ok != ok {
t.Errorf("ok: %v, wanted: %v", ok, tt.ok)
}
if err != nil && tt.err != nil && !errors.Is(err, tt.err) {
t.Errorf("err: %v, wanted: %v", err, tt.err)
}
})
}
}
func TestHeaderMatchesChecker(t *testing.T) {
for _, tt := range []struct {
err error
name string
header string
rexStr string
reqHeaderKey string
reqHeaderValue string
ok bool
}{
{
name: "match",
header: "Cf-Worker",
rexStr: ".*",
reqHeaderKey: "Cf-Worker",
reqHeaderValue: "true",
ok: true,
err: nil,
},
{
name: "not_match",
header: "Cf-Worker",
rexStr: "false",
reqHeaderKey: "Cf-Worker",
reqHeaderValue: "true",
ok: false,
err: nil,
},
{
name: "not_present",
header: "Cf-Worker",
rexStr: "foobar",
reqHeaderKey: "Something-Else",
reqHeaderValue: "true",
ok: false,
err: nil,
},
{
name: "invalid_regex",
rexStr: "a(b",
err: ErrMisconfiguration,
},
} {
t.Run(tt.name, func(t *testing.T) {
hmc, err := NewHeaderMatchesChecker(tt.header, tt.rexStr)
if err != nil && !errors.Is(err, tt.err) {
t.Fatalf("creating HeaderMatchesChecker failed")
}
if tt.err != nil && hmc == nil {
return
}
r, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
t.Fatalf("can't make request: %v", err)
}
r.Header.Set(tt.reqHeaderKey, tt.reqHeaderValue)
ok, err := hmc.Check(r)
if tt.ok != ok {
t.Errorf("ok: %v, wanted: %v", ok, tt.ok)
}
if err != nil && tt.err != nil && !errors.Is(err, tt.err) {
t.Errorf("err: %v, wanted: %v", err, tt.err)
}
})
}
}
func TestHeaderExistsChecker(t *testing.T) {
for _, tt := range []struct {
name string
header string
reqHeader string
ok bool
}{
{
name: "match",
header: "Authorization",
reqHeader: "Authorization",
ok: true,
},
{
name: "not_match",
header: "Authorization",
reqHeader: "Authentication",
},
} {
t.Run(tt.name, func(t *testing.T) {
hec := headerExistsChecker{tt.header}
r, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
t.Fatalf("can't make request: %v", err)
}
r.Header.Set(tt.reqHeader, "hunter2")
ok, err := hec.Check(r)
if tt.ok != ok {
t.Errorf("ok: %v, wanted: %v", ok, tt.ok)
}
if err != nil {
t.Errorf("err: %v", err)
}
})
}
}

View File

@@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"io" "io"
"io/fs" "io/fs"
"net"
"net/http" "net/http"
"os" "os"
"regexp" "regexp"
@@ -13,6 +12,10 @@ import (
"time" "time"
"github.com/TecharoHQ/anubis/data" "github.com/TecharoHQ/anubis/data"
"github.com/TecharoHQ/anubis/lib/checker/expression"
"github.com/TecharoHQ/anubis/lib/checker/headermatches"
"github.com/TecharoHQ/anubis/lib/checker/path"
"github.com/TecharoHQ/anubis/lib/checker/remoteaddress"
"k8s.io/apimachinery/pkg/util/yaml" "k8s.io/apimachinery/pkg/util/yaml"
) )
@@ -25,12 +28,12 @@ var (
ErrInvalidUserAgentRegex = errors.New("config.Bot: invalid user agent regex") ErrInvalidUserAgentRegex = errors.New("config.Bot: invalid user agent regex")
ErrInvalidPathRegex = errors.New("config.Bot: invalid path regex") ErrInvalidPathRegex = errors.New("config.Bot: invalid path regex")
ErrInvalidHeadersRegex = errors.New("config.Bot: invalid headers regex") ErrInvalidHeadersRegex = errors.New("config.Bot: invalid headers regex")
ErrInvalidCIDR = errors.New("config.Bot: invalid CIDR")
ErrRegexEndsWithNewline = errors.New("config.Bot: regular expression ends with newline (try >- instead of > in yaml)") ErrRegexEndsWithNewline = errors.New("config.Bot: regular expression ends with newline (try >- instead of > in yaml)")
ErrInvalidImportStatement = errors.New("config.ImportStatement: invalid source file") ErrInvalidImportStatement = errors.New("config.ImportStatement: invalid source file")
ErrCantSetBotAndImportValuesAtOnce = errors.New("config.BotOrImport: can't set bot rules and import values at the same time") ErrCantSetBotAndImportValuesAtOnce = errors.New("config.BotOrImport: can't set bot rules and import values at the same time")
ErrMustSetBotOrImportRules = errors.New("config.BotOrImport: rule definition is invalid, you must set either bot rules or an import statement, not both") ErrMustSetBotOrImportRules = errors.New("config.BotOrImport: rule definition is invalid, you must set either bot rules or an import statement, not both")
ErrStatusCodeNotValid = errors.New("config.StatusCode: status code not valid, must be between 100 and 599") ErrStatusCodeNotValid = errors.New("config.StatusCode: status code not valid, must be between 100 and 599")
ErrUnparseableConfig = errors.New("config: can't parse configuration file")
) )
type Rule string type Rule string
@@ -56,15 +59,15 @@ func (r Rule) Valid() error {
const DefaultAlgorithm = "fast" const DefaultAlgorithm = "fast"
type BotConfig struct { type BotConfig struct {
UserAgentRegex *string `json:"user_agent_regex,omitempty" yaml:"user_agent_regex,omitempty"` UserAgentRegex *string `json:"user_agent_regex,omitempty" yaml:"user_agent_regex,omitempty"`
PathRegex *string `json:"path_regex,omitempty" yaml:"path_regex,omitempty"` PathRegex *string `json:"path_regex,omitempty" yaml:"path_regex,omitempty"`
HeadersRegex map[string]string `json:"headers_regex,omitempty" yaml:"headers_regex,omitempty"` HeadersRegex map[string]string `json:"headers_regex,omitempty" yaml:"headers_regex,omitempty"`
Expression *ExpressionOrList `json:"expression,omitempty" yaml:"expression,omitempty"` Expression *expression.Config `json:"expression,omitempty" yaml:"expression,omitempty"`
Challenge *ChallengeRules `json:"challenge,omitempty" yaml:"challenge,omitempty"` Challenge *ChallengeRules `json:"challenge,omitempty" yaml:"challenge,omitempty"`
Weight *Weight `json:"weight,omitempty" yaml:"weight,omitempty"` Weight *Weight `json:"weight,omitempty" yaml:"weight,omitempty"`
Name string `json:"name" yaml:"name"` Name string `json:"name" yaml:"name"`
Action Rule `json:"action" yaml:"action"` Action Rule `json:"action" yaml:"action"`
RemoteAddr []string `json:"remote_addresses,omitempty" yaml:"remote_addresses,omitempty"` RemoteAddr []string `json:"remote_addresses,omitempty" yaml:"remote_addresses,omitempty"`
// Thoth features // Thoth features
GeoIP *GeoIP `json:"geoip,omitempty"` GeoIP *GeoIP `json:"geoip,omitempty"`
@@ -118,7 +121,7 @@ func (b *BotConfig) Valid() error {
errs = append(errs, fmt.Errorf("%w: user agent regex: %q", ErrRegexEndsWithNewline, *b.UserAgentRegex)) errs = append(errs, fmt.Errorf("%w: user agent regex: %q", ErrRegexEndsWithNewline, *b.UserAgentRegex))
} }
if _, err := regexp.Compile(*b.UserAgentRegex); err != nil { if err := headermatches.ValidUserAgent(*b.UserAgentRegex); err != nil {
errs = append(errs, ErrInvalidUserAgentRegex, err) errs = append(errs, ErrInvalidUserAgentRegex, err)
} }
} }
@@ -128,7 +131,7 @@ func (b *BotConfig) Valid() error {
errs = append(errs, fmt.Errorf("%w: path regex: %q", ErrRegexEndsWithNewline, *b.PathRegex)) errs = append(errs, fmt.Errorf("%w: path regex: %q", ErrRegexEndsWithNewline, *b.PathRegex))
} }
if _, err := regexp.Compile(*b.PathRegex); err != nil { if err := path.Valid(*b.PathRegex); err != nil {
errs = append(errs, ErrInvalidPathRegex, err) errs = append(errs, ErrInvalidPathRegex, err)
} }
} }
@@ -150,10 +153,8 @@ func (b *BotConfig) Valid() error {
} }
if len(b.RemoteAddr) > 0 { if len(b.RemoteAddr) > 0 {
for _, cidr := range b.RemoteAddr { if err := remoteaddress.Valid(b.RemoteAddr); err != nil {
if _, _, err := net.ParseCIDR(cidr); err != nil { errs = append(errs, err)
errs = append(errs, ErrInvalidCIDR, err)
}
} }
} }

View File

@@ -8,6 +8,7 @@ import (
"testing" "testing"
"github.com/TecharoHQ/anubis/data" "github.com/TecharoHQ/anubis/data"
"github.com/TecharoHQ/anubis/lib/checker/remoteaddress"
. "github.com/TecharoHQ/anubis/lib/policy/config" . "github.com/TecharoHQ/anubis/lib/policy/config"
) )
@@ -137,7 +138,7 @@ func TestBotValid(t *testing.T) {
Action: RuleAllow, Action: RuleAllow,
RemoteAddr: []string{"0.0.0.0/33"}, RemoteAddr: []string{"0.0.0.0/33"},
}, },
err: ErrInvalidCIDR, err: remoteaddress.ErrInvalidCIDR,
}, },
{ {
name: "only filter by IP range", name: "only filter by IP range",

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"github.com/TecharoHQ/anubis" "github.com/TecharoHQ/anubis"
"github.com/TecharoHQ/anubis/lib/checker/expression"
) )
var ( var (
@@ -17,7 +18,7 @@ var (
DefaultThresholds = []Threshold{ DefaultThresholds = []Threshold{
{ {
Name: "legacy-anubis-behaviour", Name: "legacy-anubis-behaviour",
Expression: &ExpressionOrList{ Expression: &expression.Config{
Expression: "weight > 0", Expression: "weight > 0",
}, },
Action: RuleChallenge, Action: RuleChallenge,
@@ -31,10 +32,10 @@ var (
) )
type Threshold struct { type Threshold struct {
Name string `json:"name" yaml:"name"` Name string `json:"name" yaml:"name"`
Expression *ExpressionOrList `json:"expression" yaml:"expression"` Expression *expression.Config `json:"expression" yaml:"expression"`
Action Rule `json:"action" yaml:"action"` Action Rule `json:"action" yaml:"action"`
Challenge *ChallengeRules `json:"challenge" yaml:"challenge"` Challenge *ChallengeRules `json:"challenge" yaml:"challenge"`
} }
func (t Threshold) Valid() error { func (t Threshold) Valid() error {

View File

@@ -6,6 +6,8 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/TecharoHQ/anubis/lib/checker/expression"
) )
func TestThresholdValid(t *testing.T) { func TestThresholdValid(t *testing.T) {
@@ -18,7 +20,7 @@ func TestThresholdValid(t *testing.T) {
name: "basic allow", name: "basic allow",
input: &Threshold{ input: &Threshold{
Name: "basic-allow", Name: "basic-allow",
Expression: &ExpressionOrList{Expression: "true"}, Expression: &expression.Config{Expression: "true"},
Action: RuleAllow, Action: RuleAllow,
}, },
err: nil, err: nil,
@@ -27,7 +29,7 @@ func TestThresholdValid(t *testing.T) {
name: "basic challenge", name: "basic challenge",
input: &Threshold{ input: &Threshold{
Name: "basic-challenge", Name: "basic-challenge",
Expression: &ExpressionOrList{Expression: "true"}, Expression: &expression.Config{Expression: "true"},
Action: RuleChallenge, Action: RuleChallenge,
Challenge: &ChallengeRules{ Challenge: &ChallengeRules{
Algorithm: "fast", Algorithm: "fast",
@@ -50,9 +52,9 @@ func TestThresholdValid(t *testing.T) {
{ {
name: "invalid expression", name: "invalid expression",
input: &Threshold{ input: &Threshold{
Expression: &ExpressionOrList{}, Expression: &expression.Config{},
}, },
err: ErrExpressionEmpty, err: expression.ErrExpressionEmpty,
}, },
{ {
name: "invalid action", name: "invalid action",

View File

@@ -1,398 +0,0 @@
package expressions
import (
"testing"
"github.com/google/cel-go/common/types"
)
func TestBotEnvironment(t *testing.T) {
env, err := BotEnvironment()
if err != nil {
t.Fatalf("failed to create bot environment: %v", err)
}
t.Run("missingHeader", func(t *testing.T) {
tests := []struct {
name string
expression string
headers map[string]string
expected types.Bool
description string
}{
{
name: "missing-header",
expression: `missingHeader(headers, "Missing-Header")`,
headers: map[string]string{
"User-Agent": "test-agent",
"Content-Type": "application/json",
},
expected: types.Bool(true),
description: "should return true when header is missing",
},
{
name: "existing-header",
expression: `missingHeader(headers, "User-Agent")`,
headers: map[string]string{
"User-Agent": "test-agent",
"Content-Type": "application/json",
},
expected: types.Bool(false),
description: "should return false when header exists",
},
{
name: "case-sensitive",
expression: `missingHeader(headers, "user-agent")`,
headers: map[string]string{
"User-Agent": "test-agent",
},
expected: types.Bool(true),
description: "should be case-sensitive (user-agent != User-Agent)",
},
{
name: "empty-headers",
expression: `missingHeader(headers, "Any-Header")`,
headers: map[string]string{},
expected: types.Bool(true),
description: "should return true for any header when map is empty",
},
{
name: "real-world-sec-ch-ua",
expression: `missingHeader(headers, "Sec-Ch-Ua")`,
headers: map[string]string{
"User-Agent": "curl/7.68.0",
"Accept": "*/*",
"Host": "example.com",
},
expected: types.Bool(true),
description: "should detect missing browser-specific headers from bots",
},
{
name: "browser-with-sec-ch-ua",
expression: `missingHeader(headers, "Sec-Ch-Ua")`,
headers: map[string]string{
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
"Sec-Ch-Ua": `"Chrome"; v="91", "Not A Brand"; v="99"`,
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
},
expected: types.Bool(false),
description: "should return false when browser sends Sec-Ch-Ua header",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
prog, err := Compile(env, tt.expression)
if err != nil {
t.Fatalf("failed to compile expression %q: %v", tt.expression, err)
}
result, _, err := prog.Eval(map[string]interface{}{
"headers": tt.headers,
})
if err != nil {
t.Fatalf("failed to evaluate expression %q: %v", tt.expression, err)
}
if result != tt.expected {
t.Errorf("%s: expected %v, got %v", tt.description, tt.expected, result)
}
})
}
t.Run("function-compilation", func(t *testing.T) {
src := `missingHeader(headers, "Test-Header")`
_, err := Compile(env, src)
if err != nil {
t.Fatalf("failed to compile missingHeader expression: %v", err)
}
})
})
t.Run("segments", func(t *testing.T) {
for _, tt := range []struct {
name string
description string
expression string
path string
expected types.Bool
}{
{
name: "simple",
description: "/ should have one path segment",
expression: `size(segments(path)) == 1`,
path: "/",
expected: types.Bool(true),
},
{
name: "two segments without trailing slash",
description: "/user/foo should have two segments",
expression: `size(segments(path)) == 2`,
path: "/user/foo",
expected: types.Bool(true),
},
{
name: "at least two segments",
description: "/foo/bar/ should have at least two path segments",
expression: `size(segments(path)) >= 2`,
path: "/foo/bar/",
expected: types.Bool(true),
},
{
name: "at most two segments",
description: "/foo/bar/ does not have less than two path segments",
expression: `size(segments(path)) < 2`,
path: "/foo/bar/",
expected: types.Bool(false),
},
} {
t.Run(tt.name, func(t *testing.T) {
prog, err := Compile(env, tt.expression)
if err != nil {
t.Fatalf("failed to compile expression %q: %v", tt.expression, err)
}
result, _, err := prog.Eval(map[string]interface{}{
"path": tt.path,
})
if err != nil {
t.Fatalf("failed to evaluate expression %q: %v", tt.expression, err)
}
if result != tt.expected {
t.Errorf("%s: expected %v, got %v", tt.description, tt.expected, result)
}
})
}
t.Run("invalid", func(t *testing.T) {
for _, tt := range []struct {
name string
description string
expression string
env any
wantFailCompile bool
wantFailEval bool
}{
{
name: "segments of headers",
description: "headers are not a path list",
expression: `segments(headers)`,
env: map[string]any{
"headers": map[string]string{
"foo": "bar",
},
},
wantFailCompile: true,
},
{
name: "invalid path type",
description: "a path should be a sting",
expression: `size(segments(path)) != 0`,
env: map[string]any{
"path": 4,
},
wantFailEval: true,
},
{
name: "invalid path",
description: "a path should start with a leading slash",
expression: `size(segments(path)) != 0`,
env: map[string]any{
"path": "foo",
},
wantFailEval: true,
},
} {
t.Run(tt.name, func(t *testing.T) {
prog, err := Compile(env, tt.expression)
if err != nil {
if !tt.wantFailCompile {
t.Log(tt.description)
t.Fatalf("failed to compile expression %q: %v", tt.expression, err)
} else {
return
}
}
_, _, err = prog.Eval(tt.env)
if err == nil {
t.Log(tt.description)
t.Fatal("wanted an error but got none")
}
t.Log(err)
})
}
})
t.Run("function-compilation", func(t *testing.T) {
src := `size(segments(path)) <= 2`
_, err := Compile(env, src)
if err != nil {
t.Fatalf("failed to compile missingHeader expression: %v", err)
}
})
})
}
func TestThresholdEnvironment(t *testing.T) {
env, err := ThresholdEnvironment()
if err != nil {
t.Fatalf("failed to create threshold environment: %v", err)
}
tests := []struct {
name string
expression string
variables map[string]interface{}
expected types.Bool
description string
shouldCompile bool
}{
{
name: "weight-variable-available",
expression: `weight > 100`,
variables: map[string]interface{}{"weight": 150},
expected: types.Bool(true),
description: "should support weight variable in expressions",
shouldCompile: true,
},
{
name: "weight-variable-false-case",
expression: `weight > 100`,
variables: map[string]interface{}{"weight": 50},
expected: types.Bool(false),
description: "should correctly evaluate weight comparisons",
shouldCompile: true,
},
{
name: "missingHeader-not-available",
expression: `missingHeader(headers, "Test")`,
variables: map[string]interface{}{},
expected: types.Bool(false), // not used
description: "should not have missingHeader function available",
shouldCompile: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
prog, err := Compile(env, tt.expression)
if !tt.shouldCompile {
if err == nil {
t.Fatalf("%s: expected compilation to fail but it succeeded", tt.description)
}
return // Test passed - compilation failed as expected
}
if err != nil {
t.Fatalf("failed to compile expression %q: %v", tt.expression, err)
}
result, _, err := prog.Eval(tt.variables)
if err != nil {
t.Fatalf("failed to evaluate expression %q: %v", tt.expression, err)
}
if result != tt.expected {
t.Errorf("%s: expected %v, got %v", tt.description, tt.expected, result)
}
})
}
}
func TestNewEnvironment(t *testing.T) {
env, err := New()
if err != nil {
t.Fatalf("failed to create new environment: %v", err)
}
tests := []struct {
name string
expression string
variables map[string]interface{}
expectBool *bool // nil if we just want to test compilation or non-bool result
description string
shouldCompile bool
}{
{
name: "randInt-function-compilation",
expression: `randInt(10)`,
variables: map[string]interface{}{},
expectBool: nil, // Don't check result, just compilation
description: "should compile randInt function",
shouldCompile: true,
},
{
name: "randInt-range-validation",
expression: `randInt(10) >= 0 && randInt(10) < 10`,
variables: map[string]interface{}{},
expectBool: boolPtr(true),
description: "should return values in correct range",
shouldCompile: true,
},
{
name: "strings-extension-size",
expression: `"hello".size() == 5`,
variables: map[string]interface{}{},
expectBool: boolPtr(true),
description: "should support string extension functions",
shouldCompile: true,
},
{
name: "strings-extension-contains",
expression: `"hello world".contains("world")`,
variables: map[string]interface{}{},
expectBool: boolPtr(true),
description: "should support string contains function",
shouldCompile: true,
},
{
name: "strings-extension-startsWith",
expression: `"hello world".startsWith("hello")`,
variables: map[string]interface{}{},
expectBool: boolPtr(true),
description: "should support string startsWith function",
shouldCompile: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
prog, err := Compile(env, tt.expression)
if !tt.shouldCompile {
if err == nil {
t.Fatalf("%s: expected compilation to fail but it succeeded", tt.description)
}
return // Test passed - compilation failed as expected
}
if err != nil {
t.Fatalf("failed to compile expression %q: %v", tt.expression, err)
}
// If we only want to test compilation, skip evaluation
if tt.expectBool == nil {
return
}
result, _, err := prog.Eval(tt.variables)
if err != nil {
t.Fatalf("failed to evaluate expression %q: %v", tt.expression, err)
}
if result != types.Bool(*tt.expectBool) {
t.Errorf("%s: expected %v, got %v", tt.description, *tt.expectBool, result)
}
})
}
}
// Helper function to create bool pointers
func boolPtr(b bool) *bool {
return &b
}

View File

@@ -8,7 +8,11 @@ import (
"log/slog" "log/slog"
"sync/atomic" "sync/atomic"
"github.com/TecharoHQ/anubis/lib/policy/checker" "github.com/TecharoHQ/anubis/lib/checker"
"github.com/TecharoHQ/anubis/lib/checker/expression"
"github.com/TecharoHQ/anubis/lib/checker/headermatches"
"github.com/TecharoHQ/anubis/lib/checker/path"
"github.com/TecharoHQ/anubis/lib/checker/remoteaddress"
"github.com/TecharoHQ/anubis/lib/policy/config" "github.com/TecharoHQ/anubis/lib/policy/config"
"github.com/TecharoHQ/anubis/lib/store" "github.com/TecharoHQ/anubis/lib/store"
"github.com/TecharoHQ/anubis/lib/thoth" "github.com/TecharoHQ/anubis/lib/thoth"
@@ -73,10 +77,10 @@ func ParseConfig(ctx context.Context, fin io.Reader, fname string, defaultDiffic
Action: b.Action, Action: b.Action,
} }
cl := checker.List{} cl := checker.Any{}
if len(b.RemoteAddr) > 0 { if len(b.RemoteAddr) > 0 {
c, err := NewRemoteAddrChecker(b.RemoteAddr) c, err := remoteaddress.New(b.RemoteAddr)
if err != nil { if err != nil {
validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s remote addr set: %w", b.Name, err)) validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s remote addr set: %w", b.Name, err))
} else { } else {
@@ -85,7 +89,7 @@ func ParseConfig(ctx context.Context, fin io.Reader, fname string, defaultDiffic
} }
if b.UserAgentRegex != nil { if b.UserAgentRegex != nil {
c, err := NewUserAgentChecker(*b.UserAgentRegex) c, err := headermatches.NewUserAgent(*b.UserAgentRegex)
if err != nil { if err != nil {
validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s user agent regex: %w", b.Name, err)) validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s user agent regex: %w", b.Name, err))
} else { } else {
@@ -94,7 +98,7 @@ func ParseConfig(ctx context.Context, fin io.Reader, fname string, defaultDiffic
} }
if b.PathRegex != nil { if b.PathRegex != nil {
c, err := NewPathChecker(*b.PathRegex) c, err := path.New(*b.PathRegex)
if err != nil { if err != nil {
validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s path regex: %w", b.Name, err)) validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s path regex: %w", b.Name, err))
} else { } else {
@@ -112,7 +116,7 @@ func ParseConfig(ctx context.Context, fin io.Reader, fname string, defaultDiffic
} }
if b.Expression != nil { if b.Expression != nil {
c, err := NewCELChecker(b.Expression) c, err := expression.New(b.Expression)
if err != nil { if err != nil {
validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s expressions: %w", b.Name, err)) validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s expressions: %w", b.Name, err))
} else { } else {

View File

@@ -1,8 +1,8 @@
package policy package policy
import ( import (
"github.com/TecharoHQ/anubis/lib/checker/expression/environment"
"github.com/TecharoHQ/anubis/lib/policy/config" "github.com/TecharoHQ/anubis/lib/policy/config"
"github.com/TecharoHQ/anubis/lib/policy/expressions"
"github.com/google/cel-go/cel" "github.com/google/cel-go/cel"
) )
@@ -16,12 +16,12 @@ func ParsedThresholdFromConfig(t config.Threshold) (*Threshold, error) {
Threshold: t, Threshold: t,
} }
env, err := expressions.ThresholdEnvironment() env, err := environment.Threshold()
if err != nil { if err != nil {
return nil, err return nil, err
} }
program, err := expressions.Compile(env, t.Expression.String()) program, err := environment.Compile(env, t.Expression.String())
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -10,11 +10,11 @@ import (
"time" "time"
"github.com/TecharoHQ/anubis/internal" "github.com/TecharoHQ/anubis/internal"
"github.com/TecharoHQ/anubis/lib/policy/checker" "github.com/TecharoHQ/anubis/lib/checker"
iptoasnv1 "github.com/TecharoHQ/thoth-proto/gen/techaro/thoth/iptoasn/v1" iptoasnv1 "github.com/TecharoHQ/thoth-proto/gen/techaro/thoth/iptoasn/v1"
) )
func (c *Client) ASNCheckerFor(asns []uint32) checker.Impl { func (c *Client) ASNCheckerFor(asns []uint32) checker.Interface {
asnMap := map[uint32]struct{}{} asnMap := map[uint32]struct{}{}
var sb strings.Builder var sb strings.Builder
fmt.Fprintln(&sb, "ASNChecker") fmt.Fprintln(&sb, "ASNChecker")

View File

@@ -5,12 +5,12 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/TecharoHQ/anubis/lib/policy/checker" "github.com/TecharoHQ/anubis/lib/checker"
"github.com/TecharoHQ/anubis/lib/thoth" "github.com/TecharoHQ/anubis/lib/thoth"
iptoasnv1 "github.com/TecharoHQ/thoth-proto/gen/techaro/thoth/iptoasn/v1" iptoasnv1 "github.com/TecharoHQ/thoth-proto/gen/techaro/thoth/iptoasn/v1"
) )
var _ checker.Impl = &thoth.ASNChecker{} var _ checker.Interface = &thoth.ASNChecker{}
func TestASNChecker(t *testing.T) { func TestASNChecker(t *testing.T) {
cli := loadSecrets(t) cli := loadSecrets(t)

View File

@@ -9,11 +9,11 @@ import (
"strings" "strings"
"time" "time"
"github.com/TecharoHQ/anubis/lib/policy/checker" "github.com/TecharoHQ/anubis/lib/checker"
iptoasnv1 "github.com/TecharoHQ/thoth-proto/gen/techaro/thoth/iptoasn/v1" iptoasnv1 "github.com/TecharoHQ/thoth-proto/gen/techaro/thoth/iptoasn/v1"
) )
func (c *Client) GeoIPCheckerFor(countries []string) checker.Impl { func (c *Client) GeoIPCheckerFor(countries []string) checker.Interface {
countryMap := map[string]struct{}{} countryMap := map[string]struct{}{}
var sb strings.Builder var sb strings.Builder
fmt.Fprintln(&sb, "GeoIPChecker") fmt.Fprintln(&sb, "GeoIPChecker")

View File

@@ -5,11 +5,11 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/TecharoHQ/anubis/lib/policy/checker" "github.com/TecharoHQ/anubis/lib/checker"
"github.com/TecharoHQ/anubis/lib/thoth" "github.com/TecharoHQ/anubis/lib/thoth"
) )
var _ checker.Impl = &thoth.GeoIPChecker{} var _ checker.Interface = &thoth.GeoIPChecker{}
func TestGeoIPChecker(t *testing.T) { func TestGeoIPChecker(t *testing.T) {
cli := loadSecrets(t) cli := loadSecrets(t)