diff --git a/lib/policy/config/config.go b/lib/policy/config/config.go index d140549f..55fb3c04 100644 --- a/lib/policy/config/config.go +++ b/lib/policy/config/config.go @@ -46,15 +46,15 @@ const ( const DefaultAlgorithm = "fast" type BotConfig struct { - UserAgentRegex *string `json:"user_agent_regex,omitempty"` - PathRegex *string `json:"path_regex,omitempty"` - HeadersRegex map[string]string `json:"headers_regex,omitempty"` - Expression *ExpressionOrList `json:"expression,omitempty"` - Challenge *ChallengeRules `json:"challenge,omitempty"` - Weight *Weight `json:"weight,omitempty"` - Name string `json:"name"` - Action Rule `json:"action"` - RemoteAddr []string `json:"remote_addresses,omitempty"` + UserAgentRegex *string `json:"user_agent_regex,omitempty" yaml:"user_agent_regex,omitempty"` + PathRegex *string `json:"path_regex,omitempty" yaml:"path_regex,omitempty"` + HeadersRegex map[string]string `json:"headers_regex,omitempty" yaml:"headers_regex,omitempty"` + Expression *ExpressionOrList `json:"expression,omitempty" yaml:"expression,omitempty"` + Challenge *ChallengeRules `json:"challenge,omitempty" yaml:"challenge,omitempty"` + Weight *Weight `json:"weight,omitempty" yaml:"weight,omitempty"` + Name string `json:"name" yaml:"name"` + Action Rule `json:"action" yaml:"action"` + RemoteAddr []string `json:"remote_addresses,omitempty" yaml:"remote_addresses,omitempty"` } func (b BotConfig) Zero() bool { diff --git a/lib/policy/config/expressionorlist.go b/lib/policy/config/expressionorlist.go index 8851c5b1..68dafd2c 100644 --- a/lib/policy/config/expressionorlist.go +++ b/lib/policy/config/expressionorlist.go @@ -13,9 +13,9 @@ var ( ) type ExpressionOrList struct { - Expression string `json:"-"` - All []string `json:"all,omitempty"` - Any []string `json:"any,omitempty"` + Expression string `json:"-" yaml:"-"` + All []string `json:"all,omitempty" yaml:"all,omitempty"` + Any []string `json:"any,omitempty" yaml:"any,omitempty"` } func (eol ExpressionOrList) Equal(rhs *ExpressionOrList) bool { @@ -34,6 +34,43 @@ func (eol ExpressionOrList) Equal(rhs *ExpressionOrList) bool { return true } +func (eol *ExpressionOrList) MarshalYAML() (any, error) { + switch { + case len(eol.All) == 1 && len(eol.Any) == 0: + eol.Expression = eol.All[0] + eol.All = nil + case len(eol.Any) == 1 && len(eol.All) == 0: + eol.Expression = eol.Any[0] + eol.Any = nil + } + + if eol.Expression != "" { + return eol.Expression, nil + } + + type RawExpressionOrList ExpressionOrList + return RawExpressionOrList(*eol), nil +} + +func (eol *ExpressionOrList) MarshalJSON() ([]byte, error) { + switch { + case len(eol.All) == 1 && len(eol.Any) == 0: + eol.Expression = eol.All[0] + eol.All = nil + case len(eol.Any) == 1 && len(eol.All) == 0: + eol.Expression = eol.Any[0] + eol.Any = nil + } + + if eol.Expression != "" { + return json.Marshal(string(eol.Expression)) + } + + type RawExpressionOrList ExpressionOrList + val := RawExpressionOrList(*eol) + return json.Marshal(val) +} + func (eol *ExpressionOrList) UnmarshalJSON(data []byte) error { switch string(data[0]) { case `"`: // string diff --git a/lib/policy/config/expressionorlist_test.go b/lib/policy/config/expressionorlist_test.go index dbdda2d7..8d0c8436 100644 --- a/lib/policy/config/expressionorlist_test.go +++ b/lib/policy/config/expressionorlist_test.go @@ -1,12 +1,147 @@ package config import ( + "bytes" "encoding/json" "errors" "testing" + + yaml "sigs.k8s.io/yaml/goyaml.v3" ) -func TestExpressionOrListUnmarshal(t *testing.T) { +func TestExpressionOrListMarshalJSON(t *testing.T) { + for _, tt := range []struct { + name string + input *ExpressionOrList + output []byte + err error + }{ + { + name: "single expression", + input: &ExpressionOrList{ + Expression: "true", + }, + output: []byte(`"true"`), + err: nil, + }, + { + name: "all", + input: &ExpressionOrList{ + All: []string{"true", "true"}, + }, + output: []byte(`{"all":["true","true"]}`), + err: nil, + }, + { + name: "all one", + input: &ExpressionOrList{ + All: []string{"true"}, + }, + output: []byte(`"true"`), + err: nil, + }, + { + name: "any", + input: &ExpressionOrList{ + Any: []string{"true", "false"}, + }, + output: []byte(`{"any":["true","false"]}`), + err: nil, + }, + { + name: "any one", + input: &ExpressionOrList{ + Any: []string{"true"}, + }, + output: []byte(`"true"`), + err: nil, + }, + } { + t.Run(tt.name, func(t *testing.T) { + result, err := json.Marshal(tt.input) + if !errors.Is(err, tt.err) { + t.Errorf("wanted marshal error: %v but got: %v", tt.err, err) + } + + if !bytes.Equal(result, tt.output) { + t.Logf("wanted: %s", string(tt.output)) + t.Logf("got: %s", string(result)) + t.Error("mismatched output") + } + }) + } +} + +func TestExpressionOrListMarshalYAML(t *testing.T) { + for _, tt := range []struct { + name string + input *ExpressionOrList + output []byte + err error + }{ + { + name: "single expression", + input: &ExpressionOrList{ + Expression: "true", + }, + output: []byte(`"true"`), + err: nil, + }, + { + name: "all", + input: &ExpressionOrList{ + All: []string{"true", "true"}, + }, + output: []byte(`all: + - "true" + - "true"`), + err: nil, + }, + { + name: "all one", + input: &ExpressionOrList{ + All: []string{"true"}, + }, + output: []byte(`"true"`), + err: nil, + }, + { + name: "any", + input: &ExpressionOrList{ + Any: []string{"true", "false"}, + }, + output: []byte(`any: + - "true" + - "false"`), + err: nil, + }, + { + name: "any one", + input: &ExpressionOrList{ + Any: []string{"true"}, + }, + output: []byte(`"true"`), + err: nil, + }, + } { + t.Run(tt.name, func(t *testing.T) { + result, err := yaml.Marshal(tt.input) + if !errors.Is(err, tt.err) { + t.Errorf("wanted marshal error: %v but got: %v", tt.err, err) + } + + result = bytes.TrimSpace(result) + + if !bytes.Equal(result, tt.output) { + t.Logf("wanted: %q", string(tt.output)) + t.Logf("got: %q", string(result)) + t.Error("mismatched output") + } + }) + } +} + +func TestExpressionOrListUnmarshalJSON(t *testing.T) { for _, tt := range []struct { err error validErr error