Skip to content
This repository has been archived by the owner on Jul 16, 2024. It is now read-only.

[Feature] Main Embedding and ConcretizedCallable logic #1

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
# qadence-embeddings
# qadence-embeddings

**qadence-embeddings** is a engine-agnostic parameter embedding library.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
**qadence-embeddings** is a engine-agnostic parameter embedding library.
**qadence-embeddings** is an engine-agnostic parameter embedding library.

Empty file added docs/index.md
Empty file.
71 changes: 71 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
site_name: qadence-embeddings
repo_url: "https://github.com/pasqal-io/qadence-embeddings"
repo_name: "qadence-embeddings"

nav:

- Embeddings in a nutshell: index.md


theme:
name: material
features:
- content.code.annotate
- navigation.indexes
- navigation.sections

palette:
- media: "(prefers-color-scheme: light)"
scheme: default
primary: light green
accent: purple
toggle:
icon: material/weather-sunny
name: Switch to dark mode
- media: "(prefers-color-scheme: dark)"
scheme: slate
primary: black
accent: light green
toggle:
icon: material/weather-night
name: Switch to light mode

markdown_extensions:
- admonition # for notes
- pymdownx.arithmatex: # for mathjax
generic: true
- pymdownx.highlight:
anchor_linenums: true
- pymdownx.inlinehilite
- pymdownx.snippets
- pymdownx.superfences

plugins:
- search
- section-index
- mkdocstrings:
default_handler: python
handlers:
python:
selection:
filters:
- "!^_" # exlude all members starting with _
- "^__init__$" # but always include __init__ modules and methods

- mkdocs-jupyter:
theme: light
- markdown-exec

extra_css:
- extras/css/mkdocstrings.css
- extras/css/colors.css
- extras/css/home.css

# For mathjax
extra_javascript:
- extras/javascripts/mathjax.js
- https://polyfill.io/v3/polyfill.min.js?features=es6
- https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js

watch:
- qadence_embeddings
107 changes: 107 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[project]
name = "qadence-embeddings"
description = "a engine-agnostic parameter embedding library."
authors = [
{ name = "Dominik Seitz", email = "dominik.seitz@pasqal.com" },
]
requires-python = ">=3.8,<3.13"
license = {text = "Apache 2.0"}
version = "1.3.0"
classifiers=[
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = ["torch", "numpy"]

[project.optional-dependencies]
dev = ["flaky","black", "pytest", "pytest-xdist", "pytest-cov", "flake8", "mypy", "pre-commit", "ruff", "nbconvert", "matplotlib", "qutip~=4.7.5"]

[tool.hatch.envs.tests]
features = [
"dev",
]

[tool.hatch.envs.tests.scripts]
test = "pytest -n auto {args}"
test-docs = "hatch -e docs run mkdocs build --clean --strict"
test-cov = "pytest -n auto --cov=qadence-embeddings tests/"

[tool.hatch.envs.docs]
dependencies = [
"mkdocs",
"mkdocs-material",
"mkdocstrings",
"mkdocstrings-python",
"mkdocs-section-index",
"mkdocs-jupyter",
"mkdocs-exclude",
"markdown-exec",
"mike",
"matplotlib",
]

[tool.hatch.envs.docs.scripts]
build = "mkdocs build --clean --strict"
serve = "mkdocs serve --dev-addr localhost:8000"

[tool.ruff]
lint.select = ["E", "F", "I", "Q"]
lint.extend-ignore = ["F841"]
line-length = 100

[tool.ruff.lint.isort]
required-imports = ["from __future__ import annotations"]

[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401", "E402"]

[tool.ruff.lint.mccabe]
max-complexity = 15

[tool.ruff.lint.flake8-quotes]
docstring-quotes = "double"

[lint.black]
line-length = 100
include = '\.pyi?$'
exclude = '''
/(
\.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
)/
'''

[lint.isort]
line_length = 100
combine_as_imports = true
balanced_wrapping = true
lines_after_imports = 2
include_trailing_comma = true
multi_line_output = 5

[lint.mypy]
python_version = "3.10"
warn_return_any = true
warn_unused_configs = true
disallow_untyped_defs = true
no_implicit_optional = false
ignore_missing_imports = true
4 changes: 4 additions & 0 deletions qadence_embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from __future__ import annotations

from .callable import ConcretizedCallable
from .embedding import Embedding
83 changes: 83 additions & 0 deletions qadence_embeddings/callable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from __future__ import annotations

from importlib import import_module
from logging import getLogger
from typing import Callable

from numpy.typing import ArrayLike

logger = getLogger(__name__)

ARRAYLIKE_FN_MAP = {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ARRAYLIKE_FN_MAP = {
ARRAYLIKE_CONSTRUCTOR_MAP = {

Or something like this ?

"torch": ("torch", "tensor"),
"jax": ("jax.numpy", "array"),
"numpy": ("numpy", "array"),
}


def ConcretizedCallable(
call_name: str,
abstract_args: list[str | float | int],

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
abstract_args: list[str | float | int],
args: list[str | float | int],

instruction_mapping: dict[str, Callable] = dict(),
engine_name: str = "torch",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No our beloved StrEnums ? ;)

) -> Callable[[dict, dict], ArrayLike]:
"""Convert a generic abstract function call and
a list of symbolic or constant parameters
into a concretized Callable in a particular engine.
which can be evaluated using
a vparams and inputs dict.

Arguments:
call_name: The name of the function
abstract_args: A list of strings (in the case of parameters) and numeric constants
denoting the arguments for `call_name`
instruction_mapping: A dict mapping from an abstract call_name to its name in an engine.
engine_name: The engine to use to create the callable.

Example:
```
In [11]: call = ConcretizedCallable('sin', ['x'], engine_name='numpy')
In [12]: call({'x': 0.5})
Out[12]: 0.479425538604203

In [13]: call = ConcretizedCallable('sin', ['x'], engine_name='torch')
In [14]: call({'x': torch.rand(1)})
Out[14]: tensor([0.5531])

In [15]: call = ConcretizedCallable('sin', ['x'], engine_name='jax')
In [16]: call({'x': 0.5})
Out[16]: Array(0.47942555, dtype=float32, weak_type=True)
```
"""
engine_call = None
engine = None
try:
engine_name, fn_name = ARRAYLIKE_FN_MAP[engine_name]
engine = import_module(engine_name)
arraylike_fn = getattr(engine, fn_name)
except (ModuleNotFoundError, ImportError) as e:
logger.error(f"Unable to import {engine_call} due to {e}.")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logger.error(f"Unable to import {engine_call} due to {e}.")
logger.error(f"Unable to import {engine_name} due to {e}.")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to raise the error again ?


try:
engine_call = getattr(engine, call_name)
except ImportError:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it the right exception to catch ?

pass

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
try:
engine_call = getattr(engine, call_name)
except ImportError:
pass
engine_call = getattr(engine, call_name, None)

if engine_call is None:
try:
engine_call = instruction_mapping[call_name]
except KeyError as e:
logger.error(
f"Requested function {call_name} can not be imported from {engine_name} and is\
not in instruction_mapping {instruction_mapping} due to {e}."

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f"Requested function {call_name} can not be imported from {engine_name} and is\
not in instruction_mapping {instruction_mapping} due to {e}."
f"Requested function {call_name} can not be imported from {engine_name} and is"
f"not in instruction_mapping {instruction_mapping} due to {e}."

)

def evaluate(params: dict = dict(), inputs: dict = dict()) -> ArrayLike:
arraylike_args = []
for symbol_or_numeric in abstract_args:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for symbol_or_numeric in abstract_args:
for arg in args:

if isinstance(symbol_or_numeric, (float, int)):
arraylike_args.append(arraylike_fn(symbol_or_numeric))
elif isinstance(symbol_or_numeric, str):
arraylike_args.append({**params, **inputs}[symbol_or_numeric])
return engine_call(*arraylike_args) # type: ignore[misc]

return evaluate
38 changes: 38 additions & 0 deletions qadence_embeddings/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, Callable, Optional


class ParameterType:
pass


class DType:
pass


class Embedding(ABC):
"""
A generic module class to hold and handle the parameters and expressions
functions coming from the `Model`. It may contain the list of user input
parameters, as well as the trainable variational parameters and the
evaluated functions from the data types being used, i.e. torch, numpy, etc.
"""

vparams: dict[str, ParameterType]
fparams: dict[str, Optional[ParameterType]]
mapped_vars: dict[str, Callable]
_dtype: DType

@abstractmethod
def __call__(self, *args: Any, **kwargs: Any) -> dict[str, Callable]:
raise NotImplementedError()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to evaluate here ?


@abstractmethod
def name_mapping(self) -> dict:
raise NotImplementedError()

@property
def dtype(self) -> DType:
return self._dtype
Empty file added tests/test_callable.py
Empty file.
Empty file added tests/test_embedding.py
Empty file.