diff --git a/dns_server.go b/dns_server.go index 1b97d7a4f..710f6ed85 100644 --- a/dns_server.go +++ b/dns_server.go @@ -22,26 +22,39 @@ var dnsAddr string type dnsRecords struct { sync.RWMutex - dnsMap map[string]string + l *logrus.Logger + dnsMap4 map[string]netip.Addr + dnsMap6 map[string]netip.Addr hostMap *HostMap myVpnAddrsTable *bart.Table[struct{}] } -func newDnsRecords(cs *CertState, hostMap *HostMap) *dnsRecords { +func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords { return &dnsRecords{ - dnsMap: make(map[string]string), + l: l, + dnsMap4: make(map[string]netip.Addr), + dnsMap6: make(map[string]netip.Addr), hostMap: hostMap, myVpnAddrsTable: cs.myVpnAddrsTable, } } -func (d *dnsRecords) Query(data string) string { +func (d *dnsRecords) Query(q uint16, data string) netip.Addr { + data = strings.ToLower(data) d.RLock() defer d.RUnlock() - if r, ok := d.dnsMap[strings.ToLower(data)]; ok { - return r + switch q { + case dns.TypeA: + if r, ok := d.dnsMap4[data]; ok { + return r + } + case dns.TypeAAAA: + if r, ok := d.dnsMap6[data]; ok { + return r + } } - return "" + + return netip.Addr{} } func (d *dnsRecords) QueryCert(data string) string { @@ -67,48 +80,62 @@ func (d *dnsRecords) QueryCert(data string) string { return string(b) } -func (d *dnsRecords) Add(host, data string) { +// Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host` +func (d *dnsRecords) Add(host string, addresses []netip.Addr) { + host = strings.ToLower(host) d.Lock() defer d.Unlock() - d.dnsMap[strings.ToLower(host)] = data + haveV4 := false + haveV6 := false + for _, addr := range addresses { + if addr.Is4() && !haveV4 { + d.dnsMap4[host] = addr + haveV4 = true + } else if addr.Is6() && !haveV6 { + d.dnsMap6[host] = addr + haveV6 = true + } + if haveV4 && haveV6 { + break + } + } } -func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) { +func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool { + a, _, _ := net.SplitHostPort(addr) + b, err := netip.ParseAddr(a) + if err != nil { + return false + } + + if b.IsLoopback() { + return true + } + + _, found := d.myVpnAddrsTable.Lookup(b) + return found //if we found it in this table, it's good +} + +func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) { for _, q := range m.Question { switch q.Qtype { - case dns.TypeA: - l.Debugf("Query for A %s", q.Name) - ip := dnsR.Query(q.Name) - if ip != "" { - rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, ip)) - if err == nil { - m.Answer = append(m.Answer, rr) - } - } - case dns.TypeAAAA: - l.Debugf("Query for AAAA %s", q.Name) - ip := dnsR.Query(q.Name) - if ip != "" { - rr, err := dns.NewRR(fmt.Sprintf("%s AAAA %s", q.Name, ip)) + case dns.TypeA, dns.TypeAAAA: + qType := dns.TypeToString[q.Qtype] + d.l.Debugf("Query for %s %s", qType, q.Name) + ip := d.Query(q.Qtype, q.Name) + if ip.IsValid() { + rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip)) if err == nil { m.Answer = append(m.Answer, rr) } } case dns.TypeTXT: - a, _, _ := net.SplitHostPort(w.RemoteAddr().String()) - b, err := netip.ParseAddr(a) - if err != nil { + // We only answer these queries from nebula nodes or localhost + if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) { return } - - // We don't answer these queries from non nebula nodes or localhost - //l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR) - _, found := dnsR.myVpnAddrsTable.Lookup(b) - if !found && a != "127.0.0.1" { - return - } - l.Debugf("Query for TXT %s", q.Name) - ip := dnsR.QueryCert(q.Name) + d.l.Debugf("Query for TXT %s", q.Name) + ip := d.QueryCert(q.Name) if ip != "" { rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip)) if err == nil { @@ -123,26 +150,24 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) { } } -func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) { +func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) { m := new(dns.Msg) m.SetReply(r) m.Compress = false switch r.Opcode { case dns.OpcodeQuery: - parseQuery(l, m, w) + d.parseQuery(m, w) } w.WriteMsg(m) } func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() { - dnsR = newDnsRecords(cs, hostMap) + dnsR = newDnsRecords(l, cs, hostMap) // attach request handler func - dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) { - handleDnsRequest(l, w, r) - }) + dns.HandleFunc(".", dnsR.handleDnsRequest) c.RegisterReloadCallback(func(c *config.C) { reloadDns(l, c) diff --git a/dns_server_test.go b/dns_server_test.go index ce0f419ac..f4643a36a 100644 --- a/dns_server_test.go +++ b/dns_server_test.go @@ -1,23 +1,38 @@ package nebula import ( + "net/netip" "testing" "github.com/miekg/dns" + "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/stretchr/testify/assert" ) func TestParsequery(t *testing.T) { - //TODO: This test is basically pointless + l := logrus.New() hostMap := &HostMap{} - ds := newDnsRecords(&CertState{}, hostMap) - ds.Add("test.com.com", "1.2.3.4") + ds := newDnsRecords(l, &CertState{}, hostMap) + addrs := []netip.Addr{ + netip.MustParseAddr("1.2.3.4"), + netip.MustParseAddr("1.2.3.5"), + netip.MustParseAddr("fd01::24"), + netip.MustParseAddr("fd01::25"), + } + ds.Add("test.com.com", addrs) - m := new(dns.Msg) + m := &dns.Msg{} m.SetQuestion("test.com.com", dns.TypeA) + ds.parseQuery(m, nil) + assert.NotNil(t, m.Answer) + assert.Equal(t, "1.2.3.4", m.Answer[0].(*dns.A).A.String()) - //parseQuery(m) + m = &dns.Msg{} + m.SetQuestion("test.com.com", dns.TypeAAAA) + ds.parseQuery(m, nil) + assert.NotNil(t, m.Answer) + assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String()) } func Test_getDnsServerAddr(t *testing.T) { diff --git a/hostmap.go b/hostmap.go index 824d72251..d5fa032e2 100644 --- a/hostmap.go +++ b/hostmap.go @@ -489,6 +489,10 @@ func (hm *HostMap) queryVpnAddr(vpnIp netip.Addr, promoteIfce *Interface) *HostI // unlockedAddHostInfo assumes you have a write-lock and will add a hostinfo object to the hostmap Indexes and RemoteIndexes maps. // If an entry exists for the Hosts table (vpnIp -> hostinfo) then the provided hostinfo will be made primary func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { + if f.serveDns { + remoteCert := hostinfo.ConnectionState.peerCert + dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs) + } for _, addr := range hostinfo.vpnAddrs { hm.unlockedInnerAddHostInfo(addr, hostinfo, f) } @@ -504,11 +508,6 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { } func (hm *HostMap) unlockedInnerAddHostInfo(vpnAddr netip.Addr, hostinfo *HostInfo, f *Interface) { - if f.serveDns { - remoteCert := hostinfo.ConnectionState.peerCert - dnsR.Add(remoteCert.Certificate.Name()+".", vpnAddr.String()) - } - existing := hm.Hosts[vpnAddr] hm.Hosts[vpnAddr] = hostinfo