refactor: use new checker types

Signed-off-by: Xe Iaso <me@xeiaso.net>
This commit is contained in:
Xe Iaso
2025-07-25 19:39:14 +00:00
parent 88c30c70fc
commit 590d8303ad
15 changed files with 331 additions and 384 deletions
+35
View File
@@ -0,0 +1,35 @@
package checker
import (
"fmt"
"net/http"
"strings"
"github.com/TecharoHQ/anubis/internal"
)
type All []Interface
func (a All) Check(r *http.Request) (bool, error) {
for _, c := range a {
match, err := c.Check(r)
if err != nil {
return match, err
}
if !match {
return false, err // no match
}
}
return true, nil // match
}
func (a All) Hash() string {
var sb strings.Builder
for _, c := range a {
fmt.Fprintln(&sb, c.Hash())
}
return internal.FastHash(sb.String())
}
+70
View File
@@ -0,0 +1,70 @@
package checker
import (
"net/http"
"testing"
)
func TestAll_Check(t *testing.T) {
tests := []struct {
name string
checkers []MockChecker
want bool
wantErr bool
}{
{
name: "All match",
checkers: []MockChecker{
{Result: true, Err: nil},
{Result: true, Err: nil},
},
want: true,
wantErr: false,
},
{
name: "One not match",
checkers: []MockChecker{
{Result: true, Err: nil},
{Result: false, Err: nil},
},
want: false,
wantErr: false,
},
{
name: "No match",
checkers: []MockChecker{
{Result: false, Err: nil},
{Result: false, Err: nil},
},
want: false,
wantErr: false,
},
{
name: "Error encountered",
checkers: []MockChecker{
{Result: true, Err: nil},
{Result: false, Err: http.ErrNotSupported},
},
want: false,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var all All
for _, mc := range tt.checkers {
all = append(all, mc)
}
got, err := all.Check(nil)
if (err != nil) != tt.wantErr {
t.Errorf("All.Check() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("All.Check() = %v, want %v", got, tt.want)
}
})
}
}
+35
View File
@@ -0,0 +1,35 @@
package checker
import (
"fmt"
"net/http"
"strings"
"github.com/TecharoHQ/anubis/internal"
)
type Any []Interface
func (a Any) Check(r *http.Request) (bool, error) {
for _, c := range a {
match, err := c.Check(r)
if err != nil {
return match, err
}
if match {
return true, err // match
}
}
return false, nil // no match
}
func (a Any) Hash() string {
var sb strings.Builder
for _, c := range a {
fmt.Fprintln(&sb, c.Hash())
}
return internal.FastHash(sb.String())
}
+83
View File
@@ -0,0 +1,83 @@
package checker
import (
"net/http"
"testing"
)
type MockChecker struct {
Result bool
Err error
}
func (m MockChecker) Check(r *http.Request) (bool, error) {
return m.Result, m.Err
}
func (m MockChecker) Hash() string {
return "mock-hash"
}
func TestAny_Check(t *testing.T) {
tests := []struct {
name string
checkers []MockChecker
want bool
wantErr bool
}{
{
name: "All match",
checkers: []MockChecker{
{Result: true, Err: nil},
{Result: true, Err: nil},
},
want: true,
wantErr: false,
},
{
name: "One match",
checkers: []MockChecker{
{Result: false, Err: nil},
{Result: true, Err: nil},
},
want: true,
wantErr: false,
},
{
name: "No match",
checkers: []MockChecker{
{Result: false, Err: nil},
{Result: false, Err: nil},
},
want: false,
wantErr: false,
},
{
name: "Error encountered",
checkers: []MockChecker{
{Result: false, Err: nil},
{Result: false, Err: http.ErrNotSupported},
},
want: false,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var any Any
for _, mc := range tt.checkers {
any = append(any, mc)
}
got, err := any.Check(nil)
if (err != nil) != tt.wantErr {
t.Errorf("Any.Check() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("Any.Check() = %v, want %v", got, tt.want)
}
})
}
}
-30
View File
@@ -3,11 +3,7 @@ package checker
import (
"errors"
"fmt"
"net/http"
"strings"
"github.com/TecharoHQ/anubis/internal"
)
var (
@@ -19,29 +15,3 @@ type Interface interface {
Check(*http.Request) (matches bool, err error)
Hash() string
}
type List []Interface
func (l List) Check(r *http.Request) (bool, error) {
for _, c := range l {
ok, err := c.Check(r)
if err != nil {
return ok, err
}
if ok {
return ok, nil
}
}
return false, nil
}
func (l List) Hash() string {
var sb strings.Builder
for _, c := range l {
fmt.Fprintln(&sb, c.Hash())
}
return internal.FastHash(sb.String())
}
+22
View File
@@ -1,8 +1,12 @@
package headermatches
import (
"context"
"encoding/json"
"net/http"
"regexp"
"github.com/TecharoHQ/anubis/lib/checker"
)
type Checker struct {
@@ -22,3 +26,21 @@ func (c *Checker) Check(r *http.Request) (bool, error) {
func (c *Checker) Hash() string {
return c.hash
}
func New(key, valueRex string) (checker.Interface, error) {
fc := fileConfig{
Header: key,
ValueRegex: valueRex,
}
if err := fc.Valid(); err != nil {
return nil, err
}
data, err := json.Marshal(fc)
if err != nil {
return nil, err
}
return Factory{}.Build(context.Background(), json.RawMessage(data))
}
+34
View File
@@ -1 +1,35 @@
package headermatches
import (
"context"
"encoding/json"
"github.com/TecharoHQ/anubis/lib/checker"
)
func ValidUserAgent(valueRex string) error {
fc := fileConfig{
Header: "User-Agent",
ValueRegex: valueRex,
}
return fc.Valid()
}
func NewUserAgent(valueRex string) (checker.Interface, error) {
fc := fileConfig{
Header: "User-Agent",
ValueRegex: valueRex,
}
if err := fc.Valid(); err != nil {
return nil, err
}
data, err := json.Marshal(fc)
if err != nil {
return nil, err
}
return Factory{}.Build(context.Background(), json.RawMessage(data))
}
@@ -83,6 +83,14 @@ func (fc fileConfig) Valid() error {
return nil
}
func Valid(cidrs []string) error {
fc := fileConfig{
RemoteAddr: cidrs,
}
return fc.Valid()
}
func New(cidrs []string) (checker.Interface, error) {
fc := fileConfig{
RemoteAddr: cidrs,
@@ -1,4 +1,4 @@
package remoteaddress
package remoteaddress_test
import (
_ "embed"
@@ -8,17 +8,17 @@ import (
"testing"
"github.com/TecharoHQ/anubis/lib/checker"
"github.com/TecharoHQ/anubis/lib/policy/config"
"github.com/TecharoHQ/anubis/lib/checker/remoteaddress"
)
func TestFactoryIsCheckerFactory(t *testing.T) {
if _, ok := (any(Factory{})).(checker.Factory); !ok {
if _, ok := (any(remoteaddress.Factory{})).(checker.Factory); !ok {
t.Fatal("Factory is not an instance of checker.Factory")
}
}
func TestFactoryValidateConfig(t *testing.T) {
f := Factory{}
f := remoteaddress.Factory{}
for _, tt := range []struct {
name string
@@ -36,14 +36,14 @@ func TestFactoryValidateConfig(t *testing.T) {
{
name: "not json",
data: []byte(`]`),
err: config.ErrUnparseableConfig,
err: checker.ErrUnparseableConfig,
},
{
name: "no cidr",
data: []byte(`{
"remote_addresses": []
}`),
err: ErrNoRemoteAddresses,
err: remoteaddress.ErrNoRemoteAddresses,
},
{
name: "bad cidr",
@@ -52,7 +52,7 @@ func TestFactoryValidateConfig(t *testing.T) {
"according to all laws of aviation"
]
}`),
err: config.ErrInvalidCIDR,
err: remoteaddress.ErrInvalidCIDR,
},
} {
t.Run(tt.name, func(t *testing.T) {
@@ -68,7 +68,7 @@ func TestFactoryValidateConfig(t *testing.T) {
}
func TestFactoryCreate(t *testing.T) {
f := Factory{}
f := remoteaddress.Factory{}
for _, tt := range []struct {
name string
@@ -94,7 +94,7 @@ func TestFactoryCreate(t *testing.T) {
"according to all laws of aviation"
]
}`),
err: config.ErrUnparseableConfig,
err: checker.ErrUnparseableConfig,
},
} {
t.Run(tt.name, func(t *testing.T) {