diff --git a/util/addr/addr.go b/util/addr/addr.go index 8928dc0e1a..b7cf8a2dbf 100644 --- a/util/addr/addr.go +++ b/util/addr/addr.go @@ -35,8 +35,8 @@ func IsLocal(addr string) bool { } // Extract returns a valid IP address. If the address provided is a valid -// address, it will be returned directly. Otherwise the available interfaces -// be itterated over to find an IP address, prefferably private. +// address, it will be returned directly. Otherwise, the available interfaces +// will be iterated over to find an IP address, preferably private. func Extract(addr string) (string, error) { // if addr is already specified then it's directly returned if len(addr) > 0 && (addr != "0.0.0.0" && addr != "[::]" && addr != "::") { @@ -115,10 +115,12 @@ func IPs() []string { return ipAddrs } -// findIP will return the first private IP available in the list, -// if no private IP is available it will return a public IP if present. +// findIP will return the first private IP available in the list. +// If no private IP is available it will return the first public IP, if present. +// If no public IP is available, it will return the first loopback IP, if present. func findIP(addresses []net.Addr) (net.IP, error) { var publicIP net.IP + var localIP net.IP for _, rawAddr := range addresses { var ip net.IP @@ -131,8 +133,17 @@ func findIP(addresses []net.Addr) (net.IP, error) { continue } + if ip.IsLoopback() { + if localIP == nil { + localIP = ip + } + continue + } + if !ip.IsPrivate() { - publicIP = ip + if publicIP == nil { + publicIP = ip + } continue } @@ -145,5 +156,10 @@ func findIP(addresses []net.Addr) (net.IP, error) { return publicIP, nil } + // Return local IP + if len(localIP) > 0 { + return localIP, nil + } + return nil, ErrIPNotFound } diff --git a/util/addr/addr_test.go b/util/addr/addr_test.go index aed1d1cf2b..4576ef904e 100644 --- a/util/addr/addr_test.go +++ b/util/addr/addr_test.go @@ -1,6 +1,7 @@ package addr import ( + "github.com/stretchr/testify/assert" "net" "testing" ) @@ -54,3 +55,78 @@ func TestExtractor(t *testing.T) { } } } + +func TestFindIP(t *testing.T) { + localhost, _ := net.ResolveIPAddr("ip", "127.0.0.1") + localhostIPv6, _ := net.ResolveIPAddr("ip", "::1") + privateIP, _ := net.ResolveIPAddr("ip", "10.0.0.1") + publicIP, _ := net.ResolveIPAddr("ip", "100.0.0.1") + publicIPv6, _ := net.ResolveIPAddr("ip", "2001:0db8:85a3:0000:0000:8a2e:0370:7334") + + testCases := []struct { + addrs []net.Addr + ip net.IP + errMsg string + }{ + { + addrs: []net.Addr{}, + ip: nil, + errMsg: ErrIPNotFound.Error(), + }, + { + addrs: []net.Addr{localhost}, + ip: localhost.IP, + }, + { + addrs: []net.Addr{localhost, localhostIPv6}, + ip: localhost.IP, + }, + { + addrs: []net.Addr{localhostIPv6}, + ip: localhostIPv6.IP, + }, + { + addrs: []net.Addr{privateIP, localhost}, + ip: privateIP.IP, + }, + { + addrs: []net.Addr{privateIP, publicIP, localhost}, + ip: privateIP.IP, + }, + { + addrs: []net.Addr{publicIP, privateIP, localhost}, + ip: privateIP.IP, + }, + { + addrs: []net.Addr{publicIP, localhost}, + ip: publicIP.IP, + }, + { + addrs: []net.Addr{publicIP, localhostIPv6}, + ip: publicIP.IP, + }, + { + addrs: []net.Addr{localhostIPv6, publicIP}, + ip: publicIP.IP, + }, + { + addrs: []net.Addr{localhostIPv6, publicIPv6, publicIP}, + ip: publicIPv6.IP, + }, + { + addrs: []net.Addr{publicIP, publicIPv6}, + ip: publicIP.IP, + }, + } + + for _, tc := range testCases { + ip, err := findIP(tc.addrs) + if tc.errMsg == "" { + assert.Nil(t, err) + assert.Equal(t, tc.ip.String(), ip.String()) + } else { + assert.NotNil(t, err) + assert.Equal(t, tc.errMsg, err.Error()) + } + } +}