Files
anubis-mirror/cmd/osiris/internal/entrypoint/router_test.go
2025-07-18 22:58:03 +00:00

320 lines
7.2 KiB
Go

package entrypoint
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"github.com/TecharoHQ/anubis/cmd/osiris/internal/config"
"github.com/hashicorp/hcl/v2/hclsimple"
)
func loadConfig(t *testing.T, fname string) config.Toplevel {
t.Helper()
var cfg config.Toplevel
if err := hclsimple.DecodeFile(fname, nil, &cfg); err != nil {
t.Fatalf("can't read configuration file %s: %v", fname, err)
}
if err := cfg.Valid(); err != nil {
t.Errorf("configuration file %s is invalid: %v", "./testdata/selfsigned.hcl", err)
}
return cfg
}
func newRouter(t *testing.T, cfg config.Toplevel) *Router {
t.Helper()
rtr, err := NewRouter(cfg)
if err != nil {
t.Fatal(err)
}
return rtr
}
func TestNewRouter(t *testing.T) {
cfg := loadConfig(t, "./testdata/good/selfsigned.hcl")
rtr := newRouter(t, cfg)
srv := httptest.NewServer(rtr)
defer srv.Close()
}
func TestNewRouterFails(t *testing.T) {
cfg := loadConfig(t, "./testdata/good/selfsigned.hcl")
cfg.Domains = append(cfg.Domains, config.Domain{
Name: "test1.internal",
TLS: config.TLS{
Cert: "./testdata/tls/invalid.crt",
Key: "./testdata/tls/invalid.key",
},
Target: cfg.Domains[0].Target,
HealthTarget: cfg.Domains[0].HealthTarget,
})
rtr, err := NewRouter(cfg)
if err == nil {
t.Fatal("wanted an error but got none")
}
srv := httptest.NewServer(rtr)
defer srv.Close()
}
func TestRouterSetConfig(t *testing.T) {
for _, tt := range []struct {
name string
configFname string
mutation func(cfg config.Toplevel) config.Toplevel
err error
}{
{
name: "basic",
configFname: "./testdata/good/selfsigned.hcl",
mutation: func(cfg config.Toplevel) config.Toplevel {
return cfg
},
},
{
name: "all schemes",
configFname: "./testdata/good/selfsigned.hcl",
mutation: func(cfg config.Toplevel) config.Toplevel {
cfg.Domains = append(cfg.Domains, config.Domain{
Name: "http.internal",
TLS: cfg.Domains[0].TLS,
Target: "http://[::1]:3000",
HealthTarget: cfg.Domains[0].HealthTarget,
})
cfg.Domains = append(cfg.Domains, config.Domain{
Name: "https.internal",
TLS: cfg.Domains[0].TLS,
Target: "https://[::1]:3000",
HealthTarget: cfg.Domains[0].HealthTarget,
})
cfg.Domains = append(cfg.Domains, config.Domain{
Name: "h2c.internal",
TLS: cfg.Domains[0].TLS,
Target: "h2c://[::1]:3000",
HealthTarget: cfg.Domains[0].HealthTarget,
})
cfg.Domains = append(cfg.Domains, config.Domain{
Name: "unix.internal",
TLS: cfg.Domains[0].TLS,
Target: "unix://foo.sock",
HealthTarget: cfg.Domains[0].HealthTarget,
})
return cfg
},
},
{
name: "invalid TLS",
configFname: "./testdata/good/selfsigned.hcl",
mutation: func(cfg config.Toplevel) config.Toplevel {
cfg.Domains = append(cfg.Domains, config.Domain{
Name: "test1.internal",
TLS: config.TLS{
Cert: "./testdata/tls/invalid.crt",
Key: "./testdata/tls/invalid.key",
},
Target: cfg.Domains[0].Target,
HealthTarget: cfg.Domains[0].HealthTarget,
})
return cfg
},
err: ErrInvalidTLSKeypair,
},
{
name: "target is not a valid URL",
configFname: "./testdata/good/selfsigned.hcl",
mutation: func(cfg config.Toplevel) config.Toplevel {
cfg.Domains = append(cfg.Domains, config.Domain{
Name: "test1.internal",
TLS: cfg.Domains[0].TLS,
Target: "http://[::1:443",
HealthTarget: cfg.Domains[0].HealthTarget,
})
return cfg
},
err: ErrTargetInvalid,
},
{
name: "invalid target scheme",
configFname: "./testdata/good/selfsigned.hcl",
mutation: func(cfg config.Toplevel) config.Toplevel {
cfg.Domains = append(cfg.Domains, config.Domain{
Name: "test1.internal",
TLS: cfg.Domains[0].TLS,
Target: "foo://",
HealthTarget: cfg.Domains[0].HealthTarget,
})
return cfg
},
err: ErrNoHandler,
},
} {
t.Run(tt.name, func(t *testing.T) {
cfg := loadConfig(t, tt.configFname)
rtr := newRouter(t, cfg)
cfg = tt.mutation(cfg)
if err := rtr.setConfig(cfg); !errors.Is(err, tt.err) {
t.Logf("want: %v", tt.err)
t.Logf("got: %v", err)
t.Error("got wrong error from rtr.setConfig function")
}
})
}
}
type ackHandler struct {
ack bool
}
func (ah *ackHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ah.ack = true
fmt.Fprintln(w, "OK")
}
func (ah *ackHandler) Reset() {
ah.ack = false
}
func newUnixServer(t *testing.T, h http.Handler) string {
sockName := filepath.Join(t.TempDir(), "s")
ln, err := net.Listen("unix", sockName)
if err != nil {
t.Fatalf("can't listen on %s: %v", sockName, err)
}
t.Cleanup(func() {
ln.Close()
os.Remove(sockName)
})
go func(ctx context.Context) {
srv := &http.Server{
Handler: h,
}
go func() {
<-ctx.Done()
srv.Close()
}()
srv.Serve(ln)
}(t.Context())
return "unix://" + sockName
}
func TestRouterGetCertificate(t *testing.T) {
cfg := loadConfig(t, "./testdata/good/selfsigned.hcl")
rtr := newRouter(t, cfg)
for _, tt := range []struct {
domainName string
err error
}{
{
domainName: "osiris.local.cetacean.club",
},
{
domainName: "whacky-fun.local",
err: ErrNoCert,
},
} {
t.Run(tt.domainName, func(t *testing.T) {
if _, err := rtr.GetCertificate(&tls.ClientHelloInfo{ServerName: tt.domainName}); !errors.Is(err, tt.err) {
t.Logf("want: %v", tt.err)
t.Logf("got: %v", err)
t.Error("got wrong error from rtr.GetCertificate")
}
})
}
}
func TestRouterServeAllProtocols(t *testing.T) {
cfg := loadConfig(t, "./testdata/good/all_protocols.hcl")
httpAckHandler := &ackHandler{}
httpsAckHandler := &ackHandler{}
h2cAckHandler := &ackHandler{}
unixAckHandler := &ackHandler{}
httpSrv := httptest.NewServer(httpAckHandler)
httpsSrv := httptest.NewTLSServer(httpsAckHandler)
h2cSrv := newH2cServer(t, h2cAckHandler)
unixPath := newUnixServer(t, unixAckHandler)
cfg.Domains[0].Target = httpSrv.URL
cfg.Domains[1].Target = httpsSrv.URL
cfg.Domains[2].Target = strings.ReplaceAll(h2cSrv.URL, "http:", "h2c:")
cfg.Domains[3].Target = unixPath
// enc := json.NewEncoder(os.Stderr)
// enc.SetIndent("", " ")
// enc.Encode(cfg)
rtr := newRouter(t, cfg)
cli := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
}
t.Run("plain http", func(t *testing.T) {
ln, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
ln.Close()
})
go rtr.HandleHTTP(t.Context(), ln)
serverURL := "http://" + ln.Addr().String()
t.Log(serverURL)
for _, d := range cfg.Domains {
t.Run(d.Name, func(t *testing.T) {
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, serverURL, nil)
if err != nil {
t.Fatal(err)
}
req.Host = d.Name
resp, err := cli.Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("wrong status code %d", resp.StatusCode)
}
})
}
})
}