Skip to content

Commit

Permalink
Optimized gate argument code.
Browse files Browse the repository at this point in the history
  • Loading branch information
Martun Karapetyan committed Jul 12, 2024
1 parent 27b55cd commit dcf70a7
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 131 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ namespace nil {
*/
expression_evaluator(
const math::expression<VariableType>& expr,
std::function<ValueType(const VariableType&)> get_var_value)
std::function<const ValueType&(const VariableType&)> get_var_value)
: expr(expr)
, get_var_value(get_var_value) {
}
Expand Down Expand Up @@ -140,7 +140,7 @@ namespace nil {
const math::expression<VariableType>& expr;

// A function used to retrieve the value of a variable.
std::function<ValueType(const VariableType &var)> get_var_value;
std::function<const ValueType&(const VariableType &var)> get_var_value;

};

Expand Down Expand Up @@ -207,7 +207,7 @@ namespace nil {
*/
cached_expression_evaluator(
const math::expression<VariableType>& expr,
std::function<ValueType(const VariableType&)> get_var_value)
std::function<const ValueType&(const VariableType&)> get_var_value)
: _expr(expr)
, _get_var_value(get_var_value) {
}
Expand Down Expand Up @@ -304,7 +304,7 @@ namespace nil {
const math::expression<VariableType>& _expr;

// A function used to retrieve the value of a variable.
std::function<ValueType(const VariableType &var)> _get_var_value;
std::function<const ValueType&(const VariableType &var)> _get_var_value;

// Shows how many times each subexpression appears. We count have the expression
// itself as a key, but apparently it's waay too slow. Just map the hash->count, assume
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <nil/crypto3/zk/snark/arithmetization/plonk/padding.hpp>
#include <nil/crypto3/random/algebraic_engine.hpp>
#include <nil/crypto3/math/polynomial/polynomial_dfs.hpp>
#include <nil/crypto3/zk/snark/arithmetization/plonk/variable.hpp>

namespace nil {
namespace blueprint {
Expand All @@ -55,6 +56,7 @@ namespace nil {
class plonk_private_table {
public:
using witnesses_container_type = std::vector<ColumnType>;
using VariableType = plonk_variable<ColumnType>;

protected:

Expand Down Expand Up @@ -82,6 +84,30 @@ namespace nil {
return _witnesses[index].size();
}

const ColumnType& get_variable_value_without_rotation(const VariableType& var) const {
switch (var.type) {
case VariableType::column_type::witness:
return witness(var.index);
case VariableType::column_type::public_input:
return public_input(var.index);
case VariableType::column_type::constant:
return constant(var.index);
case VariableType::column_type::selector:
return selector(var.index);
default:
std::cerr << "Invalid column type" << std::endl;
abort();
}
}
ColumnType get_variable_value(const VariableType& var, std::shared_ptr<math::evaluation_domain<FieldType>> domain) const {
if (var.rotation == 0) {
return get_variable_value_without_rotation(var);
}
return math::polynomial_shift(
this->get_variable_value_without_rotation(var),
var.rotation, domain->m);
}

const ColumnType& witness(std::uint32_t index) const {
assert(index < _witnesses.size());
return _witnesses[index];
Expand Down Expand Up @@ -126,6 +152,7 @@ namespace nil {
using public_input_container_type = std::vector<ColumnType>;
using constant_container_type = std::vector<ColumnType>;
using selector_container_type = std::vector<ColumnType>;
using VariableType = plonk_variable<ColumnType>;

protected:

Expand Down Expand Up @@ -286,6 +313,7 @@ namespace nil {
using public_input_container_type = typename public_table_type::public_input_container_type;
using constant_container_type = typename public_table_type::constant_container_type;
using selector_container_type = typename public_table_type::selector_container_type;
using VariableType = plonk_variable<ColumnType>;

protected:
// These are normally created by the assigner, or read from a file.
Expand All @@ -309,6 +337,31 @@ namespace nil {
, _public_table(public_inputs_amount, constants_amount, selectors_amount) {
}

const ColumnType& get_variable_value_without_rotation(const VariableType& var) const {
switch (var.type) {
case VariableType::column_type::witness:
return witness(var.index);
case VariableType::column_type::public_input:
return public_input(var.index);
case VariableType::column_type::constant:
return constant(var.index);
case VariableType::column_type::selector:
return selector(var.index);
default:
std::cerr << "Invalid column type" << std::endl;
abort();
}
}

ColumnType get_variable_value(const VariableType& var, std::shared_ptr<math::evaluation_domain<FieldType>> domain) const {
if (var.rotation == 0) {
return get_variable_value_without_rotation(var);
}
return math::polynomial_shift(
this->get_variable_value_without_rotation(var),
var.rotation, domain->m);
}

const ColumnType& witness(std::uint32_t index) const {
return _private_table.witness(index);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ namespace nil {
const plonk_assignment_table<FieldType> &assignments) const {
math::expression_evaluator<VariableType> evaluator(
*this,
[&assignments, row_index](const VariableType &var) {
[&assignments, row_index](const VariableType &var) -> const typename VariableType::assignment_type& {
std::size_t rows_amount = assignments.rows_amount();
switch (var.type) {
case VariableType::column_type::witness:
Expand All @@ -100,48 +100,44 @@ namespace nil {
case VariableType::column_type::selector:
return assignments.selector(var.index)[(rows_amount + row_index + var.rotation) % rows_amount];
default:
BOOST_ASSERT_MSG(false, "Invalid column type");
return VariableType::assignment_type::zero();
std::cerr << "Invalid column type" << std::endl;
abort();
}
});

return evaluator.evaluate();
}

math::polynomial<typename VariableType::assignment_type>
evaluate(const plonk_polynomial_table<FieldType> &assignments,
std::shared_ptr<math::evaluation_domain<FieldType>>
domain) const {
using polynomial_type = math::polynomial<typename VariableType::assignment_type>;
using polynomial_variable_type = plonk_variable<polynomial_type>;
math::expression_variable_type_converter<VariableType, polynomial_variable_type> converter;

math::expression_evaluator<polynomial_variable_type> evaluator(
converter.convert(*this),
[&domain, &assignments](const VariableType &var) {
polynomial_type assignment;
switch (var.type) {
case VariableType::column_type::witness:
assignment = assignments.witness(var.index);
break;
case VariableType::column_type::public_input:
assignment = assignments.public_input(var.index);
break;
case VariableType::column_type::constant:
assignment = assignments.constant(var.index);
break;
case VariableType::column_type::selector:
assignment = assignments.selector(var.index);
break;
default:
BOOST_ASSERT_MSG(false, "Invalid column type");
}
evaluate(const plonk_polynomial_table<FieldType> &assignments,
std::shared_ptr<math::evaluation_domain<FieldType>> domain) const {

if (var.rotation != 0) {
assignment =
math::polynomial_shift(assignment, domain->get_domain_element(var.rotation));
using polynomial_type = math::polynomial<typename VariableType::assignment_type>;
using polynomial_variable_type = plonk_variable<polynomial_type>;

// Convert scalar values to polynomials inside the expression.
math::expression_variable_type_converter<VariableType, polynomial_variable_type> converter;
auto converted_expression = converter.convert(*this);

// For each variable with a rotation pre-compute its value.
std::unordered_map<polynomial_variable_type, polynomial_type> rotated_variable_values;

math::expression_for_each_variable_visitor<polynomial_variable_type> visitor(
[&rotated_variable_values, &assignments, &domain](const polynomial_variable_type& var) {
if (var.rotation == 0)
return;
rotated_variable_values[var] = assignments.get_variable_value(var, domain);
});
visitor.visit(converted_expression);

math::expression_evaluator<polynomial_variable_type> evaluator(
converted_expression,
[&domain, &assignments, &rotated_variable_values]
(const VariableType &var) -> const polynomial_type& {
if (var.rotation == 0) {
return assignments.get_variable_value_without_rotation(var, domain);
}
return assignment;
return rotated_variable_values[var];
});
return evaluator.evaluate();
}
Expand All @@ -152,35 +148,33 @@ namespace nil {
using polynomial_dfs_type = math::polynomial_dfs<typename VariableType::assignment_type>;
using polynomial_dfs_variable_type = plonk_variable<polynomial_dfs_type>;

// Convert scalar values to polynomials inside the expression.
math::expression_variable_type_converter<variable_type, polynomial_dfs_variable_type> converter(
[&assignments](const typename VariableType::assignment_type& coeff) {
polynomial_dfs_type(0, assignments.rows_amount(), coeff);
});
math::expression_evaluator<polynomial_dfs_variable_type> evaluator(
converter.convert(*this),
[&domain, &assignments](const polynomial_dfs_variable_type &var) {
polynomial_dfs_type assignment;
switch (var.type) {
case VariableType::column_type::witness:
assignment = assignments.witness(var.index);
break;
case VariableType::column_type::public_input:
assignment = assignments.public_input(var.index);
break;
case VariableType::column_type::constant:
assignment = assignments.constant(var.index);
break;
case VariableType::column_type::selector:
assignment = assignments.selector(var.index);
break;
default:
BOOST_ASSERT_MSG(false, "Invalid column type");
}

if (var.rotation != 0) {
assignment = math::polynomial_shift(assignment, var.rotation, domain->m);
auto converted_expression = converter.convert(*this);

// For each variable with a rotation pre-compute its value.
std::unordered_map<polynomial_dfs_variable_type, polynomial_dfs_type> rotated_variable_values;

math::expression_for_each_variable_visitor<polynomial_dfs_variable_type> visitor(
[&rotated_variable_values, &assignments, &domain](const polynomial_dfs_variable_type& var) {
if (var.rotation == 0)
return ;
rotated_variable_values[var] = assignments.get_variable_value(var, domain);
});
visitor.visit(converted_expression);

math::expression_evaluator<polynomial_dfs_variable_type> evaluator(
converted_expression,
[&domain, &assignments, &rotated_variable_values]
(const polynomial_dfs_variable_type &var) -> const polynomial_dfs_type& {
if (var.rotation == 0) {
return assignments.get_variable_value_without_rotation(var, domain);
}
return assignment;
return rotated_variable_values[var];
}
);

Expand All @@ -189,9 +183,10 @@ namespace nil {

typename VariableType::assignment_type
evaluate(detail::plonk_evaluation_map<VariableType> &assignments) const {

math::expression_evaluator<VariableType> evaluator(
*this,
[&assignments](const VariableType &var) {
[&assignments](const VariableType &var) -> const typename VariableType::assignment_type& {
std::tuple<std::size_t, int, typename VariableType::column_type> key =
std::make_tuple(var.index, var.rotation, var.type);

Expand Down
Loading

0 comments on commit dcf70a7

Please sign in to comment.