From e9adb679ecf61e43bd16e14929d38eb19a061abd Mon Sep 17 00:00:00 2001 From: Vladimir Magamedov Date: Sun, 21 Jul 2024 20:27:24 +0300 Subject: [PATCH] Fixed Channel to construct valid authority header when host is the IPv6 address, closes #197 --- grpclib/client.py | 13 +++++++++++-- tests/test_functional.py | 30 ++++++++++++++++++++++++------ 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/grpclib/client.py b/grpclib/client.py index 80af3e7..95091f3 100644 --- a/grpclib/client.py +++ b/grpclib/client.py @@ -4,6 +4,7 @@ import time import asyncio import warnings +import ipaddress from types import TracebackType from typing import Generic, Optional, Union, Type, List, Sequence, Any, cast @@ -683,9 +684,8 @@ def __init__( self._codec = codec self._status_details_codec = status_details_codec self._ssl = ssl or None - self._authority = '{}:{}'.format(self._host, self._port) self._scheme = 'https' if self._ssl else 'http' - self._authority = '{}:{}'.format(self._host, self._port) + self._authority = self._get_authority(self._host, self._port) self._h2_config = H2Configuration( client_side=True, header_encoding='ascii', @@ -779,6 +779,15 @@ def _get_default_ssl_context( ctx.set_alpn_protocols(['h2']) return ctx + def _get_authority(self, host: str, port: int) -> str: + try: + ipv6_address = ipaddress.IPv6Address(host) + except ipaddress.AddressValueError: + pass + else: + host = f"[{ipv6_address}]" + return "{}:{}".format(host, port) + def request( self, name: str, diff --git a/tests/test_functional.py b/tests/test_functional.py index 8cc9894..f904c1e 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,6 +1,7 @@ import os import socket import tempfile +import ipaddress import pytest @@ -46,19 +47,27 @@ class ClientServer: channel = None channel_ctx = None + def __init__(self, *, host="127.0.0.1"): + self.host = host + async def __aenter__(self): - host = '127.0.0.1' - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(('127.0.0.1', 0)) - _, port = s.getsockname() + try: + ipaddress.IPv6Address(self.host) + except ipaddress.AddressValueError: + family = socket.AF_INET + else: + family = socket.AF_INET6 + with socket.socket(family, socket.SOCK_STREAM) as s: + s.bind((self.host, 0)) + _, port, *_ = s.getsockname() dummy_service = DummyService() self.server = Server([dummy_service]) - await self.server.start(host, port) + await self.server.start(self.host, port) self.server_ctx = await self.server.__aenter__() - self.channel = Channel(host=host, port=port) + self.channel = Channel(host=self.host, port=port) self.channel_ctx = await self.channel.__aenter__() dummy_stub = DummyServiceStub(self.channel) return dummy_service, dummy_stub @@ -211,3 +220,12 @@ async def test_stream_stream_advanced(): assert await stream.recv_message() == DummyReply(value='baz') assert await stream.recv_message() is None + + +@pytest.mark.asyncio +@pytest.mark.skipif(not socket.has_ipv6, reason="No IPv6 support") +async def test_ipv6(): + async with ClientServer(host="::1") as (handler, stub): + reply = await stub.UnaryUnary(DummyRequest(value='ping')) + assert reply == DummyReply(value='pong') + assert handler.log == [DummyRequest(value='ping')]