Skip to content

Commit

Permalink
cert-v2: dns support for v4 and v6 addresses (#1249)
Browse files Browse the repository at this point in the history
* dns support for v4 and v6 addresses

* fix comment
  • Loading branch information
JackDoanRivian authored Oct 11, 2024
1 parent d6f1b51 commit 8ccdced
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 52 deletions.
109 changes: 67 additions & 42 deletions dns_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,39 @@ var dnsAddr string

type dnsRecords struct {
sync.RWMutex
dnsMap map[string]string
l *logrus.Logger
dnsMap4 map[string]netip.Addr
dnsMap6 map[string]netip.Addr
hostMap *HostMap
myVpnAddrsTable *bart.Table[struct{}]
}

func newDnsRecords(cs *CertState, hostMap *HostMap) *dnsRecords {
func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
return &dnsRecords{
dnsMap: make(map[string]string),
l: l,
dnsMap4: make(map[string]netip.Addr),
dnsMap6: make(map[string]netip.Addr),
hostMap: hostMap,
myVpnAddrsTable: cs.myVpnAddrsTable,
}
}

func (d *dnsRecords) Query(data string) string {
func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
data = strings.ToLower(data)
d.RLock()
defer d.RUnlock()
if r, ok := d.dnsMap[strings.ToLower(data)]; ok {
return r
switch q {
case dns.TypeA:
if r, ok := d.dnsMap4[data]; ok {
return r
}
case dns.TypeAAAA:
if r, ok := d.dnsMap6[data]; ok {
return r
}
}
return ""

return netip.Addr{}
}

func (d *dnsRecords) QueryCert(data string) string {
Expand All @@ -67,48 +80,62 @@ func (d *dnsRecords) QueryCert(data string) string {
return string(b)
}

func (d *dnsRecords) Add(host, data string) {
// Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host`
func (d *dnsRecords) Add(host string, addresses []netip.Addr) {
host = strings.ToLower(host)
d.Lock()
defer d.Unlock()
d.dnsMap[strings.ToLower(host)] = data
haveV4 := false
haveV6 := false
for _, addr := range addresses {
if addr.Is4() && !haveV4 {
d.dnsMap4[host] = addr
haveV4 = true
} else if addr.Is6() && !haveV6 {
d.dnsMap6[host] = addr
haveV6 = true
}
if haveV4 && haveV6 {
break
}
}
}

func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
a, _, _ := net.SplitHostPort(addr)
b, err := netip.ParseAddr(a)
if err != nil {
return false
}

if b.IsLoopback() {
return true
}

_, found := d.myVpnAddrsTable.Lookup(b)
return found //if we found it in this table, it's good
}

func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
for _, q := range m.Question {
switch q.Qtype {
case dns.TypeA:
l.Debugf("Query for A %s", q.Name)
ip := dnsR.Query(q.Name)
if ip != "" {
rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, ip))
if err == nil {
m.Answer = append(m.Answer, rr)
}
}
case dns.TypeAAAA:
l.Debugf("Query for AAAA %s", q.Name)
ip := dnsR.Query(q.Name)
if ip != "" {
rr, err := dns.NewRR(fmt.Sprintf("%s AAAA %s", q.Name, ip))
case dns.TypeA, dns.TypeAAAA:
qType := dns.TypeToString[q.Qtype]
d.l.Debugf("Query for %s %s", qType, q.Name)
ip := d.Query(q.Qtype, q.Name)
if ip.IsValid() {
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
if err == nil {
m.Answer = append(m.Answer, rr)
}
}
case dns.TypeTXT:
a, _, _ := net.SplitHostPort(w.RemoteAddr().String())
b, err := netip.ParseAddr(a)
if err != nil {
// We only answer these queries from nebula nodes or localhost
if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) {
return
}

// We don't answer these queries from non nebula nodes or localhost
//l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR)
_, found := dnsR.myVpnAddrsTable.Lookup(b)
if !found && a != "127.0.0.1" {
return
}
l.Debugf("Query for TXT %s", q.Name)
ip := dnsR.QueryCert(q.Name)
d.l.Debugf("Query for TXT %s", q.Name)
ip := d.QueryCert(q.Name)
if ip != "" {
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
if err == nil {
Expand All @@ -123,26 +150,24 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
}
}

func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) {
func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Compress = false

switch r.Opcode {
case dns.OpcodeQuery:
parseQuery(l, m, w)
d.parseQuery(m, w)
}

w.WriteMsg(m)
}

func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
dnsR = newDnsRecords(cs, hostMap)
dnsR = newDnsRecords(l, cs, hostMap)

// attach request handler func
dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
handleDnsRequest(l, w, r)
})
dns.HandleFunc(".", dnsR.handleDnsRequest)

c.RegisterReloadCallback(func(c *config.C) {
reloadDns(l, c)
Expand Down
25 changes: 20 additions & 5 deletions dns_server_test.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,38 @@
package nebula

import (
"net/netip"
"testing"

"github.com/miekg/dns"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/stretchr/testify/assert"
)

func TestParsequery(t *testing.T) {
//TODO: This test is basically pointless
l := logrus.New()
hostMap := &HostMap{}
ds := newDnsRecords(&CertState{}, hostMap)
ds.Add("test.com.com", "1.2.3.4")
ds := newDnsRecords(l, &CertState{}, hostMap)
addrs := []netip.Addr{
netip.MustParseAddr("1.2.3.4"),
netip.MustParseAddr("1.2.3.5"),
netip.MustParseAddr("fd01::24"),
netip.MustParseAddr("fd01::25"),
}
ds.Add("test.com.com", addrs)

m := new(dns.Msg)
m := &dns.Msg{}
m.SetQuestion("test.com.com", dns.TypeA)
ds.parseQuery(m, nil)
assert.NotNil(t, m.Answer)
assert.Equal(t, "1.2.3.4", m.Answer[0].(*dns.A).A.String())

//parseQuery(m)
m = &dns.Msg{}
m.SetQuestion("test.com.com", dns.TypeAAAA)
ds.parseQuery(m, nil)
assert.NotNil(t, m.Answer)
assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String())
}

func Test_getDnsServerAddr(t *testing.T) {
Expand Down
9 changes: 4 additions & 5 deletions hostmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,10 @@ func (hm *HostMap) queryVpnAddr(vpnIp netip.Addr, promoteIfce *Interface) *HostI
// unlockedAddHostInfo assumes you have a write-lock and will add a hostinfo object to the hostmap Indexes and RemoteIndexes maps.
// If an entry exists for the Hosts table (vpnIp -> hostinfo) then the provided hostinfo will be made primary
func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
if f.serveDns {
remoteCert := hostinfo.ConnectionState.peerCert
dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs)
}
for _, addr := range hostinfo.vpnAddrs {
hm.unlockedInnerAddHostInfo(addr, hostinfo, f)
}
Expand All @@ -504,11 +508,6 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
}

func (hm *HostMap) unlockedInnerAddHostInfo(vpnAddr netip.Addr, hostinfo *HostInfo, f *Interface) {
if f.serveDns {
remoteCert := hostinfo.ConnectionState.peerCert
dnsR.Add(remoteCert.Certificate.Name()+".", vpnAddr.String())
}

existing := hm.Hosts[vpnAddr]
hm.Hosts[vpnAddr] = hostinfo

Expand Down

0 comments on commit 8ccdced

Please sign in to comment.