Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check input params of fetch_quote and add tests #58

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions nwc_backend/event_handlers/__tests__/fetch_quote_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright ©, 2022, Lightspark Group, Inc. - All Rights Reserved
# pyre-strict

import json
from datetime import datetime, timedelta, timezone
from secrets import token_hex
from unittest.mock import ANY, AsyncMock, Mock, patch

import aiohttp
import pytest

from nwc_backend.event_handlers.fetch_quote_handler import fetch_quote
from nwc_backend.exceptions import InvalidInputException, NotImplementedException
from nwc_backend.models.nip47_request import Nip47Request


@patch.object(aiohttp.ClientSession, "get")
async def test_fetch_quote_success(mock_get: Mock) -> None:
now = datetime.now(timezone.utc)
vasp_response = {
"sending_currency_code": "SAT",
"receiving_currency_code": "USD",
"payment_hash": token_hex(),
"expires_at": int((now + timedelta(minutes=5)).timestamp()),
"multiplier": 15351.4798,
"fees": 10,
"total_sending_amount": 1_000_000,
"total_receiving_amount": 65,
"created_at": int(now.timestamp()),
}
mock_response = AsyncMock()
mock_response.text = AsyncMock(return_value=json.dumps(vasp_response))
mock_response.raise_for_status = Mock()
mock_get.return_value.__aenter__.return_value = mock_response

receiver_address = "$alice@uma.me"
params = {
"receiver": {"lud16": receiver_address},
"sending_currency_code": "SAT",
"receiving_currency_code": "USD",
"locked_currency_amount": 1_000_000,
"locked_currency_side": "sending",
}
quote = await fetch_quote(
access_token=token_hex(),
request=Nip47Request(params=params),
)

params.pop("receiver")
params["receiver_address"] = receiver_address
mock_get.assert_called_once_with(url="/quote/lud16", params=params, headers=ANY)

assert quote.sending_currency_code == vasp_response["sending_currency_code"]
assert quote.receiving_currency_code == vasp_response["receiving_currency_code"]
assert quote.payment_hash == vasp_response["payment_hash"]
assert quote.expires_at == vasp_response["expires_at"]
assert quote.multiplier == vasp_response["multiplier"]
assert quote.fees == vasp_response["fees"]
assert quote.total_sending_amount == vasp_response["total_sending_amount"]
assert quote.created_at == vasp_response["created_at"]


async def test_fetch_quote_failure__invalid_input() -> None:
with pytest.raises(InvalidInputException):
await fetch_quote(
access_token=token_hex(),
request=Nip47Request(
params={
"receiver": {"lud16": "$alice@uma.me"},
"sending_currency_code": "SAT",
"receiving_currency_code": "USD",
"locked_currency_amount": 1_000_000,
"locked_currency_side": "send", # wrong enum value
}
),
)


async def test_fetch_quote_failure__unsupported_bolt12() -> None:
with pytest.raises(NotImplementedException):
await fetch_quote(
access_token=token_hex(),
request=Nip47Request(
params={
"receiver": {"bolt12": "$alice@uma.me"}, # bolt12 not supported
"sending_currency_code": "SAT",
"receiving_currency_code": "USD",
"locked_currency_amount": 1_000_000,
"locked_currency_side": "sending",
}
),
)
71 changes: 25 additions & 46 deletions nwc_backend/event_handlers/fetch_quote_handler.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,38 @@
# Copyright ©, 2022, Lightspark Group, Inc. - All Rights Reserved
# pyre-strict

import logging
from typing import Any

from aiohttp import ClientResponseError
from nostr_sdk import ErrorCode, Nip47Error
from uma_auth.models.quote import Quote

from nwc_backend.exceptions import InvalidInputException
from nwc_backend.event_handlers.input_validator import get_required_field
from nwc_backend.models.nip47_request import Nip47Request
from nwc_backend.vasp_client import (
AddressType,
LockedCurrencySide,
ReceivingAddress,
vasp_uma_client,
)


async def fetch_quote(
access_token: str, request: Nip47Request
) -> dict[str, Any] | Nip47Error:
try:
locked_currency_side = LockedCurrencySide(
request.params["locked_currency_side"].lower()
)
except ValueError:
return Nip47Error(
code=ErrorCode.OTHER,
message="Expect locked_currency_side to be either sending or receiving.",
)

try:
receiving_address = ReceivingAddress.from_dict(request.params["receiver"])
except InvalidInputException as ex:
return Nip47Error(
code=ErrorCode.OTHER,
message=ex.error_message,
)
if receiving_address.type == AddressType.BOLT12:
return Nip47Error(
code=ErrorCode.NOT_IMPLEMENTED,
message="Bolt12 is not yet supported.",
)

try:
response = await vasp_uma_client.fetch_quote(
access_token=access_token,
sending_currency_code=request.params["sending_currency_code"],
receiving_currency_code=request.params["receiving_currency_code"],
locked_currency_amount=request.params["locked_currency_amount"],
locked_currency_side=locked_currency_side,
receiving_address=receiving_address,
)
return response.to_dict()
except ClientResponseError as ex:
logging.exception("Request fetch_quote %s failed", str(request.id))
# TODO: more granular error code
return Nip47Error(code=ErrorCode.OTHER, message=ex.message)
async def fetch_quote(access_token: str, request: Nip47Request) -> Quote:
sending_currency_code = get_required_field(
request.params, "sending_currency_code", str
)
receiving_currency_code = get_required_field(
request.params, "receiving_currency_code", str
)
locked_currency_amount = get_required_field(
request.params, "locked_currency_amount", int
)
locked_currency_side = get_required_field(
request.params, "locked_currency_side", LockedCurrencySide
)
receiver = get_required_field(request.params, "receiver", dict)
receiving_address = ReceivingAddress.from_dict(receiver)
return await vasp_uma_client.fetch_quote(
access_token=access_token,
sending_currency_code=sending_currency_code,
receiving_currency_code=receiving_currency_code,
locked_currency_amount=locked_currency_amount,
locked_currency_side=locked_currency_side,
receiver_address=receiving_address,
)
Copy link
Contributor Author

@yunyuyunyu yunyuyunyu Aug 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm finally happy with how clean request handler looks like now. Will keep refactoring while adding tests.

31 changes: 20 additions & 11 deletions nwc_backend/event_handlers/nip47_event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
from datetime import datetime, timezone

from aiohttp import ClientResponseError
from nostr_sdk import ErrorCode, Event, Nip47Error, TagKind, nip04_decrypt

from nwc_backend.configs.nostr_config import nostr_config
Expand Down Expand Up @@ -100,7 +101,9 @@ async def handle_nip47_event(event: Event) -> None:
await execute_quote(uma_access_token, nip47_request)
).to_dict()
case Nip47RequestMethod.FETCH_QUOTE:
response = await fetch_quote(uma_access_token, nip47_request)
response = (
await fetch_quote(uma_access_token, nip47_request)
).to_dict()
case Nip47RequestMethod.GET_BALANCE:
response = await get_balance(uma_access_token, nip47_request)
case Nip47RequestMethod.GET_INFO:
Expand All @@ -124,18 +127,24 @@ async def handle_nip47_event(event: Event) -> None:
code=ErrorCode.NOT_IMPLEMENTED,
message=f"Method {method} is not supported.",
)
except Nip47RequestException as ex:
logging.exception("Request %s %s failed", method, str(nip47_request.id))
response = Nip47Error(
code=ex.error_code,
message=ex.error_message,
)
except Exception as ex:
logging.exception("Request %s %s failed", method, str(nip47_request.id))
response = Nip47Error(
code=ErrorCode.INTERNAL,
message=str(ex),
)

if isinstance(ex, Nip47RequestException):
response = Nip47Error(
code=ex.error_code,
message=ex.error_message,
)
elif isinstance(ex, ClientResponseError):
response = Nip47Error(
code=ErrorCode.OTHER,
message=str(ex),
)
else:
response = Nip47Error(
code=ErrorCode.INTERNAL,
message=str(ex),
)

if isinstance(response, Nip47Error):
response_event = create_nip47_error_response(
Expand Down
7 changes: 7 additions & 0 deletions nwc_backend/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,10 @@ def __init__(self, error_code: ErrorCode, error_message: str) -> None:
class InvalidInputException(Nip47RequestException):
def __init__(self, error_message: str) -> None:
super().__init__(error_code=ErrorCode.OTHER, error_message=error_message)


class NotImplementedException(Nip47RequestException):
def __init__(self, error_message: str) -> None:
super().__init__(
error_code=ErrorCode.NOT_IMPLEMENTED, error_message=error_message
)
21 changes: 13 additions & 8 deletions nwc_backend/vasp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from uma_auth.models.quote import Quote
from uma_auth.models.transaction import Transaction

from nwc_backend.exceptions import InvalidInputException
from nwc_backend.exceptions import InvalidInputException, NotImplementedException


class AddressType(Enum):
Expand All @@ -39,13 +39,18 @@ def from_dict(receiving_address: dict[str, str]) -> "ReceivingAddress":
"Expect `receiver` to contain exactly one address.",
)

address_type, address = receiving_address.popitem()
address_type, address = next(iter(receiving_address.items()))
try:
return ReceivingAddress(address=address, type=AddressType(address_type))
except ValueError:
address_type = AddressType(address_type)
except ValueError as ex:
raise InvalidInputException(
"Expect `receiver` to contain address type `bolt12` or `lud16`.",
)
) from ex

if address_type == AddressType.BOLT12:
raise NotImplementedException("Bolt12 is not yet supported.")

return ReceivingAddress(address=address, type=AddressType(address_type))


class LockedCurrencySide(Enum):
Expand Down Expand Up @@ -105,17 +110,17 @@ async def fetch_quote(
receiving_currency_code: str,
locked_currency_amount: int,
locked_currency_side: LockedCurrencySide,
receiving_address: ReceivingAddress,
receiver_address: ReceivingAddress,
) -> Quote:
params = {
"sending_currency_code": sending_currency_code,
"receiving_currency_code": receiving_currency_code,
"locked_currency_amount": locked_currency_amount,
"locked_currency_side": locked_currency_side.value,
"receiving_address": receiving_address.address,
"receiver_address": receiver_address.address,
}
result = await self._make_http_get(
path=f"/quote/{receiving_address.type.value}",
path=f"/quote/{receiver_address.type.value}",
access_token=access_token,
params=params,
)
Expand Down
Loading