diff --git a/std/algebra/emulated/sw_emulated/point.go b/std/algebra/emulated/sw_emulated/point.go index 7f4dc1dda5..7338bff86c 100644 --- a/std/algebra/emulated/sw_emulated/point.go +++ b/std/algebra/emulated/sw_emulated/point.go @@ -461,3 +461,73 @@ func (c *Curve[B, S]) ScalarMulBase(s *emulated.Element[S]) *AffinePoint[B] { return res } + +// JointScalarMulBase computes s2 * p + s1 * g and returns it, where g is the +// fixed generator. It doesn't modify p, s1 and s2. +// +// ⚠️ p must NOT be (0,0). +// ⚠️ s1 and s2 must NOT be 0. +// +// JointScalarMulBase is used to verify an ECDSA signature (r,s) on the +// secp256k1 curve. In this case, p is a public key, s2=r/s and s1=hash/s. +// - hash cannot be 0, because of pre-image resistance. +// - r cannot be 0, because r is the x coordinate of a random point on +// secp256k1 (y²=x³+7 mod p) and 7 is not a square mod p. For any other +// curve, (_,0) is a point of order 2 which is not the prime subgroup. +// - (0,0) is not a valid public key. +// +// The [EVM] specifies these checks, wich are performed on the zkEVM +// arithmetization side before calling the circuit that uses this method. +// +// This saves the Select logic related to (0,0) and the use of AddUnified to +// handle the 0-scalar edge case. +func (c *Curve[B, S]) JointScalarMulBase(p *AffinePoint[B], s2, s1 *emulated.Element[S]) *AffinePoint[B] { + g := c.Generator() + gm := c.GeneratorMultiples() + + var st S + s1r := c.scalarApi.Reduce(s1) + s1Bits := c.scalarApi.ToBits(s1r) + s2r := c.scalarApi.Reduce(s2) + s2Bits := c.scalarApi.ToBits(s2r) + n := st.Modulus().BitLen() + + // i = 1, 2 + // gm[0] = 3g, gm[1] = 5g, gm[2] = 7g + res1 := c.Lookup2(s1Bits[1], s1Bits[2], g, &gm[0], &gm[1], &gm[2]) + tmp2 := c.triple(p) + res2 := c.Select(s2Bits[1], tmp2, p) + acc := c.add(tmp2, p) + tmp2 = c.add(res2, acc) + res2 = c.Select(s2Bits[2], tmp2, res2) + acc = c.double(acc) + + for i := 3; i <= n-3; i++ { + // gm[i] = [2^i]g + tmp1 := c.add(res1, &gm[i]) + res1 = c.Select(s1Bits[i], tmp1, res1) + tmp2 = c.add(res2, acc) + res2 = c.Select(s2Bits[i], tmp2, res2) + acc = c.double(acc) + } + + // i = 0 + tmp1 := c.add(res1, c.Neg(g)) + res1 = c.Select(s1Bits[0], res1, tmp1) + tmp2 = c.add(res2, c.Neg(p)) + res2 = c.Select(s2Bits[0], res2, tmp2) + + // i = n-2 + tmp1 = c.add(res1, &gm[n-2]) + res1 = c.Select(s1Bits[n-2], tmp1, res1) + tmp2 = c.add(res2, acc) + res2 = c.Select(s2Bits[n-2], tmp2, res2) + + // i = n-1 + tmp1 = c.add(res1, &gm[n-1]) + res1 = c.Select(s1Bits[n-1], tmp1, res1) + tmp2 = c.doubleAndAdd(acc, res2) + res2 = c.Select(s2Bits[n-1], tmp2, res2) + + return c.add(res1, res2) +} diff --git a/std/evmprecompiles/01-ecrecover.go b/std/evmprecompiles/01-ecrecover.go index 00212b8468..1748a24295 100644 --- a/std/evmprecompiles/01-ecrecover.go +++ b/std/evmprecompiles/01-ecrecover.go @@ -71,9 +71,7 @@ func ECRecover(api frontend.API, msg emulated.Element[emulated.Secp256k1Fr], // compute u2 = s * rinv u2 := frField.MulMod(&s, rinv) // check u1 * G + u2 R == P - A := curve.ScalarMulBase(u1) - B := curve.ScalarMul(&R, u2) - C := curve.AddUnified(A, B) + C := curve.JointScalarMulBase(&R, u2, u1) curve.AssertIsEqual(C, &P) return &P } diff --git a/std/signature/ecdsa/ecdsa.go b/std/signature/ecdsa/ecdsa.go index 72495b303f..cac32d7db1 100644 --- a/std/signature/ecdsa/ecdsa.go +++ b/std/signature/ecdsa/ecdsa.go @@ -50,9 +50,8 @@ func (pk PublicKey[T, S]) Verify(api frontend.API, params sw_emulated.CurveParam msInv := scalarApi.MulMod(msg, sInv) rsInv := scalarApi.MulMod(&sig.R, sInv) - qa := cr.ScalarMulBase(msInv) - qb := cr.ScalarMul(&pkpt, rsInv) - q := cr.AddUnified(qa, qb) + // q = [rsInv]pkpt + [msInv]g + q := cr.JointScalarMulBase(&pkpt, rsInv, msInv) qx := baseApi.Reduce(&q.X) qxBits := baseApi.ToBits(qx) rbits := scalarApi.ToBits(&sig.R)