-
Notifications
You must be signed in to change notification settings - Fork 0
[Feature] Main Embedding and ConcretizedCallable logic #1
base: main
Are you sure you want to change the base?
Changes from 2 commits
01635e6
0296b59
e6dd59f
0a7fbea
2586c0e
8457f98
5ac4327
081d737
329aed0
99b7ef6
81ac1c5
98eee1b
deb1c2f
2e00a5e
360d7f7
4a30d06
ff826b9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
# qadence-embeddings | ||
# qadence-embeddings | ||
|
||
**qadence-embeddings** is a engine-agnostic parameter embedding library. | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
site_name: qadence-embeddings | ||
repo_url: "https://github.com/pasqal-io/qadence-embeddings" | ||
repo_name: "qadence-embeddings" | ||
|
||
nav: | ||
|
||
- Embeddings in a nutshell: index.md | ||
|
||
|
||
theme: | ||
name: material | ||
features: | ||
- content.code.annotate | ||
- navigation.indexes | ||
- navigation.sections | ||
|
||
palette: | ||
- media: "(prefers-color-scheme: light)" | ||
scheme: default | ||
primary: light green | ||
accent: purple | ||
toggle: | ||
icon: material/weather-sunny | ||
name: Switch to dark mode | ||
- media: "(prefers-color-scheme: dark)" | ||
scheme: slate | ||
primary: black | ||
accent: light green | ||
toggle: | ||
icon: material/weather-night | ||
name: Switch to light mode | ||
|
||
markdown_extensions: | ||
- admonition # for notes | ||
- pymdownx.arithmatex: # for mathjax | ||
generic: true | ||
- pymdownx.highlight: | ||
anchor_linenums: true | ||
- pymdownx.inlinehilite | ||
- pymdownx.snippets | ||
- pymdownx.superfences | ||
|
||
plugins: | ||
- search | ||
- section-index | ||
- mkdocstrings: | ||
default_handler: python | ||
handlers: | ||
python: | ||
selection: | ||
filters: | ||
- "!^_" # exlude all members starting with _ | ||
- "^__init__$" # but always include __init__ modules and methods | ||
|
||
- mkdocs-jupyter: | ||
theme: light | ||
- markdown-exec | ||
|
||
extra_css: | ||
- extras/css/mkdocstrings.css | ||
- extras/css/colors.css | ||
- extras/css/home.css | ||
|
||
# For mathjax | ||
extra_javascript: | ||
- extras/javascripts/mathjax.js | ||
- https://polyfill.io/v3/polyfill.min.js?features=es6 | ||
- https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js | ||
|
||
watch: | ||
- qadence_embeddings |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
[build-system] | ||
requires = ["hatchling"] | ||
build-backend = "hatchling.build" | ||
|
||
[project] | ||
name = "qadence-embeddings" | ||
description = "a engine-agnostic parameter embedding library." | ||
authors = [ | ||
{ name = "Dominik Seitz", email = "dominik.seitz@pasqal.com" }, | ||
] | ||
requires-python = ">=3.8,<3.13" | ||
license = {text = "Apache 2.0"} | ||
version = "1.3.0" | ||
classifiers=[ | ||
"License :: OSI Approved :: Apache Software License", | ||
"Programming Language :: Python", | ||
"Programming Language :: Python :: 3", | ||
"Programming Language :: Python :: 3.8", | ||
"Programming Language :: Python :: 3.9", | ||
"Programming Language :: Python :: 3.10", | ||
"Programming Language :: Python :: 3.11", | ||
"Programming Language :: Python :: 3.12", | ||
"Programming Language :: Python :: Implementation :: CPython", | ||
"Programming Language :: Python :: Implementation :: PyPy", | ||
] | ||
dependencies = ["torch", "numpy", "jax"] | ||
|
||
[project.optional-dependencies] | ||
dev = ["flaky","black", "pytest", "pytest-xdist", "pytest-cov", "flake8", "mypy", "pre-commit", "ruff", "nbconvert", "matplotlib", "qutip~=4.7.5"] | ||
|
||
[tool.hatch.envs.tests] | ||
features = [ | ||
"dev", | ||
] | ||
|
||
[tool.hatch.envs.tests.scripts] | ||
test = "pytest -n auto {args}" | ||
test-docs = "hatch -e docs run mkdocs build --clean --strict" | ||
test-cov = "pytest -n auto --cov=qadence-embeddings tests/" | ||
|
||
[tool.hatch.envs.docs] | ||
dependencies = [ | ||
"mkdocs", | ||
"mkdocs-material", | ||
"mkdocstrings", | ||
"mkdocstrings-python", | ||
"mkdocs-section-index", | ||
"mkdocs-jupyter", | ||
"mkdocs-exclude", | ||
"markdown-exec", | ||
"mike", | ||
"matplotlib", | ||
] | ||
|
||
[tool.hatch.envs.docs.scripts] | ||
build = "mkdocs build --clean --strict" | ||
serve = "mkdocs serve --dev-addr localhost:8000" | ||
|
||
[tool.ruff] | ||
lint.select = ["E", "F", "I", "Q"] | ||
lint.extend-ignore = ["F841"] | ||
line-length = 100 | ||
|
||
[tool.ruff.lint.isort] | ||
required-imports = ["from __future__ import annotations"] | ||
|
||
[tool.ruff.lint.per-file-ignores] | ||
"__init__.py" = ["F401", "E402"] | ||
|
||
[tool.ruff.lint.mccabe] | ||
max-complexity = 15 | ||
|
||
[tool.ruff.lint.flake8-quotes] | ||
docstring-quotes = "double" | ||
|
||
[lint.black] | ||
line-length = 100 | ||
include = '\.pyi?$' | ||
exclude = ''' | ||
/( | ||
\.git | ||
| \.hg | ||
| \.mypy_cache | ||
| \.tox | ||
| \.venv | ||
| _build | ||
| buck-out | ||
| build | ||
| dist | ||
)/ | ||
''' | ||
|
||
[lint.isort] | ||
line_length = 100 | ||
combine_as_imports = true | ||
balanced_wrapping = true | ||
lines_after_imports = 2 | ||
include_trailing_comma = true | ||
multi_line_output = 5 | ||
|
||
[lint.mypy] | ||
python_version = "3.10" | ||
warn_return_any = true | ||
warn_unused_configs = true | ||
disallow_untyped_defs = true | ||
no_implicit_optional = false | ||
ignore_missing_imports = true |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from __future__ import annotations | ||
|
||
from .callable import ConcretizedCallable | ||
from .embedding import Embedding |
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,109 @@ | ||||||||||||
from __future__ import annotations | ||||||||||||
|
||||||||||||
from importlib import import_module | ||||||||||||
from logging import getLogger | ||||||||||||
from typing import Callable | ||||||||||||
|
||||||||||||
from numpy.typing import ArrayLike | ||||||||||||
|
||||||||||||
logger = getLogger(__name__) | ||||||||||||
|
||||||||||||
ARRAYLIKE_FN_MAP = { | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Or something like this ? |
||||||||||||
"torch": ("torch", "tensor"), | ||||||||||||
"jax": ("jax.numpy", "array"), | ||||||||||||
"numpy": ("numpy", "array"), | ||||||||||||
} | ||||||||||||
|
||||||||||||
|
||||||||||||
DEFAULT_JAX_MAPPING = {"mul": ("jax.numpy", "multiply")} | ||||||||||||
DEFAULT_TORCH_MAPPING = {} | ||||||||||||
DEFAULT_NUMPY_MAPPING = {"mul": ("numpy", "multiply")} | ||||||||||||
|
||||||||||||
DEFAULT_INSTRUCTION_MAPPING = { | ||||||||||||
"torch": DEFAULT_TORCH_MAPPING, | ||||||||||||
"jax": DEFAULT_JAX_MAPPING, | ||||||||||||
"numpy": DEFAULT_NUMPY_MAPPING, | ||||||||||||
} | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would removing default mappings make it difficult to test? |
||||||||||||
|
||||||||||||
|
||||||||||||
class ConcretizedCallable: | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But no strong opinion. |
||||||||||||
|
||||||||||||
def __init__( | ||||||||||||
self, | ||||||||||||
call_name: str, | ||||||||||||
abstract_args: list[str | float | int], | ||||||||||||
instruction_mapping: dict[str, Callable] = dict(), | ||||||||||||
engine_name: str = "torch", | ||||||||||||
) -> None: | ||||||||||||
instruction_mapping = { | ||||||||||||
**instruction_mapping, | ||||||||||||
**DEFAULT_INSTRUCTION_MAPPING[engine_name], | ||||||||||||
} | ||||||||||||
self.call_name = call_name | ||||||||||||
self.abstract_args = abstract_args | ||||||||||||
self.engine_name = engine_name | ||||||||||||
self.engine_call = None | ||||||||||||
engine_call = None | ||||||||||||
engine = None | ||||||||||||
try: | ||||||||||||
engine_name, fn_name = ARRAYLIKE_FN_MAP[engine_name] | ||||||||||||
engine = import_module(engine_name) | ||||||||||||
self.arraylike_fn = getattr(engine, fn_name) | ||||||||||||
except (ModuleNotFoundError, ImportError) as e: | ||||||||||||
logger.error(f"Unable to import {engine_call} due to {e}.") | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At this point, |
||||||||||||
|
||||||||||||
try: | ||||||||||||
# breakpoint() | ||||||||||||
try: | ||||||||||||
self.engine_call = getattr(engine, call_name) | ||||||||||||
except AttributeError: | ||||||||||||
pass | ||||||||||||
Comment on lines
+69
to
+72
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
if self.engine_call is None: | ||||||||||||
mod, fn = instruction_mapping[call_name] | ||||||||||||
self.engine_call = getattr(import_module(mod), fn) | ||||||||||||
except (ImportError, KeyError) as e: | ||||||||||||
logger.error( | ||||||||||||
f"Requested function {call_name} can not be imported from {engine_name} and is\ | ||||||||||||
not in instruction_mapping {instruction_mapping} due to {e}." | ||||||||||||
Comment on lines
+78
to
+79
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
) | ||||||||||||
|
||||||||||||
"""Convert a generic abstract function call and | ||||||||||||
a list of symbolic or constant parameters | ||||||||||||
into a concretized Callable in a particular engine. | ||||||||||||
which can be evaluated using | ||||||||||||
a inputs dict. | ||||||||||||
|
||||||||||||
Arguments: | ||||||||||||
call_name: The name of the function | ||||||||||||
abstract_args: A list of strings (in the case of parameters) and numeric constants | ||||||||||||
denoting the arguments for `call_name` | ||||||||||||
instruction_mapping: A dict mapping from an abstract call_name to its name in an engine. | ||||||||||||
engine_name: The engine to use to create the callable. | ||||||||||||
|
||||||||||||
Example: | ||||||||||||
``` | ||||||||||||
In [11]: call = ConcretizedCallable('sin', ['x'], engine_name='numpy') | ||||||||||||
In [12]: call({'x': 0.5}) | ||||||||||||
Out[12]: 0.479425538604203 | ||||||||||||
|
||||||||||||
In [13]: call = ConcretizedCallable('sin', ['x'], engine_name='torch') | ||||||||||||
In [14]: call({'x': torch.rand(1)}) | ||||||||||||
Out[14]: tensor([0.5531]) | ||||||||||||
|
||||||||||||
In [15]: call = ConcretizedCallable('sin', ['x'], engine_name='jax') | ||||||||||||
In [16]: call({'x': 0.5}) | ||||||||||||
Out[16]: Array(0.47942555, dtype=float32, weak_type=True) | ||||||||||||
``` | ||||||||||||
""" | ||||||||||||
|
||||||||||||
def evaluate(self, inputs: dict[str, ArrayLike] = dict()) -> ArrayLike: | ||||||||||||
arraylike_args = [] | ||||||||||||
for symbol_or_numeric in self.abstract_args: | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
if isinstance(symbol_or_numeric, (float, int)): | ||||||||||||
arraylike_args.append(self.arraylike_fn(symbol_or_numeric)) | ||||||||||||
elif isinstance(symbol_or_numeric, str): | ||||||||||||
arraylike_args.append(inputs[symbol_or_numeric]) | ||||||||||||
return self.engine_call(*arraylike_args) # type: ignore[misc] | ||||||||||||
|
||||||||||||
def __call__(self, inputs: dict[str, ArrayLike] = dict()) -> ArrayLike: | ||||||||||||
return self.evaluate(inputs) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from __future__ import annotations | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import Any, Callable, Optional | ||
|
||
|
||
class ParameterType: | ||
pass | ||
|
||
|
||
class DType: | ||
pass | ||
|
||
|
||
class Embedding(ABC): | ||
""" | ||
A generic module class to hold and handle the parameters and expressions | ||
functions coming from the `Model`. It may contain the list of user input | ||
parameters, as well as the trainable variational parameters and the | ||
evaluated functions from the data types being used, i.e. torch, numpy, etc. | ||
""" | ||
|
||
vparams: dict[str, ParameterType] | ||
fparams: dict[str, Optional[ParameterType]] | ||
mapped_vars: dict[str, Callable] | ||
_dtype: DType | ||
|
||
@abstractmethod | ||
def __call__(self, *args: Any, **kwargs: Any) -> dict[str, Callable]: | ||
raise NotImplementedError() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have to evaluate here ? |
||
|
||
@abstractmethod | ||
def name_mapping(self) -> dict: | ||
raise NotImplementedError() | ||
|
||
@property | ||
def dtype(self) -> DType: | ||
return self._dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.