diff --git a/.cargo/config b/.cargo/config index 761408102b..a29c398498 100644 --- a/.cargo/config +++ b/.cargo/config @@ -7,4 +7,5 @@ xclippy = [ "clippy", "--workspace", "--all-targets", "--", "-Wclippy::all", "-Wclippy::disallowed_methods", + "-Wclippy::match_same_arms", ] diff --git a/Cargo.lock b/Cargo.lock index f0e0a7473d..93b03a1467 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1539,6 +1539,7 @@ dependencies = [ name = "lurk-macros" version = "0.1.0" dependencies = [ + "anyhow", "bincode", "lurk", "pasta_curves", diff --git a/lurk-macros/Cargo.toml b/lurk-macros/Cargo.toml index 7296f73027..21a8fcfc64 100644 --- a/lurk-macros/Cargo.toml +++ b/lurk-macros/Cargo.toml @@ -19,6 +19,7 @@ proptest-derive = { workspace = true } serde = { workspace = true, features = ["derive"] } [dev-dependencies] +anyhow.workspace = true bincode = { workspace = true } lurk_crate = { path = "../", package = "lurk" } pasta_curves = { workspace = true, features = ["repr-c", "serde"] } diff --git a/lurk-macros/src/lib.rs b/lurk-macros/src/lib.rs index 768d7bd29e..18ba126e7a 100644 --- a/lurk-macros/src/lib.rs +++ b/lurk-macros/src/lib.rs @@ -19,7 +19,7 @@ use proc_macro2::Span; use quote::{quote, ToTokens}; use syn::{ parse_macro_input, AttributeArgs, Data, DataEnum, DeriveInput, Ident, Item, Lit, Meta, - MetaList, NestedMeta, Type, + MetaList, NestedMeta, Path, Type, }; #[proc_macro_derive(Coproc)] @@ -375,3 +375,102 @@ fn parse_type(m: &NestedMeta) -> Type { } } } + +fn try_from_match_arms( + name: &Ident, + variant_names: &[&Ident], + ty: syn::Path, +) -> proc_macro2::TokenStream { + let mut match_arms = quote! {}; + for variant in variant_names { + match_arms.extend(quote! { + x if x == #name::#variant as #ty => Ok(#name::#variant), + }); + } + match_arms +} + +fn get_type_from_attrs(attrs: &[syn::Attribute], attr_name: &str) -> syn::Result { + let Some(nested_arg) = attrs.iter().find_map(|arg| { + let Ok(Meta::List(MetaList { path, nested, .. })) = arg.parse_meta() else { + return None; + }; + if !path.is_ident(attr_name) { + return None; + } + nested.first().cloned() + }) else { + return Err(syn::Error::new( + proc_macro2::Span::call_site(), + format!("Could not find attribute {}", attr_name), + )); + }; + + match nested_arg { + NestedMeta::Meta(Meta::Path(path)) => Ok(path), + bad => Err(syn::Error::new_spanned( + bad, + &format!("Could not parse {} attribute", attr_name)[..], + )), + } +} + +/// This macro derives an impl of TryFrom for an enum type T with `#[repr(foo)]`. +/// +/// # Example +/// ``` +/// use lurk_macros::TryFromRepr; +/// +/// #[derive(TryFromRepr)] +/// #[repr(u8)] +/// enum Foo { +/// Bar = 0, +/// Baz +/// } +/// ``` +/// +/// This will derive the natural impl that compares the input representation type to +/// the automatic conversions of each variant into that representation type. +#[proc_macro_derive(TryFromRepr)] +pub fn derive_try_from_repr(input: TokenStream) -> TokenStream { + let ast = parse_macro_input!(input as DeriveInput); + let res_ty = get_type_from_attrs(&ast.attrs, "repr"); + + let name = &ast.ident; + let variants = match ast.data { + Data::Enum(ref variants) => variants + .variants + .iter() + .map(|v| &v.ident) + .collect::>(), + Data::Struct(_) | Data::Union(_) => { + panic!("#[derive(TryFromRepr)] is only defined for enums") + } + }; + + match res_ty { + Err(e) => { + // If no explicit repr were given for us, we can't pursue + panic!( + "TryFromRepr macro requires a repr parameter, which couldn't be parsed: {:?}", + e + ); + } + Ok(ty) => { + let match_arms = try_from_match_arms(name, &variants, ty.clone()); + let name_str = name.to_string(); + quote! { + impl std::convert::TryFrom<#ty> for #name { + type Error = anyhow::Error; + fn try_from(v: #ty) -> Result>::Error> { + match v { + #match_arms + _ => Err(anyhow::anyhow!("invalid variant for enum {}", #name_str)), + } + } + } + } + } + } + .into() +} diff --git a/src/cli/repl.rs b/src/cli/repl.rs index 5c697f4118..368c4e29b6 100644 --- a/src/cli/repl.rs +++ b/src/cli/repl.rs @@ -533,11 +533,11 @@ impl Repl { .eval_expr(second) .with_context(|| "evaluating second arg")?; let Some(secret) = self.store.fetch_num(&first_io.expr) else { - bail!( - "Secret must be a number. Got {}", - first_io.expr.fmt_to_string(&self.store) - ) - }; + bail!( + "Secret must be a number. Got {}", + first_io.expr.fmt_to_string(&self.store) + ) + }; self.hide(secret.into_scalar(), second_io.expr)?; } "fetch" => { diff --git a/src/eval/reduction.rs b/src/eval/reduction.rs index 263f1b8623..19dbde6fd2 100644 --- a/src/eval/reduction.rs +++ b/src/eval/reduction.rs @@ -38,9 +38,9 @@ enum Control { impl Control { fn into_results(self, store: &mut Store) -> (Ptr, Ptr, ContPtr) { match self { - Self::Return(expr, env, cont) => (expr, env, cont), - Self::MakeThunk(expr, env, cont) => (expr, env, cont), - Self::ApplyContinuation(expr, env, cont) => (expr, env, cont), + Self::Return(expr, env, cont) + | Self::MakeThunk(expr, env, cont) + | Self::ApplyContinuation(expr, env, cont) => (expr, env, cont), Self::Error(expr, env) => (expr, env, store.intern_cont_error()), } } diff --git a/src/hash_witness.rs b/src/hash_witness.rs index 3385236855..52926b8bb0 100644 --- a/src/hash_witness.rs +++ b/src/hash_witness.rs @@ -82,6 +82,7 @@ pub trait HashName { impl HashName for ConsName { fn index(&self) -> usize { + #[allow(clippy::match_same_arms)] match self { Self::NeverUsed => MAX_CONSES_PER_REDUCTION + 1, Self::Expr => 0, @@ -131,6 +132,7 @@ pub enum ContName { impl HashName for ContName { fn index(&self) -> usize { + #[allow(clippy::match_same_arms)] match self { Self::NeverUsed => MAX_CONTS_PER_REDUCTION + 1, Self::ApplyContinuation => 0, diff --git a/src/lem/circuit.rs b/src/lem/circuit.rs index 3c81a257ac..2dee50e33f 100644 --- a/src/lem/circuit.rs +++ b/src/lem/circuit.rs @@ -839,15 +839,7 @@ impl Func { Op::Cast(_tgt, tag, _src) => { globals.insert(FWrap(tag.to_field())); } - Op::Add(_, _, _) => { - globals.insert(FWrap(Tag::Expr(Num).to_field())); - num_constraints += 1; - } - Op::Sub(_, _, _) => { - globals.insert(FWrap(Tag::Expr(Num).to_field())); - num_constraints += 1; - } - Op::Mul(_, _, _) => { + Op::Add(_, _, _) | Op::Sub(_, _, _) | Op::Mul(_, _, _) => { globals.insert(FWrap(Tag::Expr(Num).to_field())); num_constraints += 1; } @@ -877,11 +869,7 @@ impl Func { // one constraint for the image's hash num_constraints += 1; } - Op::Hide(..) => { - // TODO - globals.insert(FWrap(F::ZERO)); - } - Op::Open(..) => { + Op::Hide(..) | Op::Open(..) => { // TODO globals.insert(FWrap(F::ZERO)); } diff --git a/src/lem/pointers.rs b/src/lem/pointers.rs index b93a288b83..e46713f410 100644 --- a/src/lem/pointers.rs +++ b/src/lem/pointers.rs @@ -35,10 +35,7 @@ impl std::hash::Hash for Ptr { impl Ptr { pub fn tag(&self) -> &Tag { match self { - Ptr::Leaf(tag, _) => tag, - Ptr::Tree2(tag, _) => tag, - Ptr::Tree3(tag, _) => tag, - Ptr::Tree4(tag, _) => tag, + Ptr::Leaf(tag, _) | Ptr::Tree2(tag, _) | Ptr::Tree3(tag, _) | Ptr::Tree4(tag, _) => tag, } } diff --git a/src/parser/syntax.rs b/src/parser/syntax.rs index 0cc12004ad..bc217eeb98 100644 --- a/src/parser/syntax.rs +++ b/src/parser/syntax.rs @@ -348,17 +348,7 @@ pub mod tests { { match (expected, p.parse(Span::<'a>::new(i))) { (Some(expected), Ok((_, x))) if x == expected => true, - (Some(_), Ok(..)) => { - // println!("input: {:?}", i); - // println!("expected: {} {:?}", expected.clone(), expected); - // println!("detected: {} {:?}", x.clone(), x); - false - } - (Some(..), Err(_)) => { - // println!("{}", e); - false - } - (None, Ok(..)) => { + (Some(_), Ok(..)) | (Some(..), Err(_)) | (None, Ok(..)) => { // println!("input: {:?}", i); // println!("expected parse error"); // println!("detected: {:?}", x); diff --git a/src/store.rs b/src/store.rs index f7a85698df..cce31bffe5 100644 --- a/src/store.rs +++ b/src/store.rs @@ -663,13 +663,7 @@ impl Store { // fetch a symbol cons or keyword cons pub fn fetch_symcons(&self, ptr: &Ptr) -> Option<(Ptr, Ptr)> { match (ptr.tag, ptr.raw) { - (ExprTag::Sym, RawPtr::Null) => None, - (ExprTag::Key, RawPtr::Null) => None, - (ExprTag::Sym, RawPtr::Index(x)) => { - let (car, cdr) = self.sym_store.get_index(x)?; - Some((*car, *cdr)) - } - (ExprTag::Key, RawPtr::Index(x)) => { + (ExprTag::Sym, RawPtr::Index(x)) | (ExprTag::Key, RawPtr::Index(x)) => { let (car, cdr) = self.sym_store.get_index(x)?; Some((*car, *cdr)) } @@ -711,7 +705,6 @@ impl Store { pub fn fetch_strcons(&self, ptr: &Ptr) -> Option<(Ptr, Ptr)> { match (ptr.tag, ptr.raw) { - (ExprTag::Str, RawPtr::Null) => None, (ExprTag::Str, RawPtr::Index(x)) => { let (car, cdr) = self.str_store.get_index(x)?; Some((*car, *cdr)) diff --git a/src/tag.rs b/src/tag.rs index 352cf9209b..7fc8eb1d31 100644 --- a/src/tag.rs +++ b/src/tag.rs @@ -1,4 +1,4 @@ -use anyhow::anyhow; +use lurk_macros::TryFromRepr; #[cfg(not(target_arch = "wasm32"))] use proptest_derive::Arbitrary; use serde_repr::{Deserialize_repr, Serialize_repr}; @@ -20,7 +20,9 @@ pub trait Tag: Into + TryFrom + Copy + Sized + Eq + fmt::Debug { } /// A tag for expressions. Note that ExprTag, ContTag, Op1, Op2 all live in the same u16 namespace -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize_repr, Deserialize_repr)] +#[derive( + Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize_repr, Deserialize_repr, TryFromRepr, +)] #[cfg_attr(not(target_arch = "wasm32"), derive(Arbitrary))] #[repr(u16)] pub enum ExprTag { @@ -49,27 +51,6 @@ impl From for u64 { } } -impl TryFrom for ExprTag { - type Error = anyhow::Error; - - fn try_from(x: u16) -> Result>::Error> { - match x { - f if f == ExprTag::Nil as u16 => Ok(ExprTag::Nil), - f if f == ExprTag::Cons as u16 => Ok(ExprTag::Cons), - f if f == ExprTag::Sym as u16 => Ok(ExprTag::Sym), - f if f == ExprTag::Fun as u16 => Ok(ExprTag::Fun), - f if f == ExprTag::Thunk as u16 => Ok(ExprTag::Thunk), - f if f == ExprTag::Num as u16 => Ok(ExprTag::Num), - f if f == ExprTag::Str as u16 => Ok(ExprTag::Str), - f if f == ExprTag::Char as u16 => Ok(ExprTag::Char), - f if f == ExprTag::Comm as u16 => Ok(ExprTag::Comm), - f if f == ExprTag::U64 as u16 => Ok(ExprTag::U64), - f if f == ExprTag::Key as u16 => Ok(ExprTag::Key), - f => Err(anyhow!("Invalid ExprTag value: {}", f)), - } - } -} - impl fmt::Display for ExprTag { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -94,17 +75,15 @@ impl TypePredicates for ExprTag { } fn is_self_evaluating(&self) -> bool { match self { - Self::Cons => false, - Self::Thunk => false, - Self::Sym => false, - Self::Nil => true, - Self::Fun => true, - Self::Num => true, - Self::Str => true, - Self::Char => true, - Self::Comm => true, - Self::U64 => true, - Self::Key => true, + Self::Cons | Self::Thunk | Self::Sym => false, + Self::Nil + | Self::Fun + | Self::Num + | Self::Str + | Self::Char + | Self::Comm + | Self::U64 + | Self::Key => true, } } @@ -131,7 +110,9 @@ impl Tag for ExprTag { } /// A tag for continuations. Note that ExprTag, ContTag, Op1, Op2 all live in the same u16 namespace -#[derive(Serialize_repr, Deserialize_repr, Debug, Copy, Clone, PartialEq, Eq, Hash)] +#[derive( + Serialize_repr, Deserialize_repr, Debug, Copy, Clone, PartialEq, Eq, Hash, TryFromRepr, +)] #[cfg_attr(not(target_arch = "wasm32"), derive(Arbitrary))] #[repr(u16)] pub enum ContTag { @@ -165,32 +146,6 @@ impl From for u64 { } } -impl TryFrom for ContTag { - type Error = anyhow::Error; - - fn try_from(x: u16) -> Result>::Error> { - match x { - f if f == ContTag::Outermost as u16 => Ok(ContTag::Outermost), - f if f == ContTag::Call0 as u16 => Ok(ContTag::Call0), - f if f == ContTag::Call as u16 => Ok(ContTag::Call), - f if f == ContTag::Call2 as u16 => Ok(ContTag::Call2), - f if f == ContTag::Tail as u16 => Ok(ContTag::Tail), - f if f == ContTag::Error as u16 => Ok(ContTag::Error), - f if f == ContTag::Lookup as u16 => Ok(ContTag::Lookup), - f if f == ContTag::Unop as u16 => Ok(ContTag::Unop), - f if f == ContTag::Binop as u16 => Ok(ContTag::Binop), - f if f == ContTag::Binop2 as u16 => Ok(ContTag::Binop2), - f if f == ContTag::If as u16 => Ok(ContTag::If), - f if f == ContTag::Let as u16 => Ok(ContTag::Let), - f if f == ContTag::LetRec as u16 => Ok(ContTag::LetRec), - f if f == ContTag::Dummy as u16 => Ok(ContTag::Dummy), - f if f == ContTag::Terminal as u16 => Ok(ContTag::Terminal), - f if f == ContTag::Emit as u16 => Ok(ContTag::Emit), - f => Err(anyhow!("Invalid ContTag value: {}", f)), - } - } -} - impl Tag for ContTag { fn from_field(f: &F) -> Option { Self::try_from(f.to_u16()?).ok() @@ -231,7 +186,18 @@ impl fmt::Display for ContTag { } } -#[derive(Copy, Clone, Debug, PartialEq, PartialOrd, Eq, Hash, Serialize_repr, Deserialize_repr)] +#[derive( + Copy, + Clone, + Debug, + PartialEq, + PartialOrd, + Eq, + Hash, + Serialize_repr, + Deserialize_repr, + TryFromRepr, +)] #[cfg_attr(not(target_arch = "wasm32"), derive(Arbitrary))] #[repr(u16)] pub enum Op1 { @@ -261,28 +227,6 @@ impl From for u64 { } } -impl TryFrom for Op1 { - type Error = anyhow::Error; - - fn try_from(x: u16) -> Result>::Error> { - match x { - f if f == Op1::Car as u16 => Ok(Op1::Car), - f if f == Op1::Cdr as u16 => Ok(Op1::Cdr), - f if f == Op1::Atom as u16 => Ok(Op1::Atom), - f if f == Op1::Emit as u16 => Ok(Op1::Emit), - f if f == Op1::Open as u16 => Ok(Op1::Open), - f if f == Op1::Secret as u16 => Ok(Op1::Secret), - f if f == Op1::Commit as u16 => Ok(Op1::Commit), - f if f == Op1::Num as u16 => Ok(Op1::Num), - f if f == Op1::Comm as u16 => Ok(Op1::Comm), - f if f == Op1::Char as u16 => Ok(Op1::Char), - f if f == Op1::Eval as u16 => Ok(Op1::Eval), - f if f == Op1::U64 as u16 => Ok(Op1::U64), - f => Err(anyhow!("Invalid Op1 value: {}", f)), - } - } -} - pub trait Op where Self: 'static, @@ -371,7 +315,18 @@ impl fmt::Display for Op1 { } } -#[derive(Copy, Clone, Debug, PartialEq, PartialOrd, Eq, Hash, Serialize_repr, Deserialize_repr)] +#[derive( + Copy, + Clone, + Debug, + PartialEq, + PartialOrd, + Eq, + Hash, + Serialize_repr, + Deserialize_repr, + TryFromRepr, +)] #[cfg_attr(not(target_arch = "wasm32"), derive(Arbitrary))] #[repr(u16)] pub enum Op2 { @@ -405,32 +360,6 @@ impl From for u64 { } } -impl TryFrom for Op2 { - type Error = anyhow::Error; - - fn try_from(x: u16) -> Result>::Error> { - match x { - f if f == Op2::Sum as u16 => Ok(Op2::Sum), - f if f == Op2::Diff as u16 => Ok(Op2::Diff), - f if f == Op2::Product as u16 => Ok(Op2::Product), - f if f == Op2::Quotient as u16 => Ok(Op2::Quotient), - f if f == Op2::Equal as u16 => Ok(Op2::Equal), - f if f == Op2::NumEqual as u16 => Ok(Op2::NumEqual), - f if f == Op2::Less as u16 => Ok(Op2::Less), - f if f == Op2::Greater as u16 => Ok(Op2::Greater), - f if f == Op2::LessEqual as u16 => Ok(Op2::LessEqual), - f if f == Op2::GreaterEqual as u16 => Ok(Op2::GreaterEqual), - f if f == Op2::Cons as u16 => Ok(Op2::Cons), - f if f == Op2::StrCons as u16 => Ok(Op2::StrCons), - f if f == Op2::Begin as u16 => Ok(Op2::Begin), - f if f == Op2::Hide as u16 => Ok(Op2::Hide), - f if f == Op2::Modulo as u16 => Ok(Op2::Modulo), - f if f == Op2::Eval as u16 => Ok(Op2::Eval), - f => Err(anyhow!("Invalid Op2 value: {}", f)), - } - } -} - impl Tag for Op2 { fn from_field(f: &F) -> Option { Self::try_from(f.to_u16()?).ok() diff --git a/src/z_data/z_cont.rs b/src/z_data/z_cont.rs index 6091afa806..92050450b2 100644 --- a/src/z_data/z_cont.rs +++ b/src/z_data/z_cont.rs @@ -90,20 +90,7 @@ impl ZCont { /// Creates a list of field elements corresponding to the `ZCont` for hashing pub fn hash_components(&self) -> [F; 8] { match self { - Self::Outermost => [F::ZERO; 8], - Self::Call0 { - saved_env, - continuation, - } => [ - saved_env.0.to_field(), - saved_env.1, - continuation.0.to_field(), - continuation.1, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - ], + Self::Outermost | Self::Error | Self::Dummy | Self::Terminal => [F::ZERO; 8], Self::Call { saved_env, unevaled_arg, @@ -132,21 +119,15 @@ impl ZCont { F::ZERO, F::ZERO, ], - Self::Tail { + Self::Call0 { saved_env, continuation, - } => [ - saved_env.0.to_field(), - saved_env.1, - continuation.0.to_field(), - continuation.1, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - ], - Self::Error => [F::ZERO; 8], - Self::Lookup { + } + | Self::Tail { + saved_env, + continuation, + } + | Self::Lookup { saved_env, continuation, } => [ @@ -219,17 +200,8 @@ impl ZCont { body, saved_env, continuation, - } => [ - var.0.to_field(), - var.1, - body.0.to_field(), - body.1, - saved_env.0.to_field(), - saved_env.1, - continuation.0.to_field(), - continuation.1, - ], - Self::LetRec { + } + | Self::LetRec { var, body, saved_env, @@ -254,8 +226,6 @@ impl ZCont { F::ZERO, F::ZERO, ], - Self::Dummy => [F::ZERO; 8], - Self::Terminal => [F::ZERO; 8], } }