Skip to content
This repository has been archived by the owner on Jul 16, 2024. It is now read-only.

[Feature] Main Embedding and ConcretizedCallable logic #1

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

dominikandreasseitz
Copy link
Collaborator

@dominikandreasseitz dominikandreasseitz commented Jul 2, 2024

Porting ConcretizedCallable and Embedding to its own repo so they can be accessed by APIs and backends

@dominikandreasseitz dominikandreasseitz self-assigned this Jul 2, 2024
@dominikandreasseitz dominikandreasseitz added the enhancement New feature or request label Jul 2, 2024
@dominikandreasseitz dominikandreasseitz marked this pull request as draft July 2, 2024 14:47
Copy link

@RolandMacDoland RolandMacDoland left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @dominikandreasseitz some intermediate review.

# qadence-embeddings

**qadence-embeddings** is a engine-agnostic parameter embedding library.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
**qadence-embeddings** is a engine-agnostic parameter embedding library.
**qadence-embeddings** is an engine-agnostic parameter embedding library.

call_name: str,
abstract_args: list[str | float | int],
instruction_mapping: dict[str, Callable] = dict(),
engine_name: str = "torch",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No our beloved StrEnums ? ;)


logger = getLogger(__name__)

ARRAYLIKE_FN_MAP = {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ARRAYLIKE_FN_MAP = {
ARRAYLIKE_CONSTRUCTOR_MAP = {

Or something like this ?

engine = import_module(engine_name)
arraylike_fn = getattr(engine, fn_name)
except (ModuleNotFoundError, ImportError) as e:
logger.error(f"Unable to import {engine_call} due to {e}.")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logger.error(f"Unable to import {engine_call} due to {e}.")
logger.error(f"Unable to import {engine_name} due to {e}.")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to raise the error again ?


try:
engine_call = getattr(engine, call_name)
except ImportError:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it the right exception to catch ?

Comment on lines 61 to 64
try:
engine_call = getattr(engine, call_name)
except ImportError:
pass

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
try:
engine_call = getattr(engine, call_name)
except ImportError:
pass
engine_call = getattr(engine, call_name, None)

Comment on lines 70 to 71
f"Requested function {call_name} can not be imported from {engine_name} and is\
not in instruction_mapping {instruction_mapping} due to {e}."

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f"Requested function {call_name} can not be imported from {engine_name} and is\
not in instruction_mapping {instruction_mapping} due to {e}."
f"Requested function {call_name} can not be imported from {engine_name} and is"
f"not in instruction_mapping {instruction_mapping} due to {e}."


def ConcretizedCallable(
call_name: str,
abstract_args: list[str | float | int],

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
abstract_args: list[str | float | int],
args: list[str | float | int],


def evaluate(params: dict = dict(), inputs: dict = dict()) -> ArrayLike:
arraylike_args = []
for symbol_or_numeric in abstract_args:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for symbol_or_numeric in abstract_args:
for arg in args:


@abstractmethod
def __call__(self, *args: Any, **kwargs: Any) -> dict[str, Callable]:
raise NotImplementedError()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to evaluate here ?

Comment on lines 18 to 26
DEFAULT_JAX_MAPPING = {"mul": ("jax.numpy", "multiply")}
DEFAULT_TORCH_MAPPING = {}
DEFAULT_NUMPY_MAPPING = {"mul": ("numpy", "multiply")}

DEFAULT_INSTRUCTION_MAPPING = {
"torch": DEFAULT_TORCH_MAPPING,
"jax": DEFAULT_JAX_MAPPING,
"numpy": DEFAULT_NUMPY_MAPPING,
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would removing default mappings make it difficult to test?

Copy link

@RolandMacDoland RolandMacDoland left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many thanks @dominikandreasseitz. Comments from my side.

engine = import_module(engine_name)
self.arraylike_fn = getattr(engine, fn_name)
except (ModuleNotFoundError, ImportError) as e:
logger.error(f"Unable to import {engine_call} due to {e}.")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point, engine_call is None. I think you meant engine_name.

Comment on lines +69 to +72
try:
self.engine_call = getattr(engine, call_name)
except AttributeError:
pass

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
try:
self.engine_call = getattr(engine, call_name)
except AttributeError:
pass
self.engine_call = getattr(engine, call_name, None)

Comment on lines +78 to +79
f"Requested function {call_name} can not be imported from {engine_name} and is\
not in instruction_mapping {instruction_mapping} due to {e}."

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f"Requested function {call_name} can not be imported from {engine_name} and is\
not in instruction_mapping {instruction_mapping} due to {e}."
f"Requested function {call_name} can not be imported from {engine_name} and is"
f" not in instruction_mapping {instruction_mapping} due to {e}."

def init_param(
engine_name: str, trainable: bool = True, device: str = "cpu"
) -> ArrayLike:
engine = import_module(engine_name)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be a bit more defensive and try/catch this one.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe: since this engine import logic is used in both ConcretizedCallable here, I'd rather put it in some kind of utils.

self.fparam_names: list[str] = fparam_names
self.tparam_name = tparam_name
self.var_to_call: dict[str, ConcretizedCallable] = var_to_call
self._dtype: DTypeLike = None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does None make sense wrt tDTypeLike ?

and recalculate the those which are dependent on the time parameter using the new value
`tparam_value`.
"""
assert self.tparam_name is not None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The semantics here are that reembed_time is always called after the initial embedding and therefore self.tparam_name is set. I would raise an exception here with a proper message.

assert self.tparam_name is not None
embedded_params[self.tparam_name] = tparam_value
for time_dependent_param in self.time_dependent_vars:
embedded_params[time_dependent_param] = self.var_to_call[

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get the idea that this is an embedded_params update but it is a bit confusing.

Comment on lines +102 to +103
"""Functional version of legacy embedding: Return a new dictionary\
with all embedded parameters."""

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Functional version of legacy embedding: Return a new dictionary\
with all embedded parameters."""
"""Functional version of legacy embedding.
Return a new dictionary with all embedded parameters.
"""



@pytest.mark.parametrize(
"fn", ["sin", "cos", "log", "tanh", "tan", "acos", "sin", "sqrt", "square"]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe also add pow ?

v_params = ["theta"]
f_params = ["x"]
tparam = "t"
leaf0, native_call0 = "%0", ConcretizedCallable(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to randomly draw from a poll of functions and variable names to construct these instructions ? I am a bit worried that we only cover a very reduced subset of the possibilities.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants