Skip to content

Commit

Permalink
Refactored server close logic to gracefully exit without using GOAWAY…
Browse files Browse the repository at this point in the history
… frames
  • Loading branch information
vmagamedov committed May 19, 2024
1 parent 5916cba commit 6cd97ca
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 23 deletions.
9 changes: 6 additions & 3 deletions grpclib/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@


class Handler(AbstractHandler):
connection_lost = False
closing = False

def connection_made(self, connection: Any) -> None:
pass

def accept(self, stream: Any, headers: Any, release_stream: Any) -> None:
raise NotImplementedError('Client connection can not accept requests')
Expand All @@ -71,7 +74,7 @@ def cancel(self, stream: Any) -> None:
pass

def close(self) -> None:
self.connection_lost = True
self.closing = True


class Stream(StreamIterator[_RecvType], Generic[_SendType, _RecvType]):
Expand Down Expand Up @@ -737,7 +740,7 @@ async def _create_connection(self) -> H2Protocol:
@property
def _connected(self) -> bool:
return (self._protocol is not None
and not self._protocol.handler.connection_lost)
and not cast(Handler, self._protocol.handler).closing)

async def __connect__(self) -> H2Protocol:
if not self._connected:
Expand Down
5 changes: 5 additions & 0 deletions grpclib/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,10 @@ def closable(self) -> bool:

class AbstractHandler(ABC):

@abstractmethod
def connection_made(self, connection: Connection) -> None:
pass

@abstractmethod
def accept(
self,
Expand Down Expand Up @@ -709,6 +713,7 @@ def connection_made(self, transport: BaseTransport) -> None:
self.connection.flush()
self.connection.initialize()

self.handler.connection_made(self.connection)
self.processor = EventsProcessor(self.handler, self.connection)

def data_received(self, data: bytes) -> None:
Expand Down
53 changes: 33 additions & 20 deletions grpclib/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import asyncio
import warnings
from functools import partial

from types import TracebackType
from typing import TYPE_CHECKING, Optional, Collection, Generic, Type, cast
Expand All @@ -12,6 +13,7 @@

import h2.config
import h2.exceptions
from h2.errors import ErrorCodes

from multidict import MultiDict

Expand All @@ -24,7 +26,7 @@
from .metadata import Deadline, encode_grpc_message, _Metadata
from .metadata import encode_metadata, decode_metadata, _MetadataLike
from .metadata import _STATUS_DETAILS_KEY, encode_bin_value
from .protocol import H2Protocol, AbstractHandler
from .protocol import H2Protocol, AbstractHandler, Connection
from .exceptions import GRPCError, ProtocolError, StreamTerminatedError
from .encoding.base import GRPC_CONTENT_TYPE, CodecBase, StatusDetailsCodecBase
from .encoding.proto import ProtoCodec, ProtoStatusDetailsCodec
Expand Down Expand Up @@ -496,6 +498,7 @@ def __gc_step__(self) -> None:
class Handler(_GC, AbstractHandler):
__gc_interval__ = 10

connection: Connection
closing = False

def __init__(
Expand All @@ -511,44 +514,54 @@ def __init__(
self.dispatch = dispatch
self.loop = asyncio.get_event_loop()
self._tasks: Dict['protocol.Stream', 'asyncio.Task[None]'] = {}
self._cancelled: Set['asyncio.Task[None]'] = set()

def __gc_collect__(self) -> None:
self._tasks = {s: t for s, t in self._tasks.items()
if not t.done()}
self._cancelled = {t for t in self._cancelled
if not t.done()}
self._tasks = {s: t for s, t in self._tasks.items() if not t.done()}

def connection_made(self, connection: Connection) -> None:
self.connection = connection

def handler_done(self, stream: 'protocol.Stream', _: Any) -> None:
self._tasks.pop(stream, None)
if not self._tasks:
self.connection.close()

def accept(
self,
stream: 'protocol.Stream',
headers: _Headers,
release_stream: Callable[[], Any],
) -> None:
self.__gc_step__()
self._tasks[stream] = self.loop.create_task(request_handler(
self.mapping, stream, headers, self.codec,
self.status_details_codec, self.dispatch, release_stream,
))
if self.closing:
stream.reset_nowait(ErrorCodes.REFUSED_STREAM)
release_stream()
else:
self.__gc_step__()
self._tasks[stream] = self.loop.create_task(request_handler(
self.mapping, stream, headers, self.codec,
self.status_details_codec, self.dispatch, release_stream,
))

def cancel(self, stream: 'protocol.Stream') -> None:
task = self._tasks.pop(stream)
task.cancel()
self._cancelled.add(task)
self._tasks[stream].cancel()

def close(self) -> None:
for task in self._tasks.values():
self.__gc_collect__()
for stream, task in self._tasks.items():
task.add_done_callback(partial(self.handler_done, stream))
task.cancel()
self._cancelled.update(self._tasks.values())
self.closing = True

async def wait_closed(self) -> None:
if self._cancelled:
await asyncio.wait(self._cancelled)
self.__gc_collect__()
if self._tasks:
await asyncio.wait(self._tasks.values())
else:
self.connection.close()

def check_closed(self) -> bool:
self.__gc_collect__()
return not self._tasks and not self._cancelled
return not self._tasks


class Server(_GC):
Expand Down Expand Up @@ -737,11 +750,11 @@ async def wait_closed(self) -> None:
if self._server is None or self._server_closed_fut is None:
raise RuntimeError('Server is not started')
await self._server_closed_fut
await self._server.wait_closed()
if self._handlers:
await asyncio.wait({
self._loop.create_task(h.wait_closed()) for h in self._handlers
})
await self._server.wait_closed()

async def __aenter__(self) -> 'Server':
return self
Expand Down
3 changes: 3 additions & 0 deletions tests/stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class DummyHandler(AbstractHandler):
headers = None
release_stream = None

def connection_made(self, connection):
pass

def accept(self, stream, headers, release_stream):
self.stream = stream
self.headers = headers
Expand Down

0 comments on commit 6cd97ca

Please sign in to comment.