mirror of
https://github.com/TecharoHQ/anubis.git
synced 2026-04-22 08:06:41 +00:00
feat(internal): move SetupListener from main
Signed-off-by: Xe Iaso <me@xeiaso.net>
This commit is contained in:
@@ -0,0 +1,92 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// parseBindNetFromAddr determine bind network and address based on the given network and address.
|
||||||
|
func parseBindNetFromAddr(address string) (string, string, error) {
|
||||||
|
defaultScheme := "http://"
|
||||||
|
if !strings.Contains(address, "://") {
|
||||||
|
if strings.HasPrefix(address, ":") {
|
||||||
|
address = defaultScheme + "localhost" + address
|
||||||
|
} else {
|
||||||
|
address = defaultScheme + address
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bindUri, err := url.Parse(address)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", fmt.Errorf("failed to parse bind URL: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch bindUri.Scheme {
|
||||||
|
case "unix":
|
||||||
|
return "unix", bindUri.Path, nil
|
||||||
|
case "tcp", "http", "https":
|
||||||
|
return "tcp", bindUri.Host, nil
|
||||||
|
default:
|
||||||
|
return "", "", fmt.Errorf("unsupported network scheme %s in address %s", bindUri.Scheme, address)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupListener sets up a network listener based on the input from configuration
|
||||||
|
// envvars. It returns a network listener and the URL to that listener or an error.
|
||||||
|
func SetupListener(network, address, socketMode string) (net.Listener, string, error) {
|
||||||
|
formattedAddress := ""
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if network == "" {
|
||||||
|
// keep compatibility
|
||||||
|
network, address, err = parseBindNetFromAddr(address)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", fmt.Errorf("can't parse bind and network: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch network {
|
||||||
|
case "unix":
|
||||||
|
formattedAddress = "unix:" + address
|
||||||
|
case "tcp":
|
||||||
|
if strings.HasPrefix(address, ":") { // assume it's just a port e.g. :4259
|
||||||
|
formattedAddress = "http://localhost" + address
|
||||||
|
} else {
|
||||||
|
formattedAddress = "http://" + address
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
formattedAddress = fmt.Sprintf(`(%s) %s`, network, address)
|
||||||
|
}
|
||||||
|
|
||||||
|
ln, err := net.Listen(network, address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", fmt.Errorf("failed to bind to %s: %w", formattedAddress, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// additional permission handling for unix sockets
|
||||||
|
if network == "unix" {
|
||||||
|
mode, err := strconv.ParseUint(socketMode, 8, 0)
|
||||||
|
if err != nil {
|
||||||
|
ln.Close()
|
||||||
|
return nil, "", fmt.Errorf("could not parse socket mode %s: %w", socketMode, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = os.Chmod(address, os.FileMode(mode))
|
||||||
|
if err != nil {
|
||||||
|
err := fmt.Errorf("could not change socket mode: %w", err)
|
||||||
|
clErr := ln.Close()
|
||||||
|
if clErr != nil {
|
||||||
|
return nil, "", errors.Join(err, clErr)
|
||||||
|
}
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ln, formattedAddress, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,180 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/fs"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseBindNetFromAddr(t *testing.T) {
|
||||||
|
for _, tt := range []struct {
|
||||||
|
name string
|
||||||
|
address string
|
||||||
|
wantErr bool
|
||||||
|
network string
|
||||||
|
bind string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple tcp",
|
||||||
|
address: "localhost:9090",
|
||||||
|
wantErr: false,
|
||||||
|
network: "tcp",
|
||||||
|
bind: "localhost:9090",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "simple unix",
|
||||||
|
address: "unix:///tmp/foo.sock",
|
||||||
|
wantErr: false,
|
||||||
|
network: "unix",
|
||||||
|
bind: "/tmp/foo.sock",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid network",
|
||||||
|
address: "foo:///tmp/bar.sock",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tcp uri",
|
||||||
|
address: "tcp://[::]:9090",
|
||||||
|
wantErr: false,
|
||||||
|
network: "tcp",
|
||||||
|
bind: "[::]:9090",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "http uri",
|
||||||
|
address: "http://[::]:9090",
|
||||||
|
wantErr: false,
|
||||||
|
network: "tcp",
|
||||||
|
bind: "[::]:9090",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "https uri",
|
||||||
|
address: "https://[::]:9090",
|
||||||
|
wantErr: false,
|
||||||
|
network: "tcp",
|
||||||
|
bind: "[::]:9090",
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
network, bind, err := parseBindNetFromAddr(tt.address)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case tt.wantErr && err == nil:
|
||||||
|
t.Errorf("parseBindNetFromAddr(%q) should have errored but did not", tt.address)
|
||||||
|
case !tt.wantErr && err != nil:
|
||||||
|
t.Errorf("parseBindNetFromAddr(%q) threw an error: %v", tt.address, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if network != tt.network {
|
||||||
|
t.Errorf("parseBindNetFromAddr(%q) wanted network: %q, got: %q", tt.address, tt.network, network)
|
||||||
|
}
|
||||||
|
|
||||||
|
if bind != tt.bind {
|
||||||
|
t.Errorf("parseBindNetFromAddr(%q) wanted bind: %q, got: %q", tt.address, tt.bind, bind)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetupListener(t *testing.T) {
|
||||||
|
td := t.TempDir()
|
||||||
|
|
||||||
|
for _, tt := range []struct {
|
||||||
|
name string
|
||||||
|
network, address, socketMode string
|
||||||
|
wantErr bool
|
||||||
|
socketURLPrefix string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple tcp",
|
||||||
|
network: "",
|
||||||
|
address: ":0",
|
||||||
|
wantErr: false,
|
||||||
|
socketURLPrefix: "http://localhost:",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "simple unix",
|
||||||
|
network: "",
|
||||||
|
address: "unix://" + filepath.Join(td, "a"),
|
||||||
|
socketMode: "0770",
|
||||||
|
wantErr: false,
|
||||||
|
socketURLPrefix: "unix:" + filepath.Join(td, "a"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tcp",
|
||||||
|
network: "tcp",
|
||||||
|
address: ":0",
|
||||||
|
wantErr: false,
|
||||||
|
socketURLPrefix: "http://localhost:",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "udp",
|
||||||
|
network: "udp",
|
||||||
|
address: ":0",
|
||||||
|
wantErr: true,
|
||||||
|
socketURLPrefix: "http://localhost:",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unix socket",
|
||||||
|
network: "unix",
|
||||||
|
socketMode: "0770",
|
||||||
|
address: filepath.Join(td, "a"),
|
||||||
|
wantErr: false,
|
||||||
|
socketURLPrefix: "unix:" + filepath.Join(td, "a"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid socket mode",
|
||||||
|
network: "unix",
|
||||||
|
socketMode: "taco bell",
|
||||||
|
address: filepath.Join(td, "a"),
|
||||||
|
wantErr: true,
|
||||||
|
socketURLPrefix: "unix:" + filepath.Join(td, "a"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty socket mode",
|
||||||
|
network: "unix",
|
||||||
|
socketMode: "",
|
||||||
|
address: filepath.Join(td, "a"),
|
||||||
|
wantErr: true,
|
||||||
|
socketURLPrefix: "unix:" + filepath.Join(td, "a"),
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ln, socketURL, err := SetupListener(tt.network, tt.address, tt.socketMode)
|
||||||
|
switch {
|
||||||
|
case tt.wantErr && err == nil:
|
||||||
|
t.Errorf("SetupListener(%q, %q, %q) should have errored but did not", tt.network, tt.address, tt.socketMode)
|
||||||
|
case !tt.wantErr && err != nil:
|
||||||
|
t.Fatalf("SetupListener(%q, %q, %q) threw an error: %v", tt.network, tt.address, tt.socketMode, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ln != nil {
|
||||||
|
defer ln.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tt.wantErr && !strings.HasPrefix(socketURL, tt.socketURLPrefix) {
|
||||||
|
t.Errorf("SetupListener(%q, %q, %q) should have returned a URL with prefix %q but got: %q", tt.network, tt.address, tt.socketMode, tt.socketURLPrefix, socketURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.socketMode != "" {
|
||||||
|
mode, err := strconv.ParseUint(tt.socketMode, 8, 0)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sockPath := strings.TrimPrefix(socketURL, "unix:")
|
||||||
|
st, err := os.Stat(sockPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("can't os.Stat(%q): %v", sockPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if st.Mode().Perm() != fs.FileMode(mode) {
|
||||||
|
t.Errorf("file mode of %q should be %s but is actually %s", sockPath, strconv.FormatUint(mode, 8), strconv.FormatUint(uint64(st.Mode()), 8))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user