mirror of
https://github.com/TecharoHQ/anubis.git
synced 2026-04-11 19:18:46 +00:00
refactor: move CEL checker to its own package
Signed-off-by: Xe Iaso <me@xeiaso.net>
This commit is contained in:
86
lib/checker/expression/checker.go
Normal file
86
lib/checker/expression/checker.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package expression
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/TecharoHQ/anubis/internal"
|
||||
"github.com/TecharoHQ/anubis/lib/policy/expressions"
|
||||
"github.com/google/cel-go/cel"
|
||||
"github.com/google/cel-go/common/types"
|
||||
)
|
||||
|
||||
type Checker struct {
|
||||
program cel.Program
|
||||
src string
|
||||
hash string
|
||||
}
|
||||
|
||||
func New(cfg *Config) (*Checker, error) {
|
||||
env, err := expressions.BotEnvironment()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
program, err := expressions.Compile(env, cfg.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can't compile CEL program: %w", err)
|
||||
}
|
||||
|
||||
return &Checker{
|
||||
src: cfg.String(),
|
||||
hash: internal.FastHash(cfg.String()),
|
||||
program: program,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (cc *Checker) Hash() string {
|
||||
return cc.hash
|
||||
}
|
||||
|
||||
func (cc *Checker) Check(r *http.Request) (bool, error) {
|
||||
result, _, err := cc.program.ContextEval(r.Context(), &CELRequest{r})
|
||||
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if val, ok := result.(types.Bool); ok {
|
||||
return bool(val), nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
type CELRequest struct {
|
||||
*http.Request
|
||||
}
|
||||
|
||||
func (cr *CELRequest) Parent() cel.Activation { return nil }
|
||||
|
||||
func (cr *CELRequest) ResolveName(name string) (any, bool) {
|
||||
switch name {
|
||||
case "remoteAddress":
|
||||
return cr.Header.Get("X-Real-Ip"), true
|
||||
case "host":
|
||||
return cr.Host, true
|
||||
case "method":
|
||||
return cr.Method, true
|
||||
case "userAgent":
|
||||
return cr.UserAgent(), true
|
||||
case "path":
|
||||
return cr.URL.Path, true
|
||||
case "query":
|
||||
return expressions.URLValues{Values: cr.URL.Query()}, true
|
||||
case "headers":
|
||||
return expressions.HTTPHeaders{Header: cr.Header}, true
|
||||
case "load_1m":
|
||||
return expressions.Load1(), true
|
||||
case "load_5m":
|
||||
return expressions.Load5(), true
|
||||
case "load_15m":
|
||||
return expressions.Load15(), true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
130
lib/checker/expression/config.go
Normal file
130
lib/checker/expression/config.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package expression
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrExpressionOrListMustBeStringOrObject = errors.New("expression: this must be a string or an object")
|
||||
ErrExpressionEmpty = errors.New("expression: this expression is empty")
|
||||
ErrExpressionCantHaveBoth = errors.New("expression: expression block can't contain multiple expression types")
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Expression string `json:"-" yaml:"-"`
|
||||
All []string `json:"all,omitempty" yaml:"all,omitempty"`
|
||||
Any []string `json:"any,omitempty" yaml:"any,omitempty"`
|
||||
}
|
||||
|
||||
func (eol Config) String() string {
|
||||
switch {
|
||||
case len(eol.Expression) != 0:
|
||||
return eol.Expression
|
||||
case len(eol.All) != 0:
|
||||
var sb strings.Builder
|
||||
for i, pred := range eol.All {
|
||||
if i != 0 {
|
||||
fmt.Fprintf(&sb, " && ")
|
||||
}
|
||||
fmt.Fprintf(&sb, "( %s )", pred)
|
||||
}
|
||||
return sb.String()
|
||||
case len(eol.Any) != 0:
|
||||
var sb strings.Builder
|
||||
for i, pred := range eol.Any {
|
||||
if i != 0 {
|
||||
fmt.Fprintf(&sb, " || ")
|
||||
}
|
||||
fmt.Fprintf(&sb, "( %s )", pred)
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
panic("this should not happen")
|
||||
}
|
||||
|
||||
func (eol Config) Equal(rhs *Config) bool {
|
||||
if eol.Expression != rhs.Expression {
|
||||
return false
|
||||
}
|
||||
|
||||
if !slices.Equal(eol.All, rhs.All) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !slices.Equal(eol.Any, rhs.Any) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (eol *Config) MarshalYAML() (any, error) {
|
||||
switch {
|
||||
case len(eol.All) == 1 && len(eol.Any) == 0:
|
||||
eol.Expression = eol.All[0]
|
||||
eol.All = nil
|
||||
case len(eol.Any) == 1 && len(eol.All) == 0:
|
||||
eol.Expression = eol.Any[0]
|
||||
eol.Any = nil
|
||||
}
|
||||
|
||||
if eol.Expression != "" {
|
||||
return eol.Expression, nil
|
||||
}
|
||||
|
||||
type RawExpressionOrList Config
|
||||
return RawExpressionOrList(*eol), nil
|
||||
}
|
||||
|
||||
func (eol *Config) MarshalJSON() ([]byte, error) {
|
||||
switch {
|
||||
case len(eol.All) == 1 && len(eol.Any) == 0:
|
||||
eol.Expression = eol.All[0]
|
||||
eol.All = nil
|
||||
case len(eol.Any) == 1 && len(eol.All) == 0:
|
||||
eol.Expression = eol.Any[0]
|
||||
eol.Any = nil
|
||||
}
|
||||
|
||||
if eol.Expression != "" {
|
||||
return json.Marshal(string(eol.Expression))
|
||||
}
|
||||
|
||||
type RawExpressionOrList Config
|
||||
val := RawExpressionOrList(*eol)
|
||||
return json.Marshal(val)
|
||||
}
|
||||
|
||||
func (eol *Config) UnmarshalJSON(data []byte) error {
|
||||
switch string(data[0]) {
|
||||
case `"`: // string
|
||||
return json.Unmarshal(data, &eol.Expression)
|
||||
case "{": // object
|
||||
type RawExpressionOrList Config
|
||||
var val RawExpressionOrList
|
||||
if err := json.Unmarshal(data, &val); err != nil {
|
||||
return err
|
||||
}
|
||||
eol.All = val.All
|
||||
eol.Any = val.Any
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return ErrExpressionOrListMustBeStringOrObject
|
||||
}
|
||||
|
||||
func (eol *Config) Valid() error {
|
||||
if eol.Expression == "" && len(eol.All) == 0 && len(eol.Any) == 0 {
|
||||
return ErrExpressionEmpty
|
||||
}
|
||||
if len(eol.All) != 0 && len(eol.Any) != 0 {
|
||||
return ErrExpressionCantHaveBoth
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
266
lib/checker/expression/config_test.go
Normal file
266
lib/checker/expression/config_test.go
Normal file
@@ -0,0 +1,266 @@
|
||||
package expression
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
yaml "sigs.k8s.io/yaml/goyaml.v3"
|
||||
)
|
||||
|
||||
func TestExpressionOrListMarshalJSON(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
input *Config
|
||||
output []byte
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "single expression",
|
||||
input: &Config{
|
||||
Expression: "true",
|
||||
},
|
||||
output: []byte(`"true"`),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "all",
|
||||
input: &Config{
|
||||
All: []string{"true", "true"},
|
||||
},
|
||||
output: []byte(`{"all":["true","true"]}`),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "all one",
|
||||
input: &Config{
|
||||
All: []string{"true"},
|
||||
},
|
||||
output: []byte(`"true"`),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "any",
|
||||
input: &Config{
|
||||
Any: []string{"true", "false"},
|
||||
},
|
||||
output: []byte(`{"any":["true","false"]}`),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "any one",
|
||||
input: &Config{
|
||||
Any: []string{"true"},
|
||||
},
|
||||
output: []byte(`"true"`),
|
||||
err: nil,
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := json.Marshal(tt.input)
|
||||
if !errors.Is(err, tt.err) {
|
||||
t.Errorf("wanted marshal error: %v but got: %v", tt.err, err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(result, tt.output) {
|
||||
t.Logf("wanted: %s", string(tt.output))
|
||||
t.Logf("got: %s", string(result))
|
||||
t.Error("mismatched output")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpressionOrListMarshalYAML(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
input *Config
|
||||
output []byte
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "single expression",
|
||||
input: &Config{
|
||||
Expression: "true",
|
||||
},
|
||||
output: []byte(`"true"`),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "all",
|
||||
input: &Config{
|
||||
All: []string{"true", "true"},
|
||||
},
|
||||
output: []byte(`all:
|
||||
- "true"
|
||||
- "true"`),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "all one",
|
||||
input: &Config{
|
||||
All: []string{"true"},
|
||||
},
|
||||
output: []byte(`"true"`),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "any",
|
||||
input: &Config{
|
||||
Any: []string{"true", "false"},
|
||||
},
|
||||
output: []byte(`any:
|
||||
- "true"
|
||||
- "false"`),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "any one",
|
||||
input: &Config{
|
||||
Any: []string{"true"},
|
||||
},
|
||||
output: []byte(`"true"`),
|
||||
err: nil,
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := yaml.Marshal(tt.input)
|
||||
if !errors.Is(err, tt.err) {
|
||||
t.Errorf("wanted marshal error: %v but got: %v", tt.err, err)
|
||||
}
|
||||
|
||||
result = bytes.TrimSpace(result)
|
||||
|
||||
if !bytes.Equal(result, tt.output) {
|
||||
t.Logf("wanted: %q", string(tt.output))
|
||||
t.Logf("got: %q", string(result))
|
||||
t.Error("mismatched output")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpressionOrListUnmarshalJSON(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
err error
|
||||
validErr error
|
||||
result *Config
|
||||
name string
|
||||
inp string
|
||||
}{
|
||||
{
|
||||
name: "simple",
|
||||
inp: `"\"User-Agent\" in headers"`,
|
||||
result: &Config{
|
||||
Expression: `"User-Agent" in headers`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "object-and",
|
||||
inp: `{
|
||||
"all": ["\"User-Agent\" in headers"]
|
||||
}`,
|
||||
result: &Config{
|
||||
All: []string{
|
||||
`"User-Agent" in headers`,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "object-or",
|
||||
inp: `{
|
||||
"any": ["\"User-Agent\" in headers"]
|
||||
}`,
|
||||
result: &Config{
|
||||
Any: []string{
|
||||
`"User-Agent" in headers`,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "both-or-and",
|
||||
inp: `{
|
||||
"all": ["\"User-Agent\" in headers"],
|
||||
"any": ["\"User-Agent\" in headers"]
|
||||
}`,
|
||||
validErr: ErrExpressionCantHaveBoth,
|
||||
},
|
||||
{
|
||||
name: "expression-empty",
|
||||
inp: `{
|
||||
"any": []
|
||||
}`,
|
||||
validErr: ErrExpressionEmpty,
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var eol Config
|
||||
|
||||
if err := json.Unmarshal([]byte(tt.inp), &eol); !errors.Is(err, tt.err) {
|
||||
t.Errorf("wanted unmarshal error: %v but got: %v", tt.err, err)
|
||||
}
|
||||
|
||||
if tt.result != nil && !eol.Equal(tt.result) {
|
||||
t.Logf("wanted: %#v", tt.result)
|
||||
t.Logf("got: %#v", &eol)
|
||||
t.Fatal("parsed expression is not what was expected")
|
||||
}
|
||||
|
||||
if err := eol.Valid(); !errors.Is(err, tt.validErr) {
|
||||
t.Errorf("wanted validation error: %v but got: %v", tt.err, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpressionOrListString(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
in Config
|
||||
out string
|
||||
}{
|
||||
{
|
||||
name: "single expression",
|
||||
in: Config{
|
||||
Expression: "true",
|
||||
},
|
||||
out: "true",
|
||||
},
|
||||
{
|
||||
name: "all",
|
||||
in: Config{
|
||||
All: []string{"true"},
|
||||
},
|
||||
out: "( true )",
|
||||
},
|
||||
{
|
||||
name: "all with &&",
|
||||
in: Config{
|
||||
All: []string{"true", "true"},
|
||||
},
|
||||
out: "( true ) && ( true )",
|
||||
},
|
||||
{
|
||||
name: "any",
|
||||
in: Config{
|
||||
All: []string{"true"},
|
||||
},
|
||||
out: "( true )",
|
||||
},
|
||||
{
|
||||
name: "any with ||",
|
||||
in: Config{
|
||||
Any: []string{"true", "true"},
|
||||
},
|
||||
out: "( true ) || ( true )",
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.in.String()
|
||||
if result != tt.out {
|
||||
t.Errorf("wanted %q, got: %q", tt.out, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
43
lib/checker/expression/factory.go
Normal file
43
lib/checker/expression/factory.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package expression
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/TecharoHQ/anubis/lib/checker"
|
||||
)
|
||||
|
||||
func init() {
|
||||
checker.Register("expression", Factory{})
|
||||
}
|
||||
|
||||
type Factory struct{}
|
||||
|
||||
func (f Factory) Build(ctx context.Context, data json.RawMessage) (checker.Interface, error) {
|
||||
var fc = &Config{}
|
||||
|
||||
if err := json.Unmarshal([]byte(data), fc); err != nil {
|
||||
return nil, errors.Join(checker.ErrUnparseableConfig, err)
|
||||
}
|
||||
|
||||
if err := fc.Valid(); err != nil {
|
||||
return nil, errors.Join(checker.ErrInvalidConfig, err)
|
||||
}
|
||||
|
||||
return New(fc)
|
||||
}
|
||||
|
||||
func (f Factory) Valid(ctx context.Context, data json.RawMessage) error {
|
||||
var fc = &Config{}
|
||||
|
||||
if err := json.Unmarshal([]byte(data), fc); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := fc.Valid(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user