-
Notifications
You must be signed in to change notification settings - Fork 0
[Feature] Main Embedding and ConcretizedCallable logic #1
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @dominikandreasseitz some intermediate review.
# qadence-embeddings | ||
|
||
**qadence-embeddings** is a engine-agnostic parameter embedding library. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
**qadence-embeddings** is a engine-agnostic parameter embedding library. | |
**qadence-embeddings** is an engine-agnostic parameter embedding library. |
qadence_embeddings/callable.py
Outdated
call_name: str, | ||
abstract_args: list[str | float | int], | ||
instruction_mapping: dict[str, Callable] = dict(), | ||
engine_name: str = "torch", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No our beloved StrEnums ? ;)
|
||
logger = getLogger(__name__) | ||
|
||
ARRAYLIKE_FN_MAP = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ARRAYLIKE_FN_MAP = { | |
ARRAYLIKE_CONSTRUCTOR_MAP = { |
Or something like this ?
qadence_embeddings/callable.py
Outdated
engine = import_module(engine_name) | ||
arraylike_fn = getattr(engine, fn_name) | ||
except (ModuleNotFoundError, ImportError) as e: | ||
logger.error(f"Unable to import {engine_call} due to {e}.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logger.error(f"Unable to import {engine_call} due to {e}.") | |
logger.error(f"Unable to import {engine_name} due to {e}.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it necessary to raise the error again ?
qadence_embeddings/callable.py
Outdated
|
||
try: | ||
engine_call = getattr(engine, call_name) | ||
except ImportError: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it the right exception to catch ?
qadence_embeddings/callable.py
Outdated
try: | ||
engine_call = getattr(engine, call_name) | ||
except ImportError: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
try: | |
engine_call = getattr(engine, call_name) | |
except ImportError: | |
pass | |
engine_call = getattr(engine, call_name, None) |
qadence_embeddings/callable.py
Outdated
f"Requested function {call_name} can not be imported from {engine_name} and is\ | ||
not in instruction_mapping {instruction_mapping} due to {e}." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
f"Requested function {call_name} can not be imported from {engine_name} and is\ | |
not in instruction_mapping {instruction_mapping} due to {e}." | |
f"Requested function {call_name} can not be imported from {engine_name} and is" | |
f"not in instruction_mapping {instruction_mapping} due to {e}." |
qadence_embeddings/callable.py
Outdated
|
||
def ConcretizedCallable( | ||
call_name: str, | ||
abstract_args: list[str | float | int], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
abstract_args: list[str | float | int], | |
args: list[str | float | int], |
qadence_embeddings/callable.py
Outdated
|
||
def evaluate(params: dict = dict(), inputs: dict = dict()) -> ArrayLike: | ||
arraylike_args = [] | ||
for symbol_or_numeric in abstract_args: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for symbol_or_numeric in abstract_args: | |
for arg in args: |
qadence_embeddings/embedding.py
Outdated
|
||
@abstractmethod | ||
def __call__(self, *args: Any, **kwargs: Any) -> dict[str, Callable]: | ||
raise NotImplementedError() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have to evaluate here ?
qadence_embeddings/callable.py
Outdated
DEFAULT_JAX_MAPPING = {"mul": ("jax.numpy", "multiply")} | ||
DEFAULT_TORCH_MAPPING = {} | ||
DEFAULT_NUMPY_MAPPING = {"mul": ("numpy", "multiply")} | ||
|
||
DEFAULT_INSTRUCTION_MAPPING = { | ||
"torch": DEFAULT_TORCH_MAPPING, | ||
"jax": DEFAULT_JAX_MAPPING, | ||
"numpy": DEFAULT_NUMPY_MAPPING, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would removing default mappings make it difficult to test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Many thanks @dominikandreasseitz. Comments from my side.
engine = import_module(engine_name) | ||
self.arraylike_fn = getattr(engine, fn_name) | ||
except (ModuleNotFoundError, ImportError) as e: | ||
logger.error(f"Unable to import {engine_call} due to {e}.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At this point, engine_call
is None
. I think you meant engine_name
.
try: | ||
self.engine_call = getattr(engine, call_name) | ||
except AttributeError: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
try: | |
self.engine_call = getattr(engine, call_name) | |
except AttributeError: | |
pass | |
self.engine_call = getattr(engine, call_name, None) |
f"Requested function {call_name} can not be imported from {engine_name} and is\ | ||
not in instruction_mapping {instruction_mapping} due to {e}." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
f"Requested function {call_name} can not be imported from {engine_name} and is\ | |
not in instruction_mapping {instruction_mapping} due to {e}." | |
f"Requested function {call_name} can not be imported from {engine_name} and is" | |
f" not in instruction_mapping {instruction_mapping} due to {e}." |
def init_param( | ||
engine_name: str, trainable: bool = True, device: str = "cpu" | ||
) -> ArrayLike: | ||
engine = import_module(engine_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would be a bit more defensive and try/catch this one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe: since this engine import logic is used in both ConcretizedCallable
here, I'd rather put it in some kind of utils
.
self.fparam_names: list[str] = fparam_names | ||
self.tparam_name = tparam_name | ||
self.var_to_call: dict[str, ConcretizedCallable] = var_to_call | ||
self._dtype: DTypeLike = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does None
make sense wrt tDTypeLike
?
and recalculate the those which are dependent on the time parameter using the new value | ||
`tparam_value`. | ||
""" | ||
assert self.tparam_name is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The semantics here are that reembed_time
is always called after the initial embedding and therefore self.tparam_name
is set. I would raise an exception here with a proper message.
assert self.tparam_name is not None | ||
embedded_params[self.tparam_name] = tparam_value | ||
for time_dependent_param in self.time_dependent_vars: | ||
embedded_params[time_dependent_param] = self.var_to_call[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I get the idea that this is an embedded_params
update but it is a bit confusing.
"""Functional version of legacy embedding: Return a new dictionary\ | ||
with all embedded parameters.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"""Functional version of legacy embedding: Return a new dictionary\ | |
with all embedded parameters.""" | |
"""Functional version of legacy embedding. | |
Return a new dictionary with all embedded parameters. | |
""" |
|
||
|
||
@pytest.mark.parametrize( | ||
"fn", ["sin", "cos", "log", "tanh", "tan", "acos", "sin", "sqrt", "square"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe also add pow
?
v_params = ["theta"] | ||
f_params = ["x"] | ||
tparam = "t" | ||
leaf0, native_call0 = "%0", ConcretizedCallable( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to randomly draw from a poll of functions and variable names to construct these instructions ? I am a bit worried that we only cover a very reduced subset of the possibilities.
Porting
ConcretizedCallable
andEmbedding
to its own repo so they can be accessed by APIs and backends