Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Terrafrom] Add rate limiting #1084

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 87 additions & 24 deletions integrations/terraform-cloud/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio
import time
from typing import Any, AsyncGenerator, Optional
from port_ocean.utils import http_async_client
import httpx
Expand All @@ -23,6 +25,12 @@ class CacheKeys(StrEnum):
PAGE_SIZE = 100


# https://developer.hashicorp.com/terraform/cloud-docs/api-docs#rate-limiting
RATE_LIMIT_PER_SECOND = 30
RATE_LIMIT_BUFFER = 5 # Buffer to avoid hitting the exact limit
MAX_CONCURRENT_REQUESTS = 10
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why 10 ?



class TerraformClient:
def __init__(self, terraform_base_url: str, auth_token: str) -> None:
self.terraform_base_url = terraform_base_url
Expand All @@ -32,7 +40,30 @@ def __init__(self, terraform_base_url: str, auth_token: str) -> None:
}
self.api_url = f"{self.terraform_base_url}/api/v2"
self.client = http_async_client
self.client.headers.update(self.base_headers)

self.rate_limit = RATE_LIMIT_PER_SECOND
self.rate_limit_remaining = RATE_LIMIT_PER_SECOND
self.rate_limit_reset: float = 0.0
self.last_request_time = time.time()
self.request_times: list[float] = []
self.semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
self.rate_limit_lock = asyncio.Lock()

async def wait_for_rate_limit(self) -> None:
async with self.rate_limit_lock:
current_time = time.time()
self.request_times = [t for t in self.request_times if current_time - t < 1]

if len(self.request_times) >= RATE_LIMIT_PER_SECOND:
wait_time = 1 - (current_time - self.request_times[0])
if wait_time > 0:
logger.info(
f"Rate limit reached, waiting for {wait_time:.2f} seconds"
)
await asyncio.sleep(wait_time)
self.request_times = self.request_times[1:]

self.request_times.append(current_time)
Comment on lines +43 to +66
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


async def send_api_request(
self,
Expand All @@ -41,32 +72,64 @@ async def send_api_request(
query_params: Optional[dict[str, Any]] = None,
json_data: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
logger.info(f"Requesting Terraform Cloud data for endpoint: {endpoint}")
try:
url = f"{self.api_url}/{endpoint}"
logger.info(
f"URL: {url}, Method: {method}, Params: {query_params}, Body: {json_data}"
)
response = await self.client.request(
method=method,
url=url,
params=query_params,
json=json_data,
)
response.raise_for_status()
async with self.semaphore:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a concurrency constraint on the Terraform Cloud API as well ?

await self.wait_for_rate_limit()

logger.info(f"Successfully retrieved data for endpoint: {endpoint}")
logger.info(f"Requesting Terraform Cloud data for endpoint: {endpoint}")
try:
url = f"{self.api_url}/{endpoint}"
logger.info(
f"URL: {url}, Method: {method}, Params: {query_params}, Body: {json_data}"
)
response = await self.client.request(
method=method,
url=url,
params=query_params,
json=json_data,
headers=self.base_headers,
)
response.raise_for_status()

return response.json()
async with self.rate_limit_lock:
self.rate_limit = int(
response.headers.get("x-ratelimit-limit", RATE_LIMIT_PER_SECOND)
)
self.rate_limit_remaining = int(
response.headers.get(
"x-ratelimit-remaining", RATE_LIMIT_PER_SECOND
)
)
self.rate_limit_reset = float(
response.headers.get("x-ratelimit-reset", "0")
)

except httpx.HTTPStatusError as e:
logger.error(
f"HTTP error on {endpoint}: {e.response.status_code} - {e.response.text}"
)
raise
except httpx.HTTPError as e:
logger.error(f"HTTP error on {endpoint}: {str(e)}")
raise
logger.debug(f"Successfully retrieved data for endpoint: {endpoint}")
logger.debug(
f"Rate limit: {self.rate_limit_remaining}/{self.rate_limit}"
)
logger.debug(f"Rate limit reset: {self.rate_limit_reset}")

return response.json()

except httpx.HTTPStatusError as e:
if e.response.status_code == 429:
retry_after = float(
e.response.headers.get("x-ratelimit-reset", "1")
)
logger.warning(
f"Rate limit exceeded. Waiting for {retry_after} seconds before retrying."
)
await asyncio.sleep(retry_after)
return await self.send_api_request(
endpoint, method, query_params, json_data
)
logger.error(
f"HTTP error on {endpoint}: {e.response.status_code} - {e.response.text}"
)
raise
except httpx.HTTPError as e:
logger.error(f"HTTP error on {endpoint}: {str(e)}")
raise

async def get_paginated_resources(
self, endpoint: str, params: Optional[dict[str, Any]] = None
Expand Down
94 changes: 44 additions & 50 deletions integrations/terraform-cloud/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
from asyncio import gather
from enum import StrEnum
from typing import Any, List, Dict
Expand All @@ -19,7 +18,7 @@ class ObjectKind(StrEnum):


SKIP_WEBHOOK_CREATION = False
SEMAPHORE_LIMIT = 30
CHUNK_SIZE = 10


def init_terraform_client() -> TerraformClient:
Expand All @@ -40,51 +39,25 @@ def init_terraform_client() -> TerraformClient:
async def enrich_state_versions_with_output_data(
http_client: TerraformClient, state_versions: List[dict[str, Any]]
) -> list[dict[str, Any]]:
async with asyncio.BoundedSemaphore(SEMAPHORE_LIMIT):
tasks = [
http_client.get_state_version_output(state_version["id"])
for state_version in state_versions
]

output_batches = []
for completed_task in asyncio.as_completed(tasks):
output = await completed_task
output_batches.append(output)
async def get_output(state_version: dict[str, Any]) -> dict[str, Any]:
output = await http_client.get_state_version_output(state_version["id"])
return {**state_version, "__output": output}

enriched_state_versions = [
{**state_version, "__output": output}
for state_version, output in zip(state_versions, output_batches)
]
enriched_versions = []
for chunk in [
state_versions[i : i + CHUNK_SIZE]
for i in range(0, len(state_versions), CHUNK_SIZE)
]:
chunk_results = await gather(*[get_output(sv) for sv in chunk])
enriched_versions.extend(chunk_results)

return enriched_state_versions
return enriched_versions


async def enrich_workspaces_with_tags(
http_client: TerraformClient, workspaces: List[dict[str, Any]]
) -> list[dict[str, Any]]:
async def get_tags_for_workspace(workspace: dict[str, Any]) -> dict[str, Any]:
async with asyncio.BoundedSemaphore(SEMAPHORE_LIMIT):
try:
tags = []
async for tag_batch in http_client.get_workspace_tags(workspace["id"]):
tags.extend(tag_batch)
return {**workspace, "__tags": tags}
except Exception as e:
logger.warning(
f"Failed to fetch tags for workspace {workspace['id']}: {e}"
)
return {**workspace, "__tags": []}

tasks = [get_tags_for_workspace(workspace) for workspace in workspaces]
enriched_workspaces = [await task for task in asyncio.as_completed(tasks)]

return enriched_workspaces


async def enrich_workspace_with_tags(
http_client: TerraformClient, workspace: dict[str, Any]
) -> dict[str, Any]:
async with asyncio.BoundedSemaphore(SEMAPHORE_LIMIT):
try:
tags = []
async for tag_batch in http_client.get_workspace_tags(workspace["id"]):
Expand All @@ -94,6 +67,28 @@ async def enrich_workspace_with_tags(
logger.warning(f"Failed to fetch tags for workspace {workspace['id']}: {e}")
return {**workspace, "__tags": []}

enriched_workspaces = []
for chunk in [
workspaces[i : i + CHUNK_SIZE] for i in range(0, len(workspaces), CHUNK_SIZE)
]:
chunk_results = await gather(*[get_tags_for_workspace(w) for w in chunk])
enriched_workspaces.extend(chunk_results)

return enriched_workspaces


async def enrich_workspace_with_tags(
http_client: TerraformClient, workspace: dict[str, Any]
) -> dict[str, Any]:
try:
tags = []
async for tag_batch in http_client.get_workspace_tags(workspace["id"]):
tags.extend(tag_batch)
return {**workspace, "__tags": tags}
except Exception as e:
logger.warning(f"Failed to fetch tags for workspace {workspace['id']}: {e}")
return {**workspace, "__tags": []}


@ocean.on_resync(ObjectKind.ORGANIZATION)
async def resync_organizations(kind: str) -> ASYNC_GENERATOR_RESYNC_TYPE:
Expand Down Expand Up @@ -136,21 +131,20 @@ async def fetch_runs_for_workspace(
)
]

async def fetch_runs_for_all_workspaces() -> ASYNC_GENERATOR_RESYNC_TYPE:
async for workspaces in terraform_client.get_paginated_workspaces():
logger.info(
f"Received {len(workspaces)} batch workspaces... fetching its associated {kind}"
)
async for workspaces in terraform_client.get_paginated_workspaces():
logger.info(
f"Received {len(workspaces)} batch workspaces... fetching its associated {kind}"
)

tasks = [fetch_runs_for_workspace(workspace) for workspace in workspaces]
for completed_task in asyncio.as_completed(tasks):
workspace_runs = await completed_task
for chunk in [
workspaces[i : i + CHUNK_SIZE]
for i in range(0, len(workspaces), CHUNK_SIZE)
]:
chunk_results = await gather(*[fetch_runs_for_workspace(w) for w in chunk])
for workspace_runs in chunk_results:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a specific reason for replacing as_completed with gather ? waiting for all tasks within a chunk to complete before moving on to the next chunk appears to impede performance in this context. Please correct me.

for runs in workspace_runs:
yield runs

async for run_batch in fetch_runs_for_all_workspaces():
yield run_batch


@ocean.on_resync(ObjectKind.STATE_VERSION)
async def resync_state_versions(kind: str) -> ASYNC_GENERATOR_RESYNC_TYPE:
Expand Down
107 changes: 107 additions & 0 deletions integrations/terraform-cloud/tests/test_terraform_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import asyncio
from typing import AsyncGenerator
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from client import TerraformClient, RATE_LIMIT_PER_SECOND
import time
import httpx


@pytest.fixture
def mock_http_client() -> AsyncMock:
return AsyncMock(spec=httpx.AsyncClient)


@pytest.fixture
async def terraform_client(
mock_http_client: AsyncMock,
) -> AsyncGenerator[TerraformClient, None]:
with patch("client.http_async_client", mock_http_client):
client = TerraformClient("https://app.terraform.io", "test_token")
client.rate_limit_lock = asyncio.Lock()
# Manually set the headers to avoid the coroutine warning
client.client.headers = httpx.Headers(client.base_headers)
yield client


@pytest.mark.asyncio
async def test_wait_for_rate_limit(terraform_client: TerraformClient) -> None:
current_time = time.time()
with patch("time.time", side_effect=[current_time, current_time + 0.1]):
with patch.object(asyncio, "sleep", new_callable=AsyncMock) as mock_sleep:
# Simulate rate limit not reached
terraform_client.request_times = [current_time - 0.1] * (
RATE_LIMIT_PER_SECOND - 1
)
await terraform_client.wait_for_rate_limit()
mock_sleep.assert_not_called()

# Simulate rate limit reached
terraform_client.request_times = [
current_time - 0.1
] * RATE_LIMIT_PER_SECOND
await terraform_client.wait_for_rate_limit()
mock_sleep.assert_called_once()
assert mock_sleep.call_args[0][0] > 0 # Ensure sleep time is positive

# Test when wait time is not needed
current_time = time.time()
with patch("time.time", return_value=current_time):
with patch.object(asyncio, "sleep", new_callable=AsyncMock) as mock_sleep:
terraform_client.request_times = [
current_time - 1.1
] * RATE_LIMIT_PER_SECOND
await terraform_client.wait_for_rate_limit()
mock_sleep.assert_not_called()


@pytest.mark.asyncio
async def test_send_api_request(
terraform_client: TerraformClient, mock_http_client: AsyncMock
) -> None:
mock_response = MagicMock()
mock_response.json.return_value = {"data": [{"id": "test"}]}
mock_response.headers = {
"x-ratelimit-limit": "30",
"x-ratelimit-remaining": "29",
"x-ratelimit-reset": "1.0",
}
mock_http_client.request.return_value = mock_response

result = await terraform_client.send_api_request("test_endpoint")

expected_headers = {
"Authorization": "Bearer test_token",
"Content-Type": "application/vnd.api+json",
}

mock_http_client.request.assert_called_once_with(
method="GET",
url="https://app.terraform.io/api/v2/test_endpoint",
params=None,
json=None,
headers=expected_headers,
)

assert result == {"data": [{"id": "test"}]}


@pytest.mark.asyncio
async def test_get_paginated_resources(terraform_client: TerraformClient) -> None:
mock_responses = [
{"data": [{"id": "1"}, {"id": "2"}], "links": {"next": "page2"}},
{"data": [{"id": "3"}, {"id": "4"}], "links": {"next": None}},
]

with patch.object(
terraform_client, "send_api_request", side_effect=mock_responses
) as mock_send:
results = []
async for resources in terraform_client.get_paginated_resources(
"test_endpoint"
):
results.extend(resources)

assert len(results) == 4
assert [r["id"] for r in results] == ["1", "2", "3", "4"]
assert mock_send.call_count == 2
Loading