Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions std/hints.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ func registerHints() {
solver.RegisterHint(bits.NNAF)
solver.RegisterHint(bits.IthBit)
solver.RegisterHint(bits.NBits)
solver.RegisterHint(selector.MuxIndicators)
solver.RegisterHint(selector.MapIndicators)
solver.RegisterHint(selector.GetHints()...)
solver.RegisterHint(emulated.GetHints()...)
}
80 changes: 80 additions & 0 deletions std/selector/doc_partition_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package selector_test

import "github.com/consensys/gnark/frontend"

import (
"fmt"

"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/backend/groth16"
"github.com/consensys/gnark/frontend/cs/r1cs"
"github.com/consensys/gnark/std/selector"
)

// adderCircuit adds first Count number of its input array In.
type adderCircuit struct {
Count frontend.Variable
In [10]frontend.Variable
ExpectedSum frontend.Variable
}

// Define defines the arithmetic circuit.
func (c *adderCircuit) Define(api frontend.API) error {
selectedPart := selector.Partition(api, c.Count, false, c.In[:])
sum := api.Add(selectedPart[0], selectedPart[1], selectedPart[2:]...)
api.AssertIsEqual(sum, c.ExpectedSum)
return nil
}

// ExamplePartition gives an example on how to use selector.Partition to make a circuit that accepts a Count and an
// input array In, and then calculates the sum of first Count numbers of the input array.
func ExamplePartition() {
circuit := adderCircuit{}
witness := adderCircuit{
Count: 6,
In: [10]frontend.Variable{0, 2, 4, 6, 8, 10, 12, 14, 16, 18},
ExpectedSum: 30,
}
ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit)
if err != nil {
panic(err)
} else {
fmt.Println("compiled")
}
pk, vk, err := groth16.Setup(ccs)
if err != nil {
panic(err)
} else {
fmt.Println("setup done")
}
secretWitness, err := frontend.NewWitness(&witness, ecc.BN254.ScalarField())
if err != nil {
panic(err)
} else {
fmt.Println("secret witness")
}
publicWitness, err := secretWitness.Public()
if err != nil {
panic(err)
} else {
fmt.Println("public witness")
}
proof, err := groth16.Prove(ccs, pk, secretWitness)
if err != nil {
panic(err)
} else {
fmt.Println("proof")
}
err = groth16.Verify(proof, vk, publicWitness)
if err != nil {
panic(err)
} else {
fmt.Println("verify")
}
// Output: compiled
// setup done
// secret witness
// public witness
// proof
// verify
}
26 changes: 18 additions & 8 deletions std/selector/multiplexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package selector

import (
"fmt"
"math/big"

"github.com/consensys/gnark/constraint/solver"
Expand All @@ -21,8 +22,13 @@ import (

func init() {
// register hints
solver.RegisterHint(MuxIndicators)
solver.RegisterHint(MapIndicators)
solver.RegisterHint(GetHints()...)
}

// GetHints returns all hint functions used in this package. This method is
// useful for registering all hints in the solver.
func GetHints() []solver.Hint {
return []solver.Hint{stepOutput, muxIndicators, mapIndicators}
}

// Map is a key value associative array: the output will be values[i] such that keys[i] == queryKey. If keys does not
Expand Down Expand Up @@ -55,10 +61,14 @@ func generateSelector(api frontend.API, wantMux bool, sel frontend.Variable,
keys []frontend.Variable, values []frontend.Variable) (out frontend.Variable) {

var indicators []frontend.Variable
var err error
if wantMux {
indicators, _ = api.Compiler().NewHint(MuxIndicators, len(values), sel)
indicators, err = api.Compiler().NewHint(muxIndicators, len(values), sel)
} else {
indicators, _ = api.Compiler().NewHint(MapIndicators, len(keys), append(keys, sel)...)
indicators, err = api.Compiler().NewHint(mapIndicators, len(keys), append(keys, sel)...)
}
if err != nil {
panic(fmt.Sprintf("error in calling Mux/Map hint: %v", err))
}

out = 0
Expand All @@ -82,9 +92,9 @@ func generateSelector(api frontend.API, wantMux bool, sel frontend.Variable,
return out
}

// MuxIndicators is a hint function used within [Mux] function. It must be
// muxIndicators is a hint function used within [Mux] function. It must be
// provided to the prover when circuit uses it.
func MuxIndicators(_ *big.Int, inputs []*big.Int, results []*big.Int) error {
func muxIndicators(_ *big.Int, inputs []*big.Int, results []*big.Int) error {
sel := inputs[0]
for i := 0; i < len(results); i++ {
// i is an int which can be int32 or int64. We convert i to int64 then to bigInt, which is safe. We should
Expand All @@ -98,9 +108,9 @@ func MuxIndicators(_ *big.Int, inputs []*big.Int, results []*big.Int) error {
return nil
}

// MapIndicators is a hint function used within [Map] function. It must be
// mapIndicators is a hint function used within [Map] function. It must be
// provided to the prover when circuit uses it.
func MapIndicators(_ *big.Int, inputs []*big.Int, results []*big.Int) error {
func mapIndicators(_ *big.Int, inputs []*big.Int, results []*big.Int) error {
key := inputs[len(inputs)-1]
// We must make sure that we are initializing all elements of results
for i := 0; i < len(results); i++ {
Expand Down
105 changes: 105 additions & 0 deletions std/selector/slice.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package selector

import (
"fmt"
"github.com/consensys/gnark/frontend"
"math/big"
)

// Slice selects a slice of the input array at indices [start, end), and zeroes the array at other
// indices. More precisely, for each i we have:
//
// if i >= start and i < end
// out[i] = input[i]
// else
// out[i] = 0
//
// We must have start >= 0 and end <= len(input), otherwise a proof cannot be generated.
func Slice(api frontend.API, start, end frontend.Variable, input []frontend.Variable) []frontend.Variable {
// it appears that this is the most efficient implementation. There is also another implementation
// which creates the mask by adding two stepMask outputs, however that would not work correctly when
// end < start.
out := Partition(api, end, false, input)
out = Partition(api, start, true, out)
return out
}

// Partition selects left or right side of the input array, with respect to the pivotPosition.
// More precisely when rightSide is false, for each i we have:
//
// if i < pivotPosition
// out[i] = input[i]
// else
// out[i] = 0
//
// and when rightSide is true, for each i we have:
//
// if i >= pivotPosition
// out[i] = input[i]
// else
// out[i] = 0
//
// We must have pivotPosition >= 0 and pivotPosition <= len(input), otherwise a proof cannot be generated.
func Partition(api frontend.API, pivotPosition frontend.Variable, rightSide bool,
Comment thread
ivokub marked this conversation as resolved.
input []frontend.Variable) (out []frontend.Variable) {
out = make([]frontend.Variable, len(input))
var mask []frontend.Variable
// we create a bit mask to multiply with the input.
if rightSide {
mask = stepMask(api, len(input), pivotPosition, 0, 1)
} else {
mask = stepMask(api, len(input), pivotPosition, 1, 0)
}
for i := 0; i < len(out); i++ {
out[i] = api.Mul(mask[i], input[i])
}
return
}

// stepMask generates a step like function into an output array of a given length.
// The output is an array of length outputLen,
// such that its first stepPosition elements are equal to startValue and the remaining elements are equal to
// endValue. Note that outputLen cannot be a circuit variable.
//
// We must have stepPosition >= 0 and stepPosition <= outputLen, otherwise a proof cannot be generated.
// This function panics when outputLen is less than 2.
func stepMask(api frontend.API, outputLen int,
stepPosition, startValue, endValue frontend.Variable) []frontend.Variable {
if outputLen < 2 {
panic("the output len of StepMask must be >= 2")
}
// get the output as a hint
out, err := api.Compiler().NewHint(stepOutput, outputLen, stepPosition, startValue, endValue)
if err != nil {
panic(fmt.Sprintf("error in calling StepMask hint: %v", err))
}

// add the boundary constraints:
// (out[0] - startValue) * stepPosition == 0
api.AssertIsEqual(api.Mul(api.Sub(out[0], startValue), stepPosition), 0)
// (out[len(out)-1] - endValue) * (len(out) - stepPosition) == 0
api.AssertIsEqual(api.Mul(api.Sub(out[len(out)-1], endValue), api.Sub(len(out), stepPosition)), 0)
Comment thread
ivokub marked this conversation as resolved.

// add constraints for the correct form of a step function that steps at the stepPosition
for i := 1; i < len(out); i++ {
// (out[i] - out[i-1]) * (i - stepPosition) == 0
api.AssertIsEqual(api.Mul(api.Sub(out[i], out[i-1]), api.Sub(i, stepPosition)), 0)
}
return out
}

// stepOutput is a hint function used within [StepMask] function. It must be
// provided to the prover when circuit uses it.
func stepOutput(_ *big.Int, inputs, results []*big.Int) error {
stepPos := inputs[0]
Comment thread
ivokub marked this conversation as resolved.
startValue := inputs[1]
endValue := inputs[2]
for i := 0; i < len(results); i++ {
if i < int(stepPos.Int64()) {
results[i].Set(startValue)
} else {
results[i].Set(endValue)
}
}
return nil
}
Loading