From 9a30b1a6949f5a7eae5e8209a1d809a19abc55ed Mon Sep 17 00:00:00 2001 From: MITSUNARI Shigeo Date: Fri, 19 Jan 2024 13:04:54 +0900 Subject: [PATCH] refactoring vint --- include/mcl/vint.hpp | 309 +++++++++++++------------------------------ 1 file changed, 91 insertions(+), 218 deletions(-) diff --git a/include/mcl/vint.hpp b/include/mcl/vint.hpp index 8c5c9634..d29cb4e8 100644 --- a/include/mcl/vint.hpp +++ b/include/mcl/vint.hpp @@ -22,84 +22,16 @@ namespace mcl { -namespace vint { - -class FixedBuffer { - static const size_t N = maxUnitSize * 2 + 1; - size_t size_; - Unit v_[N]; -public: - FixedBuffer() - : size_(0) - { - } - FixedBuffer(const FixedBuffer& rhs) - { - operator=(rhs); - } - FixedBuffer& operator=(const FixedBuffer& rhs) - { - size_ = rhs.size_; -#if defined(__GNUC__) && !defined(__EMSCRIPTEN__) && !defined(__clang__) - #pragma GCC diagnostic push - #pragma GCC diagnostic ignored "-Wmaybe-uninitialized" -#endif - for (size_t i = 0; i < size_; i++) { - v_[i] = rhs.v_[i]; - } -#if defined(__GNUC__) && !defined(__EMSCRIPTEN__) && !defined(__clang__) - #pragma GCC diagnostic pop -#endif - return *this; - } - void clear() { size_ = 0; } - void alloc(bool *pb, size_t n) - { - if (n > N) { - *pb = false; - return; - } - size_ = n; - *pb = true; - } - void swap(FixedBuffer& rhs) - { - FixedBuffer *p1 = this; - FixedBuffer *p2 = &rhs; - if (p1->size_ < p2->size_) { - fp::swap_(p1, p2); - } - assert(p1->size_ >= p2->size_); - for (size_t i = 0; i < p2->size_; i++) { - fp::swap_(p1->v_[i], p2->v_[i]); - } - for (size_t i = p2->size_; i < p1->size_; i++) { - p2->v_[i] = p1->v_[i]; - } - fp::swap_(p1->size_, p2->size_); - } - // to avoid warning of gcc - void verify(size_t n) const - { - assert(n <= N); - (void)n; - } - const Unit& operator[](size_t n) const { verify(n); return v_[n]; } - Unit& operator[](size_t n) { verify(n); return v_[n]; } -}; - -} // vint - /** signed integer with variable length */ class Vint { public: - typedef vint::FixedBuffer Buffer; static const size_t UnitBitSize = sizeof(Unit) * 8; static const int invalidVar = -2147483647 - 1; // abs(invalidVar) is not defined + static const size_t N = maxUnitSize * 2 + 1; private: - Buffer buf_; + Unit buf_[N]; size_t size_; bool isNeg_; void trim(size_t n) @@ -118,29 +50,24 @@ class Vint { isNeg_ = false; } } - static int ucompare(const Buffer& x, size_t xn, const Buffer& y, size_t yn) + static int ucompare(const Unit *x, size_t xn, const Unit *y, size_t yn) { - if (xn == yn) return bint::cmpN(&x[0], &y[0], xn); + if (xn == yn) return bint::cmpN(x, y, xn); return xn > yn ? 1 : -1; } - static void uadd(Vint& z, const Buffer& x, size_t xn, const Buffer& y, size_t yn) + static void uadd(Vint& z, const Unit *px, size_t xn, const Unit *py, size_t yn) { - const Unit *px = &x[0]; - const Unit *py = &y[0]; if (yn > xn) { fp::swap_(xn, yn); fp::swap_(px, py); } assert(xn >= yn); - bool b; // &x[0] and &y[0] will not change if z == x or z == y because they are FixedBuffer - z.buf_.alloc(&b, xn + 1); - assert(b); - if (!b) { + if (!z.setSize(xn + 1)) { z.clear(); return; } - Unit *dst = &z.buf_[0]; + Unit *dst = z.buf_; Unit c = bint::addN(dst, px, py, yn); if (xn > yn) { size_t n = xn - yn; @@ -150,49 +77,38 @@ class Vint { dst[xn] = c; z.trim(xn + 1); } - static void uadd1(Vint& z, const Buffer& x, size_t xn, Unit y) + static void uadd1(Vint& z, const Unit *x, size_t xn, Unit y) { size_t zn = xn + 1; - bool b; - z.buf_.alloc(&b, zn); - assert(b); - if (!b) { + if (!z.setSize(zn)) { z.clear(); return; } - if (&z.buf_[0] != &x[0]) bint::copyN(&z.buf_[0], &x[0], xn); - z.buf_[zn - 1] = bint::addUnit(&z.buf_[0], xn, y); + if (z.buf_ != x) bint::copyN(z.buf_, x, xn); + z.buf_[zn - 1] = bint::addUnit(z.buf_, xn, y); z.trim(zn); } - static void usub1(Vint& z, const Buffer& x, size_t xn, Unit y) + static void usub1(Vint& z, const Unit *x, size_t xn, Unit y) { size_t zn = xn; - bool b; - z.buf_.alloc(&b, zn); - assert(b); - if (!b) { + if (!z.setSize(zn)) { z.clear(); - return; } - Unit *dst = &z.buf_[0]; - const Unit *src = &x[0]; - if (dst != src) bint::copyN(dst, src, xn); + Unit *dst = z.buf_; + if (dst != x) bint::copyN(dst, x, xn); Unit c = bint::subUnit(dst, xn, y); (void)c; assert(!c); z.trim(zn); } - static void usub(Vint& z, const Buffer& x, size_t xn, const Buffer& y, size_t yn) + static void usub(Vint& z, const Unit *x, size_t xn, const Unit *y, size_t yn) { assert(xn >= yn); - bool b; - z.buf_.alloc(&b, xn); - assert(b); - if (!b) { + if (!z.setSize(xn)) { z.clear(); return; } - Unit c = bint::subN(&z.buf_[0], &x[0], &y[0], yn); + Unit c = bint::subN(z.buf_, x, y, yn); if (xn > yn) { size_t n = xn - yn; Unit *dst = &z.buf_[yn]; @@ -257,28 +173,23 @@ class Vint { @param q [out] x / y if q != 0 @param r [out] x % y */ - static void udiv(Vint* q, Vint& r, const Buffer& x, size_t xn, const Buffer& y, size_t yn) + static void udiv(Vint* q, Vint& r, const Unit *x, size_t xn, const Unit *y, size_t yn) { assert(q != &r); if (xn < yn) { - r.buf_ = x; - r.trim(xn); + r.copy(x, xn); if (q) q->clear(); return; } size_t qn = xn - yn + 1; - bool b; if (q) { - q->buf_.alloc(&b, qn); - assert(b); (void)b; + q->setSize(qn); } Unit *xx = (Unit*)CYBOZU_ALLOCA(sizeof(Unit) * xn); - bint::copyN(xx, &x[0], xn); + bint::copyN(xx, x, xn); Unit *qq = q ? &q->buf_[0] : 0; size_t rn = bint::div(qq, qn, xx, xn, &y[0], yn); - r.buf_.alloc(&b, rn); - assert(b); (void)b; - bint::copyN(&r.buf_[0], xx, rn); + r.copy(xx, rn); if (q) { q->trim(qn); } @@ -372,31 +283,41 @@ class Vint { } } } - + bool setSize(size_t n) + { + if (n > N) return false; + size_ = n; + return true; + } + void copy(const Unit *x, size_t n) + { + if (setSize(n)) { + bint::copyN(buf_, x, n); + } + } public: - Vint(int x = 0) - : size_(0) + Vint() + : size_(1) + , isNeg_(false) + { + buf_[0] = 0; + } + Vint(int x) { *this = x; } Vint(Unit x) - : size_(0) { *this = x; } Vint(const Vint& rhs) - : buf_(rhs.buf_) - , size_(rhs.size_) - , isNeg_(rhs.isNeg_) { + *this = rhs; } Vint& operator=(int x) { assert(x != invalidVar); isNeg_ = x < 0; - bool b; - buf_.alloc(&b, 1); - assert(b); (void)b; buf_[0] = fp::abs_(x); size_ = 1; return *this; @@ -404,32 +325,20 @@ class Vint { Vint& operator=(Unit x) { isNeg_ = false; - bool b; - buf_.alloc(&b, 1); - assert(b); (void)b; buf_[0] = x; size_ = 1; return *this; } Vint& operator=(const Vint& rhs) { - buf_ = rhs.buf_; size_ = rhs.size_; isNeg_ = rhs.isNeg_; + mcl::bint::copyN(buf_, rhs.buf_, size_); return *this; } - void swap(Vint& rhs) -#if CYBOZU_CPP_VERSION >= CYBOZU_CPP_VERSION_CPP11 - noexcept -#endif - { - fp::swap_(buf_, rhs.buf_); - fp::swap_(size_, rhs.size_); - fp::swap_(isNeg_, rhs.isNeg_); - } void dump(const char *msg = "") const { - bint::dump(&buf_[0], size_, msg); + bint::dump(buf_, size_, msg); } /* set positive value @@ -445,11 +354,14 @@ class Vint { return; } size_t unitSize = (sizeof(S) * size + sizeof(Unit) - 1) / sizeof(Unit); - buf_.alloc(pb, unitSize); - if (!*pb) return; - bool b = fp::convertArrayAsLE(&buf_[0], unitSize, x, size); - assert(b); - (void)b; + if (!setSize(unitSize)) { + *pb = false; + return; + } + *pb = fp::convertArrayAsLE(buf_, unitSize, x, size); + if (!*pb) { + return; + } trim(unitSize); } /* @@ -460,9 +372,8 @@ class Vint { assert(max > 0); if (rg.isZero()) rg = fp::RandGen::get(); size_t n = max.size(); - buf_.alloc(pb, n); - if (!*pb) return; - rg.read(pb, &buf_[0], n * sizeof(buf_[0])); + size_ = n; + rg.read(pb, buf_, n * sizeof(Unit)); if (!*pb) return; trim(n); *this %= max; @@ -480,7 +391,7 @@ class Vint { *pb = false; return; } - bint::copyN(x, &buf_[0], n); + bint::copyN(x, buf_, n); bint::clearN(x + n, maxSize - n); *pb = true; } @@ -490,7 +401,7 @@ class Vint { { if (isNeg_) cybozu::writeChar(pb, os, '-'); char buf[1024]; - size_t n = mcl::fp::arrayToStr(buf, sizeof(buf), &buf_[0], size_, base, false); + size_t n = mcl::fp::arrayToStr(buf, sizeof(buf), buf_, size_, base, false); if (n == 0) { *pb = false; return; @@ -526,24 +437,23 @@ class Vint { // ignore sign bool testBit(size_t i) const { + assert(i < N * UnitBitSize); + if (i >= N * UnitBitSize) return false; size_t q = i / UnitBitSize; size_t r = i % UnitBitSize; - assert(q <= size()); Unit mask = Unit(1) << r; return (buf_[q] & mask) != 0; } void setBit(size_t i, bool v = true) { + assert(i < N * UnitBitSize); + if (i >= N * UnitBitSize) return; size_t q = i / UnitBitSize; size_t r = i % UnitBitSize; - assert(q <= size()); - bool b; - buf_.alloc(&b, q + 1); - assert(b); - if (!b) { - clear(); - return; + if (q > size_) { + bint::clearN(buf_ + size_, q - size_); } + size_ = q + 1; Unit mask = Unit(1) << r; if (v) { buf_[q] |= mask; @@ -562,9 +472,10 @@ class Vint { { // allow twice size of MCL_MAX_BIT_SIZE because of multiplication const size_t maxN = (MCL_MAX_BIT_SIZE * 2 + UnitBitSize - 1) / UnitBitSize; - buf_.alloc(pb, maxN); - if (!*pb) return; - *pb = false; + if (!setSize(maxN)) { + *pb = false; + return; + } isNeg_ = false; size_t len = strlen(str); size_t n = fp::strToArray(&isNeg_, &buf_[0], maxN, str, len, base); @@ -595,7 +506,7 @@ class Vint { } else { // same sign Unit y0 = fp::abs_(y); - int c = (x.size() > 1) ? 1 : bint::cmpT<1>(&x.buf_[0], &y0); + int c = (x.size() > 1) ? 1 : bint::cmpT<1>(x.buf_, &y0); if (x.isNeg_) { return -c; } @@ -615,7 +526,7 @@ class Vint { uint32_t getLow32bit() const { return (uint32_t)buf_[0]; } bool isOdd() const { return (buf_[0] & 1) == 1; } bool isEven() const { return !isOdd(); } - const Unit *getUnit() const { return &buf_[0]; } + const Unit *getUnit() const { return buf_; } size_t getUnitSize() const { return size_; } static void add(Vint& z, const Vint& x, const Vint& y) { @@ -630,10 +541,8 @@ class Vint { const size_t xn = x.size(); const size_t yn = y.size(); size_t zn = xn + yn; - bool b; - z.buf_.alloc(&b, zn); - assert(b); (void)b; - bint::mulNM(&z.buf_[0], &x.buf_[0], xn, &y.buf_[0], yn); + if (!z.setSize(zn)) return; + bint::mulNM(z.buf_, x.buf_, xn, y.buf_, yn); z.trim(zn); z.isNeg_ = x.isNeg_ ^ y.isNeg_; } @@ -653,14 +562,8 @@ class Vint { { size_t xn = x.size(); size_t zn = xn + 1; - bool b; - z.buf_.alloc(&b, zn); - assert(b); - if (!b) { - z.clear(); - return; - } - z.buf_[zn - 1] = bint::mulUnitN(&z.buf_[0], &x.buf_[0], y, xn); + if (!z.setSize(zn)) return; + z.buf_[zn - 1] = bint::mulUnitN(z.buf_, x.buf_, y, xn); z.isNeg_ = x.isNeg_; z.trim(zn); } @@ -706,17 +609,11 @@ class Vint { int r; if (q) { q->isNeg_ = xNeg ^ yNeg; - bool b; - q->buf_.alloc(&b, xn); - assert(b); - if (!b) { - q->clear(); - return 0; - } - r = (int)bint::divUnit(&q->buf_[0], &x.buf_[0], xn, absY); + if (!q->setSize(xn)) return 0; + r = (int)bint::divUnit(q->buf_, x.buf_, xn, absY); q->trim(xn); } else { - r = (int)bint::modUnit(&x.buf_[0], xn, absY); + r = (int)bint::modUnit(x.buf_, xn, absY); } return xNeg ? -r : r; } @@ -758,16 +655,8 @@ class Vint { { assert(!x.isNeg_); size_t xn = x.size(); - if (q) { - bool b; - q->buf_.alloc(&b, xn); - assert(b); - if (!b) { - q->clear(); - return 0; - } - } - Unit r = bint::divUnit(q ? &q->buf_[0] : 0, &x.buf_[0], xn, y); + if (q && !q->setSize(N)) return 0; + Unit r = bint::divUnit(q ? q->buf_ : 0, x.buf_, xn, y); if (q) { q->trim(xn); q->isNeg_ = false; @@ -805,10 +694,12 @@ class Vint { size_t n = fp::local::loadWord(buf, sizeof(buf), is); if (n == 0) return; const size_t maxN = 384 / (sizeof(MCL_SIZEOF_UNIT) * 8); - buf_.alloc(pb, maxN); - if (!*pb) return; + if (!setSize(maxN)) { + *pb = false; + return; + } isNeg_ = false; - n = fp::strToArray(&isNeg_, &buf_[0], maxN, buf, n, ioMode); + n = fp::strToArray(&isNeg_, buf_, maxN, buf, n, ioMode); if (n == 0) return; trim(n); *pb = true; @@ -819,10 +710,8 @@ class Vint { assert(shiftBit <= MCL_MAX_BIT_SIZE * 2); // many be too big size_t xn = x.size(); size_t yn = xn + (shiftBit + UnitBitSize - 1) / UnitBitSize; - bool b; - y.buf_.alloc(&b, yn); - assert(b); (void)b; - bint::shiftLeft(&y.buf_[0], &x.buf_[0], shiftBit, xn); + if (!y.setSize(yn)) return; + bint::shiftLeft(y.buf_, x.buf_, shiftBit, xn); y.isNeg_ = x.isNeg_; y.trim(yn); } @@ -836,9 +725,7 @@ class Vint { return; } size_t yn = xn - shiftBit / UnitBitSize; - bool b; - y.buf_.alloc(&b, yn); - assert(b); (void)b; + y.setSize(yn); bint::shiftRight(&y.buf_[0], &x.buf_[0], shiftBit, xn); y.isNeg_ = x.isNeg_; y.trim(yn); @@ -870,16 +757,11 @@ class Vint { size_t xn = px->size(); size_t yn = py->size(); assert(xn >= yn); - bool b; - z.buf_.alloc(&b, xn); - assert(b); - if (!b) { - z.clear(); - } + z.setSize(xn); for (size_t i = 0; i < yn; i++) { z.buf_[i] = x.buf_[i] | y.buf_[i]; } - bint::copyN(&z.buf_[0] + yn, &px->buf_[0] + yn, xn - yn); + bint::copyN(z.buf_ + yn, px->buf_ + yn, xn - yn); z.trim(xn); } static void andBit(Vint& z, const Vint& x, const Vint& y) @@ -891,13 +773,7 @@ class Vint { } size_t yn = py->size(); assert(px->size() >= yn); - bool b; - z.buf_.alloc(&b, yn); - assert(b); - if (!b) { - z.clear(); - return; - } + z.setSize(yn); for (size_t i = 0; i < yn; i++) { z.buf_[i] = x.buf_[i] & y.buf_[i]; } @@ -912,9 +788,6 @@ class Vint { static void andBitu1(Vint& z, const Vint& x, Unit y) { assert(!x.isNeg_); - bool b; - z.buf_.alloc(&b, 1); - assert(b); (void)b; z.buf_[0] = x.buf_[0] & y; z.size_ = 1; z.isNeg_ = false; @@ -925,7 +798,7 @@ class Vint { static void pow(Vint& z, const Vint& x, const Vint& y) { assert(!y.isNeg_); - powT(z, x, &y.buf_[0], y.size(), mul, sqr); + powT(z, x, y.buf_, y.size(), mul, sqr); } /* REMARK y >= 0;