Skip to content

Commit

Permalink
feat(event_handler): add support for OpenAPI security schemes (#4103)
Browse files Browse the repository at this point in the history
  • Loading branch information
rubenfonseca authored Apr 18, 2024
1 parent 1e7b3ab commit 55713ce
Show file tree
Hide file tree
Showing 15 changed files with 862 additions and 37 deletions.
149 changes: 125 additions & 24 deletions aws_lambda_powertools/event_handler/api_gateway.py

Large diffs are not rendered by default.

15 changes: 15 additions & 0 deletions aws_lambda_powertools/event_handler/bedrock_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ def get( # type: ignore[override]
include_in_schema: bool = True,
middlewares: Optional[List[Callable[..., Any]]] = None,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
security = None

return super(BedrockAgentResolver, self).get(
rule,
cors,
Expand All @@ -114,6 +116,7 @@ def get( # type: ignore[override]
tags,
operation_id,
include_in_schema,
security,
middlewares,
)

Expand All @@ -134,6 +137,8 @@ def post( # type: ignore[override]
include_in_schema: bool = True,
middlewares: Optional[List[Callable[..., Any]]] = None,
):
security = None

return super().post(
rule,
cors,
Expand All @@ -146,6 +151,7 @@ def post( # type: ignore[override]
tags,
operation_id,
include_in_schema,
security,
middlewares,
)

Expand All @@ -166,6 +172,8 @@ def put( # type: ignore[override]
include_in_schema: bool = True,
middlewares: Optional[List[Callable[..., Any]]] = None,
):
security = None

return super().put(
rule,
cors,
Expand All @@ -178,6 +186,7 @@ def put( # type: ignore[override]
tags,
operation_id,
include_in_schema,
security,
middlewares,
)

Expand All @@ -198,6 +207,8 @@ def patch( # type: ignore[override]
include_in_schema: bool = True,
middlewares: Optional[List[Callable]] = None,
):
security = None

return super().patch(
rule,
cors,
Expand All @@ -210,6 +221,7 @@ def patch( # type: ignore[override]
tags,
operation_id,
include_in_schema,
security,
middlewares,
)

Expand All @@ -230,6 +242,8 @@ def delete( # type: ignore[override]
include_in_schema: bool = True,
middlewares: Optional[List[Callable[..., Any]]] = None,
):
security = None

return super().delete(
rule,
cors,
Expand All @@ -242,6 +256,7 @@ def delete( # type: ignore[override]
tags,
operation_id,
include_in_schema,
security,
middlewares,
)

Expand Down
3 changes: 2 additions & 1 deletion aws_lambda_powertools/event_handler/openapi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,12 +441,13 @@ class SecurityBase(BaseModel):
description: Optional[str] = None

if PYDANTIC_V2:
model_config = {"extra": "allow"}
model_config = {"extra": "allow", "populate_by_name": True}

else:

class Config:
extra = "allow"
allow_population_by_field_name = True


class APIKeyIn(Enum):
Expand Down
13 changes: 13 additions & 0 deletions aws_lambda_powertools/event_handler/openapi/swagger_ui/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from aws_lambda_powertools.event_handler.openapi.swagger_ui.html import (
generate_swagger_html,
)
from aws_lambda_powertools.event_handler.openapi.swagger_ui.oauth2 import (
OAuth2Config,
generate_oauth2_redirect_html,
)

__all__ = [
"generate_swagger_html",
"generate_oauth2_redirect_html",
"OAuth2Config",
]
39 changes: 33 additions & 6 deletions aws_lambda_powertools/event_handler/openapi/swagger_ui/html.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,16 @@
def generate_swagger_html(spec: str, path: str, swagger_js: str, swagger_css: str, swagger_base_url: str) -> str:
from typing import Optional

from aws_lambda_powertools.event_handler.openapi.swagger_ui.oauth2 import OAuth2Config


def generate_swagger_html(
spec: str,
path: str,
swagger_js: str,
swagger_css: str,
swagger_base_url: str,
oauth2_config: Optional[OAuth2Config],
) -> str:
"""
Generate Swagger UI HTML page
Expand All @@ -8,10 +20,14 @@ def generate_swagger_html(spec: str, path: str, swagger_js: str, swagger_css: st
The OpenAPI spec
path: str
The path to the Swagger documentation
js_url: str
The URL to the Swagger UI JavaScript file
css_url: str
The URL to the Swagger UI CSS file
swagger_js: str
Swagger UI JavaScript source code or URL
swagger_css: str
Swagger UI CSS source code or URL
swagger_base_url: str
The base URL for Swagger UI
oauth2_config: OAuth2Config, optional
The OAuth2 configuration.
"""

# If Swagger base URL is present, generate HTML content with linked CSS and JavaScript files
Expand All @@ -23,6 +39,11 @@ def generate_swagger_html(spec: str, path: str, swagger_js: str, swagger_css: st
swagger_css_content = f"<style>{swagger_css}</style>"
swagger_js_content = f"<script>{swagger_js}</script>"

# Prepare oauth2 config
oauth2_content = (
f"ui.initOAuth({oauth2_config.json(exclude_none=True, exclude_unset=True)});" if oauth2_config else ""
)

return f"""
<!DOCTYPE html>
<html>
Expand All @@ -45,6 +66,9 @@ def generate_swagger_html(spec: str, path: str, swagger_js: str, swagger_css: st
{swagger_js_content}
<script>
var currentUrl = new URL(window.location.href);
var baseUrl = currentUrl.protocol + "//" + currentUrl.host + currentUrl.pathname;
var swaggerUIOptions = {{
dom_id: "#swagger-ui",
docExpansion: "list",
Expand All @@ -60,11 +84,14 @@ def generate_swagger_html(spec: str, path: str, swagger_js: str, swagger_css: st
],
plugins: [
SwaggerUIBundle.plugins.DownloadUrl
]
],
withCredentials: true,
oauth2RedirectUrl: baseUrl + "?format=oauth2-redirect",
}}
var ui = SwaggerUIBundle(swaggerUIOptions)
ui.specActions.updateUrl('{path}?format=json');
{oauth2_content}
</script>
</html>
""".strip()
158 changes: 158 additions & 0 deletions aws_lambda_powertools/event_handler/openapi/swagger_ui/oauth2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# ruff: noqa: E501
import warnings
from typing import Dict, Optional, Sequence

from pydantic import BaseModel, Field, validator

from aws_lambda_powertools.event_handler.openapi.pydantic_loader import PYDANTIC_V2
from aws_lambda_powertools.shared.functions import powertools_dev_is_set


# Based on https://swagger.io/docs/open-source-tools/swagger-ui/usage/oauth2/
class OAuth2Config(BaseModel):
"""
OAuth2 configuration for Swagger UI
"""

# The client ID for the OAuth2 application
clientId: Optional[str] = Field(alias="client_id", default=None)

# The client secret for the OAuth2 application. This is sensitive information and requires the explicit presence
# of the POWERTOOLS_DEV environment variable.
clientSecret: Optional[str] = Field(alias="client_secret", default=None)

# The realm in which the OAuth2 application is registered. Optional.
realm: Optional[str] = Field(default=None)

# The name of the OAuth2 application
appName: str = Field(alias="app_name")

# The scopes that the OAuth2 application requires. Defaults to an empty list.
scopes: Sequence[str] = Field(default=[])

# Additional query string parameters to be included in the OAuth2 request. Defaults to an empty dictionary.
additionalQueryStringParams: Dict[str, str] = Field(alias="additional_query_string_params", default={})

# Whether to use basic authentication with the access code grant type. Defaults to False.
useBasicAuthenticationWithAccessCodeGrant: bool = Field(
alias="use_basic_authentication_with_access_code_grant",
default=False,
)

# Whether to use PKCE with the authorization code grant type. Defaults to False.
usePkceWithAuthorizationCodeGrant: bool = Field(alias="use_pkce_with_authorization_code_grant", default=False)

if PYDANTIC_V2:
model_config = {"extra": "allow"}
else:

class Config:
extra = "allow"
allow_population_by_field_name = True

@validator("clientSecret", always=True)
def client_secret_only_on_dev(cls, v: Optional[str]) -> Optional[str]:
if not v:
return None

if not powertools_dev_is_set():
raise ValueError(
"cannot use client_secret without POWERTOOLS_DEV mode. See "
"https://docs.powertools.aws.dev/lambda/python/latest/#optimizing-for-non-production-environments",
)
else:
warnings.warn(
"OAuth2Config is using client_secret and POWERTOOLS_DEV is set. This reveals sensitive information. "
"DO NOT USE THIS OUTSIDE LOCAL DEVELOPMENT",
stacklevel=2,
)
return v


def generate_oauth2_redirect_html() -> str:
"""
Generates the HTML content for the OAuth2 redirect page.
Source: https://github.com/swagger-api/swagger-ui/blob/master/dist/oauth2-redirect.html
"""
return """
<!doctype html>
<html lang="en-US">
<head>
<title>Swagger UI: OAuth2 Redirect</title>
</head>
<body>
<script>
'use strict';
function run () {
var oauth2 = window.opener.swaggerUIRedirectOauth2;
var sentState = oauth2.state;
var redirectUrl = oauth2.redirectUrl;
var isValid, qp, arr;
if (/code|token|error/.test(window.location.hash)) {
qp = window.location.hash.substring(1).replace('?', '&');
} else {
qp = location.search.substring(1);
}
arr = qp.split("&");
arr.forEach(function (v,i,_arr) { _arr[i] = '"' + v.replace('=', '":"') + '"';});
qp = qp ? JSON.parse('{' + arr.join() + '}',
function (key, value) {
return key === "" ? value : decodeURIComponent(value);
}
) : {};
isValid = qp.state === sentState;
if ((
oauth2.auth.schema.get("flow") === "accessCode" ||
oauth2.auth.schema.get("flow") === "authorizationCode" ||
oauth2.auth.schema.get("flow") === "authorization_code"
) && !oauth2.auth.code) {
if (!isValid) {
oauth2.errCb({
authId: oauth2.auth.name,
source: "auth",
level: "warning",
message: "Authorization may be unsafe, passed state was changed in server. The passed state wasn't returned from auth server."
});
}
if (qp.code) {
delete oauth2.state;
oauth2.auth.code = qp.code;
oauth2.callback({auth: oauth2.auth, redirectUrl: redirectUrl});
} else {
let oauthErrorMsg;
if (qp.error) {
oauthErrorMsg = "["+qp.error+"]: " +
(qp.error_description ? qp.error_description+ ". " : "no accessCode received from the server. ") +
(qp.error_uri ? "More info: "+qp.error_uri : "");
}
oauth2.errCb({
authId: oauth2.auth.name,
source: "auth",
level: "error",
message: oauthErrorMsg || "[Authorization failed]: no accessCode received from the server."
});
}
} else {
oauth2.callback({auth: oauth2.auth, token: qp, isValid: isValid, redirectUrl: redirectUrl});
}
window.close();
}
if (document.readyState !== 'loading') {
run();
} else {
document.addEventListener('DOMContentLoaded', function () {
run();
});
}
</script>
</body>
</html>
""".strip()
Loading

0 comments on commit 55713ce

Please sign in to comment.