mirror of
https://github.com/TecharoHQ/anubis.git
synced 2026-04-14 12:38:45 +00:00
refactor: move cel environment creation to a subpackage
Signed-off-by: Xe Iaso <me@xeiaso.net>
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/TecharoHQ/anubis/internal"
|
||||
"github.com/TecharoHQ/anubis/lib/checker/expression/environment"
|
||||
"github.com/TecharoHQ/anubis/lib/policy/expressions"
|
||||
"github.com/google/cel-go/cel"
|
||||
"github.com/google/cel-go/common/types"
|
||||
@@ -17,12 +18,12 @@ type Checker struct {
|
||||
}
|
||||
|
||||
func New(cfg *Config) (*Checker, error) {
|
||||
env, err := expressions.BotEnvironment()
|
||||
env, err := environment.Bot()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
program, err := expressions.Compile(env, cfg.String())
|
||||
program, err := environment.Compile(env, cfg.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can't compile CEL program: %w", err)
|
||||
}
|
||||
|
||||
119
lib/checker/expression/environment/environment.go
Normal file
119
lib/checker/expression/environment/environment.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package environment
|
||||
|
||||
import (
|
||||
"math/rand/v2"
|
||||
|
||||
"github.com/google/cel-go/cel"
|
||||
"github.com/google/cel-go/common/types"
|
||||
"github.com/google/cel-go/common/types/ref"
|
||||
"github.com/google/cel-go/common/types/traits"
|
||||
"github.com/google/cel-go/ext"
|
||||
)
|
||||
|
||||
// Bot creates a new CEL environment, this is the set of variables and
|
||||
// functions that are passed into the CEL scope so that Anubis can fail
|
||||
// loudly and early when something is invalid instead of blowing up at
|
||||
// runtime.
|
||||
func Bot() (*cel.Env, error) {
|
||||
return New(
|
||||
// Variables exposed to CEL programs:
|
||||
cel.Variable("remoteAddress", cel.StringType),
|
||||
cel.Variable("host", cel.StringType),
|
||||
cel.Variable("method", cel.StringType),
|
||||
cel.Variable("userAgent", cel.StringType),
|
||||
cel.Variable("path", cel.StringType),
|
||||
cel.Variable("query", cel.MapType(cel.StringType, cel.StringType)),
|
||||
cel.Variable("headers", cel.MapType(cel.StringType, cel.StringType)),
|
||||
cel.Variable("load_1m", cel.DoubleType),
|
||||
cel.Variable("load_5m", cel.DoubleType),
|
||||
cel.Variable("load_15m", cel.DoubleType),
|
||||
|
||||
// Bot-specific functions:
|
||||
cel.Function("missingHeader",
|
||||
cel.Overload("missingHeader_map_string_string_string",
|
||||
[]*cel.Type{cel.MapType(cel.StringType, cel.StringType), cel.StringType},
|
||||
cel.BoolType,
|
||||
cel.BinaryBinding(func(headers, key ref.Val) ref.Val {
|
||||
// Convert headers to a trait that supports Find
|
||||
headersMap, ok := headers.(traits.Indexer)
|
||||
if !ok {
|
||||
return types.ValOrErr(headers, "headers is not a map, but is %T", headers)
|
||||
}
|
||||
|
||||
keyStr, ok := key.(types.String)
|
||||
if !ok {
|
||||
return types.ValOrErr(key, "key is not a string, but is %T", key)
|
||||
}
|
||||
|
||||
val := headersMap.Get(keyStr)
|
||||
// Check if the key is missing by testing for an error
|
||||
if types.IsError(val) {
|
||||
return types.Bool(true) // header is missing
|
||||
}
|
||||
return types.Bool(false) // header is present
|
||||
}),
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
// Threshold creates a new CEL environment for threshold checking.
|
||||
func Threshold() (*cel.Env, error) {
|
||||
return New(
|
||||
cel.Variable("weight", cel.IntType),
|
||||
)
|
||||
}
|
||||
|
||||
// New creates a new base CEL environment.
|
||||
func New(opts ...cel.EnvOption) (*cel.Env, error) {
|
||||
args := []cel.EnvOption{
|
||||
ext.Strings(
|
||||
ext.StringsLocale("en_US"),
|
||||
ext.StringsValidateFormatCalls(true),
|
||||
),
|
||||
|
||||
// default all timestamps to UTC
|
||||
cel.DefaultUTCTimeZone(true),
|
||||
|
||||
// Functions exposed to all CEL programs:
|
||||
cel.Function("randInt",
|
||||
cel.Overload("randInt_int",
|
||||
[]*cel.Type{cel.IntType},
|
||||
cel.IntType,
|
||||
cel.UnaryBinding(func(val ref.Val) ref.Val {
|
||||
n, ok := val.(types.Int)
|
||||
if !ok {
|
||||
return types.ValOrErr(val, "value is not an integer, but is %T", val)
|
||||
}
|
||||
|
||||
return types.Int(rand.IntN(int(n)))
|
||||
}),
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
args = append(args, opts...)
|
||||
return cel.NewEnv(args...)
|
||||
}
|
||||
|
||||
// Compile takes a CEL environment and syntax tree then emits an optimized
|
||||
// Program for execution.
|
||||
func Compile(env *cel.Env, src string) (cel.Program, error) {
|
||||
intermediate, iss := env.Compile(src)
|
||||
if iss != nil {
|
||||
return nil, iss.Err()
|
||||
}
|
||||
|
||||
ast, iss := env.Check(intermediate)
|
||||
if iss != nil {
|
||||
return nil, iss.Err()
|
||||
}
|
||||
|
||||
return env.Program(
|
||||
ast,
|
||||
cel.EvalOptions(
|
||||
// optimize regular expressions right now instead of on the fly
|
||||
cel.OptOptimize,
|
||||
),
|
||||
)
|
||||
}
|
||||
269
lib/checker/expression/environment/environment_test.go
Normal file
269
lib/checker/expression/environment/environment_test.go
Normal file
@@ -0,0 +1,269 @@
|
||||
package environment
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/cel-go/common/types"
|
||||
)
|
||||
|
||||
func TestBot(t *testing.T) {
|
||||
env, err := Bot()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create bot environment: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
expression string
|
||||
headers map[string]string
|
||||
expected types.Bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "missing-header",
|
||||
expression: `missingHeader(headers, "Missing-Header")`,
|
||||
headers: map[string]string{
|
||||
"User-Agent": "test-agent",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
expected: types.Bool(true),
|
||||
description: "should return true when header is missing",
|
||||
},
|
||||
{
|
||||
name: "existing-header",
|
||||
expression: `missingHeader(headers, "User-Agent")`,
|
||||
headers: map[string]string{
|
||||
"User-Agent": "test-agent",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
expected: types.Bool(false),
|
||||
description: "should return false when header exists",
|
||||
},
|
||||
{
|
||||
name: "case-sensitive",
|
||||
expression: `missingHeader(headers, "user-agent")`,
|
||||
headers: map[string]string{
|
||||
"User-Agent": "test-agent",
|
||||
},
|
||||
expected: types.Bool(true),
|
||||
description: "should be case-sensitive (user-agent != User-Agent)",
|
||||
},
|
||||
{
|
||||
name: "empty-headers",
|
||||
expression: `missingHeader(headers, "Any-Header")`,
|
||||
headers: map[string]string{},
|
||||
expected: types.Bool(true),
|
||||
description: "should return true for any header when map is empty",
|
||||
},
|
||||
{
|
||||
name: "real-world-sec-ch-ua",
|
||||
expression: `missingHeader(headers, "Sec-Ch-Ua")`,
|
||||
headers: map[string]string{
|
||||
"User-Agent": "curl/7.68.0",
|
||||
"Accept": "*/*",
|
||||
"Host": "example.com",
|
||||
},
|
||||
expected: types.Bool(true),
|
||||
description: "should detect missing browser-specific headers from bots",
|
||||
},
|
||||
{
|
||||
name: "browser-with-sec-ch-ua",
|
||||
expression: `missingHeader(headers, "Sec-Ch-Ua")`,
|
||||
headers: map[string]string{
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
|
||||
"Sec-Ch-Ua": `"Chrome"; v="91", "Not A Brand"; v="99"`,
|
||||
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
|
||||
},
|
||||
expected: types.Bool(false),
|
||||
description: "should return false when browser sends Sec-Ch-Ua header",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
prog, err := Compile(env, tt.expression)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to compile expression %q: %v", tt.expression, err)
|
||||
}
|
||||
|
||||
result, _, err := prog.Eval(map[string]interface{}{
|
||||
"headers": tt.headers,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to evaluate expression %q: %v", tt.expression, err)
|
||||
}
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("%s: expected %v, got %v", tt.description, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("function-compilation", func(t *testing.T) {
|
||||
src := `missingHeader(headers, "Test-Header")`
|
||||
_, err := Compile(env, src)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to compile missingHeader expression: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestThreshold(t *testing.T) {
|
||||
env, err := Threshold()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create threshold environment: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
expression string
|
||||
variables map[string]interface{}
|
||||
expected types.Bool
|
||||
description string
|
||||
shouldCompile bool
|
||||
}{
|
||||
{
|
||||
name: "weight-variable-available",
|
||||
expression: `weight > 100`,
|
||||
variables: map[string]interface{}{"weight": 150},
|
||||
expected: types.Bool(true),
|
||||
description: "should support weight variable in expressions",
|
||||
shouldCompile: true,
|
||||
},
|
||||
{
|
||||
name: "weight-variable-false-case",
|
||||
expression: `weight > 100`,
|
||||
variables: map[string]interface{}{"weight": 50},
|
||||
expected: types.Bool(false),
|
||||
description: "should correctly evaluate weight comparisons",
|
||||
shouldCompile: true,
|
||||
},
|
||||
{
|
||||
name: "missingHeader-not-available",
|
||||
expression: `missingHeader(headers, "Test")`,
|
||||
variables: map[string]interface{}{},
|
||||
expected: types.Bool(false), // not used
|
||||
description: "should not have missingHeader function available",
|
||||
shouldCompile: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
prog, err := Compile(env, tt.expression)
|
||||
|
||||
if !tt.shouldCompile {
|
||||
if err == nil {
|
||||
t.Fatalf("%s: expected compilation to fail but it succeeded", tt.description)
|
||||
}
|
||||
return // Test passed - compilation failed as expected
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("failed to compile expression %q: %v", tt.expression, err)
|
||||
}
|
||||
|
||||
result, _, err := prog.Eval(tt.variables)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to evaluate expression %q: %v", tt.expression, err)
|
||||
}
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("%s: expected %v, got %v", tt.description, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewEnvironment(t *testing.T) {
|
||||
env, err := New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create new environment: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
expression string
|
||||
variables map[string]interface{}
|
||||
expectBool *bool // nil if we just want to test compilation or non-bool result
|
||||
description string
|
||||
shouldCompile bool
|
||||
}{
|
||||
{
|
||||
name: "randInt-function-compilation",
|
||||
expression: `randInt(10)`,
|
||||
variables: map[string]interface{}{},
|
||||
expectBool: nil, // Don't check result, just compilation
|
||||
description: "should compile randInt function",
|
||||
shouldCompile: true,
|
||||
},
|
||||
{
|
||||
name: "randInt-range-validation",
|
||||
expression: `randInt(10) >= 0 && randInt(10) < 10`,
|
||||
variables: map[string]interface{}{},
|
||||
expectBool: boolPtr(true),
|
||||
description: "should return values in correct range",
|
||||
shouldCompile: true,
|
||||
},
|
||||
{
|
||||
name: "strings-extension-size",
|
||||
expression: `"hello".size() == 5`,
|
||||
variables: map[string]interface{}{},
|
||||
expectBool: boolPtr(true),
|
||||
description: "should support string extension functions",
|
||||
shouldCompile: true,
|
||||
},
|
||||
{
|
||||
name: "strings-extension-contains",
|
||||
expression: `"hello world".contains("world")`,
|
||||
variables: map[string]interface{}{},
|
||||
expectBool: boolPtr(true),
|
||||
description: "should support string contains function",
|
||||
shouldCompile: true,
|
||||
},
|
||||
{
|
||||
name: "strings-extension-startsWith",
|
||||
expression: `"hello world".startsWith("hello")`,
|
||||
variables: map[string]interface{}{},
|
||||
expectBool: boolPtr(true),
|
||||
description: "should support string startsWith function",
|
||||
shouldCompile: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
prog, err := Compile(env, tt.expression)
|
||||
|
||||
if !tt.shouldCompile {
|
||||
if err == nil {
|
||||
t.Fatalf("%s: expected compilation to fail but it succeeded", tt.description)
|
||||
}
|
||||
return // Test passed - compilation failed as expected
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("failed to compile expression %q: %v", tt.expression, err)
|
||||
}
|
||||
|
||||
// If we only want to test compilation, skip evaluation
|
||||
if tt.expectBool == nil {
|
||||
return
|
||||
}
|
||||
|
||||
result, _, err := prog.Eval(tt.variables)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to evaluate expression %q: %v", tt.expression, err)
|
||||
}
|
||||
|
||||
if result != types.Bool(*tt.expectBool) {
|
||||
t.Errorf("%s: expected %v, got %v", tt.description, *tt.expectBool, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to create bool pointers
|
||||
func boolPtr(b bool) *bool {
|
||||
return &b
|
||||
}
|
||||
Reference in New Issue
Block a user