Skip to content

Commit

Permalink
Add support for collecting node-value coverage information to eval_ir…
Browse files Browse the repository at this point in the history
…/proc_main

PiperOrigin-RevId: 679791162
  • Loading branch information
allight authored and copybara-github committed Sep 28, 2024
1 parent c8a5247 commit 67b0033
Show file tree
Hide file tree
Showing 8 changed files with 861 additions and 15 deletions.
54 changes: 54 additions & 0 deletions xls/tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,25 @@ package(
licenses = ["notice"], # Apache 2.0
)

proto_library(
name = "node_coverage_stats_proto",
srcs = ["node_coverage_stats.proto"],
visibility = ["//xls:xls_users"],
deps = ["//xls/ir:xls_value_proto"],
)

cc_proto_library(
name = "node_coverage_stats_cc_proto",
visibility = ["//xls:xls_users"],
deps = [":node_coverage_stats_proto"],
)

py_proto_library(
name = "node_coverage_stats_py_pb2",
visibility = ["//xls:xls_users"],
deps = [":node_coverage_stats_proto"],
)

cc_binary(
name = "lec_main",
srcs = ["lec_main.cc"],
Expand Down Expand Up @@ -162,6 +181,7 @@ cc_binary(
srcs = ["eval_ir_main.cc"],
visibility = ["//xls:xls_users"],
deps = [
":node_coverage_utils",
"//xls/common:exit_status",
"//xls/common:init_xls",
"//xls/common/file:filesystem",
Expand All @@ -177,6 +197,7 @@ cc_binary(
"//xls/dslx/ir_convert:ir_converter",
"//xls/interpreter:evaluator_options",
"//xls/interpreter:ir_interpreter",
"//xls/interpreter:observer",
"//xls/interpreter:random_value",
"//xls/ir",
"//xls/ir:bits",
Expand Down Expand Up @@ -262,6 +283,7 @@ cc_binary(
visibility = ["//xls:xls_users"],
deps = [
":eval_utils",
":node_coverage_utils",
"//xls/codegen:module_signature_cc_proto",
"//xls/common:exit_status",
"//xls/common:init_xls",
Expand Down Expand Up @@ -703,9 +725,12 @@ py_test(
python_version = "PY3",
srcs_version = "PY3",
deps = [
":node_coverage_stats_py_pb2",
"//xls/common:runfiles",
"//xls/common:test_base",
"//xls/ir:xls_value_py_pb2",
"@com_google_absl_py//absl/testing:absltest",
"@com_google_absl_py//absl/testing:parameterized",
],
)

Expand Down Expand Up @@ -739,6 +764,7 @@ py_test(
python_version = "PY3",
srcs_version = "PY3",
deps = [
":node_coverage_stats_py_pb2",
"//xls/common:runfiles",
"@com_google_absl_py//absl/logging",
"@com_google_absl_py//absl/testing:absltest",
Expand Down Expand Up @@ -902,3 +928,31 @@ py_test(
"@com_google_absl_py//absl/testing:absltest",
],
)

cc_library(
name = "node_coverage_utils",
srcs = ["node_coverage_utils.cc"],
hdrs = ["node_coverage_utils.h"],
deps = [
":node_coverage_stats_cc_proto",
"//xls/common/file:filesystem",
"//xls/common/status:ret_check",
"//xls/common/status:status_macros",
"//xls/data_structures:inline_bitmap",
"//xls/data_structures:leaf_type_tree",
"//xls/interpreter:observer",
"//xls/ir",
"//xls/ir:bits",
"//xls/ir:source_location",
"//xls/ir:type",
"//xls/ir:value",
"//xls/ir:value_utils",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
"@com_google_protobuf//:protobuf",
],
)
67 changes: 56 additions & 11 deletions xls/tools/eval_ir_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,14 @@
#include "xls/dslx/parse_and_typecheck.h"
#include "xls/dslx/warning_kind.h"
#include "xls/interpreter/function_interpreter.h"
#include "xls/interpreter/observer.h"
#include "xls/interpreter/random_value.h"
#include "xls/ir/bits.h"
#include "xls/ir/events.h"
#include "xls/ir/format_preference.h"
#include "xls/ir/function.h"
#include "xls/ir/ir_parser.h"
#include "xls/ir/node.h"
#include "xls/ir/nodes.h"
#include "xls/ir/package.h"
#include "xls/ir/type.h"
Expand All @@ -93,6 +95,7 @@
#include "xls/passes/optimization_pass.h"
#include "xls/passes/optimization_pass_pipeline.h"
#include "xls/passes/pass_base.h"
#include "xls/tools/node_coverage_utils.h"

static constexpr std::string_view kUsage = R"(
Evaluates an IR file with user-specified or random inputs using the IR
Expand Down Expand Up @@ -198,6 +201,20 @@ ABSL_FLAG(
"force mismatches between JIT and interpreter for testing purposed.");
// LINT.ThenChange(//xls/build_rules/xls_ir_rules.bzl)

// TODO(allight): It might be nice to allow one to specify these in build files.
// Right now if you want this report you need to run eval_ir_main on the command
// line. Being able to use generate it with a xls_eval_ir_test target could
// conceivably be useful.
ABSL_FLAG(std::optional<std::string>, output_node_coverage_stats_proto,
std::nullopt,
"File to write a (binary) NodeCoverageStatsProto showing which bits "
"in the run were actually set for each node.");
// TODO(allight): It might be nice to allow one to specify these in build files.
ABSL_FLAG(std::optional<std::string>, output_node_coverage_stats_textproto,
std::nullopt,
"File to write a (text) NodeCoverageStatsProto showing which bits "
"in the run were actually set for each node.");

// TODO(allight): It would be nice to enable doing this automatically if the
// llvm jit code crashes or something.
ABSL_FLAG(
Expand Down Expand Up @@ -347,16 +364,18 @@ absl::StatusOr<InterpreterResult<Value>> RunLlvmInterpreter(
// results, respectively. These strings are included in error messages.
absl::StatusOr<std::vector<Value>> Eval(
Function* f, absl::Span<const ArgSet> arg_sets, bool use_jit,
std::optional<EvaluationObserver*> eval_observer = std::nullopt,
std::string_view actual_src = "actual",
std::string_view expected_src = "expected") {
EvalIrJitObserver observer(absl::GetFlag(FLAGS_use_llvm_jit_interpreter));
std::unique_ptr<FunctionJit> jit;
if (use_jit) {
// No support for procs yet.
XLS_ASSIGN_OR_RETURN(
jit,
FunctionJit::Create(f, absl::GetFlag(FLAGS_llvm_opt_level),
/*include_observer_callbacks=*/false, &observer));
jit, FunctionJit::Create(
f, absl::GetFlag(FLAGS_llvm_opt_level),
/*include_observer_callbacks=*/eval_observer.has_value(),
&observer));
}

std::vector<Value> results;
Expand All @@ -365,12 +384,25 @@ absl::StatusOr<std::vector<Value>> Eval(
if (use_jit) {
if (absl::GetFlag(FLAGS_test_only_inject_jit_result).empty()) {
if (absl::GetFlag(FLAGS_use_llvm_jit_interpreter)) {
XLS_RET_CHECK(!eval_observer)
<< "Observer not supported with llvm interpreter.";
XLS_ASSIGN_OR_RETURN(
result, DropInterpreterEvents(RunLlvmInterpreter(
observer.saved_opt_ir(), jit.get(), arg_set.args)));
} else {
std::optional<RuntimeEvaluationObserverAdapter> adapt;
if (eval_observer) {
adapt.emplace(
eval_observer.value(),
[](int64_t v) -> Node* {
return reinterpret_cast<Node*>(static_cast<intptr_t>(v));
},
jit->runtime());
XLS_RETURN_IF_ERROR(jit->SetRuntimeObserver(&adapt.value()));
}
XLS_ASSIGN_OR_RETURN(result,
DropInterpreterEvents(jit->Run(arg_set.args)));
jit->ClearRuntimeObserver();
}
} else {
XLS_ASSIGN_OR_RETURN(result, Parser::ParseTypedValue(absl::GetFlag(
Expand All @@ -381,8 +413,8 @@ absl::StatusOr<std::vector<Value>> Eval(
// resulting events once the JIT fully supports events. Note: This will
// require rethinking some of the control flow because event comparison
// only makes sense for certain modes (optimize_ir and test_llvm_jit).
XLS_ASSIGN_OR_RETURN(
result, DropInterpreterEvents(InterpretFunction(f, arg_set.args)));
XLS_ASSIGN_OR_RETURN(result, DropInterpreterEvents(InterpretFunction(
f, arg_set.args, eval_observer)));
}
std::cout << result.ToString(FormatPreference::kHex) << '\n';

Expand Down Expand Up @@ -416,6 +448,9 @@ class EvalInvariantChecker : public OptimizationInvariantChecker {
}
XLS_ASSIGN_OR_RETURN(Function * f, package->GetTopAsFunction());
XLS_RETURN_IF_ERROR(Eval(f, arg_sets_, use_jit_,
// Runs between passes don't give useful coverage
// information.
/*eval_observer=*/std::nullopt,
/*actual_src=*/results->invocations.empty()
? std::string("start of pipeline")
: results->invocations.back().pass_name,
Expand All @@ -434,29 +469,38 @@ class EvalInvariantChecker : public OptimizationInvariantChecker {
// after optimizations.
absl::Status Run(Package* package, absl::Span<const ArgSet> arg_sets_in) {
XLS_ASSIGN_OR_RETURN(Function * f, package->GetTopAsFunction());
ScopedRecordNodeCoverage cov(
absl::GetFlag(FLAGS_output_node_coverage_stats_proto),
absl::GetFlag(FLAGS_output_node_coverage_stats_textproto));
// Copy the input ArgSets because we want to write in expected values if they
// do not exist.
std::vector<ArgSet> arg_sets(arg_sets_in.begin(), arg_sets_in.end());

if (absl::GetFlag(FLAGS_test_llvm_jit)) {
QCHECK(!absl::GetFlag(FLAGS_optimize_ir))
<< "Cannot specify both --test_llvm_jit and --optimize_ir";
XLS_ASSIGN_OR_RETURN(std::vector<Value> interpreter_results,
Eval(f, arg_sets, /*use_jit=*/false));
XLS_ASSIGN_OR_RETURN(
std::vector<Value> interpreter_results,
Eval(f, arg_sets, /*use_jit=*/false, /*eval_observer=*/std::nullopt));
for (int64_t i = 0; i < arg_sets.size(); ++i) {
QCHECK(!arg_sets[i].expected.has_value())
<< "Cannot specify expected values when using --test_llvm_jit";
arg_sets[i].expected = interpreter_results[i];
}
return Eval(f, arg_sets, /*use_jit=*/true, "JIT", "interpreter").status();
XLS_RETURN_IF_ERROR(Eval(f, arg_sets, /*use_jit=*/true,
/*eval_observer=*/cov.observer(), "JIT",
"interpreter")
.status());
return absl::OkStatus();
}

// Run the argsets through the IR before any optimizations. Write in the
// results as the expected values if the expected value is not already
// set. These expected values are used in any later evaluation after
// optimizations.
XLS_ASSIGN_OR_RETURN(std::vector<Value> results,
Eval(f, arg_sets, absl::GetFlag(FLAGS_use_llvm_jit)));
XLS_ASSIGN_OR_RETURN(
std::vector<Value> results,
Eval(f, arg_sets, absl::GetFlag(FLAGS_use_llvm_jit), cov.observer()));
for (int64_t i = 0; i < arg_sets.size(); ++i) {
if (!arg_sets[i].expected.has_value()) {
arg_sets[i].expected = results[i];
Expand All @@ -478,7 +522,8 @@ absl::Status Run(Package* package, absl::Span<const ArgSet> arg_sets_in) {
pipeline->Run(package, OptimizationPassOptions(), &results).status());

XLS_RETURN_IF_ERROR(Eval(f, arg_sets, absl::GetFlag(FLAGS_use_llvm_jit),
"after optimizations", "before optimizations")
cov.observer(), "after optimizations",
"before optimizations")
.status());
} else {
XLS_RET_CHECK(!absl::GetFlag(FLAGS_eval_after_each_pass))
Expand Down
71 changes: 69 additions & 2 deletions xls/tools/eval_ir_main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@
# limitations under the License.

import ctypes
import struct
import subprocess

from absl.testing import absltest
from absl.testing import parameterized
from xls.common import runfiles
from xls.common import test_base
from xls.ir import xls_value_pb2
from xls.tools import node_coverage_stats_pb2

EVAL_IR_MAIN_PATH = runfiles.get_path('xls/tools/eval_ir_main')

Expand All @@ -43,7 +46,22 @@
"""


class EvalMainTest(absltest.TestCase):
def _value_32_bits(v: int) -> xls_value_pb2.ValueProto:
return xls_value_pb2.ValueProto(
bits=xls_value_pb2.ValueProto.Bits(
bit_count=32, data=struct.pack('<i', v)
)
)


def parameterized_proc_backends(func):
return parameterized.named_parameters(
('jit', ['--use_llvm_jit']),
('interpreter', ['--nouse_llvm_jit']),
)(func)


class EvalMainTest(parameterized.TestCase):

def test_one_input_jit(self):
ir_file = self.create_tempfile(content=ADD_IR)
Expand Down Expand Up @@ -331,6 +349,55 @@ def test_validator_fails(self):
self.assertNotEqual(comp.returncode, 0)
self.assertIn('Unable to generate valid input', comp.stderr.decode('utf-8'))

@parameterized_proc_backends
def test_coverage(self, backend):
ir_file = self.create_tempfile(content=ADD_IR)
cov = self.create_tempfile()
subprocess.run(
[
EVAL_IR_MAIN_PATH,
ir_file.full_path,
'--input',
'bits[32]:0x5; bits[32]:0xC',
'--expected=bits[32]:0x11',
'--alsologtostderr',
f'--output_node_coverage_stats_proto={cov.full_path}',
]
+ backend,
check=True,
)
node_coverage = node_coverage_stats_pb2.NodeCoverageStatsProto.FromString(
cov.read_bytes()
)
node_stats = node_coverage_stats_pb2.NodeCoverageStatsProto.NodeStats
node_coverage.nodes.sort(key=lambda n: n.node_id)
self.assertEqual(
node_coverage.nodes,
[
node_stats(
node_id=1,
node_text='add.1: bits[32] = add(x, y, id=1)',
set_bits=_value_32_bits(0x11),
unset_bit_count=30,
total_bit_count=32,
),
node_stats(
node_id=4,
node_text='x: bits[32] = param(name=x, id=4)',
set_bits=_value_32_bits(0x5),
unset_bit_count=30,
total_bit_count=32,
),
node_stats(
node_id=5,
node_text='y: bits[32] = param(name=y, id=5)',
set_bits=_value_32_bits(0xC),
unset_bit_count=30,
total_bit_count=32,
),
],
)


if __name__ == '__main__':
test_base.main()
Loading

0 comments on commit 67b0033

Please sign in to comment.