diff --git a/internal/setuplistener.go b/internal/setuplistener.go new file mode 100644 index 00000000..fc076657 --- /dev/null +++ b/internal/setuplistener.go @@ -0,0 +1,92 @@ +package internal + +import ( + "errors" + "fmt" + "net" + "net/url" + "os" + "strconv" + "strings" +) + +// parseBindNetFromAddr determine bind network and address based on the given network and address. +func parseBindNetFromAddr(address string) (string, string, error) { + defaultScheme := "http://" + if !strings.Contains(address, "://") { + if strings.HasPrefix(address, ":") { + address = defaultScheme + "localhost" + address + } else { + address = defaultScheme + address + } + } + + bindUri, err := url.Parse(address) + if err != nil { + return "", "", fmt.Errorf("failed to parse bind URL: %w", err) + } + + switch bindUri.Scheme { + case "unix": + return "unix", bindUri.Path, nil + case "tcp", "http", "https": + return "tcp", bindUri.Host, nil + default: + return "", "", fmt.Errorf("unsupported network scheme %s in address %s", bindUri.Scheme, address) + } +} + +// SetupListener sets up a network listener based on the input from configuration +// envvars. It returns a network listener and the URL to that listener or an error. +func SetupListener(network, address, socketMode string) (net.Listener, string, error) { + formattedAddress := "" + var err error + + if network == "" { + // keep compatibility + network, address, err = parseBindNetFromAddr(address) + } + + if err != nil { + return nil, "", fmt.Errorf("can't parse bind and network: %w", err) + } + + switch network { + case "unix": + formattedAddress = "unix:" + address + case "tcp": + if strings.HasPrefix(address, ":") { // assume it's just a port e.g. :4259 + formattedAddress = "http://localhost" + address + } else { + formattedAddress = "http://" + address + } + default: + formattedAddress = fmt.Sprintf(`(%s) %s`, network, address) + } + + ln, err := net.Listen(network, address) + if err != nil { + return nil, "", fmt.Errorf("failed to bind to %s: %w", formattedAddress, err) + } + + // additional permission handling for unix sockets + if network == "unix" { + mode, err := strconv.ParseUint(socketMode, 8, 0) + if err != nil { + ln.Close() + return nil, "", fmt.Errorf("could not parse socket mode %s: %w", socketMode, err) + } + + err = os.Chmod(address, os.FileMode(mode)) + if err != nil { + err := fmt.Errorf("could not change socket mode: %w", err) + clErr := ln.Close() + if clErr != nil { + return nil, "", errors.Join(err, clErr) + } + return nil, "", err + } + } + + return ln, formattedAddress, nil +} diff --git a/internal/setuplistener_test.go b/internal/setuplistener_test.go new file mode 100644 index 00000000..026be531 --- /dev/null +++ b/internal/setuplistener_test.go @@ -0,0 +1,180 @@ +package internal + +import ( + "io/fs" + "os" + "path/filepath" + "strconv" + "strings" + "testing" +) + +func TestParseBindNetFromAddr(t *testing.T) { + for _, tt := range []struct { + name string + address string + wantErr bool + network string + bind string + }{ + { + name: "simple tcp", + address: "localhost:9090", + wantErr: false, + network: "tcp", + bind: "localhost:9090", + }, + { + name: "simple unix", + address: "unix:///tmp/foo.sock", + wantErr: false, + network: "unix", + bind: "/tmp/foo.sock", + }, + { + name: "invalid network", + address: "foo:///tmp/bar.sock", + wantErr: true, + }, + { + name: "tcp uri", + address: "tcp://[::]:9090", + wantErr: false, + network: "tcp", + bind: "[::]:9090", + }, + { + name: "http uri", + address: "http://[::]:9090", + wantErr: false, + network: "tcp", + bind: "[::]:9090", + }, + { + name: "https uri", + address: "https://[::]:9090", + wantErr: false, + network: "tcp", + bind: "[::]:9090", + }, + } { + t.Run(tt.name, func(t *testing.T) { + network, bind, err := parseBindNetFromAddr(tt.address) + + switch { + case tt.wantErr && err == nil: + t.Errorf("parseBindNetFromAddr(%q) should have errored but did not", tt.address) + case !tt.wantErr && err != nil: + t.Errorf("parseBindNetFromAddr(%q) threw an error: %v", tt.address, err) + } + + if network != tt.network { + t.Errorf("parseBindNetFromAddr(%q) wanted network: %q, got: %q", tt.address, tt.network, network) + } + + if bind != tt.bind { + t.Errorf("parseBindNetFromAddr(%q) wanted bind: %q, got: %q", tt.address, tt.bind, bind) + } + }) + } +} + +func TestSetupListener(t *testing.T) { + td := t.TempDir() + + for _, tt := range []struct { + name string + network, address, socketMode string + wantErr bool + socketURLPrefix string + }{ + { + name: "simple tcp", + network: "", + address: ":0", + wantErr: false, + socketURLPrefix: "http://localhost:", + }, + { + name: "simple unix", + network: "", + address: "unix://" + filepath.Join(td, "a"), + socketMode: "0770", + wantErr: false, + socketURLPrefix: "unix:" + filepath.Join(td, "a"), + }, + { + name: "tcp", + network: "tcp", + address: ":0", + wantErr: false, + socketURLPrefix: "http://localhost:", + }, + { + name: "udp", + network: "udp", + address: ":0", + wantErr: true, + socketURLPrefix: "http://localhost:", + }, + { + name: "unix socket", + network: "unix", + socketMode: "0770", + address: filepath.Join(td, "a"), + wantErr: false, + socketURLPrefix: "unix:" + filepath.Join(td, "a"), + }, + { + name: "invalid socket mode", + network: "unix", + socketMode: "taco bell", + address: filepath.Join(td, "a"), + wantErr: true, + socketURLPrefix: "unix:" + filepath.Join(td, "a"), + }, + { + name: "empty socket mode", + network: "unix", + socketMode: "", + address: filepath.Join(td, "a"), + wantErr: true, + socketURLPrefix: "unix:" + filepath.Join(td, "a"), + }, + } { + t.Run(tt.name, func(t *testing.T) { + ln, socketURL, err := SetupListener(tt.network, tt.address, tt.socketMode) + switch { + case tt.wantErr && err == nil: + t.Errorf("SetupListener(%q, %q, %q) should have errored but did not", tt.network, tt.address, tt.socketMode) + case !tt.wantErr && err != nil: + t.Fatalf("SetupListener(%q, %q, %q) threw an error: %v", tt.network, tt.address, tt.socketMode, err) + } + + if ln != nil { + defer ln.Close() + } + + if !tt.wantErr && !strings.HasPrefix(socketURL, tt.socketURLPrefix) { + t.Errorf("SetupListener(%q, %q, %q) should have returned a URL with prefix %q but got: %q", tt.network, tt.address, tt.socketMode, tt.socketURLPrefix, socketURL) + } + + if tt.socketMode != "" { + mode, err := strconv.ParseUint(tt.socketMode, 8, 0) + if err != nil { + return + } + + sockPath := strings.TrimPrefix(socketURL, "unix:") + st, err := os.Stat(sockPath) + if err != nil { + t.Fatalf("can't os.Stat(%q): %v", sockPath, err) + } + + if st.Mode().Perm() != fs.FileMode(mode) { + t.Errorf("file mode of %q should be %s but is actually %s", sockPath, strconv.FormatUint(mode, 8), strconv.FormatUint(uint64(st.Mode()), 8)) + } + } + }) + } +}