diff --git a/posthog/temporal/common/clickhouse.py b/posthog/temporal/common/clickhouse.py index 147f09813e2c6..613b77525c519 100644 --- a/posthog/temporal/common/clickhouse.py +++ b/posthog/temporal/common/clickhouse.py @@ -2,6 +2,7 @@ import contextlib import datetime as dt import json +import ssl import typing import uuid @@ -76,6 +77,13 @@ def encode_clickhouse_data(data: typing.Any, quote_char="'") -> bytes: return f"{quote_char}{str_data}{quote_char}".encode() +class ClickHouseClientNotConnected(Exception): + """Exception raised when attempting to run an async query without connecting.""" + + def __init__(self): + super().__init__("ClickHouseClient is not connected. Are you running in a context manager?") + + class ClickHouseError(Exception): """Base Exception representing anything going wrong with ClickHouse.""" @@ -97,21 +105,21 @@ class ClickHouseClient: def __init__( self, - session: aiohttp.ClientSession | None = None, url: str = "http://localhost:8123", user: str = "default", password: str = "", database: str = "default", + timeout: None | aiohttp.ClientTimeout = None, + ssl: ssl.SSLContext | bool = True, **kwargs, ): - if session is None: - self.session = aiohttp.ClientSession() - else: - self.session = session - self.url = url self.headers = {} self.params = {} + self.timeout = timeout + self.ssl = ssl + self.connector: None | aiohttp.TCPConnector = None + self.session: None | aiohttp.ClientSession = None if user: self.headers["X-ClickHouse-User"] = user @@ -123,10 +131,9 @@ def __init__( self.params.update(kwargs) @classmethod - def from_posthog_settings(cls, session, settings, **kwargs): + def from_posthog_settings(cls, settings, **kwargs): """Initialize a ClickHouseClient from PostHog settings.""" return cls( - session=session, url=settings.CLICKHOUSE_URL, user=settings.CLICKHOUSE_USER, password=settings.CLICKHOUSE_PASSWORD, @@ -140,6 +147,9 @@ async def is_alive(self) -> bool: Returns: A boolean indicating whether the connection is alive. """ + if self.session is None: + raise ClickHouseClientNotConnected() + try: await self.session.get( url=self.url, @@ -217,6 +227,8 @@ async def aget_query( Returns: The response received from the ClickHouse HTTP interface. """ + if self.session is None: + raise ClickHouseClientNotConnected() params = {**self.params} if query_id is not None: @@ -245,6 +257,8 @@ async def apost_query( Returns: The response received from the ClickHouse HTTP interface. """ + if self.session is None: + raise ClickHouseClientNotConnected() params = {**self.params} if query_id is not None: @@ -378,11 +392,21 @@ async def astream_query_as_arrow( async def __aenter__(self): """Enter method part of the AsyncContextManager protocol.""" + self.connector = aiohttp.TCPConnector(ssl=self.ssl) + self.session = aiohttp.ClientSession(connector=self.connector, timeout=self.timeout) return self async def __aexit__(self, exc_type, exc_value, tb): """Exit method part of the AsyncContextManager protocol.""" - await self.session.close() + if self.session is not None: + await self.session.close() + + if self.connector is not None: + await self.connector.close() + + self.session = None + self.connector = None + return False @contextlib.asynccontextmanager @@ -427,19 +451,17 @@ async def get_client( team_id, settings.CLICKHOUSE_MAX_BLOCK_SIZE_DEFAULT ) - with aiohttp.TCPConnector(ssl=False) as connector: - async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session: - async with ClickHouseClient( - session, - url=settings.CLICKHOUSE_OFFLINE_HTTP_URL, - user=settings.CLICKHOUSE_USER, - password=settings.CLICKHOUSE_PASSWORD, - database=settings.CLICKHOUSE_DATABASE, - max_execution_time=settings.CLICKHOUSE_MAX_EXECUTION_TIME, - max_memory_usage=settings.CLICKHOUSE_MAX_MEMORY_USAGE, - max_block_size=max_block_size, - cancel_http_readonly_queries_on_client_close=1, - output_format_arrow_string_as_string="true", - **kwargs, - ) as client: - yield client + async with ClickHouseClient( + url=settings.CLICKHOUSE_OFFLINE_HTTP_URL, + user=settings.CLICKHOUSE_USER, + password=settings.CLICKHOUSE_PASSWORD, + database=settings.CLICKHOUSE_DATABASE, + timeout=timeout, + ssl=False, + max_execution_time=settings.CLICKHOUSE_MAX_EXECUTION_TIME, + max_memory_usage=settings.CLICKHOUSE_MAX_MEMORY_USAGE, + max_block_size=max_block_size, + output_format_arrow_string_as_string="true", + **kwargs, + ) as client: + yield client diff --git a/posthog/temporal/tests/conftest.py b/posthog/temporal/tests/conftest.py index f7802d6252875..f88d74009385d 100644 --- a/posthog/temporal/tests/conftest.py +++ b/posthog/temporal/tests/conftest.py @@ -1,14 +1,14 @@ import asyncio import random +import psycopg import pytest import pytest_asyncio import temporalio.worker from asgiref.sync import sync_to_async from django.conf import settings -from temporalio.testing import ActivityEnvironment -import psycopg from psycopg import sql +from temporalio.testing import ActivityEnvironment from posthog.models import Organization, Team from posthog.temporal.common.clickhouse import ClickHouseClient @@ -65,10 +65,10 @@ def activity_environment(): return ActivityEnvironment() -@pytest.fixture(scope="module") -def clickhouse_client(): +@pytest_asyncio.fixture(scope="module") +async def clickhouse_client(): """Provide a ClickHouseClient to use in tests.""" - client = ClickHouseClient( + async with ClickHouseClient( url=settings.CLICKHOUSE_HTTP_URL, user=settings.CLICKHOUSE_USER, password=settings.CLICKHOUSE_PASSWORD, @@ -78,9 +78,8 @@ def clickhouse_client(): # Durting testing, it's useful to enable it to wait for mutations. # Otherwise, tests that rely on running a mutation may become flaky. mutations_sync=2, - ) - - yield client + ) as client: + yield client @pytest_asyncio.fixture