diff --git a/lib/policy/checker/celutils.go b/lib/policy/checker/celutils.go new file mode 100644 index 00000000..b260c622 --- /dev/null +++ b/lib/policy/checker/celutils.go @@ -0,0 +1,61 @@ +package checker + +import ( + "errors" + "maps" + "reflect" + "slices" + + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" +) + +var ErrNotImplemented = errors.New("cel: functionality not implemented") + +type stringSliceIterator struct { + keys []string + idx int +} + +func (s *stringSliceIterator) Value() any { + return s +} + +func (s *stringSliceIterator) ConvertToNative(typeDesc reflect.Type) (any, error) { + return nil, ErrNotImplemented +} + +func (s *stringSliceIterator) ConvertToType(typeValue ref.Type) ref.Val { + return types.NewErr("can't convert from %q to %q", types.IteratorType, typeValue) +} + +func (s *stringSliceIterator) Equal(other ref.Val) ref.Val { + return types.NewErr("can't compare %q to %q", types.IteratorType, other.Type()) + +} + +func (s *stringSliceIterator) Type() ref.Type { + return types.IteratorType +} + +func (s *stringSliceIterator) HasNext() ref.Val { + return types.Bool(s.idx < len(s.keys)) +} + +func (s *stringSliceIterator) Next() ref.Val { + if s.HasNext() != types.True { + return nil + } + + val := s.keys[s.idx] + s.idx++ + return types.String(val) +} + +func NewMapIterator(m map[string][]string) traits.Iterator { + return &stringSliceIterator{ + keys: slices.Collect(maps.Keys(m)), + idx: 0, + } +} diff --git a/lib/policy/expressions/http_headers.go b/lib/policy/expressions/http_headers.go index 57fcc841..735f4516 100644 --- a/lib/policy/expressions/http_headers.go +++ b/lib/policy/expressions/http_headers.go @@ -5,6 +5,7 @@ import ( "reflect" "strings" + "github.com/TecharoHQ/anubis/lib/policy/checker" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/traits" @@ -66,7 +67,9 @@ func (h HTTPHeaders) Get(key ref.Val) ref.Val { return result } -func (h HTTPHeaders) Iterator() traits.Iterator { panic("TODO(Xe): implement me") } +func (h HTTPHeaders) Iterator() traits.Iterator { + return checker.NewMapIterator(h.Header) +} func (h HTTPHeaders) IsZeroValue() bool { return len(h.Header) == 0 diff --git a/lib/policy/expressions/url_values.go b/lib/policy/expressions/url_values.go index a4c63519..9051da25 100644 --- a/lib/policy/expressions/url_values.go +++ b/lib/policy/expressions/url_values.go @@ -6,6 +6,7 @@ import ( "reflect" "strings" + "github.com/TecharoHQ/anubis/lib/policy/checker" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/traits" @@ -69,7 +70,9 @@ func (u URLValues) Get(key ref.Val) ref.Val { return result } -func (u URLValues) Iterator() traits.Iterator { panic("TODO(Xe): implement me") } +func (u URLValues) Iterator() traits.Iterator { + return checker.NewMapIterator(u.Values) +} func (u URLValues) IsZeroValue() bool { return len(u.Values) == 0