diff --git a/.github/workflows/rust-tests.yml b/.github/workflows/rust-tests.yml new file mode 100644 index 0000000..5f2217e --- /dev/null +++ b/.github/workflows/rust-tests.yml @@ -0,0 +1,47 @@ +name: Rust Tests + +on: + push: + branches: main + pull_request: + branches: main + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v3 + + - name: Set up Rust + uses: actions-rs/toolchain@v1 + with: + toolchain: stable + profile: minimal + override: true + + - name: Cache Cargo registry + uses: actions/cache@v3 + with: + path: ~/.cargo/registry + key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo-registry- + + - name: Cache Cargo build + uses: actions/cache@v3 + with: + path: target + key: ${{ runner.os }}-cargo-build-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo-build- + + - name: Install dependencies + run: cargo fetch + + - name: Run `clippy` linter + run: cargo clippy --all-targets -- -D warnings + + - name: Run tests + run: cargo test --all diff --git a/.gitignore b/.gitignore index b6f2e35..cb73f1e 100644 --- a/.gitignore +++ b/.gitignore @@ -50,3 +50,9 @@ runs/ # Mkdocs site/ + + +# Added by cargo + +/target +*.so \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..b905b04 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,403 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "autocfg" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" + +[[package]] +name = "bitflags" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "indoc" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" + +[[package]] +name = "libc" +version = "0.2.155" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" + +[[package]] +name = "lock_api" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +dependencies = [ + "autocfg", + "scopeguard", +] + +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + +[[package]] +name = "num" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23" +dependencies = [ + "num-bigint", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", +] + +[[package]] +name = "num-bigint" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c165a9ab64cf766f73521c0dd2cfdff64f488b8f0b3e621face3462d3db536d7" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-iter" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" + +[[package]] +name = "parking_lot" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets", +] + +[[package]] +name = "portable-atomic" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" + +[[package]] +name = "proc-macro2" +version = "1.0.85" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22244ce15aa966053a896d1accb3a6e68469b97c7f33f284b99f0d576879fc23" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "pyo3" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e00b96a521718e08e03b1a622f01c8a8deb50719335de3f60b3b3950f069d8" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "parking_lot", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7883df5835fafdad87c0d888b266c8ec0f4c9ca48a5bed6bbb592e8dedee1b50" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01be5843dc60b916ab4dad1dca6d20b9b4e6ddc8e15f50c47fe6d85f1fb97403" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77b34069fc0682e11b31dbd10321cbf94808394c56fd996796ce45217dfac53c" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08260721f32db5e1a5beae69a55553f56b99bd0e1c3e6e0a5e8851a9d0f5a85c" +dependencies = [ + "heck 0.4.1", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn", +] + +[[package]] +name = "qadence2_expressions" +version = "0.1.0" +dependencies = [ + "num", + "num-traits", + "pyo3", + "strum", + "strum_macros", +] + +[[package]] +name = "quote" +version = "1.0.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "redox_syscall" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "469052894dcb553421e483e4209ee581a45100d31b4018de03e5a7ad86374a7e" +dependencies = [ + "bitflags", +] + +[[package]] +name = "rustversion" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6" + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "smallvec" +version = "1.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" + +[[package]] +name = "strum" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d8cec3501a5194c432b2b7976db6b7d10ec95c253208b45f83f7136aa985e29" + +[[package]] +name = "strum_macros" +version = "0.26.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7993a8e3a9e88a00351486baae9522c91b123a088f76469e5bd5cc17198ea87" +dependencies = [ + "heck 0.5.0", + "proc-macro2", + "quote", + "rustversion", + "syn", +] + +[[package]] +name = "syn" +version = "2.0.66" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c42f3f41a2de00b01c0aaad383c5a45241efc8b2d1eda5661812fda5f3cdcff5" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "target-lexicon" +version = "0.12.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1fc403891a21bcfb7c37834ba66a547a8f402146eba7265b5a6d88059c9ff2f" + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "unindent" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" + +[[package]] +name = "windows-targets" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..994e4d1 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "qadence2_expressions" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[lib] +# The name of the native library. This is the name which will be used in Python to import the +# library (i.e. `import string_sum`). If you change this, you must also change the name of the +# `#[pymodule]` in `src/lib.rs`. +name = "pyexpression" +# "cdylib" is necessary to produce a shared library for Python to import from. +# +# Downstream Rust code (including code in `bin/`, `examples/`, and `tests/`) will not be able +# to `use string_sum;` unless the "rlib" or "lib" crate type is also included, e.g.: +# crate-type = ["cdylib", "rlib"] +crate-type = ["cdylib"] + +[dependencies] +num = "0.4.3" +num-traits = "0.2.19" +pyo3 = "0.21.2" +strum = "0.26.2" +strum_macros = "0.26.3" diff --git a/src/expression.rs b/src/expression.rs new file mode 100644 index 0000000..4f5861b --- /dev/null +++ b/src/expression.rs @@ -0,0 +1,304 @@ +use num_traits::pow::Pow; + +use crate::operator::Operator; +use crate::symbols::Numerical; + +use std::ops::{Add, Div, Mul, Sub, Neg}; + + +#[derive(Clone, Debug, PartialEq)] +pub enum Expression { + Symbol(&'static str), + Value(Numerical), + Expr { head: Operator, args: Vec }, +} + +// Implement helper functions to create different types of Expressions. +impl Expression { + pub fn symbol(name: &'static str) -> Self { + Expression::Symbol(name) + } + + pub fn float(value: f64) -> Self { + Expression::Value(Numerical::Float(value)) + } + + pub fn complex(real: f64, imag: f64) -> Self { + Expression::Value(Numerical::complex(real, imag)) + } +} + +impl Neg for Expression { + type Output = Expression; + + fn neg(self) -> Self::Output { + use Expression::{Expr, Value, Symbol}; + use Operator::MUL; + + match self { + // Negating a symbol directly isn't well-defined, but for the sake of + // completeness, we could wrap it in an Expr with multiplication by -1. + Symbol(s) => Expr { + head: MUL, + args: vec![ + Expression::float(-1.), + Expression::symbol(s), + ], + }, + + // Negate the numerical value. + Value(v) => { + Value(-v) + }, + + // Negate the entire expression by multiplying by -1 + Expr { head, args } => Expr { + head: MUL, + args: vec![ + Expression::float(-1.), + Expr { + head, + args, + } + ] + }, + } + } +} + +impl Pow for Expression { + type Output = Expression; + + fn pow(self, rhs: Self) -> Self::Output { + use Expression::{Expr, Value}; + use Operator::POW; + + match (self, rhs) { + // Numerical values are operated directly. + (Value(lhs), Value(rhs)) => Value(lhs.pow(rhs)), + + // If the left side is already a power expression, chain the exponent. + (Expr { head: POW, args: mut args_lhs }, rhs) => { + args_lhs.push(rhs); + Expr { head: POW, args: args_lhs } + }, + + // Otherwise, create a new power expression. + (lhs, rhs) => Expr { + head: POW, + args: vec![lhs, rhs], + } + } + } +} + +macro_rules! impl_binary_operator_for_expression { + ($trait:ident, $method:ident, $operator:path) => { + impl $trait for Expression { + type Output = Self; + + fn $method(self, other: Self) -> Self { + use Expression::*; + + match (self, other) { + (Value(x), Value(y)) => Value(x.$method(y)), + + (Expr {head: $operator, args: args_lhs}, Expr {head: $operator, args: args_rhs}) => { + let args = args_lhs.into_iter().chain(args_rhs.into_iter()).collect(); + Expr{head: $operator, args} + }, + + (Expr {head: $operator, args: mut args_lhs}, rhs) => { + args_lhs.push(rhs); + Expr {head: $operator, args: args_lhs} + }, + + (lhs, Expr {head: $operator, args: mut args_rhs}) => { + args_rhs.push(lhs); + Expr {head: $operator, args: args_rhs} + }, + + (lhs, rhs) => Expr{head: $operator, args: vec![lhs, rhs]}, + } + } + } + }; + + ($trait:ident, $method:ident, $operator:path, $inv:expr) => { + impl $trait for Expression { + type Output = Self; + + fn $method(self, other: Self) -> Self { + use Expression::*; + + match (self, other) { + (Value(x), Value(y)) => Value(x.$method(y)), + (lhs, rhs) => Expr { + head: $operator, + args: vec![lhs, $inv(rhs)] + }, + } + } + } + } +} + +impl_binary_operator_for_expression!(Add, add, Operator::ADD); +impl_binary_operator_for_expression!(Mul, mul, Operator::MUL); +impl_binary_operator_for_expression!(Sub, sub, Operator:: ADD, |x: Expression| { x.neg() }); +impl_binary_operator_for_expression!(Div, div, Operator:: MUL, |x: Expression| { x.pow(Expression::float(-1.0)) }); + + +#[cfg(test)] +mod tests { + use super::*; // This imports everything from the parent module + use num::Complex; + use Expression::Expr; + use Operator::{ADD, MUL}; + + #[test] + fn test_symbol_expression() { + let symbol_expr = Expression::symbol("x"); + assert_eq!(symbol_expr, Expression::Symbol("x")); + } + + #[test] + fn test_neg_expression() { + let expr = Expression::symbol("y"); + let neg_expr = -expr; + assert_eq!( + neg_expr, + Expr { + head: MUL, + args: vec![ + Expression::float(-1.), + Expression::symbol("y"), + ] + } + ); + } + + #[test] + fn test_float_expression() { + let float_expr = Expression::float(std::f64::consts::FRAC_1_PI); + assert_eq!(float_expr, Expression::Value(Numerical::Float(std::f64::consts::FRAC_1_PI))); + } + + #[test] + fn test_neg_float_expr() { + let value = Expression::float(10.); + let neg_value = -value; + assert_eq!(neg_value, Expression::float(-10.)); + } + + #[test] + fn test_complex_expression() { + let complex_expr = Expression::complex(1.0, 2.0); + assert_eq!(complex_expr, Expression::Value(Numerical::Complex(Complex::new(1.0, 2.0)))); + } + + #[test] + fn test_neg_complex_expr() { + let value = Expression::complex(1., 2.); + let neg_value = -value; + assert_eq!(neg_value, Expression::complex(-1., -2.)); + } + + #[test] + fn test_mixed_types_expression_add() { + let symbol_expr = Expression::symbol("x"); + let mixed_expr = Expression::float(1.) + symbol_expr; + assert_eq!( + mixed_expr, + Expr { + head: ADD, + args: vec![ + Expression::float(1.), + Expression::symbol("x"), + ] + } + ); + } + + #[test] + fn test_mixed_types_expression_sub() { + let symbol_expr = Expression::symbol("x"); + let mixed_expr = Expression::complex(1.0, 2.0) - symbol_expr; + assert_eq!( + mixed_expr, + Expr { + head: ADD, + args: vec![ + Expression::complex(1.0, 2.0), + Expression::Expr { + head: MUL, + args: vec![ + Expression::float(-1.), + Expression::symbol("x"), + ] + } + ] + } + ); + } + + #[test] + fn test_expression_binary_ops_float_to_float() { + let expr1 = Expression::float(1.0); + let expr2 = Expression::float(2.0); + let n1 = expr1.clone() + expr2.clone(); + let n2 = expr1.clone() - expr2.clone(); + let n3 = expr1.clone() * expr2.clone(); + let n4 = expr1.clone() / expr2.clone(); + + assert_eq!(n1, Expression::Value(Numerical::Float(3.0))); + assert_eq!(n2, Expression::Value(Numerical::Float(-1.0))); + assert_eq!(n3, Expression::Value(Numerical::Float(2.0))); + assert_eq!(n4, Expression::Value(Numerical::Float(0.5))); + } + + #[test] + fn test_expression_binary_ops_float_to_complex() { + let expr1 = Expression::float(1.0); + let expr2 = Expression::complex(2.0, 4.0); + let n1 = expr1.clone() + expr2.clone(); + let n2 = expr1.clone() - expr2.clone(); + let n3 = expr1.clone() * expr2.clone(); + let n4 = expr1.clone() / expr2.clone(); + + assert_eq!(n1, Expression::Value(Numerical::Complex(Complex::new(3.0, 4.0)))); + assert_eq!(n2, Expression::Value(Numerical::Complex(Complex::new(-1.0, -4.0)))); + assert_eq!(n3, Expression::Value(Numerical::Complex(Complex::new(2.0, 4.0)))); + assert_eq!(n4, Expression::Value(Numerical::Complex(Complex::new(0.1, -0.2)))); + } + + #[test] + fn test_expression_binary_ops_complex_to_float() { + let expr1 = Expression::complex(1.0, 2.0); + let expr2 = Expression::float(2.0); + let n1 = expr1.clone() + expr2.clone(); + let n2 = expr1.clone() - expr2.clone(); + let n3 = expr1.clone() * expr2.clone(); + let n4 = expr1.clone() / expr2.clone(); + + assert_eq!(n1, Expression::Value(Numerical::Complex(Complex::new(3.0, 2.0)))); + assert_eq!(n2, Expression::Value(Numerical::Complex(Complex::new(-1.0, 2.0)))); + assert_eq!(n3, Expression::Value(Numerical::Complex(Complex::new(2.0, 4.0)))); + assert_eq!(n4, Expression::Value(Numerical::Complex(Complex::new(0.5, 1.0)))); + } + + #[test] + fn test_expression_add_complex_to_complex() { + let expr1 = Expression::complex(1.0, 2.0); + let expr2 = Expression::complex(3.0, 4.0); + let n1 = expr1.clone() + expr2.clone(); + let n2 = expr1.clone() - expr2.clone(); + let n3 = expr1.clone() * expr2.clone(); + let n4 = expr1.clone() / expr2.clone(); + + assert_eq!(n1, Expression::Value(Numerical::Complex(Complex::new(4.0, 6.0)))); + assert_eq!(n2, Expression::Value(Numerical::Complex(Complex::new(-2.0, -2.0)))); + assert_eq!(n3, Expression::Value(Numerical::Complex(Complex::new(-5.0, 10.0)))); + assert_eq!(n4, Expression::Value(Numerical::Complex(Complex::new(0.44, 0.08)))); + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..28af4aa --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,42 @@ +// use pyo3::prelude::*; + +pub mod operator; +pub mod symbols; +pub mod expression; + +// #[pyclass] +// pub enum Operator { +// ADD, +// MUL, +// NONCOMMUTE, +// POWER, +// CALL, +// } + +// impl Operator { +// pub fn as_str(&self) -> &'static str { +// match self { +// Operator::ADD => "+", +// Operator::MUL => "*", +// Operator::NONCOMMUTE => "@", +// Operator::POWER => "^", +// Operator::CALL => "call", +// } +// } +// } + + +// /// Formats the sum of two numbers as string. +// #[pyfunction] +// fn operator() -> PyResult<&'static str> { +// Ok(Operator::ADD.as_str()) +// } + +// /// A Python module implemented in Rust. The name of this function must match +// /// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to +// /// import the module. +// #[pymodule] +// fn pyexpression(m: &Bound<'_, PyModule>) -> PyResult<()> { +// m.add_function(wrap_pyfunction!(operator, m)?)?; +// Ok(()) +// } diff --git a/src/libs.rs b/src/libs.rs new file mode 100644 index 0000000..ce1206a --- /dev/null +++ b/src/libs.rs @@ -0,0 +1,16 @@ +use pyo3::prelude::*; + +/// Formats the sum of two numbers as string. +#[pyfunction] +fn sum_as_string(a: usize, b: usize) -> PyResult { + Ok((a + b).to_string()) +} + +/// A Python module implemented in Rust. The name of this function must match +/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to +/// import the module. +#[pymodule] +fn qadence2_expressions(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_function(wrap_pyfunction!(sum_as_string, m)?)?; + Ok(()) +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..63cf1fb --- /dev/null +++ b/src/main.rs @@ -0,0 +1,6 @@ +// use crate::operator::Operator; + +fn main() { + println!("Hello, world!"); + // println!("{}", Operator::ADD.as_str()); +} diff --git a/src/operator.rs b/src/operator.rs new file mode 100644 index 0000000..6c56dad --- /dev/null +++ b/src/operator.rs @@ -0,0 +1,36 @@ +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum Operator { + ADD, + CALL, + MUL, + NONCOMMUTE, + POW, +} + +impl Operator { + pub fn as_str(&self) -> &'static str { + match self { + Operator::ADD => "+", + Operator::CALL => "call", + Operator::MUL => "*", + Operator::NONCOMMUTE => "@", + Operator::POW => "^", + } + } +} + + +#[cfg(test)] +mod tests { + // Note this useful idiom: importing names from outer (for mod tests) scope. + use super::*; + + #[test] + fn test_operator_as_str() { + assert_eq!(Operator::ADD.as_str(), "+"); + assert_eq!(Operator::MUL.as_str(), "*"); + assert_eq!(Operator::NONCOMMUTE.as_str(), "@"); + assert_eq!(Operator::POW.as_str(), "^"); + assert_eq!(Operator::CALL.as_str(), "call"); + } +} diff --git a/src/symbols.rs b/src/symbols.rs new file mode 100644 index 0000000..4280393 --- /dev/null +++ b/src/symbols.rs @@ -0,0 +1,175 @@ +use num::Complex; +use num_traits::pow::Pow; +use std::ops::{Add, Div, Mul, Sub, Neg}; +use std::fmt; + + + +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum Numerical { + Float(f64), + Complex(Complex), +} + +impl Numerical { + /// Convenience method to create a Numerical::Complex + pub fn complex(re: f64, im: f64) -> Self { + Numerical::Complex(Complex::new(re, im)) + } +} + +impl fmt::Display for Numerical { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Numerical::Float(value) => write!(f, "{}", value), + Numerical::Complex(value) => write!(f, "{} + {}i", value.re, value.im), + } + } +} + +impl Neg for Numerical { + type Output = Numerical; + + fn neg(self) -> Self::Output { + use Numerical::*; + + match self { + Float(f) => Numerical::Float(-f), + Complex(c) => Numerical::complex(-c.re, -c.im), + } + } +} + +impl Pow for Numerical { + type Output = Numerical; + + fn pow(self, rhs: Numerical) -> Self::Output { + use Numerical::*; + + match (self, rhs) { + (Float(base), Float(exp)) => Float(base.powf(exp)), + (Complex(base), Float(exp)) => Complex(base.powf(exp)), + (Complex(base), Complex(exp)) => Complex(base.powc(exp)), + (Float(base), Complex(exp)) => Complex((num::Complex::new(base, 0.0)).powc(exp)), + } + } +} + +macro_rules! impl_binary_operator_for_numerical { + ($trait:ident, $method:ident) => { + impl $trait for Numerical { + type Output = Self; + + fn $method(self, other: Self) -> Self::Output { + use Numerical::*; + // To disambiguate with the enum variant. + use num::Complex as complex; + + match (self, other) { + // Complex and Complex + (Complex(a), Complex(b)) => Complex(a.$method(b)), + + // Complex with Float + (Complex(a), Float(b)) => Complex(a.$method(complex::from(b))), + (Float(a), Complex(b)) => Complex(complex::from(a).$method(b)), + + // Float and Float + (Float(a), Float(b)) => Float(a.$method(b)), + } + } + } + }; +} + +// Implement the binary operators for Numerical using the macro +impl_binary_operator_for_numerical!(Add, add); +impl_binary_operator_for_numerical!(Sub, sub); +impl_binary_operator_for_numerical!(Mul, mul); +impl_binary_operator_for_numerical!(Div, div); + +#[derive(Debug, PartialEq)] +pub struct Symbol (&'static str); + +#[cfg(test)] +mod tests { + // Note this useful idiom: importing names from outer (for mod tests) scope. + use super::*; + + // Approximate equality check for Complex numbers + fn approx_eq_complex(c1: &num::Complex, c2: &num::Complex, epsilon: f64) -> bool { + (c1.re - c2.re).abs() < epsilon && (c1.im - c2.im).abs() < epsilon + } + + #[test] + fn test_negation_float() { + let num_float = Numerical::Float(5.5); + assert_eq!(-num_float, Numerical::Float(-5.5)); + + let num_float_neg = Numerical::Float(-5.5); + assert_eq!(-num_float_neg, Numerical::Float(5.5)); + } + + #[test] + fn test_negation_complex() { + let num_complex = Numerical::Complex(Complex::new(3.0, 4.0)); + assert_eq!(-num_complex, Numerical::Complex(Complex::new(-3.0, -4.0))); + + let num_complex_neg = Numerical::Complex(Complex::new(-3.0, -4.0)); + assert_eq!(-num_complex_neg, Numerical::Complex(Complex::new(3.0, 4.0))); + } + + #[test] + fn test_numerical_pow() { + let n3 = Numerical::Float(2.0); + let n4 = Numerical::Float(3.0); + assert_eq!(n3.pow(n4), Numerical::Float(8.0)); + + let n5 = Numerical::Complex(num::Complex::new(2.0, 0.0)); + let n6 = Numerical::Complex(num::Complex::new(3.0, 0.0)); + if let Numerical::Complex(c) = n5.pow(n6) { + assert!(approx_eq_complex(&c, &Complex::new(8.0, 0.0), 1e-9)); + } else { + panic!("Expected complex result"); + } + } + + #[test] + fn test_numerical_binary_ops_float_to_float() { + let n1 = Numerical::Float(5.0); + let n2 = Numerical::Float(10.0); + assert_eq!(n1 + n2, Numerical::Float(15.0)); + assert_eq!(n1 - n2, Numerical::Float(-5.0)); + assert_eq!(n1 * n2, Numerical::Float(50.0)); + assert_eq!(n1 / n2, Numerical::Float(0.5)); + } + + #[test] + fn test_numerical_binary_ops_float_to_complex() { + let n1 = Numerical::Float(5.0); + let n2 = Numerical::complex(3.0, 4.0); + assert_eq!(n1 + n2, Numerical::Complex(Complex::new(8.0, 4.0))); + assert_eq!(n1 - n2, Numerical::Complex(Complex::new(2.0, -4.0))); + assert_eq!(n1 * n2, Numerical::Complex(Complex::new(15.0, 20.0))); + assert_eq!(n1 / n2, Numerical::Complex(Complex::new(15.0 / 25.0, -20.0 / 25.0))); + } + + #[test] + fn test_numerical_binary_ops_complex_to_float() { + let n1 = Numerical::complex(5.0, 4.0); + let n2 = Numerical::Float(3.0); + assert_eq!(n1 + n2, Numerical::Complex(Complex::new(8.0, 4.0))); + assert_eq!(n1 - n2, Numerical::Complex(Complex::new(2.0, 4.0))); + assert_eq!(n1 * n2, Numerical::Complex(Complex::new(15.0, 12.0))); + assert_eq!(n1 / n2, Numerical::Complex(Complex::new(5.0 / 3.0, 4.0 / 3.0))); + } + + #[test] + fn test_numerical_binary_ops_complex_to_complex() { + let n1 = Numerical::complex(5.0, 4.0); + let n2 = Numerical::complex(3.0, 2.0); + assert_eq!(n1 + n2, Numerical::Complex(Complex::new(8.0, 6.0))); + assert_eq!(n1 - n2, Numerical::Complex(Complex::new(2.0, 2.0))); + assert_eq!(n1 * n2, Numerical::Complex(Complex::new(7.0, 22.0))); + assert_eq!(n1 / n2, Numerical::Complex(Complex::new(23.0 / 13.0, 2.0 / 13.0))); + } +}