Skip to content
This repository has been archived by the owner on Jul 16, 2024. It is now read-only.

[Feature] Main Embedding and ConcretizedCallable logic #1

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ classifiers=[
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = ["torch", "numpy"]
dependencies = ["torch", "numpy", "jax"]

[project.optional-dependencies]
dev = ["flaky","black", "pytest", "pytest-xdist", "pytest-cov", "flake8", "mypy", "pre-commit", "ruff", "nbconvert", "matplotlib", "qutip~=4.7.5"]
Expand Down
94 changes: 60 additions & 34 deletions qadence_embeddings/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,63 @@
}


def ConcretizedCallable(
call_name: str,
abstract_args: list[str | float | int],
instruction_mapping: dict[str, Callable] = dict(),
engine_name: str = "torch",
) -> Callable[[dict, dict], ArrayLike]:
DEFAULT_JAX_MAPPING = {"mul": ("jax.numpy", "multiply")}
DEFAULT_TORCH_MAPPING = {}
DEFAULT_NUMPY_MAPPING = {"mul": ("numpy", "multiply")}

DEFAULT_INSTRUCTION_MAPPING = {
"torch": DEFAULT_TORCH_MAPPING,
"jax": DEFAULT_JAX_MAPPING,
"numpy": DEFAULT_NUMPY_MAPPING,
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would removing default mappings make it difficult to test?



class ConcretizedCallable:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class ConcretizedCallable:
class CallableMap:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But no strong opinion.


def __init__(
self,
call_name: str,
abstract_args: list[str | float | int],
instruction_mapping: dict[str, Callable] = dict(),
engine_name: str = "torch",
) -> None:
instruction_mapping = {
**instruction_mapping,
**DEFAULT_INSTRUCTION_MAPPING[engine_name],
}
self.call_name = call_name
self.abstract_args = abstract_args
self.engine_name = engine_name
self.engine_call = None
engine_call = None
engine = None
try:
engine_name, fn_name = ARRAYLIKE_FN_MAP[engine_name]
engine = import_module(engine_name)
self.arraylike_fn = getattr(engine, fn_name)
except (ModuleNotFoundError, ImportError) as e:
logger.error(f"Unable to import {engine_call} due to {e}.")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point, engine_call is None. I think you meant engine_name.


try:
# breakpoint()
try:
self.engine_call = getattr(engine, call_name)
except AttributeError:
pass
Comment on lines +69 to +72

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
try:
self.engine_call = getattr(engine, call_name)
except AttributeError:
pass
self.engine_call = getattr(engine, call_name, None)

if self.engine_call is None:
mod, fn = instruction_mapping[call_name]
self.engine_call = getattr(import_module(mod), fn)
except (ImportError, KeyError) as e:
logger.error(
f"Requested function {call_name} can not be imported from {engine_name} and is\
not in instruction_mapping {instruction_mapping} due to {e}."
Comment on lines +78 to +79

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f"Requested function {call_name} can not be imported from {engine_name} and is\
not in instruction_mapping {instruction_mapping} due to {e}."
f"Requested function {call_name} can not be imported from {engine_name} and is"
f" not in instruction_mapping {instruction_mapping} due to {e}."

)

"""Convert a generic abstract function call and
a list of symbolic or constant parameters
into a concretized Callable in a particular engine.
which can be evaluated using
a vparams and inputs dict.
a inputs dict.

Arguments:
call_name: The name of the function
Expand All @@ -49,35 +95,15 @@ def ConcretizedCallable(
Out[16]: Array(0.47942555, dtype=float32, weak_type=True)
```
"""
engine_call = None
engine = None
try:
engine_name, fn_name = ARRAYLIKE_FN_MAP[engine_name]
engine = import_module(engine_name)
arraylike_fn = getattr(engine, fn_name)
except (ModuleNotFoundError, ImportError) as e:
logger.error(f"Unable to import {engine_call} due to {e}.")

try:
engine_call = getattr(engine, call_name)
except ImportError:
pass
if engine_call is None:
try:
engine_call = instruction_mapping[call_name]
except KeyError as e:
logger.error(
f"Requested function {call_name} can not be imported from {engine_name} and is\
not in instruction_mapping {instruction_mapping} due to {e}."
)

def evaluate(params: dict = dict(), inputs: dict = dict()) -> ArrayLike:
def evaluate(self, inputs: dict[str, ArrayLike] = dict()) -> ArrayLike:
arraylike_args = []
for symbol_or_numeric in abstract_args:
for symbol_or_numeric in self.abstract_args:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for symbol_or_numeric in self.abstract_args:
for symbol_or_number in self.abstract_args:

if isinstance(symbol_or_numeric, (float, int)):
arraylike_args.append(arraylike_fn(symbol_or_numeric))
arraylike_args.append(self.arraylike_fn(symbol_or_numeric))
elif isinstance(symbol_or_numeric, str):
arraylike_args.append({**params, **inputs}[symbol_or_numeric])
return engine_call(*arraylike_args) # type: ignore[misc]
arraylike_args.append(inputs[symbol_or_numeric])
return self.engine_call(*arraylike_args) # type: ignore[misc]

return evaluate
def __call__(self, inputs: dict[str, ArrayLike] = dict()) -> ArrayLike:
return self.evaluate(inputs)
63 changes: 63 additions & 0 deletions tests/test_callable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from __future__ import annotations

import numpy as np
import torch

from qadence_embeddings.callable import ConcretizedCallable


def test_sin() -> None:
results = []
x = np.random.randn(1)
for engine_name in ["jax", "torch", "numpy"]:
native_call = ConcretizedCallable("sin", ["x"], {}, engine_name)
native_result = native_call(
{"x": (torch.tensor(x) if engine_name == "torch" else x)}
)
results.append(native_result.item())
assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2])


def test_log() -> None:
results = []
x = np.random.uniform(0, 5)
for engine_name in ["jax", "torch", "numpy"]:
native_call = ConcretizedCallable("log", ["x"], {}, engine_name)
native_result = native_call(
{"x": (torch.tensor(x) if engine_name == "torch" else x)}
)
results.append(native_result.item())

assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2])


def test_mul() -> None:
results = []
x = np.random.randn(1)
y = np.random.randn(1)
for engine_name in ["jax", "torch", "numpy"]:
native_call = ConcretizedCallable("mul", ["x", "y"], {}, engine_name)
native_result = native_call(
{
"x": torch.tensor(x) if engine_name == "torch" else x,
"y": torch.tensor(y) if engine_name == "torch" else y,
}
)
results.append(native_result.item())
assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2])


def test_add() -> None:
results = []
x = np.random.randn(1)
y = np.random.randn(1)
for engine_name in ["jax", "torch", "numpy"]:
native_call = ConcretizedCallable("add", ["x", "y"], {}, engine_name)
native_result = native_call(
{
"x": torch.tensor(x) if engine_name == "torch" else x,
"y": torch.tensor(y) if engine_name == "torch" else y,
}
)
results.append(native_result.item())
assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2])