diff --git a/grapesy.cabal b/grapesy.cabal index 7796bfb7..187e4c0a 100644 --- a/grapesy.cabal +++ b/grapesy.cabal @@ -346,7 +346,6 @@ test-suite test-grapesy , bytestring >= 0.10 && < 0.13 , case-insensitive >= 1.2 && < 1.3 , containers >= 0.6 && < 0.8 - , directory >= 1.3 && < 1.4 , exceptions >= 0.10 && < 0.11 , http-types >= 0.12 && < 0.13 , http2 >= 5.3.1 && < 5.4 @@ -365,6 +364,7 @@ test-suite test-grapesy , tasty >= 1.4 && < 1.6 , tasty-hunit >= 0.10 && < 0.11 , tasty-quickcheck >= 0.10 && < 0.12 + , temporary >= 1.3 && < 1.4 , text >= 1.2 && < 2.2 , tls >= 1.7 && < 2.2 , tree-diff >= 0.3 && < 0.4 diff --git a/src/Network/GRPC/Client/Connection.hs b/src/Network/GRPC/Client/Connection.hs index 745011c7..a9f4ffa5 100644 --- a/src/Network/GRPC/Client/Connection.hs +++ b/src/Network/GRPC/Client/Connection.hs @@ -29,6 +29,7 @@ import Control.Concurrent.STM import Control.Monad import Control.Monad.Catch import Data.Default +import Data.Maybe import GHC.Stack import Network.HPACK qualified as HPACK import Network.HTTP2.Client qualified as HTTP2.Client @@ -165,13 +166,18 @@ data ReconnectPolicy = -- connection), do not attempt to connect again. DontReconnect - -- | Reconnect after random delay after the IO action returns + -- | Reconnect to the (potentially different) server after the IO action + -- returns + -- + -- If the 'Maybe' is 'Just', we'll attempt to reconnect to a server at the + -- new address. If 'Nothing', we'll attempt to connect to the original + -- server that 'withConnection' was given. -- -- This is a very general API: typically the IO action will call -- 'threadDelay' after some amount of time (which will typically involve -- some randomness), but it can be used to do things such as display a -- message to the user somewhere that the client is reconnecting. - | ReconnectAfter (IO ReconnectPolicy) + | ReconnectAfter (Maybe Server) (IO ReconnectPolicy) -- | The default policy is 'DontReconnect' -- @@ -207,7 +213,7 @@ exponentialBackoff waitFor e = go where go :: (Double, Double) -> Word -> ReconnectPolicy go _ 0 = DontReconnect - go (lo, hi) n = ReconnectAfter $ do + go (lo, hi) n = ReconnectAfter Nothing $ do delay <- randomRIO (lo, hi) waitFor $ round $ delay * 1_000_000 return $ go (lo * e, hi * e) (pred n) @@ -378,11 +384,11 @@ stayConnected :: -> TVar ConnectionState -> MVar () -> IO () -stayConnected connParams server connStateVar connOutOfScope = - loop (connReconnectPolicy connParams) +stayConnected connParams initialServer connStateVar connOutOfScope = do + loop initialServer (connReconnectPolicy connParams) where - loop :: ReconnectPolicy -> IO () - loop remainingReconnectPolicy = do + loop :: Server -> ReconnectPolicy -> IO () + loop server remainingReconnectPolicy = do -- Start new attempt (this just allocates some internal state) attempt <- newConnectionAttempt connParams connStateVar connOutOfScope @@ -425,9 +431,9 @@ stayConnected connParams server connStateVar connOutOfScope = atomically $ writeTVar connStateVar $ ConnectionAbandoned err (False, DontReconnect) -> do atomically $ writeTVar connStateVar $ ConnectionAbandoned err - (False, ReconnectAfter f) -> do + (False, ReconnectAfter mNewServer f) -> do atomically $ writeTVar connStateVar $ ConnectionNotReady - loop =<< f + loop (fromMaybe initialServer mNewServer) =<< f -- | Insecure connection (no TLS) connectInsecure :: ConnParams -> Attempt -> Address -> IO () diff --git a/test-grapesy/Test/Driver/ClientServer.hs b/test-grapesy/Test/Driver/ClientServer.hs index 2678a31a..11388cad 100644 --- a/test-grapesy/Test/Driver/ClientServer.hs +++ b/test-grapesy/Test/Driver/ClientServer.hs @@ -533,7 +533,7 @@ runTestClient cfg firstTestFailure port clientRun = do -- This avoids a race condition between the server starting first -- and the client starting first. , connReconnectPolicy = - Client.ReconnectAfter $ do + Client.ReconnectAfter Nothing $ do threadDelay 100_000 return Client.DontReconnect } diff --git a/test-grapesy/Test/Sanity/Disconnect.hs b/test-grapesy/Test/Sanity/Disconnect.hs index 2d164a0d..d2ba6c3e 100644 --- a/test-grapesy/Test/Sanity/Disconnect.hs +++ b/test-grapesy/Test/Sanity/Disconnect.hs @@ -24,6 +24,7 @@ import Data.Either import Data.IORef import Data.Word import Foreign.C.Types (CInt(..)) +import Network.Socket import System.Posix import Test.Tasty import Test.Tasty.HUnit @@ -59,7 +60,7 @@ test_clientDisconnect = do ] portSignal <- newEmptyMVar - void $ forkIO $ rawTestServer (pure Nothing) (putMVar portSignal) server + void $ forkIO $ rawTestServer (putMVar portSignal) server -- Start server serverPort <- readMVar portSignal @@ -126,24 +127,19 @@ test_serverDisconnect :: Assertion test_serverDisconnect = withTemporaryFile $ \ipcFile -> do -- We use a temporary file as a very rudimentary means of inter-process -- communication so the server (which runs in a separate process) can make - -- the client aware of the port it is assigned by the OS. This also helps us - -- make sure the server binds to the same port when it comes back up for - -- reconnect purposes. - let ipcWrite :: String -> IO () - ipcWrite msg = do - writeFile ipcFile "" - writeFile ipcFile msg + -- the client aware of the port it is assigned by the OS. + let ipcWrite :: PortNumber -> IO () + ipcWrite port = do + writeFile ipcFile (show port) - ipcRead :: IO String - ipcRead = readFile ipcFile - - ipcWaitRead :: IO String - ipcWaitRead = do - ipcRead >>= \case - "" -> do - threadDelay 10000 >> ipcWaitRead - msg -> do - return msg + ipcRead :: IO PortNumber + ipcRead = do + fmap (readMaybe @PortNumber) (readFile ipcFile) >>= \case + Nothing -> do + ipcRead + Just p -> do + writeFile ipcFile "" + return p -- Create the server server <- @@ -153,22 +149,22 @@ test_serverDisconnect = withTemporaryFile $ \ipcFile -> do ] let -- Starts the server in a new process. Gives back an action that kills - -- the server process. + -- the created server process. startServer :: IO (IO ()) startServer = do serverPid <- forkProcess $ - rawTestServer (readMaybe <$> ipcRead) (ipcWrite . show) server - return $ c_kill (fromIntegral serverPid) sigKILL + rawTestServer ipcWrite server + return $ signalProcess sigKILL serverPid -- Start server, get the port killServer <- startServer - serverPort <- read <$> ipcWaitRead + port1 <- ipcRead signalRestart <- newEmptyMVar - let serverAddress = + let serverAddress port = Client.ServerInsecure Client.Address { addressHost = "127.0.0.1" - , addressPort = serverPort + , addressPort = port , addressAuthority = Nothing } @@ -178,19 +174,23 @@ test_serverDisconnect = withTemporaryFile $ \ipcFile -> do go :: Int -> Client.ReconnectPolicy go n | n == 5 - = Client.ReconnectAfter $ do + = Client.ReconnectAfter Nothing $ do killRestarted <- startServer + port2 <- ipcRead putMVar signalRestart killRestarted - return $ Client.exponentialBackoff threadDelay 1 (1, 1) 100 + return $ + Client.ReconnectAfter + (Just $ serverAddress port2) + (pure Client.DontReconnect) | otherwise - = Client.ReconnectAfter $ do + = Client.ReconnectAfter Nothing $ do threadDelay 10000 return $ go (n + 1) connParams :: Client.ConnParams connParams = def { Client.connReconnectPolicy = reconnectPolicy } - Client.withConnection connParams serverAddress $ \conn -> do + Client.withConnection connParams (serverAddress port1) $ \conn -> do -- Make 50 concurrent calls. 49 of them sending infinite messages. One -- of them kills the server after 100 messages. let numCalls = 50 @@ -216,7 +216,7 @@ test_serverDisconnect = withTemporaryFile $ \ipcFile -> do killRestarted <- takeMVar signalRestart result <- Client.withRPC conn def (Proxy @Trivial) $ countUntil (pure . (>= 100)) - assertBool "" (result == 100) + assertEqual "" 100 result -- Do not leave the server process hanging around killRestarted @@ -274,7 +274,6 @@ echoHandler disconnectCounter call = trackDisconnects disconnectCounter $ do Auxiliary -------------------------------------------------------------------------------} -foreign import ccall unsafe "kill" c_kill :: CInt -> CInt -> IO () foreign import ccall unsafe "exit" c_exit :: CInt -> IO () type Trivial = RawRpc "trivial" "trivial" diff --git a/test-grapesy/Test/Util.hs b/test-grapesy/Test/Util.hs index ef8e0b3e..c0356271 100644 --- a/test-grapesy/Test/Util.hs +++ b/test-grapesy/Test/Util.hs @@ -14,8 +14,8 @@ import Control.Exception import Control.Monad.Catch import Control.Monad.IO.Class import GHC.Stack -import System.Directory import System.IO +import System.IO.Temp {------------------------------------------------------------------------------- Timeouts @@ -51,12 +51,5 @@ within t info io = do generalBracket startTimer stopTimer $ \_ -> io withTemporaryFile :: (FilePath -> IO a) -> IO a -withTemporaryFile k = do - tmpDir <- getTemporaryDirectory - Control.Exception.bracket - (openTempFile tmpDir "grapesy-test-suite.txt") - (removeFile . fst) - ( \(fp, h) -> do - hClose h - k fp - ) +withTemporaryFile k = + withSystemTempFile "grapesy-test-suite.txt" (\fp h -> hClose h >> k fp) diff --git a/test-grapesy/Test/Util/RawTestServer.hs b/test-grapesy/Test/Util/RawTestServer.hs index 10d5708e..77410c8a 100644 --- a/test-grapesy/Test/Util/RawTestServer.hs +++ b/test-grapesy/Test/Util/RawTestServer.hs @@ -16,7 +16,6 @@ import Network.Socket import Network.GRPC.Client qualified as Client import Network.HTTP.Types qualified as HTTP import Network.GRPC.Common -import Data.Maybe {------------------------------------------------------------------------------- Raw test server @@ -35,10 +34,9 @@ import Data.Maybe -- This also allows us to avoid binding to a specific port in the tests (which -- might already be in use on the machine running the tests, leading to spurious -- test failures). -rawTestServer :: IO (Maybe PortNumber) -> (PortNumber -> IO ()) -> HTTP2.Server -> IO () -rawTestServer getPort signalPort server = do - mPortIn <- fromMaybe 0 <$> getPort - addr <- NetworkRun.resolve Stream (Just "127.0.0.1") (show mPortIn) [AI_PASSIVE] +rawTestServer :: (PortNumber -> IO ()) -> HTTP2.Server -> IO () +rawTestServer signalPort server = do + addr <- NetworkRun.resolve Stream (Just "127.0.0.1") "0" [AI_PASSIVE] bracket (NetworkRun.openTCPServerSocket addr) close $ \listenSock -> do addr' <- getSocketName listenSock portOut <- case addr' of @@ -56,7 +54,7 @@ rawTestServer getPort signalPort server = do withTestServer :: HTTP2.Server -> (Client.Address -> IO a) -> IO a withTestServer server k = do serverPort <- newEmptyMVar - withAsync (rawTestServer (pure Nothing) (putMVar serverPort) server) $ + withAsync (rawTestServer (putMVar serverPort) server) $ \_serverThread -> do port <- readMVar serverPort let addr :: Client.Address