diff --git a/.gitignore b/.gitignore index 2b18eda..336e7d6 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ config/local.toml config/prod.toml test-data/user-data*/ my_permanent_leveldb/ -*level_db*/ \ No newline at end of file +*level_db*/ +*.json \ No newline at end of file diff --git a/README.md b/README.md index 6fa580d..6aeaaed 100644 --- a/README.md +++ b/README.md @@ -35,14 +35,27 @@ output_proof_path="global_proof.json" cargo run --release --package zk-por-cli --bin zk-por-cli prove --cfg-path ${cfg_dir_path} --output-path ${output_proof_path} ``` + +- get-merkle-proof +``` +cargo run --release --package zk-por-cli --bin zk-por-cli get-merkle-proof --account-path account.json --output-path merkle_proof.json --cfg-path config +``` + - verify ``` global_root_path="global_proof.json" # optional. If not provided, will skip verifying the inclusion -arg_inclusion_proof_path="--inclusion-proof-path inclusion_proof.json" +inclusion_proof_path="merkle_proof.json" -cargo run --features zk-por-core/verifier --release --package zk-por-cli --bin zk-por-cli verify --global-proof-path ${global_root_path} ${arg_inclusion_proof_path} +cargo run --features zk-por-core/verifier --release --package zk-por-cli --bin zk-por-cli verify --global-proof-path ${global_root_path} --inclusion-proof-path ${inclusion_proof_path} --root 11288199779358641579,2344540219612146741,6809171731163302525,17936043556479519168 +``` + +## cli +``` +./target/release/zk-por-cli --help +./target/release/zk-por-cli prove --cfg-path ${cfg_dir_path} --output-path ${output_proof_path} +./target/release/zk-por-cli get-merkle-proof --account-path account.json --output-path merkle_proof.json --cfg-path config ``` ## code coverage diff --git a/crates/zk-por-cli/src/lib.rs b/crates/zk-por-cli/src/lib.rs index 7d62afa..9794d33 100644 --- a/crates/zk-por-cli/src/lib.rs +++ b/crates/zk-por-cli/src/lib.rs @@ -1,3 +1,4 @@ mod constant; +pub mod merkle_proof; pub mod prover; pub mod verifier; diff --git a/crates/zk-por-cli/src/main.rs b/crates/zk-por-cli/src/main.rs index aae41d0..8e1e26a 100644 --- a/crates/zk-por-cli/src/main.rs +++ b/crates/zk-por-cli/src/main.rs @@ -1,8 +1,8 @@ use std::{path::PathBuf, str::FromStr}; use clap::{Parser, Subcommand}; -use zk_por_cli::{prover::prove, verifier::verify}; -use zk_por_core::error::PoRError; +use zk_por_cli::{merkle_proof::get_merkle_proof, prover::prove, verifier::verify}; +use zk_por_core::{config::ProverConfig, error::PoRError}; #[derive(Parser)] #[command(version, about, long_about = None)] @@ -25,13 +25,19 @@ pub enum ZkPorCommitCommands { }, GetMerkleProof { #[arg(short, long)] - user_id: String, + cfg_path: String, // path to config file + #[arg(short, long)] + account_path: String, + #[arg(short, long)] + output_path: String, // path to output file }, Verify { #[arg(short, long)] global_proof_path: String, #[arg(short, long)] inclusion_proof_path: Option, + #[arg(short, long)] + root: Option, }, } @@ -45,16 +51,17 @@ impl Execute for ZkPorCommitCommands { let output_file = PathBuf::from_str(&output_path).unwrap(); prove(prover_cfg, output_file) } - ZkPorCommitCommands::GetMerkleProof { user_id } => { - // TODO: implement this - _ = user_id; - Ok(()) + ZkPorCommitCommands::GetMerkleProof { cfg_path, account_path, output_path } => { + let cfg = zk_por_core::config::ProverConfig::load(&cfg_path) + .map_err(|e| PoRError::ConfigError(e))?; + let prover_cfg: ProverConfig = cfg.try_deserialize().unwrap(); + get_merkle_proof(account_path.to_string(), prover_cfg, output_path.to_string()) } - ZkPorCommitCommands::Verify { global_proof_path, inclusion_proof_path } => { + ZkPorCommitCommands::Verify { global_proof_path, inclusion_proof_path, root } => { let global_proof_path = PathBuf::from_str(&global_proof_path).unwrap(); let inclusion_proof_path = inclusion_proof_path.as_ref().map(|p| PathBuf::from_str(&p).unwrap()); - verify(global_proof_path, inclusion_proof_path) + verify(global_proof_path, inclusion_proof_path, root.clone()) } } } diff --git a/crates/zk-por-cli/src/merkle_proof.rs b/crates/zk-por-cli/src/merkle_proof.rs new file mode 100644 index 0000000..015b249 --- /dev/null +++ b/crates/zk-por-cli/src/merkle_proof.rs @@ -0,0 +1,67 @@ +use std::{fs::File, io::Write, str::FromStr}; + +use serde_json::json; +use zk_por_core::{ + account::Account, + config::ProverConfig, + database::{DataBase, DbOption}, + error::PoRError, + global::GlobalConfig, + merkle_proof::MerkleProof, + parser::{AccountParser, FileAccountReader, FileManager, FilesCfg}, +}; + +use crate::constant::RECURSION_BRANCHOUT_NUM; + +pub fn get_merkle_proof( + account_path: String, + cfg: ProverConfig, + output_path: String, +) -> Result<(), PoRError> { + let database = DataBase::new(DbOption { + user_map_dir: cfg.db.level_db_user_path.to_string(), + gmst_dir: cfg.db.level_db_gmst_path.to_string(), + }); + + let batch_size = cfg.prover.batch_size as usize; + let token_num = cfg.prover.num_of_tokens as usize; + + // the path to dump the final generated proof + let file_manager = FileManager {}; + let account_parser = FileAccountReader::new( + FilesCfg { + dir: std::path::PathBuf::from_str(&cfg.prover.user_data_path).unwrap(), + batch_size: cfg.prover.batch_size, + num_of_tokens: cfg.prover.num_of_tokens, + }, + &file_manager, + ); + account_parser.log_state(); + // let mut account_parser: Box = Box::new(parser); + + let batch_num = account_parser.total_num_of_users().div_ceil(batch_size); + + let global_cfg = GlobalConfig { + num_of_tokens: token_num, + num_of_batches: batch_num, + batch_size: batch_size, + recursion_branchout_num: RECURSION_BRANCHOUT_NUM, + }; + + // the account json format is a map of tokens to token values, wheras the format in the merkle proof is given by a vec of token values. + let account = Account::new_from_file_path(account_path); + + if account.is_err() { + return Err(account.unwrap_err()); + } + + let merkle_proof = MerkleProof::new_from_account(&account.unwrap(), &database, &global_cfg) + .expect("Una ble to generate merkle proof"); + + let mut file = File::create(output_path.clone()) + .expect(format!("fail to create proof file at {:#?}", output_path).as_str()); + file.write_all(json!(merkle_proof).to_string().as_bytes()) + .expect("fail to write proof to file"); + + Ok(()) +} diff --git a/crates/zk-por-cli/src/prover.rs b/crates/zk-por-cli/src/prover.rs index 88352fb..4f4183d 100644 --- a/crates/zk-por-cli/src/prover.rs +++ b/crates/zk-por-cli/src/prover.rs @@ -110,7 +110,7 @@ pub fn prove(cfg: ProverConfig, proof_output_path: PathBuf) -> Result<(), PoRErr accounts.resize(account_num + pad_num, Account::get_empty_account(token_num)); } - assert_eq!(account_num % batch_size, 0); + assert_eq!(accounts.len() % batch_size, 0); tracing::debug!( "parse {} times, with number of accounts {}, number of batches {}", @@ -307,7 +307,10 @@ pub fn prove(cfg: ProverConfig, proof_output_path: PathBuf) -> Result<(), PoRErr // persist gmst to database let global_mst = GLOBAL_MST.get().unwrap(); + let _g = global_mst.read().expect("unable to get a lock"); + let root_hash = _g.get_root().expect("no root"); + tracing::info!("root hash is {:?}", root_hash); let start = std::time::Instant::now(); _g.persist(&mut database); tracing::info!("persist gmst to db in {:?}", start.elapsed()); diff --git a/crates/zk-por-cli/src/verifier.rs b/crates/zk-por-cli/src/verifier.rs index 8da3bf4..d716c2a 100644 --- a/crates/zk-por-cli/src/verifier.rs +++ b/crates/zk-por-cli/src/verifier.rs @@ -7,12 +7,15 @@ use zk_por_core::{ circuit_config::{get_recursive_circuit_configs, STANDARD_CONFIG}, circuit_registry::registry::CircuitRegistry, error::PoRError, + merkle_proof::MerkleProof, + util::get_hash_from_hash_string, Proof, }; pub fn verify( global_proof_path: PathBuf, merkle_inclusion_path: Option, + root: Option, ) -> Result<(), PoRError> { let proof_file = File::open(&global_proof_path).unwrap(); let reader = std::io::BufReader::new(proof_file); @@ -63,10 +66,28 @@ pub fn verify( } println!("successfully verify the global proof for round {}", round_num); - // TODO: verify the inclusion proof if let Some(merkle_inclusion_path) = merkle_inclusion_path { - _ = merkle_inclusion_path; - println!("successfully verify the inclusion proof for user for round {}", round_num); + let merkle_path = File::open(&merkle_inclusion_path).unwrap(); + let reader = std::io::BufReader::new(merkle_path); + + // Parse the JSON as Proof + let proof: MerkleProof = from_reader(reader).unwrap(); + + if root.is_none() { + return Err(PoRError::InvalidParameter( + "Require root for merkle proof verification".to_string(), + )); + } + + let res = proof.verify_merkle_proof(get_hash_from_hash_string(root.unwrap())); + + if res.is_err() { + let res_err = res.unwrap_err(); + return Err(res_err); + } else { + println!("successfully verify the inclusion proof for user for round {}", round_num); + return Ok(()); + } } Ok(()) diff --git a/crates/zk-por-core/src/account.rs b/crates/zk-por-core/src/account.rs index 5d98838..01462c9 100644 --- a/crates/zk-por-core/src/account.rs +++ b/crates/zk-por-core/src/account.rs @@ -3,15 +3,20 @@ use plonky2::{ plonk::config::Hasher, }; use plonky2_field::types::Field; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::{collections::BTreeMap, fs::File, io::BufReader}; use crate::{ database::{DataBase, UserId}, + error::PoRError, + parser::parse_account_state, types::F, }; use rand::Rng; /// A struct representing a users account. It represents their equity and debt as a Vector of goldilocks field elements. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Account { pub id: String, // 256 bit hex string pub equity: Vec, @@ -34,6 +39,14 @@ impl Account { hash } + pub fn get_empty_account_with_user_id(user_id: String, num_of_tokens: usize) -> Account { + Self { + id: user_id, + equity: vec![F::default(); num_of_tokens], + debt: vec![F::default(); num_of_tokens], + } + } + pub fn get_empty_account(num_of_tokens: usize) -> Account { Self { id: "0".repeat(64), @@ -59,6 +72,27 @@ impl Account { .map(|seg| F::from_canonical_u64(u64::from_str_radix(seg, 16).unwrap())) .collect::>() } + + /// Get a new account from a file path + pub fn new_from_file_path(path: String) -> Result { + let file_res = File::open(path); + + if file_res.is_err() { + return Err(PoRError::InvalidParameter("Invalid account json file path".to_string())); + } + + let reader = BufReader::new(file_res.unwrap()); + // Deserialize the json data to a struct + let account_map_res: Result, _> = serde_json::from_reader(reader); + + if account_map_res.is_err() { + return Err(PoRError::InvalidParameter("Invalid account json".to_string())); + } + + let account = parse_account_state(&account_map_res.unwrap()); + + Ok(account) + } } pub fn persist_account_id_to_gmst_pos( @@ -70,12 +104,9 @@ pub fn persist_account_id_to_gmst_pos( .iter() .enumerate() .map(|(i, acct)| { - let hex_decode = hex::decode(&acct.id).unwrap(); - assert_eq!(hex_decode.len(), 32); - let mut array = [0u8; 32]; - array.copy_from_slice(&hex_decode); - - (UserId(array), (i + start_idx) as u32) + let user_id = UserId::from_hex_string(acct.id.to_string()).unwrap(); + // tracing::debug!("persist account {:?} with index: {:?}", acct.id, i + start_idx); + (user_id, (i + start_idx) as u32) }) .collect::>(); db.add_batch_users(user_batch); @@ -105,6 +136,10 @@ pub fn gen_accounts_with_random_data(num_accounts: usize, num_assets: usize) -> } pub fn gen_empty_accounts(batch_size: usize, num_assets: usize) -> Vec { - let accounts = vec![Account::get_empty_account(num_assets); batch_size]; + let accounts = + vec![ + Account::get_empty_account_with_user_id(UserId::rand().to_string(), num_assets); + batch_size + ]; accounts } diff --git a/crates/zk-por-core/src/config.rs b/crates/zk-por-core/src/config.rs index 5180c70..a40e3e9 100644 --- a/crates/zk-por-core/src/config.rs +++ b/crates/zk-por-core/src/config.rs @@ -40,6 +40,22 @@ pub struct ConfigDb { pub level_db_gmst_path: String, } +impl ConfigDb { + pub fn load(dir: &str) -> Result { + let env = std::env::var("ENV").unwrap_or("default".into()); + Config::builder() + // .add_source(File::with_name(&format!("{}/default", dir))) + .add_source(File::with_name(&format!("{}/{}", dir, env)).required(false)) + .add_source(File::with_name(&format!("{}/local", dir)).required(false)) + .add_source(config::Environment::with_prefix("ZKPOR")) + .build() + } + pub fn try_new() -> Result { + let config = Self::load("config")?; + config.try_deserialize() + } +} + #[derive(Debug, Clone, Deserialize)] pub struct ProverConfig { pub log: ConfigLog, diff --git a/crates/zk-por-core/src/database.rs b/crates/zk-por-core/src/database.rs index d0fdfe4..da9fcc9 100644 --- a/crates/zk-por-core/src/database.rs +++ b/crates/zk-por-core/src/database.rs @@ -1,10 +1,11 @@ use std::str::FromStr; +use hex::ToHex; use plonky2::{hash::hash_types::HashOut, plonk::config::GenericHashOut}; use rand::Rng; use zk_por_db::LevelDb; -use crate::types::F; +use crate::{error::PoRError, types::F}; #[derive(Debug, Clone, Copy)] pub struct UserId(pub [u8; 32]); @@ -16,6 +17,29 @@ impl UserId { rng.fill(&mut bytes); Self(bytes) } + + pub fn to_string(&self) -> String { + self.0.encode_hex() + } + + pub fn from_hex_string(hex_str: String) -> Result { + if hex_str.len() != 64 { + tracing::error!("User Id: {:?} is not a valid id, length is not 256 bits", hex_str); + return Err(PoRError::InvalidParameter(hex_str)); + } + + let decode_res = hex::decode(hex_str.clone()); + + if decode_res.is_err() { + tracing::error!("User Id: {:?} is not a valid id", hex_str); + return Err(PoRError::InvalidParameter(hex_str)); + } + + let mut arr = [0u8; 32]; + arr.copy_from_slice(&decode_res.unwrap()); + + Ok(UserId { 0: arr }) + } } impl db_key::Key for UserId { diff --git a/crates/zk-por-core/src/error.rs b/crates/zk-por-core/src/error.rs index 24e3c21..5f50a49 100644 --- a/crates/zk-por-core/src/error.rs +++ b/crates/zk-por-core/src/error.rs @@ -6,6 +6,9 @@ pub enum PoRError { #[error("Proof is not valid")] InvalidProof, + #[error("Merkle proof is not valid")] + InvalidMerkleProof, + #[error("config error: {0}")] ConfigError(#[from] ConfigError), @@ -20,4 +23,7 @@ pub enum PoRError { #[error("The verification circuit digest does not match the prover. ")] CircuitDigestMismatch, + + #[error("User is not valid")] + InvalidUser, } diff --git a/crates/zk-por-core/src/global.rs b/crates/zk-por-core/src/global.rs index 1a198d4..0346cac 100644 --- a/crates/zk-por-core/src/global.rs +++ b/crates/zk-por-core/src/global.rs @@ -10,7 +10,7 @@ use plonky2::{hash::hash_types::HashOut, util::log2_strict}; use std::{ops::Div, sync::RwLock}; use tracing::debug; -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] pub struct GlobalConfig { pub num_of_tokens: usize, pub num_of_batches: usize, @@ -21,7 +21,7 @@ pub struct GlobalConfig { pub static GLOBAL_MST: OnceCell> = OnceCell::new(); pub struct GlobalMst { - inner: Vec>, + pub inner: Vec>, top_recursion_level: usize, pub cfg: GlobalConfig, } @@ -34,7 +34,7 @@ impl GlobalMst { let mst_vec = vec![HashOut::default(); 0]; // will resize later let mut mst = Self { inner: mst_vec, top_recursion_level: top_level, cfg: cfg }; // the number of hash is one smaller to the index of the root node of the last recursion level. - let root_node_idx = mst.get_recursive_global_index(top_level, 0); + let root_node_idx = GlobalMst::get_recursive_global_index(&cfg, top_level, 0); let tree_size = root_node_idx + 1; mst.inner.resize(tree_size, HashOut::default()); mst @@ -44,28 +44,36 @@ impl GlobalMst { self.inner.len() } - pub fn get_num_of_leaves(&self) -> usize { - self.cfg.batch_size * self.cfg.num_of_batches + pub fn get_num_of_leaves(cfg: &GlobalConfig) -> usize { + cfg.batch_size * cfg.num_of_batches } pub fn get_nodes(&self, range: std::ops::Range) -> &[HashOut] { &self.inner[range] } + pub fn get_root(&self) -> Option<&HashOut> { + self.inner.last() + } + /// convert a mst node inner index to global index in gmst. /// For a mst, the inner index is level-by-level, e.g., /// 14 /// 12 13 /// 8-9, 10-11 /// 0 - 3, 4 - 7 - pub fn get_batch_tree_global_index(&self, batch_idx: usize, inner_tree_idx: usize) -> usize { - let batch_size = self.cfg.batch_size; + pub fn get_batch_tree_global_index( + cfg: &GlobalConfig, + batch_idx: usize, + inner_tree_idx: usize, + ) -> usize { + let batch_size = cfg.batch_size; let tree_depth = log2_strict(batch_size); let batch_tree_level = get_node_level(batch_size, inner_tree_idx); let level_from_bottom = tree_depth - batch_tree_level; - let numeritor = 2 * batch_size * self.cfg.num_of_batches; + let numeritor = 2 * batch_size * cfg.num_of_batches; let global_tree_vertical_offset = numeritor - numeritor.div(1 << level_from_bottom); // the gmst idx of the first node at {level_from_bottom} level let level_node_counts = batch_size.div(1 << level_from_bottom); @@ -83,21 +91,21 @@ impl GlobalMst { // mst root node at level 0, pub fn get_recursive_global_index( - &self, + cfg: &GlobalConfig, recursive_level: usize, inner_level_idx: usize, ) -> usize { - let mst_node_num = 2 * self.cfg.batch_size - 1; - let batch_num = self.cfg.num_of_batches; - let branchout_num = self.cfg.recursion_branchout_num; + let mst_node_num = 2 * cfg.batch_size - 1; + let batch_num = cfg.num_of_batches; + let branchout_num = cfg.recursion_branchout_num; if recursive_level == 0 { // level of merkle sum tree root - if inner_level_idx < self.cfg.num_of_batches { + if inner_level_idx < cfg.num_of_batches { // the global index of the root of the batch tree let mst_root_idx = mst_node_num - 1; - return self.get_batch_tree_global_index(inner_level_idx, mst_root_idx); + return GlobalMst::get_batch_tree_global_index(cfg, inner_level_idx, mst_root_idx); } else { - return batch_num * mst_node_num + (inner_level_idx - self.cfg.num_of_batches); + return batch_num * mst_node_num + (inner_level_idx - cfg.num_of_batches); } } @@ -115,9 +123,9 @@ impl GlobalMst { let mut level = recursive_level; while level > 1 { - let mut this_level_node_num = last_level_node_num / self.cfg.recursion_branchout_num; + let mut this_level_node_num = last_level_node_num / cfg.recursion_branchout_num; this_level_node_num = - pad_to_multiple_of(this_level_node_num, self.cfg.recursion_branchout_num); + pad_to_multiple_of(this_level_node_num, cfg.recursion_branchout_num); recursive_offset += this_level_node_num; @@ -132,14 +140,18 @@ impl GlobalMst { /// `batch_idx`: index indicating the batch index /// `i`: the sub batch tree index; e.g the batch tree is of size 1<<10; i \in [0, 2*batch_size) pub fn set_batch_hash(&mut self, batch_idx: usize, i: usize, hash: HashOut) { - let global_mst_idx = self.get_batch_tree_global_index(batch_idx, i); + let global_mst_idx = GlobalMst::get_batch_tree_global_index(&self.cfg, batch_idx, i); self.inner[global_mst_idx] = hash; } pub fn get_batch_root_hash(&self, batch_idx: usize) -> HashOut { debug!("get batch root hash, batch_idx: {:?}", batch_idx); assert!(batch_idx < self.cfg.num_of_batches); - let root_idx = self.get_batch_tree_global_index(batch_idx, 2 * self.cfg.batch_size - 2); + let root_idx = GlobalMst::get_batch_tree_global_index( + &self.cfg, + batch_idx, + 2 * self.cfg.batch_size - 2, + ); self.inner[root_idx] } @@ -149,7 +161,7 @@ impl GlobalMst { "set_recursive_hash, recursive_level: {:?}, index: {:?}, hash: {:?}", recursive_level, index, hash ); - let idx = self.get_recursive_global_index(recursive_level, index); + let idx = GlobalMst::get_recursive_global_index(&self.cfg, recursive_level, index); self.inner[idx] = hash; } @@ -164,11 +176,18 @@ impl GlobalMst { let inner_left_child_idx = 2 * (inner_tree_idx - leaf_size); let inner_right_child_idx = 2 * (inner_tree_idx - leaf_size) + 1; - let global_parent_idx = self.get_batch_tree_global_index(tree_idx, inner_tree_idx); - let global_left_child_idx = - self.get_batch_tree_global_index(tree_idx, inner_left_child_idx); - let global_right_child_idx = - self.get_batch_tree_global_index(tree_idx, inner_right_child_idx); + let global_parent_idx = + GlobalMst::get_batch_tree_global_index(&self.cfg, tree_idx, inner_tree_idx); + let global_left_child_idx = GlobalMst::get_batch_tree_global_index( + &self.cfg, + tree_idx, + inner_left_child_idx, + ); + let global_right_child_idx = GlobalMst::get_batch_tree_global_index( + &self.cfg, + tree_idx, + inner_right_child_idx, + ); visited_global_idx[global_left_child_idx] = true; visited_global_idx[global_right_child_idx] = true; @@ -191,11 +210,12 @@ impl GlobalMst { let inner_child_indexes = (0..branchout_num) .map(|i| inner_idx * branchout_num + i) .collect::>(); - let global_idx = self.get_recursive_global_index(level, inner_idx); + let global_idx = GlobalMst::get_recursive_global_index(&self.cfg, level, inner_idx); let global_child_indexes = inner_child_indexes .iter() .map(|&i| { - let child_global_idx = self.get_recursive_global_index(level - 1, i); + let child_global_idx = + GlobalMst::get_recursive_global_index(&self.cfg, level - 1, i); visited_global_idx[child_global_idx] = true; child_global_idx }) @@ -215,7 +235,8 @@ impl GlobalMst { last_level_node_count = pad_to_multiple_of(this_level_node_count, branchout_num); } } - let global_root_idx = self.get_recursive_global_index(self.top_recursion_level, 0); + let global_root_idx = + GlobalMst::get_recursive_global_index(&self.cfg, self.top_recursion_level, 0); visited_global_idx[global_root_idx] = true; visited_global_idx.iter().all(|&v| v) @@ -232,7 +253,7 @@ impl GlobalMst { let batches = (i..end) .into_iter() .enumerate() - .map(|(chunk_idx, j)| ((i + j).try_into().unwrap(), nodes[chunk_idx])) + .map(|(chunk_idx, j)| ((j).try_into().unwrap(), nodes[chunk_idx])) .collect::)>>(); db.add_batch_gmst_nodes(batches); i += chunk_size; @@ -274,20 +295,20 @@ mod test { assert_eq!(total_len, 97); assert_eq!(gmst.top_recursion_level, 2); - assert_eq!(gmst.get_batch_tree_global_index(0, 1), 1); - assert_eq!(gmst.get_batch_tree_global_index(0, 14), 84); - assert_eq!(gmst.get_batch_tree_global_index(1, 1), 9); - assert_eq!(gmst.get_batch_tree_global_index(1, 14), 85); - assert_eq!(gmst.get_batch_tree_global_index(5, 7), 47); - assert_eq!(gmst.get_batch_tree_global_index(5, 14), 89); - - assert_eq!(gmst.get_recursive_global_index(0, 7), 91); - assert_eq!(gmst.get_recursive_global_index(0, 1), 85); - assert_eq!(gmst.get_recursive_global_index(1, 0), 92); - assert_eq!(gmst.get_recursive_global_index(1, 1), 93); - assert_eq!(gmst.get_recursive_global_index(1, 2), 94); - assert_eq!(gmst.get_recursive_global_index(1, 3), 95); - assert_eq!(gmst.get_recursive_global_index(2, 0), 96); + assert_eq!(GlobalMst::get_batch_tree_global_index(&gmst.cfg, 0, 1), 1); + assert_eq!(GlobalMst::get_batch_tree_global_index(&gmst.cfg, 0, 14), 84); + assert_eq!(GlobalMst::get_batch_tree_global_index(&gmst.cfg, 1, 1), 9); + assert_eq!(GlobalMst::get_batch_tree_global_index(&gmst.cfg, 1, 14), 85); + assert_eq!(GlobalMst::get_batch_tree_global_index(&gmst.cfg, 5, 7), 47); + assert_eq!(GlobalMst::get_batch_tree_global_index(&gmst.cfg, 5, 14), 89); + + assert_eq!(GlobalMst::get_recursive_global_index(&gmst.cfg, 0, 7), 91); + assert_eq!(GlobalMst::get_recursive_global_index(&gmst.cfg, 0, 1), 85); + assert_eq!(GlobalMst::get_recursive_global_index(&gmst.cfg, 1, 0), 92); + assert_eq!(GlobalMst::get_recursive_global_index(&gmst.cfg, 1, 1), 93); + assert_eq!(GlobalMst::get_recursive_global_index(&gmst.cfg, 1, 2), 94); + assert_eq!(GlobalMst::get_recursive_global_index(&gmst.cfg, 1, 3), 95); + assert_eq!(GlobalMst::get_recursive_global_index(&gmst.cfg, 2, 0), 96); } #[test] @@ -331,8 +352,11 @@ mod test { for inner_idx in 0..this_level_node_count { let children_hashes = (0..branchout_num) .map(|i| { - let child_global_idx = gmst - .get_recursive_global_index(level - 1, inner_idx * branchout_num + i); + let child_global_idx = GlobalMst::get_recursive_global_index( + &gmst.cfg, + level - 1, + inner_idx * branchout_num + i, + ); gmst.inner[child_global_idx] }) .collect::>>(); diff --git a/crates/zk-por-core/src/merkle_proof.rs b/crates/zk-por-core/src/merkle_proof.rs index cbb5561..049aff3 100644 --- a/crates/zk-por-core/src/merkle_proof.rs +++ b/crates/zk-por-core/src/merkle_proof.rs @@ -1,15 +1,37 @@ +use itertools::Itertools; +use plonky2::{ + hash::{hash_types::HashOut, poseidon::PoseidonHash}, + plonk::config::Hasher, +}; use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; +use serde::{Deserialize, Serialize}; + +use crate::{ + account::Account, + database::{DataBase, UserId}, + error::PoRError, + global::{GlobalConfig, GlobalMst}, + merkle_sum_prover::utils::hash_2_subhashes, + types::{D, F}, +}; + +/// We use this wrapper struct for the left and right indexes of our recursive siblings. This is needed so a user knows the position of +/// their own hash when hashing. +#[derive(Debug, Clone, PartialEq)] +pub struct RecursiveIndex { + left_indexes: Vec, + right_indexes: Vec, +} -use crate::global::GlobalMst; - -#[derive(Debug, Clone)] +/// Indexes for a given users merkle proof of inclusion siblings in the Global Merkle Sum Tree +#[derive(Debug, Clone, PartialEq)] pub struct MerkleProofIndex { pub sum_tree_siblings: Vec, - pub recursive_tree_siblings: Vec>, + pub recursive_tree_siblings: Vec, } impl MerkleProofIndex { - pub fn new_from_user_index(user_index: usize, global_mst: &GlobalMst) -> MerkleProofIndex { + pub fn new_from_user_index(user_index: usize, global_mst: &GlobalConfig) -> MerkleProofIndex { let sum_tree_siblings = get_mst_siblings_index(user_index, global_mst); let recursive_tree_siblings = get_recursive_siblings_index(user_index, global_mst); @@ -19,17 +41,17 @@ impl MerkleProofIndex { /// Get the siblings index for the merkle proof of inclusion given a leaf index of a binary merkle sum tree. /// We get the parent index of a leaf using the formula: parent = index / 2 + num_leaves -pub fn get_mst_siblings_index(global_index: usize, global_mst: &GlobalMst) -> Vec { +pub fn get_mst_siblings_index(global_leaf_index: usize, cfg: &GlobalConfig) -> Vec { // Make sure our global index is within the number of leaves - assert!(global_index < global_mst.get_num_of_leaves()); + assert!(global_leaf_index < GlobalMst::get_num_of_leaves(cfg)); - let batch_idx = global_index / global_mst.cfg.batch_size; + let batch_id = global_leaf_index / cfg.batch_size; let mut siblings = Vec::new(); // This is the index in the local mst tree - let mut local_index = global_index % global_mst.cfg.batch_size; + let mut local_index = global_leaf_index % cfg.batch_size; - while local_index < (global_mst.cfg.batch_size * 2 - 2) { + while local_index < (cfg.batch_size * 2 - 2) { if local_index % 2 == 1 { let sibling_index = local_index - 1; siblings.push(sibling_index); @@ -38,78 +60,230 @@ pub fn get_mst_siblings_index(global_index: usize, global_mst: &GlobalMst) -> Ve siblings.push(sibling_index); } - let parent = local_index / 2 + global_mst.cfg.batch_size; - local_index = parent; + let local_parent_index = local_index / 2 + cfg.batch_size; + local_index = local_parent_index; } - siblings.par_iter().map(|x| global_mst.get_batch_tree_global_index(batch_idx, *x)).collect() + siblings.par_iter().map(|x| GlobalMst::get_batch_tree_global_index(cfg, batch_id, *x)).collect() } /// Gets the recursive siblings indexes (recursive tree is n-ary tree) as a Vec of vecs, each inner vec is one layer of siblings. pub fn get_recursive_siblings_index( global_index: usize, - global_mst: &GlobalMst, -) -> Vec> { + cfg: &GlobalConfig, +) -> Vec { // Make sure our global index is within the number of leaves - assert!(global_index < global_mst.get_num_of_leaves()); + assert!(global_index < GlobalMst::get_num_of_leaves(cfg)); let mut siblings = Vec::new(); - let local_mst_root_index = global_mst.cfg.batch_size * 2 - 2; - let mst_batch_idx = global_index / global_mst.cfg.batch_size; + let local_mst_root_index = cfg.batch_size * 2 - 2; + let mst_batch_idx = global_index / cfg.batch_size; let this_mst_root_idx = - global_mst.get_batch_tree_global_index(mst_batch_idx, local_mst_root_index); + GlobalMst::get_batch_tree_global_index(cfg, mst_batch_idx, local_mst_root_index); - let first_mst_root_idx = global_mst.get_batch_tree_global_index(0, local_mst_root_index); + let first_mst_root_idx = GlobalMst::get_batch_tree_global_index(cfg, 0, local_mst_root_index); assert!(this_mst_root_idx >= first_mst_root_idx); let this_mst_root_offset = this_mst_root_idx - first_mst_root_idx; - let mut recursive_idx = this_mst_root_offset / global_mst.cfg.recursion_branchout_num; - let mut recursive_offset = this_mst_root_offset % global_mst.cfg.recursion_branchout_num; + let mut recursive_idx = this_mst_root_offset / cfg.recursion_branchout_num; + let mut recursive_offset = this_mst_root_offset % cfg.recursion_branchout_num; - let layers = (global_mst.cfg.num_of_batches.next_power_of_two() as f64) - .log(global_mst.cfg.recursion_branchout_num as f64) + let layers = (cfg.num_of_batches.next_power_of_two() as f64) + .log(cfg.recursion_branchout_num as f64) .ceil() as usize; for i in 0..layers { - let mut layer = Vec::new(); + let mut left_layer = Vec::new(); + let mut right_layer = Vec::new(); if i == 0 { - for j in 0..global_mst.cfg.recursion_branchout_num { - if j != recursive_offset { - let index = first_mst_root_idx - + (global_mst.cfg.recursion_branchout_num * recursive_idx) - + j; - layer.push(index); + for j in 0..cfg.recursion_branchout_num { + if j < recursive_offset { + let index = + first_mst_root_idx + (cfg.recursion_branchout_num * recursive_idx) + j; + left_layer.push(index); + } + + if j > recursive_offset { + let index = + first_mst_root_idx + (cfg.recursion_branchout_num * recursive_idx) + j; + right_layer.push(index); } } } else { - for j in 0..global_mst.cfg.recursion_branchout_num { - if j != recursive_offset { - let index = global_mst.get_recursive_global_index( + for j in 0..cfg.recursion_branchout_num { + if j < recursive_offset { + let index = GlobalMst::get_recursive_global_index( + cfg, i, - recursive_idx * global_mst.cfg.recursion_branchout_num + j, + recursive_idx * cfg.recursion_branchout_num + j, ); - layer.push(index); + left_layer.push(index); + } + + if j > recursive_offset { + let index = GlobalMst::get_recursive_global_index( + cfg, + i, + recursive_idx * cfg.recursion_branchout_num + j, + ); + right_layer.push(index); } } } - siblings.push(layer); + siblings.push(RecursiveIndex { left_indexes: left_layer, right_indexes: right_layer }); - recursive_offset = recursive_idx % global_mst.cfg.recursion_branchout_num; - recursive_idx = recursive_idx / global_mst.cfg.recursion_branchout_num; + recursive_offset = recursive_idx % cfg.recursion_branchout_num; + recursive_idx = recursive_idx / cfg.recursion_branchout_num; } siblings } +/// We use this wrapper struct for the left and right hashes of our recursive siblings. This is needed so a user knows the position of +/// their own hash when hashing. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct RecursiveHashes { + left_hashes: Vec>, + right_hashes: Vec>, +} + +impl RecursiveHashes { + pub fn new_from_index(indexes: &RecursiveIndex, db: &DataBase) -> Self { + let left_hashes = indexes + .left_indexes + .iter() + .map(|y| db.get_gmst_node_hash(*y as i32).unwrap()) + .collect_vec(); + let right_hashes = indexes + .right_indexes + .iter() + .map(|y| db.get_gmst_node_hash(*y as i32).unwrap()) + .collect_vec(); + RecursiveHashes { left_hashes, right_hashes } + } + + /// Calculated Hash = Left hashes || own hash || Right hashes + pub fn get_calculated_hash(self, own_hash: HashOut) -> HashOut { + let mut hash_inputs = self.left_hashes; + hash_inputs.push(own_hash); + hash_inputs.extend(self.right_hashes); + + let inputs: Vec = hash_inputs.iter().map(|x| x.elements).flatten().collect(); + + PoseidonHash::hash_no_pad(inputs.as_slice()) + } +} + +/// Hashes for a given users merkle proof of inclusion siblings in the Global Merkle Sum Tree, also includes account data as it is needed for the verification +/// of the merkle proof (needed to calculate own hash) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MerkleProof { + pub account: Account, + pub index: usize, + pub sum_tree_siblings: Vec>, + pub recursive_tree_siblings: Vec, +} + +impl MerkleProof { + pub fn new_from_account( + account: &Account, + db: &DataBase, + cfg: &GlobalConfig, + ) -> Result { + let user_id_res = UserId::from_hex_string(account.id.clone()); + if user_id_res.is_err() { + return Err(user_id_res.unwrap_err()); + } + + let user_id = user_id_res.unwrap(); + + let user_index = db.get_user_index(user_id.clone()); + if user_index.is_none() { + tracing::error!("User with id: {:?} does not exist", user_id.to_string()); + return Err(PoRError::InvalidParameter(user_id.to_string())); + } + + let merkle_proof_indexes = + MerkleProofIndex::new_from_user_index(user_index.unwrap() as usize, cfg); + let merkle_proof = get_merkle_proof_hashes_from_indexes( + account, + &merkle_proof_indexes, + user_index.unwrap() as usize, + db, + ); + Ok(merkle_proof) + } + + pub fn verify_merkle_proof(&self, gmst_root: HashOut) -> Result<(), PoRError> { + let account_hash = self.account.get_hash(); + + let mut index = self.index; + + let calculated_mst_hash = self.sum_tree_siblings.iter().fold(account_hash, |acc, x| { + if index % 2 == 0 { + index /= 2; + hash_2_subhashes::(&acc, x) + } else { + index /= 2; + hash_2_subhashes::(x, &acc) + } + }); + + let calculated_hash = self + .recursive_tree_siblings + .iter() + .fold(calculated_mst_hash, |acc, x| x.clone().get_calculated_hash(acc)); + + if calculated_hash == gmst_root { + Ok(()) + } else { + Err(PoRError::InvalidMerkleProof) + } + } +} + +/// Given the indexes for the MST siblings, get the hashes from the database for the merkle proof of inclusion. +pub fn get_merkle_proof_hashes_from_indexes( + account: &Account, + indexes: &MerkleProofIndex, + user_index: usize, + db: &DataBase, +) -> MerkleProof { + let mst_hashes: Vec> = indexes + .sum_tree_siblings + .iter() + .map(|x| db.get_gmst_node_hash(*x as i32).unwrap()) + .collect(); + + let recursive_hashes: Vec = indexes + .recursive_tree_siblings + .iter() + .map(|x| RecursiveHashes::new_from_index(x, db)) + .collect(); + + MerkleProof { + account: account.clone(), + sum_tree_siblings: mst_hashes, + recursive_tree_siblings: recursive_hashes, + index: user_index, + } +} + #[cfg(test)] pub mod test { + use itertools::Itertools; + use plonky2::hash::hash_types::HashOut; + use crate::{ + account::Account, global::{GlobalConfig, GlobalMst}, - merkle_proof::get_recursive_siblings_index, + merkle_proof::{get_recursive_siblings_index, MerkleProofIndex, RecursiveIndex}, + types::F, }; + use plonky2_field::types::Field; - use super::get_mst_siblings_index; + use super::{get_mst_siblings_index, MerkleProof, RecursiveHashes}; #[test] pub fn test_get_siblings_index() { @@ -122,7 +296,7 @@ pub mod test { let global_index = 0; - let siblings = get_mst_siblings_index(global_index, &gmst); + let siblings = get_mst_siblings_index(global_index, &gmst.cfg); assert_eq!(siblings, vec![1, 33, 49]); let gmst = GlobalMst::new(GlobalConfig { @@ -134,7 +308,7 @@ pub mod test { let global_index = 0; - let siblings = get_mst_siblings_index(global_index, &gmst); + let siblings = get_mst_siblings_index(global_index, &gmst.cfg); assert_eq!(siblings, vec![1, 65, 97]); let gmst = GlobalMst::new(GlobalConfig { @@ -146,7 +320,7 @@ pub mod test { let global_index = 0; - let siblings = get_mst_siblings_index(global_index, &gmst); + let siblings = get_mst_siblings_index(global_index, &gmst.cfg); assert_eq!(siblings, vec![1, 49, 73]); } @@ -161,8 +335,15 @@ pub mod test { let global_index = 0; - let siblings = get_recursive_siblings_index(global_index, &gmst); - assert_eq!(siblings, vec![vec![91, 92, 93], vec![107, 108, 109]]); + let siblings = get_recursive_siblings_index(global_index, &gmst.cfg); + + assert_eq!( + siblings, + vec![ + RecursiveIndex { left_indexes: vec![], right_indexes: vec![91, 92, 93] }, + RecursiveIndex { left_indexes: vec![], right_indexes: vec![107, 108, 109] } + ] + ); let gmst = GlobalMst::new(GlobalConfig { num_of_tokens: 100, @@ -173,11 +354,18 @@ pub mod test { let global_index = 163; - let siblings = get_recursive_siblings_index(global_index, &gmst); - assert_eq!(siblings, vec![vec![441, 442, 443], vec![456, 458, 459], vec![460, 462, 463]]); + let siblings = get_recursive_siblings_index(global_index, &gmst.cfg); + assert_eq!( + siblings, + vec![ + RecursiveIndex { left_indexes: vec![], right_indexes: vec![441, 442, 443] }, + RecursiveIndex { left_indexes: vec![456], right_indexes: vec![458, 459] }, + RecursiveIndex { left_indexes: vec![460], right_indexes: vec![462, 463] } + ] + ); let gmst = GlobalMst::new(GlobalConfig { - num_of_tokens: 100, + num_of_tokens: 10, num_of_batches: 6, batch_size: 4, recursion_branchout_num: 4, @@ -185,7 +373,236 @@ pub mod test { let global_index = 20; - let siblings = get_recursive_siblings_index(global_index, &gmst); - assert_eq!(siblings, vec![[40, 42, 43], [44, 46, 47]]); + let siblings = get_recursive_siblings_index(global_index, &gmst.cfg); + assert_eq!( + siblings, + vec![ + RecursiveIndex { left_indexes: vec![40], right_indexes: vec![42, 43] }, + RecursiveIndex { left_indexes: vec![44], right_indexes: vec![46, 47] }, + ] + ); + } + + #[test] + pub fn test_get_new_merkle_index_from_user_index() { + let gmst = GlobalMst::new(GlobalConfig { + num_of_tokens: 100, + num_of_batches: 15, + batch_size: 4, + recursion_branchout_num: 4, + }); + + let global_index = 0; + + let merkle_proof_indexes = MerkleProofIndex::new_from_user_index(global_index, &gmst.cfg); + + assert_eq!( + merkle_proof_indexes, + MerkleProofIndex { + sum_tree_siblings: vec![1, 61], + recursive_tree_siblings: vec![ + RecursiveIndex { left_indexes: vec![], right_indexes: vec![91, 92, 93] }, + RecursiveIndex { left_indexes: vec![], right_indexes: vec![107, 108, 109] } + ], + } + ); + } + + #[test] + pub fn test_verify_merkle_proof() { + let _gmst = GlobalMst::new(GlobalConfig { + num_of_tokens: 3, + num_of_batches: 4, + batch_size: 2, + recursion_branchout_num: 4, + }); + + let equity = vec![3, 3, 3].iter().map(|x| F::from_canonical_u32(*x)).collect_vec(); + let debt = vec![1, 1, 1].iter().map(|x| F::from_canonical_u32(*x)).collect_vec(); + + let sum_tree_siblings = vec![HashOut::from_vec( + vec![ + 7609058119952049295, + 8895839458156070742, + 1052773619972611009, + 6038312163525827182, + ] + .iter() + .map(|x| F::from_canonical_u64(*x)) + .collect::>(), + )]; + + let recursive_tree_siblings = vec![RecursiveHashes { + left_hashes: vec![], + right_hashes: vec![ + HashOut::from_vec( + vec![ + 15026394135096265436, + 13313300609834454638, + 10151802728958521275, + 6200471959130767555, + ] + .iter() + .map(|x| F::from_canonical_u64(*x)) + .collect::>(), + ), + HashOut::from_vec( + vec![ + 2010803994799996791, + 568450490466247075, + 18209684900543488748, + 7678193912819861368, + ] + .iter() + .map(|x| F::from_canonical_u64(*x)) + .collect::>(), + ), + HashOut::from_vec( + vec![ + 13089029781628355232, + 10704046654659337561, + 15794212269117984095, + 15948192230150472783, + ] + .iter() + .map(|x| F::from_canonical_u64(*x)) + .collect::>(), + ), + ], + }]; + + let account = Account { + id: "320b5ea99e653bc2b593db4130d10a4efd3a0b4cc2e1a6672b678d71dfbd33ad".to_string(), + equity: equity.clone(), + debt: debt.clone(), + }; + + let merkle_proof = + MerkleProof { account, sum_tree_siblings, recursive_tree_siblings, index: 0 }; + + let root = HashOut::from_vec( + vec![ + 10628303359772907103, + 7478459528589413745, + 12007196562137971174, + 2652030368197917032, + ] + .iter() + .map(|x| F::from_canonical_u64(*x)) + .collect::>(), + ); + + let res = merkle_proof.verify_merkle_proof(root); + + res.unwrap(); } + + // THIS IS THE TEST DATA FOR VERIFY + // #[test] + // pub fn poseidon_hash() { + // let equity = vec![3,3,3,].iter().map(|x| F::from_canonical_u32(*x)).collect_vec(); + // let debt = vec![1,1,1,].iter().map(|x| F::from_canonical_u32(*x)).collect_vec(); + + // let accounts = vec![ + // Account{ + // id: "320b5ea99e653bc2b593db4130d10a4efd3a0b4cc2e1a6672b678d71dfbd33ad".to_string(), + // equity: equity.clone(), + // debt: debt.clone(), + // }, + // Account{ + // id: "320b5ea99e653bc2b593db4130d10a4efd3a0b4cc2e1a6672b678d71dfbd33ac".to_string(), + // equity: equity.clone(), + // debt: debt.clone(), + // }, + // Account{ + // id: "320b5ea99e653bc2b593db4130d10a4efd3a0b4cc2e1a6672b678d71dfbd33ab".to_string(), + // equity: equity.clone(), + // debt: debt.clone(), + // }, + // Account{ + // id: "320b5ea99e653bc2b593db4130d10a4efd3a0b4cc2e1a6672b678d71dfbd33aa".to_string(), + // equity: equity.clone(), + // debt: debt.clone(), + // }, + // Account{ + // id: "320b5ea99e653bc2b593db4130d10a4efd3a0b4cc2e1a6672b678d71dfbd33a1".to_string(), + // equity: equity.clone(), + // debt: debt.clone(), + // }, + // Account{ + // id: "320b5ea99e653bc2b593db4130d10a4efd3a0b4cc2e1a6672b678d71dfbd33a2".to_string(), + // equity: equity.clone(), + // debt: debt.clone(), + // }, + // Account{ + // id: "320b5ea99e653bc2b593db4130d10a4efd3a0b4cc2e1a6672b678d71dfbd33a3".to_string(), + // equity: equity.clone(), + // debt: debt.clone(), + // }, + // Account{ + // id: "320b5ea99e653bc2b593db4130d10a4efd3a0b4cc2e1a6672b678d71dfbd33a4".to_string(), + // equity: equity.clone(), + // debt: debt.clone(), + // } + // ]; + + // let msts: Vec = accounts + // .chunks(2) + // .map(|account_batch| MerkleSumTree::new_tree_from_accounts(&account_batch.to_vec())) + // .collect(); + + // let mst_hashes = msts.iter().map(|x| x.merkle_sum_tree.iter().map(|y| y.hash).collect_vec()).collect_vec(); + // println!("msts:{:?}", mst_hashes); + // let inputs = vec![ + // HashOut::from_vec( + // vec![ + // 8699257539652901730, + // 12847577670763395377, + // 14540605839220144846, + // 1921995570040415498, + // ] + // .iter() + // .map(|x| F::from_canonical_u64(*x)) + // .collect::>(), + // ), + // HashOut::from_vec( + // vec![ + // 15026394135096265436, + // 13313300609834454638, + // 10151802728958521275, + // 6200471959130767555, + // ] + // .iter() + // .map(|x| F::from_canonical_u64(*x)) + // .collect::>(), + // ), + // HashOut::from_vec( + // vec![ + // 2010803994799996791, + // 568450490466247075, + // 18209684900543488748, + // 7678193912819861368, + // ] + // .iter() + // .map(|x| F::from_canonical_u64(*x)) + // .collect::>(), + // ), + // HashOut::from_vec( + // vec![ + // 13089029781628355232, + // 10704046654659337561, + // 15794212269117984095, + // 15948192230150472783, + // ] + // .iter() + // .map(|x| F::from_canonical_u64(*x)) + // .collect::>(), + // ), + // ]; + + // let hash = PoseidonHash::hash_no_pad( + // inputs.iter().map(|x| x.elements).flatten().collect_vec().as_slice(), + // ); + // println!("Hash: {:?}", hash); + // } } diff --git a/crates/zk-por-core/src/merkle_sum_tree.rs b/crates/zk-por-core/src/merkle_sum_tree.rs index 7b7fd1a..f203e01 100644 --- a/crates/zk-por-core/src/merkle_sum_tree.rs +++ b/crates/zk-por-core/src/merkle_sum_tree.rs @@ -36,6 +36,7 @@ impl MerkleSumNode { } /// Struct representing a merkle sum tree, it is represented as a vector of Merkle Sum Nodes. +#[derive(Debug, Clone)] pub struct MerkleSumTree { pub merkle_sum_tree: Vec, pub tree_depth: usize, diff --git a/crates/zk-por-core/src/parser.rs b/crates/zk-por-core/src/parser.rs index 4fcf75c..f4ce18b 100644 --- a/crates/zk-por-core/src/parser.rs +++ b/crates/zk-por-core/src/parser.rs @@ -233,35 +233,36 @@ impl AccountParser for FileAccountReader { fn parse_exchange_state(parsed_data: &Vec>) -> Vec { let mut accounts_data: Vec = Vec::new(); for obj in parsed_data { - let mut account_id = ""; - let mut inner_vec: Vec = Vec::new(); - for (key, value) in obj.iter() { - if key != "id" { - if let Some(number_str) = value.as_str() { - match number_str.parse::() { - Ok(number) => inner_vec.push(F::from_canonical_u64(number)), - Err(e) => { - error!("Error in parsing token value number: {:?}", e); - panic!("Error in parsing token value number: {:?}", e); - } + accounts_data.push(parse_account_state(obj)); + } + accounts_data +} + +/// Parses the exchanges state at some snapshot and returns. +pub fn parse_account_state(parsed_data: &BTreeMap) -> Account { + let mut account_id = ""; + let mut inner_vec: Vec = Vec::new(); + for (key, value) in parsed_data.iter() { + if key != "id" { + if let Some(number_str) = value.as_str() { + match number_str.parse::() { + Ok(number) => inner_vec.push(F::from_canonical_u64(number)), + Err(e) => { + error!("Error in parsing token value number: {:?}", e); + panic!("Error in parsing token value number: {:?}", e); } - } else { - error!("Error in parsing string from json: {:?}", value); - panic!("Error in parsing string from json: {:?}", value); } } else { - account_id = value.as_str().unwrap(); + error!("Error in parsing string from json: {:?}", value); + panic!("Error in parsing string from json: {:?}", value); } + } else { + account_id = value.as_str().unwrap(); } - // todo:: currently, we fill debt all to zero - let asset_len = inner_vec.len(); - accounts_data.push(Account { - id: account_id.into(), - equity: inner_vec, - debt: vec![F::ZERO; asset_len], - }); } - accounts_data + // todo:: currently, we fill debt all to zero + let asset_len = inner_vec.len(); + Account { id: account_id.into(), equity: inner_vec, debt: vec![F::ZERO; asset_len] } } pub struct RandomAccountParser { @@ -295,7 +296,7 @@ mod test { use crate::{ account::Account, - parser::{FileManager, FilesCfg}, + parser::{parse_exchange_state, FileManager, FilesCfg}, }; use mockall::*; use serde_json::Value; @@ -305,7 +306,7 @@ mod test { str::FromStr, }; - use super::{parse_exchange_state, AccountParser, FileAccountReader, JsonFileManager}; + use super::{AccountParser, FileAccountReader, JsonFileManager}; #[test] pub fn test_read_json_file_into_map() { diff --git a/crates/zk-por-core/src/util.rs b/crates/zk-por-core/src/util.rs index 9ca826e..b216fc3 100644 --- a/crates/zk-por-core/src/util.rs +++ b/crates/zk-por-core/src/util.rs @@ -1,3 +1,7 @@ +use crate::types::F; +use plonky2::hash::hash_types::HashOut; +use plonky2_field::types::Field; + pub fn pad_to_multiple_of(n: usize, multiple: usize) -> usize { if multiple == 0 { return n; // Avoid division by zero @@ -17,11 +21,29 @@ pub fn get_node_level(batch_size: usize, node_idx: usize) -> usize { ((total_nums - node_idx) as f64).log(2.0).floor() as usize } +/// Given a hash string, get a hashout +pub fn get_hash_from_hash_string(hash_string: String) -> HashOut { + let without_brackets = hash_string.trim_matches(|c| c == '[' || c == ']').to_string(); // Remove brackets + + let hash_as_vec_f: Vec = without_brackets + .split(',') + .map(|s| F::from_canonical_u64(s.parse::().unwrap())) + .collect(); + + if hash_as_vec_f.len() != 4 { + panic!("Incorrect format of hash"); + } + + HashOut::from_vec(hash_as_vec_f) +} + #[cfg(test)] pub mod test_util { + use plonky2::hash::hash_types::HashOut; + use crate::util::get_node_level; - use super::pad_to_multiple_of; + use super::{get_hash_from_hash_string, pad_to_multiple_of}; #[test] fn test_get_node_level() { @@ -40,4 +62,10 @@ pub mod test_util { assert_eq!(pad_to_multiple_of(24, 4), 24); assert_eq!(pad_to_multiple_of(27, 4), 28); } + + #[test] + fn test_get_hash_from_hash_string() { + let hash = get_hash_from_hash_string("[0000,0000,0000,0000]".to_string()); + assert_eq!(hash, HashOut::ZERO); + } } diff --git a/doc/solution.md b/doc/solution.md new file mode 100644 index 0000000..ed9048f --- /dev/null +++ b/doc/solution.md @@ -0,0 +1,9 @@ +# gmst + +```mermaid + graph TD; + A-->B; + A-->C; + B-->D; + C-->D; +``` \ No newline at end of file