diff --git a/src/Network/GRPC/Client/Connection.hs b/src/Network/GRPC/Client/Connection.hs index 678280a1..37e0de02 100644 --- a/src/Network/GRPC/Client/Connection.hs +++ b/src/Network/GRPC/Client/Connection.hs @@ -112,6 +112,19 @@ data ConnParams = ConnParams { -- (this is not conform gRPC spec). , connContentType :: Maybe ContentType + -- | Should we verify all request headers? + -- + -- This is the client analogue of + -- 'Network.GRPC.Server.Context.serverVerifyHeaders'. + -- + -- Arguably, it is less essential to verify headers on the client: a + -- server must deal with all kinds of different clients, and might want to + -- know if any of those clients has expectations that it cannot fulfill. A + -- client however connects to a known server, and knows what information + -- it wants from the server. It is also a bit more awkward to implement, + -- since the client is more asynchronous than the server handler. + , connVerifyHeaders :: Bool + -- | Optionally set the initial compression algorithm -- -- Under normal circumstances, the @grapesy@ client will only start using @@ -129,12 +142,13 @@ data ConnParams = ConnParams { instance Default ConnParams where def = ConnParams { - connCompression = def - , connDefaultTimeout = Nothing - , connReconnectPolicy = def - , connContentType = Just ContentTypeDefault - , connInitCompression = Nothing - , connHTTP2Settings = def + connCompression = def + , connDefaultTimeout = Nothing + , connReconnectPolicy = def + , connContentType = Just ContentTypeDefault + , connVerifyHeaders = False + , connInitCompression = Nothing + , connHTTP2Settings = def } {------------------------------------------------------------------------------- diff --git a/src/Network/GRPC/Client/Session.hs b/src/Network/GRPC/Client/Session.hs index 4cf19338..5e1332b3 100644 --- a/src/Network/GRPC/Client/Session.hs +++ b/src/Network/GRPC/Client/Session.hs @@ -21,6 +21,7 @@ import Network.GRPC.Common import Network.GRPC.Common.Compression qualified as Compr import Network.GRPC.Spec import Network.GRPC.Util.Session +import Network.GRPC.Util.HKD qualified as HKD {------------------------------------------------------------------------------- Definition @@ -126,17 +127,26 @@ processResponseHeaders :: -> IO Compression processResponseHeaders (ClientSession conn) responseHeaders' = do Connection.updateConnectionMeta conn responseHeaders' - case responseCompression responseHeaders' of - Left err -> throwIO $ CallSetupInvalidResponseHeaders err - Right Nothing -> return noCompression - Right (Just cid) -> - case Compr.getSupported (connCompression connParams) cid of - Just compr -> return compr - Nothing -> throwIO $ CallSetupUnsupportedCompression cid + + if connVerifyHeaders connParams then + case HKD.sequence responseHeaders' of + Left err -> throwIO $ CallSetupInvalidResponseHeaders err + Right hdrs -> getCompression $ responseCompression hdrs + else + case responseCompression responseHeaders' of + Left err -> throwIO $ CallSetupInvalidResponseHeaders err + Right mcid -> getCompression mcid where connParams :: ConnParams connParams = Connection.connParams conn + getCompression :: Maybe CompressionId -> IO Compression + getCompression Nothing = return noCompression + getCompression (Just cid) = + case Compr.getSupported (connCompression connParams) cid of + Just compr -> return compr + Nothing -> throwIO $ CallSetupUnsupportedCompression cid + {------------------------------------------------------------------------------- Exceptions -------------------------------------------------------------------------------} diff --git a/test-grapesy/Test/Driver/ClientServer.hs b/test-grapesy/Test/Driver/ClientServer.hs index 3fbcbf6a..ffa1cd87 100644 --- a/test-grapesy/Test/Driver/ClientServer.hs +++ b/test-grapesy/Test/Driver/ClientServer.hs @@ -525,6 +525,8 @@ runTestClient cfg firstTestFailure port clientRun = do connCompression = clientCompr cfg , connInitCompression = clientInitCompr cfg , connDefaultTimeout = Nothing + , connVerifyHeaders = True + , connHTTP2Settings = defaultHTTP2Settings -- Content-type , connContentType = @@ -540,7 +542,6 @@ runTestClient cfg firstTestFailure port clientRun = do Client.ReconnectAfter $ do threadDelay 100_000 return Client.DontReconnect - , connHTTP2Settings = defaultHTTP2Settings } clientServer :: Client.Server diff --git a/test-grapesy/Test/Sanity/BrokenDeployments.hs b/test-grapesy/Test/Sanity/BrokenDeployments.hs index 6b9fc92c..fdd61541 100644 --- a/test-grapesy/Test/Sanity/BrokenDeployments.hs +++ b/test-grapesy/Test/Sanity/BrokenDeployments.hs @@ -25,7 +25,8 @@ import Proto.API.Ping tests :: TestTree tests = testGroup "Test.Sanity.BrokenDeployments" [ - testCase "non200" test_non200 + testCase "non200" test_non200 + , testCase "nonGrpcContentType" test_nonGrpcContentType ] {------------------------------------------------------------------------------- @@ -37,9 +38,9 @@ tests = testGroup "Test.Sanity.BrokenDeployments" [ -- We don't test all codes here; we'd just end up duplicating the logic in -- 'classifyServerResponse'. We just check one representative value. test_non200 :: Assertion -test_non200 = respondWith response $ \addr -> do +test_non200 = respondWith response $ \addr -> do mResp :: Either GrpcException (Proto PongMessage) <- try $ - Client.withConnection def (Client.ServerInsecure addr) $ \conn -> + Client.withConnection connParams (Client.ServerInsecure addr) $ \conn -> Client.withRPC conn def (Proxy @Ping) $ \call -> do Client.sendFinalInput call defMessage fst <$> Client.recvFinalOutput call @@ -54,6 +55,33 @@ test_non200 = respondWith response $ \addr -> do responseStatus = HTTP.badRequest400 } +test_nonGrpcContentType :: Assertion +test_nonGrpcContentType = respondWith response $ \addr -> do + mResp <- try $ + Client.withConnection connParams (Client.ServerInsecure addr) $ \conn -> + Client.withRPC conn def (Proxy @Ping) $ \call -> do + Client.sendFinalInput call defMessage + fst <$> Client.recvFinalOutput call + case mResp of + -- TODO: + -- We should get a gRPC exception here instead. + Left Client.CallSetupInvalidResponseHeaders{} -> + return () + _otherwise -> + assertFailure $ "Unexpected response: " ++ show mResp + where + response :: Response + response = def { + responseHeaders = [ + ("content-type", "someInvalidContentType") + ] + } + +connParams :: Client.ConnParams +connParams = def { + Client.connVerifyHeaders = True + } + {------------------------------------------------------------------------------- Test server