Skip to content

Commit

Permalink
Introduce connVerifyHeaders
Browse files Browse the repository at this point in the history
Closes #175.
  • Loading branch information
edsko committed Jul 24, 2024
1 parent 16f0761 commit cb7e24e
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 17 deletions.
26 changes: 20 additions & 6 deletions src/Network/GRPC/Client/Connection.hs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,19 @@ data ConnParams = ConnParams {
-- (this is not conform gRPC spec).
, connContentType :: Maybe ContentType

-- | Should we verify all request headers?
--
-- This is the client analogue of
-- 'Network.GRPC.Server.Context.serverVerifyHeaders'.
--
-- Arguably, it is less essential to verify headers on the client: a
-- server must deal with all kinds of different clients, and might want to
-- know if any of those clients has expectations that it cannot fulfill. A
-- client however connects to a known server, and knows what information
-- it wants from the server. It is also a bit more awkward to implement,
-- since the client is more asynchronous than the server handler.
, connVerifyHeaders :: Bool

-- | Optionally set the initial compression algorithm
--
-- Under normal circumstances, the @grapesy@ client will only start using
Expand All @@ -129,12 +142,13 @@ data ConnParams = ConnParams {

instance Default ConnParams where
def = ConnParams {
connCompression = def
, connDefaultTimeout = Nothing
, connReconnectPolicy = def
, connContentType = Just ContentTypeDefault
, connInitCompression = Nothing
, connHTTP2Settings = def
connCompression = def
, connDefaultTimeout = Nothing
, connReconnectPolicy = def
, connContentType = Just ContentTypeDefault
, connVerifyHeaders = False
, connInitCompression = Nothing
, connHTTP2Settings = def
}

{-------------------------------------------------------------------------------
Expand Down
24 changes: 17 additions & 7 deletions src/Network/GRPC/Client/Session.hs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import Network.GRPC.Common
import Network.GRPC.Common.Compression qualified as Compr
import Network.GRPC.Spec
import Network.GRPC.Util.Session
import Network.GRPC.Util.HKD qualified as HKD

{-------------------------------------------------------------------------------
Definition
Expand Down Expand Up @@ -126,17 +127,26 @@ processResponseHeaders ::
-> IO Compression
processResponseHeaders (ClientSession conn) responseHeaders' = do
Connection.updateConnectionMeta conn responseHeaders'
case responseCompression responseHeaders' of
Left err -> throwIO $ CallSetupInvalidResponseHeaders err
Right Nothing -> return noCompression
Right (Just cid) ->
case Compr.getSupported (connCompression connParams) cid of
Just compr -> return compr
Nothing -> throwIO $ CallSetupUnsupportedCompression cid

if connVerifyHeaders connParams then
case HKD.sequence responseHeaders' of
Left err -> throwIO $ CallSetupInvalidResponseHeaders err
Right hdrs -> getCompression $ responseCompression hdrs
else
case responseCompression responseHeaders' of
Left err -> throwIO $ CallSetupInvalidResponseHeaders err
Right mcid -> getCompression mcid
where
connParams :: ConnParams
connParams = Connection.connParams conn

getCompression :: Maybe CompressionId -> IO Compression
getCompression Nothing = return noCompression
getCompression (Just cid) =
case Compr.getSupported (connCompression connParams) cid of
Just compr -> return compr
Nothing -> throwIO $ CallSetupUnsupportedCompression cid

{-------------------------------------------------------------------------------
Exceptions
-------------------------------------------------------------------------------}
Expand Down
3 changes: 2 additions & 1 deletion test-grapesy/Test/Driver/ClientServer.hs
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,8 @@ runTestClient cfg firstTestFailure port clientRun = do
connCompression = clientCompr cfg
, connInitCompression = clientInitCompr cfg
, connDefaultTimeout = Nothing
, connVerifyHeaders = True
, connHTTP2Settings = defaultHTTP2Settings

-- Content-type
, connContentType =
Expand All @@ -540,7 +542,6 @@ runTestClient cfg firstTestFailure port clientRun = do
Client.ReconnectAfter $ do
threadDelay 100_000
return Client.DontReconnect
, connHTTP2Settings = defaultHTTP2Settings
}

clientServer :: Client.Server
Expand Down
34 changes: 31 additions & 3 deletions test-grapesy/Test/Sanity/BrokenDeployments.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ import Proto.API.Ping

tests :: TestTree
tests = testGroup "Test.Sanity.BrokenDeployments" [
testCase "non200" test_non200
testCase "non200" test_non200
, testCase "nonGrpcContentType" test_nonGrpcContentType
]

{-------------------------------------------------------------------------------
Expand All @@ -37,9 +38,9 @@ tests = testGroup "Test.Sanity.BrokenDeployments" [
-- We don't test all codes here; we'd just end up duplicating the logic in
-- 'classifyServerResponse'. We just check one representative value.
test_non200 :: Assertion
test_non200 = respondWith response $ \addr -> do
test_non200 = respondWith response $ \addr -> do
mResp :: Either GrpcException (Proto PongMessage) <- try $
Client.withConnection def (Client.ServerInsecure addr) $ \conn ->
Client.withConnection connParams (Client.ServerInsecure addr) $ \conn ->
Client.withRPC conn def (Proxy @Ping) $ \call -> do
Client.sendFinalInput call defMessage
fst <$> Client.recvFinalOutput call
Expand All @@ -54,6 +55,33 @@ test_non200 = respondWith response $ \addr -> do
responseStatus = HTTP.badRequest400
}

test_nonGrpcContentType :: Assertion
test_nonGrpcContentType = respondWith response $ \addr -> do
mResp <- try $
Client.withConnection connParams (Client.ServerInsecure addr) $ \conn ->
Client.withRPC conn def (Proxy @Ping) $ \call -> do
Client.sendFinalInput call defMessage
fst <$> Client.recvFinalOutput call
case mResp of
-- TODO: <https://github.com/well-typed/grapesy/issues/22>
-- We should get a gRPC exception here instead.
Left Client.CallSetupInvalidResponseHeaders{} ->
return ()
_otherwise ->
assertFailure $ "Unexpected response: " ++ show mResp
where
response :: Response
response = def {
responseHeaders = [
("content-type", "someInvalidContentType")
]
}

connParams :: Client.ConnParams
connParams = def {
Client.connVerifyHeaders = True
}

{-------------------------------------------------------------------------------
Test server
Expand Down

0 comments on commit cb7e24e

Please sign in to comment.