Skip to content

Commit

Permalink
Fixed Channel to construct valid authority header when host is the IP…
Browse files Browse the repository at this point in the history
…v6 address, closes #197
  • Loading branch information
vmagamedov committed Jul 21, 2024
1 parent b98d2a0 commit e9adb67
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 8 deletions.
13 changes: 11 additions & 2 deletions grpclib/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
import asyncio
import warnings
import ipaddress

from types import TracebackType
from typing import Generic, Optional, Union, Type, List, Sequence, Any, cast
Expand Down Expand Up @@ -683,9 +684,8 @@ def __init__(
self._codec = codec
self._status_details_codec = status_details_codec
self._ssl = ssl or None
self._authority = '{}:{}'.format(self._host, self._port)
self._scheme = 'https' if self._ssl else 'http'
self._authority = '{}:{}'.format(self._host, self._port)
self._authority = self._get_authority(self._host, self._port)
self._h2_config = H2Configuration(
client_side=True,
header_encoding='ascii',
Expand Down Expand Up @@ -779,6 +779,15 @@ def _get_default_ssl_context(
ctx.set_alpn_protocols(['h2'])
return ctx

def _get_authority(self, host: str, port: int) -> str:
try:
ipv6_address = ipaddress.IPv6Address(host)
except ipaddress.AddressValueError:
pass
else:
host = f"[{ipv6_address}]"
return "{}:{}".format(host, port)

def request(
self,
name: str,
Expand Down
30 changes: 24 additions & 6 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import socket
import tempfile
import ipaddress

import pytest

Expand Down Expand Up @@ -46,19 +47,27 @@ class ClientServer:
channel = None
channel_ctx = None

def __init__(self, *, host="127.0.0.1"):
self.host = host

async def __aenter__(self):
host = '127.0.0.1'
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('127.0.0.1', 0))
_, port = s.getsockname()
try:
ipaddress.IPv6Address(self.host)
except ipaddress.AddressValueError:
family = socket.AF_INET
else:
family = socket.AF_INET6
with socket.socket(family, socket.SOCK_STREAM) as s:
s.bind((self.host, 0))
_, port, *_ = s.getsockname()

dummy_service = DummyService()

self.server = Server([dummy_service])
await self.server.start(host, port)
await self.server.start(self.host, port)
self.server_ctx = await self.server.__aenter__()

self.channel = Channel(host=host, port=port)
self.channel = Channel(host=self.host, port=port)
self.channel_ctx = await self.channel.__aenter__()
dummy_stub = DummyServiceStub(self.channel)
return dummy_service, dummy_stub
Expand Down Expand Up @@ -211,3 +220,12 @@ async def test_stream_stream_advanced():
assert await stream.recv_message() == DummyReply(value='baz')

assert await stream.recv_message() is None


@pytest.mark.asyncio
@pytest.mark.skipif(not socket.has_ipv6, reason="No IPv6 support")
async def test_ipv6():
async with ClientServer(host="::1") as (handler, stub):
reply = await stub.UnaryUnary(DummyRequest(value='ping'))
assert reply == DummyReply(value='pong')
assert handler.log == [DummyRequest(value='ping')]

0 comments on commit e9adb67

Please sign in to comment.