From 4e1db3842e52e35c0c505b0741ef65005aed7c88 Mon Sep 17 00:00:00 2001 From: Xe Iaso Date: Sun, 27 Apr 2025 22:49:44 -0400 Subject: [PATCH] fix(policy/expressions): do not Contains missing keys Signed-off-by: Xe Iaso --- lib/policy/expressions/http_headers.go | 4 ++ lib/policy/expressions/http_headers_test.go | 52 +++++++++++++++++++++ lib/policy/expressions/url_values.go | 4 ++ lib/policy/expressions/url_values_test.go | 50 ++++++++++++++++++++ 4 files changed, 110 insertions(+) create mode 100644 lib/policy/expressions/http_headers_test.go create mode 100644 lib/policy/expressions/url_values_test.go diff --git a/lib/policy/expressions/http_headers.go b/lib/policy/expressions/http_headers.go index 45bf0aac..57fcc841 100644 --- a/lib/policy/expressions/http_headers.go +++ b/lib/policy/expressions/http_headers.go @@ -46,6 +46,10 @@ func (h HTTPHeaders) Find(key ref.Val) (ref.Val, bool) { return nil, false } + if _, ok := h.Header[string(k)]; !ok { + return nil, false + } + return types.String(strings.Join(h.Header.Values(string(k)), ",")), true } diff --git a/lib/policy/expressions/http_headers_test.go b/lib/policy/expressions/http_headers_test.go new file mode 100644 index 00000000..d56f65ca --- /dev/null +++ b/lib/policy/expressions/http_headers_test.go @@ -0,0 +1,52 @@ +package expressions + +import ( + "net/http" + "testing" + + "github.com/google/cel-go/common/types" +) + +func TestHTTPHeaders(t *testing.T) { + headers := HTTPHeaders{ + Header: http.Header{ + "Content-Type": {"application/json"}, + "Cf-Worker": {"true"}, + "User-Agent": {"Go-http-client/2"}, + }, + } + + t.Run("contains-existing-header", func(t *testing.T) { + resp := headers.Contains(types.String("User-Agent")) + if !bool(resp.(types.Bool)) { + t.Fatal("headers does not contain User-Agent") + } + }) + + t.Run("not-contains-missing-header", func(t *testing.T) { + resp := headers.Contains(types.String("Xxx-Random-Header")) + if bool(resp.(types.Bool)) { + t.Fatal("headers does not contain User-Agent") + } + }) + + t.Run("get-existing-header", func(t *testing.T) { + val := headers.Get(types.String("User-Agent")) + switch val.(type) { + case types.String: + // ok + default: + t.Fatalf("result was wrong type %T", val) + } + }) + + t.Run("not-get-missing-header", func(t *testing.T) { + val := headers.Get(types.String("Xxx-Random-Header")) + switch val.(type) { + case *types.Err: + // ok + default: + t.Fatalf("result was wrong type %T", val) + } + }) +} diff --git a/lib/policy/expressions/url_values.go b/lib/policy/expressions/url_values.go index 83694cea..a4c63519 100644 --- a/lib/policy/expressions/url_values.go +++ b/lib/policy/expressions/url_values.go @@ -49,6 +49,10 @@ func (u URLValues) Find(key ref.Val) (ref.Val, bool) { return nil, false } + if _, ok := u.Values[string(k)]; !ok { + return nil, false + } + return types.String(strings.Join(u.Values[string(k)], ",")), true } diff --git a/lib/policy/expressions/url_values_test.go b/lib/policy/expressions/url_values_test.go new file mode 100644 index 00000000..49d27b72 --- /dev/null +++ b/lib/policy/expressions/url_values_test.go @@ -0,0 +1,50 @@ +package expressions + +import ( + "net/url" + "testing" + + "github.com/google/cel-go/common/types" +) + +func TestURLValues(t *testing.T) { + headers := URLValues{ + Values: url.Values{ + "format": {"json"}, + }, + } + + t.Run("contains-existing-key", func(t *testing.T) { + resp := headers.Contains(types.String("format")) + if !bool(resp.(types.Bool)) { + t.Fatal("headers does not contain User-Agent") + } + }) + + t.Run("not-contains-missing-key", func(t *testing.T) { + resp := headers.Contains(types.String("not-there")) + if bool(resp.(types.Bool)) { + t.Fatal("headers does not contain User-Agent") + } + }) + + t.Run("get-existing-key", func(t *testing.T) { + val := headers.Get(types.String("format")) + switch val.(type) { + case types.String: + // ok + default: + t.Fatalf("result was wrong type %T", val) + } + }) + + t.Run("not-get-missing-key", func(t *testing.T) { + val := headers.Get(types.String("not-there")) + switch val.(type) { + case *types.Err: + // ok + default: + t.Fatalf("result was wrong type %T", val) + } + }) +}