Skip to content

Commit

Permalink
Let 'ReconnectPolicy' specify new server address
Browse files Browse the repository at this point in the history
The 'ReconnectAfter' constructor of reconnect policy now holds an optional
'Server' argument, allowing reconnect policies to specify new server addresses
to attempt reconnection to. This makes it possible to fall back to redundant
servers, without needing to completely throw away a connection on the client.
  • Loading branch information
FinleyMcIlwaine committed Aug 28, 2024
1 parent 5834d3d commit 9710050
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 57 deletions.
2 changes: 1 addition & 1 deletion grapesy.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,6 @@ test-suite test-grapesy
, bytestring >= 0.10 && < 0.13
, case-insensitive >= 1.2 && < 1.3
, containers >= 0.6 && < 0.8
, directory >= 1.3 && < 1.4
, exceptions >= 0.10 && < 0.11
, http-types >= 0.12 && < 0.13
, http2 >= 5.3.1 && < 5.4
Expand All @@ -365,6 +364,7 @@ test-suite test-grapesy
, tasty >= 1.4 && < 1.6
, tasty-hunit >= 0.10 && < 0.11
, tasty-quickcheck >= 0.10 && < 0.12
, temporary >= 1.3 && < 1.4
, text >= 1.2 && < 2.2
, tls >= 1.7 && < 2.2
, tree-diff >= 0.3 && < 0.4
Expand Down
24 changes: 15 additions & 9 deletions src/Network/GRPC/Client/Connection.hs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import Control.Concurrent.STM
import Control.Monad
import Control.Monad.Catch
import Data.Default
import Data.Maybe
import GHC.Stack
import Network.HPACK qualified as HPACK
import Network.HTTP2.Client qualified as HTTP2.Client
Expand Down Expand Up @@ -165,13 +166,18 @@ data ReconnectPolicy =
-- connection), do not attempt to connect again.
DontReconnect

-- | Reconnect after random delay after the IO action returns
-- | Reconnect to the (potentially different) server after the IO action
-- returns
--
-- If the 'Maybe' is 'Just', we'll attempt to reconnect to a server at the
-- new address. If 'Nothing', we'll attempt to connect to the original
-- server that 'withConnection' was given.
--
-- This is a very general API: typically the IO action will call
-- 'threadDelay' after some amount of time (which will typically involve
-- some randomness), but it can be used to do things such as display a
-- message to the user somewhere that the client is reconnecting.
| ReconnectAfter (IO ReconnectPolicy)
| ReconnectAfter (Maybe Server) (IO ReconnectPolicy)

-- | The default policy is 'DontReconnect'
--
Expand Down Expand Up @@ -207,7 +213,7 @@ exponentialBackoff waitFor e = go
where
go :: (Double, Double) -> Word -> ReconnectPolicy
go _ 0 = DontReconnect
go (lo, hi) n = ReconnectAfter $ do
go (lo, hi) n = ReconnectAfter Nothing $ do
delay <- randomRIO (lo, hi)
waitFor $ round $ delay * 1_000_000
return $ go (lo * e, hi * e) (pred n)
Expand Down Expand Up @@ -378,11 +384,11 @@ stayConnected ::
-> TVar ConnectionState
-> MVar ()
-> IO ()
stayConnected connParams server connStateVar connOutOfScope =
loop (connReconnectPolicy connParams)
stayConnected connParams initialServer connStateVar connOutOfScope = do
loop initialServer (connReconnectPolicy connParams)
where
loop :: ReconnectPolicy -> IO ()
loop remainingReconnectPolicy = do
loop :: Server -> ReconnectPolicy -> IO ()
loop server remainingReconnectPolicy = do
-- Start new attempt (this just allocates some internal state)
attempt <- newConnectionAttempt connParams connStateVar connOutOfScope

Expand Down Expand Up @@ -425,9 +431,9 @@ stayConnected connParams server connStateVar connOutOfScope =
atomically $ writeTVar connStateVar $ ConnectionAbandoned err
(False, DontReconnect) -> do
atomically $ writeTVar connStateVar $ ConnectionAbandoned err
(False, ReconnectAfter f) -> do
(False, ReconnectAfter mNewServer f) -> do
atomically $ writeTVar connStateVar $ ConnectionNotReady
loop =<< f
loop (fromMaybe initialServer mNewServer) =<< f

-- | Insecure connection (no TLS)
connectInsecure :: ConnParams -> Attempt -> Address -> IO ()
Expand Down
2 changes: 1 addition & 1 deletion test-grapesy/Test/Driver/ClientServer.hs
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ runTestClient cfg firstTestFailure port clientRun = do
-- This avoids a race condition between the server starting first
-- and the client starting first.
, connReconnectPolicy =
Client.ReconnectAfter $ do
Client.ReconnectAfter Nothing $ do
threadDelay 100_000
return Client.DontReconnect
}
Expand Down
59 changes: 29 additions & 30 deletions test-grapesy/Test/Sanity/Disconnect.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import Data.Either
import Data.IORef
import Data.Word
import Foreign.C.Types (CInt(..))
import Network.Socket
import System.Posix
import Test.Tasty
import Test.Tasty.HUnit
Expand Down Expand Up @@ -59,7 +60,7 @@ test_clientDisconnect = do
]

portSignal <- newEmptyMVar
void $ forkIO $ rawTestServer (pure Nothing) (putMVar portSignal) server
void $ forkIO $ rawTestServer (putMVar portSignal) server

-- Start server
serverPort <- readMVar portSignal
Expand Down Expand Up @@ -126,24 +127,19 @@ test_serverDisconnect :: Assertion
test_serverDisconnect = withTemporaryFile $ \ipcFile -> do
-- We use a temporary file as a very rudimentary means of inter-process
-- communication so the server (which runs in a separate process) can make
-- the client aware of the port it is assigned by the OS. This also helps us
-- make sure the server binds to the same port when it comes back up for
-- reconnect purposes.
let ipcWrite :: String -> IO ()
ipcWrite msg = do
writeFile ipcFile ""
writeFile ipcFile msg
-- the client aware of the port it is assigned by the OS.
let ipcWrite :: PortNumber -> IO ()
ipcWrite port = do
writeFile ipcFile (show port)

ipcRead :: IO String
ipcRead = readFile ipcFile

ipcWaitRead :: IO String
ipcWaitRead = do
ipcRead >>= \case
"" -> do
threadDelay 10000 >> ipcWaitRead
msg -> do
return msg
ipcRead :: IO PortNumber
ipcRead = do
fmap (readMaybe @PortNumber) (readFile ipcFile) >>= \case
Nothing -> do
ipcRead
Just p -> do
writeFile ipcFile ""
return p

-- Create the server
server <-
Expand All @@ -153,22 +149,22 @@ test_serverDisconnect = withTemporaryFile $ \ipcFile -> do
]

let -- Starts the server in a new process. Gives back an action that kills
-- the server process.
-- the created server process.
startServer :: IO (IO ())
startServer = do
serverPid <-
forkProcess $
rawTestServer (readMaybe <$> ipcRead) (ipcWrite . show) server
return $ c_kill (fromIntegral serverPid) sigKILL
rawTestServer ipcWrite server
return $ signalProcess sigKILL serverPid

-- Start server, get the port
killServer <- startServer
serverPort <- read <$> ipcWaitRead
port1 <- ipcRead
signalRestart <- newEmptyMVar
let serverAddress =
let serverAddress port =
Client.ServerInsecure Client.Address {
addressHost = "127.0.0.1"
, addressPort = serverPort
, addressPort = port
, addressAuthority = Nothing
}

Expand All @@ -178,19 +174,23 @@ test_serverDisconnect = withTemporaryFile $ \ipcFile -> do
go :: Int -> Client.ReconnectPolicy
go n
| n == 5
= Client.ReconnectAfter $ do
= Client.ReconnectAfter Nothing $ do
killRestarted <- startServer
port2 <- ipcRead
putMVar signalRestart killRestarted
return $ Client.exponentialBackoff threadDelay 1 (1, 1) 100
return $
Client.ReconnectAfter
(Just $ serverAddress port2)
(pure Client.DontReconnect)
| otherwise
= Client.ReconnectAfter $ do
= Client.ReconnectAfter Nothing $ do
threadDelay 10000
return $ go (n + 1)

connParams :: Client.ConnParams
connParams = def { Client.connReconnectPolicy = reconnectPolicy }

Client.withConnection connParams serverAddress $ \conn -> do
Client.withConnection connParams (serverAddress port1) $ \conn -> do
-- Make 50 concurrent calls. 49 of them sending infinite messages. One
-- of them kills the server after 100 messages.
let numCalls = 50
Expand All @@ -216,7 +216,7 @@ test_serverDisconnect = withTemporaryFile $ \ipcFile -> do
killRestarted <- takeMVar signalRestart
result <- Client.withRPC conn def (Proxy @Trivial) $
countUntil (pure . (>= 100))
assertBool "" (result == 100)
assertEqual "" 100 result

-- Do not leave the server process hanging around
killRestarted
Expand Down Expand Up @@ -274,7 +274,6 @@ echoHandler disconnectCounter call = trackDisconnects disconnectCounter $ do
Auxiliary
-------------------------------------------------------------------------------}

foreign import ccall unsafe "kill" c_kill :: CInt -> CInt -> IO ()
foreign import ccall unsafe "exit" c_exit :: CInt -> IO ()

type Trivial = RawRpc "trivial" "trivial"
Expand Down
13 changes: 3 additions & 10 deletions test-grapesy/Test/Util.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ import Control.Exception
import Control.Monad.Catch
import Control.Monad.IO.Class
import GHC.Stack
import System.Directory
import System.IO
import System.IO.Temp

{-------------------------------------------------------------------------------
Timeouts
Expand Down Expand Up @@ -51,12 +51,5 @@ within t info io = do
generalBracket startTimer stopTimer $ \_ -> io

withTemporaryFile :: (FilePath -> IO a) -> IO a
withTemporaryFile k = do
tmpDir <- getTemporaryDirectory
Control.Exception.bracket
(openTempFile tmpDir "grapesy-test-suite.txt")
(removeFile . fst)
( \(fp, h) -> do
hClose h
k fp
)
withTemporaryFile k =
withSystemTempFile "grapesy-test-suite.txt" (\fp h -> hClose h >> k fp)
10 changes: 4 additions & 6 deletions test-grapesy/Test/Util/RawTestServer.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import Network.Socket
import Network.GRPC.Client qualified as Client
import Network.HTTP.Types qualified as HTTP
import Network.GRPC.Common
import Data.Maybe

{-------------------------------------------------------------------------------
Raw test server
Expand All @@ -35,10 +34,9 @@ import Data.Maybe
-- This also allows us to avoid binding to a specific port in the tests (which
-- might already be in use on the machine running the tests, leading to spurious
-- test failures).
rawTestServer :: IO (Maybe PortNumber) -> (PortNumber -> IO ()) -> HTTP2.Server -> IO ()
rawTestServer getPort signalPort server = do
mPortIn <- fromMaybe 0 <$> getPort
addr <- NetworkRun.resolve Stream (Just "127.0.0.1") (show mPortIn) [AI_PASSIVE]
rawTestServer :: (PortNumber -> IO ()) -> HTTP2.Server -> IO ()
rawTestServer signalPort server = do
addr <- NetworkRun.resolve Stream (Just "127.0.0.1") "0" [AI_PASSIVE]
bracket (NetworkRun.openTCPServerSocket addr) close $ \listenSock -> do
addr' <- getSocketName listenSock
portOut <- case addr' of
Expand All @@ -56,7 +54,7 @@ rawTestServer getPort signalPort server = do
withTestServer :: HTTP2.Server -> (Client.Address -> IO a) -> IO a
withTestServer server k = do
serverPort <- newEmptyMVar
withAsync (rawTestServer (pure Nothing) (putMVar serverPort) server) $
withAsync (rawTestServer (putMVar serverPort) server) $
\_serverThread -> do
port <- readMVar serverPort
let addr :: Client.Address
Expand Down

0 comments on commit 9710050

Please sign in to comment.