mirror of
https://github.com/TecharoHQ/anubis.git
synced 2026-04-14 12:38:45 +00:00
refactor: get rid of package expressions by moving the code into package expression
Signed-off-by: Xe Iaso <me@xeiaso.net>
This commit is contained in:
3
lib/checker/expression/README.md
Normal file
3
lib/checker/expression/README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# Expressions support
|
||||
|
||||
The expressions support is based on ideas from [go-away](https://git.gammaspectra.live/git/go-away) but with different opinions about how things should be done.
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
|
||||
"github.com/TecharoHQ/anubis/internal"
|
||||
"github.com/TecharoHQ/anubis/lib/checker/expression/environment"
|
||||
"github.com/TecharoHQ/anubis/lib/policy/expressions"
|
||||
"github.com/google/cel-go/cel"
|
||||
"github.com/google/cel-go/common/types"
|
||||
)
|
||||
@@ -72,15 +71,15 @@ func (cr *CELRequest) ResolveName(name string) (any, bool) {
|
||||
case "path":
|
||||
return cr.URL.Path, true
|
||||
case "query":
|
||||
return expressions.URLValues{Values: cr.URL.Query()}, true
|
||||
return URLValues{Values: cr.URL.Query()}, true
|
||||
case "headers":
|
||||
return expressions.HTTPHeaders{Header: cr.Header}, true
|
||||
return HTTPHeaders{Header: cr.Header}, true
|
||||
case "load_1m":
|
||||
return expressions.Load1(), true
|
||||
return Load1(), true
|
||||
case "load_5m":
|
||||
return expressions.Load5(), true
|
||||
return Load5(), true
|
||||
case "load_15m":
|
||||
return expressions.Load15(), true
|
||||
return Load15(), true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
|
||||
75
lib/checker/expression/http_headers.go
Normal file
75
lib/checker/expression/http_headers.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package expression
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/google/cel-go/common/types"
|
||||
"github.com/google/cel-go/common/types/ref"
|
||||
"github.com/google/cel-go/common/types/traits"
|
||||
)
|
||||
|
||||
// HTTPHeaders is a type wrapper to expose HTTP headers into CEL programs.
|
||||
type HTTPHeaders struct {
|
||||
http.Header
|
||||
}
|
||||
|
||||
func (h HTTPHeaders) ConvertToNative(typeDesc reflect.Type) (any, error) {
|
||||
return nil, ErrNotImplemented
|
||||
}
|
||||
|
||||
func (h HTTPHeaders) ConvertToType(typeVal ref.Type) ref.Val {
|
||||
switch typeVal {
|
||||
case types.MapType:
|
||||
return h
|
||||
case types.TypeType:
|
||||
return types.MapType
|
||||
}
|
||||
|
||||
return types.NewErr("can't convert from %q to %q", types.MapType, typeVal)
|
||||
}
|
||||
|
||||
func (h HTTPHeaders) Equal(other ref.Val) ref.Val {
|
||||
return types.Bool(false) // We don't want to compare header maps
|
||||
}
|
||||
|
||||
func (h HTTPHeaders) Type() ref.Type {
|
||||
return types.MapType
|
||||
}
|
||||
|
||||
func (h HTTPHeaders) Value() any { return h }
|
||||
|
||||
func (h HTTPHeaders) Find(key ref.Val) (ref.Val, bool) {
|
||||
k, ok := key.(types.String)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if _, ok := h.Header[string(k)]; !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return types.String(strings.Join(h.Header.Values(string(k)), ",")), true
|
||||
}
|
||||
|
||||
func (h HTTPHeaders) Contains(key ref.Val) ref.Val {
|
||||
_, ok := h.Find(key)
|
||||
return types.Bool(ok)
|
||||
}
|
||||
|
||||
func (h HTTPHeaders) Get(key ref.Val) ref.Val {
|
||||
result, ok := h.Find(key)
|
||||
if !ok {
|
||||
return types.ValOrErr(result, "no such key: %v", key)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (h HTTPHeaders) Iterator() traits.Iterator { panic("TODO(Xe): implement me") }
|
||||
|
||||
func (h HTTPHeaders) IsZeroValue() bool {
|
||||
return len(h.Header) == 0
|
||||
}
|
||||
|
||||
func (h HTTPHeaders) Size() ref.Val { return types.Int(len(h.Header)) }
|
||||
52
lib/checker/expression/http_headers_test.go
Normal file
52
lib/checker/expression/http_headers_test.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package expression
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/google/cel-go/common/types"
|
||||
)
|
||||
|
||||
func TestHTTPHeaders(t *testing.T) {
|
||||
headers := HTTPHeaders{
|
||||
Header: http.Header{
|
||||
"Content-Type": {"application/json"},
|
||||
"Cf-Worker": {"true"},
|
||||
"User-Agent": {"Go-http-client/2"},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("contains-existing-header", func(t *testing.T) {
|
||||
resp := headers.Contains(types.String("User-Agent"))
|
||||
if !resp.(types.Bool) {
|
||||
t.Fatal("headers does not contain User-Agent")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("not-contains-missing-header", func(t *testing.T) {
|
||||
resp := headers.Contains(types.String("Xxx-Random-Header"))
|
||||
if resp.(types.Bool) {
|
||||
t.Fatal("headers does not contain User-Agent")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("get-existing-header", func(t *testing.T) {
|
||||
val := headers.Get(types.String("User-Agent"))
|
||||
switch val.(type) {
|
||||
case types.String:
|
||||
// ok
|
||||
default:
|
||||
t.Fatalf("result was wrong type %T", val)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("not-get-missing-header", func(t *testing.T) {
|
||||
val := headers.Get(types.String("Xxx-Random-Header"))
|
||||
switch val.(type) {
|
||||
case *types.Err:
|
||||
// ok
|
||||
default:
|
||||
t.Fatalf("result was wrong type %T", val)
|
||||
}
|
||||
})
|
||||
}
|
||||
69
lib/checker/expression/loadavg.go
Normal file
69
lib/checker/expression/loadavg.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package expression
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/shirou/gopsutil/v4/load"
|
||||
)
|
||||
|
||||
type loadAvg struct {
|
||||
lock sync.RWMutex
|
||||
data *load.AvgStat
|
||||
}
|
||||
|
||||
func (l *loadAvg) updateThread(ctx context.Context) {
|
||||
ticker := time.NewTicker(15 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
l.update()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
l.update()
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *loadAvg) update() {
|
||||
l.lock.Lock()
|
||||
defer l.lock.Unlock()
|
||||
|
||||
var err error
|
||||
l.data, err = load.Avg()
|
||||
if err != nil {
|
||||
slog.Debug("can't get load average", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
globalLoadAvg *loadAvg
|
||||
)
|
||||
|
||||
func init() {
|
||||
globalLoadAvg = &loadAvg{}
|
||||
go globalLoadAvg.updateThread(context.Background())
|
||||
}
|
||||
|
||||
func Load1() float64 {
|
||||
globalLoadAvg.lock.RLock()
|
||||
defer globalLoadAvg.lock.RUnlock()
|
||||
return globalLoadAvg.data.Load1
|
||||
}
|
||||
|
||||
func Load5() float64 {
|
||||
globalLoadAvg.lock.RLock()
|
||||
defer globalLoadAvg.lock.RUnlock()
|
||||
return globalLoadAvg.data.Load5
|
||||
}
|
||||
|
||||
func Load15() float64 {
|
||||
globalLoadAvg.lock.RLock()
|
||||
defer globalLoadAvg.lock.RUnlock()
|
||||
return globalLoadAvg.data.Load15
|
||||
}
|
||||
78
lib/checker/expression/url_values.go
Normal file
78
lib/checker/expression/url_values.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package expression
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/google/cel-go/common/types"
|
||||
"github.com/google/cel-go/common/types/ref"
|
||||
"github.com/google/cel-go/common/types/traits"
|
||||
)
|
||||
|
||||
var ErrNotImplemented = errors.New("expressions: not implemented")
|
||||
|
||||
// URLValues is a type wrapper to expose url.Values into CEL programs.
|
||||
type URLValues struct {
|
||||
url.Values
|
||||
}
|
||||
|
||||
func (u URLValues) ConvertToNative(typeDesc reflect.Type) (any, error) {
|
||||
return nil, ErrNotImplemented
|
||||
}
|
||||
|
||||
func (u URLValues) ConvertToType(typeVal ref.Type) ref.Val {
|
||||
switch typeVal {
|
||||
case types.MapType:
|
||||
return u
|
||||
case types.TypeType:
|
||||
return types.MapType
|
||||
}
|
||||
|
||||
return types.NewErr("can't convert from %q to %q", types.MapType, typeVal)
|
||||
}
|
||||
|
||||
func (u URLValues) Equal(other ref.Val) ref.Val {
|
||||
return types.Bool(false) // We don't want to compare header maps
|
||||
}
|
||||
|
||||
func (u URLValues) Type() ref.Type {
|
||||
return types.MapType
|
||||
}
|
||||
|
||||
func (u URLValues) Value() any { return u }
|
||||
|
||||
func (u URLValues) Find(key ref.Val) (ref.Val, bool) {
|
||||
k, ok := key.(types.String)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if _, ok := u.Values[string(k)]; !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return types.String(strings.Join(u.Values[string(k)], ",")), true
|
||||
}
|
||||
|
||||
func (u URLValues) Contains(key ref.Val) ref.Val {
|
||||
_, ok := u.Find(key)
|
||||
return types.Bool(ok)
|
||||
}
|
||||
|
||||
func (u URLValues) Get(key ref.Val) ref.Val {
|
||||
result, ok := u.Find(key)
|
||||
if !ok {
|
||||
return types.ValOrErr(result, "no such key: %v", key)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (u URLValues) Iterator() traits.Iterator { panic("TODO(Xe): implement me") }
|
||||
|
||||
func (u URLValues) IsZeroValue() bool {
|
||||
return len(u.Values) == 0
|
||||
}
|
||||
|
||||
func (u URLValues) Size() ref.Val { return types.Int(len(u.Values)) }
|
||||
50
lib/checker/expression/url_values_test.go
Normal file
50
lib/checker/expression/url_values_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package expression
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/google/cel-go/common/types"
|
||||
)
|
||||
|
||||
func TestURLValues(t *testing.T) {
|
||||
headers := URLValues{
|
||||
Values: url.Values{
|
||||
"format": {"json"},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("contains-existing-key", func(t *testing.T) {
|
||||
resp := headers.Contains(types.String("format"))
|
||||
if !resp.(types.Bool) {
|
||||
t.Fatal("headers does not contain User-Agent")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("not-contains-missing-key", func(t *testing.T) {
|
||||
resp := headers.Contains(types.String("not-there"))
|
||||
if resp.(types.Bool) {
|
||||
t.Fatal("headers does not contain User-Agent")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("get-existing-key", func(t *testing.T) {
|
||||
val := headers.Get(types.String("format"))
|
||||
switch val.(type) {
|
||||
case types.String:
|
||||
// ok
|
||||
default:
|
||||
t.Fatalf("result was wrong type %T", val)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("not-get-missing-key", func(t *testing.T) {
|
||||
val := headers.Get(types.String("not-there"))
|
||||
switch val.(type) {
|
||||
case *types.Err:
|
||||
// ok
|
||||
default:
|
||||
t.Fatalf("result was wrong type %T", val)
|
||||
}
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user