Skip to content

Commit

Permalink
Merge pull request #81 from okx/ruanpc/production
Browse files Browse the repository at this point in the history
Ruanpc/production
  • Loading branch information
RUAN0007 authored Sep 9, 2024
2 parents 19e2e61 + fe38c93 commit 595a181
Show file tree
Hide file tree
Showing 13 changed files with 271 additions and 111 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ test-data/proofs*/
my_permanent_leveldb/
*level_db*/
*.json
*DS_Store
*tar
release/
2 changes: 2 additions & 0 deletions crates/zk-por-cli/src/constant.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
pub const RECURSION_BRANCHOUT_NUM: usize = 64;
pub const DEFAULT_BATCH_SIZE: usize = 1024;
pub const GLOBAL_PROOF_FILENAME: &str = "global_proof.json";
pub const GLOBAL_INFO_FILENAME: &str = "global_info.json";
pub const USER_PROOF_DIRNAME: &str = "user_proofs";
pub const DEFAULT_USER_PROOF_FILE_PATTERN: &str = "*_inclusion_proof.json";
2 changes: 1 addition & 1 deletion crates/zk-por-cli/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
mod constant;
pub mod constant;
pub mod prover;
pub mod verifier;
46 changes: 34 additions & 12 deletions crates/zk-por-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::{path::PathBuf, str::FromStr};

use clap::{Parser, Subcommand};
use zk_por_cli::{
constant::{DEFAULT_USER_PROOF_FILE_PATTERN, GLOBAL_PROOF_FILENAME},
prover::prove,
verifier::{verify_global, verify_user},
};
Expand All @@ -11,7 +12,7 @@ use zk_por_core::error::PoRError;
#[command(version, about, long_about = None)]
struct Cli {
#[command(subcommand)]
command: ZkPorCommitCommands,
command: Option<ZkPorCommitCommands>,
}

pub trait Execute {
Expand Down Expand Up @@ -39,34 +40,55 @@ pub enum ZkPorCommitCommands {
},
}

impl Execute for ZkPorCommitCommands {
impl Execute for Option<ZkPorCommitCommands> {
fn execute(&self) -> std::result::Result<(), PoRError> {
match self {
ZkPorCommitCommands::Prove { cfg_path, output_path } => {
Some(ZkPorCommitCommands::Prove { cfg_path, output_path }) => {
let cfg = zk_por_core::config::ProverConfig::load(&cfg_path)
.map_err(|e| PoRError::ConfigError(e))?;
let prover_cfg = cfg.try_deserialize().unwrap();
let output_file = PathBuf::from_str(&output_path).unwrap();
prove(prover_cfg, output_file)
let output_path = PathBuf::from_str(&output_path).unwrap();
prove(prover_cfg, output_path)
}

ZkPorCommitCommands::VerifyGlobal { proof_path: global_proof_path } => {
Some(ZkPorCommitCommands::VerifyGlobal { proof_path: global_proof_path }) => {
let global_proof_path = PathBuf::from_str(&global_proof_path).unwrap();
verify_global(global_proof_path)
verify_global(global_proof_path, true)
}

ZkPorCommitCommands::VerifyUser { global_proof_path, user_proof_path_pattern } => {
Some(ZkPorCommitCommands::VerifyUser {
global_proof_path,
user_proof_path_pattern,
}) => {
let global_proof_path = PathBuf::from_str(&global_proof_path).unwrap();
verify_user(global_proof_path, user_proof_path_pattern)
verify_user(global_proof_path, user_proof_path_pattern, true)
}

None => {
println!("============Validation started============");
let global_proof_path = PathBuf::from_str(GLOBAL_PROOF_FILENAME).unwrap();
let user_proof_path_pattern = DEFAULT_USER_PROOF_FILE_PATTERN.to_owned();
if verify_global(global_proof_path.clone(), false).is_ok() {
println!("Total sum and non-negative constraint validation passed")
} else {
println!("Total sum and non-negative constraint validation failed")
}

if verify_user(global_proof_path, &user_proof_path_pattern, false).is_ok() {
println!("Inclusion constraint validation passed")
} else {
println!("Inclusion constraint validation failed")
}
println!("============Validation finished============");
Ok(())
}
}
}
}

fn main() -> std::result::Result<(), PoRError> {
let cli = Cli::parse();
let start = std::time::Instant::now();
let result = cli.command.execute();
println!("result: {:?}, elapsed: {:?}", result, start.elapsed());
let r = cli.command.execute();
println!("Execution result: {:?}", r);
Ok(())
}
39 changes: 37 additions & 2 deletions crates/zk-por-cli/src/prover.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use super::constant::{
DEFAULT_BATCH_SIZE, GLOBAL_PROOF_FILENAME, RECURSION_BRANCHOUT_NUM, USER_PROOF_DIRNAME,
DEFAULT_BATCH_SIZE, GLOBAL_INFO_FILENAME, GLOBAL_PROOF_FILENAME, RECURSION_BRANCHOUT_NUM,
USER_PROOF_DIRNAME,
};
use indicatif::ProgressBar;
use plonky2::hash::hash_types::HashOut;
use plonky2_field::types::PrimeField64;
use rayon::{iter::ParallelIterator, prelude::*};
use serde_json::json;
use std::{
Expand All @@ -28,7 +30,7 @@ use zk_por_core::{
parser::{AccountParser, FileAccountReader, FileManager, FilesCfg},
recursive_prover::recursive_circuit::RecursiveTargets,
types::F,
General, Proof,
General, Info, Proof,
};
use zk_por_tracing::{init_tracing, TraceConfig};

Expand Down Expand Up @@ -367,6 +369,39 @@ fn dump_proofs(
.write_all(json!(root_proof).to_string().as_bytes())
.map_err(|e| return PoRError::Io(e))?;

///////////////////////////////////////////////
let hash_offset = RecursiveTargets::<RECURSION_BRANCHOUT_NUM>::pub_input_hash_offset();
let root_hash = HashOut::<F>::from_partial(&root_proof.proof.public_inputs[hash_offset]);
let root_hash_bytes = root_hash
.elements
.iter()
.map(|x| x.to_canonical_u64().to_le_bytes())
.flatten()
.collect::<Vec<u8>>();
let root_hash = hex::encode(root_hash_bytes);

let equity_offset = RecursiveTargets::<RECURSION_BRANCHOUT_NUM>::pub_input_equity_offset();
let equity_sum = root_proof.proof.public_inputs[equity_offset].to_canonical_u64();

let debt_offset = RecursiveTargets::<RECURSION_BRANCHOUT_NUM>::pub_input_debt_offset();
let debt_sum = root_proof.proof.public_inputs[debt_offset].to_canonical_u64();
assert!(equity_sum >= debt_sum);
let balance_sum = equity_sum - debt_sum;
let info = Info {
root_hash: root_hash,
equity_sum: equity_sum,
debt_sum: debt_sum,
balance_sum: balance_sum,
};

let global_info_output_path = proof_output_dir_path.join(GLOBAL_INFO_FILENAME);
let mut global_info_file =
File::create(global_info_output_path.clone()).map_err(|e| PoRError::Io(e))?;

global_info_file
.write_all(json!(info).to_string().as_bytes())
.map_err(|e| return PoRError::Io(e))?;

///////////////////////////////////////////////
// generate and dump proof for each user
// create a new account reader to avoid buffering previously loaded accounts in memory
Expand Down
100 changes: 63 additions & 37 deletions crates/zk-por-cli/src/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ fn find_matching_files(pattern: &str) -> Result<Vec<PathBuf>, io::Error> {
pub fn verify_user(
global_proof_path: PathBuf,
user_proof_path_pattern: &String,
verbose: bool,
) -> Result<(), PoRError> {
let proof_file = File::open(&global_proof_path).unwrap();
let reader = std::io::BufReader::new(proof_file);
Expand All @@ -52,34 +53,55 @@ pub fn verify_user(
let user_proof_paths =
find_matching_files(user_proof_path_pattern).map_err(|e| PoRError::Io(e))?;
let proof_file_num = user_proof_paths.len();
println!("successfully identify {} user proof files", proof_file_num);
if proof_file_num == 0 {
return Err(PoRError::InvalidParameter(format!(
"fail to find any user proof files with pattern {}",
user_proof_path_pattern
)));
}

if verbose {
println!("successfully identify {} user proof files", proof_file_num);
}

let bar = ProgressBar::new(proof_file_num as u64);
let chunk_size: usize = num_cpus::get();
user_proof_paths.chunks(chunk_size).for_each(|chunks| {
chunks.par_iter().for_each(|user_proof_path| {
let invalid_proof_paths = user_proof_paths
.par_iter()
.map(|user_proof_path| {
let merkle_path = File::open(&user_proof_path).unwrap();
let reader = std::io::BufReader::new(merkle_path);
let proof: MerkleProof = from_reader(reader).unwrap();
if let Err(e) = proof.verify_merkle_proof(root_hash) {
panic!(
"fail to verify the user proof on path {:?} due to error {:?}",
user_proof_path, e
)
let result = proof.verify_merkle_proof(root_hash);
if verbose {
bar.inc(1);
}
});
bar.inc(chunks.len() as u64);
});
bar.finish();
println!(
"successfully verify {} user proofs with file pattern {}",
proof_file_num, user_proof_path_pattern
);
(result, user_proof_path)
})
.filter(|(result, _)| result.is_err())
.map(|(_, invalid_proof_path)| invalid_proof_path.to_str().unwrap().to_owned())
.collect::<Vec<String>>();
if verbose {
bar.finish();
}

let invalid_proof_num = invalid_proof_paths.len();
let valid_proof_num = proof_file_num - invalid_proof_num;
if verbose {
let max_to_display_valid_proof_num = 10;

println!(
"{}/{} user proofs pass the verification. {} fail, the first {} failed proof files: {:?}",
valid_proof_num, proof_file_num, invalid_proof_num, std::cmp::min(invalid_proof_num, invalid_proof_num), invalid_proof_paths.iter().take(max_to_display_valid_proof_num).collect::<Vec<&String>>(),
);
}

if invalid_proof_num > 0 {
return Err(PoRError::InvalidProof);
}
Ok(())
}

pub fn verify_global(global_proof_path: PathBuf) -> Result<(), PoRError> {
pub fn verify_global(global_proof_path: PathBuf, verbose: bool) -> Result<(), PoRError> {
let proof_file = File::open(&global_proof_path).unwrap();
let reader = std::io::BufReader::new(proof_file);

Expand All @@ -98,11 +120,13 @@ pub fn verify_global(global_proof_path: PathBuf) -> Result<(), PoRError> {
get_recursive_circuit_configs::<RECURSION_BRANCHOUT_NUM>(batch_num);

// not to use trace::log to avoid the dependency on the trace config.
println!(
"start to reconstruct the circuit with {} recursive levels for round {}",
recursive_circuit_configs.len(),
round_num
);
if verbose {
println!(
"start to reconstruct the circuit with {} recursive levels for round {}",
recursive_circuit_configs.len(),
round_num
);
}
let start = std::time::Instant::now();
let circuit_registry = CircuitRegistry::<RECURSION_BRANCHOUT_NUM>::init(
batch_size,
Expand All @@ -118,22 +142,24 @@ pub fn verify_global(global_proof_path: PathBuf) -> Result<(), PoRError> {
return Err(PoRError::CircuitDigestMismatch);
}

println!(
"successfully reconstruct the circuit for round {} in {:?}",
round_num,
start.elapsed()
);
if verbose {
println!(
"successfully reconstruct the circuit for round {} in {:?}",
round_num,
start.elapsed()
);

let equity = proof.proof.public_inputs
[RecursiveTargets::<RECURSION_BRANCHOUT_NUM>::pub_input_equity_offset()];
let debt = proof.proof.public_inputs
[RecursiveTargets::<RECURSION_BRANCHOUT_NUM>::pub_input_debt_offset()];
if !root_circuit.verify(proof.proof).is_ok() {
return Err(PoRError::InvalidProof);
}

let equity = proof.proof.public_inputs
[RecursiveTargets::<RECURSION_BRANCHOUT_NUM>::pub_input_equity_offset()];
let debt = proof.proof.public_inputs
[RecursiveTargets::<RECURSION_BRANCHOUT_NUM>::pub_input_debt_offset()];
if !root_circuit.verify(proof.proof).is_ok() {
return Err(PoRError::InvalidProof);
println!("successfully verify the global proof for round {}, total exchange users' equity is {}, debt is {}, exchange liability is {}",
round_num, equity.to_canonical_u64(), debt.to_canonical_u64(), (equity- debt).to_canonical_u64());
}

println!("successfully verify the global proof for round {}, total exchange users' equity is {}, debt is {}, exchange liability is {}",
round_num, equity.to_canonical_u64(), debt.to_canonical_u64(), (equity- debt).to_canonical_u64());

Ok(())
}
12 changes: 6 additions & 6 deletions crates/zk-por-core/src/circuit_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use plonky2::{
plonk::circuit_data::CircuitConfig,
};

use super::circuit_utils::recursive_levels;

pub const STANDARD_CONFIG: CircuitConfig = CircuitConfig {
num_wires: 135,
num_routed_wires: 80,
Expand Down Expand Up @@ -61,15 +63,13 @@ pub const STANDARD_ZK_CONFIG: CircuitConfig = CircuitConfig {
pub fn get_recursive_circuit_configs<const RECURSION_BRANCHOUT_NUM: usize>(
batch_num: usize,
) -> Vec<CircuitConfig> {
let level = (batch_num as f64).log(RECURSION_BRANCHOUT_NUM as f64).ceil() as usize;
let level = recursive_levels(batch_num, RECURSION_BRANCHOUT_NUM);

let mut configs = vec![STANDARD_CONFIG; level];

if let Some(last) = configs.last_mut() {
*last = STANDARD_ZK_CONFIG; // Change the last element to 0
} else {
configs.push(STANDARD_ZK_CONFIG); // Add 0 if the vec is empty
}
// The last circuit is a zk circuit, and the rest are standard circuits.
// there is at least one recursive circuit.
*configs.last_mut().unwrap() = STANDARD_ZK_CONFIG;
configs
}

Expand Down
7 changes: 6 additions & 1 deletion crates/zk-por-core/src/circuit_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ pub fn prove_timing() -> TimingTree {
TimingTree::new("prove", level)
}

pub fn recursive_levels(batch_num: usize, recursion_branchout_num: usize) -> usize {
let level = (batch_num as f64).log(recursion_branchout_num as f64).ceil() as usize;
std::cmp::max(1, level)
}

/// Test runner for ease of testing
#[allow(clippy::unused_unit)]
pub fn run_circuit_test<T, F, const D: usize>(test: T)
Expand Down Expand Up @@ -71,7 +76,7 @@ pub mod test {
#[should_panic]
fn test_assert_non_negative_unsigned_panic() {
run_circuit_test(|builder, _pw| {
let x = builder.constant(F::from_canonical_i64(-1));
let x = builder.constant(F::from_canonical_u64(F::ORDER - 1));
assert_non_negative_unsigned(builder, x);
});
}
Expand Down
15 changes: 9 additions & 6 deletions crates/zk-por-core/src/global.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::{
circuit_utils::recursive_levels,
database::PoRDB,
merkle_sum_prover::utils::hash_2_subhashes,
recursive_prover::prover::hash_n_subhashes,
Expand Down Expand Up @@ -28,8 +29,7 @@ pub struct GlobalMst {

impl GlobalMst {
pub fn new(cfg: GlobalConfig) -> Self {
let top_level =
(cfg.num_of_batches as f64).log(cfg.recursion_branchout_num as f64).ceil() as usize;
let top_level = recursive_levels(cfg.num_of_batches, cfg.recursion_branchout_num);

let mst_vec = vec![HashOut::default(); 0]; // will resize later
let mut mst = Self { inner: mst_vec, top_recursion_level: top_level, cfg: cfg };
Expand Down Expand Up @@ -157,11 +157,14 @@ impl GlobalMst {

/// `recursive_level` count from bottom to top; recursive_level = 1 means the bottom layer; increase whilve moving to the top.
pub fn set_recursive_hash(&mut self, recursive_level: usize, index: usize, hash: HashOut<F>) {
debug!(
"set_recursive_hash, recursive_level: {:?}, index: {:?}, hash: {:?}",
recursive_level, index, hash
);
let idx = GlobalMst::get_recursive_global_index(&self.cfg, recursive_level, index);
tracing::debug!(
"set_recursive_hash, recursive_level: {:?}, index: {:?}, hash: {:?}, idx: {:?}",
recursive_level,
index,
hash,
idx,
);
self.inner[idx] = hash;
}

Expand Down
Loading

0 comments on commit 595a181

Please sign in to comment.