From 03d03d2c479fa366a93f62967306a873aa7ff2c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=20Sewi=C5=82o?= <95349104+ksew1@users.noreply.github.com> Date: Tue, 17 Sep 2024 04:35:15 -0700 Subject: [PATCH] Support contracts (#64) Closes #43 --- CHANGELOG.md | 2 + Cargo.lock | 36 +++++-- Cargo.toml | 1 + README.md | 2 - crates/cairo-coverage/Cargo.toml | 4 +- crates/cairo-coverage/src/cli.rs | 13 +++ .../src/data_loader/loaded_data.rs | 97 +++++++++++-------- crates/cairo-coverage/src/data_loader/mod.rs | 1 + .../src/data_loader/sierra_program.rs | 85 ++++++++++++++++ crates/cairo-coverage/src/input/data.rs | 33 ++----- .../src/input/sierra_to_cairo_map.rs | 14 ++- .../src/input/statement_category_filter.rs | 24 ++--- .../src/input/unique_executed_sierra_ids.rs | 36 +++---- crates/cairo-coverage/src/main.rs | 94 +++++++++++++++--- crates/cairo-coverage/src/merge.rs | 45 +++++++++ .../tests/data/snforge_template/Scarb.toml | 24 +++++ .../tests/data/snforge_template/src/lib.cairo | 25 +++++ .../tests/test_contract.cairo | 47 +++++++++ crates/cairo-coverage/tests/e2e/general.rs | 7 ++ .../expected_output/snforge_template.lcov | 27 ++++++ 20 files changed, 486 insertions(+), 131 deletions(-) create mode 100644 crates/cairo-coverage/src/data_loader/sierra_program.rs create mode 100644 crates/cairo-coverage/src/merge.rs create mode 100644 crates/cairo-coverage/tests/data/snforge_template/Scarb.toml create mode 100644 crates/cairo-coverage/tests/data/snforge_template/src/lib.cairo create mode 100644 crates/cairo-coverage/tests/data/snforge_template/tests/test_contract.cairo create mode 100644 crates/cairo-coverage/tests/expected_output/snforge_template.lcov diff --git a/CHANGELOG.md b/CHANGELOG.md index 7a853f5..d1a9e19 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 #### Added +- Support for contracts - Option to not include macros in coverage report. To get the same behavior as before use `--include macros` +- `--project-path` flag to specify the path to the project root directory. This useful when inference fails #### Fixed diff --git a/Cargo.lock b/Cargo.lock index 16f03bd..0c865fa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -80,9 +80,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.88" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e1496f8fb1fbf272686b8d37f523dab3e4a7443300055e74cdaa449f3114356" +checksum = "86fdf8605db99b54d3cd748a44c6d04df638eb5dafb219b135d0149bd0db01f6" [[package]] name = "arrayvec" @@ -192,6 +192,7 @@ dependencies = [ "assert_fs", "cairo-lang-sierra", "cairo-lang-sierra-to-casm", + "cairo-lang-starknet-classes", "camino", "clap", "derived-deref", @@ -318,6 +319,29 @@ dependencies = [ "cairo-lang-utils", ] +[[package]] +name = "cairo-lang-starknet-classes" +version = "2.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "482b8f9d7f8cc7140f1260ee71f3308a66d15bd228a06281067ca3f8f4410db2" +dependencies = [ + "cairo-lang-casm", + "cairo-lang-sierra", + "cairo-lang-sierra-to-casm", + "cairo-lang-utils", + "convert_case", + "itertools 0.12.1", + "num-bigint", + "num-integer", + "num-traits 0.2.19", + "serde", + "serde_json", + "sha3", + "smol_str", + "starknet-types-core", + "thiserror", +] + [[package]] name = "cairo-lang-utils" version = "2.8.2" @@ -969,9 +993,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.19.0" +version = "1.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +checksum = "33ea5043e58958ee56f3e15a90aee535795cd7dfd319846288d93c5b57d85cbe" [[package]] name = "oorandom" @@ -1612,9 +1636,9 @@ checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" [[package]] name = "unicode-segmentation" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202" +checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" [[package]] name = "unicode-xid" diff --git a/Cargo.toml b/Cargo.toml index 8c6b07a..9de5ee9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ camino = "1.1.9" clap = { version = "4.5.17", features = ["derive"] } cairo-lang-sierra = "2.8.2" cairo-lang-sierra-to-casm = "2.8.2" +cairo-lang-starknet-classes = "2.8.2" derived-deref = "2.1.0" itertools = "0.13.0" serde = "1.0.210" diff --git a/README.md b/README.md index f752ec5..54d0aed 100644 --- a/README.md +++ b/README.md @@ -9,10 +9,8 @@ > > We currently don't support: > - Branch coverage -> - Contracts > > Things that might not work as expected: -> - Macros coverage > - Counters for how many times line was executed ## Installation diff --git a/crates/cairo-coverage/Cargo.toml b/crates/cairo-coverage/Cargo.toml index 7cfafcd..e210ed5 100644 --- a/crates/cairo-coverage/Cargo.toml +++ b/crates/cairo-coverage/Cargo.toml @@ -8,6 +8,7 @@ anyhow.workspace = true camino.workspace = true cairo-lang-sierra.workspace = true cairo-lang-sierra-to-casm.workspace = true +cairo-lang-starknet-classes.workspace = true clap.workspace = true derived-deref.workspace = true itertools.workspace = true @@ -15,8 +16,9 @@ serde.workspace = true serde_json.workspace = true trace-data.workspace = true regex.workspace = true +indoc.workspace = true [dev-dependencies] assert_fs.workspace = true snapbox.workspace = true -indoc.workspace = true + diff --git a/crates/cairo-coverage/src/cli.rs b/crates/cairo-coverage/src/cli.rs index 89802b3..df7f38c 100644 --- a/crates/cairo-coverage/src/cli.rs +++ b/crates/cairo-coverage/src/cli.rs @@ -16,6 +16,10 @@ pub struct Cli { /// Include additional components in the coverage report. #[arg(long, short, num_args = 1..)] pub include: Vec, + + /// Path to the project directory. If not provided, the project directory is inferred from the trace. + #[arg(value_parser = parse_project_path, long)] + pub project_path: Option, } #[derive(ValueEnum, Debug, Clone, Eq, PartialEq)] @@ -38,3 +42,12 @@ fn parse_trace_file(path: &str) -> Result { Ok(trace_file) } + +fn parse_project_path(path: &str) -> Result { + let project_path = Utf8PathBuf::from(path); + + ensure!(project_path.exists(), "Project path does not exist"); + ensure!(project_path.is_dir(), "Project path is not a directory"); + + Ok(project_path) +} diff --git a/crates/cairo-coverage/src/data_loader/loaded_data.rs b/crates/cairo-coverage/src/data_loader/loaded_data.rs index 5ce59cf..474eda4 100644 --- a/crates/cairo-coverage/src/data_loader/loaded_data.rs +++ b/crates/cairo-coverage/src/data_loader/loaded_data.rs @@ -1,65 +1,84 @@ -use anyhow::Context; -use anyhow::Result; +use crate::data_loader::sierra_program::{GetDebugInfos, SierraProgram}; +use anyhow::{Context, Result}; use cairo_lang_sierra::debug_info::DebugInfo; -use cairo_lang_sierra::program::{Program, ProgramArtifact, VersionedProgram}; +use cairo_lang_sierra_to_casm::compiler::CairoProgramDebugInfo; use camino::Utf8PathBuf; use derived_deref::Deref; use serde::de::DeserializeOwned; use std::collections::HashMap; use std::fs; -use trace_data::CallTrace; - -type SourceSierraPath = String; +use trace_data::{CallTrace, CallTraceNode, CasmLevelInfo}; #[derive(Deref)] -pub struct LoadedDataMap(HashMap); +pub struct LoadedDataMap(HashMap); pub struct LoadedData { - pub program: Program, pub debug_info: DebugInfo, - pub call_traces: Vec, + pub casm_level_infos: Vec, + pub casm_debug_info: CairoProgramDebugInfo, } impl LoadedDataMap { - pub fn load(call_trace_paths: &Vec) -> Result { - let mut map: HashMap = HashMap::new(); - for call_trace_path in call_trace_paths { - let call_trace: CallTrace = read_and_deserialize(call_trace_path)?; + pub fn load(call_trace_paths: &[Utf8PathBuf]) -> Result { + let execution_infos = call_trace_paths + .iter() + .map(read_and_deserialize) + .collect::>>()? + .into_iter() + .flat_map(load_nested_traces) + .filter_map(|call_trace| call_trace.cairo_execution_info) + .collect::>(); - let source_sierra_path = &call_trace - .cairo_execution_info - .as_ref() - .context("Missing key 'cairo_execution_info' in call trace. Perhaps you have outdated scarb?")? - .source_sierra_path; + // OPTIMIZATION: + // Group execution info by source Sierra path + // so that the same Sierra program does not need to be deserialized multiple times. + let execution_infos_by_sierra_path = execution_infos.into_iter().fold( + HashMap::new(), + |mut acc: HashMap<_, Vec<_>>, execution_info| { + acc.entry(execution_info.source_sierra_path) + .or_default() + .push(execution_info.casm_level_info); + acc + }, + ); - if let Some(loaded_data) = map.get_mut(&source_sierra_path.to_string()) { - loaded_data.call_traces.push(call_trace); - } else { - let VersionedProgram::V1 { - program: - ProgramArtifact { - program, + Ok(Self( + execution_infos_by_sierra_path + .into_iter() + .map(|(source_sierra_path, casm_level_infos)| { + read_and_deserialize::(&source_sierra_path)? + .compile_and_get_debug_infos() + .map(|(debug_info, casm_debug_info)| LoadedData { debug_info, - }, - .. - } = read_and_deserialize(source_sierra_path)?; + casm_level_infos, + casm_debug_info, + }) + .context(format!( + "Error occurred while loading program from: {source_sierra_path}" + )) + .map(|loaded_data| (source_sierra_path, loaded_data)) + }) + .collect::>()?, + )) + } +} - map.insert( - source_sierra_path.to_string(), - LoadedData { - program, - debug_info: debug_info - .context(format!("Debug info not found in: {source_sierra_path}"))?, - call_traces: vec![call_trace], - }, - ); +fn load_nested_traces(call_trace: CallTrace) -> Vec { + fn load_recursively(call_trace: CallTrace, acc: &mut Vec) { + acc.push(call_trace.clone()); + for call_trace_node in call_trace.nested_calls { + if let CallTraceNode::EntryPointCall(nested_call_trace) = call_trace_node { + load_recursively(nested_call_trace, acc); } } - Ok(Self(map)) } + + let mut call_traces = Vec::new(); + load_recursively(call_trace, &mut call_traces); + call_traces } -fn read_and_deserialize(file_path: &Utf8PathBuf) -> anyhow::Result { +fn read_and_deserialize(file_path: &Utf8PathBuf) -> Result { fs::read_to_string(file_path) .context(format!("Failed to read file at path: {file_path}")) .and_then(|content| { diff --git a/crates/cairo-coverage/src/data_loader/mod.rs b/crates/cairo-coverage/src/data_loader/mod.rs index 09d9bdb..f9615e4 100644 --- a/crates/cairo-coverage/src/data_loader/mod.rs +++ b/crates/cairo-coverage/src/data_loader/mod.rs @@ -1,4 +1,5 @@ mod loaded_data; +mod sierra_program; mod types; pub use loaded_data::{LoadedData, LoadedDataMap}; diff --git a/crates/cairo-coverage/src/data_loader/sierra_program.rs b/crates/cairo-coverage/src/data_loader/sierra_program.rs new file mode 100644 index 0000000..d0a5326 --- /dev/null +++ b/crates/cairo-coverage/src/data_loader/sierra_program.rs @@ -0,0 +1,85 @@ +use anyhow::{Context, Result}; +use cairo_lang_sierra::debug_info::DebugInfo; +use cairo_lang_sierra::program::{Program, ProgramArtifact, VersionedProgram}; +use cairo_lang_sierra_to_casm::compiler::{CairoProgramDebugInfo, SierraToCasmConfig}; +use cairo_lang_sierra_to_casm::metadata::{calc_metadata, MetadataComputationConfig}; +use cairo_lang_starknet_classes::casm_contract_class::CasmContractClass; +use cairo_lang_starknet_classes::contract_class::ContractClass; +use serde::Deserialize; + +#[derive(Deserialize)] +#[serde(untagged)] +pub enum SierraProgram { + VersionedProgram(VersionedProgram), + ContractClass(ContractClass), +} + +pub trait GetDebugInfos { + fn compile_and_get_debug_infos(self) -> Result<(DebugInfo, CairoProgramDebugInfo)>; +} + +impl GetDebugInfos for VersionedProgram { + fn compile_and_get_debug_infos(self) -> Result<(DebugInfo, CairoProgramDebugInfo)> { + let VersionedProgram::V1 { + program: + ProgramArtifact { + program, + debug_info, + }, + .. + } = self; + + let debug_info = debug_info.context("Debug info not found in program")?; + let casm_debug_info = compile_program_to_casm_debug_info(&program)?; + Ok((debug_info, casm_debug_info)) + } +} + +impl GetDebugInfos for ContractClass { + fn compile_and_get_debug_infos(self) -> Result<(DebugInfo, CairoProgramDebugInfo)> { + let debug_info = self + .sierra_program_debug_info + .context("Debug info not found in contract")?; + + // OPTIMIZATION: + // Debug info is unused in the compilation. This saves us a costly clone. + let casm_debug_info = compile_contract_class_to_casm_debug_info(ContractClass { + sierra_program_debug_info: None, + ..self + })?; + + Ok((debug_info, casm_debug_info)) + } +} +impl GetDebugInfos for SierraProgram { + fn compile_and_get_debug_infos(self) -> Result<(DebugInfo, CairoProgramDebugInfo)> { + match self { + SierraProgram::VersionedProgram(program) => program.compile_and_get_debug_infos(), + SierraProgram::ContractClass(contract_class) => { + contract_class.compile_and_get_debug_infos() + } + } + } +} + +fn compile_program_to_casm_debug_info(program: &Program) -> Result { + cairo_lang_sierra_to_casm::compiler::compile( + program, + &calc_metadata(program, MetadataComputationConfig::default()) + .context("Failed calculating Sierra variables")?, + SierraToCasmConfig { + gas_usage_check: false, + max_bytecode_size: usize::MAX, + }, + ) + .map(|casm| casm.debug_info) + .context("Failed to compile program to casm") +} + +fn compile_contract_class_to_casm_debug_info( + contract_class: ContractClass, +) -> Result { + CasmContractClass::from_contract_class_with_debug_info(contract_class, false, usize::MAX) + .map(|(_, casm_debug_info)| casm_debug_info) + .context("Failed to compile contract class to casm") +} diff --git a/crates/cairo-coverage/src/input/data.rs b/crates/cairo-coverage/src/input/data.rs index 92ec18f..7fa647b 100644 --- a/crates/cairo-coverage/src/input/data.rs +++ b/crates/cairo-coverage/src/input/data.rs @@ -1,10 +1,8 @@ use crate::data_loader::LoadedData; use crate::input::statement_category_filter::StatementCategoryFilter; use crate::input::{create_sierra_to_cairo_map, SierraToCairoMap, UniqueExecutedSierraIds}; +use crate::merge::MergeOwned; use anyhow::{Context, Result}; -use cairo_lang_sierra::program::Program; -use cairo_lang_sierra_to_casm::compiler::{CairoProgram, SierraToCasmConfig}; -use cairo_lang_sierra_to_casm::metadata::{calc_metadata, MetadataComputationConfig}; pub struct InputData { pub unique_executed_sierra_ids: UniqueExecutedSierraIds, @@ -14,23 +12,19 @@ pub struct InputData { impl InputData { pub fn new( LoadedData { - program, debug_info, - call_traces, + casm_level_infos, + casm_debug_info, }: &LoadedData, filter: &StatementCategoryFilter, ) -> Result { let sierra_to_cairo_map = create_sierra_to_cairo_map(debug_info, filter)?; - let casm = compile_to_casm(program)?; - let unique_executed_sierra_ids = call_traces + let unique_executed_sierra_ids = casm_level_infos .iter() - .map(|call_trace| UniqueExecutedSierraIds::new(&casm, call_trace, &sierra_to_cairo_map)) - .collect::>>()? - .into_iter() - .reduce(|mut acc, unique_executed_sierra_ids| { - acc.extend(unique_executed_sierra_ids.clone().into_iter()); - acc + .map(|casm_level_info| { + UniqueExecutedSierraIds::new(casm_debug_info, casm_level_info, &sierra_to_cairo_map) }) + .reduce(MergeOwned::merge_owned) .context("Failed to create unique executed sierra ids")?; Ok(Self { @@ -39,16 +33,3 @@ impl InputData { }) } } - -fn compile_to_casm(program: &Program) -> Result { - cairo_lang_sierra_to_casm::compiler::compile( - program, - &calc_metadata(program, MetadataComputationConfig::default()) - .context("Failed calculating Sierra variables")?, - SierraToCasmConfig { - gas_usage_check: false, - max_bytecode_size: usize::MAX, - }, - ) - .context("Failed to compile sierra to casm") -} diff --git a/crates/cairo-coverage/src/input/sierra_to_cairo_map.rs b/crates/cairo-coverage/src/input/sierra_to_cairo_map.rs index 4ce4ae6..d9f5b28 100644 --- a/crates/cairo-coverage/src/input/sierra_to_cairo_map.rs +++ b/crates/cairo-coverage/src/input/sierra_to_cairo_map.rs @@ -5,6 +5,7 @@ use anyhow::{Context, Result}; use cairo_lang_sierra::debug_info::{Annotations, DebugInfo}; use cairo_lang_sierra::program::StatementIdx; use derived_deref::Deref; +use indoc::formatdoc; use serde::de::DeserializeOwned; use std::collections::HashMap; @@ -82,7 +83,18 @@ trait Namespace { annotations .get(Self::NAMESPACE) .cloned() - .context(format!("Expected key: {} but was missing", Self::NAMESPACE)) + .context(formatdoc! { + r#"Expected key: {} but was missing. + + Perhaps you are missing the following entries in Scarb.toml: + + [profile.dev.cairo] + unstable-add-statements-functions-debug-info = true + unstable-add-statements-code-locations-debug-info = true + inlining-strategy= "avoid" + "#, + Self::NAMESPACE, + }) .and_then(|value| { serde_json::from_value(value) .context(format!("Failed to deserialize at key: {}", Self::NAMESPACE)) diff --git a/crates/cairo-coverage/src/input/statement_category_filter.rs b/crates/cairo-coverage/src/input/statement_category_filter.rs index 19f6224..6cc0920 100644 --- a/crates/cairo-coverage/src/input/statement_category_filter.rs +++ b/crates/cairo-coverage/src/input/statement_category_filter.rs @@ -1,7 +1,6 @@ use crate::cli::IncludedComponent; use crate::data_loader::LoadedData; use crate::input::sierra_to_cairo_map::StatementOrigin; -use anyhow::{Context, Result}; use camino::Utf8PathBuf; use regex::Regex; use std::collections::HashSet; @@ -10,7 +9,6 @@ use std::sync::LazyLock; pub static VIRTUAL_FILE_REGEX: LazyLock = LazyLock::new(|| Regex::new(r"\[.*?]").unwrap()); const SNFORGE_TEST_EXECUTABLE: &str = "snforge_internal_test_executable"; -const SNFORGE_SIERRA_DIR: &str = ".snfoundry_versioned_programs"; #[derive(Eq, PartialEq, Hash)] enum StatementCategory { @@ -37,10 +35,10 @@ pub struct StatementCategoryFilter { impl StatementCategoryFilter { pub fn new( - source_sierra_path: &str, + user_project_path: &Utf8PathBuf, included_component: &[IncludedComponent], loaded_data: &LoadedData, - ) -> Result { + ) -> Self { let test_functions = loaded_data .debug_info .executables @@ -57,12 +55,13 @@ impl StatementCategoryFilter { .chain(once(StatementCategory::UserFunction)) .collect(); - let user_project_path = find_user_project_path(source_sierra_path)?; - Ok(Self { + let user_project_path = user_project_path.to_string(); + + Self { user_project_path, allowed_statement_categories, test_functions, - }) + } } pub fn should_include(&self, statement_origin: &StatementOrigin) -> bool { @@ -95,14 +94,3 @@ impl StatementCategoryFilter { labels } } - -fn find_user_project_path(source_sierra_path: &str) -> Result { - Utf8PathBuf::from(source_sierra_path) - .parent() - .filter(|parent| parent.file_name() == Some(SNFORGE_SIERRA_DIR)) - .and_then(|parent| parent.parent()) - .map(ToString::to_string) - .context(format!( - "Source sierra path should be in the format: /{SNFORGE_SIERRA_DIR}/.sierra.json, got: {source_sierra_path}" - )) -} diff --git a/crates/cairo-coverage/src/input/unique_executed_sierra_ids.rs b/crates/cairo-coverage/src/input/unique_executed_sierra_ids.rs index 63b1dca..23fd5b0 100644 --- a/crates/cairo-coverage/src/input/unique_executed_sierra_ids.rs +++ b/crates/cairo-coverage/src/input/unique_executed_sierra_ids.rs @@ -1,40 +1,31 @@ use crate::input::SierraToCairoMap; -use anyhow::{Context, Result}; +use crate::merge::MergeOwned; use cairo_lang_sierra::program::StatementIdx; -use cairo_lang_sierra_to_casm::compiler::CairoProgram; +use cairo_lang_sierra_to_casm::compiler::CairoProgramDebugInfo; use derived_deref::Deref; use itertools::Itertools; use std::collections::HashMap; -use trace_data::{CallTrace, CasmLevelInfo}; +use trace_data::CasmLevelInfo; #[derive(Deref)] pub struct UniqueExecutedSierraIds(HashMap); -impl Extend<(StatementIdx, usize)> for UniqueExecutedSierraIds { - fn extend>(&mut self, iter: T) { - for (key, value) in iter { - self.0 - .entry(key) - .and_modify(|e| *e += value) - .or_insert(value); - } +impl MergeOwned for UniqueExecutedSierraIds { + fn merge_owned(self, other: Self) -> Self { + Self(self.0.merge_owned(other.0)) } } impl UniqueExecutedSierraIds { pub fn new( - casm: &CairoProgram, - call_trace: &CallTrace, + casm_debug_info: &CairoProgramDebugInfo, + casm_level_info: &CasmLevelInfo, sierra_to_cairo_map: &SierraToCairoMap, - ) -> Result { + ) -> Self { let CasmLevelInfo { run_with_call_header, vm_trace, - } = &call_trace - .cairo_execution_info - .as_ref() - .context("Missing key 'cairo_execution_info' in call trace")? - .casm_level_info; + } = &casm_level_info; let real_minimal_pc = run_with_call_header .then(|| vm_trace.last().map_or(1, |trace| trace.pc + 1)) @@ -46,17 +37,14 @@ impl UniqueExecutedSierraIds { .filter(|pc| pc >= &real_minimal_pc) .map(|pc| { let real_pc_code_offset = pc - real_minimal_pc; - casm.debug_info + casm_debug_info .sierra_statement_info .partition_point(|debug_info| debug_info.start_offset <= real_pc_code_offset) - 1 }) .map(StatementIdx); - Ok(squash_idx_pointing_to_same_statement( - iter, - sierra_to_cairo_map, - )) + squash_idx_pointing_to_same_statement(iter, sierra_to_cairo_map) } } diff --git a/crates/cairo-coverage/src/main.rs b/crates/cairo-coverage/src/main.rs index c63e9ce..e1267bc 100644 --- a/crates/cairo-coverage/src/main.rs +++ b/crates/cairo-coverage/src/main.rs @@ -2,6 +2,7 @@ mod cli; mod coverage_data; mod data_loader; mod input; +mod merge; mod output; mod types; @@ -9,33 +10,98 @@ use crate::coverage_data::create_files_coverage_data_with_hits; use crate::data_loader::LoadedDataMap; use crate::input::{InputData, StatementCategoryFilter}; use crate::output::lcov::LcovFormat; -use anyhow::{Context, Result}; +use anyhow::{bail, ensure, Context, Result}; +use camino::Utf8PathBuf; use clap::Parser; use cli::Cli; +use indoc::indoc; +use merge::MergeOwned; use std::fs::OpenOptions; use std::io::Write; +const SNFORGE_SIERRA_DIR: &str = ".snfoundry_versioned_programs"; + fn main() -> Result<()> { - let cli = Cli::parse(); + let Cli { + trace_files, + include, + output_path, + project_path, + } = &Cli::parse(); - let output_path = &cli.output_path; + let coverage_data = LoadedDataMap::load(trace_files)? + .iter() + .map(|(source_sierra_path, loaded_data)| { + let project_path = &get_project_path(source_sierra_path, project_path)?; + let filter = StatementCategoryFilter::new(project_path, include, loaded_data); + let input_data = InputData::new(loaded_data, &filter)?; + Ok(create_files_coverage_data_with_hits(&input_data)) + }) + .collect::>>()? + .into_iter() + // Versioned programs and contract classes can represent the same piece of code, + // so we merge the file coverage after processing them to avoid duplicate entries. + .reduce(MergeOwned::merge_owned) + .context("No elements to merge")?; - let mut file = OpenOptions::new() + OpenOptions::new() .append(true) .create(true) .open(output_path) - .context(format!("Failed to open output file at path: {output_path}"))?; + .context(format!("Failed to open output file at path: {output_path}"))? + .write_all(LcovFormat::from(coverage_data).to_string().as_bytes()) + .context("Failed to write to output file")?; - let loaded_data = LoadedDataMap::load(&cli.trace_files)?; - for (source_sierra_path, loaded_data) in loaded_data.iter() { - let filter = StatementCategoryFilter::new(source_sierra_path, &cli.include, loaded_data)?; - let input_data = InputData::new(loaded_data, &filter)?; - let coverage_data = create_files_coverage_data_with_hits(&input_data); - let output_data = LcovFormat::from(coverage_data); + Ok(()) +} - file.write_all(output_data.to_string().as_bytes()) - .context("Failed to write to output file")?; +fn get_project_path( + source_sierra_path: &Utf8PathBuf, + project_path: &Option, +) -> Result { + if let Some(project_path) = project_path { + Ok(project_path.clone()) + } else { + find_user_project_path(source_sierra_path).context(indoc! { + r"Inference of project path failed. + Please provide the project path explicitly using the --project-path flag. + If you are using snforge, it is not possible to use cairo-coverage flags. + You need to run cairo-coverage directly." + }) } +} - Ok(()) +fn find_user_project_path(source_sierra_path: &Utf8PathBuf) -> Result { + ensure!( + source_sierra_path.extension() == Some("json"), + "Source sierra path should have a .json extension, got: {source_sierra_path}" + ); + + match source_sierra_path.with_extension("").extension() { + Some("sierra") => { + source_sierra_path + .parent() + .filter(|parent| parent.file_name() == Some(SNFORGE_SIERRA_DIR)) + .and_then(|parent| parent.parent()) + .map(Utf8PathBuf::from) + .context(format!( + "Source sierra path should be in the format: /{SNFORGE_SIERRA_DIR}/.sierra.json, got: {source_sierra_path}" + )) + } + Some("contract_class") => { + source_sierra_path + .parent() + .filter(|parent| parent.file_name() == Some("dev")) + .and_then(|parent| parent.parent()) + .filter(|parent| parent.file_name() == Some("target")) + .and_then(|parent| parent.parent()) + .map(Utf8PathBuf::from) + .context(format!( + "Source sierra path should be in the format: /target/dev/.contract_class.json, got: {source_sierra_path}" + )) + } + _ => bail!( + "Source sierra path should have a .sierra or .contract_class extension, got: {source_sierra_path}" + ), + } } diff --git a/crates/cairo-coverage/src/merge.rs b/crates/cairo-coverage/src/merge.rs new file mode 100644 index 0000000..0f7123a --- /dev/null +++ b/crates/cairo-coverage/src/merge.rs @@ -0,0 +1,45 @@ +use std::collections::HashMap; +use std::hash::Hash; + +trait Merge { + fn merge(&mut self, other: Self); +} + +impl Merge for HashMap +where + K: Eq + Hash, +{ + fn merge(&mut self, other: Self) { + for (key, value) in other { + *self.entry(key).or_default() += value; + } + } +} + +impl Merge for HashMap +where + K: Eq + Hash, + V: Merge + Clone, +{ + fn merge(&mut self, other: Self) { + for (key, value) in other { + self.entry(key) + .and_modify(|e| e.merge(value.clone())) + .or_insert(value); + } + } +} + +pub trait MergeOwned { + fn merge_owned(self, other: Self) -> Self; +} + +impl MergeOwned for T +where + T: Merge, +{ + fn merge_owned(mut self, other: Self) -> Self { + self.merge(other); + self + } +} diff --git a/crates/cairo-coverage/tests/data/snforge_template/Scarb.toml b/crates/cairo-coverage/tests/data/snforge_template/Scarb.toml new file mode 100644 index 0000000..15652d2 --- /dev/null +++ b/crates/cairo-coverage/tests/data/snforge_template/Scarb.toml @@ -0,0 +1,24 @@ +[package] +name = "snforge_template" +version = "0.1.0" +edition = "2023_11" + +# See more keys and their definitions at https://docs.swmansion.com/scarb/docs/reference/manifest.html + +[dependencies] +starknet = "2.8.0" + +[dev-dependencies] +snforge_std = { git = "https://github.com/foundry-rs/starknet-foundry", tag = "v0.30.0" } + +[[target.starknet-contract]] +sierra = true + +[scripts] +test = "snforge test" + + +[profile.dev.cairo] +unstable-add-statements-functions-debug-info = true +unstable-add-statements-code-locations-debug-info = true +inlining-strategy= "avoid" \ No newline at end of file diff --git a/crates/cairo-coverage/tests/data/snforge_template/src/lib.cairo b/crates/cairo-coverage/tests/data/snforge_template/src/lib.cairo new file mode 100644 index 0000000..4955786 --- /dev/null +++ b/crates/cairo-coverage/tests/data/snforge_template/src/lib.cairo @@ -0,0 +1,25 @@ +#[starknet::interface] +pub trait IHelloStarknet { + fn increase_balance(ref self: TContractState, amount: felt252); + fn get_balance(self: @TContractState) -> felt252; +} + +#[starknet::contract] +mod HelloStarknet { + #[storage] + struct Storage { + balance: felt252, + } + + #[abi(embed_v0)] + impl HelloStarknetImpl of super::IHelloStarknet { + fn increase_balance(ref self: ContractState, amount: felt252) { + assert(amount != 0, 'Amount cannot be 0'); + self.balance.write(self.balance.read() + amount); + } + + fn get_balance(self: @ContractState) -> felt252 { + self.balance.read() + } + } +} diff --git a/crates/cairo-coverage/tests/data/snforge_template/tests/test_contract.cairo b/crates/cairo-coverage/tests/data/snforge_template/tests/test_contract.cairo new file mode 100644 index 0000000..f9d53e3 --- /dev/null +++ b/crates/cairo-coverage/tests/data/snforge_template/tests/test_contract.cairo @@ -0,0 +1,47 @@ +use starknet::ContractAddress; + +use snforge_std::{declare, ContractClassTrait, DeclareResultTrait}; + +use snforge_template::IHelloStarknetSafeDispatcher; +use snforge_template::IHelloStarknetSafeDispatcherTrait; +use snforge_template::IHelloStarknetDispatcher; +use snforge_template::IHelloStarknetDispatcherTrait; + +fn deploy_contract(name: ByteArray) -> ContractAddress { + let contract = declare(name).unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@ArrayTrait::new()).unwrap(); + contract_address +} + +#[test] +fn test_increase_balance() { + let contract_address = deploy_contract("HelloStarknet"); + + let dispatcher = IHelloStarknetDispatcher { contract_address }; + + let balance_before = dispatcher.get_balance(); + assert(balance_before == 0, 'Invalid balance'); + + dispatcher.increase_balance(42); + + let balance_after = dispatcher.get_balance(); + assert(balance_after == 42, 'Invalid balance'); +} + +#[test] +#[feature("safe_dispatcher")] +fn test_cannot_increase_balance_with_zero_value() { + let contract_address = deploy_contract("HelloStarknet"); + + let safe_dispatcher = IHelloStarknetSafeDispatcher { contract_address }; + + let balance_before = safe_dispatcher.get_balance().unwrap(); + assert(balance_before == 0, 'Invalid balance'); + + match safe_dispatcher.increase_balance(0) { + Result::Ok(_) => core::panic_with_felt252('Should have panicked'), + Result::Err(panic_data) => { + assert(*panic_data.at(0) == 'Amount cannot be 0', *panic_data.at(0)); + } + }; +} diff --git a/crates/cairo-coverage/tests/e2e/general.rs b/crates/cairo-coverage/tests/e2e/general.rs index 7cb59aa..1901255 100644 --- a/crates/cairo-coverage/tests/e2e/general.rs +++ b/crates/cairo-coverage/tests/e2e/general.rs @@ -63,3 +63,10 @@ fn macros_not_included() { .run() .output_same_as_in_file("macros_not_included.lcov"); } + +#[test] +fn snforge_template() { + TestProject::new("snforge_template") + .run() + .output_same_as_in_file("snforge_template.lcov"); +} diff --git a/crates/cairo-coverage/tests/expected_output/snforge_template.lcov b/crates/cairo-coverage/tests/expected_output/snforge_template.lcov new file mode 100644 index 0000000..3215180 --- /dev/null +++ b/crates/cairo-coverage/tests/expected_output/snforge_template.lcov @@ -0,0 +1,27 @@ +TN: +SF:{dir}/src/lib.cairo +FN:22,snforge_template::HelloStarknet::HelloStarknetImpl::get_balance +FNDA:12,snforge_template::HelloStarknet::HelloStarknetImpl::get_balance +FN:16,snforge_template::HelloStarknet::HelloStarknetImpl::increase_balance +FNDA:7,snforge_template::HelloStarknet::HelloStarknetImpl::increase_balance +FNF:2 +FNH:2 +DA:16,1 +DA:17,2 +DA:18,7 +DA:22,12 +LF:4 +LH:4 +end_of_record +TN: +SF:{dir}/tests/test_contract.cairo +FN:10,snforge_template_integrationtest::test_contract::deploy_contract +FNDA:6,snforge_template_integrationtest::test_contract::deploy_contract +FNF:1 +FNH:1 +DA:10,2 +DA:11,6 +DA:12,6 +LF:3 +LH:3 +end_of_record