Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions internal/backend/circuits/div.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ func init() {
m := ecc.BN254.ScalarField()
var c big.Int
c.ModInverse(b, m).Mul(&c, a)
c.Mod(&c, m)

// good.A = a
good.A = a
Expand Down
156 changes: 86 additions & 70 deletions test/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,71 +133,73 @@ func IsSolved(circuit, witness frontend.Circuit, field *big.Int, opts ...TestEng
}

func (e *engine) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable {
b1, b2 := e.toBigInt(i1), e.toBigInt(i2)
b1.Add(&b1, &b2)
res := new(big.Int)
res.Add(e.toBigInt(i1), e.toBigInt(i2))
for i := 0; i < len(in); i++ {
bn := e.toBigInt(in[i])
b1.Add(&b1, &bn)
res.Add(res, e.toBigInt(in[i]))
}
b1.Mod(&b1, e.modulus())
return b1
res.Mod(res, e.modulus())
return res
}

func (e *engine) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable {
b1, b2 := e.toBigInt(i1), e.toBigInt(i2)
b1.Sub(&b1, &b2)
res := new(big.Int)
res.Sub(e.toBigInt(i1), e.toBigInt(i2))
for i := 0; i < len(in); i++ {
bn := e.toBigInt(in[i])
b1.Sub(&b1, &bn)
res.Sub(res, e.toBigInt(in[i]))
}
b1.Mod(&b1, e.modulus())
return b1
res.Mod(res, e.modulus())
return res
}

func (e *engine) Neg(i1 frontend.Variable) frontend.Variable {
b1 := e.toBigInt(i1)
b1.Neg(&b1)
b1.Mod(&b1, e.modulus())
return b1
res := new(big.Int)
res.Neg(e.toBigInt(i1))
res.Mod(res, e.modulus())
return res
}

func (e *engine) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable {
b1, b2 := e.toBigInt(i1), e.toBigInt(i2)
b1.Mul(&b1, &b2).Mod(&b1, e.modulus())
res := new(big.Int)
res.Mul(e.toBigInt(i1), e.toBigInt(i2))
res.Mod(res, e.modulus())
for i := 0; i < len(in); i++ {
bn := e.toBigInt(in[i])
b1.Mul(&b1, &bn).Mod(&b1, e.modulus())
res.Mul(res, e.toBigInt(in[i]))
res.Mod(res, e.modulus())
}
return b1
return res
}

func (e *engine) Div(i1, i2 frontend.Variable) frontend.Variable {
b1, b2 := e.toBigInt(i1), e.toBigInt(i2)
if b2.ModInverse(&b2, e.modulus()) == nil {
res := new(big.Int)
if res.ModInverse(e.toBigInt(i2), e.modulus()) == nil {
panic("no inverse")
}
b2.Mul(&b1, &b2).Mod(&b2, e.modulus())
return b2
res.Mul(res, e.toBigInt(i1))
res.Mod(res, e.modulus())
return res
}

func (e *engine) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable {
res := new(big.Int)
b1, b2 := e.toBigInt(i1), e.toBigInt(i2)
if b1.IsUint64() && b2.IsUint64() && b1.Uint64() == 0 && b2.Uint64() == 0 {
return 0
}
if b2.ModInverse(&b2, e.modulus()) == nil {
if res.ModInverse(b2, e.modulus()) == nil {
panic("no inverse")
}
b2.Mul(&b1, &b2).Mod(&b2, e.modulus())
return b2
res.Mul(res, b1)
res.Mod(res, e.modulus())
return res
}

func (e *engine) Inverse(i1 frontend.Variable) frontend.Variable {
b1 := e.toBigInt(i1)
if b1.ModInverse(&b1, e.modulus()) == nil {
res := new(big.Int)
if res.ModInverse(e.toBigInt(i1), e.modulus()) == nil {
panic("no inverse")
}
return b1
return res
}

func (e *engine) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable {
Expand All @@ -224,62 +226,67 @@ func (e *engine) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable {

// this is a sanity check, it should never happen
value := e.toBigInt(e.FromBinary(ri...))
if value.Cmp(&b1) != 0 {
if value.Cmp(b1) != 0 {

panic(fmt.Sprintf("[ToBinary] decomposing %s (bitLen == %d) with %d bits reconstructs into %s", b1.String(), b1.BitLen(), nbBits, value.String()))
}
return r
}

func (e *engine) FromBinary(v ...frontend.Variable) frontend.Variable {
bits := make([]big.Int, len(v))
bits := make([]*big.Int, len(v))
for i := 0; i < len(v); i++ {
bits[i] = e.toBigInt(v[i])
e.mustBeBoolean(&bits[i])
e.mustBeBoolean(bits[i])
}

// Σ (2**i * bits[i]) == r
var c, r big.Int
c := new(big.Int)
r := new(big.Int)
tmp := new(big.Int)
c.SetUint64(1)

for i := 0; i < len(bits); i++ {
bits[i].Mul(&bits[i], &c)
r.Add(&r, &bits[i])
c.Lsh(&c, 1)
tmp.Mul(bits[i], c)
r.Add(r, tmp)
c.Lsh(c, 1)
}
r.Mod(&r, e.modulus())
r.Mod(r, e.modulus())

return r
}

func (e *engine) Xor(i1, i2 frontend.Variable) frontend.Variable {
b1, b2 := e.toBigInt(i1), e.toBigInt(i2)
e.mustBeBoolean(&b1)
e.mustBeBoolean(&b2)
b1.Xor(&b1, &b2)
return b1
e.mustBeBoolean(b1)
e.mustBeBoolean(b2)
res := new(big.Int)
res.Xor(b1, b2)
return res
}

func (e *engine) Or(i1, i2 frontend.Variable) frontend.Variable {
b1, b2 := e.toBigInt(i1), e.toBigInt(i2)
e.mustBeBoolean(&b1)
e.mustBeBoolean(&b2)
b1.Or(&b1, &b2)
return b1
e.mustBeBoolean(b1)
e.mustBeBoolean(b2)
res := new(big.Int)
res.Or(b1, b2)
return res
}

func (e *engine) And(i1, i2 frontend.Variable) frontend.Variable {
b1, b2 := e.toBigInt(i1), e.toBigInt(i2)
e.mustBeBoolean(&b1)
e.mustBeBoolean(&b2)
b1.And(&b1, &b2)
return b1
e.mustBeBoolean(b1)
e.mustBeBoolean(b2)
res := new(big.Int)
res.And(b1, b2)
return res
}

// Select if b is true, yields i1 else yields i2
func (e *engine) Select(b frontend.Variable, i1, i2 frontend.Variable) frontend.Variable {
b1 := e.toBigInt(b)
e.mustBeBoolean(&b1)
e.mustBeBoolean(b1)

if b1.Uint64() == 1 {
return e.toBigInt(i1)
Expand All @@ -293,10 +300,10 @@ func (e *engine) Select(b frontend.Variable, i1, i2 frontend.Variable) frontend.
func (e *engine) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 frontend.Variable) frontend.Variable {
s0 := e.toBigInt(b0)
s1 := e.toBigInt(b1)
e.mustBeBoolean(&s0)
e.mustBeBoolean(&s1)
lookup := new(big.Int).Lsh(&s1, 1)
lookup.Or(lookup, &s0)
e.mustBeBoolean(s0)
e.mustBeBoolean(s1)
lookup := new(big.Int).Lsh(s1, 1)
lookup.Or(lookup, s0)
return e.toBigInt([]frontend.Variable{i0, i1, i2, i3}[lookup.Uint64()])
}

Expand All @@ -305,36 +312,38 @@ func (e *engine) IsZero(i1 frontend.Variable) frontend.Variable {
b1 := e.toBigInt(i1)

if b1.IsUint64() && b1.Uint64() == 0 {
return 1
return big.NewInt(1)
}

return (0)
return big.NewInt(0)
}

// Cmp returns 1 if i1>i2, 0 if i1==i2, -1 if i1<i2
func (e *engine) Cmp(i1, i2 frontend.Variable) frontend.Variable {
b1 := e.toBigInt(i1)
b2 := e.toBigInt(i2)
return e.toBigInt(b1.Cmp(&b2))
res := big.NewInt(int64(b1.Cmp(b2)))
res.Mod(res, e.modulus())
return res
}

func (e *engine) AssertIsEqual(i1, i2 frontend.Variable) {
b1, b2 := e.toBigInt(i1), e.toBigInt(i2)
if b1.Cmp(&b2) != 0 {
if b1.Cmp(b2) != 0 {
panic(fmt.Sprintf("[assertIsEqual] %s == %s", b1.String(), b2.String()))
}
}

func (e *engine) AssertIsDifferent(i1, i2 frontend.Variable) {
b1, b2 := e.toBigInt(i1), e.toBigInt(i2)
if b1.Cmp(&b2) == 0 {
if b1.Cmp(b2) == 0 {
panic(fmt.Sprintf("[assertIsDifferent] %s != %s", b1.String(), b2.String()))
}
}

func (e *engine) AssertIsBoolean(i1 frontend.Variable) {
b1 := e.toBigInt(i1)
e.mustBeBoolean(&b1)
e.mustBeBoolean(b1)
}

func (e *engine) AssertIsLessOrEqual(v frontend.Variable, bound frontend.Variable) {
Expand All @@ -346,7 +355,7 @@ func (e *engine) AssertIsLessOrEqual(v frontend.Variable, bound frontend.Variabl
}

b1 := e.toBigInt(v)
if b1.Cmp(&bValue) == 1 {
if b1.Cmp(bValue) == 1 {
panic(fmt.Sprintf("[assertIsLessOrEqual] %s > %s", b1.String(), bValue.String()))
}
}
Expand Down Expand Up @@ -380,8 +389,7 @@ func (e *engine) NewHint(f hint.Function, nbOutputs int, inputs ...frontend.Vari
in := make([]*big.Int, len(inputs))

for i := 0; i < len(inputs); i++ {
v := e.toBigInt(inputs[i])
in[i] = &v
in[i] = e.toBigInt(inputs[i])
}
res := make([]*big.Int, nbOutputs)
for i := range res {
Expand All @@ -396,6 +404,7 @@ func (e *engine) NewHint(f hint.Function, nbOutputs int, inputs ...frontend.Vari

out := make([]frontend.Variable, len(res))
for i := range res {
res[i].Mod(res[i], e.q)
out[i] = res[i]
}

Expand All @@ -410,7 +419,7 @@ func (e *engine) IsConstant(v frontend.Variable) bool {
// ConstantValue returns the big.Int value of v
func (e *engine) ConstantValue(v frontend.Variable) (*big.Int, bool) {
r := e.toBigInt(v)
return &r, e.constVars
return r, e.constVars
}

func (e *engine) IsBoolean(v frontend.Variable) bool {
Expand All @@ -433,10 +442,17 @@ func (e *engine) AddCounter(from, to frontend.Tag) {
// do nothing, we don't measure constraints with the test engine
}

func (e *engine) toBigInt(i1 frontend.Variable) big.Int {
b := utils.FromInterface(i1)
b.Mod(&b, e.modulus())
return b
func (e *engine) toBigInt(i1 frontend.Variable) *big.Int {
switch vv := i1.(type) {
case *big.Int:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we still do a mod reduce if it's a big int? may slow things down (or we could have a fast path with a comparaison first) --> you can have circuit inputs / constants at this stage, so it may trigger weird edge cases if we allow some values in the test engine to not be mod reduced?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to rewrite the API methods such that toBigInt is only called for method inputs and I always recreate a new *big.Int for temp variables and result. The operations and results are always mod reduced.

I also tried to mod reduce the inputs in toBigInt method, but as I will be modifying inputs, I started having a few failing edge cases for the integration tests. Another approach would be to allocate a new *big.Int, set its value from the input and mod reduce it, but this goes against the goal of this PR which was to prevent unnecessary allocations.

return vv
case big.Int:
return &vv
default:
b := utils.FromInterface(i1)
b.Mod(&b, e.modulus())
return &b
}
}

// bitLen returns the number of bits needed to represent a fr.Element
Expand Down