Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small Lem refactors #673

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .cargo/config
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ xclippy = [
"-Wclippy::dbg_macro",
"-Wclippy::disallowed_methods",
"-Wclippy::derive_partial_eq_without_eq",
"-Wclippy::enum_glob_use",
"-Wclippy::filter_map_next",
"-Wclippy::flat_map_option",
"-Wclippy::inefficient_to_string",
Expand All @@ -22,5 +23,6 @@ xclippy = [
"-Wclippy::needless_borrow",
"-Wclippy::checked_conversions",
"-Wrust_2018_idioms",
"-Wtrivial_numeric_casts",
"-Wunused_lifetimes",
]
2 changes: 1 addition & 1 deletion clutch/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ impl ClutchState<F, Coproc<F>> {
.ok_or_else(|| anyhow!("proof not found: {zptr_string}"))?;

let pp = public_params(self.reduction_count, true, self.lang(), &public_param_dir())?;
let result = proof.verify(&pp, &self.lang()).unwrap();
let result = proof.verify(&pp, &self.lang())?;

if result.verified {
Ok(Some(lurk_sym_ptr!(store, t)))
Expand Down
2 changes: 1 addition & 1 deletion src/cli/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ impl Backend {
}

fn compatible_fields(&self) -> Vec<LanguageField> {
use LanguageField::*;
use LanguageField::{Pallas, Vesta, BLS12_381};
match self {
Self::Nova => vec![Pallas, Vesta],
Self::SnarkPackPlus => vec![BLS12_381],
Expand Down
2 changes: 1 addition & 1 deletion src/lem/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use crate::circuit::gadgets::{

use crate::{
field::{FWrap, LurkField},
tag::ExprTag::*,
tag::ExprTag::{Comm, Nil, Num, Sym},
};

use super::{
Expand Down
7 changes: 6 additions & 1 deletion src/lem/eval.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
use anyhow::Result;
use once_cell::sync::OnceCell;

use crate::{field::LurkField, func, state::initial_lurk_state, tag::ContTag::*};
use crate::{
field::LurkField,
func,
state::initial_lurk_state,
tag::ContTag::{Error, Outermost, Terminal},
};

use super::{interpreter::Frame, pointers::Ptr, store::Store, Func, Tag};

Expand Down
53 changes: 26 additions & 27 deletions src/lem/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@ use std::collections::VecDeque;

use super::{path::Path, pointers::Ptr, store::Store, var_map::VarMap, Block, Ctrl, Func, Op, Tag};

use crate::{field::LurkField, num::Num, state::initial_lurk_state, tag::ExprTag::*};
use crate::{
field::LurkField,
num::Num as BaseNum,
state::initial_lurk_state,
tag::ExprTag::{Comm, Nil, Num, Sym},
};

#[derive(Clone, Debug)]
pub enum PreimageData<F: LurkField> {
Expand Down Expand Up @@ -211,8 +216,8 @@ impl Block {
let b = bindings.get(b)?;
let c = if let (Ptr::Atom(_, f), Ptr::Atom(_, g)) = (a, b) {
preimages.less_than.push(Some(PreimageData::FPair(*f, *g)));
let f = Num::Scalar(*f);
let g = Num::Scalar(*g);
let f = BaseNum::Scalar(*f);
let g = BaseNum::Scalar(*g);
let b = if f < g { F::ONE } else { F::ZERO };
Ptr::Atom(Tag::Expr(Num), b)
} else {
Expand Down Expand Up @@ -363,18 +368,15 @@ impl Block {
Ctrl::MatchTag(match_var, cases, def) => {
let ptr = bindings.get(match_var)?;
let tag = ptr.tag();
match cases.get(tag) {
Some(block) => {
path.push_tag_inplace(*tag);
block.run(input, store, bindings, preimages, path, emitted)
}
None => {
path.push_default_inplace();
match def {
Some(def) => def.run(input, store, bindings, preimages, path, emitted),
None => bail!("No match for tag {}", tag),
}
}
if let Some(block) = cases.get(tag) {
path.push_tag_inplace(*tag);
block.run(input, store, bindings, preimages, path, emitted)
} else {
path.push_default_inplace();
let Some(def) = def else {
bail!("No match for tag {}", tag)
};
def.run(input, store, bindings, preimages, path, emitted)
}
}
Ctrl::MatchSymbol(match_var, cases, def) => {
Expand All @@ -385,18 +387,15 @@ impl Block {
let Some(sym) = store.fetch_symbol(ptr) else {
bail!("Symbol bound to {match_var} wasn't interned");
};
match cases.get(&sym) {
Some(block) => {
path.push_symbol_inplace(sym);
block.run(input, store, bindings, preimages, path, emitted)
}
None => {
path.push_default_inplace();
match def {
Some(def) => def.run(input, store, bindings, preimages, path, emitted),
None => bail!("No match for symbol {sym}"),
}
}
if let Some(block) = cases.get(&sym) {
path.push_symbol_inplace(sym);
block.run(input, store, bindings, preimages, path, emitted)
} else {
path.push_default_inplace();
let Some(def) = def else {
bail!("No match for symbol {sym}")
};
def.run(input, store, bindings, preimages, path, emitted)
}
}
Ctrl::IfEq(x, y, eq_block, else_block) => {
Expand Down
2 changes: 1 addition & 1 deletion src/lem/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ mod tests {
use crate::{
lem::{Block, Ctrl, Op, Tag, Var},
state::lurk_sym,
tag::ExprTag::*,
tag::ExprTag::{Char, Num, Str},
Symbol,
};

Expand Down
20 changes: 11 additions & 9 deletions src/lem/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,16 @@ pub enum Tag {
Cont(ContTag),
}

impl From<u16> for Tag {
fn from(val: u16) -> Self {
impl TryFrom<u16> for Tag {
type Error = anyhow::Error;

fn try_from(val: u16) -> Result<Self, Self::Error> {
if let Ok(tag) = ExprTag::try_from(val) {
Tag::Expr(tag)
Ok(Tag::Expr(tag))
} else if let Ok(tag) = ContTag::try_from(val) {
Tag::Cont(tag)
Ok(Tag::Cont(tag))
} else {
panic!("Invalid u16 for Tag: {val}")
bail!("Invalid u16 for Tag: {val}")
}
}
}
Expand Down Expand Up @@ -157,7 +159,7 @@ impl Tag {

impl std::fmt::Display for Tag {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use Tag::*;
use Tag::{Cont, Expr};
match self {
Expr(tag) => write!(f, "expr.{}", tag),
Cont(tag) => write!(f, "cont.{}", tag),
Expand Down Expand Up @@ -196,8 +198,8 @@ impl Lit {
}

pub fn from_ptr<F: LurkField>(ptr: &Ptr<F>, store: &Store<F>) -> Option<Self> {
use ExprTag::*;
use Tag::*;
use ExprTag::{Num, Str, Sym};
use Tag::Expr;
match ptr.tag() {
Expr(Num) => match ptr {
Ptr::Atom(_, f) => {
Expand Down Expand Up @@ -794,7 +796,7 @@ mod tests {
/// provided expressions.
/// - `expected_slots` gives the number of expected slots for each type of hash.
fn synthesize_test_helper(func: &Func, inputs: Vec<Ptr<Fr>>, expected_num_slots: SlotsCounter) {
use crate::tag::ContTag::*;
use crate::tag::ContTag::{Error, Outermost, Terminal};
let store = &mut func.init_store();
let outermost = Ptr::null(Tag::Cont(Outermost));
let terminal = Ptr::null(Tag::Cont(Terminal));
Expand Down
6 changes: 2 additions & 4 deletions src/lem/path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,8 @@ impl Path {

/// Computes the number of different paths taken given a list of paths
pub fn num_paths_taken(paths: &[Self]) -> usize {
let mut all_paths: HashSet<Self> = HashSet::default();
paths.iter().for_each(|path| {
all_paths.insert(path.clone());
});
let mut all_paths: HashSet<&Self> = HashSet::default();
all_paths.extend(paths);
all_paths.len()
}
}
Expand Down
5 changes: 4 additions & 1 deletion src/lem/pointers.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use serde::{Deserialize, Serialize};

use crate::{field::*, tag::ExprTag::*};
use crate::{
field::*,
tag::ExprTag::{Char, Comm, Nil, Num, U64},
};

use super::Tag;

Expand Down
35 changes: 19 additions & 16 deletions src/lem/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ use crate::{
field::{FWrap, LurkField},
hash::PoseidonCache,
lem::Tag,
parser::*,
parser::{syntax, Error, Span},
state::{lurk_sym, State},
symbol::Symbol,
syntax::Syntax,
tag::ExprTag::*,
tag::ExprTag::{Char, Comm, Cons, Fun, Key, Nil, Num, Str, Sym, Thunk, U64},
uint::UInt,
};

Expand Down Expand Up @@ -279,7 +279,7 @@ impl<F: LurkField> Store<F> {
None
}
}
Ptr::Tuple2(Tag::Expr(Sym), idx) | Ptr::Tuple2(Tag::Expr(Nil), idx) => {
Ptr::Tuple2(Tag::Expr(Sym | Nil), idx) => {
let path = self.fetch_symbol_path(*idx)?;
let sym = Symbol::sym_from_vec(path);
self.ptr_symbol_cache.insert(*ptr, Box::new(sym.clone()));
Expand Down Expand Up @@ -438,7 +438,7 @@ impl<F: LurkField> Store<F> {
&mut self,
state: Rc<RefCell<State>>,
input: &'a str,
) -> Result<(Span<'a>, Ptr<F>, bool), crate::parser::Error> {
) -> Result<(Span<'a>, Ptr<F>, bool), Error> {
match preceded(syntax::parse_space, syntax::parse_maybe_meta(state, false))
.parse(input.into())
{
Expand All @@ -463,9 +463,10 @@ impl<F: LurkField> Store<F> {
pub fn hash_ptr(&self, ptr: &Ptr<F>) -> Result<ZPtr<F>> {
match ptr {
Ptr::Atom(tag, x) => Ok(ZPtr::from_parts(*tag, *x)),
Ptr::Tuple2(tag, idx) => match self.z_cache.get(ptr) {
Some(z_ptr) => Ok(*z_ptr),
None => {
Ptr::Tuple2(tag, idx) => {
if let Some(z_ptr) = self.z_cache.get(ptr) {
Ok(*z_ptr)
} else {
let Some((a, b)) = self.fetch_2_ptrs(*idx) else {
bail!("Index {idx} not found on tuple2")
};
Expand All @@ -483,10 +484,11 @@ impl<F: LurkField> Store<F> {
self.z_cache.insert(*ptr, Box::new(z_ptr));
Ok(z_ptr)
}
},
Ptr::Tuple3(tag, idx) => match self.z_cache.get(ptr) {
Some(z_ptr) => Ok(*z_ptr),
None => {
}
Ptr::Tuple3(tag, idx) => {
if let Some(z_ptr) = self.z_cache.get(ptr) {
Ok(*z_ptr)
} else {
let Some((a, b, c)) = self.fetch_3_ptrs(*idx) else {
bail!("Index {idx} not found on tuple3")
};
Expand All @@ -507,10 +509,11 @@ impl<F: LurkField> Store<F> {
self.z_cache.insert(*ptr, Box::new(z_ptr));
Ok(z_ptr)
}
},
Ptr::Tuple4(tag, idx) => match self.z_cache.get(ptr) {
Some(z_ptr) => Ok(*z_ptr),
None => {
}
Ptr::Tuple4(tag, idx) => {
if let Some(z_ptr) = self.z_cache.get(ptr) {
Ok(*z_ptr)
} else {
let Some((a, b, c, d)) = self.fetch_4_ptrs(*idx) else {
bail!("Index {idx} not found on tuple4")
};
Expand All @@ -534,7 +537,7 @@ impl<F: LurkField> Store<F> {
self.z_cache.insert(*ptr, Box::new(z_ptr));
Ok(z_ptr)
}
},
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/lem/var_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use super::Var;
/// to be more ergonomic under the assumption that a LEM must always define
/// variables before using them, so we don't expect to need some piece of
/// information from a variable that hasn't been defined.
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct VarMap<V>(HashMap<Var, V>);

impl<V> VarMap<V> {
Expand Down
2 changes: 1 addition & 1 deletion src/lem/zstore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use super::{
store::Store,
};

#[derive(Default, Serialize, Deserialize)]
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct ZStore<F: LurkField> {
dag: BTreeMap<ZPtr<F>, ZChildren<F>>,
comms: BTreeMap<FWrap<F>, (F, ZPtr<F>)>,
Expand Down
19 changes: 15 additions & 4 deletions src/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,10 @@ impl<F: LurkField> Store<F> {
}

pub fn fetch_cont(&self, ptr: &ContPtr<F>) -> Option<Continuation<F>> {
use ContTag::*;
use ContTag::{
Binop, Binop2, Call, Call0, Call2, Dummy, Emit, Error, If, Let, LetRec, Lookup,
Outermost, Tail, Terminal, Unop,
};
match ptr.tag {
Outermost => Some(Continuation::Outermost),
Call0 => self
Expand Down Expand Up @@ -1414,7 +1417,9 @@ impl<F: LurkField> Store<F> {
if let Some(ptr) = self.fetch_z_expr_ptr(z_ptr) {
Some(ptr)
} else {
use ZExpr::*;
use ZExpr::{
Char, Comm, Cons, EmptyStr, Fun, Key, Nil, Num, RootSym, Str, Sym, Thunk, UInt,
};
match (z_ptr.tag(), z_store.get_expr(z_ptr)) {
(ExprTag::Nil, Some(Nil)) => {
let ptr = lurk_sym_ptr!(self, nil);
Expand Down Expand Up @@ -1526,7 +1531,10 @@ impl<F: LurkField> Store<F> {
z_ptr: &ZContPtr<F>,
z_store: &ZStore<F>,
) -> Option<ContPtr<F>> {
use ZCont::*;
use ZCont::{
Binop, Binop2, Call, Call0, Call2, Dummy, Emit, Error, If, Let, LetRec, Lookup,
Outermost, Tail, Terminal, Unop,
};
let tag: ContTag = z_ptr.tag();

if let Some(cont) = z_store.get_cont(z_ptr) {
Expand Down Expand Up @@ -1892,7 +1900,10 @@ pub mod test {

#[test]
fn cont_tag_vals() {
use super::ContTag::*;
use super::ContTag::{
Binop, Binop2, Call, Call0, Call2, Dummy, Emit, Error, If, Let, LetRec, Lookup,
Outermost, Tail, Terminal, Unop,
};

assert_eq!(0b0001_0000_0000_0000, Outermost as u16);
assert_eq!(0b0001_0000_0000_0001, Call0 as u16);
Expand Down
8 changes: 4 additions & 4 deletions src/syntax_macros.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#[macro_export]
macro_rules! num {
($f:ty, $i:literal) => {
$crate::syntax::Syntax::<$f>::Num(Pos::No, ($i as u64).into())
$crate::syntax::Syntax::<$f>::Num(Pos::No, ($i).into())
};
($i:literal) => {
$crate::syntax::Syntax::Num(Pos::No, ($i as u64).into())
$crate::syntax::Syntax::Num(Pos::No, ($i).into())
};
($i:expr) => {
$crate::syntax::Syntax::Num(Pos::No, $i)
Expand All @@ -14,10 +14,10 @@ macro_rules! num {
#[macro_export]
macro_rules! uint {
($f:ty, $i:literal) => {
$crate::syntax::Syntax::<$f>::UInt(Pos::No, $crate::uint::UInt::U64($i as u64))
$crate::syntax::Syntax::<$f>::UInt(Pos::No, $crate::uint::UInt::U64($i))
};
($i:literal) => {
$crate::syntax::Syntax::UInt(Pos::No, $crate::uint::UInt::U64($i as u64))
$crate::syntax::Syntax::UInt(Pos::No, $crate::uint::UInt::U64($i))
};
}

Expand Down
Loading
Loading