diff --git a/lib/policy/checker/checker.go b/lib/policy/checker/checker.go index 1ee276aa..dcb1c0c0 100644 --- a/lib/policy/checker/checker.go +++ b/lib/policy/checker/checker.go @@ -10,7 +10,7 @@ import ( ) type Impl interface { - Check(*http.Request) (bool, error) + Check(*http.Request) (matches bool, err error) Hash() string } diff --git a/lib/policy/checker/registry.go b/lib/policy/checker/registry.go new file mode 100644 index 00000000..d618c06d --- /dev/null +++ b/lib/policy/checker/registry.go @@ -0,0 +1,42 @@ +package checker + +import ( + "encoding/json" + "sort" + "sync" +) + +type Factory interface { + ValidateConfig(json.RawMessage) error + Create(json.RawMessage) (Impl, error) +} + +var ( + registry map[string]Factory = map[string]Factory{} + regLock sync.RWMutex +) + +func Register(name string, factory Factory) { + regLock.Lock() + defer regLock.Unlock() + + registry[name] = factory +} + +func Get(name string) (Factory, bool) { + regLock.RLock() + defer regLock.RUnlock() + result, ok := registry[name] + return result, ok +} + +func Methods() []string { + regLock.RLock() + defer regLock.RUnlock() + var result []string + for method := range registry { + result = append(result, method) + } + sort.Strings(result) + return result +} diff --git a/lib/policy/checker/remoteaddress/remoteaddress.go b/lib/policy/checker/remoteaddress/remoteaddress.go new file mode 100644 index 00000000..b05d68dd --- /dev/null +++ b/lib/policy/checker/remoteaddress/remoteaddress.go @@ -0,0 +1,104 @@ +package remoteaddress + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "net/netip" + + "github.com/TecharoHQ/anubis/internal" + "github.com/TecharoHQ/anubis/lib/policy" + "github.com/TecharoHQ/anubis/lib/policy/checker" + "github.com/TecharoHQ/anubis/lib/policy/config" + "github.com/gaissmai/bart" +) + +var ( + ErrNoRemoteAddresses = errors.New("remoteaddress: no remote addresses defined") +) + +func init() {} + +type Factory struct{} + +func (Factory) ValidateConfig(inp json.RawMessage) error { + var fc fileConfig + if err := json.Unmarshal([]byte(inp), &fc); err != nil { + return fmt.Errorf("%w: %w", config.ErrUnparseableConfig, err) + } + + if err := fc.Valid(); err != nil { + return err + } + + return nil +} + +func (Factory) Create(inp json.RawMessage) (checker.Impl, error) { + c := struct { + RemoteAddr []netip.Prefix `json:"remote_addresses,omitempty" yaml:"remote_addresses,omitempty"` + }{} + + if err := json.Unmarshal([]byte(inp), &c); err != nil { + return nil, fmt.Errorf("%w: %w", config.ErrUnparseableConfig, err) + } + + table := new(bart.Lite) + + for _, cidr := range c.RemoteAddr { + table.Insert(cidr) + } + + return &RemoteAddrChecker{ + prefixTable: table, + hash: internal.FastHash(string(inp)), + }, nil +} + +type fileConfig struct { + RemoteAddr []string `json:"remote_addresses,omitempty" yaml:"remote_addresses,omitempty"` +} + +func (fc fileConfig) Valid() error { + var errs []error + + if len(fc.RemoteAddr) == 0 { + errs = append(errs, ErrNoRemoteAddresses) + } + + for _, cidr := range fc.RemoteAddr { + if _, err := netip.ParsePrefix(cidr); err != nil { + errs = append(errs, fmt.Errorf("%w: cidr %q is invalid: %w", config.ErrInvalidCIDR, cidr, err)) + } + } + + if len(errs) != 0 { + return fmt.Errorf("%w: %w", policy.ErrMisconfiguration, errors.Join(errs...)) + } + + return nil +} + +type RemoteAddrChecker struct { + prefixTable *bart.Lite + hash string +} + +func (rac *RemoteAddrChecker) Check(r *http.Request) (bool, error) { + host := r.Header.Get("X-Real-Ip") + if host == "" { + return false, fmt.Errorf("%w: header X-Real-Ip is not set", policy.ErrMisconfiguration) + } + + addr, err := netip.ParseAddr(host) + if err != nil { + return false, fmt.Errorf("%w: %s is not an IP address: %w", policy.ErrMisconfiguration, host, err) + } + + return rac.prefixTable.Contains(addr), nil +} + +func (rac *RemoteAddrChecker) Hash() string { + return rac.hash +} diff --git a/lib/policy/checker/remoteaddress/remoteaddress_test.go b/lib/policy/checker/remoteaddress/remoteaddress_test.go new file mode 100644 index 00000000..f3bb9a3e --- /dev/null +++ b/lib/policy/checker/remoteaddress/remoteaddress_test.go @@ -0,0 +1,216 @@ +package remoteaddress + +import ( + _ "embed" + "encoding/json" + "errors" + "net/http" + "testing" + + "github.com/TecharoHQ/anubis/lib/policy/checker" + "github.com/TecharoHQ/anubis/lib/policy/config" +) + +func TestFactoryIsCheckerFactory(t *testing.T) { + if _, ok := (any(Factory{})).(checker.Factory); !ok { + t.Fatal("Factory is not an instance of checker.Factory") + } +} + +func TestFactoryValidateConfig(t *testing.T) { + f := Factory{} + + for _, tt := range []struct { + name string + data []byte + err error + }{ + { + name: "basic valid", + data: []byte(`{ + "remote_addresses": [ + "1.1.1.1/32" + ] +}`), + }, + { + name: "not json", + data: []byte(`]`), + err: config.ErrUnparseableConfig, + }, + { + name: "no cidr", + data: []byte(`{ + "remote_addresses": [] +}`), + err: ErrNoRemoteAddresses, + }, + { + name: "bad cidr", + data: []byte(`{ + "remote_addresses": [ + "according to all laws of aviation" + ] +}`), + err: config.ErrInvalidCIDR, + }, + } { + t.Run(tt.name, func(t *testing.T) { + data := json.RawMessage(tt.data) + + if err := f.ValidateConfig(data); !errors.Is(err, tt.err) { + t.Logf("want: %v", tt.err) + t.Logf("got: %v", err) + t.Fatal("validation didn't do what was expected") + } + }) + } +} + +func TestFactoryCreate(t *testing.T) { + f := Factory{} + + for _, tt := range []struct { + name string + data []byte + err error + ip string + match bool + }{ + { + name: "basic valid", + data: []byte(`{ + "remote_addresses": [ + "1.1.1.1/32" + ] +}`), + ip: "1.1.1.1", + match: true, + }, + { + name: "bad cidr", + data: []byte(`{ + "remote_addresses": [ + "according to all laws of aviation" + ] +}`), + err: config.ErrUnparseableConfig, + }, + } { + t.Run(tt.name, func(t *testing.T) { + data := json.RawMessage(tt.data) + + impl, err := f.Create(data) + if !errors.Is(err, tt.err) { + t.Logf("want: %v", tt.err) + t.Logf("got: %v", err) + t.Fatal("creation didn't do what was expected") + } + + if tt.err != nil { + return + } + + r, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatalf("can't make request: %v", err) + } + + if tt.ip != "" { + r.Header.Add("X-Real-Ip", tt.ip) + } + + match, err := impl.Check(r) + + if tt.match != match { + t.Errorf("match: %v, wanted: %v", match, tt.match) + } + + if err != nil && tt.err != nil && !errors.Is(err, tt.err) { + t.Errorf("err: %v, wanted: %v", err, tt.err) + } + + if impl.Hash() == "" { + t.Error("hash method returns empty string") + } + }) + } +} + +// func TestRemoteAddrChecker(t *testing.T) { +// for _, tt := range []struct { +// err error +// name string +// ip string +// cidrs []string +// ok bool +// }{ +// { +// name: "match_ipv4", +// cidrs: []string{"0.0.0.0/0"}, +// ip: "1.1.1.1", +// ok: true, +// err: nil, +// }, +// { +// name: "match_ipv6", +// cidrs: []string{"::/0"}, +// ip: "cafe:babe::", +// ok: true, +// err: nil, +// }, +// { +// name: "not_match_ipv4", +// cidrs: []string{"1.1.1.1/32"}, +// ip: "1.1.1.2", +// ok: false, +// err: nil, +// }, +// { +// name: "not_match_ipv6", +// cidrs: []string{"cafe:babe::/128"}, +// ip: "cafe:babe:4::/128", +// ok: false, +// err: nil, +// }, +// { +// name: "no_ip_set", +// cidrs: []string{"::/0"}, +// ok: false, +// err: policy.ErrMisconfiguration, +// }, +// { +// name: "invalid_ip", +// cidrs: []string{"::/0"}, +// ip: "According to all natural laws of aviation", +// ok: false, +// err: policy.ErrMisconfiguration, +// }, +// } { +// t.Run(tt.name, func(t *testing.T) { +// rac, err := NewRemoteAddrChecker(tt.cidrs) +// if err != nil && !errors.Is(err, tt.err) { +// t.Fatalf("creating RemoteAddrChecker failed: %v", err) +// } + +// r, err := http.NewRequest(http.MethodGet, "/", nil) +// if err != nil { +// t.Fatalf("can't make request: %v", err) +// } + +// if tt.ip != "" { +// r.Header.Add("X-Real-Ip", tt.ip) +// } + +// ok, err := rac.Check(r) + +// if tt.ok != ok { +// t.Errorf("ok: %v, wanted: %v", ok, tt.ok) +// } + +// if err != nil && tt.err != nil && !errors.Is(err, tt.err) { +// t.Errorf("err: %v, wanted: %v", err, tt.err) +// } +// }) +// } +// } diff --git a/lib/policy/checker/remoteaddress/testdata/invalid_bad_cidr.json b/lib/policy/checker/remoteaddress/testdata/invalid_bad_cidr.json new file mode 100644 index 00000000..09eb5a81 --- /dev/null +++ b/lib/policy/checker/remoteaddress/testdata/invalid_bad_cidr.json @@ -0,0 +1,5 @@ +{ + "remote_addresses": [ + "according to all laws of aviation" + ] +} \ No newline at end of file diff --git a/lib/policy/checker/remoteaddress/testdata/invalid_no_cidr.json b/lib/policy/checker/remoteaddress/testdata/invalid_no_cidr.json new file mode 100644 index 00000000..0c979eea --- /dev/null +++ b/lib/policy/checker/remoteaddress/testdata/invalid_no_cidr.json @@ -0,0 +1,3 @@ +{ + "remote_addresses": [] +} \ No newline at end of file diff --git a/lib/policy/checker/remoteaddress/testdata/invalid_not_json.json b/lib/policy/checker/remoteaddress/testdata/invalid_not_json.json new file mode 100644 index 00000000..54caf60b --- /dev/null +++ b/lib/policy/checker/remoteaddress/testdata/invalid_not_json.json @@ -0,0 +1 @@ +] \ No newline at end of file diff --git a/lib/policy/checker/remoteaddress/testdata/valid_addresses.json b/lib/policy/checker/remoteaddress/testdata/valid_addresses.json new file mode 100644 index 00000000..53d59dd0 --- /dev/null +++ b/lib/policy/checker/remoteaddress/testdata/valid_addresses.json @@ -0,0 +1,5 @@ +{ + "remote_addresses": [ + "1.1.1.1/32" + ] +} \ No newline at end of file diff --git a/lib/policy/config/config.go b/lib/policy/config/config.go index 18aceb09..f495a3bd 100644 --- a/lib/policy/config/config.go +++ b/lib/policy/config/config.go @@ -31,6 +31,7 @@ var ( ErrCantSetBotAndImportValuesAtOnce = errors.New("config.BotOrImport: can't set bot rules and import values at the same time") ErrMustSetBotOrImportRules = errors.New("config.BotOrImport: rule definition is invalid, you must set either bot rules or an import statement, not both") ErrStatusCodeNotValid = errors.New("config.StatusCode: status code not valid, must be between 100 and 599") + ErrUnparseableConfig = errors.New("config: can't parse configuration file") ) type Rule string