diff --git a/docs/docs/CHANGELOG.md b/docs/docs/CHANGELOG.md index 88823bd6..a96b8e6f 100644 --- a/docs/docs/CHANGELOG.md +++ b/docs/docs/CHANGELOG.md @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed mixed tab/space indentation in Caddy documentation code block +- Fix CEL internal errors when iterating `headers`/`query` map wrappers by implementing map iterators for `HTTPHeaders` and `URLValues` ([#1465](https://github.com/TecharoHQ/anubis/pull/1465)). ## v1.25.0: Necron diff --git a/lib/policy/celchecker_test.go b/lib/policy/celchecker_test.go new file mode 100644 index 00000000..61c90819 --- /dev/null +++ b/lib/policy/celchecker_test.go @@ -0,0 +1,44 @@ +package policy + +import ( + "net/http" + "testing" + + "github.com/TecharoHQ/anubis/internal/dns" + "github.com/TecharoHQ/anubis/lib/config" + "github.com/TecharoHQ/anubis/lib/store/memory" +) + +func newTestDNS(t *testing.T) *dns.Dns { + t.Helper() + + ctx := t.Context() + memStore := memory.New(ctx) + cache := dns.NewDNSCache(300, 300, memStore) + return dns.New(ctx, cache) +} + +func TestCELChecker_MapIterationWrappers(t *testing.T) { + cfg := &config.ExpressionOrList{ + Expression: `headers.exists(k, k == "Accept") && query.exists(k, k == "format")`, + } + + checker, err := NewCELChecker(cfg, newTestDNS(t)) + if err != nil { + t.Fatalf("creating CEL checker failed: %v", err) + } + + req, err := http.NewRequest(http.MethodGet, "https://example.com/?format=json", nil) + if err != nil { + t.Fatalf("making request failed: %v", err) + } + req.Header.Set("Accept", "application/json") + + got, err := checker.Check(req) + if err != nil { + t.Fatalf("checking expression failed: %v", err) + } + if !got { + t.Fatal("expected expression to evaluate true") + } +} diff --git a/lib/policy/expressions/http_headers.go b/lib/policy/expressions/http_headers.go index 57fcc841..c4342dc3 100644 --- a/lib/policy/expressions/http_headers.go +++ b/lib/policy/expressions/http_headers.go @@ -66,7 +66,9 @@ func (h HTTPHeaders) Get(key ref.Val) ref.Val { return result } -func (h HTTPHeaders) Iterator() traits.Iterator { panic("TODO(Xe): implement me") } +func (h HTTPHeaders) Iterator() traits.Iterator { + return newMapIterator(h.Header) +} func (h HTTPHeaders) IsZeroValue() bool { return len(h.Header) == 0 diff --git a/lib/policy/expressions/map_iterator.go b/lib/policy/expressions/map_iterator.go new file mode 100644 index 00000000..4a86269f --- /dev/null +++ b/lib/policy/expressions/map_iterator.go @@ -0,0 +1,60 @@ +package expressions + +import ( + "errors" + "maps" + "reflect" + "slices" + + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" +) + +var ErrNotImplemented = errors.New("expressions: not implemented") + +type stringSliceIterator struct { + keys []string + idx int +} + +func (s *stringSliceIterator) Value() any { + return s +} + +func (s *stringSliceIterator) ConvertToNative(typeDesc reflect.Type) (any, error) { + return nil, ErrNotImplemented +} + +func (s *stringSliceIterator) ConvertToType(typeValue ref.Type) ref.Val { + return types.NewErr("can't convert from %q to %q", types.IteratorType, typeValue) +} + +func (s *stringSliceIterator) Equal(other ref.Val) ref.Val { + return types.NewErr("can't compare %q to %q", types.IteratorType, other.Type()) +} + +func (s *stringSliceIterator) Type() ref.Type { + return types.IteratorType +} + +func (s *stringSliceIterator) HasNext() ref.Val { + return types.Bool(s.idx < len(s.keys)) +} + +func (s *stringSliceIterator) Next() ref.Val { + if s.HasNext() != types.True { + return nil + } + + val := s.keys[s.idx] + s.idx++ + return types.String(val) +} + +func newMapIterator(m map[string][]string) traits.Iterator { + return &stringSliceIterator{ + keys: slices.Collect(maps.Keys(m)), + idx: 0, + } +} diff --git a/lib/policy/expressions/url_values.go b/lib/policy/expressions/url_values.go index a4c63519..21d25d5b 100644 --- a/lib/policy/expressions/url_values.go +++ b/lib/policy/expressions/url_values.go @@ -1,7 +1,6 @@ package expressions import ( - "errors" "net/url" "reflect" "strings" @@ -11,8 +10,6 @@ import ( "github.com/google/cel-go/common/types/traits" ) -var ErrNotImplemented = errors.New("expressions: not implemented") - // URLValues is a type wrapper to expose url.Values into CEL programs. type URLValues struct { url.Values @@ -69,7 +66,9 @@ func (u URLValues) Get(key ref.Val) ref.Val { return result } -func (u URLValues) Iterator() traits.Iterator { panic("TODO(Xe): implement me") } +func (u URLValues) Iterator() traits.Iterator { + return newMapIterator(u.Values) +} func (u URLValues) IsZeroValue() bool { return len(u.Values) == 0