Skip to content

Commit

Permalink
mulVec with AVX-512 IFMA
Browse files Browse the repository at this point in the history
  • Loading branch information
herumi committed Apr 18, 2024
1 parent df89180 commit 0285b26
Show file tree
Hide file tree
Showing 5 changed files with 1,461 additions and 3 deletions.
10 changes: 10 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,16 @@ src/bint64.ll: src/gen_bint.exe
src/bint32.ll: src/gen_bint.exe
$< -u 32 -ver 0x90 > $@
endif
ifeq ($(ARCH),x86_64)
MSM=msm_avx
MCL_MSM?=1
endif
ifeq ($(MCL_MSM),1)
CFLAGS+=-DMCL_MSM=1
LIB_OBJ+=$(OBJ_DIR)/$(MSM).o
$(OBJ_DIR)/$(MSM).o: src/$(MSM).cpp
$(PRE)$(CXX) -c $< -o $@ $(CFLAGS) -mavx512f -mavx512ifma $(CFLAGS_USER)
endif
include/mcl/bint_proto.hpp: src/gen_bint_header.py
python3 $< > $@ proto $(GEN_BINT_HEADER_PY_OPT)
src/bint_switch.hpp: src/gen_bint_header.py
Expand Down
5 changes: 5 additions & 0 deletions include/mcl/bn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ void mulByCofactorBLS12fast(T& Q, const T& P);
#endif
namespace mcl {

extern void initMsm(const mcl::CurveParam& cp);

namespace MCL_NAMESPACE_BN {

namespace local {
Expand Down Expand Up @@ -2274,6 +2276,9 @@ inline void init(bool *pb, const mcl::CurveParam& cp = mcl::BN254, fp::Mode mode
if (!*pb) return;
G1::setMulVecGLV(mcl::ec::mulVecGLVT<local::GLV1, G1, Fr>);
G2::setMulVecGLV(mcl::ec::mulVecGLVT<local::GLV2, G2, Fr>);
#ifdef MCL_MSM
mcl::initMsm(cp);
#endif
Fp12::setPowVecGLV(local::powVecGLV);
G1::setCompressedExpression();
G2::setCompressedExpression();
Expand Down
11 changes: 11 additions & 0 deletions include/mcl/ec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1315,6 +1315,7 @@ class EcT : public fp::Serializable<EcT<_Fp, _Fr> > {
static bool verifyOrder_;
static mpz_class order_;
static bool (*mulVecGLV)(EcT& z, const EcT *xVec, const void *yVec, size_t n, bool constTime);
static void (*mulVecOpti)(EcT& z, EcT *xVec, const void *yVec, size_t n);
static bool (*isValidOrderFast)(const EcT& x);
/* default constructor is undefined value */
EcT() {}
Expand Down Expand Up @@ -1380,6 +1381,7 @@ class EcT : public fp::Serializable<EcT<_Fp, _Fr> > {
verifyOrder_ = false;
order_ = 0;
mulVecGLV = 0;
mulVecOpti = 0;
isValidOrderFast = 0;
mode_ = mode;
}
Expand All @@ -1406,6 +1408,10 @@ class EcT : public fp::Serializable<EcT<_Fp, _Fr> > {
{
mulVecGLV = f;
}
static void setMulVecOpti(void f(EcT& z, EcT *xVec, const void *yVec, size_t yn))
{
mulVecOpti = f;
}
static inline void init(bool *pb, const char *astr, const char *bstr, int mode = ec::Jacobi)
{
Fp a, b;
Expand Down Expand Up @@ -2069,6 +2075,10 @@ class EcT : public fp::Serializable<EcT<_Fp, _Fr> > {
z.clear();
return;
}
if (mulVecOpti && n >= 128) {
mulVecOpti(z, xVec, yVec, n);
return;
}
if (mulVecGLV && mulVecGLV(z, xVec, yVec, n, false)) {
return;
}
Expand Down Expand Up @@ -2175,6 +2185,7 @@ template<class Fp, class Fr> int EcT<Fp, Fr>::ioMode_;
template<class Fp, class Fr> bool EcT<Fp, Fr>::verifyOrder_;
template<class Fp, class Fr> mpz_class EcT<Fp, Fr>::order_;
template<class Fp, class Fr> bool (*EcT<Fp, Fr>::mulVecGLV)(EcT& z, const EcT *xVec, const void *yVec, size_t n, bool constTime);
template<class Fp, class Fr> void (*EcT<Fp, Fr>::mulVecOpti)(EcT& z, EcT *xVec, const void *yVec, size_t n);
template<class Fp, class Fr> bool (*EcT<Fp, Fr>::isValidOrderFast)(const EcT& x);
template<class Fp, class Fr> int EcT<Fp, Fr>::mode_;

Expand Down
6 changes: 3 additions & 3 deletions include/mcl/operator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ struct Operator : public E {
return;
}
const size_t w = 4;
const size_t N = 1 << w;
const size_t n = 1 << w;
uint8_t idxTbl[sizeof(T) * 8 / w];
mcl::fp::BitIterator<Unit> iter(y, yn);
size_t idxN = 0;
Expand All @@ -204,9 +204,9 @@ struct Operator : public E {
idxTbl[idxN++] = iter.getNext(w);
}
assert(idxN > 0);
T tbl[N];
T tbl[n];
tbl[1] = x;
for (size_t i = 2; i < N; i++) {
for (size_t i = 2; i < n; i++) {
tbl[i] = tbl[i-1] * x;
}
uint32_t idx = idxTbl[idxN - 1];
Expand Down
Loading

0 comments on commit 0285b26

Please sign in to comment.