feat(internal): move SetupListener from main

Signed-off-by: Xe Iaso <me@xeiaso.net>
This commit is contained in:
Xe Iaso
2026-04-21 14:57:59 -04:00
parent 8af9845117
commit 298fc0b847
2 changed files with 272 additions and 0 deletions
+92
View File
@@ -0,0 +1,92 @@
package internal
import (
"errors"
"fmt"
"net"
"net/url"
"os"
"strconv"
"strings"
)
// parseBindNetFromAddr determine bind network and address based on the given network and address.
func parseBindNetFromAddr(address string) (string, string, error) {
defaultScheme := "http://"
if !strings.Contains(address, "://") {
if strings.HasPrefix(address, ":") {
address = defaultScheme + "localhost" + address
} else {
address = defaultScheme + address
}
}
bindUri, err := url.Parse(address)
if err != nil {
return "", "", fmt.Errorf("failed to parse bind URL: %w", err)
}
switch bindUri.Scheme {
case "unix":
return "unix", bindUri.Path, nil
case "tcp", "http", "https":
return "tcp", bindUri.Host, nil
default:
return "", "", fmt.Errorf("unsupported network scheme %s in address %s", bindUri.Scheme, address)
}
}
// SetupListener sets up a network listener based on the input from configuration
// envvars. It returns a network listener and the URL to that listener or an error.
func SetupListener(network, address, socketMode string) (net.Listener, string, error) {
formattedAddress := ""
var err error
if network == "" {
// keep compatibility
network, address, err = parseBindNetFromAddr(address)
}
if err != nil {
return nil, "", fmt.Errorf("can't parse bind and network: %w", err)
}
switch network {
case "unix":
formattedAddress = "unix:" + address
case "tcp":
if strings.HasPrefix(address, ":") { // assume it's just a port e.g. :4259
formattedAddress = "http://localhost" + address
} else {
formattedAddress = "http://" + address
}
default:
formattedAddress = fmt.Sprintf(`(%s) %s`, network, address)
}
ln, err := net.Listen(network, address)
if err != nil {
return nil, "", fmt.Errorf("failed to bind to %s: %w", formattedAddress, err)
}
// additional permission handling for unix sockets
if network == "unix" {
mode, err := strconv.ParseUint(socketMode, 8, 0)
if err != nil {
ln.Close()
return nil, "", fmt.Errorf("could not parse socket mode %s: %w", socketMode, err)
}
err = os.Chmod(address, os.FileMode(mode))
if err != nil {
err := fmt.Errorf("could not change socket mode: %w", err)
clErr := ln.Close()
if clErr != nil {
return nil, "", errors.Join(err, clErr)
}
return nil, "", err
}
}
return ln, formattedAddress, nil
}
+180
View File
@@ -0,0 +1,180 @@
package internal
import (
"io/fs"
"os"
"path/filepath"
"strconv"
"strings"
"testing"
)
func TestParseBindNetFromAddr(t *testing.T) {
for _, tt := range []struct {
name string
address string
wantErr bool
network string
bind string
}{
{
name: "simple tcp",
address: "localhost:9090",
wantErr: false,
network: "tcp",
bind: "localhost:9090",
},
{
name: "simple unix",
address: "unix:///tmp/foo.sock",
wantErr: false,
network: "unix",
bind: "/tmp/foo.sock",
},
{
name: "invalid network",
address: "foo:///tmp/bar.sock",
wantErr: true,
},
{
name: "tcp uri",
address: "tcp://[::]:9090",
wantErr: false,
network: "tcp",
bind: "[::]:9090",
},
{
name: "http uri",
address: "http://[::]:9090",
wantErr: false,
network: "tcp",
bind: "[::]:9090",
},
{
name: "https uri",
address: "https://[::]:9090",
wantErr: false,
network: "tcp",
bind: "[::]:9090",
},
} {
t.Run(tt.name, func(t *testing.T) {
network, bind, err := parseBindNetFromAddr(tt.address)
switch {
case tt.wantErr && err == nil:
t.Errorf("parseBindNetFromAddr(%q) should have errored but did not", tt.address)
case !tt.wantErr && err != nil:
t.Errorf("parseBindNetFromAddr(%q) threw an error: %v", tt.address, err)
}
if network != tt.network {
t.Errorf("parseBindNetFromAddr(%q) wanted network: %q, got: %q", tt.address, tt.network, network)
}
if bind != tt.bind {
t.Errorf("parseBindNetFromAddr(%q) wanted bind: %q, got: %q", tt.address, tt.bind, bind)
}
})
}
}
func TestSetupListener(t *testing.T) {
td := t.TempDir()
for _, tt := range []struct {
name string
network, address, socketMode string
wantErr bool
socketURLPrefix string
}{
{
name: "simple tcp",
network: "",
address: ":0",
wantErr: false,
socketURLPrefix: "http://localhost:",
},
{
name: "simple unix",
network: "",
address: "unix://" + filepath.Join(td, "a"),
socketMode: "0770",
wantErr: false,
socketURLPrefix: "unix:" + filepath.Join(td, "a"),
},
{
name: "tcp",
network: "tcp",
address: ":0",
wantErr: false,
socketURLPrefix: "http://localhost:",
},
{
name: "udp",
network: "udp",
address: ":0",
wantErr: true,
socketURLPrefix: "http://localhost:",
},
{
name: "unix socket",
network: "unix",
socketMode: "0770",
address: filepath.Join(td, "a"),
wantErr: false,
socketURLPrefix: "unix:" + filepath.Join(td, "a"),
},
{
name: "invalid socket mode",
network: "unix",
socketMode: "taco bell",
address: filepath.Join(td, "a"),
wantErr: true,
socketURLPrefix: "unix:" + filepath.Join(td, "a"),
},
{
name: "empty socket mode",
network: "unix",
socketMode: "",
address: filepath.Join(td, "a"),
wantErr: true,
socketURLPrefix: "unix:" + filepath.Join(td, "a"),
},
} {
t.Run(tt.name, func(t *testing.T) {
ln, socketURL, err := SetupListener(tt.network, tt.address, tt.socketMode)
switch {
case tt.wantErr && err == nil:
t.Errorf("SetupListener(%q, %q, %q) should have errored but did not", tt.network, tt.address, tt.socketMode)
case !tt.wantErr && err != nil:
t.Fatalf("SetupListener(%q, %q, %q) threw an error: %v", tt.network, tt.address, tt.socketMode, err)
}
if ln != nil {
defer ln.Close()
}
if !tt.wantErr && !strings.HasPrefix(socketURL, tt.socketURLPrefix) {
t.Errorf("SetupListener(%q, %q, %q) should have returned a URL with prefix %q but got: %q", tt.network, tt.address, tt.socketMode, tt.socketURLPrefix, socketURL)
}
if tt.socketMode != "" {
mode, err := strconv.ParseUint(tt.socketMode, 8, 0)
if err != nil {
return
}
sockPath := strings.TrimPrefix(socketURL, "unix:")
st, err := os.Stat(sockPath)
if err != nil {
t.Fatalf("can't os.Stat(%q): %v", sockPath, err)
}
if st.Mode().Perm() != fs.FileMode(mode) {
t.Errorf("file mode of %q should be %s but is actually %s", sockPath, strconv.FormatUint(mode, 8), strconv.FormatUint(uint64(st.Mode()), 8))
}
}
})
}
}