refactor: move CEL checker to its own package

Signed-off-by: Xe Iaso <me@xeiaso.net>
This commit is contained in:
Xe Iaso
2025-07-25 19:52:07 +00:00
parent 590d8303ad
commit e98d749bf2
13 changed files with 135 additions and 78 deletions

View File

@@ -1,85 +0,0 @@
package policy
import (
"fmt"
"net/http"
"github.com/TecharoHQ/anubis/internal"
"github.com/TecharoHQ/anubis/lib/policy/config"
"github.com/TecharoHQ/anubis/lib/policy/expressions"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
)
type CELChecker struct {
program cel.Program
src string
}
func NewCELChecker(cfg *config.ExpressionOrList) (*CELChecker, error) {
env, err := expressions.BotEnvironment()
if err != nil {
return nil, err
}
program, err := expressions.Compile(env, cfg.String())
if err != nil {
return nil, fmt.Errorf("can't compile CEL program: %w", err)
}
return &CELChecker{
src: cfg.String(),
program: program,
}, nil
}
func (cc *CELChecker) Hash() string {
return internal.FastHash(cc.src)
}
func (cc *CELChecker) Check(r *http.Request) (bool, error) {
result, _, err := cc.program.ContextEval(r.Context(), &CELRequest{r})
if err != nil {
return false, err
}
if val, ok := result.(types.Bool); ok {
return bool(val), nil
}
return false, nil
}
type CELRequest struct {
*http.Request
}
func (cr *CELRequest) Parent() cel.Activation { return nil }
func (cr *CELRequest) ResolveName(name string) (any, bool) {
switch name {
case "remoteAddress":
return cr.Header.Get("X-Real-Ip"), true
case "host":
return cr.Host, true
case "method":
return cr.Method, true
case "userAgent":
return cr.UserAgent(), true
case "path":
return cr.URL.Path, true
case "query":
return expressions.URLValues{Values: cr.URL.Query()}, true
case "headers":
return expressions.HTTPHeaders{Header: cr.Header}, true
case "load_1m":
return expressions.Load1(), true
case "load_5m":
return expressions.Load5(), true
case "load_15m":
return expressions.Load15(), true
default:
return nil, false
}
}

View File

@@ -12,6 +12,7 @@ import (
"time"
"github.com/TecharoHQ/anubis/data"
"github.com/TecharoHQ/anubis/lib/checker/expression"
"github.com/TecharoHQ/anubis/lib/checker/headermatches"
"github.com/TecharoHQ/anubis/lib/checker/path"
"github.com/TecharoHQ/anubis/lib/checker/remoteaddress"
@@ -58,15 +59,15 @@ func (r Rule) Valid() error {
const DefaultAlgorithm = "fast"
type BotConfig struct {
UserAgentRegex *string `json:"user_agent_regex,omitempty" yaml:"user_agent_regex,omitempty"`
PathRegex *string `json:"path_regex,omitempty" yaml:"path_regex,omitempty"`
HeadersRegex map[string]string `json:"headers_regex,omitempty" yaml:"headers_regex,omitempty"`
Expression *ExpressionOrList `json:"expression,omitempty" yaml:"expression,omitempty"`
Challenge *ChallengeRules `json:"challenge,omitempty" yaml:"challenge,omitempty"`
Weight *Weight `json:"weight,omitempty" yaml:"weight,omitempty"`
Name string `json:"name" yaml:"name"`
Action Rule `json:"action" yaml:"action"`
RemoteAddr []string `json:"remote_addresses,omitempty" yaml:"remote_addresses,omitempty"`
UserAgentRegex *string `json:"user_agent_regex,omitempty" yaml:"user_agent_regex,omitempty"`
PathRegex *string `json:"path_regex,omitempty" yaml:"path_regex,omitempty"`
HeadersRegex map[string]string `json:"headers_regex,omitempty" yaml:"headers_regex,omitempty"`
Expression *expression.Config `json:"expression,omitempty" yaml:"expression,omitempty"`
Challenge *ChallengeRules `json:"challenge,omitempty" yaml:"challenge,omitempty"`
Weight *Weight `json:"weight,omitempty" yaml:"weight,omitempty"`
Name string `json:"name" yaml:"name"`
Action Rule `json:"action" yaml:"action"`
RemoteAddr []string `json:"remote_addresses,omitempty" yaml:"remote_addresses,omitempty"`
// Thoth features
GeoIP *GeoIP `json:"geoip,omitempty"`

View File

@@ -1,130 +0,0 @@
package config
import (
"encoding/json"
"errors"
"fmt"
"slices"
"strings"
)
var (
ErrExpressionOrListMustBeStringOrObject = errors.New("config: this must be a string or an object")
ErrExpressionEmpty = errors.New("config: this expression is empty")
ErrExpressionCantHaveBoth = errors.New("config: expression block can't contain multiple expression types")
)
type ExpressionOrList struct {
Expression string `json:"-" yaml:"-"`
All []string `json:"all,omitempty" yaml:"all,omitempty"`
Any []string `json:"any,omitempty" yaml:"any,omitempty"`
}
func (eol ExpressionOrList) String() string {
switch {
case len(eol.Expression) != 0:
return eol.Expression
case len(eol.All) != 0:
var sb strings.Builder
for i, pred := range eol.All {
if i != 0 {
fmt.Fprintf(&sb, " && ")
}
fmt.Fprintf(&sb, "( %s )", pred)
}
return sb.String()
case len(eol.Any) != 0:
var sb strings.Builder
for i, pred := range eol.Any {
if i != 0 {
fmt.Fprintf(&sb, " || ")
}
fmt.Fprintf(&sb, "( %s )", pred)
}
return sb.String()
}
panic("this should not happen")
}
func (eol ExpressionOrList) Equal(rhs *ExpressionOrList) bool {
if eol.Expression != rhs.Expression {
return false
}
if !slices.Equal(eol.All, rhs.All) {
return false
}
if !slices.Equal(eol.Any, rhs.Any) {
return false
}
return true
}
func (eol *ExpressionOrList) MarshalYAML() (any, error) {
switch {
case len(eol.All) == 1 && len(eol.Any) == 0:
eol.Expression = eol.All[0]
eol.All = nil
case len(eol.Any) == 1 && len(eol.All) == 0:
eol.Expression = eol.Any[0]
eol.Any = nil
}
if eol.Expression != "" {
return eol.Expression, nil
}
type RawExpressionOrList ExpressionOrList
return RawExpressionOrList(*eol), nil
}
func (eol *ExpressionOrList) MarshalJSON() ([]byte, error) {
switch {
case len(eol.All) == 1 && len(eol.Any) == 0:
eol.Expression = eol.All[0]
eol.All = nil
case len(eol.Any) == 1 && len(eol.All) == 0:
eol.Expression = eol.Any[0]
eol.Any = nil
}
if eol.Expression != "" {
return json.Marshal(string(eol.Expression))
}
type RawExpressionOrList ExpressionOrList
val := RawExpressionOrList(*eol)
return json.Marshal(val)
}
func (eol *ExpressionOrList) UnmarshalJSON(data []byte) error {
switch string(data[0]) {
case `"`: // string
return json.Unmarshal(data, &eol.Expression)
case "{": // object
type RawExpressionOrList ExpressionOrList
var val RawExpressionOrList
if err := json.Unmarshal(data, &val); err != nil {
return err
}
eol.All = val.All
eol.Any = val.Any
return nil
}
return ErrExpressionOrListMustBeStringOrObject
}
func (eol *ExpressionOrList) Valid() error {
if eol.Expression == "" && len(eol.All) == 0 && len(eol.Any) == 0 {
return ErrExpressionEmpty
}
if len(eol.All) != 0 && len(eol.Any) != 0 {
return ErrExpressionCantHaveBoth
}
return nil
}

View File

@@ -1,266 +0,0 @@
package config
import (
"bytes"
"encoding/json"
"errors"
"testing"
yaml "sigs.k8s.io/yaml/goyaml.v3"
)
func TestExpressionOrListMarshalJSON(t *testing.T) {
for _, tt := range []struct {
name string
input *ExpressionOrList
output []byte
err error
}{
{
name: "single expression",
input: &ExpressionOrList{
Expression: "true",
},
output: []byte(`"true"`),
err: nil,
},
{
name: "all",
input: &ExpressionOrList{
All: []string{"true", "true"},
},
output: []byte(`{"all":["true","true"]}`),
err: nil,
},
{
name: "all one",
input: &ExpressionOrList{
All: []string{"true"},
},
output: []byte(`"true"`),
err: nil,
},
{
name: "any",
input: &ExpressionOrList{
Any: []string{"true", "false"},
},
output: []byte(`{"any":["true","false"]}`),
err: nil,
},
{
name: "any one",
input: &ExpressionOrList{
Any: []string{"true"},
},
output: []byte(`"true"`),
err: nil,
},
} {
t.Run(tt.name, func(t *testing.T) {
result, err := json.Marshal(tt.input)
if !errors.Is(err, tt.err) {
t.Errorf("wanted marshal error: %v but got: %v", tt.err, err)
}
if !bytes.Equal(result, tt.output) {
t.Logf("wanted: %s", string(tt.output))
t.Logf("got: %s", string(result))
t.Error("mismatched output")
}
})
}
}
func TestExpressionOrListMarshalYAML(t *testing.T) {
for _, tt := range []struct {
name string
input *ExpressionOrList
output []byte
err error
}{
{
name: "single expression",
input: &ExpressionOrList{
Expression: "true",
},
output: []byte(`"true"`),
err: nil,
},
{
name: "all",
input: &ExpressionOrList{
All: []string{"true", "true"},
},
output: []byte(`all:
- "true"
- "true"`),
err: nil,
},
{
name: "all one",
input: &ExpressionOrList{
All: []string{"true"},
},
output: []byte(`"true"`),
err: nil,
},
{
name: "any",
input: &ExpressionOrList{
Any: []string{"true", "false"},
},
output: []byte(`any:
- "true"
- "false"`),
err: nil,
},
{
name: "any one",
input: &ExpressionOrList{
Any: []string{"true"},
},
output: []byte(`"true"`),
err: nil,
},
} {
t.Run(tt.name, func(t *testing.T) {
result, err := yaml.Marshal(tt.input)
if !errors.Is(err, tt.err) {
t.Errorf("wanted marshal error: %v but got: %v", tt.err, err)
}
result = bytes.TrimSpace(result)
if !bytes.Equal(result, tt.output) {
t.Logf("wanted: %q", string(tt.output))
t.Logf("got: %q", string(result))
t.Error("mismatched output")
}
})
}
}
func TestExpressionOrListUnmarshalJSON(t *testing.T) {
for _, tt := range []struct {
err error
validErr error
result *ExpressionOrList
name string
inp string
}{
{
name: "simple",
inp: `"\"User-Agent\" in headers"`,
result: &ExpressionOrList{
Expression: `"User-Agent" in headers`,
},
},
{
name: "object-and",
inp: `{
"all": ["\"User-Agent\" in headers"]
}`,
result: &ExpressionOrList{
All: []string{
`"User-Agent" in headers`,
},
},
},
{
name: "object-or",
inp: `{
"any": ["\"User-Agent\" in headers"]
}`,
result: &ExpressionOrList{
Any: []string{
`"User-Agent" in headers`,
},
},
},
{
name: "both-or-and",
inp: `{
"all": ["\"User-Agent\" in headers"],
"any": ["\"User-Agent\" in headers"]
}`,
validErr: ErrExpressionCantHaveBoth,
},
{
name: "expression-empty",
inp: `{
"any": []
}`,
validErr: ErrExpressionEmpty,
},
} {
t.Run(tt.name, func(t *testing.T) {
var eol ExpressionOrList
if err := json.Unmarshal([]byte(tt.inp), &eol); !errors.Is(err, tt.err) {
t.Errorf("wanted unmarshal error: %v but got: %v", tt.err, err)
}
if tt.result != nil && !eol.Equal(tt.result) {
t.Logf("wanted: %#v", tt.result)
t.Logf("got: %#v", &eol)
t.Fatal("parsed expression is not what was expected")
}
if err := eol.Valid(); !errors.Is(err, tt.validErr) {
t.Errorf("wanted validation error: %v but got: %v", tt.err, err)
}
})
}
}
func TestExpressionOrListString(t *testing.T) {
for _, tt := range []struct {
name string
in ExpressionOrList
out string
}{
{
name: "single expression",
in: ExpressionOrList{
Expression: "true",
},
out: "true",
},
{
name: "all",
in: ExpressionOrList{
All: []string{"true"},
},
out: "( true )",
},
{
name: "all with &&",
in: ExpressionOrList{
All: []string{"true", "true"},
},
out: "( true ) && ( true )",
},
{
name: "any",
in: ExpressionOrList{
All: []string{"true"},
},
out: "( true )",
},
{
name: "any with ||",
in: ExpressionOrList{
Any: []string{"true", "true"},
},
out: "( true ) || ( true )",
},
} {
t.Run(tt.name, func(t *testing.T) {
result := tt.in.String()
if result != tt.out {
t.Errorf("wanted %q, got: %q", tt.out, result)
}
})
}
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"github.com/TecharoHQ/anubis"
"github.com/TecharoHQ/anubis/lib/checker/expression"
)
var (
@@ -17,7 +18,7 @@ var (
DefaultThresholds = []Threshold{
{
Name: "legacy-anubis-behaviour",
Expression: &ExpressionOrList{
Expression: &expression.Config{
Expression: "weight > 0",
},
Action: RuleChallenge,
@@ -31,10 +32,10 @@ var (
)
type Threshold struct {
Name string `json:"name" yaml:"name"`
Expression *ExpressionOrList `json:"expression" yaml:"expression"`
Action Rule `json:"action" yaml:"action"`
Challenge *ChallengeRules `json:"challenge" yaml:"challenge"`
Name string `json:"name" yaml:"name"`
Expression *expression.Config `json:"expression" yaml:"expression"`
Action Rule `json:"action" yaml:"action"`
Challenge *ChallengeRules `json:"challenge" yaml:"challenge"`
}
func (t Threshold) Valid() error {

View File

@@ -6,6 +6,8 @@ import (
"os"
"path/filepath"
"testing"
"github.com/TecharoHQ/anubis/lib/checker/expression"
)
func TestThresholdValid(t *testing.T) {
@@ -18,7 +20,7 @@ func TestThresholdValid(t *testing.T) {
name: "basic allow",
input: &Threshold{
Name: "basic-allow",
Expression: &ExpressionOrList{Expression: "true"},
Expression: &expression.Config{Expression: "true"},
Action: RuleAllow,
},
err: nil,
@@ -27,7 +29,7 @@ func TestThresholdValid(t *testing.T) {
name: "basic challenge",
input: &Threshold{
Name: "basic-challenge",
Expression: &ExpressionOrList{Expression: "true"},
Expression: &expression.Config{Expression: "true"},
Action: RuleChallenge,
Challenge: &ChallengeRules{
Algorithm: "fast",
@@ -50,9 +52,9 @@ func TestThresholdValid(t *testing.T) {
{
name: "invalid expression",
input: &Threshold{
Expression: &ExpressionOrList{},
Expression: &expression.Config{},
},
err: ErrExpressionEmpty,
err: expression.ErrExpressionEmpty,
},
{
name: "invalid action",

View File

@@ -9,6 +9,7 @@ import (
"sync/atomic"
"github.com/TecharoHQ/anubis/lib/checker"
"github.com/TecharoHQ/anubis/lib/checker/expression"
"github.com/TecharoHQ/anubis/lib/checker/headermatches"
"github.com/TecharoHQ/anubis/lib/checker/path"
"github.com/TecharoHQ/anubis/lib/checker/remoteaddress"
@@ -115,7 +116,7 @@ func ParseConfig(ctx context.Context, fin io.Reader, fname string, defaultDiffic
}
if b.Expression != nil {
c, err := NewCELChecker(b.Expression)
c, err := expression.New(b.Expression)
if err != nil {
validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s expressions: %w", b.Name, err))
} else {