diff --git a/.travis.yml b/.travis.yml index b93bbb23..0926ba5b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -22,3 +22,4 @@ before_script: script: - go test -race . - go test -tags='cluster' -short -race -v ./... + - GOMODULE111=off go test . diff --git a/CHANGELOG.md b/CHANGELOG.md index 29c74d15..8076c26d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,12 @@ All notable changes to this project will be documented in this file. This project adheres to [Semantic Versioning](http://semver.org/). +## v6.2.0 - 2020-03-16 + +- Backoff v4 +- Reworked cluster discovery +- Fix rare connection goroutine leak + ## v6.1.0 - 2020-03-09 - Reworked and tested new connection pools with multiple queries per connection diff --git a/README.md b/README.md index 3b07fca3..366c3777 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ ![RethinkDB-go Logo](https://raw.github.com/wiki/rethinkdb/rethinkdb-go/gopher-and-thinker-s.png "Golang Gopher and RethinkDB Thinker") -Current version: v6.1.0 (RethinkDB v2.4) +Current version: v6.2.0 (RethinkDB v2.4) Please note that this version of the driver only supports versions of RethinkDB using the v0.4 protocol (any versions of the driver older than RethinkDB 2.0 will not work). diff --git a/cluster.go b/cluster.go index 3a8c0096..9b2d2e64 100644 --- a/cluster.go +++ b/cluster.go @@ -1,16 +1,25 @@ package rethinkdb import ( + "errors" "fmt" + "sort" "strings" "sync" "sync/atomic" "time" - backoff "github.com/cenkalti/backoff/v4" "github.com/hailocab/go-hostpool" "github.com/sirupsen/logrus" "golang.org/x/net/context" + "gopkg.in/cenkalti/backoff.v4" +) + +var errClusterClosed = errors.New("rethinkdb: cluster is closed") + +const ( + clusterWorking = 0 + clusterClosed = 1 ) // A Cluster represents a connection to a RethinkDB cluster, a cluster is created @@ -27,34 +36,46 @@ type Cluster struct { seeds []Host // Initial host nodes specified by user. hp hostpool.HostPool nodes map[string]*Node // Active nodes in cluster. - closed bool + closed int32 // 0 - working, 1 - closed - nodeIndex int64 + connFactory connFactory + + discoverInterval time.Duration } // NewCluster creates a new cluster by connecting to the given hosts. func NewCluster(hosts []Host, opts *ConnectOpts) (*Cluster, error) { c := &Cluster{ - hp: hostpool.NewEpsilonGreedy([]string{}, opts.HostDecayDuration, &hostpool.LinearEpsilonValueCalculator{}), - seeds: hosts, - opts: opts, + hp: newHostPool(opts), + seeds: hosts, + opts: opts, + closed: clusterWorking, + connFactory: NewConnection, } - // Attempt to connect to each host and discover any additional hosts if host - // discovery is enabled - if err := c.connectNodes(c.getSeeds()); err != nil { + err := c.run() + if err != nil { return nil, err } - if !c.IsConnected() { - return nil, ErrNoConnectionsStarted - } + return c, nil +} - if opts.DiscoverHosts { - go c.discover() +func newHostPool(opts *ConnectOpts) hostpool.HostPool { + return hostpool.NewEpsilonGreedy([]string{}, opts.HostDecayDuration, &hostpool.LinearEpsilonValueCalculator{}) +} + +func (c *Cluster) run() error { + // Attempt to connect to each host and discover any additional hosts if host + // discovery is enabled + if err := c.connectCluster(); err != nil { + return err } - return c, nil + if !c.IsConnected() { + return ErrNoConnectionsStarted + } + return nil } // Query executes a ReQL query using the cluster to connect to the database @@ -148,7 +169,7 @@ func (c *Cluster) SetMaxOpenConns(n int) { // Close closes the cluster func (c *Cluster) Close(optArgs ...CloseOpts) error { - if c.closed { + if c.isClosed() { return nil } @@ -160,24 +181,38 @@ func (c *Cluster) Close(optArgs ...CloseOpts) error { } c.hp.Close() - c.closed = true + atomic.StoreInt32(&c.closed, clusterClosed) return nil } +func (c *Cluster) isClosed() bool { + return atomic.LoadInt32(&c.closed) == clusterClosed +} + // discover attempts to find new nodes in the cluster using the current nodes func (c *Cluster) discover() { // Keep retrying with exponential backoff. b := backoff.NewExponentialBackOff() // Never finish retrying (max interval is still 60s) b.MaxElapsedTime = 0 + if c.discoverInterval != 0 { + b.InitialInterval = c.discoverInterval + } // Keep trying to discover new nodes for { - backoff.RetryNotify(func() error { + if c.isClosed() { + return + } + + _ = backoff.RetryNotify(func() error { + if c.isClosed() { + return backoff.Permanent(errClusterClosed) + } // If no hosts try seeding nodes if len(c.GetNodes()) == 0 { - c.connectNodes(c.getSeeds()) + return c.connectCluster() } return c.listenForNodeChanges() @@ -197,7 +232,7 @@ func (c *Cluster) listenForNodeChanges() error { } q, err := newQuery( - DB("rethinkdb").Table("server_status").Changes(), + DB(SystemDatabase).Table(ServerStatusSystemTable).Changes(ChangesOpts{IncludeInitial: true}), map[string]interface{}{}, c.opts, ) @@ -210,27 +245,28 @@ func (c *Cluster) listenForNodeChanges() error { hpr.Mark(err) return err } + defer func() { _ = cursor.Close() }() // Keep reading node status updates from changefeed var result struct { - NewVal nodeStatus `rethinkdb:"new_val"` - OldVal nodeStatus `rethinkdb:"old_val"` + NewVal *nodeStatus `rethinkdb:"new_val"` + OldVal *nodeStatus `rethinkdb:"old_val"` } for cursor.Next(&result) { addr := fmt.Sprintf("%s:%d", result.NewVal.Network.Hostname, result.NewVal.Network.ReqlPort) addr = strings.ToLower(addr) - switch result.NewVal.Status { - case "connected": - // Connect to node using exponential backoff (give up after waiting 5s) - // to give the node time to start-up. - b := backoff.NewExponentialBackOff() - b.MaxElapsedTime = time.Second * 5 - - backoff.Retry(func() error { - node, err := c.connectNodeWithStatus(result.NewVal) - if err == nil { - if !c.nodeExists(node) { + if result.NewVal != nil && result.OldVal == nil { + // added new node + if !c.nodeExists(result.NewVal.ID) { + // Connect to node using exponential backoff (give up after waiting 5s) + // to give the node time to start-up. + b := backoff.NewExponentialBackOff() + b.MaxElapsedTime = time.Second * 5 + + err = backoff.Retry(func() error { + node, err := c.connectNodeWithStatus(result.NewVal) + if err == nil { c.addNode(node) Log.WithFields(logrus.Fields{ @@ -238,10 +274,21 @@ func (c *Cluster) listenForNodeChanges() error { "host": node.Host.String(), }).Debug("Connected to node") } + return err + }, b) + if err != nil { + return err } - - return err - }, b) + } + } else if result.OldVal != nil && result.NewVal == nil { + // removed old node + oldNode := c.removeNode(result.OldVal.ID) + if oldNode != nil { + _ = oldNode.Close() + } + } else { + // node updated + // nothing to do - assuming node can't change it's hostname in a single Changes() message } } @@ -250,87 +297,46 @@ func (c *Cluster) listenForNodeChanges() error { return err } -func (c *Cluster) connectNodes(hosts []Host) error { - // Add existing nodes to map +func (c *Cluster) connectCluster() error { nodeSet := map[string]*Node{} - for _, node := range c.GetNodes() { - nodeSet[node.ID] = node - } - var attemptErr error // Attempt to connect to each seed host - for _, host := range hosts { - conn, err := NewConnection(host.String(), c.opts) + for _, host := range c.seeds { + conn, err := c.connFactory(host.String(), c.opts) if err != nil { attemptErr = err Log.Warnf("Error creating connection: %s", err.Error()) continue } - defer conn.Close() - - if c.opts.DiscoverHosts { - q, err := newQuery( - DB("rethinkdb").Table("server_status"), - map[string]interface{}{}, - c.opts, - ) - if err != nil { - Log.Warnf("Error building query: %s", err) - continue - } - _, cursor, err := conn.Query(nil, q) // nil = connection opts' timeout - if err != nil { - attemptErr = err - Log.Warnf("Error fetching cluster status: %s", err) - continue - } + svrRsp, err := conn.Server() + if err != nil { + attemptErr = err + Log.Warnf("Error fetching server ID: %s", err) + _ = conn.Close() - var results []nodeStatus - err = cursor.All(&results) - if err != nil { - attemptErr = err - continue - } + continue + } + _ = conn.Close() - for _, result := range results { - node, err := c.connectNodeWithStatus(result) - if err == nil { - if _, ok := nodeSet[node.ID]; !ok { - Log.WithFields(logrus.Fields{ - "id": node.ID, - "host": node.Host.String(), - }).Debug("Connected to node") - nodeSet[node.ID] = node - } - } else { - attemptErr = err - Log.Warnf("Error connecting to node: %s", err) - } - } - } else { - svrRsp, err := conn.Server() - if err != nil { - attemptErr = err - Log.Warnf("Error fetching server ID: %s", err) - continue - } + node, err := c.connectNode(svrRsp.ID, []Host{host}) + if err != nil { + attemptErr = err + Log.Warnf("Error connecting to node: %s", err) + continue + } - node, err := c.connectNode(svrRsp.ID, []Host{host}) - if err == nil { - if _, ok := nodeSet[node.ID]; !ok { - Log.WithFields(logrus.Fields{ - "id": node.ID, - "host": node.Host.String(), - }).Debug("Connected to node") + if _, ok := nodeSet[node.ID]; !ok { + Log.WithFields(logrus.Fields{ + "id": node.ID, + "host": node.Host.String(), + }).Debug("Connected to node") - nodeSet[node.ID] = node - } - } else { - attemptErr = err - Log.Warnf("Error connecting to node: %s", err) - } + nodeSet[node.ID] = node + } else { + // dublicate node + _ = node.Close() } } @@ -338,19 +344,26 @@ func (c *Cluster) connectNodes(hosts []Host) error { // include driver errors such as if there was an issue building the // query if len(nodeSet) == 0 { - return attemptErr + if attemptErr != nil { + return attemptErr + } + return ErrNoConnections } - nodes := []*Node{} + var nodes []*Node for _, node := range nodeSet { nodes = append(nodes, node) } - c.setNodes(nodes) + c.replaceNodes(nodes) + + if c.opts.DiscoverHosts { + go c.discover() + } return nil } -func (c *Cluster) connectNodeWithStatus(s nodeStatus) (*Node, error) { +func (c *Cluster) connectNodeWithStatus(s *nodeStatus) (*Node, error) { aliases := make([]Host, len(s.Network.CanonicalAddresses)) for i, aliasAddress := range s.Network.CanonicalAddresses { aliases[i] = NewHost(aliasAddress.Host, int(s.Network.ReqlPort)) @@ -364,7 +377,7 @@ func (c *Cluster) connectNode(id string, aliases []Host) (*Node, error) { var err error for len(aliases) > 0 { - pool, err = NewPool(aliases[0], c.opts) + pool, err = newPool(aliases[0], c.opts, c.connFactory) if err != nil { aliases = aliases[1:] continue @@ -387,31 +400,12 @@ func (c *Cluster) connectNode(id string, aliases []Host) (*Node, error) { return nil, ErrInvalidNode } - return newNode(id, aliases, c, pool), nil + return newNode(id, aliases, pool), nil } -// IsConnected returns true if cluster has nodes and is not already closed. +// IsConnected returns true if cluster has nodes and is not already connClosed. func (c *Cluster) IsConnected() bool { - c.mu.RLock() - closed := c.closed - c.mu.RUnlock() - - return (len(c.GetNodes()) > 0) && !closed -} - -// AddSeeds adds new seed hosts to the cluster. -func (c *Cluster) AddSeeds(hosts []Host) { - c.mu.Lock() - c.seeds = append(c.seeds, hosts...) - c.mu.Unlock() -} - -func (c *Cluster) getSeeds() []Host { - c.mu.RLock() - seeds := c.seeds - c.mu.RUnlock() - - return seeds + return (len(c.GetNodes()) > 0) && !c.isClosed() } // GetNextNode returns a random node on the cluster @@ -436,18 +430,20 @@ func (c *Cluster) GetNextNode() (*Node, hostpool.HostPoolResponse, error) { // GetNodes returns a list of all nodes in the cluster func (c *Cluster) GetNodes() []*Node { c.mu.RLock() + defer c.mu.RUnlock() nodes := make([]*Node, 0, len(c.nodes)) for _, n := range c.nodes { nodes = append(nodes, n) } - c.mu.RUnlock() return nodes } -func (c *Cluster) nodeExists(search *Node) bool { - for _, node := range c.GetNodes() { - if node.ID == search.ID { +func (c *Cluster) nodeExists(nodeID string) bool { + c.mu.RLock() + defer c.mu.RUnlock() + for _, node := range c.nodes { + if node.ID == nodeID { return true } } @@ -455,22 +451,24 @@ func (c *Cluster) nodeExists(search *Node) bool { } func (c *Cluster) addNode(node *Node) { - c.mu.RLock() - nodes := append(c.GetNodes(), node) - c.mu.RUnlock() - - c.setNodes(nodes) -} + host := node.Host.String() + c.mu.Lock() + defer c.mu.Unlock() + if _, exist := c.nodes[host]; exist { + // addNode() should be called only if the node doesn't exist + return + } -func (c *Cluster) addNodes(nodesToAdd []*Node) { - c.mu.RLock() - nodes := append(c.GetNodes(), nodesToAdd...) - c.mu.RUnlock() + c.nodes[host] = node - c.setNodes(nodes) + hosts := make([]string, 0, len(c.nodes)) + for _, n := range c.nodes { + hosts = append(hosts, n.Host.String()) + } + c.hp.SetHosts(hosts) } -func (c *Cluster) setNodes(nodes []*Node) { +func (c *Cluster) replaceNodes(nodes []*Node) { nodesMap := make(map[string]*Node, len(nodes)) hosts := make([]string, len(nodes)) for i, node := range nodes { @@ -480,38 +478,37 @@ func (c *Cluster) setNodes(nodes []*Node) { hosts[i] = host } + sort.Strings(hosts) // unit tests stability + c.mu.Lock() c.nodes = nodesMap c.hp.SetHosts(hosts) c.mu.Unlock() } -func (c *Cluster) removeNode(nodeID string) { - nodes := c.GetNodes() - nodeArray := make([]*Node, len(nodes)-1) - count := 0 - - // Add nodes that are not in remove list. - for _, n := range nodes { - if n.ID != nodeID { - nodeArray[count] = n - count++ +func (c *Cluster) removeNode(nodeID string) *Node { + c.mu.Lock() + defer c.mu.Unlock() + var rmNode *Node + for _, node := range c.nodes { + if node.ID == nodeID { + rmNode = node + break } } - - // Do sanity check to make sure assumptions are correct. - if count < len(nodeArray) { - // Resize array. - nodeArray2 := make([]*Node, count) - copy(nodeArray2, nodeArray) - nodeArray = nodeArray2 + if rmNode == nil { + return nil } - c.setNodes(nodeArray) -} + delete(c.nodes, rmNode.Host.String()) + + hosts := make([]string, 0, len(c.nodes)) + for _, n := range c.nodes { + hosts = append(hosts, n.Host.String()) + } + c.hp.SetHosts(hosts) -func (c *Cluster) nextNodeIndex() int64 { - return atomic.AddInt64(&c.nodeIndex, 1) + return rmNode } func (c *Cluster) numRetries() int { diff --git a/cluster_test.go b/cluster_test.go new file mode 100644 index 00000000..d55a83b2 --- /dev/null +++ b/cluster_test.go @@ -0,0 +1,571 @@ +package rethinkdb + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "fmt" + "github.com/stretchr/testify/mock" + test "gopkg.in/check.v1" + "gopkg.in/rethinkdb/rethinkdb-go.v6/encoding" + p "gopkg.in/rethinkdb/rethinkdb-go.v6/ql2" + "io" + "net" + "time" +) + +type ClusterSuite struct{} + +var _ = test.Suite(&ClusterSuite{}) + +func (s *ClusterSuite) TestCluster_NewSingle_NoDiscover_Ok(c *test.C) { + host1 := Host{Name: "host1", Port: 28015} + node1 := "node1" + + conn1 := &connMock{} + expectServerQuery(conn1, 1, node1) + conn1.onCloseReturn(nil) + conn2 := &connMock{} + conn2.onCloseReturn(nil) + + dialMock := &mockDial{} + dialMock.On("Dial", host1.String()).Return(conn1, nil).Once() + dialMock.On("Dial", host1.String()).Return(conn2, nil).Once() + + opts := &ConnectOpts{} + seeds := []Host{host1} + cluster := &Cluster{ + hp: newHostPool(opts), + seeds: seeds, + opts: opts, + closed: clusterWorking, + connFactory: mockedConnectionFactory(dialMock), + } + + err := cluster.run() + c.Assert(err, test.IsNil) + conn1.waitDial() + conn2.waitDial() + err = cluster.Close() + c.Assert(err, test.IsNil) + conn1.waitDone() + conn2.waitDone() + mock.AssertExpectationsForObjects(c, dialMock, conn1, conn2) +} + +func (s *ClusterSuite) TestCluster_NewMultiple_NoDiscover_Ok(c *test.C) { + host1 := Host{Name: "host1", Port: 28015} + host2 := Host{Name: "host2", Port: 28015} + node1 := "node1" + node2 := "node2" + + conn1 := &connMock{} + expectServerQuery(conn1, 1, node1) + conn1.onCloseReturn(nil) + conn2 := &connMock{} + conn2.onCloseReturn(nil) + conn3 := &connMock{} + expectServerQuery(conn3, 1, node2) + conn3.onCloseReturn(nil) + conn4 := &connMock{} + conn4.onCloseReturn(nil) + + dialMock := &mockDial{} + dialMock.On("Dial", host1.String()).Return(conn1, nil).Once() + dialMock.On("Dial", host1.String()).Return(conn2, nil).Once() + dialMock.On("Dial", host2.String()).Return(conn3, nil).Once() + dialMock.On("Dial", host2.String()).Return(conn4, nil).Once() + + opts := &ConnectOpts{} + seeds := []Host{host1, host2} + cluster := &Cluster{ + hp: newHostPool(opts), + seeds: seeds, + opts: opts, + closed: clusterWorking, + connFactory: mockedConnectionFactory(dialMock), + } + + err := cluster.run() + c.Assert(err, test.IsNil) + conn1.waitDial() + conn2.waitDial() + conn3.waitDial() + conn4.waitDial() + err = cluster.Close() + c.Assert(err, test.IsNil) + conn1.waitDone() + conn2.waitDone() + conn3.waitDone() + conn4.waitDone() + mock.AssertExpectationsForObjects(c, dialMock, conn1, conn2, conn3, conn4) +} + +func (s *ClusterSuite) TestCluster_NewSingle_NoDiscover_DialFail(c *test.C) { + host1 := Host{Name: "host1", Port: 28015} + + dialMock := &mockDial{} + dialMock.On("Dial", host1.String()).Return(nil, io.EOF).Once() + + opts := &ConnectOpts{} + seeds := []Host{host1} + cluster := &Cluster{ + hp: newHostPool(opts), + seeds: seeds, + opts: opts, + closed: clusterWorking, + connFactory: mockedConnectionFactory(dialMock), + } + + err := cluster.run() + c.Assert(err, test.Equals, io.EOF) + mock.AssertExpectationsForObjects(c, dialMock) +} + +func (s *ClusterSuite) TestCluster_NewMultiple_NoDiscover_DialHalfFail(c *test.C) { + host1 := Host{Name: "host1", Port: 28015} + host2 := Host{Name: "host2", Port: 28015} + node1 := "node1" + + conn1 := &connMock{} + expectServerQuery(conn1, 1, node1) + conn1.onCloseReturn(nil) + conn2 := &connMock{} + conn2.onCloseReturn(nil) + + dialMock := &mockDial{} + dialMock.On("Dial", host1.String()).Return(conn1, nil).Once() + dialMock.On("Dial", host1.String()).Return(conn2, nil).Once() + dialMock.On("Dial", host2.String()).Return(nil, io.EOF).Once() + + opts := &ConnectOpts{} + seeds := []Host{host1, host2} + cluster := &Cluster{ + hp: newHostPool(opts), + seeds: seeds, + opts: opts, + closed: clusterWorking, + connFactory: mockedConnectionFactory(dialMock), + } + + err := cluster.run() + c.Assert(err, test.IsNil) + conn1.waitDial() + conn2.waitDial() + err = cluster.Close() + c.Assert(err, test.IsNil) + conn1.waitDone() + conn2.waitDone() + mock.AssertExpectationsForObjects(c, dialMock, conn1, conn2) +} + +func (s *ClusterSuite) TestCluster_NewMultiple_NoDiscover_DialFail(c *test.C) { + host1 := Host{Name: "host1", Port: 28015} + host2 := Host{Name: "host2", Port: 28015} + + dialMock := &mockDial{} + dialMock.On("Dial", host1.String()).Return(nil, io.EOF).Once() + dialMock.On("Dial", host2.String()).Return(nil, io.EOF).Once() + + opts := &ConnectOpts{} + seeds := []Host{host1, host2} + cluster := &Cluster{ + hp: newHostPool(opts), + seeds: seeds, + opts: opts, + closed: clusterWorking, + connFactory: mockedConnectionFactory(dialMock), + } + + err := cluster.run() + c.Assert(err, test.Equals, io.EOF) + mock.AssertExpectationsForObjects(c, dialMock) +} + +func (s *ClusterSuite) TestCluster_NewSingle_NoDiscover_ServerFail(c *test.C) { + host1 := Host{Name: "host1", Port: 28015} + + conn1 := &connMock{} + expectServerQueryFail(conn1, 1, io.EOF) + conn1.onCloseReturn(nil) + + dialMock := &mockDial{} + dialMock.On("Dial", host1.String()).Return(conn1, nil).Once() + + opts := &ConnectOpts{} + seeds := []Host{host1} + cluster := &Cluster{ + hp: newHostPool(opts), + seeds: seeds, + opts: opts, + closed: clusterWorking, + connFactory: mockedConnectionFactory(dialMock), + } + + err := cluster.run() + c.Assert(err, test.NotNil) + if _, ok := err.(RQLConnectionError); ok { + c.Assert(err, test.Equals, RQLConnectionError{rqlError(io.EOF.Error())}) + } else { + c.Assert(err, test.Equals, ErrConnectionClosed) + } + conn1.waitDone() + mock.AssertExpectationsForObjects(c, dialMock, conn1) +} + +func (s *ClusterSuite) TestCluster_NewSingle_NoDiscover_PingFail(c *test.C) { + host1 := Host{Name: "host1", Port: 28015} + node1 := "node1" + + conn1 := &connMock{} + expectServerQuery(conn1, 1, node1) + conn1.onCloseReturn(nil) + + dialMock := &mockDial{} + dialMock.On("Dial", host1.String()).Return(conn1, nil).Once() + dialMock.On("Dial", host1.String()).Return(nil, io.EOF).Once() + + opts := &ConnectOpts{} + seeds := []Host{host1} + cluster := &Cluster{ + hp: newHostPool(opts), + seeds: seeds, + opts: opts, + closed: clusterWorking, + connFactory: mockedConnectionFactory(dialMock), + } + + err := cluster.run() + c.Assert(err, test.Equals, io.EOF) + conn1.waitDone() + mock.AssertExpectationsForObjects(c, dialMock, conn1) +} + +func (s *ClusterSuite) TestCluster_NewSingle_Discover_Ok(c *test.C) { + host1 := Host{Name: "host1", Port: 28015} + host2 := Host{Name: "1.1.1.1", Port: 2222} + host3 := Host{Name: "2.2.2.2", Port: 3333} + node1 := "node1" + node2 := "node2" + node3 := "node3" + + conn1 := &connMock{} + expectServerQuery(conn1, 1, node1) + conn1.onCloseReturn(nil) + conn2 := &connMock{} + expectServerStatus(conn2, 1, []string{node1, node2, node3}, []Host{host1, host2, host3}) + conn2.onCloseReturn(nil) + conn3 := &connMock{} + conn3.onCloseReturn(nil) + conn4 := &connMock{} // doesn't need call Server() due to it's known through ServerStatus() + conn4.onCloseReturn(nil) + + dialMock := &mockDial{} + dialMock.On("Dial", host1.String()).Return(conn1, nil).Once() + dialMock.On("Dial", host1.String()).Return(conn2, nil).Once() + dialMock.On("Dial", host2.String()).Return(conn3, nil).Once() + dialMock.On("Dial", host3.String()).Return(conn4, nil).Once() + + opts := &ConnectOpts{DiscoverHosts: true} + seeds := []Host{host1} + cluster := &Cluster{ + hp: newHostPool(opts), + seeds: seeds, + opts: opts, + closed: clusterWorking, + connFactory: mockedConnectionFactory(dialMock), + discoverInterval: 10 * time.Second, + } + + err := cluster.run() + c.Assert(err, test.IsNil) + conn1.waitDial() + conn2.waitDial() + conn3.waitDial() + conn4.waitDial() + for !cluster.nodeExists(node2) || !cluster.nodeExists(node3) { // wait node to be added to list to be closed with cluster + time.Sleep(time.Millisecond) + } + err = cluster.Close() + c.Assert(err, test.IsNil) + conn1.waitDone() + conn2.waitDone() + conn3.waitDone() + conn4.waitDone() + mock.AssertExpectationsForObjects(c, dialMock, conn1, conn2, conn3, conn4) +} + +func (s *ClusterSuite) TestCluster_NewMultiple_Discover_Ok(c *test.C) { + host1 := Host{Name: "host1", Port: 28015} + host2 := Host{Name: "host2", Port: 28016} + host3 := Host{Name: "2.2.2.2", Port: 3333} + node1 := "node1" + node2 := "node2" + node3 := "node3" + + conn1 := &connMock{} + expectServerQuery(conn1, 1, node1) + conn1.onCloseReturn(nil) + conn2 := &connMock{} + expectServerStatus(conn2, 1, []string{node1, node2, node3}, []Host{host1, host2, host3}) + conn2.onCloseReturn(nil) + conn3 := &connMock{} + expectServerQuery(conn3, 1, node2) + conn3.onCloseReturn(nil) + conn4 := &connMock{} + conn4.onCloseReturn(nil) + conn5 := &connMock{} // doesn't need call Server() due to it's known through ServerStatus() + conn5.onCloseReturn(nil) + + dialMock := &mockDial{} + dialMock.On("Dial", host1.String()).Return(conn1, nil).Once() + dialMock.On("Dial", host1.String()).Return(conn2, nil).Once() + dialMock.On("Dial", host2.String()).Return(conn3, nil).Once() + dialMock.On("Dial", host2.String()).Return(conn4, nil).Once() + dialMock.On("Dial", host3.String()).Return(conn5, nil).Once() + + opts := &ConnectOpts{DiscoverHosts: true} + seeds := []Host{host1, host2} + cluster := &Cluster{ + hp: newHostPool(opts), + seeds: seeds, + opts: opts, + closed: clusterWorking, + connFactory: mockedConnectionFactory(dialMock), + discoverInterval: 10 * time.Second, + } + + err := cluster.run() + c.Assert(err, test.IsNil) + conn1.waitDial() + conn2.waitDial() + conn3.waitDial() + conn4.waitDial() + conn5.waitDial() + for !cluster.nodeExists(node3) { // wait node to be added to list to be closed with cluster + time.Sleep(time.Millisecond) + } + err = cluster.Close() + c.Assert(err, test.IsNil) + conn1.waitDone() + conn2.waitDone() + conn3.waitDone() + conn4.waitDone() + conn5.waitDone() + mock.AssertExpectationsForObjects(c, dialMock, conn1, conn2, conn3, conn4, conn5) +} + +type mockDial struct { + mock.Mock +} + +func mockedConnectionFactory(dial *mockDial) connFactory { + return func(host string, opts *ConnectOpts) (connection *Connection, err error) { + args := dial.MethodCalled("Dial", host) + err = args.Error(1) + if err != nil { + return nil, err + } + + connection = newConnection(args.Get(0).(net.Conn), host, opts) + done := runConnection(connection) + + m := args.Get(0).(*connMock) + m.setDone(done) + + return connection, nil + } +} + +func expectServerQuery(conn *connMock, token int64, nodeID string) { + writeChan := make(chan struct{}) + readChan := make(chan struct{}) + + rawQ := makeServerQueryRaw(token) + conn.On("Write", rawQ).Return(0, nil, nil).Once().Run(func(args mock.Arguments) { + close(writeChan) + }) + + rawR := makeServerResponseRaw(token, nodeID) + rawH := makeResponseHeaderRaw(token, len(rawR)) + + conn.On("Read", respHeaderLen).Return(rawH, len(rawH), nil, nil).Once().Run(func(args mock.Arguments) { + <-writeChan + close(readChan) + }) + conn.On("Read", len(rawR)).Return(rawR, len(rawR), nil, nil).Once().Run(func(args mock.Arguments) { + <-readChan + }) +} + +func expectServerQueryFail(conn *connMock, token int64, err error) { + writeChan := make(chan struct{}) + + rawQ := makeServerQueryRaw(token) + conn.On("Write", rawQ).Return(0, nil, nil).Once().Run(func(args mock.Arguments) { + close(writeChan) + }) + + conn.On("Read", respHeaderLen).Return(nil, 0, err, nil).Once().Run(func(args mock.Arguments) { + <-writeChan + }) +} + +func makeServerQueryRaw(token int64) []byte { + buf := &bytes.Buffer{} + buf.Grow(respHeaderLen) + buf.Write(buf.Bytes()[:respHeaderLen]) + enc := json.NewEncoder(buf) + + q := Query{ + Token: token, + Type: p.Query_SERVER_INFO, + } + + err := enc.Encode(q.Build()) + if err != nil { + panic(fmt.Sprintf("must encode failed: %v", err)) + } + b := buf.Bytes() + binary.LittleEndian.PutUint64(b, uint64(q.Token)) + binary.LittleEndian.PutUint32(b[8:], uint32(len(b)-respHeaderLen)) + return b +} + +func makeResponseHeaderRaw(token int64, respLen int) []byte { + buf1 := &bytes.Buffer{} + buf1.Grow(respHeaderLen) + buf1.Write(buf1.Bytes()[:respHeaderLen]) // reserve for header + b1 := buf1.Bytes() + binary.LittleEndian.PutUint64(b1, uint64(token)) + binary.LittleEndian.PutUint32(b1[8:], uint32(respLen)) + return b1 +} + +func makeServerResponseRaw(token int64, nodeID string) []byte { + buf2 := &bytes.Buffer{} + enc := json.NewEncoder(buf2) + + coded, err := encoding.Encode(&ServerResponse{ID: nodeID}) + if err != nil { + panic(fmt.Sprintf("must encode response failed: %v", err)) + } + jresp, err := json.Marshal(coded) + if err != nil { + panic(fmt.Sprintf("must encode response failed: %v", err)) + } + + resp := Response{Token: token, Type: p.Response_SERVER_INFO, Responses: []json.RawMessage{jresp}} + err = enc.Encode(resp) + if err != nil { + panic(fmt.Sprintf("must encode failed: %v", err)) + } + + return buf2.Bytes() +} + +func expectServerStatus(conn *connMock, token int64, nodeIDs []string, hosts []Host) { + writeChan := make(chan struct{}) + readHChan := make(chan struct{}) + readRChan := make(chan struct{}) + + rawQ := makeServerStatusQueryRaw(token) + conn.On("Write", rawQ).Return(0, nil, nil).Once().Run(func(args mock.Arguments) { + close(writeChan) + }) + + rawR := makeServerStatusResponseRaw(token, nodeIDs, hosts) + rawH := makeResponseHeaderRaw(token, len(rawR)) + + conn.On("Read", respHeaderLen).Return(rawH, len(rawH), nil, nil).Once().Run(func(args mock.Arguments) { + <-writeChan + close(readHChan) + }) + conn.On("Read", len(rawR)).Return(rawR, len(rawR), nil, nil).Once().Run(func(args mock.Arguments) { + <-readHChan + close(readRChan) + }) + + rawQ2 := makeContinueQueryRaw(token) + // maybe - connection may be closed until cursor fetchs next batch + conn.On("Write", rawQ2).Return(0, nil, nil).Maybe().Run(func(args mock.Arguments) { + <-readRChan + }) +} + +func makeServerStatusQueryRaw(token int64) []byte { + buf := &bytes.Buffer{} + buf.Grow(respHeaderLen) + buf.Write(buf.Bytes()[:respHeaderLen]) // reserve for header + enc := json.NewEncoder(buf) + + t := DB(SystemDatabase).Table(ServerStatusSystemTable).Changes(ChangesOpts{IncludeInitial: true}) + q, err := newQuery(t, map[string]interface{}{}, &ConnectOpts{}) + if err != nil { + panic(fmt.Sprintf("must newQuery failed: %v", err)) + } + q.Token = token + + err = enc.Encode(q.Build()) + if err != nil { + panic(fmt.Sprintf("must encode failed: %v", err)) + } + + b := buf.Bytes() + binary.LittleEndian.PutUint64(b, uint64(q.Token)) + binary.LittleEndian.PutUint32(b[8:], uint32(len(b)-respHeaderLen)) + return b +} + +func makeServerStatusResponseRaw(token int64, nodeIDs []string, hosts []Host) []byte { + buf2 := &bytes.Buffer{} + enc := json.NewEncoder(buf2) + + type change struct { + NewVal *nodeStatus `rethinkdb:"new_val"` + OldVal *nodeStatus `rethinkdb:"old_val"` + } + jresps := make([]json.RawMessage, len(nodeIDs)) + for i := range nodeIDs { + status := &nodeStatus{ID: nodeIDs[i], Network: nodeStatusNetwork{ + ReqlPort: int64(hosts[i].Port), + CanonicalAddresses: []nodeStatusNetworkAddr{ + {Host: hosts[i].Name}, + }, + }} + + coded, err := encoding.Encode(&change{NewVal: status}) + if err != nil { + panic(fmt.Sprintf("must encode response failed: %v", err)) + } + jresps[i], err = json.Marshal(coded) + if err != nil { + panic(fmt.Sprintf("must encode response failed: %v", err)) + } + } + + resp := Response{Token: token, Type: p.Response_SUCCESS_PARTIAL, Responses: jresps} + err := enc.Encode(resp) + if err != nil { + panic(fmt.Sprintf("must encode failed: %v", err)) + } + return buf2.Bytes() +} + +func makeContinueQueryRaw(token int64) []byte { + buf := &bytes.Buffer{} + buf.Grow(respHeaderLen) + buf.Write(buf.Bytes()[:respHeaderLen]) // reserve for header + enc := json.NewEncoder(buf) + + q := Query{Token: token, Type: p.Query_CONTINUE} + err := enc.Encode(q.Build()) + if err != nil { + panic(fmt.Sprintf("must encode failed: %v", err)) + } + + b := buf.Bytes() + binary.LittleEndian.PutUint64(b, uint64(q.Token)) + binary.LittleEndian.PutUint32(b[8:], uint32(len(b)-respHeaderLen)) + return b +} diff --git a/connection.go b/connection.go index f5b760f8..6532cba3 100644 --- a/connection.go +++ b/connection.go @@ -22,11 +22,11 @@ const ( respHeaderLen = 12 defaultKeepAlivePeriod = time.Second * 30 - notBad = 0 - bad = 1 + connNotBad = 0 + connBad = 1 - working = 0 - closed = 1 + connWorking = 0 + connClosed = 1 ) // Response represents the raw response from a query, most of the time you @@ -49,15 +49,16 @@ type Connection struct { address string opts *ConnectOpts - _ [4]byte - token int64 - cursors map[int64]*Cursor - bad int32 // 0 - not bad, 1 - bad - closed int32 // 0 - working, 1 - closed - stopReadChan chan bool - readRequestsChan chan tokenAndPromise - responseChan chan responseAndError - mu sync.Mutex + _ [4]byte + token int64 + cursors map[int64]*Cursor + bad int32 // 0 - not bad, 1 - bad + closed int32 // 0 - working, 1 - closed + stopReadChan chan bool + readRequestsChan chan tokenAndPromise + responseChan chan responseAndError + stopProcessingChan chan struct{} + mu sync.Mutex } type responseAndError struct { @@ -110,31 +111,30 @@ func NewConnection(address string, opts *ConnectOpts) (*Connection, error) { return nil, err } - c.runConnection() + // NOTE: mock.go: Mock.Query() + // NOTE: connection_test.go: runConnection() + go c.readSocket() + go c.processResponses() return c, nil } func newConnection(conn net.Conn, address string, opts *ConnectOpts) *Connection { c := &Connection{ - Conn: conn, - address: address, - opts: opts, - cursors: make(map[int64]*Cursor), - stopReadChan: make(chan bool, 1), - bad: notBad, - closed: working, - readRequestsChan: make(chan tokenAndPromise, 16), - responseChan: make(chan responseAndError, 16), + Conn: conn, + address: address, + opts: opts, + cursors: make(map[int64]*Cursor), + stopReadChan: make(chan bool, 1), + bad: connNotBad, + closed: connWorking, + readRequestsChan: make(chan tokenAndPromise, 16), + responseChan: make(chan responseAndError, 16), + stopProcessingChan: make(chan struct{}), } return c } -func (c *Connection) runConnection() { - go c.readSocket() - go c.processResponses() -} - // Close closes the underlying net.Conn func (c *Connection) Close() error { var err error @@ -186,7 +186,7 @@ func (c *Connection) Query(ctx context.Context, q Query) (*Response, *Cursor, er parentSpan := opentracing.SpanFromContext(ctx) if parentSpan != nil { if q.Type == p.Query_START { - querySpan := c.startTracingSpan(parentSpan, &q) // will be Finished when cursor closed + querySpan := c.startTracingSpan(parentSpan, &q) // will be Finished when cursor connClosed parentSpan = querySpan ctx = opentracing.ContextWithSpan(ctx, querySpan) } @@ -224,13 +224,15 @@ func (c *Connection) Query(ctx context.Context, q Query) (*Response, *Cursor, er return future.response, future.cursor, future.err case <-ctx.Done(): return c.stopQuery(&q) + case <-c.stopProcessingChan: // connection readRequests processing stopped, promise can be never answered + return nil, nil, ErrConnectionClosed } } func (c *Connection) stopQuery(q *Query) (*Response, *Cursor, error) { - if q.Type != p.Query_STOP { + if q.Type != p.Query_STOP && !c.isClosed() && !c.isBad() { stopQuery := newStopQuery(q.Token) - c.Query(c.contextFromConnectionOpts(), stopQuery) + _, _, _ = c.Query(c.contextFromConnectionOpts(), stopQuery) } return nil, nil, ErrQueryTimeout } @@ -255,16 +257,16 @@ func (c *Connection) readSocket() { for { response, err := c.readResponse() - respPair := responseAndError{ + c.responseChan <- responseAndError{ response: response, err: err, } select { - case c.responseChan <- respPair: case <-c.stopReadChan: close(c.responseChan) return + default: } } } @@ -278,22 +280,21 @@ func (c *Connection) processResponses() { var ok bool select { - case respPair := <-c.responseChan: + case respPair, openned := <-c.responseChan: if respPair.err != nil { // Transport socket error, can't continue to work - // Don't know return to who - return to all - for _, rr := range readRequests { - if rr.promise != nil { - rr.promise <- responseAndCursor{err: respPair.err} - close(rr.promise) - } - } + // Don't know return to who (no token) - return to all + broadcastError(readRequests, respPair.err) readRequests = []tokenAndPromise{} - c.Close() + _ = c.Close() // next `if` will be called indirect cascade by closing chans continue } - if respPair.response == nil && respPair.err == nil { // responseChan is closed - continue + if !openned { // responseChan is connClosed (stopReadChan is closed too) + close(c.stopProcessingChan) + broadcastError(readRequests, ErrConnectionClosed) + c.cursors = nil + + return } response = respPair.response @@ -312,16 +313,6 @@ func (c *Connection) processResponses() { continue } responses = removeResponse(responses, readRequest.query.Token) - - case <-c.stopReadChan: - for _, rr := range readRequests { - if rr.promise != nil { - rr.promise <- responseAndCursor{err: ErrConnectionClosed} - close(rr.promise) - } - } - c.cursors = nil - return } response, cursor, err := c.processResponse(readRequest.ctx, *readRequest.query, response, readRequest.span) @@ -332,6 +323,15 @@ func (c *Connection) processResponses() { } } +func broadcastError(readRequests []tokenAndPromise, err error) { + for _, rr := range readRequests { + if rr.promise != nil { + rr.promise <- responseAndCursor{err: err} + close(rr.promise) + } + } +} + type ServerResponse struct { ID string `rethinkdb:"id"` Name string `rethinkdb:"name"` @@ -526,19 +526,19 @@ func (c *Connection) processWaitResponse(response *Response) (*Response, *Cursor } func (c *Connection) setBad() { - atomic.StoreInt32(&c.bad, bad) + atomic.StoreInt32(&c.bad, connBad) } func (c *Connection) isBad() bool { - return atomic.LoadInt32(&c.bad) == bad + return atomic.LoadInt32(&c.bad) == connBad } func (c *Connection) setClosed() { - atomic.StoreInt32(&c.closed, closed) + atomic.StoreInt32(&c.closed, connClosed) } func (c *Connection) isClosed() bool { - return atomic.LoadInt32(&c.closed) == closed + return atomic.LoadInt32(&c.closed) == connClosed } func getReadRequest(list []tokenAndPromise, token int64) (tokenAndPromise, bool) { diff --git a/connection_test.go b/connection_test.go index 15a5845f..99a0cebd 100644 --- a/connection_test.go +++ b/connection_test.go @@ -5,13 +5,35 @@ import ( "encoding/json" "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/mocktracer" + "github.com/stretchr/testify/mock" "golang.org/x/net/context" test "gopkg.in/check.v1" p "gopkg.in/rethinkdb/rethinkdb-go.v6/ql2" "io" + "sync" "time" ) +func runConnection(c *Connection) <-chan struct{} { + wg := &sync.WaitGroup{} + wg.Add(2) + go func() { + c.readSocket() + wg.Done() + }() + go func() { + c.processResponses() + wg.Done() + }() + + doneChan := make(chan struct{}) + go func() { + wg.Wait() + close(doneChan) + }() + return doneChan +} + type ConnectionSuite struct{} var _ = test.Suite(&ConnectionSuite{}) @@ -31,9 +53,10 @@ func (s *ConnectionSuite) TestConnection_Query_Ok(c *test.C) { conn.On("Close").Return(nil) connection := newConnection(conn, "addr", &ConnectOpts{}) - connection.runConnection() + closed := runConnection(connection) response, cursor, err := connection.Query(ctx, q) connection.Close() + <-closed c.Assert(response, test.NotNil) c.Assert(response.Token, test.Equals, token) @@ -66,9 +89,10 @@ func (s *ConnectionSuite) TestConnection_Query_DefaultDBOk(c *test.C) { conn.On("Close").Return(nil) connection := newConnection(conn, "addr", &ConnectOpts{Database: "db"}) - connection.runConnection() + done := runConnection(connection) response, cursor, err := connection.Query(ctx, q) connection.Close() + <-done c.Assert(response, test.NotNil) c.Assert(response.Token, test.Equals, token) @@ -132,10 +156,10 @@ func (s *ConnectionSuite) TestConnection_Query_NoReplyOk(c *test.C) { conn.On("Close").Return(nil) connection := newConnection(conn, "addr", &ConnectOpts{}) - connection.runConnection() + done := runConnection(connection) response, cursor, err := connection.Query(nil, q) - time.Sleep(5 * time.Millisecond) connection.Close() + <-done c.Assert(response, test.IsNil) c.Assert(cursor, test.IsNil) @@ -212,18 +236,24 @@ func (s *ConnectionSuite) TestConnection_processResponses_SocketErr(c *test.C) { promise3 := make(chan responseAndCursor, 1) conn := &connMock{} - conn.On("Close").Return(nil) - connection := newConnection(conn, "addr", &ConnectOpts{}) - go connection.processResponses() + conn.On("Close").Return(nil).Run(func(args mock.Arguments) { + close(connection.responseChan) + }) + + done := make(chan struct{}) + go func() { + connection.processResponses() + close(done) + }() connection.readRequestsChan <- tokenAndPromise{query: &Query{Token: 1}, promise: promise1} connection.readRequestsChan <- tokenAndPromise{query: &Query{Token: 2}, promise: promise2} connection.readRequestsChan <- tokenAndPromise{query: &Query{Token: 2}, promise: promise3} time.Sleep(5 * time.Millisecond) connection.responseChan <- responseAndError{err: io.EOF} - time.Sleep(5 * time.Millisecond) + <-done select { case f := <-promise1: @@ -254,13 +284,16 @@ func (s *ConnectionSuite) TestConnection_processResponses_StopOk(c *test.C) { connection := newConnection(nil, "addr", &ConnectOpts{}) - go connection.processResponses() + done := make(chan struct{}) + go func() { + connection.processResponses() + close(done) + }() connection.readRequestsChan <- tokenAndPromise{query: &Query{Token: 1}, promise: promise1} - close(connection.responseChan) - time.Sleep(5 * time.Millisecond) - close(connection.stopReadChan) time.Sleep(5 * time.Millisecond) + close(connection.responseChan) + <-done select { case f := <-promise1: diff --git a/cursor.go b/cursor.go index 6be9113d..170ca833 100644 --- a/cursor.go +++ b/cursor.go @@ -15,7 +15,7 @@ import ( var ( errNilCursor = errors.New("cursor is nil") - errCursorClosed = errors.New("connection closed, cannot read cursor") + errCursorClosed = errors.New("connection connClosed, cannot read cursor") ) func newCursor(ctx context.Context, conn *Connection, cursorType string, token int64, term *Term, opts map[string]interface{}) *Cursor { @@ -120,7 +120,7 @@ func (c *Cursor) Err() error { } // Close closes the cursor, preventing further enumeration. If the end is -// encountered, the cursor is closed automatically. Close is idempotent. +// encountered, the cursor is connClosed automatically. Close is idempotent. func (c *Cursor) Close() error { if c == nil { return errNilCursor @@ -131,7 +131,7 @@ func (c *Cursor) Close() error { var err error - // If cursor is already closed return immediately + // If cursor is already connClosed return immediately closed := c.closed if closed { return nil @@ -143,7 +143,7 @@ func (c *Cursor) Close() error { if conn == nil { return nil } - if conn.Conn == nil { + if conn.isClosed() { return nil } @@ -386,7 +386,7 @@ func (c *Cursor) All(result interface{}) error { resultv.Elem().Set(slicev.Slice(0, i)) if err := c.Err(); err != nil { - c.Close() + _ = c.Close() return err } @@ -601,7 +601,7 @@ func (c *Cursor) seekCursor(bufferResponse bool) error { } // Loop over loading data, applying skips as necessary and loading more data as needed - // until either the cursor is closed or finished, or we have applied all outstanding + // until either the cursor is connClosed or finished, or we have applied all outstanding // skips and data is available for { c.applyPendingSkips(bufferResponse) // if we are buffering the responses, skip can drain from the buffer diff --git a/errors.go b/errors.go index 42b98f47..ef31e350 100644 --- a/errors.go +++ b/errors.go @@ -22,7 +22,7 @@ var ( // ErrNoConnections is returned when there are no active connections in the // clusters connection pool. ErrNoConnections = errors.New("rethinkdb: no connections were available") - // ErrConnectionClosed is returned when trying to send a query with a closed + // ErrConnectionClosed is returned when trying to send a query with a connClosed // connection. ErrConnectionClosed = errors.New("rethinkdb: the connection is closed") // ErrQueryTimeout is returned when query context deadline exceeded. diff --git a/go.mod b/go.mod index 3fa4cb26..862f3ac0 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module gopkg.in/rethinkdb/rethinkdb-go.v6 require ( github.com/bitly/go-hostpool v0.1.0 // indirect github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 // indirect - github.com/cenkalti/backoff/v4 v4.0.0 + gopkg.in/cenkalti/backoff.v4 v4.0.0 github.com/golang/protobuf v1.3.4 github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed github.com/kr/pretty v0.1.0 // indirect @@ -23,4 +23,7 @@ require ( gopkg.in/yaml.v2 v2.2.8 // indirect ) +// gopath support +replace gopkg.in/cenkalti/backoff.v4 v4.0.0 => github.com/cenkalti/backoff/v4 v4.0.0 + go 1.14 diff --git a/mock.go b/mock.go index 733df587..82579739 100644 --- a/mock.go +++ b/mock.go @@ -4,7 +4,6 @@ import ( "encoding/binary" "encoding/json" "fmt" - "gopkg.in/check.v1" "gopkg.in/rethinkdb/rethinkdb-go.v6/encoding" "net" "reflect" @@ -63,7 +62,7 @@ type MockQuery struct { Repeatability int // Holds a channel that will be used to block the Return until it either - // recieves a message or is closed. nil means it returns immediately. + // recieves a message or is connClosed. nil means it returns immediately. WaitFor <-chan time.Time // Amount of times this query has been executed @@ -155,7 +154,7 @@ func (mq *MockQuery) Times(i int) *MockQuery { return mq } -// WaitUntil sets the channel that will block the mock's return until its closed +// WaitUntil sets the channel that will block the mock's return until its connClosed // or a message is received. // // mock.On(r.Table("test")).WaitUntil(time.After(time.Second)) @@ -363,7 +362,8 @@ func (m *Mock) Query(ctx context.Context, q Query) (*Cursor, error) { c.releaseConn = func() error { return conn.Close() } conn.cursors[query.Query.Token] = c - conn.runConnection() + go conn.readSocket() + go conn.processResponses() c.mu.Lock() err := c.fetchMore() @@ -424,7 +424,6 @@ func (m *Mock) queries() []MockQuery { } type mockConn struct { - c *check.C mu sync.Mutex value []byte tokens chan int64 @@ -513,7 +512,7 @@ func (c *mockConn) Read(b []byte) (n int, err error) { func (c *mockConn) Write(b []byte) (n int, err error) { if len(b) < 8 { - panic("bad socket write") + panic("connBad socket write") } token := int64(binary.LittleEndian.Uint64(b[:8])) c.tokens <- token diff --git a/mocks_test.go b/mocks_test.go index a86da3bd..76eea5a3 100644 --- a/mocks_test.go +++ b/mocks_test.go @@ -2,12 +2,40 @@ package rethinkdb import ( "github.com/stretchr/testify/mock" + "io" "net" "time" ) type connMock struct { mock.Mock + done <-chan struct{} + doneSet chan struct{} +} + +func (m *connMock) setDone(done <-chan struct{}) { + m.done = done + close(m.doneSet) +} + +func (m *connMock) waitDial() { + <-m.doneSet +} + +func (m *connMock) waitDone() { + <-m.done +} + +func (m *connMock) onCloseReturn(err error) { + closeChan := make(chan struct{}) + m.doneSet = make(chan struct{}) + // Maybe - Connection can be closed by Close() before Read() occurs when stopReadChan closed + m.On("Read", respHeaderLen).Return(nil, 0, io.EOF, nil).Maybe().Run(func(args mock.Arguments) { + <-closeChan + }) + m.On("Close").Return(err).Once().Run(func(args mock.Arguments) { + close(closeChan) + }) } func (m *connMock) Read(b []byte) (n int, err error) { diff --git a/node.go b/node.go index 73009665..5fe132bd 100644 --- a/node.go +++ b/node.go @@ -13,26 +13,24 @@ type Node struct { Host Host aliases []Host - cluster *Cluster - pool *Pool + pool *Pool mu sync.RWMutex closed bool } -func newNode(id string, aliases []Host, cluster *Cluster, pool *Pool) *Node { +func newNode(id string, aliases []Host, pool *Pool) *Node { node := &Node{ ID: id, Host: aliases[0], aliases: aliases, - cluster: cluster, pool: pool, } return node } -// Closed returns true if the node is closed +// Closed returns true if the node is connClosed func (n *Node) Closed() bool { n.mu.RLock() defer n.mu.RUnlock() @@ -119,16 +117,19 @@ func (n *Node) Server() (ServerResponse, error) { } type nodeStatus struct { - ID string `rethinkdb:"id"` - Name string `rethinkdb:"name"` - Status string `rethinkdb:"status"` - Network struct { - Hostname string `rethinkdb:"hostname"` - ClusterPort int64 `rethinkdb:"cluster_port"` - ReqlPort int64 `rethinkdb:"reql_port"` - CanonicalAddresses []struct { - Host string `rethinkdb:"host"` - Port int64 `rethinkdb:"port"` - } `rethinkdb:"canonical_addresses"` - } `rethinkdb:"network"` + ID string `rethinkdb:"id"` + Name string `rethinkdb:"name"` + Network nodeStatusNetwork `rethinkdb:"network"` +} + +type nodeStatusNetwork struct { + Hostname string `rethinkdb:"hostname"` + ClusterPort int64 `rethinkdb:"cluster_port"` + ReqlPort int64 `rethinkdb:"reql_port"` + CanonicalAddresses []nodeStatusNetworkAddr `rethinkdb:"canonical_addresses"` +} + +type nodeStatusNetworkAddr struct { + Host string `rethinkdb:"host"` + Port int64 `rethinkdb:"port"` } diff --git a/pool.go b/pool.go index 73a96e6f..3b13387b 100644 --- a/pool.go +++ b/pool.go @@ -17,6 +17,8 @@ const ( poolIsClosed int32 = 1 ) +type connFactory func(host string, opts *ConnectOpts) (*Connection, error) + // A Pool is used to store a pool of connections to a single RethinkDB server type Pool struct { host Host @@ -26,11 +28,17 @@ type Pool struct { pointer int32 closed int32 + connFactory connFactory + mu sync.Mutex // protects lazy creating connections } // NewPool creates a new connection pool for the given host func NewPool(host Host, opts *ConnectOpts) (*Pool, error) { + return newPool(host, opts, NewConnection) +} + +func newPool(host Host, opts *ConnectOpts, connFactory connFactory) (*Pool, error) { initialCap := opts.InitialCap if initialCap <= 0 { // Fallback to MaxIdle if InitialCap is zero, this should be removed @@ -46,18 +54,19 @@ func NewPool(host Host, opts *ConnectOpts) (*Pool, error) { conns := make([]*Connection, maxOpen) var err error for i := 0; i < opts.InitialCap; i++ { - conns[i], err = NewConnection(host.String(), opts) + conns[i], err = connFactory(host.String(), opts) if err != nil { return nil, err } } return &Pool{ - conns: conns, - pointer: -1, - host: host, - opts: opts, - closed: poolIsNotClosed, + conns: conns, + pointer: -1, + host: host, + opts: opts, + connFactory: connFactory, + closed: poolIsNotClosed, }, nil } @@ -108,17 +117,27 @@ func (p *Pool) conn() (*Connection, error) { } pos = pos % int32(len(p.conns)) + var err error + if p.conns[pos] == nil { p.mu.Lock() defer p.mu.Unlock() if p.conns[pos] == nil { - var err error - p.conns[pos], err = NewConnection(p.host.String(), p.opts) + p.conns[pos], err = p.connFactory(p.host.String(), p.opts) if err != nil { return nil, err } } + } else if p.conns[pos].isBad() { + // connBad connection needs to be reconnected + p.mu.Lock() + defer p.mu.Unlock() + + p.conns[pos], err = p.connFactory(p.host.String(), p.opts) + if err != nil { + return nil, err + } } return p.conns[pos], nil diff --git a/query_math.go b/query_math.go index 7f3c5c4e..4c5c836a 100644 --- a/query_math.go +++ b/query_math.go @@ -179,7 +179,7 @@ func (o RandomOpts) toMap() map[string]interface{} { // Note: Any integer responses can be be coerced to floating-points, when // unmarshaling to a Go floating-point type. The last argument given will always // be the ‘open’ side of the range, but when generating a floating-point -// number, the ‘open’ side may be less than the ‘closed’ side. +// number, the ‘open’ side may be less than the ‘connClosed’ side. func Random(args ...interface{}) Term { var opts = map[string]interface{}{} diff --git a/session.go b/session.go index 0e5c731a..2476e76a 100644 --- a/session.go +++ b/session.go @@ -70,6 +70,7 @@ type ConnectOpts struct { // NumRetries is the number of times a query is retried if a connection // error is detected, queries are not retried if RethinkDB returns a // runtime error. + // Default is 3. NumRetries int `json:"num_retries,omitempty"` // InitialCap is used by the internal connection pool and is used to