diff --git a/std/hash/hash.go b/std/hash/hash.go index c1f6fe1358..554c3ea471 100644 --- a/std/hash/hash.go +++ b/std/hash/hash.go @@ -84,10 +84,34 @@ type BinaryHasher interface { // the length of the input is the total number of bytes written. type BinaryFixedLengthHasher interface { BinaryHasher - // FixedLengthSum returns digest of the first length bytes. + // FixedLengthSum returns digest of the first length bytes. See the + // [WithMinimalLength] option for setting lower bound on length. FixedLengthSum(length frontend.Variable) []uints.U8 } +// HasherConfig allows to configure the behavior of the hash constructors. Do +// not initialize the configuration directly but rather use the [Option] +// functions which perform correct initializations. This configuration is +// exported for importing in hash implementations. +type HasherConfig struct { + MinimalLength int +} + +// Option allows configuring the hash functions. +type Option func(*HasherConfig) error + +// WithMinimalLength hints the minimal length of the input to the hash function. +// This allows to optimize the constraint count when calling +// [BinaryFixedLengthHasher.FixedLengthSum] as we can avoid selecting between +// the dummy padding and actual padding. If this option is not provided, then we +// assume the minimal length is 0. +func WithMinimalLength(minimalLength int) Option { + return func(cfg *HasherConfig) error { + cfg.MinimalLength = minimalLength + return nil + } +} + // Compressor is a 2-1 one-way function. It takes two inputs and compresses // them into one output. // diff --git a/std/hash/sha2/sha2.go b/std/hash/sha2/sha2.go index ea36f7f70d..4faf8577de 100644 --- a/std/hash/sha2/sha2.go +++ b/std/hash/sha2/sha2.go @@ -6,6 +6,7 @@ package sha2 import ( "encoding/binary" + "fmt" "math/big" "github.com/consensys/gnark/frontend" @@ -25,14 +26,22 @@ type digest struct { api frontend.API uapi *uints.BinaryField[uints.U32] in []uints.U8 + + minimalLength int } -func New(api frontend.API) (hash.BinaryFixedLengthHasher, error) { +func New(api frontend.API, opts ...hash.Option) (hash.BinaryFixedLengthHasher, error) { + cfg := new(hash.HasherConfig) + for _, opt := range opts { + if err := opt(cfg); err != nil { + return nil, fmt.Errorf("applying option: %w", err) + } + } uapi, err := uints.New[uints.U32](api) if err != nil { - return nil, err + return nil, fmt.Errorf("initializing uints: %w", err) } - return &digest{api: api, uapi: uapi}, nil + return &digest{api: api, uapi: uapi, minimalLength: cfg.MinimalLength}, nil } func (d *digest) Write(data []uints.U8) { @@ -68,9 +77,14 @@ func (d *digest) Sum() []uints.U8 { copy(buf[:], padded[i*64:(i+1)*64]) runningDigest = sha2.Permute(d.uapi, runningDigest, buf) } + + return d.unpackU8Digest(runningDigest) +} + +func (d *digest) unpackU8Digest(digest [8]uints.U32) []uints.U8 { var ret []uints.U8 - for i := range runningDigest { - ret = append(ret, d.uapi.UnpackMSB(runningDigest[i])...) + for i := range digest { + ret = append(ret, d.uapi.UnpackMSB(digest[i])...) } return ret } @@ -85,15 +99,18 @@ func (d *digest) FixedLengthSum(length frontend.Variable) []uints.U8 { // idea - have a mask for blocks where 1 is only for the block we want to // use. - data := make([]uints.U8, len(d.in)) - copy(data, d.in) - - comparator := cmp.NewBoundedComparator(d.api, big.NewInt(int64(len(data)+64+8)), false) - - for i := 0; i < 64+8; i++ { - data = append(data, uints.NewU8(0)) + maxLen := len(d.in) + comparator := cmp.NewBoundedComparator(d.api, big.NewInt(int64(maxLen+64+8)), false) + // when minimal length is 0 (i.e. not set), then we can skip the check as it holds naturally (all field elements are non-negative) + if d.minimalLength > 0 { + // we use comparator as [frontend.API] doesn't have a fast path for case API.AssertIsLessOrEqual(constant, variable) + comparator.AssertIsLessEq(d.minimalLength, length) } + data := make([]uints.U8, maxLen) + copy(data, d.in) + data = append(data, uints.NewU8Array(make([]uint8, 64+8))...) + lenMod64 := d.mod64(length) lenMod64Less56 := comparator.IsLess(lenMod64, 56) @@ -106,16 +123,18 @@ func (d *digest) FixedLengthSum(length frontend.Variable) []uints.U8 { var dataLenBtyes [8]frontend.Variable d.bigEndianPutUint64(dataLenBtyes[:], d.api.Mul(length, 8)) - for i := range data { - isPaddingStartPos := d.api.IsZero(d.api.Sub(i, length)) + // When i < minLen or i > maxLen, padding 1 0r 0 is completely unnecessary + for i := d.minimalLength; i <= maxLen; i++ { + isPaddingStartPos := cmp.IsEqual(d.api, i, length) data[i].Val = d.api.Select(isPaddingStartPos, 0x80, data[i].Val) isPaddingPos := comparator.IsLess(length, i) data[i].Val = d.api.Select(isPaddingPos, 0, data[i].Val) } - for i := range data { - isLast8BytesPos := d.api.IsZero(d.api.Sub(i, last8BytesPos)) + // When i <= minLen, padding length is completely unnecessary + for i := d.minimalLength + 1; i < len(data); i++ { + isLast8BytesPos := cmp.IsEqual(d.api, i, last8BytesPos) for j := 0; j < 8; j++ { if i+j < len(data) { data[i+j].Val = d.api.Select(isLast8BytesPos, dataLenBtyes[j], data[i+j].Val) @@ -127,14 +146,20 @@ func (d *digest) FixedLengthSum(length frontend.Variable) []uints.U8 { var resultDigest [8]uints.U32 var buf [64]uints.U8 copy(runningDigest[:], _seed) - copy(resultDigest[:], _seed) for i := 0; i < len(data)/64; i++ { copy(buf[:], data[i*64:(i+1)*64]) runningDigest = sha2.Permute(d.uapi, runningDigest, buf) - isInRange := comparator.IsLess(i*64, totalLen) + // When i < minLen/64, runningDigest cannot be resultDigest, and proceed to the next loop directly + if i < d.minimalLength/64 { + continue + } else if i == d.minimalLength/64 { // init resultDigests + copy(resultDigest[:], runningDigest[:]) + continue + } + isInRange := comparator.IsLess(i*64, totalLen) for j := 0; j < 8; j++ { for k := 0; k < 4; k++ { resultDigest[j][k].Val = d.api.Select(isInRange, runningDigest[j][k].Val, resultDigest[j][k].Val) @@ -142,11 +167,7 @@ func (d *digest) FixedLengthSum(length frontend.Variable) []uints.U8 { } } - var ret []uints.U8 - for i := range resultDigest { - ret = append(ret, d.uapi.UnpackMSB(resultDigest[i])...) - } - return ret + return d.unpackU8Digest(resultDigest) } func (d *digest) Reset() { diff --git a/std/hash/sha2/sha2_test.go b/std/hash/sha2/sha2_test.go index 0093fddc43..5f3572ac8a 100644 --- a/std/hash/sha2/sha2_test.go +++ b/std/hash/sha2/sha2_test.go @@ -1,12 +1,14 @@ package sha2 import ( + "crypto/rand" "crypto/sha256" "fmt" "testing" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/hash" "github.com/consensys/gnark/std/math/uints" "github.com/consensys/gnark/test" ) @@ -53,10 +55,13 @@ type sha2FixedLengthCircuit struct { In []uints.U8 Length frontend.Variable Expected [32]uints.U8 + + // minimal length of the input is the circuit parameter + minimalLength int } func (c *sha2FixedLengthCircuit) Define(api frontend.API) error { - h, err := New(api) + h, err := New(api, hash.WithMinimalLength(c.minimalLength)) if err != nil { return err } @@ -76,16 +81,30 @@ func (c *sha2FixedLengthCircuit) Define(api frontend.API) error { } func TestSHA2FixedLengthSum(t *testing.T) { - bts := make([]byte, 144) - length := 56 - dgst := sha256.Sum256(bts[:length]) - witness := sha2FixedLengthCircuit{ - In: uints.NewU8Array(bts), - Length: length, - } - copy(witness.Expected[:], uints.NewU8Array(dgst[:])) - err := test.IsSolved(&sha2FixedLengthCircuit{In: make([]uints.U8, len(bts))}, &witness, ecc.BN254.ScalarField()) - if err != nil { - t.Fatal(err) + const maxLen = 144 + assert := test.NewAssert(t) + bts := make([]byte, maxLen) + _, err := rand.Reader.Read(bts) + assert.NoError(err) + + for _, lengthBound := range []int{0, 1, 63, 64, 65, len(bts)} { + circuit := &sha2FixedLengthCircuit{In: make([]uints.U8, len(bts)), minimalLength: lengthBound} + for _, length := range []int{0, 1, 63, 64, 65, len(bts)} { + assert.Run(func(assert *test.Assert) { + dgst := sha256.Sum256(bts[:length]) + witness := &sha2FixedLengthCircuit{ + In: uints.NewU8Array(bts), + Length: length, + Expected: [32]uints.U8(uints.NewU8Array(dgst[:])), + } + + err = test.IsSolved(circuit, witness, ecc.BN254.ScalarField()) + if length >= lengthBound { + assert.NoError(err) + } else if length < lengthBound { + assert.Error(err, "expected error for length < lengthBound") + } + }, fmt.Sprintf("bound=%d/length=%d", lengthBound, length)) + } } } diff --git a/std/hash/sha3/hashes.go b/std/hash/sha3/hashes.go index 4fac328955..38273f9ea7 100644 --- a/std/hash/sha3/hashes.go +++ b/std/hash/sha3/hashes.go @@ -1,99 +1,69 @@ package sha3 import ( + "fmt" + "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/hash" "github.com/consensys/gnark/std/math/uints" ) -// New256 creates a new SHA3-256 hash. -// Its generic security strength is 256 bits against preimage attacks, -// and 128 bits against collision attacks. -func New256(api frontend.API) (hash.BinaryFixedLengthHasher, error) { +// newHash is a helper function to create a new SHA3 hash. +func newHash(api frontend.API, dsByte byte, rate, outputLen int, opts ...hash.Option) (hash.BinaryFixedLengthHasher, error) { + cfg := new(hash.HasherConfig) + for _, opt := range opts { + if err := opt(cfg); err != nil { + return nil, fmt.Errorf("applying option: %w", err) + } + } uapi, err := uints.New[uints.U64](api) if err != nil { - return nil, err + return nil, fmt.Errorf("initializing uints: %w", err) } return &digest{ - api: api, - uapi: uapi, - state: newState(), - dsbyte: 0x06, - rate: 136, - outputLen: 32, + api: api, + uapi: uapi, + state: newState(), + dsbyte: dsByte, + rate: rate, + outputLen: outputLen, + minimalLength: cfg.MinimalLength, }, nil } +// New256 creates a new SHA3-256 hash. +// Its generic security strength is 256 bits against preimage attacks, +// and 128 bits against collision attacks. +func New256(api frontend.API, opts ...hash.Option) (hash.BinaryFixedLengthHasher, error) { + return newHash(api, 0x06, 136, 32, opts...) +} + // New384 creates a new SHA3-384 hash. // Its generic security strength is 384 bits against preimage attacks, // and 192 bits against collision attacks. -func New384(api frontend.API) (hash.BinaryFixedLengthHasher, error) { - uapi, err := uints.New[uints.U64](api) - if err != nil { - return nil, err - } - return &digest{ - api: api, - uapi: uapi, - state: newState(), - dsbyte: 0x06, - rate: 104, - outputLen: 48, - }, nil +func New384(api frontend.API, opts ...hash.Option) (hash.BinaryFixedLengthHasher, error) { + return newHash(api, 0x06, 104, 48, opts...) } // New512 creates a new SHA3-512 hash. // Its generic security strength is 512 bits against preimage attacks, // and 256 bits against collision attacks. -func New512(api frontend.API) (hash.BinaryFixedLengthHasher, error) { - uapi, err := uints.New[uints.U64](api) - if err != nil { - return nil, err - } - return &digest{ - api: api, - uapi: uapi, - state: newState(), - dsbyte: 0x06, - rate: 72, - outputLen: 64, - }, nil +func New512(api frontend.API, opts ...hash.Option) (hash.BinaryFixedLengthHasher, error) { + return newHash(api, 0x06, 72, 64, opts...) } // NewLegacyKeccak256 creates a new Keccak-256 hash. // // Only use this function if you require compatibility with an existing cryptosystem // that uses non-standard padding. All other users should use New256 instead. -func NewLegacyKeccak256(api frontend.API) (hash.BinaryFixedLengthHasher, error) { - uapi, err := uints.New[uints.U64](api) - if err != nil { - return nil, err - } - return &digest{ - api: api, - uapi: uapi, - state: newState(), - dsbyte: 0x01, - rate: 136, - outputLen: 32, - }, nil +func NewLegacyKeccak256(api frontend.API, opts ...hash.Option) (hash.BinaryFixedLengthHasher, error) { + return newHash(api, 0x01, 136, 32, opts...) } // NewLegacyKeccak512 creates a new Keccak-512 hash. // // Only use this function if you require compatibility with an existing cryptosystem // that uses non-standard padding. All other users should use New512 instead. -func NewLegacyKeccak512(api frontend.API) (hash.BinaryFixedLengthHasher, error) { - uapi, err := uints.New[uints.U64](api) - if err != nil { - return nil, err - } - return &digest{ - api: api, - uapi: uapi, - state: newState(), - dsbyte: 0x01, - rate: 72, - outputLen: 64, - }, nil +func NewLegacyKeccak512(api frontend.API, opts ...hash.Option) (hash.BinaryFixedLengthHasher, error) { + return newHash(api, 0x01, 72, 64, opts...) } diff --git a/std/hash/sha3/sha3.go b/std/hash/sha3/sha3.go index be5ba4df56..cc6152ffc2 100644 --- a/std/hash/sha3/sha3.go +++ b/std/hash/sha3/sha3.go @@ -10,13 +10,14 @@ import ( ) type digest struct { - api frontend.API - uapi *uints.BinaryField[uints.U64] - state [25]uints.U64 // 1600 bits state: 25 x 64 - in []uints.U8 // input to be digested - dsbyte byte // dsbyte contains the "domain separation" bits and the first bit of the padding - rate int // the number of bytes of state to use - outputLen int // the default output size in bytes + api frontend.API + uapi *uints.BinaryField[uints.U64] + state [25]uints.U64 // 1600 bits state: 25 x 64 + in []uints.U8 // input to be digested + dsbyte byte // dsbyte contains the "domain separation" bits and the first bit of the padding + rate int // the number of bytes of state to use + outputLen int // the default output size in bytes + minimalLength int // lower bound on the length of the input to optimize fixed length hashing } func (d *digest) Write(in []uints.U8) { @@ -39,10 +40,18 @@ func (d *digest) Sum() []uints.U8 { } func (d *digest) FixedLengthSum(length frontend.Variable) []uints.U8 { + comparator := cmp.NewBoundedComparator(d.api, big.NewInt(int64(len(d.in))), false) + // in case the lower bound on the length of input is given, check that the input is long enough + if d.minimalLength > 0 { + comparator.AssertIsLessEq(d.minimalLength, length) + } + padded, numberOfBlocks := d.paddingFixedWidth(length) blocks := d.composeBlocks(padded) + d.absorbingFixedWidth(blocks, numberOfBlocks) + return d.squeezeBlocks() } @@ -67,11 +76,13 @@ func (d *digest) padding() []uints.U8 { func (d *digest) paddingFixedWidth(length frontend.Variable) (padded []uints.U8, numberOfBlocks frontend.Variable) { numberOfBlocks = frontend.Variable(0) - padded = make([]uints.U8, len(d.in)) + maxLen := len(d.in) + padded = make([]uints.U8, maxLen) copy(padded[:], d.in[:]) padded = append(padded, uints.NewU8Array(make([]uint8, d.rate))...) - for i := 0; i <= len(padded)-d.rate; i++ { + // When i < minLen or i > maxLen, it is completely unnecessary + for i := d.minimalLength; i <= maxLen; i++ { reachEnd := cmp.IsEqual(d.api, i, length) switch q := d.rate - ((i) % d.rate); q { case 1: @@ -83,7 +94,7 @@ func (d *digest) paddingFixedWidth(length frontend.Variable) (padded []uints.U8, numberOfBlocks = d.api.Select(reachEnd, (i+2)/d.rate, numberOfBlocks) default: padded[i].Val = d.api.Select(reachEnd, d.dsbyte, padded[i].Val) - for j := 0; j < q-1; j++ { + for j := 0; j < q-2; j++ { padded[i+1+j].Val = d.api.Select(reachEnd, 0, padded[i+1+j].Val) } padded[i+q-1].Val = d.api.Select(reachEnd, 0x80, padded[i+q-1].Val) @@ -119,9 +130,9 @@ func (d *digest) absorbing(blocks [][]uints.U64) { } func (d *digest) absorbingFixedWidth(blocks [][]uints.U64, nbBlocks frontend.Variable) { + minNbOfBlocks := d.minimalLength / d.rate var state [25]uints.U64 var resultState [25]uints.U64 - copy(resultState[:], d.state[:]) copy(state[:], d.state[:]) comparator := cmp.NewBoundedComparator(d.api, big.NewInt(int64(len(blocks))), false) @@ -131,9 +142,18 @@ func (d *digest) absorbingFixedWidth(blocks [][]uints.U64, nbBlocks frontend.Var state[j] = d.uapi.Xor(state[j], block[j]) } state = keccakf.Permute(d.uapi, state) + + // When i < minNbOfBlocks, state cannot be resultState, and proceed to the next loop directly + if i < minNbOfBlocks { + continue + } else if i == minNbOfBlocks { // init resultState + copy(resultState[:], state[:]) + continue + } + isInRange := comparator.IsLess(i, nbBlocks) - // only select blocks that are in range - for j := 0; j < 25; j++ { + // only select blocks that are in range. Only process the first outputLen data relevant to the result + for j := 0; j < d.outputLen/8; j++ { for k := 0; k < 8; k++ { resultState[j][k].Val = d.api.Select(isInRange, state[j][k].Val, resultState[j][k].Val) } diff --git a/std/hash/sha3/sha3_test.go b/std/hash/sha3/sha3_test.go index eeec091338..9e8e7c75c4 100644 --- a/std/hash/sha3/sha3_test.go +++ b/std/hash/sha3/sha3_test.go @@ -16,7 +16,7 @@ import ( ) type testCase struct { - zk func(api frontend.API) (zkhash.BinaryFixedLengthHasher, error) + zk func(api frontend.API, opts ...zkhash.Option) (zkhash.BinaryFixedLengthHasher, error) native func() hash.Hash } @@ -95,6 +95,9 @@ type sha3FixedLengthSumCircuit struct { Expected []uints.U8 Length frontend.Variable hasher string + + // minimal length of the input is the circuit parameter + minimalLength int } func (c *sha3FixedLengthSumCircuit) Define(api frontend.API) error { @@ -102,7 +105,7 @@ func (c *sha3FixedLengthSumCircuit) Define(api frontend.API) error { if !ok { return fmt.Errorf("hash function unknown: %s", c.hasher) } - h, err := newHasher.zk(api) + h, err := newHasher.zk(api, zkhash.WithMinimalLength(c.minimalLength)) if err != nil { return err } @@ -120,8 +123,9 @@ func (c *sha3FixedLengthSumCircuit) Define(api frontend.API) error { } func TestSHA3FixedLengthSum(t *testing.T) { + const maxLen = 310 assert := test.NewAssert(t) - in := make([]byte, 310) + in := make([]byte, maxLen) _, err := rand.Reader.Read(in) assert.NoError(err) @@ -129,29 +133,33 @@ func TestSHA3FixedLengthSum(t *testing.T) { assert.Run(func(assert *test.Assert) { name := name strategy := testCases[name] - for _, length := range []int{0, 1, 31, 32, 33, 135, 136, 137, len(in)} { - assert.Run(func(assert *test.Assert) { - h := strategy.native() - h.Write(in[:length]) - expected := h.Sum(nil) - - circuit := &sha3FixedLengthSumCircuit{ - In: make([]uints.U8, len(in)), - Expected: make([]uints.U8, len(expected)), - Length: 0, - hasher: name, - } - - witness := &sha3FixedLengthSumCircuit{ - In: uints.NewU8Array(in), - Expected: uints.NewU8Array(expected), - Length: length, - } - - if err := test.IsSolved(circuit, witness, ecc.BN254.ScalarField()); err != nil { - t.Fatalf("%s: %s", name, err) - } - }, fmt.Sprintf("length=%d", length)) + nHasher := strategy.native() + for _, lengthBound := range []int{0, 1, nHasher.BlockSize() - 1, nHasher.BlockSize(), nHasher.BlockSize() + 1, len(in)} { + circuit := &sha3FixedLengthSumCircuit{ + In: make([]uints.U8, len(in)), + Expected: make([]uints.U8, nHasher.Size()), + hasher: name, + minimalLength: lengthBound, + } + for _, length := range []int{0, 1, nHasher.BlockSize() - 1, nHasher.BlockSize(), nHasher.BlockSize() + 1, len(in)} { + assert.Run(func(assert *test.Assert) { + h := strategy.native() + h.Write(in[:length]) + expected := h.Sum(nil) + + witness := &sha3FixedLengthSumCircuit{ + In: uints.NewU8Array(in), + Expected: uints.NewU8Array(expected), + Length: length, + } + err := test.IsSolved(circuit, witness, ecc.BN254.ScalarField()) + if length >= lengthBound { + assert.NoError(err) + } else if length < lengthBound { + assert.Error(err, "expected error for length < lengthBound") + } + }, fmt.Sprintf("bound=%d/length=%d", lengthBound, length)) + } } }, fmt.Sprintf("hash=%s", name)) }