mirror of
https://github.com/TecharoHQ/anubis.git
synced 2026-04-11 02:58:49 +00:00
175 lines
3.9 KiB
Go
175 lines
3.9 KiB
Go
package entrypoint
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httputil"
|
|
"net/url"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/TecharoHQ/anubis/cmd/osiris/internal/config"
|
|
"github.com/TecharoHQ/anubis/internal/fingerprint"
|
|
"github.com/lum8rjack/go-ja4h"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"github.com/prometheus/client_golang/prometheus/promauto"
|
|
)
|
|
|
|
var (
|
|
ErrTargetInvalid = errors.New("[unexpected] target invalid")
|
|
ErrNoHandler = errors.New("[unexpected] no handler for domain")
|
|
ErrInvalidTLSKeypair = errors.New("[unexpected] invalid TLS keypair")
|
|
ErrNoCert = errors.New("this server does not have a certificate for that domain")
|
|
|
|
requestsPerDomain = promauto.NewGaugeVec(prometheus.GaugeOpts{
|
|
Namespace: "techaro",
|
|
Subsystem: "osiris",
|
|
Name: "request_count",
|
|
}, []string{"domain"})
|
|
|
|
unresolvedRequests = promauto.NewGauge(prometheus.GaugeOpts{
|
|
Namespace: "techaro",
|
|
Subsystem: "osiris",
|
|
Name: "unresolved_requests",
|
|
})
|
|
)
|
|
|
|
type Router struct {
|
|
lock sync.RWMutex
|
|
routes map[string]http.Handler
|
|
tlsCerts map[string]*tls.Certificate
|
|
}
|
|
|
|
func (rtr *Router) setConfig(c config.Toplevel) error {
|
|
var errs []error
|
|
newMap := map[string]http.Handler{}
|
|
newCerts := map[string]*tls.Certificate{}
|
|
|
|
for _, d := range c.Domains {
|
|
var domainErrs []error
|
|
|
|
u, err := url.Parse(d.Target)
|
|
if err != nil {
|
|
domainErrs = append(domainErrs, fmt.Errorf("%w %q: %v", ErrTargetInvalid, d.Target, err))
|
|
}
|
|
|
|
var h http.Handler
|
|
|
|
switch u.Scheme {
|
|
case "http", "https":
|
|
h = httputil.NewSingleHostReverseProxy(u)
|
|
case "h2c":
|
|
h = newH2CReverseProxy(u)
|
|
case "unix":
|
|
h = &httputil.ReverseProxy{
|
|
Transport: &http.Transport{
|
|
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
|
|
return net.Dial("unix", strings.TrimPrefix(d.Target, "unix://"))
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
if h == nil {
|
|
domainErrs = append(domainErrs, ErrNoHandler)
|
|
}
|
|
|
|
newMap[d.Name] = h
|
|
|
|
cert, err := tls.LoadX509KeyPair(d.TLS.Cert, d.TLS.Key)
|
|
if err != nil {
|
|
domainErrs = append(domainErrs, fmt.Errorf("%w: %w", ErrInvalidTLSKeypair, err))
|
|
}
|
|
|
|
newCerts[d.Name] = &cert
|
|
|
|
if len(domainErrs) != 0 {
|
|
errs = append(errs, fmt.Errorf("invalid domain %s: %w", d.Name, errors.Join(domainErrs...)))
|
|
}
|
|
}
|
|
|
|
if len(errs) != 0 {
|
|
return fmt.Errorf("can't compile config to routing map: %w", errors.Join(errs...))
|
|
}
|
|
|
|
rtr.lock.Lock()
|
|
rtr.routes = newMap
|
|
rtr.tlsCerts = newCerts
|
|
rtr.lock.Unlock()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (rtr *Router) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
rtr.lock.RLock()
|
|
cert, ok := rtr.tlsCerts[hello.ServerName]
|
|
rtr.lock.RUnlock()
|
|
|
|
if !ok {
|
|
return nil, ErrNoCert
|
|
}
|
|
|
|
return cert, nil
|
|
}
|
|
|
|
func NewRouter(c config.Toplevel) (*Router, error) {
|
|
result := &Router{
|
|
routes: map[string]http.Handler{},
|
|
}
|
|
|
|
if err := result.setConfig(c); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func (rtr *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
var host = r.Host
|
|
|
|
if strings.Contains(host, ":") {
|
|
host, _, _ = net.SplitHostPort(host)
|
|
}
|
|
|
|
requestsPerDomain.WithLabelValues(host).Inc()
|
|
|
|
var h http.Handler
|
|
var ok bool
|
|
|
|
ja4hFP := ja4h.JA4H(r)
|
|
|
|
slog.Info("got request", "method", r.Method, "host", host, "path", r.URL.Path)
|
|
|
|
rtr.lock.RLock()
|
|
h, ok = rtr.routes[host]
|
|
rtr.lock.RUnlock()
|
|
|
|
if !ok {
|
|
unresolvedRequests.Inc()
|
|
http.NotFound(w, r) // TODO(Xe): brand this
|
|
return
|
|
}
|
|
|
|
r.Header.Set("X-Http-Ja4h-Fingerprint", ja4hFP)
|
|
|
|
if fp := fingerprint.GetTLSFingerprint(r); fp != nil {
|
|
if ja3n := fp.JA3N(); ja3n != nil {
|
|
r.Header.Set("X-Tls-Ja3n-Fingerprint", ja3n.String())
|
|
}
|
|
if ja4 := fp.JA4(); ja4 != nil {
|
|
r.Header.Set("X-Tls-Ja4-Fingerprint", ja4.String())
|
|
}
|
|
}
|
|
|
|
if tcpFP := fingerprint.GetTCPFingerprint(r); tcpFP != nil {
|
|
r.Header.Set("X-Tcp-Ja4t-Fingerprint", tcpFP.String())
|
|
}
|
|
|
|
h.ServeHTTP(w, r)
|
|
}
|