From b096acb5a4c040cb03d8cec8cae65b223589ff51 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Mon, 3 Jul 2023 15:31:48 +0300 Subject: [PATCH 1/9] feat: use squaring instead of mimcs for commitment expansion --- std/multicommit/nativecommit.go | 24 +++++++----------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/std/multicommit/nativecommit.go b/std/multicommit/nativecommit.go index f3f81f2a4b..4f8e56a0cc 100644 --- a/std/multicommit/nativecommit.go +++ b/std/multicommit/nativecommit.go @@ -20,7 +20,6 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/kvstore" - "github.com/consensys/gnark/std/hash/mimc" ) type multicommitter struct { @@ -94,22 +93,13 @@ func (mct *multicommitter) commitAndCall(api frontend.API) error { if err != nil { return fmt.Errorf("commit: %w", err) } - if len(mct.cbs) == 1 { - if err = mct.cbs[0](api, cmt); err != nil { - return fmt.Errorf("single callback: %w", err) - } - } else { - hasher, err := mimc.NewMiMC(api) - if err != nil { - return fmt.Errorf("new hasher: %w", err) - } - for i, cb := range mct.cbs { - hasher.Reset() - hasher.Write(i+1, cmt) - localcmt := hasher.Sum() - if err = cb(api, localcmt); err != nil { - return fmt.Errorf("with commitment callback %d: %w", i, err) - } + if err = mct.cbs[0](api, cmt); err != nil { + return fmt.Errorf("callback 0: %w", err) + } + for i := 1; i < len(mct.cbs); i++ { + cmt = api.Mul(cmt, cmt) + if err := mct.cbs[i](api, cmt); err != nil { + return fmt.Errorf("callback %d: %w", i, err) } } return nil From 1539ac2a5ed8687602e5fce4323fa2cca27d700c Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Mon, 3 Jul 2023 15:33:12 +0300 Subject: [PATCH 2/9] feat: cache finite field APIs --- std/math/emulated/field.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/std/math/emulated/field.go b/std/math/emulated/field.go index 124f4e9b37..df5a6b6055 100644 --- a/std/math/emulated/field.go +++ b/std/math/emulated/field.go @@ -6,6 +6,7 @@ import ( "sync" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/kvstore" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" "github.com/consensys/gnark/std/rangecheck" @@ -44,6 +45,8 @@ type Field[T FieldParams] struct { checker frontend.Rangechecker } +type ctxKey[T FieldParams] struct{} + // NewField returns an object to be used in-circuit to perform emulated // arithmetic over the field defined by type parameter [FieldParams]. The // operations on this type are defined on [Element]. There is also another type @@ -53,6 +56,12 @@ type Field[T FieldParams] struct { // This is an experimental feature and performing emulated arithmetic in-circuit // is extremly costly. See package doc for more info. func NewField[T FieldParams](native frontend.API) (*Field[T], error) { + if storer, ok := native.(kvstore.Store); ok { + ff := storer.GetKeyValue(ctxKey[T]{}) + if ff, ok := ff.(*Field[T]); ok { + return ff, nil + } + } f := &Field[T]{ api: native, log: logger.Logger(), @@ -89,6 +98,9 @@ func NewField[T FieldParams](native frontend.API) (*Field[T], error) { return nil, fmt.Errorf("elements with limb length %d does not fit into scalar field", f.fParams.BitsPerLimb()) } + if storer, ok := native.(kvstore.Store); ok { + storer.SetKeyValue(ctxKey[T]{}, f) + } return f, nil } From c7587c785f5b81f4402a30086642065449752836 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Mon, 3 Jul 2023 15:37:55 +0300 Subject: [PATCH 3/9] feat: add mulmod by poly evaluation --- std/math/emulated/element.go | 3 + std/math/emulated/field.go | 3 + std/math/emulated/field_mul.go | 229 +++++++++++++++++++++++++++++++++ std/math/emulated/hints.go | 1 + 4 files changed, 236 insertions(+) create mode 100644 std/math/emulated/field_mul.go diff --git a/std/math/emulated/element.go b/std/math/emulated/element.go index b75511b7a9..f3da9d3c7c 100644 --- a/std/math/emulated/element.go +++ b/std/math/emulated/element.go @@ -31,6 +31,9 @@ type Element[T FieldParams] struct { // ensure that the limbs are width-constrained. We do not store the // enforcement info in the Element to prevent modifying the witness. internal bool + + isEvaluated bool + evaluation frontend.Variable `gnark:"-"` } // ValueOf returns an Element[T] from a constant value. diff --git a/std/math/emulated/field.go b/std/math/emulated/field.go index df5a6b6055..3480aaab31 100644 --- a/std/math/emulated/field.go +++ b/std/math/emulated/field.go @@ -43,6 +43,8 @@ type Field[T FieldParams] struct { constrainedLimbs map[uint64]struct{} checker frontend.Rangechecker + + mulChecks []mulCheck[T] } type ctxKey[T FieldParams] struct{} @@ -98,6 +100,7 @@ func NewField[T FieldParams](native frontend.API) (*Field[T], error) { return nil, fmt.Errorf("elements with limb length %d does not fit into scalar field", f.fParams.BitsPerLimb()) } + native.Compiler().Defer(f.performMulChecks) if storer, ok := native.(kvstore.Store); ok { storer.SetKeyValue(ctxKey[T]{}, f) } diff --git a/std/math/emulated/field_mul.go b/std/math/emulated/field_mul.go new file mode 100644 index 0000000000..67c5bd75c7 --- /dev/null +++ b/std/math/emulated/field_mul.go @@ -0,0 +1,229 @@ +package emulated + +import ( + "fmt" + "math" + "math/big" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/internal/multicommit" +) + +type mulCheck[T FieldParams] struct { + f *Field[T] + // a * b = r + k*p + c + a, b *Element[T] // inputs + r *Element[T] // reduced value + k *Element[T] // coefficient + c *Element[T] // carry +} + +func (mc *mulCheck[T]) evalRound1(api frontend.API, at []frontend.Variable) { + mc.c = mc.f.evalWithChallenge(mc.c, at) + mc.r = mc.f.evalWithChallenge(mc.r, at) + mc.k = mc.f.evalWithChallenge(mc.k, at) +} + +func (mc *mulCheck[T]) evalRound2(api frontend.API, at []frontend.Variable) { + mc.a = mc.f.evalWithChallenge(mc.a, at) + mc.b = mc.f.evalWithChallenge(mc.b, at) +} + +func (mc *mulCheck[T]) check(api frontend.API, peval, coef frontend.Variable) { + ls := api.Mul(mc.a.evaluation, mc.b.evaluation) + rs := api.Add(mc.r.evaluation, api.Mul(peval, mc.k.evaluation), api.Mul(mc.c.evaluation, coef)) + api.AssertIsEqual(ls, rs) +} + +func (f *Field[T]) mulMod(a, b *Element[T], nextOverflow uint) *Element[T] { + f.enforceWidthConditional(a) + f.enforceWidthConditional(b) + k, r, c, err := f.callMulHint(a, b) + if err != nil { + panic(err) + } + mc := mulCheck[T]{ + f: f, + a: a, + b: b, + c: c, + k: k, + r: r, + } + f.mulChecks = append(f.mulChecks, mc) + return r +} + +func (f *Field[T]) evalWithChallenge(a *Element[T], at []frontend.Variable) *Element[T] { + if a.isEvaluated { + return a + } + if len(at) < len(a.Limbs)-1 { + panic("evaluation powers less than limbs") + } + sum := f.api.Mul(a.Limbs[0], 1) // copy because we use MulAcc + for i := 1; i < len(a.Limbs); i++ { + sum = f.api.MulAcc(sum, a.Limbs[i], at[i-1]) + } + a.isEvaluated = true + a.evaluation = sum + return a +} + +func (f *Field[T]) performMulChecks(api frontend.API) error { + // use given api. We are in defer and API may be different to what we have + // stored. + + // there are no multiplication checks, nothing to do + if len(f.mulChecks) == 0 { + return nil + } + + // we construct a list of elements we want to commit to. Even though we have + // commited when doing range checks, do it again here explicitly for safety. + // TODO: committing is actually expensive in PLONK. We create a constraint + // for every variable we commit to (to set the selector polynomial). So, it + // is actually better not to commit again. However, if we would be to use + // multi-commit and range checks are in different commitment, then we have + // problem. + var toCommit []frontend.Variable + for i := range f.mulChecks { + toCommit = append(toCommit, f.mulChecks[i].a.Limbs...) + toCommit = append(toCommit, f.mulChecks[i].b.Limbs...) + toCommit = append(toCommit, f.mulChecks[i].r.Limbs...) + toCommit = append(toCommit, f.mulChecks[i].k.Limbs...) + toCommit = append(toCommit, f.mulChecks[i].c.Limbs...) + } + multicommit.WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { + coefsLen := 0 + for i := range f.mulChecks { + coefsLen = max(coefsLen, len(f.mulChecks[i].c.Limbs)) + } + at := make([]frontend.Variable, coefsLen) + var prev frontend.Variable = 1 + for i := range at { + at[i] = api.Mul(prev, commitment) + prev = at[i] + } + for i := range f.mulChecks { + f.mulChecks[i].evalRound1(api, at) + } + // assuming r is input to some other multiplication, then is already evaluated + for i := range f.mulChecks { + f.mulChecks[i].evalRound2(api, at) + } + pval := f.evalWithChallenge(f.Modulus(), at) + coef := big.NewInt(1) + coef.Lsh(coef, f.fParams.BitsPerLimb()) + ccoef := api.Sub(coef, commitment) + for i := range f.mulChecks { + f.mulChecks[i].check(api, pval.evaluation, ccoef) + } + return nil + }, toCommit...) + return nil +} + +func (f *Field[T]) callMulHint(a, b *Element[T]) (quo, rem, carries *Element[T], err error) { + // inputs is always nblimbs + // quotient may be larger if inputs have overflow + // remainder is always nblimbs + // carries is 2 * nblimbs - 2 (do not consider first limb) + nextOverflow, _ := f.mulPreCond(a, b) + // skip error handle - it happens when we are supposed to reduce. But we + // already check it as a precondition. We only need the overflow here. + nbLimbs, nbBits := f.fParams.NbLimbs(), f.fParams.BitsPerLimb() + nbQuoLimbs := ((2*nbLimbs-1)*nbBits + nextOverflow + 1 - // + uint(f.fParams.Modulus().BitLen()) + // + nbBits - 1) / + nbBits + nbRemLimbs := nbLimbs + nbCarryLimbs := (nbQuoLimbs + nbLimbs) - 2 + hintInputs := []frontend.Variable{ + nbBits, + nbLimbs, + } + hintInputs = append(hintInputs, f.Modulus().Limbs...) + hintInputs = append(hintInputs, a.Limbs...) + hintInputs = append(hintInputs, b.Limbs...) + ret, err := f.api.NewHint(mulHint, int(nbQuoLimbs)+int(nbRemLimbs)+int(nbCarryLimbs), hintInputs...) + if err != nil { + err = fmt.Errorf("call hint: %w", err) + return + } + quo = f.packLimbs(ret[:nbQuoLimbs], false) + rem = f.packLimbs(ret[nbQuoLimbs:nbQuoLimbs+nbRemLimbs], true) + carries = f.newInternalElement(ret[nbQuoLimbs+nbRemLimbs:], 0) + return +} + +func mulHint(field *big.Int, inputs, outputs []*big.Int) error { + nbBits := int(inputs[0].Int64()) + nbLimbs := int(inputs[1].Int64()) + ptr := 2 + plimbs := inputs[ptr : ptr+nbLimbs] + ptr += nbLimbs + alimbs := inputs[ptr : ptr+nbLimbs] + ptr += nbLimbs + blimbs := inputs[ptr : ptr+nbLimbs] + + nbQuoLen := (len(outputs) - 2*nbLimbs + 2) / 2 + nbCarryLen := nbLimbs + nbQuoLen - 2 + outptr := 0 + quoLimbs := outputs[outptr : outptr+nbQuoLen] + outptr += nbQuoLen + remLimbs := outputs[outptr : outptr+nbLimbs] + outptr += nbLimbs + carryLimbs := outputs[outptr : outptr+nbCarryLen] + + p := new(big.Int) + a := new(big.Int) + b := new(big.Int) + if err := recompose(plimbs, uint(nbBits), p); err != nil { + return fmt.Errorf("recompose p: %w", err) + } + if err := recompose(alimbs, uint(nbBits), a); err != nil { + return fmt.Errorf("recompose a: %w", err) + } + if err := recompose(blimbs, uint(nbBits), b); err != nil { + return fmt.Errorf("recompose b: %w", err) + } + quo := new(big.Int) + rem := new(big.Int) + ab := new(big.Int).Mul(a, b) + quo.QuoRem(ab, p, rem) + if err := decompose(quo, uint(nbBits), quoLimbs); err != nil { + return fmt.Errorf("decompose quo: %w", err) + } + if err := decompose(rem, uint(nbBits), remLimbs); err != nil { + return fmt.Errorf("decompose rem: %w", err) + } + xp := make([]*big.Int, nbLimbs+nbQuoLen-1) + yp := make([]*big.Int, nbLimbs+nbQuoLen-1) + for i := range xp { + xp[i] = new(big.Int) + } + for i := range yp { + yp[i] = new(big.Int) + } + tmp := new(big.Int) + for i := 0; i < nbLimbs; i++ { + for j := 0; j < nbLimbs; j++ { + tmp.Mul(alimbs[i], blimbs[j]) + xp[i+j].Add(xp[i+j], tmp) + } + yp[i].Add(yp[i], remLimbs[i]) + for j := 0; j < nbQuoLen; j++ { + tmp.Mul(quoLimbs[j], plimbs[i]) + yp[i+j].Add(yp[i+j], tmp) + } + } + carry := new(big.Int) + for i := range carryLimbs { + carry.Add(carry, xp[i]) + carry.Sub(carry, yp[i]) + carry.Rsh(carry, uint(nbBits)) + carryLimbs[i] = new(big.Int).Set(carry) + } + return nil +} diff --git a/std/math/emulated/hints.go b/std/math/emulated/hints.go index 278518372e..b8821eb174 100644 --- a/std/math/emulated/hints.go +++ b/std/math/emulated/hints.go @@ -25,6 +25,7 @@ func GetHints() []solver.Hint { RemHint, RightShift, SqrtHint, + mulHint, } } From 46f08c5fbc8f0e6606ecaeacaf43a24e6a1ccc3d Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Mon, 3 Jul 2023 15:38:32 +0300 Subject: [PATCH 4/9] refactor: move multiplication --- std/math/emulated/field_mul.go | 120 +++++++++++++++++++++++++++++++- std/math/emulated/field_ops.go | 123 +-------------------------------- 2 files changed, 119 insertions(+), 124 deletions(-) diff --git a/std/math/emulated/field_mul.go b/std/math/emulated/field_mul.go index 67c5bd75c7..d44a229f6a 100644 --- a/std/math/emulated/field_mul.go +++ b/std/math/emulated/field_mul.go @@ -2,11 +2,11 @@ package emulated import ( "fmt" - "math" "math/big" + "math/bits" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/internal/multicommit" + "github.com/consensys/gnark/std/multicommit" ) type mulCheck[T FieldParams] struct { @@ -227,3 +227,119 @@ func mulHint(field *big.Int, inputs, outputs []*big.Int) error { } return nil } + +// Mul computes a*b and returns it. It doesn't reduce the output and it may be +// larger than the modulus. The returned Element has as many limbs as the inputs +// together. If the result wouldn't fit into Element, then locally reduces the +// inputs first. Doesn't mutate inputs. +// +// Even though this method skips reduction and allows for multiplication chains, +// then in most cases it is more efficient to use [Field[T].MulMod] as reducing +// Element with 2 times the limbs is 2 times more expensive. +// +// For multiplying by a constant, use [Field[T].MulConst] method which is more +// efficient. +func (f *Field[T]) Mul(a, b *Element[T]) *Element[T] { + return f.reduceAndOp(f.mulMod, f.mulPreCond, a, b) +} + +// Mul computes a*b and reduces it modulo the field order. The returned Element +// has default number of limbs and zero overflow. +func (f *Field[T]) MulMod(a, b *Element[T]) *Element[T] { + return f.reduceAndOp(f.mulMod, f.mulPreCond, a, b) +} + +// MulConst multiplies a by a constant c and returns it. We assume that the +// input constant is "small", so that we can compute the product by multiplying +// all individual limbs with the constant. If it is not small, then use the +// general [Field[T].Mul] or [Field[T].MulMod] with creating new Element from +// the constant on-the-fly. +func (f *Field[T]) MulConst(a *Element[T], c *big.Int) *Element[T] { + switch c.Sign() { + case -1: + f.MulConst(f.Neg(a), new(big.Int).Neg(c)) + case 0: + return f.Zero() + } + cbl := uint(c.BitLen()) + if cbl > f.maxOverflow() { + panic(fmt.Sprintf("constant bit length %d exceeds max %d", cbl, f.maxOverflow())) + } + return f.reduceAndOp( + func(a, _ *Element[T], u uint) *Element[T] { + if ba, aConst := f.constantValue(a); aConst { + ba.Mul(ba, c) + return newConstElement[T](ba) + } + limbs := make([]frontend.Variable, len(a.Limbs)) + for i := range a.Limbs { + limbs[i] = f.api.Mul(a.Limbs[i], c) + } + return f.newInternalElement(limbs, a.overflow+cbl) + }, + func(a, _ *Element[T]) (nextOverflow uint, err error) { + nextOverflow = a.overflow + uint(cbl) + if nextOverflow > f.maxOverflow() { + err = overflowError{op: "mulConst", nextOverflow: nextOverflow, maxOverflow: f.maxOverflow()} + } + return + }, + a, nil, + ) +} + +func (f *Field[T]) mulPreCond(a, b *Element[T]) (nextOverflow uint, err error) { + reduceRight := a.overflow < b.overflow + nbResLimbs := nbMultiplicationResLimbs(len(a.Limbs), len(b.Limbs)) + nbLimbsOverflow := uint(1) + if nbResLimbs > 0 { + nbLimbsOverflow = uint(bits.Len(uint(2*nbResLimbs - 1))) + } + nextOverflow = f.fParams.BitsPerLimb() + nbLimbsOverflow + a.overflow + b.overflow + if nextOverflow > f.maxOverflow() { + err = overflowError{op: "mul", nextOverflow: nextOverflow, maxOverflow: f.maxOverflow(), reduceRight: reduceRight} + } + return +} + +func (f *Field[T]) mul(a, b *Element[T], nextOverflow uint) *Element[T] { + ba, aConst := f.constantValue(a) + bb, bConst := f.constantValue(b) + if aConst && bConst { + ba.Mul(ba, bb).Mod(ba, f.fParams.Modulus()) + return newConstElement[T](ba) + } + + // mulResult contains the result (out of circuit) of a * b school book multiplication + // len(mulResult) == len(a) + len(b) - 1 + mulResult, err := f.computeMultiplicationHint(a.Limbs, b.Limbs) + if err != nil { + panic(fmt.Sprintf("multiplication hint: %s", err)) + } + + // we computed the result of the mul outside the circuit (mulResult) + // and we want to constrain inside the circuit that this injected value + // actually matches the in-circuit a * b values + // create constraints (\sum_{i=0}^{m-1} a_i c^i) * (\sum_{i=0}^{m-1} b_i + // c^i) = (\sum_{i=0}^{2m-2} z_i c^i) for c \in {1, 2m-1} + w := new(big.Int) + for c := 1; c <= len(mulResult); c++ { + w.SetInt64(1) // c^i + l := f.api.Mul(a.Limbs[0], 1) + r := f.api.Mul(b.Limbs[0], 1) + o := f.api.Mul(mulResult[0], 1) + + for i := 1; i < len(mulResult); i++ { + w.Lsh(w, uint(c)) + if i < len(a.Limbs) { + l = f.api.MulAcc(l, a.Limbs[i], w) + } + if i < len(b.Limbs) { + r = f.api.MulAcc(r, b.Limbs[i], w) + } + o = f.api.MulAcc(o, mulResult[i], w) + } + f.api.AssertIsEqual(f.api.Mul(l, r), o) + } + return f.newInternalElement(mulResult, nextOverflow) +} diff --git a/std/math/emulated/field_ops.go b/std/math/emulated/field_ops.go index 1cbcc45ef3..c089af9b45 100644 --- a/std/math/emulated/field_ops.go +++ b/std/math/emulated/field_ops.go @@ -3,8 +3,6 @@ package emulated import ( "errors" "fmt" - "math/big" - "math/bits" "github.com/consensys/gnark/frontend" ) @@ -133,126 +131,7 @@ func (f *Field[T]) add(a, b *Element[T], nextOverflow uint) *Element[T] { return f.newInternalElement(limbs, nextOverflow) } -// Mul computes a*b and returns it. It doesn't reduce the output and it may be -// larger than the modulus. The returned Element has as many limbs as the inputs -// together. If the result wouldn't fit into Element, then locally reduces the -// inputs first. Doesn't mutate inputs. -// -// Even though this method skips reduction and allows for multiplication chains, -// then in most cases it is more efficient to use [Field[T].MulMod] as reducing -// Element with 2 times the limbs is 2 times more expensive. -// -// For multiplying by a constant, use [Field[T].MulConst] method which is more -// efficient. -// -// Uses [MultiplicationHint]. -func (f *Field[T]) Mul(a, b *Element[T]) *Element[T] { - return f.reduceAndOp(f.mul, f.mulPreCond, a, b) -} - -// Mul computes a*b and reduces it modulo the field order. The returned Element -// has default number of limbs and zero overflow. -func (f *Field[T]) MulMod(a, b *Element[T]) *Element[T] { - r := f.Mul(a, b) - return f.Reduce(r) -} - -// MulConst multiplies a by a constant c and returns it. We assume that the -// input constant is "small", so that we can compute the product by multiplying -// all individual limbs with the constant. If it is not small, then use the -// general [Field[T].Mul] or [Field[T].MulMod] with creating new Element from -// the constant on-the-fly. -func (f *Field[T]) MulConst(a *Element[T], c *big.Int) *Element[T] { - switch c.Sign() { - case -1: - f.MulConst(f.Neg(a), new(big.Int).Neg(c)) - case 0: - return f.Zero() - } - cbl := uint(c.BitLen()) - if cbl > f.maxOverflow() { - panic(fmt.Sprintf("constant bit length %d exceeds max %d", cbl, f.maxOverflow())) - } - return f.reduceAndOp( - func(a, _ *Element[T], u uint) *Element[T] { - if ba, aConst := f.constantValue(a); aConst { - ba.Mul(ba, c) - return newConstElement[T](ba) - } - limbs := make([]frontend.Variable, len(a.Limbs)) - for i := range a.Limbs { - limbs[i] = f.api.Mul(a.Limbs[i], c) - } - return f.newInternalElement(limbs, a.overflow+cbl) - }, - func(a, _ *Element[T]) (nextOverflow uint, err error) { - nextOverflow = a.overflow + uint(cbl) - if nextOverflow > f.maxOverflow() { - err = overflowError{op: "mulConst", nextOverflow: nextOverflow, maxOverflow: f.maxOverflow()} - } - return - }, - a, nil, - ) -} - -func (f *Field[T]) mulPreCond(a, b *Element[T]) (nextOverflow uint, err error) { - reduceRight := a.overflow < b.overflow - nbResLimbs := nbMultiplicationResLimbs(len(a.Limbs), len(b.Limbs)) - nbLimbsOverflow := uint(1) - if nbResLimbs > 0 { - nbLimbsOverflow = uint(bits.Len(uint(2*nbResLimbs - 1))) - } - nextOverflow = f.fParams.BitsPerLimb() + nbLimbsOverflow + a.overflow + b.overflow - if nextOverflow > f.maxOverflow() { - err = overflowError{op: "mul", nextOverflow: nextOverflow, maxOverflow: f.maxOverflow(), reduceRight: reduceRight} - } - return -} - -func (f *Field[T]) mul(a, b *Element[T], nextOverflow uint) *Element[T] { - ba, aConst := f.constantValue(a) - bb, bConst := f.constantValue(b) - if aConst && bConst { - ba.Mul(ba, bb).Mod(ba, f.fParams.Modulus()) - return newConstElement[T](ba) - } - - // mulResult contains the result (out of circuit) of a * b school book multiplication - // len(mulResult) == len(a) + len(b) - 1 - mulResult, err := f.computeMultiplicationHint(a.Limbs, b.Limbs) - if err != nil { - panic(fmt.Sprintf("multiplication hint: %s", err)) - } - - // we computed the result of the mul outside the circuit (mulResult) - // and we want to constrain inside the circuit that this injected value - // actually matches the in-circuit a * b values - // create constraints (\sum_{i=0}^{m-1} a_i c^i) * (\sum_{i=0}^{m-1} b_i - // c^i) = (\sum_{i=0}^{2m-2} z_i c^i) for c \in {1, 2m-1} - w := new(big.Int) - for c := 1; c <= len(mulResult); c++ { - w.SetInt64(1) // c^i - l := f.api.Mul(a.Limbs[0], 1) - r := f.api.Mul(b.Limbs[0], 1) - o := f.api.Mul(mulResult[0], 1) - - for i := 1; i < len(mulResult); i++ { - w.Lsh(w, uint(c)) - if i < len(a.Limbs) { - l = f.api.MulAcc(l, a.Limbs[i], w) - } - if i < len(b.Limbs) { - r = f.api.MulAcc(r, b.Limbs[i], w) - } - o = f.api.MulAcc(o, mulResult[i], w) - } - f.api.AssertIsEqual(f.api.Mul(l, r), o) - } - return f.newInternalElement(mulResult, nextOverflow) -} - -// Reduce reduces a modulo the field order and returns it. Uses hint [RemHint]. +// Reduce reduces a modulo the field order and returns it. func (f *Field[T]) Reduce(a *Element[T]) *Element[T] { f.enforceWidthConditional(a) if a.overflow == 0 { From d00f3cfdc1b5a5b13152714ed69ca59cde7ca207 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Mon, 3 Jul 2023 15:38:55 +0300 Subject: [PATCH 5/9] feat: use mulmod for reduction --- std/math/emulated/field_ops.go | 8 +------- std/math/emulated/hints.go | 34 ---------------------------------- 2 files changed, 1 insertion(+), 41 deletions(-) diff --git a/std/math/emulated/field_ops.go b/std/math/emulated/field_ops.go index c089af9b45..4e9b26462c 100644 --- a/std/math/emulated/field_ops.go +++ b/std/math/emulated/field_ops.go @@ -142,14 +142,8 @@ func (f *Field[T]) Reduce(a *Element[T]) *Element[T] { if _, aConst := f.constantValue(a); aConst { panic("trying to reduce a constant, which happen to have an overflow flag set") } - // slow path - use hint to reduce value - e, err := f.computeRemHint(a, f.Modulus()) - if err != nil { - panic(fmt.Sprintf("reduction hint: %v", err)) - } - f.AssertIsEqual(e, a) - return e + return f.mulMod(a, f.One(), 0) } // Sub subtracts b from a and returns it. Reduces locally if wouldn't fit into diff --git a/std/math/emulated/hints.go b/std/math/emulated/hints.go index b8821eb174..16a560f9c7 100644 --- a/std/math/emulated/hints.go +++ b/std/math/emulated/hints.go @@ -22,7 +22,6 @@ func GetHints() []solver.Hint { QuoHint, InverseHint, MultiplicationHint, - RemHint, RightShift, SqrtHint, mulHint, @@ -89,39 +88,6 @@ func MultiplicationHint(mod *big.Int, inputs []*big.Int, outputs []*big.Int) err return nil } -// computeRemHint packs inputs for the RemHint hint function. -// sets z to the remainder x%y for y != 0 and returns z. -func (f *Field[T]) computeRemHint(x, y *Element[T]) (z *Element[T], err error) { - var fp T - hintInputs := []frontend.Variable{ - fp.BitsPerLimb(), - len(x.Limbs), - } - hintInputs = append(hintInputs, x.Limbs...) - hintInputs = append(hintInputs, y.Limbs...) - limbs, err := f.api.NewHint(RemHint, int(len(y.Limbs)), hintInputs...) - if err != nil { - return nil, err - } - return f.packLimbs(limbs, true), nil -} - -// RemHint sets z to the remainder x%y for y != 0 and returns z. -// If y == 0, returns an error. -// Rem implements truncated modulus (like Go); see QuoRem for more details. -func RemHint(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error { - nbBits, _, x, y, err := parseHintDivInputs(inputs) - if err != nil { - return err - } - r := new(big.Int) - r.Rem(x, y) - if err := decompose(r, nbBits, outputs); err != nil { - return fmt.Errorf("decompose remainder: %w", err) - } - return nil -} - // computeQuoHint packs the inputs for QuoHint function and returns z = x / y // (discards remainder) func (f *Field[T]) computeQuoHint(x *Element[T]) (z *Element[T], err error) { From ef9d20de22cd43c3b4fc69f53e57bb1c9d09e8f4 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Mon, 3 Jul 2023 15:44:38 +0300 Subject: [PATCH 6/9] fix: clean evaluations after performing mulchecks --- std/math/emulated/field_mul.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/std/math/emulated/field_mul.go b/std/math/emulated/field_mul.go index d44a229f6a..74623a0f20 100644 --- a/std/math/emulated/field_mul.go +++ b/std/math/emulated/field_mul.go @@ -35,6 +35,19 @@ func (mc *mulCheck[T]) check(api frontend.API, peval, coef frontend.Variable) { api.AssertIsEqual(ls, rs) } +func (mc *mulCheck[T]) cleanEvaluations() { + mc.a.evaluation = 0 + mc.a.isEvaluated = false + mc.b.evaluation = 0 + mc.b.isEvaluated = false + mc.r.evaluation = 0 + mc.r.isEvaluated = false + mc.k.evaluation = 0 + mc.k.isEvaluated = false + mc.c.evaluation = 0 + mc.c.isEvaluated = false +} + func (f *Field[T]) mulMod(a, b *Element[T], nextOverflow uint) *Element[T] { f.enforceWidthConditional(a) f.enforceWidthConditional(b) @@ -119,6 +132,11 @@ func (f *Field[T]) performMulChecks(api frontend.API) error { for i := range f.mulChecks { f.mulChecks[i].check(api, pval.evaluation, ccoef) } + // clean cached evaluation. Helps in case we compile the same circuit + // multiple times. + for i := range f.mulChecks { + f.mulChecks[i].cleanEvaluations() + } return nil }, toCommit...) return nil From 05e553079a1da121e0ef26e455c35a1c56ee8d03 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Mon, 3 Jul 2023 16:06:02 +0300 Subject: [PATCH 7/9] fix: constant strict width check --- std/math/emulated/field_assert.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/std/math/emulated/field_assert.go b/std/math/emulated/field_assert.go index 2a02715900..54fa833412 100644 --- a/std/math/emulated/field_assert.go +++ b/std/math/emulated/field_assert.go @@ -95,7 +95,7 @@ func (f *Field[T]) AssertLimbsEquality(a, b *Element[T]) { // (defined by the field parameter). func (f *Field[T]) enforceWidth(a *Element[T], modWidth bool) { if _, aConst := f.constantValue(a); aConst { - if len(a.Limbs) != int(f.fParams.NbLimbs()) { + if modWidth && len(a.Limbs) != int(f.fParams.NbLimbs()) { panic("constant limb width doesn't match parametrized field") } } From 4e5546541fe3d8b425aacaf4125793a7935cbdf8 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Mon, 3 Jul 2023 16:10:54 +0300 Subject: [PATCH 8/9] perf: update stats --- internal/stats/latest.stats | Bin 2816 -> 2816 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/internal/stats/latest.stats b/internal/stats/latest.stats index 3901e90f4984b7e3aa17810acee8d64da52ccc6b..385af8dff9b6af103beff52ffb3f98c84d487cb2 100644 GIT binary patch delta 259 zcmZn=YY>}|Ir$6o!^zg14<=t_xjXp{Tg&8GY(b3TlO2KBnCtB10`|L;9T@*i%)CC? zm*d^!KBkA0H?eBebF5?h$G?ez@gF`$9Z*dY7BvhUjQ>=QG5%A20F=bTQ2qu~BY{N? RgFS^NOtxYQnViqu3IMGhX=?xg delta 265 zcmZn=YY>}|Ie9kQ^~v*D|4i283Ywh2@o;iAS1Y69*z@ spAd=}>?t&1awl`hWOnA3$%k0(PA+19H`$Bv0i*cj1+3zezcQZ%099*dAOHXW From 366950e073bd7ad6e96d8db89f6196b7b180f8ea Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Fri, 1 Dec 2023 02:18:35 +0100 Subject: [PATCH 9/9] docs: update package documentation --- std/math/emulated/doc.go | 47 +++++------------- std/math/emulated/field_mul.go | 87 ++++++++++++++++++++++++++++++---- 2 files changed, 89 insertions(+), 45 deletions(-) diff --git a/std/math/emulated/doc.go b/std/math/emulated/doc.go index 7838ba710f..cb04ce4a38 100644 --- a/std/math/emulated/doc.go +++ b/std/math/emulated/doc.go @@ -86,52 +86,29 @@ then the overflow value f' for the sum is computed as The complexity of native limb-wise multiplication is k^2. This translates directly to the complexity in the number of constraints in the constraint -system. However, alternatively, when instead computing the limb values -off-circuit and constructing a system of k linear equations, we can ensure that -the product was computed correctly. +system. -Let the factors be +For multiplication, we would instead use polynomial representation of the elements: x = ∑_{i=0}^k x_i 2^{w i} - -and - y = ∑_{i=0}^k y_i 2^{w i}. -For computing the product, we compute off-circuit the limbs - - z_i = ∑_{j, j'>0, j+j'=i, j+j'≤2k-2} x_{j} y_{j'}, // in MultiplicationHint() - -and assert in-circuit - - ∑_{i=0}^{2k-2} z_i c^i = (∑_{i=0}^k x_i) (∑_{i=0}^k y_i), ∀ c ∈ {1, ..., 2k-1}. - -Computing the overflow for the multiplication result is slightly more -complicated. The overflow for - - x_{j} y_{j'} - -is - - w+f+f'+1. - -Naively, as the limbs of the result are summed over all 0 ≤ i ≤ 2k-2, then the -overflow of the limbs should be - - w+f+f'+2k-1. +as -For computing the number of bits and thus in the overflow, we can instead look -at the maximal possible value. This can be computed by + x(X) = ∑_{i=0}^k x_i X^i + y(X) = ∑_{i=0}^k y_i X^i. - (2^{2w+f+f'+2}-1)*(2k-1). +If the multiplication result modulo r is c, then the following holds: -Its bitlength is + x * y = c + z*r. - 2w+f+f'+1+log_2(2k-1), +We can check the correctness of the multiplication by checking the following +identity at a random point: -which leads to maximal overflow of + x(X) * y(X) = c(X) + z(X) * r(X) + (2^w' - X) e(X), - w+f+f'+1+log_2(2k-1). +where e(X) is a polynomial used for carrying the overflows of the left- and +right-hand side of the above equation. # Subtraction diff --git a/std/math/emulated/field_mul.go b/std/math/emulated/field_mul.go index 74623a0f20..964a4f3058 100644 --- a/std/math/emulated/field_mul.go +++ b/std/math/emulated/field_mul.go @@ -9,6 +9,48 @@ import ( "github.com/consensys/gnark/std/multicommit" ) +// mulCheck represents a single multiplication check. Instead of doing a +// multiplication exactly where called, we compute the result using hint and +// return it. Additionally, we store the correctness check for later checking +// (together with every other multiplication) to share the verifier challenge +// computation. +// +// With this approach this is important that we do not change the [Element] +// values after they are returned from [mulMod] as mulCheck keeps pointers and +// the check will fail if the values refered to by the pointers change. By +// following the [Field] public methods this shouldn't happend as we always take +// and return pointers, and to change the values the user has to explicitly +// dereference. +// +// We store the values a, b, r, k, c. They are as follows: +// - a, b - the inputs what we are multiplying. Do not have to be reduced. +// - r - the multiplication result reduced modulo the emulation parameter. +// - k - the quotient for integer multiplication a*b divided by emulation parameter. +// - c - element representing carry. Used only for aligning the limb widths. +// +// Given these values, the following holds: +// +// a * b = r * k*p +// +// But for asserting that the previous equation holds, we instead use the +// polynomial representation of the elements. If a non-native element a is given +// by its limbs +// +// a = (a_0, ..., a_n) +// +// then +// +// a(X) = \sum_i a_i * X^i. +// +// Now, the multiplication check instead becomes +// +// a(X) * b(X) = r(X) + k(X) * p(X) + (2^t-X) c(X), +// +// which can be checked only at a single random point. Here we need an +// additional polynomial c(X) which is used for carrying the overflow bits to +// the consecutive limbs. By subtracting 2^t c(X) we can remove the bits from +// the corresponding coefficients in r(X)+k(X)*p(X) and by adding X c(X) we can +// add the bits to X(r(X) + k(X) * p(X)) (i.e. to the next coefficient). type mulCheck[T FieldParams] struct { f *Field[T] // a * b = r + k*p + c @@ -18,23 +60,34 @@ type mulCheck[T FieldParams] struct { c *Element[T] // carry } +// evalRound1 evaluates first c(X), r(X) and k(X) at a given random point at[0]. +// In the first round we do not assume that any of them is already evaluated as +// they come directly from hint. func (mc *mulCheck[T]) evalRound1(api frontend.API, at []frontend.Variable) { mc.c = mc.f.evalWithChallenge(mc.c, at) mc.r = mc.f.evalWithChallenge(mc.r, at) mc.k = mc.f.evalWithChallenge(mc.k, at) } +// evalRound2 now evaluates a and b at a given random point at[0]. However, it +// may happen that a or b is equal to r from a previous mulcheck. In that case +// we can reuse the evaluation to save constraints. func (mc *mulCheck[T]) evalRound2(api frontend.API, at []frontend.Variable) { mc.a = mc.f.evalWithChallenge(mc.a, at) mc.b = mc.f.evalWithChallenge(mc.b, at) } +// check checks a(ch) * b(ch) = r(ch) + k(ch) * p(ch) + (2^t - ch) c(ch). As the +// computation of p(ch) and (2^t-ch) can be shared over all mulCheck instances, +// then we get them already evaluated as peval and coef. func (mc *mulCheck[T]) check(api frontend.API, peval, coef frontend.Variable) { ls := api.Mul(mc.a.evaluation, mc.b.evaluation) rs := api.Add(mc.r.evaluation, api.Mul(peval, mc.k.evaluation), api.Mul(mc.c.evaluation, coef)) api.AssertIsEqual(ls, rs) } +// cleanEvaluations cleans the cached evaluation values. This is necessary for +// ensuring the circuit stability over many compilations. func (mc *mulCheck[T]) cleanEvaluations() { mc.a.evaluation = 0 mc.a.isEvaluated = false @@ -48,7 +101,9 @@ func (mc *mulCheck[T]) cleanEvaluations() { mc.c.isEvaluated = false } -func (f *Field[T]) mulMod(a, b *Element[T], nextOverflow uint) *Element[T] { +// mulMod returns a*b mod r. In practice it computes the result using a hint and +// defers the actual multiplication check. +func (f *Field[T]) mulMod(a, b *Element[T], _ uint) *Element[T] { f.enforceWidthConditional(a) f.enforceWidthConditional(b) k, r, c, err := f.callMulHint(a, b) @@ -67,6 +122,10 @@ func (f *Field[T]) mulMod(a, b *Element[T], nextOverflow uint) *Element[T] { return r } +// evalWithChallenge represents element a as a polynomial a(X) and evaluates at +// at[0]. For efficiency, we use already evaluated powers of at[0] given by at. +// It stores the evaluation result inside the Element and marks it as evaluated. +// If the method is called for already evaluated a then returns the known value. func (f *Field[T]) evalWithChallenge(a *Element[T], at []frontend.Variable) *Element[T] { if a.isEvaluated { return a @@ -83,6 +142,8 @@ func (f *Field[T]) evalWithChallenge(a *Element[T], at []frontend.Variable) *Ele return a } +// performMulChecks should be deferred to actually perform all the +// multiplication checks. func (f *Field[T]) performMulChecks(api frontend.API) error { // use given api. We are in defer and API may be different to what we have // stored. @@ -107,7 +168,9 @@ func (f *Field[T]) performMulChecks(api frontend.API) error { toCommit = append(toCommit, f.mulChecks[i].k.Limbs...) toCommit = append(toCommit, f.mulChecks[i].c.Limbs...) } + // we give all the inputs as inputs to obtain random verifier challenge. multicommit.WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { + // for efficiency, we compute all powers of the challenge as slice at. coefsLen := 0 for i := range f.mulChecks { coefsLen = max(coefsLen, len(f.mulChecks[i].c.Limbs)) @@ -118,6 +181,7 @@ func (f *Field[T]) performMulChecks(api frontend.API) error { at[i] = api.Mul(prev, commitment) prev = at[i] } + // evaluate all r, k, c for i := range f.mulChecks { f.mulChecks[i].evalRound1(api, at) } @@ -125,10 +189,13 @@ func (f *Field[T]) performMulChecks(api frontend.API) error { for i := range f.mulChecks { f.mulChecks[i].evalRound2(api, at) } + // evaluate p(X) at challenge pval := f.evalWithChallenge(f.Modulus(), at) + // compute (2^t-X) at challenge coef := big.NewInt(1) coef.Lsh(coef, f.fParams.BitsPerLimb()) ccoef := api.Sub(coef, commitment) + // verify all mulchecks for i := range f.mulChecks { f.mulChecks[i].check(api, pval.evaluation, ccoef) } @@ -142,6 +209,7 @@ func (f *Field[T]) performMulChecks(api frontend.API) error { return nil } +// callMulHint uses hint to compute r, k and c. func (f *Field[T]) callMulHint(a, b *Element[T]) (quo, rem, carries *Element[T], err error) { // inputs is always nblimbs // quotient may be larger if inputs have overflow @@ -246,14 +314,9 @@ func mulHint(field *big.Int, inputs, outputs []*big.Int) error { return nil } -// Mul computes a*b and returns it. It doesn't reduce the output and it may be -// larger than the modulus. The returned Element has as many limbs as the inputs -// together. If the result wouldn't fit into Element, then locally reduces the -// inputs first. Doesn't mutate inputs. -// -// Even though this method skips reduction and allows for multiplication chains, -// then in most cases it is more efficient to use [Field[T].MulMod] as reducing -// Element with 2 times the limbs is 2 times more expensive. +// Mul computes a*b and reduces it modulo the field order. The returned Element +// has default number of limbs and zero overflow. If the result wouldn't fit +// into Element, then locally reduces the inputs first. Doesn't mutate inputs. // // For multiplying by a constant, use [Field[T].MulConst] method which is more // efficient. @@ -261,8 +324,10 @@ func (f *Field[T]) Mul(a, b *Element[T]) *Element[T] { return f.reduceAndOp(f.mulMod, f.mulPreCond, a, b) } -// Mul computes a*b and reduces it modulo the field order. The returned Element +// MulMod computes a*b and reduces it modulo the field order. The returned Element // has default number of limbs and zero overflow. +// +// Equivalent to [Field[T].Mul], kept for backwards compatibility. func (f *Field[T]) MulMod(a, b *Element[T]) *Element[T] { return f.reduceAndOp(f.mulMod, f.mulPreCond, a, b) } @@ -321,6 +386,8 @@ func (f *Field[T]) mulPreCond(a, b *Element[T]) (nextOverflow uint, err error) { } func (f *Field[T]) mul(a, b *Element[T], nextOverflow uint) *Element[T] { + // TODO: kept for [AssertIsEqual]. Consider if this can be removed and we + // can use MulMod for equality assertion. ba, aConst := f.constantValue(a) bb, bConst := f.constantValue(b) if aConst && bConst {