mirror of
https://github.com/TecharoHQ/anubis.git
synced 2026-04-15 04:58:43 +00:00
refactor: raise checker to be a subpackage of lib
Signed-off-by: Xe Iaso <me@xeiaso.net>
This commit is contained in:
47
lib/checker/checker.go
Normal file
47
lib/checker/checker.go
Normal file
@@ -0,0 +1,47 @@
|
||||
// Package checker defines the Checker interface and a helper utility to avoid import cycles.
|
||||
package checker
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/TecharoHQ/anubis/internal"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUnparseableConfig = errors.New("checker: config is unparseable")
|
||||
ErrInvalidConfig = errors.New("checker: config is invalid")
|
||||
)
|
||||
|
||||
type Interface interface {
|
||||
Check(*http.Request) (matches bool, err error)
|
||||
Hash() string
|
||||
}
|
||||
|
||||
type List []Interface
|
||||
|
||||
func (l List) Check(r *http.Request) (bool, error) {
|
||||
for _, c := range l {
|
||||
ok, err := c.Check(r)
|
||||
if err != nil {
|
||||
return ok, err
|
||||
}
|
||||
if ok {
|
||||
return ok, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (l List) Hash() string {
|
||||
var sb strings.Builder
|
||||
|
||||
for _, c := range l {
|
||||
fmt.Fprintln(&sb, c.Hash())
|
||||
}
|
||||
|
||||
return internal.FastHash(sb.String())
|
||||
}
|
||||
43
lib/checker/registry.go
Normal file
43
lib/checker/registry.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package checker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"sort"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Factory interface {
|
||||
Build(context.Context, json.RawMessage) (Interface, error)
|
||||
Valid(context.Context, json.RawMessage) error
|
||||
}
|
||||
|
||||
var (
|
||||
registry map[string]Factory = map[string]Factory{}
|
||||
regLock sync.RWMutex
|
||||
)
|
||||
|
||||
func Register(name string, factory Factory) {
|
||||
regLock.Lock()
|
||||
defer regLock.Unlock()
|
||||
|
||||
registry[name] = factory
|
||||
}
|
||||
|
||||
func Get(name string) (Factory, bool) {
|
||||
regLock.RLock()
|
||||
defer regLock.RUnlock()
|
||||
result, ok := registry[name]
|
||||
return result, ok
|
||||
}
|
||||
|
||||
func Methods() []string {
|
||||
regLock.RLock()
|
||||
defer regLock.RUnlock()
|
||||
var result []string
|
||||
for method := range registry {
|
||||
result = append(result, method)
|
||||
}
|
||||
sort.Strings(result)
|
||||
return result
|
||||
}
|
||||
119
lib/checker/remoteaddress/remoteaddress.go
Normal file
119
lib/checker/remoteaddress/remoteaddress.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package remoteaddress
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
|
||||
"github.com/TecharoHQ/anubis"
|
||||
"github.com/TecharoHQ/anubis/internal"
|
||||
"github.com/TecharoHQ/anubis/lib/checker"
|
||||
"github.com/gaissmai/bart"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNoRemoteAddresses = errors.New("remoteaddress: no remote addresses defined")
|
||||
ErrInvalidCIDR = errors.New("remoteaddress: invalid CIDR")
|
||||
)
|
||||
|
||||
func init() {
|
||||
checker.Register("remote_address", Factory{})
|
||||
}
|
||||
|
||||
type Factory struct{}
|
||||
|
||||
func (Factory) Valid(_ context.Context, inp json.RawMessage) error {
|
||||
var fc fileConfig
|
||||
if err := json.Unmarshal([]byte(inp), &fc); err != nil {
|
||||
return fmt.Errorf("%w: %w", checker.ErrUnparseableConfig, err)
|
||||
}
|
||||
|
||||
if err := fc.Valid(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (Factory) Build(_ context.Context, inp json.RawMessage) (checker.Interface, error) {
|
||||
c := struct {
|
||||
RemoteAddr []netip.Prefix `json:"remote_addresses,omitempty" yaml:"remote_addresses,omitempty"`
|
||||
}{}
|
||||
|
||||
if err := json.Unmarshal([]byte(inp), &c); err != nil {
|
||||
return nil, fmt.Errorf("%w: %w", checker.ErrUnparseableConfig, err)
|
||||
}
|
||||
|
||||
table := new(bart.Lite)
|
||||
|
||||
for _, cidr := range c.RemoteAddr {
|
||||
table.Insert(cidr)
|
||||
}
|
||||
|
||||
return &RemoteAddrChecker{
|
||||
prefixTable: table,
|
||||
hash: internal.FastHash(string(inp)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type fileConfig struct {
|
||||
RemoteAddr []string `json:"remote_addresses,omitempty" yaml:"remote_addresses,omitempty"`
|
||||
}
|
||||
|
||||
func (fc fileConfig) Valid() error {
|
||||
var errs []error
|
||||
|
||||
if len(fc.RemoteAddr) == 0 {
|
||||
errs = append(errs, ErrNoRemoteAddresses)
|
||||
}
|
||||
|
||||
for _, cidr := range fc.RemoteAddr {
|
||||
if _, err := netip.ParsePrefix(cidr); err != nil {
|
||||
errs = append(errs, fmt.Errorf("%w: cidr %q is invalid: %w", ErrInvalidCIDR, cidr, err))
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) != 0 {
|
||||
return fmt.Errorf("%w: %w", checker.ErrInvalidConfig, errors.Join(errs...))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func New(cidrs []string) (checker.Interface, error) {
|
||||
fc := fileConfig{
|
||||
RemoteAddr: cidrs,
|
||||
}
|
||||
data, err := json.Marshal(fc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return Factory{}.Build(context.Background(), json.RawMessage(data))
|
||||
}
|
||||
|
||||
type RemoteAddrChecker struct {
|
||||
prefixTable *bart.Lite
|
||||
hash string
|
||||
}
|
||||
|
||||
func (rac *RemoteAddrChecker) Check(r *http.Request) (bool, error) {
|
||||
host := r.Header.Get("X-Real-Ip")
|
||||
if host == "" {
|
||||
return false, fmt.Errorf("%w: header X-Real-Ip is not set", anubis.ErrMisconfiguration)
|
||||
}
|
||||
|
||||
addr, err := netip.ParseAddr(host)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("%w: %s is not an IP address: %w", anubis.ErrMisconfiguration, host, err)
|
||||
}
|
||||
|
||||
return rac.prefixTable.Contains(addr), nil
|
||||
}
|
||||
|
||||
func (rac *RemoteAddrChecker) Hash() string {
|
||||
return rac.hash
|
||||
}
|
||||
138
lib/checker/remoteaddress/remoteaddress_test.go
Normal file
138
lib/checker/remoteaddress/remoteaddress_test.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package remoteaddress
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/TecharoHQ/anubis/lib/checker"
|
||||
"github.com/TecharoHQ/anubis/lib/policy/config"
|
||||
)
|
||||
|
||||
func TestFactoryIsCheckerFactory(t *testing.T) {
|
||||
if _, ok := (any(Factory{})).(checker.Factory); !ok {
|
||||
t.Fatal("Factory is not an instance of checker.Factory")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFactoryValidateConfig(t *testing.T) {
|
||||
f := Factory{}
|
||||
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
data []byte
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "basic valid",
|
||||
data: []byte(`{
|
||||
"remote_addresses": [
|
||||
"1.1.1.1/32"
|
||||
]
|
||||
}`),
|
||||
},
|
||||
{
|
||||
name: "not json",
|
||||
data: []byte(`]`),
|
||||
err: config.ErrUnparseableConfig,
|
||||
},
|
||||
{
|
||||
name: "no cidr",
|
||||
data: []byte(`{
|
||||
"remote_addresses": []
|
||||
}`),
|
||||
err: ErrNoRemoteAddresses,
|
||||
},
|
||||
{
|
||||
name: "bad cidr",
|
||||
data: []byte(`{
|
||||
"remote_addresses": [
|
||||
"according to all laws of aviation"
|
||||
]
|
||||
}`),
|
||||
err: config.ErrInvalidCIDR,
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
data := json.RawMessage(tt.data)
|
||||
|
||||
if err := f.Valid(t.Context(), data); !errors.Is(err, tt.err) {
|
||||
t.Logf("want: %v", tt.err)
|
||||
t.Logf("got: %v", err)
|
||||
t.Fatal("validation didn't do what was expected")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFactoryCreate(t *testing.T) {
|
||||
f := Factory{}
|
||||
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
data []byte
|
||||
err error
|
||||
ip string
|
||||
match bool
|
||||
}{
|
||||
{
|
||||
name: "basic valid",
|
||||
data: []byte(`{
|
||||
"remote_addresses": [
|
||||
"1.1.1.1/32"
|
||||
]
|
||||
}`),
|
||||
ip: "1.1.1.1",
|
||||
match: true,
|
||||
},
|
||||
{
|
||||
name: "bad cidr",
|
||||
data: []byte(`{
|
||||
"remote_addresses": [
|
||||
"according to all laws of aviation"
|
||||
]
|
||||
}`),
|
||||
err: config.ErrUnparseableConfig,
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
data := json.RawMessage(tt.data)
|
||||
|
||||
impl, err := f.Build(t.Context(), data)
|
||||
if !errors.Is(err, tt.err) {
|
||||
t.Logf("want: %v", tt.err)
|
||||
t.Logf("got: %v", err)
|
||||
t.Fatal("creation didn't do what was expected")
|
||||
}
|
||||
|
||||
if tt.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
r, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("can't make request: %v", err)
|
||||
}
|
||||
|
||||
if tt.ip != "" {
|
||||
r.Header.Add("X-Real-Ip", tt.ip)
|
||||
}
|
||||
|
||||
match, err := impl.Check(r)
|
||||
|
||||
if tt.match != match {
|
||||
t.Errorf("match: %v, wanted: %v", match, tt.match)
|
||||
}
|
||||
|
||||
if err != nil && tt.err != nil && !errors.Is(err, tt.err) {
|
||||
t.Errorf("err: %v, wanted: %v", err, tt.err)
|
||||
}
|
||||
|
||||
if impl.Hash() == "" {
|
||||
t.Error("hash method returns empty string")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
5
lib/checker/remoteaddress/testdata/invalid_bad_cidr.json
vendored
Normal file
5
lib/checker/remoteaddress/testdata/invalid_bad_cidr.json
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"remote_addresses": [
|
||||
"according to all laws of aviation"
|
||||
]
|
||||
}
|
||||
3
lib/checker/remoteaddress/testdata/invalid_no_cidr.json
vendored
Normal file
3
lib/checker/remoteaddress/testdata/invalid_no_cidr.json
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"remote_addresses": []
|
||||
}
|
||||
1
lib/checker/remoteaddress/testdata/invalid_not_json.json
vendored
Normal file
1
lib/checker/remoteaddress/testdata/invalid_not_json.json
vendored
Normal file
@@ -0,0 +1 @@
|
||||
]
|
||||
5
lib/checker/remoteaddress/testdata/valid_addresses.json
vendored
Normal file
5
lib/checker/remoteaddress/testdata/valid_addresses.json
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"remote_addresses": [
|
||||
"1.1.1.1/32"
|
||||
]
|
||||
}
|
||||
Reference in New Issue
Block a user