diff --git a/std/math/emulated/element_test.go b/std/math/emulated/element_test.go index fa8ebaa059..c80460996f 100644 --- a/std/math/emulated/element_test.go +++ b/std/math/emulated/element_test.go @@ -1500,3 +1500,50 @@ func testFastPaths[T FieldParams](t *testing.T) { assert.CheckCircuit(circuit, test.WithValidAssignment(assignment)) } + +type TestAssertIsDifferentCircuit[T FieldParams] struct { + A, B Element[T] + addMod bool +} + +func (c *TestAssertIsDifferentCircuit[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return err + } + b := &c.B + if c.addMod { + b = f.Add(b, f.Modulus()) + } + f.AssertIsDifferent(&c.A, b) + return nil +} + +func TestAssertIsDifferent(t *testing.T) { + testAssertIsDifferent[Goldilocks](t) + testAssertIsDifferent[Secp256k1Fp](t) + testAssertIsDifferent[BN254Fp](t) +} + +func testAssertIsDifferent[T FieldParams](t *testing.T) { + assert := test.NewAssert(t) + circuitNoMod := &TestAssertIsDifferentCircuit[T]{addMod: false} + var fp T + a, _ := rand.Int(rand.Reader, fp.Modulus()) + assignment1 := &TestAssertIsDifferentCircuit[T]{A: ValueOf[T](a), B: ValueOf[T](a)} + var b *big.Int + for { + b, _ = rand.Int(rand.Reader, fp.Modulus()) + if b.Cmp(a) == 0 { + continue + } + break + } + assignment2 := &TestAssertIsDifferentCircuit[T]{A: ValueOf[T](a), B: ValueOf[T](b)} + assert.CheckCircuit(circuitNoMod, test.WithInvalidAssignment(assignment1), test.WithValidAssignment(assignment2)) + + circuitWithMod := &TestAssertIsDifferentCircuit[T]{addMod: true} + assignment3 := &TestAssertIsDifferentCircuit[T]{A: ValueOf[T](a), B: ValueOf[T](a)} + assignment4 := &TestAssertIsDifferentCircuit[T]{A: ValueOf[T](a), B: ValueOf[T](b)} + assert.CheckCircuit(circuitWithMod, test.WithInvalidAssignment(assignment3), test.WithValidAssignment(assignment4)) +} diff --git a/std/math/emulated/field_assert.go b/std/math/emulated/field_assert.go index e28f7c7da0..c747c68683 100644 --- a/std/math/emulated/field_assert.go +++ b/std/math/emulated/field_assert.go @@ -153,6 +153,14 @@ func (f *Field[T]) IsZero(a *Element[T]) frontend.Variable { return f.api.Or(res0, resP) } +// AssertIsDifferent asserts that a and b are different. +func (f *Field[T]) AssertIsDifferent(a, b *Element[T]) { + // we skip conditional width checking as it is done in [Sub] below + diff := f.Sub(a, b) + diffIsZero := f.IsZero(diff) + f.api.AssertIsEqual(diffIsZero, 0) +} + // // Cmp returns: // // - -1 if a < b // // - 0 if a = b @@ -172,18 +180,3 @@ func (f *Field[T]) IsZero(a *Element[T]) frontend.Variable { // } // return res // } - -// TODO(@ivokub) -// func (f *Field[T]) AssertIsDifferent(a, b *Element[T]) { -// ca := f.Reduce(a) -// f.AssertIsInRange(ca) -// cb := f.Reduce(b) -// f.AssertIsInRange(cb) -// var res frontend.Variable = 0 -// for i := 0; i < int(f.fParams.NbLimbs()); i++ { -// cmp := f.api.Cmp(ca.Limbs[i], cb.Limbs[i]) -// cmpsq := f.api.Mul(cmp, cmp) -// res = f.api.Add(res, cmpsq) -// } -// f.api.AssertIsDifferent(res, 0) -// }