Skip to content
Merged
16 changes: 16 additions & 0 deletions std/algebra/emulated/sw_emulated/point.go
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,22 @@ func (c *Curve[B, S]) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 *AffinePo
}
}

// Mux performs a lookup from the inputs and returns inputs[sel]. It is most
// efficient for power of two lengths of the inputs, but works for any number of
// inputs.
func (c *Curve[B, S]) Mux(sel frontend.Variable, inputs ...*AffinePoint[B]) *AffinePoint[B] {
xs := make([]*emulated.Element[B], len(inputs))
ys := make([]*emulated.Element[B], len(inputs))
for i := range inputs {
xs[i] = &inputs[i].X
ys[i] = &inputs[i].Y
}
return &AffinePoint[B]{
X: *c.baseApi.Mux(sel, xs...),
Y: *c.baseApi.Mux(sel, ys...),
}
}

// ScalarMul computes s * p and returns it. It doesn't modify p nor s.
// This function doesn't check that the p is on the curve. See AssertIsOnCurve.
//
Expand Down
48 changes: 48 additions & 0 deletions std/algebra/emulated/sw_emulated/point_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1091,3 +1091,51 @@ func TestJointScalarMul6(t *testing.T) {
err := test.IsSolved(&circuit, &witness, testCurve.ScalarField())
assert.NoError(err)
}

type MuxCircuitTest[T, S emulated.FieldParams] struct {
Selector frontend.Variable
Inputs [8]AffinePoint[T]
Expected AffinePoint[T]
}

func (c *MuxCircuitTest[T, S]) Define(api frontend.API) error {
cr, err := New[T, S](api, GetCurveParams[T]())
if err != nil {
return err
}
els := make([]*AffinePoint[T], len(c.Inputs))
for i := range c.Inputs {
els[i] = &c.Inputs[i]
}
res := cr.Mux(c.Selector, els...)
cr.AssertIsEqual(res, &c.Expected)
return nil
}

func TestMux(t *testing.T) {
assert := test.NewAssert(t)
circuit := MuxCircuitTest[emulated.BN254Fp, emulated.BN254Fr]{}
r := make([]fr_bn.Element, len(circuit.Inputs))
for i := range r {
r[i].SetRandom()
}
selector, _ := rand.Int(rand.Reader, big.NewInt(int64(len(r))))
expectedR := r[selector.Int64()]
expected := new(bn254.G1Affine).ScalarMultiplicationBase(expectedR.BigInt(new(big.Int)))
witness := MuxCircuitTest[emulated.BN254Fp, emulated.BLS12381Fr]{
Selector: selector,
Expected: AffinePoint[emparams.BN254Fp]{
X: emulated.ValueOf[emulated.BN254Fp](expected.X),
Y: emulated.ValueOf[emulated.BN254Fp](expected.Y),
},
}
for i := range r {
eli := new(bn254.G1Affine).ScalarMultiplicationBase(r[i].BigInt(new(big.Int)))
witness.Inputs[i] = AffinePoint[emparams.BN254Fp]{
X: emulated.ValueOf[emulated.BN254Fp](eli.X),
Y: emulated.ValueOf[emulated.BN254Fp](eli.Y),
}
}
err := test.IsSolved(&circuit, &witness, testCurve.ScalarField())
assert.NoError(err)
}
5 changes: 5 additions & 0 deletions std/algebra/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ type Curve[FR emulated.FieldParams, G1El G1ElementT] interface {
// - p3 if b0=0 and b1=1,
// - p4 if b0=1 and b1=1.
Lookup2(b1 frontend.Variable, b2 frontend.Variable, p1 *G1El, p2 *G1El, p3 *G1El, p4 *G1El) *G1El

// Mux performs a lookup from the inputs and returns inputs[sel]. It is most
// efficient for power of two lengths of the inputs, but works for any
// number of inputs.
Mux(sel frontend.Variable, inputs ...*G1El) *G1El
}

// Pairing allows to compute the bi-linear pairing of G1 and G2 elements.
Expand Down
17 changes: 17 additions & 0 deletions std/algebra/native/sw_bls12377/pairing2.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/consensys/gnark/std/math/bits"
"github.com/consensys/gnark/std/math/emulated"
"github.com/consensys/gnark/std/math/emulated/emparams"
"github.com/consensys/gnark/std/selector"
)

// Curve allows G1 operations in BLS12-377.
Expand Down Expand Up @@ -213,6 +214,22 @@ func (c *Curve) Lookup2(b1, b2 frontend.Variable, p1, p2, p3, p4 *G1Affine) *G1A
}
}

// Mux performs a lookup from the inputs and returns inputs[sel]. It is most
// efficient for power of two lengths of the inputs, but works for any number of
// inputs.
func (c *Curve) Mux(sel frontend.Variable, inputs ...*G1Affine) *G1Affine {
xs := make([]frontend.Variable, len(inputs))
ys := make([]frontend.Variable, len(inputs))
for i := range inputs {
xs[i] = inputs[i].X
ys[i] = inputs[i].Y
}
return &G1Affine{
X: selector.Mux(c.api, sel, xs...),
Y: selector.Mux(c.api, sel, ys...),
}
}

// Pairing allows computing pairing-related operations in BLS12-377.
type Pairing struct {
api frontend.API
Expand Down
55 changes: 55 additions & 0 deletions std/algebra/native/sw_bls12377/pairing2_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package sw_bls12377

import (
"crypto/rand"
"math/big"
"testing"

"github.com/consensys/gnark-crypto/ecc"
bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377"
fr_bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/test"
)

type MuxCircuitTest struct {
Selector frontend.Variable
Inputs [8]G1Affine
Expected G1Affine
}

func (c *MuxCircuitTest) Define(api frontend.API) error {
cr, err := NewCurve(api)
if err != nil {
return err
}
els := make([]*G1Affine, len(c.Inputs))
for i := range c.Inputs {
els[i] = &c.Inputs[i]
}
res := cr.Mux(c.Selector, els...)
cr.AssertIsEqual(res, &c.Expected)
return nil
}

func TestMux(t *testing.T) {
assert := test.NewAssert(t)
circuit := MuxCircuitTest{}
r := make([]fr_bls12377.Element, len(circuit.Inputs))
for i := range r {
r[i].SetRandom()
}
selector, _ := rand.Int(rand.Reader, big.NewInt(int64(len(r))))
expectedR := r[selector.Int64()]
expected := new(bls12377.G1Affine).ScalarMultiplicationBase(expectedR.BigInt(new(big.Int)))
witness := MuxCircuitTest{
Selector: selector,
Expected: NewG1Affine(*expected),
}
for i := range r {
eli := new(bls12377.G1Affine).ScalarMultiplicationBase(r[i].BigInt(new(big.Int)))
witness.Inputs[i] = NewG1Affine(*eli)
}
err := test.IsSolved(&circuit, &witness, ecc.BW6_761.ScalarField())
assert.NoError(err)
}
17 changes: 17 additions & 0 deletions std/algebra/native/sw_bls24315/pairing2.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/consensys/gnark/std/math/bits"
"github.com/consensys/gnark/std/math/emulated"
"github.com/consensys/gnark/std/math/emulated/emparams"
"github.com/consensys/gnark/std/selector"
)

// Curve allows G1 operations in BLS24-315.
Expand Down Expand Up @@ -213,6 +214,22 @@ func (c *Curve) Lookup2(b1, b2 frontend.Variable, p1, p2, p3, p4 *G1Affine) *G1A
}
}

// Mux performs a lookup from the inputs and returns inputs[sel]. It is most
// efficient for power of two lengths of the inputs, but works for any number of
// inputs.
func (c *Curve) Mux(sel frontend.Variable, inputs ...*G1Affine) *G1Affine {
xs := make([]frontend.Variable, len(inputs))
ys := make([]frontend.Variable, len(inputs))
for i := range inputs {
xs[i] = inputs[i].X
ys[i] = inputs[i].Y
}
return &G1Affine{
X: selector.Mux(c.api, sel, xs...),
Y: selector.Mux(c.api, sel, ys...),
}
}

// Pairing allows computing pairing-related operations in BLS24-315.
type Pairing struct {
api frontend.API
Expand Down
55 changes: 55 additions & 0 deletions std/algebra/native/sw_bls24315/pairing2_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package sw_bls24315

import (
"crypto/rand"
"math/big"
"testing"

"github.com/consensys/gnark-crypto/ecc"
bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315"
fr_bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/fr"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/test"
)

type MuxCircuitTest struct {
Selector frontend.Variable
Inputs [8]G1Affine
Expected G1Affine
}

func (c *MuxCircuitTest) Define(api frontend.API) error {
cr, err := NewCurve(api)
if err != nil {
return err
}
els := make([]*G1Affine, len(c.Inputs))
for i := range c.Inputs {
els[i] = &c.Inputs[i]
}
res := cr.Mux(c.Selector, els...)
cr.AssertIsEqual(res, &c.Expected)
return nil
}

func TestMux(t *testing.T) {
assert := test.NewAssert(t)
circuit := MuxCircuitTest{}
r := make([]fr_bls24315.Element, len(circuit.Inputs))
for i := range r {
r[i].SetRandom()
}
selector, _ := rand.Int(rand.Reader, big.NewInt(int64(len(r))))
expectedR := r[selector.Int64()]
expected := new(bls24315.G1Affine).ScalarMultiplicationBase(expectedR.BigInt(new(big.Int)))
witness := MuxCircuitTest{
Selector: selector,
Expected: NewG1Affine(*expected),
}
for i := range r {
eli := new(bls24315.G1Affine).ScalarMultiplicationBase(r[i].BigInt(new(big.Int)))
witness.Inputs[i] = NewG1Affine(*eli)
}
err := test.IsSolved(&circuit, &witness, ecc.BW6_761.ScalarField())
assert.NoError(err)
}
45 changes: 45 additions & 0 deletions std/math/emulated/element_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,51 @@ func testLookup2[T FieldParams](t *testing.T) {
}, testName[T]())
}

type MuxCircuit[T FieldParams] struct {
Selector frontend.Variable
Inputs [8]Element[T]
Expected Element[T]
}

func (c *MuxCircuit[T]) Define(api frontend.API) error {
f, err := NewField[T](api)
if err != nil {
return err
}
inputs := make([]*Element[T], len(c.Inputs))
for i := range inputs {
inputs[i] = &c.Inputs[i]
}
res := f.Mux(c.Selector, inputs...)
f.AssertIsEqual(res, &c.Expected)
return nil
}

func TestMux(t *testing.T) {
testMux[Goldilocks](t)
testMux[Secp256k1Fp](t)
testMux[BN254Fp](t)
}

func testMux[T FieldParams](t *testing.T) {
var fp T
assert := test.NewAssert(t)
assert.Run(func(assert *test.Assert) {
var circuit, witness MuxCircuit[T]
vals := make([]*big.Int, len(witness.Inputs))
for i := range witness.Inputs {
vals[i], _ = rand.Int(rand.Reader, fp.Modulus())
witness.Inputs[i] = ValueOf[T](vals[i])
}
selector, _ := rand.Int(rand.Reader, big.NewInt(int64(len(witness.Inputs))))
expected := vals[selector.Int64()]
witness.Expected = ValueOf[T](expected)
witness.Selector = selector

assert.CheckCircuit(&circuit, test.WithValidAssignment(&witness))
})
}

type ComputationCircuit[T FieldParams] struct {
noReduce bool

Expand Down
50 changes: 50 additions & 0 deletions std/math/emulated/field_ops.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"

"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/selector"
)

// Div computes a/b and returns it. It uses [DivHint] as a hint function.
Expand Down Expand Up @@ -266,6 +267,55 @@ func (f *Field[T]) Lookup2(b0, b1 frontend.Variable, a, b, c, d *Element[T]) *El
return e
}

// Mux selects element inputs[sel] and returns it. The number of the limbs and
// overflow in the result is the maximum of the inputs'. If the inputs are very
// unbalanced, then reduce the inputs before calling the method. It is most
// efficient for power of two lengths of the inputs, but works for any
// number of inputs.
func (f *Field[T]) Mux(sel frontend.Variable, inputs ...*Element[T]) *Element[T] {
if len(inputs) == 0 {
return nil
}
nbInputs := len(inputs)
overflow := uint(0)
nbLimbs := 0
for i := range inputs {
f.enforceWidthConditional(inputs[i])
if inputs[i].overflow > overflow {
overflow = inputs[i].overflow
}
if len(inputs[i].Limbs) > nbLimbs {
nbLimbs = len(inputs[i].Limbs)
}
}
normalize := func(limbs []frontend.Variable) []frontend.Variable {
if len(limbs) < nbLimbs {
tail := make([]frontend.Variable, nbLimbs-len(limbs))
for i := range tail {
tail[i] = 0
}
return append(limbs, tail...)
}
return limbs
}
normLimbs := make([][]frontend.Variable, nbInputs)
for i := range inputs {
normLimbs[i] = normalize(inputs[i].Limbs)
}
normLimbsTransposed := make([][]frontend.Variable, nbLimbs)
for i := range normLimbsTransposed {
normLimbsTransposed[i] = make([]frontend.Variable, nbInputs)
for j := range normLimbsTransposed[i] {
normLimbsTransposed[i][j] = normLimbs[j][i]
}
}
e := f.newInternalElement(make([]frontend.Variable, nbLimbs), overflow)
for i := range inputs[0].Limbs {
e.Limbs[i] = selector.Mux(f.api, sel, normLimbsTransposed[i]...)
}
return e
}

// reduceAndOp applies op on the inputs. If the pre-condition check preCond
// errs, then first reduces the input arguments. The reduction is done
// one-by-one with the element with highest overflow reduced first.
Expand Down
Loading