-
Notifications
You must be signed in to change notification settings - Fork 0
[Feature] Main Embedding and ConcretizedCallable logic #1
base: main
Are you sure you want to change the base?
Changes from 1 commit
01635e6
0296b59
e6dd59f
0a7fbea
2586c0e
8457f98
5ac4327
081d737
329aed0
99b7ef6
81ac1c5
98eee1b
deb1c2f
2e00a5e
360d7f7
4a30d06
ff826b9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -15,17 +15,63 @@ | |||||||||||
} | ||||||||||||
|
||||||||||||
|
||||||||||||
def ConcretizedCallable( | ||||||||||||
call_name: str, | ||||||||||||
abstract_args: list[str | float | int], | ||||||||||||
instruction_mapping: dict[str, Callable] = dict(), | ||||||||||||
engine_name: str = "torch", | ||||||||||||
) -> Callable[[dict, dict], ArrayLike]: | ||||||||||||
DEFAULT_JAX_MAPPING = {"mul": ("jax.numpy", "multiply")} | ||||||||||||
DEFAULT_TORCH_MAPPING = {} | ||||||||||||
DEFAULT_NUMPY_MAPPING = {"mul": ("numpy", "multiply")} | ||||||||||||
|
||||||||||||
DEFAULT_INSTRUCTION_MAPPING = { | ||||||||||||
"torch": DEFAULT_TORCH_MAPPING, | ||||||||||||
"jax": DEFAULT_JAX_MAPPING, | ||||||||||||
"numpy": DEFAULT_NUMPY_MAPPING, | ||||||||||||
} | ||||||||||||
|
||||||||||||
|
||||||||||||
class ConcretizedCallable: | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But no strong opinion. |
||||||||||||
|
||||||||||||
def __init__( | ||||||||||||
self, | ||||||||||||
call_name: str, | ||||||||||||
abstract_args: list[str | float | int], | ||||||||||||
instruction_mapping: dict[str, Callable] = dict(), | ||||||||||||
engine_name: str = "torch", | ||||||||||||
) -> None: | ||||||||||||
instruction_mapping = { | ||||||||||||
**instruction_mapping, | ||||||||||||
**DEFAULT_INSTRUCTION_MAPPING[engine_name], | ||||||||||||
} | ||||||||||||
self.call_name = call_name | ||||||||||||
self.abstract_args = abstract_args | ||||||||||||
self.engine_name = engine_name | ||||||||||||
self.engine_call = None | ||||||||||||
engine_call = None | ||||||||||||
engine = None | ||||||||||||
try: | ||||||||||||
engine_name, fn_name = ARRAYLIKE_FN_MAP[engine_name] | ||||||||||||
engine = import_module(engine_name) | ||||||||||||
self.arraylike_fn = getattr(engine, fn_name) | ||||||||||||
except (ModuleNotFoundError, ImportError) as e: | ||||||||||||
logger.error(f"Unable to import {engine_call} due to {e}.") | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At this point, |
||||||||||||
|
||||||||||||
try: | ||||||||||||
# breakpoint() | ||||||||||||
try: | ||||||||||||
self.engine_call = getattr(engine, call_name) | ||||||||||||
except AttributeError: | ||||||||||||
pass | ||||||||||||
Comment on lines
+69
to
+72
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
if self.engine_call is None: | ||||||||||||
mod, fn = instruction_mapping[call_name] | ||||||||||||
self.engine_call = getattr(import_module(mod), fn) | ||||||||||||
except (ImportError, KeyError) as e: | ||||||||||||
logger.error( | ||||||||||||
f"Requested function {call_name} can not be imported from {engine_name} and is\ | ||||||||||||
not in instruction_mapping {instruction_mapping} due to {e}." | ||||||||||||
Comment on lines
+78
to
+79
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
) | ||||||||||||
|
||||||||||||
"""Convert a generic abstract function call and | ||||||||||||
a list of symbolic or constant parameters | ||||||||||||
into a concretized Callable in a particular engine. | ||||||||||||
which can be evaluated using | ||||||||||||
a vparams and inputs dict. | ||||||||||||
a inputs dict. | ||||||||||||
|
||||||||||||
Arguments: | ||||||||||||
call_name: The name of the function | ||||||||||||
|
@@ -49,35 +95,15 @@ def ConcretizedCallable( | |||||||||||
Out[16]: Array(0.47942555, dtype=float32, weak_type=True) | ||||||||||||
``` | ||||||||||||
""" | ||||||||||||
engine_call = None | ||||||||||||
engine = None | ||||||||||||
try: | ||||||||||||
engine_name, fn_name = ARRAYLIKE_FN_MAP[engine_name] | ||||||||||||
engine = import_module(engine_name) | ||||||||||||
arraylike_fn = getattr(engine, fn_name) | ||||||||||||
except (ModuleNotFoundError, ImportError) as e: | ||||||||||||
logger.error(f"Unable to import {engine_call} due to {e}.") | ||||||||||||
|
||||||||||||
try: | ||||||||||||
engine_call = getattr(engine, call_name) | ||||||||||||
except ImportError: | ||||||||||||
pass | ||||||||||||
if engine_call is None: | ||||||||||||
try: | ||||||||||||
engine_call = instruction_mapping[call_name] | ||||||||||||
except KeyError as e: | ||||||||||||
logger.error( | ||||||||||||
f"Requested function {call_name} can not be imported from {engine_name} and is\ | ||||||||||||
not in instruction_mapping {instruction_mapping} due to {e}." | ||||||||||||
) | ||||||||||||
|
||||||||||||
def evaluate(params: dict = dict(), inputs: dict = dict()) -> ArrayLike: | ||||||||||||
def evaluate(self, inputs: dict[str, ArrayLike] = dict()) -> ArrayLike: | ||||||||||||
arraylike_args = [] | ||||||||||||
for symbol_or_numeric in abstract_args: | ||||||||||||
for symbol_or_numeric in self.abstract_args: | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
if isinstance(symbol_or_numeric, (float, int)): | ||||||||||||
arraylike_args.append(arraylike_fn(symbol_or_numeric)) | ||||||||||||
arraylike_args.append(self.arraylike_fn(symbol_or_numeric)) | ||||||||||||
elif isinstance(symbol_or_numeric, str): | ||||||||||||
arraylike_args.append({**params, **inputs}[symbol_or_numeric]) | ||||||||||||
return engine_call(*arraylike_args) # type: ignore[misc] | ||||||||||||
arraylike_args.append(inputs[symbol_or_numeric]) | ||||||||||||
return self.engine_call(*arraylike_args) # type: ignore[misc] | ||||||||||||
|
||||||||||||
return evaluate | ||||||||||||
def __call__(self, inputs: dict[str, ArrayLike] = dict()) -> ArrayLike: | ||||||||||||
return self.evaluate(inputs) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from __future__ import annotations | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from qadence_embeddings.callable import ConcretizedCallable | ||
|
||
|
||
def test_sin() -> None: | ||
results = [] | ||
x = np.random.randn(1) | ||
for engine_name in ["jax", "torch", "numpy"]: | ||
native_call = ConcretizedCallable("sin", ["x"], {}, engine_name) | ||
native_result = native_call( | ||
{"x": (torch.tensor(x) if engine_name == "torch" else x)} | ||
) | ||
results.append(native_result.item()) | ||
assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2]) | ||
|
||
|
||
def test_log() -> None: | ||
results = [] | ||
x = np.random.uniform(0, 5) | ||
for engine_name in ["jax", "torch", "numpy"]: | ||
native_call = ConcretizedCallable("log", ["x"], {}, engine_name) | ||
native_result = native_call( | ||
{"x": (torch.tensor(x) if engine_name == "torch" else x)} | ||
) | ||
results.append(native_result.item()) | ||
|
||
assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2]) | ||
|
||
|
||
def test_mul() -> None: | ||
results = [] | ||
x = np.random.randn(1) | ||
y = np.random.randn(1) | ||
for engine_name in ["jax", "torch", "numpy"]: | ||
native_call = ConcretizedCallable("mul", ["x", "y"], {}, engine_name) | ||
native_result = native_call( | ||
{ | ||
"x": torch.tensor(x) if engine_name == "torch" else x, | ||
"y": torch.tensor(y) if engine_name == "torch" else y, | ||
} | ||
) | ||
results.append(native_result.item()) | ||
assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2]) | ||
|
||
|
||
def test_add() -> None: | ||
results = [] | ||
x = np.random.randn(1) | ||
y = np.random.randn(1) | ||
for engine_name in ["jax", "torch", "numpy"]: | ||
native_call = ConcretizedCallable("add", ["x", "y"], {}, engine_name) | ||
native_result = native_call( | ||
{ | ||
"x": torch.tensor(x) if engine_name == "torch" else x, | ||
"y": torch.tensor(y) if engine_name == "torch" else y, | ||
} | ||
) | ||
results.append(native_result.item()) | ||
assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would removing default mappings make it difficult to test?