Skip to content
This repository has been archived by the owner on Jul 16, 2024. It is now read-only.

[Feature] Main Embedding and ConcretizedCallable logic #1

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
# qadence-embeddings
# qadence-embeddings

**qadence-embeddings** is a engine-agnostic parameter embedding library.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
**qadence-embeddings** is a engine-agnostic parameter embedding library.
**qadence-embeddings** is an engine-agnostic parameter embedding library.

Empty file added docs/index.md
Empty file.
71 changes: 71 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
site_name: qadence-embeddings
repo_url: "https://github.com/pasqal-io/qadence-embeddings"
repo_name: "qadence-embeddings"

nav:

- Embeddings in a nutshell: index.md


theme:
name: material
features:
- content.code.annotate
- navigation.indexes
- navigation.sections

palette:
- media: "(prefers-color-scheme: light)"
scheme: default
primary: light green
accent: purple
toggle:
icon: material/weather-sunny
name: Switch to dark mode
- media: "(prefers-color-scheme: dark)"
scheme: slate
primary: black
accent: light green
toggle:
icon: material/weather-night
name: Switch to light mode

markdown_extensions:
- admonition # for notes
- pymdownx.arithmatex: # for mathjax
generic: true
- pymdownx.highlight:
anchor_linenums: true
- pymdownx.inlinehilite
- pymdownx.snippets
- pymdownx.superfences

plugins:
- search
- section-index
- mkdocstrings:
default_handler: python
handlers:
python:
selection:
filters:
- "!^_" # exlude all members starting with _
- "^__init__$" # but always include __init__ modules and methods

- mkdocs-jupyter:
theme: light
- markdown-exec

extra_css:
- extras/css/mkdocstrings.css
- extras/css/colors.css
- extras/css/home.css

# For mathjax
extra_javascript:
- extras/javascripts/mathjax.js
- https://polyfill.io/v3/polyfill.min.js?features=es6
- https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js

watch:
- qadence_embeddings
107 changes: 107 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[project]
name = "qadence-embeddings"
description = "a engine-agnostic parameter embedding library."
authors = [
{ name = "Dominik Seitz", email = "dominik.seitz@pasqal.com" },
]
requires-python = ">=3.8,<3.13"
license = {text = "Apache 2.0"}
version = "1.3.0"
classifiers=[
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = ["torch", "numpy", "jax"]

[project.optional-dependencies]
dev = ["flaky","black", "pytest", "pytest-xdist", "pytest-cov", "flake8", "mypy", "pre-commit", "ruff", "nbconvert", "matplotlib", "qutip~=4.7.5"]

[tool.hatch.envs.tests]
features = [
"dev",
]

[tool.hatch.envs.tests.scripts]
test = "pytest -n auto {args}"
test-docs = "hatch -e docs run mkdocs build --clean --strict"
test-cov = "pytest -n auto --cov=qadence-embeddings tests/"

[tool.hatch.envs.docs]
dependencies = [
"mkdocs",
"mkdocs-material",
"mkdocstrings",
"mkdocstrings-python",
"mkdocs-section-index",
"mkdocs-jupyter",
"mkdocs-exclude",
"markdown-exec",
"mike",
"matplotlib",
]

[tool.hatch.envs.docs.scripts]
build = "mkdocs build --clean --strict"
serve = "mkdocs serve --dev-addr localhost:8000"

[tool.ruff]
lint.select = ["E", "F", "I", "Q"]
lint.extend-ignore = ["F841"]
line-length = 100

[tool.ruff.lint.isort]
required-imports = ["from __future__ import annotations"]

[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401", "E402"]

[tool.ruff.lint.mccabe]
max-complexity = 15

[tool.ruff.lint.flake8-quotes]
docstring-quotes = "double"

[lint.black]
line-length = 100
include = '\.pyi?$'
exclude = '''
/(
\.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
)/
'''

[lint.isort]
line_length = 100
combine_as_imports = true
balanced_wrapping = true
lines_after_imports = 2
include_trailing_comma = true
multi_line_output = 5

[lint.mypy]
python_version = "3.10"
warn_return_any = true
warn_unused_configs = true
disallow_untyped_defs = true
no_implicit_optional = false
ignore_missing_imports = true
4 changes: 4 additions & 0 deletions qadence_embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from __future__ import annotations

from .callable import ConcretizedCallable
from .embedding import Embedding
109 changes: 109 additions & 0 deletions qadence_embeddings/callable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from __future__ import annotations

from importlib import import_module
from logging import getLogger
from typing import Callable

from numpy.typing import ArrayLike

logger = getLogger(__name__)

ARRAYLIKE_FN_MAP = {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ARRAYLIKE_FN_MAP = {
ARRAYLIKE_CONSTRUCTOR_MAP = {

Or something like this ?

"torch": ("torch", "tensor"),
"jax": ("jax.numpy", "array"),
"numpy": ("numpy", "array"),
}


DEFAULT_JAX_MAPPING = {"mul": ("jax.numpy", "multiply")}
DEFAULT_TORCH_MAPPING = {}
DEFAULT_NUMPY_MAPPING = {"mul": ("numpy", "multiply")}

DEFAULT_INSTRUCTION_MAPPING = {
"torch": DEFAULT_TORCH_MAPPING,
"jax": DEFAULT_JAX_MAPPING,
"numpy": DEFAULT_NUMPY_MAPPING,
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would removing default mappings make it difficult to test?



class ConcretizedCallable:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class ConcretizedCallable:
class CallableMap:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But no strong opinion.


def __init__(
self,
call_name: str,
abstract_args: list[str | float | int],
instruction_mapping: dict[str, Callable] = dict(),
engine_name: str = "torch",
) -> None:
instruction_mapping = {
**instruction_mapping,
**DEFAULT_INSTRUCTION_MAPPING[engine_name],
}
self.call_name = call_name
self.abstract_args = abstract_args
self.engine_name = engine_name
self.engine_call = None
engine_call = None
engine = None
try:
engine_name, fn_name = ARRAYLIKE_FN_MAP[engine_name]
engine = import_module(engine_name)
self.arraylike_fn = getattr(engine, fn_name)
except (ModuleNotFoundError, ImportError) as e:
logger.error(f"Unable to import {engine_call} due to {e}.")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point, engine_call is None. I think you meant engine_name.


try:
# breakpoint()
try:
self.engine_call = getattr(engine, call_name)
except AttributeError:
pass
Comment on lines +69 to +72

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
try:
self.engine_call = getattr(engine, call_name)
except AttributeError:
pass
self.engine_call = getattr(engine, call_name, None)

if self.engine_call is None:
mod, fn = instruction_mapping[call_name]
self.engine_call = getattr(import_module(mod), fn)
except (ImportError, KeyError) as e:
logger.error(
f"Requested function {call_name} can not be imported from {engine_name} and is\
not in instruction_mapping {instruction_mapping} due to {e}."
Comment on lines +78 to +79

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f"Requested function {call_name} can not be imported from {engine_name} and is\
not in instruction_mapping {instruction_mapping} due to {e}."
f"Requested function {call_name} can not be imported from {engine_name} and is"
f" not in instruction_mapping {instruction_mapping} due to {e}."

)

"""Convert a generic abstract function call and
a list of symbolic or constant parameters
into a concretized Callable in a particular engine.
which can be evaluated using
a inputs dict.

Arguments:
call_name: The name of the function
abstract_args: A list of strings (in the case of parameters) and numeric constants
denoting the arguments for `call_name`
instruction_mapping: A dict mapping from an abstract call_name to its name in an engine.
engine_name: The engine to use to create the callable.

Example:
```
In [11]: call = ConcretizedCallable('sin', ['x'], engine_name='numpy')
In [12]: call({'x': 0.5})
Out[12]: 0.479425538604203

In [13]: call = ConcretizedCallable('sin', ['x'], engine_name='torch')
In [14]: call({'x': torch.rand(1)})
Out[14]: tensor([0.5531])

In [15]: call = ConcretizedCallable('sin', ['x'], engine_name='jax')
In [16]: call({'x': 0.5})
Out[16]: Array(0.47942555, dtype=float32, weak_type=True)
```
"""

def evaluate(self, inputs: dict[str, ArrayLike] = dict()) -> ArrayLike:
arraylike_args = []
for symbol_or_numeric in self.abstract_args:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for symbol_or_numeric in self.abstract_args:
for symbol_or_number in self.abstract_args:

if isinstance(symbol_or_numeric, (float, int)):
arraylike_args.append(self.arraylike_fn(symbol_or_numeric))
elif isinstance(symbol_or_numeric, str):
arraylike_args.append(inputs[symbol_or_numeric])
return self.engine_call(*arraylike_args) # type: ignore[misc]

def __call__(self, inputs: dict[str, ArrayLike] = dict()) -> ArrayLike:
return self.evaluate(inputs)
38 changes: 38 additions & 0 deletions qadence_embeddings/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, Callable, Optional


class ParameterType:
pass


class DType:
pass


class Embedding(ABC):
"""
A generic module class to hold and handle the parameters and expressions
functions coming from the `Model`. It may contain the list of user input
parameters, as well as the trainable variational parameters and the
evaluated functions from the data types being used, i.e. torch, numpy, etc.
"""

vparams: dict[str, ParameterType]
fparams: dict[str, Optional[ParameterType]]
mapped_vars: dict[str, Callable]
_dtype: DType

@abstractmethod
def __call__(self, *args: Any, **kwargs: Any) -> dict[str, Callable]:
raise NotImplementedError()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to evaluate here ?


@abstractmethod
def name_mapping(self) -> dict:
raise NotImplementedError()

@property
def dtype(self) -> DType:
return self._dtype
Loading