diff --git a/frontend/api.go b/frontend/api.go index f12060ff33..4daa79cf33 100644 --- a/frontend/api.go +++ b/frontend/api.go @@ -160,4 +160,7 @@ type BatchInverter interface { type PlonkAPI interface { // EvaluatePlonkExpression returns res = qL.a + qR.b + qM.ab + qC EvaluatePlonkExpression(a, b Variable, qL, qR, qM, qC int) Variable + + // AddPlonkConstraint asserts qL.a + qR.b + qM.ab + qO.o + qC + AddPlonkConstraint(a, b, o Variable, qL, qR, qO, qM, qC int) } diff --git a/frontend/cs/scs/api.go b/frontend/cs/scs/api.go index 18ce8a1359..12ff9be0ce 100644 --- a/frontend/cs/scs/api.go +++ b/frontend/cs/scs/api.go @@ -683,6 +683,37 @@ func (builder *builder) EvaluatePlonkExpression(a, b frontend.Variable, qL, qR, return res } +// AddPlonkConstraint asserts qL.a + qR.b + qO.o + qM.ab + qC = 0 +func (builder *builder) AddPlonkConstraint(a, b, o frontend.Variable, qL, qR, qO, qM, qC int) { + _, aConstant := builder.constantValue(a) + _, bConstant := builder.constantValue(b) + _, oConstant := builder.constantValue(o) + if aConstant || bConstant || oConstant { + builder.AssertIsEqual( + builder.Add( + builder.Mul(a, qL), + builder.Mul(b, qR), + builder.Mul(a, b, qM), + builder.Mul(o, qO), + qC, + ), + 0, + ) + return + } + + builder.addPlonkConstraint(sparseR1C{ + xa: a.(expr.Term).VID, + xb: b.(expr.Term).VID, + xc: o.(expr.Term).VID, + qL: builder.cs.Mul(builder.cs.FromInterface(qL), a.(expr.Term).Coeff), + qR: builder.cs.Mul(builder.cs.FromInterface(qR), b.(expr.Term).Coeff), + qO: builder.cs.Mul(builder.cs.FromInterface(qO), o.(expr.Term).Coeff), + qM: builder.cs.Mul(builder.cs.FromInterface(qM), builder.cs.Mul(a.(expr.Term).Coeff, b.(expr.Term).Coeff)), + qC: builder.cs.FromInterface(qC), + }) +} + func filterConstants(v []frontend.Variable) []frontend.Variable { res := make([]frontend.Variable, 0, len(v)) for _, vI := range v { diff --git a/std/compress/internal/io.go b/std/compress/internal/io.go new file mode 100644 index 0000000000..dc2a37df88 --- /dev/null +++ b/std/compress/internal/io.go @@ -0,0 +1,236 @@ +package internal + +import ( + "errors" + hint "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/compress" + "github.com/consensys/gnark/std/compress/internal/plonk" + "github.com/consensys/gnark/std/lookup/logderivlookup" + "math/big" +) + +// NumReader takes a sequence of words [ b₀ b₁ ... ], along with a base r and length n +// and returns the numbers (b₀ b₁ ... bₙ₋₁)ᵣ, (b₁ b₂ ... bₙ)ᵣ, ... upon successive calls to Next() +type NumReader struct { + api frontend.API + toRead []frontend.Variable + radix int + maxCoeff int + wordsPerNum int + last frontend.Variable +} + +// NewNumReader returns a new NumReader +// toRead is the slice of words to read from +// numNbBits defines the radix as r = 2ⁿᵘᵐᴺᵇᴮⁱᵗˢ (or rather numNbBits = log₂(r) ) +// wordNbBits defines the number of bits in each word such that n = numNbBits/wordNbBits +// it is the caller's responsibility to check 0 ≤ bᵢ < r ∀ i +func NewNumReader(api frontend.API, toRead []frontend.Variable, numNbBits, wordNbBits int) *NumReader { + wordsPerNum := numNbBits / wordNbBits + + if wordsPerNum*wordNbBits != numNbBits { + panic("wordNbBits must be a divisor of 8") + } + + radix := 1 << wordNbBits + return &NumReader{ + api: api, + toRead: toRead, + radix: radix, + maxCoeff: 1 << numNbBits, + wordsPerNum: wordsPerNum, + } +} + +// Next returns the next number in the sequence and advances the reader head by one word. assumes bits past the end of the Slice are 0 +func (nr *NumReader) Next() frontend.Variable { + return nr.next(nil) +} + +// AssertNextEquals is functionally equivalent to +// +// z := nr.Next() +// api.AssertIsEqual(v, z) +// +// while saving exactly one constraint +func (nr *NumReader) AssertNextEquals(v frontend.Variable) { + nr.next(v) +} + +// next returns the next number in the sequence. +// if v != nil, it returns v and asserts it is equal to the next number in the sequence (making a petty saving of one constraint by not creating a new variable) +func (nr *NumReader) next(v frontend.Variable) frontend.Variable { + if len(nr.toRead) == 0 { + return 0 + } + + if nr.last == nil { // the very first call + nr.last = compress.ReadNum(nr.api, nr.toRead[:min(len(nr.toRead), nr.wordsPerNum)], nr.radix) + if v != nil { + nr.api.AssertIsEqual(nr.last, v) + } + return nr.last + } + + // let r := nr.radix, n := log(nr.maxCoeff)ᵣ + // then (b₁ b₂ ... bₙ)ᵣ = r × (b₀ b₁ ... bₙ₋₁)ᵣ - rⁿ × b₀ + bₙ + nr.last = nr.api.Sub(nr.api.Mul(nr.last, nr.radix), nr.api.Mul(nr.toRead[0], nr.maxCoeff)) // r × (b₀ b₁ ... bₙ₋₁)ᵣ - rⁿ × b₀ + if nr.wordsPerNum < len(nr.toRead) { + if v == nil { // return r × (b₀ b₁ ... bₙ₋₁)ᵣ - rⁿ × b₀ + bₙ + nr.last = nr.api.Add(nr.last, nr.toRead[nr.wordsPerNum]) + } else { // assert v = r × (b₀ b₁ ... bₙ₋₁)ᵣ - rⁿ × b₀ + bₙ + plonk.AddConstraint(nr.api, nr.last, nr.toRead[nr.wordsPerNum], v, 1, 1, -1, 0, 0) + nr.last = v + } + } else if v != nil { + panic("todo refactoring required") + } + + nr.toRead = nr.toRead[1:] + return nr.last +} + +// TODO Use std/rangecheck instead +type RangeChecker struct { + api frontend.API + tables map[uint]*logderivlookup.Table +} + +func NewRangeChecker(api frontend.API) *RangeChecker { + return &RangeChecker{api: api, tables: make(map[uint]*logderivlookup.Table)} +} + +func (r *RangeChecker) AssertLessThan(bound uint, c ...frontend.Variable) { + + var check func(frontend.Variable) + switch bound { + case 1: + check = func(v frontend.Variable) { r.api.AssertIsEqual(v, 0) } + case 2: + check = r.api.AssertIsBoolean + case 4: + check = r.api.AssertIsCrumb + default: + cRangeTable, ok := r.tables[bound] + if !ok { + cRangeTable := logderivlookup.New(r.api) + for i := uint(0); i < bound; i++ { + cRangeTable.Insert(0) + } + } + _ = cRangeTable.Lookup(c...) + return + } + for i := range c { + check(c[i]) + } +} + +// IsLessThan returns a variable that is 1 if 0 ≤ c < bound, 0 otherwise +// TODO perf @Tabaie see if we can get away with a weaker contract, where the return value is 0 iff 0 ≤ c < bound +func (r *RangeChecker) IsLessThan(bound uint, c frontend.Variable) frontend.Variable { + switch bound { + case 1: + return r.api.IsZero(c) + } + + if bound%2 != 0 { + panic("odd bounds not yet supported") + } + v := plonk.EvaluateExpression(r.api, c, c, -int(bound-1), 0, 1, 0) // toRead² - (bound-1)× toRead + res := v + for i := uint(1); i < bound/2; i++ { + res = plonk.EvaluateExpression(r.api, res, v, int(i*(bound-i-1)), 0, 1, 0) + } + + return r.api.IsZero(res) +} + +var wordNbBitsToHint = map[int]hint.Hint{1: BreakUpBytesIntoBitsHint, 2: BreakUpBytesIntoCrumbsHint, 4: BreakUpBytesIntoHalfHint} + +// BreakUpBytesIntoWords breaks up bytes into words of size wordNbBits +// It also returns a Slice of bytes which are a reading of the input byte Slice starting from each of the words, thus a super-Slice of the input +// It has the side effect of checking that the input does in fact consist of bytes +// As an example, let the words be bits and the input be the bytes [b₀ b₁ b₂ b₃ b₄ b₅ b₆ b₇], [b₈ b₉ b₁₀ b₁₁ b₁₂ b₁₃ b₁₄ b₁₅] +// Then the output words are b₀, b₁, b₂, b₃, b₄, b₅, b₆, b₇, b₈, b₉, b₁₀, b₁₁, b₁₂, b₁₃, b₁₄, b₁₅ +// The "recombined" output is the slice {[b₀ b₁ b₂ b₃ b₄ b₅ b₆ b₇], [b₁ b₂ b₃ b₄ b₅ b₆ b₇ b₈], ...} +// Note that for any i in range we get recombined[8*i] = bytes[i] +func (r *RangeChecker) BreakUpBytesIntoWords(wordNbBits int, bytes ...frontend.Variable) (words, recombined []frontend.Variable) { + + wordsPerByte := 8 / wordNbBits + if wordsPerByte*wordNbBits != 8 { + panic("wordNbBits must be a divisor of 8") + } + + // solving: break up bytes into words + words = bytes + if wordsPerByte != 1 { + var err error + if words, err = r.api.Compiler().NewHint(wordNbBitsToHint[wordNbBits], wordsPerByte*len(bytes), bytes...); err != nil { + panic(err) + } + } + + // proving: check that words are in range + r.AssertLessThan(1<= 0 { + return errors.New("not a byte") + } + for j := 8/wordNbBits - 1; j >= 0; j-- { + outs[i*8/wordNbBits+j].Mod(&v, wordMod) // todo @tabaie more efficiently + v.Rsh(&v, uint(wordNbBits)) + } + } + + return nil +} + +func BreakUpBytesIntoBitsHint(_ *big.Int, ins, outs []*big.Int) error { + return breakUpBytesIntoWords(1, ins, outs) +} + +func BreakUpBytesIntoCrumbsHint(_ *big.Int, ins, outs []*big.Int) error { + return breakUpBytesIntoWords(2, ins, outs) +} + +func BreakUpBytesIntoHalfHint(_ *big.Int, ins, outs []*big.Int) error { // todo find catchy name for 4 bits + return breakUpBytesIntoWords(4, ins, outs) +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/std/compress/internal/io_test.go b/std/compress/internal/io_test.go new file mode 100644 index 0000000000..b0048a9a3e --- /dev/null +++ b/std/compress/internal/io_test.go @@ -0,0 +1,160 @@ +package internal_test + +import ( + "bytes" + "crypto/rand" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/scs" + "github.com/consensys/gnark/std/compress/internal" + "github.com/consensys/gnark/std/compress/lzss" + "github.com/consensys/gnark/std/math/bits" + test_vector_utils "github.com/consensys/gnark/std/utils/test_vectors_utils" + "github.com/consensys/gnark/test" + "github.com/icza/bitio" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestRecombineBytes(t *testing.T) { + // get some random bytes + _bytes := make([]byte, 50000) + _, err := rand.Read(_bytes) + assert.NoError(t, err) + + // turn them into bits + r := bitio.NewReader(bytes.NewReader(_bytes)) + bits := make([]byte, 8*len(_bytes)) + for i := range bits { + if b := r.TryReadBool(); b { + bits[i] = 1 + } + } + + // turn them back into bytes + recombined := make([]byte, len(bits)) + for i := range recombined { + for j := 0; j < 8 && i+j < len(bits); j++ { + recombined[i] += bits[i+j] << (7 - j) + } + } + assert.NoError(t, r.TryError) + + circuit := recombineBytesCircuit{ + Bytes: make([]frontend.Variable, len(_bytes)), + Bits: make([]frontend.Variable, len(bits)), + Recombined: make([]frontend.Variable, len(recombined)), + } + + assignment := recombineBytesCircuit{ + Bytes: test_vector_utils.ToVariableSlice(_bytes), + Bits: test_vector_utils.ToVariableSlice(bits), + Recombined: test_vector_utils.ToVariableSlice(recombined), + } + + lzss.RegisterHints() + test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithBackends(backend.PLONK), test.WithCurves(ecc.BLS12_377)) +} + +type recombineBytesCircuit struct { + Bytes, Bits, Recombined []frontend.Variable +} + +func (c *recombineBytesCircuit) Define(api frontend.API) error { + r := internal.NewRangeChecker(api) + bits, recombined := r.BreakUpBytesIntoWords(1, c.Bytes...) + if len(bits) != len(c.Bits) { + panic("wrong number of bits") + } + for i := range bits { + api.AssertIsEqual(c.Bits[i], bits[i]) + } + if len(recombined) != len(c.Recombined) { + panic("wrong number of bytes") + } + for i := range recombined { + api.AssertIsEqual(c.Recombined[i], recombined[i]) + } + return nil +} + +func TestRangeChecker_IsLessThan(t *testing.T) { + ins := []frontend.Variable{-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10} + outs := []frontend.Variable{0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0} + circuit := rangeCheckerCircuit{ + Ins: make([]frontend.Variable, len(ins)), + Outs: make([]frontend.Variable, len(outs)), + Range: 8, + } + assignment := rangeCheckerCircuit{ + Ins: ins, + Outs: outs, + } + test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithBackends(backend.GROTH16, backend.PLONK), test.WithCurves(ecc.BLS12_377)) +} + +type rangeCheckerCircuit struct { + Ins, Outs []frontend.Variable + Range uint +} + +func (c *rangeCheckerCircuit) Define(api frontend.API) error { + if len(c.Ins) != len(c.Outs) { + panic("witness length mismatch") + } + r := internal.NewRangeChecker(api) + + for i := range c.Ins { + lt := r.IsLessThan(c.Range, c.Ins[i]) + api.AssertIsEqual(c.Outs[i], lt) + } + + return nil +} + +func TestBreakUpBytesIntoWordsGains(t *testing.T) { + customCircuit := breakUpBytesIntoWordsCustomCircuit{make([]frontend.Variable, 128*1024)} + stdCircuit := breakUpBytesIntoWordsStdCircuit{make([]frontend.Variable, 128*1024)} + + csCustom, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &customCircuit) + assert.NoError(t, err) + + csStd, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &stdCircuit) + assert.NoError(t, err) + + customNbConstraints := csCustom.GetNbConstraints() + stdNbConstraints := csStd.GetNbConstraints() + + assert.Greater(t, stdNbConstraints-customNbConstraints, 1000000, "custom circuit must save at least 1M constraints") + assert.LessOrEqual(t, 100*customNbConstraints/stdNbConstraints, 75, "custom circuit should achieve at least a 25%% reduction in constraints") +} + +type breakUpBytesIntoWordsCircuit struct { + Bytes []frontend.Variable +} + +type breakUpBytesIntoWordsStdCircuit breakUpBytesIntoWordsCircuit +type breakUpBytesIntoWordsCustomCircuit breakUpBytesIntoWordsCircuit + +func (c *breakUpBytesIntoWordsStdCircuit) Define(api frontend.API) error { + words := make([]frontend.Variable, 0, len(c.Bytes)*8) + for _, _byte := range c.Bytes { + words = append(words, + bits.ToBase(api, bits.Binary, _byte, bits.WithNbDigits(8), bits.WithUnconstrainedInputs(), bits.OmitModulusCheck())..., + ) + } + + _bytes := make([]frontend.Variable, len(words)) + r := internal.NewNumReader(api, words, 8, 1) + for i := range words { + _bytes[i] = r.Next() + } + return nil +} + +func (c *breakUpBytesIntoWordsCustomCircuit) Define(api frontend.API) error { + r := internal.NewRangeChecker(api) + _, _ = r.BreakUpBytesIntoWords(1, c.Bytes...) + return nil +} diff --git a/std/compress/internal/plonk/plonk.go b/std/compress/internal/plonk/plonk.go new file mode 100644 index 0000000000..77d8e0a29a --- /dev/null +++ b/std/compress/internal/plonk/plonk.go @@ -0,0 +1,27 @@ +package plonk + +import "github.com/consensys/gnark/frontend" + +func EvaluateExpression(api frontend.API, a, b frontend.Variable, aCoeff, bCoeff, mCoeff, constant int) frontend.Variable { + if plonkAPI, ok := api.(frontend.PlonkAPI); ok { + return plonkAPI.EvaluatePlonkExpression(a, b, aCoeff, bCoeff, mCoeff, constant) + } + return api.Add(api.Mul(a, aCoeff), api.Mul(b, bCoeff), api.Mul(mCoeff, a, b), constant) +} + +func AddConstraint(api frontend.API, a, b, o frontend.Variable, qL, qR, qO, qM, qC int) { + if papi, ok := api.(frontend.PlonkAPI); ok { + papi.AddPlonkConstraint(a, b, o, qL, qR, qO, qM, qC) + } else { + api.AssertIsEqual( + api.Add( + api.Mul(a, qL), + api.Mul(b, qR), + api.Mul(a, b, qM), + api.Mul(o, qO), + qC, + ), + 0, + ) + } +} diff --git a/std/compress/internal/plonk/plonk_test.go b/std/compress/internal/plonk/plonk_test.go new file mode 100644 index 0000000000..1023d1ee2a --- /dev/null +++ b/std/compress/internal/plonk/plonk_test.go @@ -0,0 +1,132 @@ +package plonk + +import ( + "crypto/rand" + "encoding/binary" + "errors" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/frontend" + test_vector_utils "github.com/consensys/gnark/std/utils/test_vectors_utils" + "github.com/consensys/gnark/test" + "reflect" + "testing" +) + +func TestCustomConstraint(t *testing.T) { + const nbCases = 1000 + + // only testing cases with qO = -1 + + circuit := customConstraintCircuit{ + A: make([]frontend.Variable, nbCases), + B: make([]frontend.Variable, nbCases), + O: make([]frontend.Variable, nbCases), + mode: make([]int, nbCases), + aVal: make(fr.Vector, nbCases), + bVal: make(fr.Vector, nbCases), + oVal: make(fr.Vector, nbCases), + qC: make([]int, nbCases), + qL: make([]int, nbCases), + qR: make([]int, nbCases), + qM: make([]int, nbCases), + } + + assignment := customConstraintCircuit{ + A: make([]frontend.Variable, nbCases), + B: make([]frontend.Variable, nbCases), + O: make([]frontend.Variable, nbCases), + } + + randomizeInts(circuit.qC, circuit.qL, circuit.qR, circuit.qM) + randomizeElems(circuit.aVal, circuit.bVal) + + var sum, summand fr.Element + for i := range circuit.A { + circuit.mode[i] = i % 8 + + sum.SetInt64(int64(circuit.qC[i])) + + summand.SetInt64(int64(circuit.qL[i])) + summand.Mul(&summand, &circuit.aVal[i]) + sum.Add(&sum, &summand) + + summand.SetInt64(int64(circuit.qR[i])) + summand.Mul(&summand, &circuit.bVal[i]) + sum.Add(&sum, &summand) + + summand.SetInt64(int64(circuit.qM[i])) + summand.Mul(&summand, &circuit.aVal[i]).Mul(&summand, &circuit.bVal[i]) + sum.Add(&sum, &summand) + + assignment.O[i] = sum + circuit.oVal[i] = sum + } + + assignment.A = test_vector_utils.ToVariableSlice(circuit.aVal) + assignment.B = test_vector_utils.ToVariableSlice(circuit.bVal) + + test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithBackends(backend.PLONK), test.WithCurves(ecc.BLS12_377)) +} + +func randomizeInts(slices ...[]int) { + var buff [8]byte + for _, slice := range slices { + for i := range slice { + if _, err := rand.Read(buff[:]); err != nil { + panic(err) + } + neg := 1 - 2*int(buff[0]>>7) + buff[0] &= 127 + slice[i] = int(binary.BigEndian.Uint64(buff[:])) * neg + } + } +} + +func randomizeElems(vectors ...fr.Vector) { + for _, vector := range vectors { + for i := range vector { + if _, err := vector[i].SetRandom(); err != nil { + panic(err) + } + } + } +} + +type customConstraintCircuit struct { + A, B []frontend.Variable + O []frontend.Variable + aVal, bVal, oVal fr.Vector + mode, qC, qL, qR, qM []int +} + +func ifConstThenElse(api frontend.API, isConst int, val fr.Element, _var frontend.Variable) frontend.Variable { + api.AssertIsEqual(val, _var) + + if isConst != 0 { + return val + } + + return _var +} + +func (c *customConstraintCircuit) Define(api frontend.API) error { + slices := []interface{}{c.B, c.O, c.mode, c.aVal, c.bVal, c.oVal, c.qC, c.qL, c.qR, c.qM} + for _, slice := range slices { + if reflect.ValueOf(slice).Len() != len(c.A) { + return errors.New("inconsistent lengths") + } + } + + for i := range c.A { + a, b, o := ifConstThenElse(api, c.mode[i]&1, c.aVal[i], c.A[i]), ifConstThenElse(api, c.mode[i]&2, c.bVal[i], c.B[i]), ifConstThenElse(api, c.mode[i]&4, c.oVal[i], c.O[i]) + + _o := EvaluateExpression(api, a, b, c.qL[i], c.qR[i], c.qM[i], c.qC[i]) + api.AssertIsEqual(_o, o) + + AddConstraint(api, a, b, o, c.qL[i], c.qR[i], -1, c.qM[i], c.qC[i]) + } + + return nil +} diff --git a/std/compress/io.go b/std/compress/io.go new file mode 100644 index 0000000000..ff3d54710b --- /dev/null +++ b/std/compress/io.go @@ -0,0 +1,145 @@ +package compress + +import ( + "errors" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/compress/internal/plonk" + "github.com/consensys/gnark/std/hash/mimc" + "github.com/consensys/gnark/std/lookup/logderivlookup" + "hash" + "math/big" +) + +// Pack packs the words as tightly as possible, and works Big Endian: i.e. the first word is the most significant in the packed elem +// it is on the caller to make sure the words are within range +func Pack(api frontend.API, words []frontend.Variable, bitsPerWord int) []frontend.Variable { + return PackN(api, words, bitsPerWord, (api.Compiler().FieldBitLen()-1)/bitsPerWord) +} + +// PackN packs the words wordsPerElem at a time into field elements, and works Big Endian: i.e. the first word is the most significant in the packed elem +// it is on the caller to make sure the words are within range +func PackN(api frontend.API, words []frontend.Variable, bitsPerWord, wordsPerElem int) []frontend.Variable { + res := make([]frontend.Variable, (len(words)+wordsPerElem-1)/wordsPerElem) + + r := make([]big.Int, wordsPerElem) + r[wordsPerElem-1].SetInt64(1) + for i := wordsPerElem - 2; i >= 0; i-- { + r[i].Lsh(&r[i+1], uint(bitsPerWord)) + } + + for elemI := range res { + res[elemI] = 0 + for wordI := 0; wordI < wordsPerElem; wordI++ { + absWordI := elemI*wordsPerElem + wordI + if absWordI >= len(words) { + break + } + res[elemI] = api.Add(res[elemI], api.Mul(words[absWordI], r[wordI])) + } + } + return res +} + +// AssertChecksumEquals takes a MiMC hash of e and asserts it is equal to checksum +func AssertChecksumEquals(api frontend.API, e []frontend.Variable, checksum frontend.Variable) error { + hsh, err := mimc.NewMiMC(api) + if err != nil { + return err + } + hsh.Write(e...) + api.AssertIsEqual(hsh.Sum(), checksum) + return nil +} + +// ChecksumPaddedBytes packs b into field elements, then hashes the field elements along with validLength (encoded into a field element of its own) +func ChecksumPaddedBytes(b []byte, validLength int, hsh hash.Hash, fieldNbBits int) []byte { + if validLength < 0 || validLength > len(b) { + panic("invalid length") + } + usableBytesPerElem := (fieldNbBits+7)/8 - 1 + buf := make([]byte, usableBytesPerElem+1) + for i := 0; i < len(b); i += usableBytesPerElem { + copy(buf[1:], b[i:]) + for j := usableBytesPerElem; j+i > len(b) && j > 0; j-- { + buf[j] = 0 + } + hsh.Write(buf) + } + big.NewInt(int64(validLength)).FillBytes(buf) + hsh.Write(buf) + + return hsh.Sum(nil) +} + +// UnpackIntoBytes construes every element in packed as consisting of bytesPerElem bytes, returning those bytes +// it DOES NOT prove that the elements in unpacked are actually bytes +// nbBytes is the number of "valid" bytes according to the padding scheme in https://github.com/Consensys/zkevm-monorepo/blob/main/prover/lib/compressor/blob/blob_maker.go#L299 +// TODO @tabaie @gbotrel move the padding/packing code to gnark or compress +// the very last non-zero byte in the unpacked stream is meant to encode the number of unused bytes in the last field element used. +// though UnpackIntoBytes includes that last byte in unpacked, it is not counted in nbBytes +func UnpackIntoBytes(api frontend.API, bytesPerElem int, packed []frontend.Variable) (unpacked []frontend.Variable, nbBytes frontend.Variable, err error) { + if unpacked, err = api.Compiler().NewHint(UnpackIntoBytesHint, bytesPerElem*len(packed), packed...); err != nil { + return + } + found := frontend.Variable(0) + nbBytes = frontend.Variable(0) + for i := len(unpacked) - 1; i >= 0; i-- { + + z := api.IsZero(unpacked[i]) + + lastNonZero := plonk.EvaluateExpression(api, z, found, -1, -1, 1, 1) // nz - found + nbBytes = api.Add(nbBytes, api.Mul(lastNonZero, frontend.Variable(i))) // the last nonzero byte itself is useless + + //api.AssertIsEqual(api.Mul(api.Sub(bytesPerElem-i%bytesPerElem, unpacked[i]), lastNonZero), 0) // sanity check, technically unnecessary TODO @Tabaie make sure it's one constraint only or better yet, remove + + found = plonk.EvaluateExpression(api, z, found, -1, 0, 1, 1) // found ? 1 : nz = nz + found (1 - nz) = 1 - z + found z + } + return +} + +func UnpackIntoBytesHint(_ *big.Int, ins, outs []*big.Int) error { + bytesPerElem := len(outs) / len(ins) + if len(ins)*bytesPerElem != len(outs) { + return errors.New("in length must divide out length") + } + _256 := big.NewInt(256) + var v big.Int + for i := range ins { + v.Set(ins[i]) + for j := bytesPerElem - 1; j >= 0; j-- { + v.DivMod(&v, _256, outs[i*bytesPerElem+j]) + } + } + return nil +} + +// ReadNum reads the slice c as a big endian number in base radix +func ReadNum(api frontend.API, c []frontend.Variable, radix int) frontend.Variable { + if len(c) == 0 { + return 0 + } + + res := c[0] + for i := 1; i < len(c); i++ { + res = api.Add(c[i], api.Mul(res, radix)) + } + + return res +} + +// ShiftLeft erases shiftAmount many elements from the left of Slice and replaces them in the right with zeros +// it is the caller's responsibility to make sure that 0 \le shift < len(c) +func ShiftLeft(api frontend.API, slice []frontend.Variable, shiftAmount frontend.Variable) []frontend.Variable { + res := make([]frontend.Variable, len(slice)) + l := logderivlookup.New(api) + for i := range slice { + l.Insert(slice[i]) + } + for range slice { + l.Insert(0) + } + for i := range slice { + res[i] = l.Lookup(api.Add(i, shiftAmount))[0] + } + return res +} diff --git a/std/compress/io_test.go b/std/compress/io_test.go new file mode 100644 index 0000000000..d69813b063 --- /dev/null +++ b/std/compress/io_test.go @@ -0,0 +1,119 @@ +package compress + +import ( + "crypto/rand" + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark-crypto/hash" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/scs" + "github.com/consensys/gnark/profile" + test_vector_utils "github.com/consensys/gnark/std/utils/test_vectors_utils" + "github.com/consensys/gnark/test" + "github.com/stretchr/testify/assert" + "math/big" + "testing" +) + +func TestShiftLeft(t *testing.T) { + for n := 4; n < 20; n++ { + b := make([]byte, n) + _, err := rand.Read(b) + assert.NoError(t, err) + + shiftAmount, err := rand.Int(rand.Reader, big.NewInt(int64(n))) + assert.NoError(t, err) + + shifted := make([]byte, n) + for i := range shifted { + if j := i + int(shiftAmount.Int64()); j < len(shifted) { + shifted[i] = b[j] + } else { + shifted[i] = 0 + } + } + + circuit := shiftLeftCircuit{ + Slice: make([]frontend.Variable, len(b)), + Shifted: make([]frontend.Variable, len(shifted)), + } + + assignment := shiftLeftCircuit{ + Slice: test_vector_utils.ToVariableSlice(b), + Shifted: test_vector_utils.ToVariableSlice(shifted), + ShiftAmount: shiftAmount, + } + + test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithBackends(backend.PLONK), test.WithCurves(ecc.BLS12_377)) + } +} + +func BenchmarkShiftLeft(b *testing.B) { + const n = 128 * 1024 + + circuit := shiftLeftCircuit{ + Slice: make([]frontend.Variable, n), + Shifted: make([]frontend.Variable, n), + } + + p := profile.Start() + cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &circuit) + assert.NoError(b, err) + p.Stop() + fmt.Println(cs.GetNbConstraints(), "constraints") +} + +type shiftLeftCircuit struct { + Slice []frontend.Variable + Shifted []frontend.Variable + ShiftAmount frontend.Variable +} + +func (c *shiftLeftCircuit) Define(api frontend.API) error { + if len(c.Slice) != len(c.Shifted) { + panic("witness length mismatch") + } + shifted := ShiftLeft(api, c.Slice, c.ShiftAmount) + if len(shifted) != len(c.Shifted) { + panic("wrong length") + } + for i := range shifted { + api.AssertIsEqual(c.Shifted[i], shifted[i]) + } + return nil +} + +func TestChecksumBytes(t *testing.T) { + + for n := 1; n < 100; n++ { + b := make([]byte, n) + _, err := rand.Read(b) + assert.NoError(t, err) + + checksum := ChecksumPaddedBytes(b, len(b), hash.MIMC_BLS12_377.New(), fr.Bits) + + circuit := checksumTestCircuit{ + Bytes: make([]frontend.Variable, len(b)), + } + + assignment := checksumTestCircuit{ + Bytes: test_vector_utils.ToVariableSlice(b), + Sum: checksum, + } + + test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithBackends(backend.PLONK), test.WithCurves(ecc.BLS12_377)) + + } +} + +type checksumTestCircuit struct { + Bytes []frontend.Variable + Sum frontend.Variable +} + +func (c *checksumTestCircuit) Define(api frontend.API) error { + Packed := append(Pack(api, c.Bytes, 8), len(c.Bytes)) + return AssertChecksumEquals(api, Packed, c.Sum) +} diff --git a/std/compress/lzss/e2e_test.go b/std/compress/lzss/e2e_test.go deleted file mode 100644 index 4ad2fde7db..0000000000 --- a/std/compress/lzss/e2e_test.go +++ /dev/null @@ -1,116 +0,0 @@ -package lzss - -import ( - goCompress "github.com/consensys/compress" - "github.com/consensys/compress/lzss" - "os" - "testing" - - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/frontend" - test_vector_utils "github.com/consensys/gnark/std/utils/test_vectors_utils" - "github.com/consensys/gnark/test" - "github.com/stretchr/testify/assert" -) - -func TestCompression1ZeroE2E(t *testing.T) { - testCompressionE2E(t, []byte{0}, nil, "1_zero") -} - -func BenchmarkCompression26KBE2E(b *testing.B) { - _, err := BenchCompressionE2ECompilation(nil, "./testdata/3c2943") - assert.NoError(b, err) -} - -func testCompressionE2E(t *testing.T, d, dict []byte, name string) { - if d == nil { - var err error - d, err = os.ReadFile("./testdata/" + name + "/data.bin") - assert.NoError(t, err) - } - - // compress - - level := lzss.GoodCompression - compressor, err := lzss.NewCompressor(dict, level) - assert.NoError(t, err) - - c, err := compressor.Compress(d) - assert.NoError(t, err) - - cStream, err := goCompress.NewStream(c, uint8(level)) - assert.NoError(t, err) - - cSum, err := check(cStream, cStream.Len()) - assert.NoError(t, err) - - dStream, err := goCompress.NewStream(d, 8) - assert.NoError(t, err) - - dSum, err := check(dStream, len(d)) - assert.NoError(t, err) - - dict = lzss.AugmentDict(dict) - - dictStream, err := goCompress.NewStream(dict, 8) - assert.NoError(t, err) - - dictSum, err := check(dictStream, len(dict)) - assert.NoError(t, err) - - circuit := TestCompressionCircuit{ - C: make([]frontend.Variable, cStream.Len()), - D: make([]frontend.Variable, len(d)), - Dict: make([]frontend.Variable, len(dict)), - Level: level, - } - - // solve the circuit or only compile it - - assignment := TestCompressionCircuit{ - CChecksum: cSum, - DChecksum: dSum, - DictChecksum: dictSum, - C: test_vector_utils.ToVariableSlice(cStream.D), - D: test_vector_utils.ToVariableSlice(d), - Dict: test_vector_utils.ToVariableSlice(dict), - CLen: cStream.Len(), - DLen: len(d), - } - test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithBackends(backend.PLONK), test.WithCurves(ecc.BLS12_377)) -} - -func TestChecksum0(t *testing.T) { - testChecksum(t, goCompress.Stream{D: []int{}, NbSymbs: 256}) -} - -func testChecksum(t *testing.T, d goCompress.Stream) { - circuit := checksumTestCircuit{ - Inputs: make([]frontend.Variable, d.Len()), - InputLen: d.Len(), - } - - sum, err := check(d, d.Len()) - assert.NoError(t, err) - - assignment := checksumTestCircuit{ - Inputs: test_vector_utils.ToVariableSlice(d.D), - InputLen: d.Len(), - Sum: sum, - } - test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithBackends(backend.PLONK), test.WithCurves(ecc.BLS12_377)) -} - -type checksumTestCircuit struct { - Inputs []frontend.Variable - InputLen frontend.Variable - Sum frontend.Variable -} - -func (c *checksumTestCircuit) Define(api frontend.API) error { - if err := checkSnark(api, c.Inputs, len(c.Inputs), c.Sum); err != nil { - return err - } - return nil -} diff --git a/std/compress/lzss/snark.go b/std/compress/lzss/snark.go index 1143cfdf35..a6945772dc 100644 --- a/std/compress/lzss/snark.go +++ b/std/compress/lzss/snark.go @@ -2,48 +2,59 @@ package lzss import ( "github.com/consensys/compress/lzss" + hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/compress" + "github.com/consensys/gnark/std/compress/internal" + "github.com/consensys/gnark/std/compress/internal/plonk" "github.com/consensys/gnark/std/lookup/logderivlookup" ) // Decompress decompresses c into d using dict as the dictionary // which must come pre "augmented" +// it is on the caller to ensure that the dictionary is correct; in particular it must consist of bytes. Decompress does not check this. +// it is recommended to pack the dictionary using compress.Pack and take a MiMC checksum of it. +// d will consist of bytes // It returns the length of d as a frontend.Variable func Decompress(api frontend.API, c []frontend.Variable, cLength frontend.Variable, d, dict []frontend.Variable, level lzss.Level) (dLength frontend.Variable, err error) { + // size-related "constants" wordNbBits := int(level) - - // ensure input is in range - checkInputRange(api, c, wordNbBits) - - // init the dictionary and backref types - shortBackRefType, longBackRefType, dictBackRefType := lzss.InitBackRefTypes(len(dict), level) - + shortBackRefType, longBackRefType, dictBackRefType := lzss.InitBackRefTypes(len(dict), level) // init the dictionary and backref types; only needed for the constants below shortBrNbWords := int(shortBackRefType.NbBitsBackRef) / wordNbBits longBrNbWords := int(longBackRefType.NbBitsBackRef) / wordNbBits dictBrNbWords := int(dictBackRefType.NbBitsBackRef) / wordNbBits - byteNbWords := 8 / wordNbBits - - const sizeHeader = 3 // TODO @tabaie @gbotrel Handle this outside the circuit instead - - api.AssertIsEqual(compress.ReadNum(api, c, (sizeHeader-1)*byteNbWords, wordNbBits), 0) // compressor version TODO @tabaie @gbotrel Handle this outside the circuit instead? - fileCompressionMode := compress.ReadNum(api, c[(sizeHeader-1)*byteNbWords:], byteNbWords, wordNbBits) + byteNbWords := uint(8 / wordNbBits) + + // check header: version and compression level + const ( + sizeHeader = 3 + version = 0 + ) + api.AssertIsEqual(c[0], version/256) + api.AssertIsEqual(c[1], version%256) + fileCompressionMode := c[2] api.AssertIsEqual(api.Mul(fileCompressionMode, fileCompressionMode), api.Mul(fileCompressionMode, wordNbBits)) // if fcm!=0, then fcm=wordNbBits decompressionNotBypassed := api.Sub(1, api.IsZero(fileCompressionMode)) - c = c[sizeHeader*byteNbWords:] - cLength = api.Sub(cLength, sizeHeader*byteNbWords) + // check that the input is in range and convert into small words + rangeChecker := internal.NewRangeChecker(api) + bytes := make([]frontend.Variable, len(c)-sizeHeader+1) + copy(bytes, c[sizeHeader:]) + bytes[len(bytes)-1] = 0 // pad with a zero to avoid out of range errors + c, bytes = rangeChecker.BreakUpBytesIntoWords(wordNbBits, bytes...) // from this point on c is in words + cLength = api.Mul(api.Sub(cLength, sizeHeader), 8/wordNbBits) // one constraint; insignificant impact anyway + + // create a random-access table to be referenced outTable := logderivlookup.New(api) for i := range dict { outTable.Insert(dict[i]) } // formatted input - bytes := combineIntoBytes(api, c, wordNbBits) bytesTable := sliceToTable(api, bytes) - bytesTable.Insert(0) // just because we use this table for looking up backref lengths as well + addrTable := initAddrTable(api, bytes, c, wordNbBits, []lzss.BackrefType{shortBackRefType, longBackRefType, dictBackRefType}) // state variables @@ -73,7 +84,7 @@ func Decompress(api frontend.API, c []frontend.Variable, cLength frontend.Variab // copying = copyLen01 ? copyLen==1 : 1 either from previous iterations or starting a new copy // copying = copyLen01 ? copyLen : 1 - copying := evaluatePlonkExpression(api, copyLen01, copyLen, -1, 0, 1, 1) + copying := plonk.EvaluateExpression(api, copyLen01, copyLen, -1, 0, 1, 1) copyAddr := api.Mul(api.Sub(outI+len(dict)-1, currIndicatedCpAddr), currIndicatesBr) dictCopyAddr := api.Add(currIndicatedCpAddr, api.Sub(currIndicatedCpLen, copyLen)) @@ -81,7 +92,9 @@ func Decompress(api frontend.API, c []frontend.Variable, cLength frontend.Variab toCopy := outTable.Lookup(copyAddr)[0] // write to output - d[outI] = api.Select(copying, toCopy, curr) + outVal := api.Select(copying, toCopy, curr) + // TODO previously the last byte of the output kept getting repeated. That can be worked with. If there was a reason to save some 600K constraints in the zkEVM decompressor, take this out again + d[outI] = plonk.EvaluateExpression(api, outVal, eof, 1, 0, -1, 0) // write zeros past eof // WARNING: curr modified by MulAcc outTable.Insert(d[outI]) @@ -96,10 +109,10 @@ func Decompress(api frontend.API, c []frontend.Variable, cLength frontend.Variab if eof == 0 { inI = api.Add(inI, inIDelta) } else { - inI = api.Add(inI, evaluatePlonkExpression(api, inIDelta, eof, 1, 0, -1, 0)) // if eof, stay put + inI = api.Add(inI, plonk.EvaluateExpression(api, inIDelta, eof, 1, 0, -1, 0)) // if eof, stay put } - eofNow := api.IsZero(api.Sub(inI, cLength)) + eofNow := rangeChecker.IsLessThan(byteNbWords, api.Sub(cLength, inI)) // less than a byte left; meaning we are at the end of the input dLength = api.Add(dLength, api.Mul(api.Sub(eofNow, eof), outI+1)) // if eof, don't advance dLength eof = eofNow @@ -108,29 +121,6 @@ func Decompress(api frontend.API, c []frontend.Variable, cLength frontend.Variab return dLength, nil } -func checkInputRange(api frontend.API, c []frontend.Variable, wordNbBits int) { - if wordNbBits > 2 { - cRangeTable := logderivlookup.New(api) - for i := 0; i < 1< lzss.MaxInputSize { t.Skip("input too large") @@ -84,79 +115,63 @@ func FuzzSnark(f *testing.F) { // TODO This is always skipped if len(input) == 0 { t.Skip("input too small") } - testCompressionRoundTripSnark(t, input, dict) + testCompressionRoundTrip(t, input, dict) }) } -type testCompressionRoundTripOption func(*lzss.Level) +type testCompressionRoundTripSettings struct { + level lzss.Level + cBegin int +} + +type testCompressionRoundTripOption func(settings *testCompressionRoundTripSettings) func withLevel(level lzss.Level) testCompressionRoundTripOption { - return func(l *lzss.Level) { - *l = level + return func(s *testCompressionRoundTripSettings) { + s.level = level } } -func testCompressionRoundTripSnark(t *testing.T, d, dict []byte, options ...testCompressionRoundTripOption) { +func withCBegin(cBegin int) testCompressionRoundTripOption { + return func(s *testCompressionRoundTripSettings) { + s.cBegin = cBegin + } +} - level := lzss.BestCompression +func testCompressionRoundTrip(t *testing.T, d, dict []byte, options ...testCompressionRoundTripOption) { + + settings := testCompressionRoundTripSettings{ + level: lzss.BestCompression, + } for _, option := range options { - option(&level) + option(&settings) } - compressor, err := lzss.NewCompressor(dict, level) + compressor, err := lzss.NewCompressor(dict, settings.level) require.NoError(t, err) c, err := compressor.Compress(d) require.NoError(t, err) - cStream, err := lzss.ReadIntoStream(c, dict, level) - require.NoError(t, err) + //assert.NoError(t, os.WriteFile("compress.csv", lzss.CompressedStreamInfo(c, dict).ToCsv(), 0644)) circuit := &DecompressionTestCircuit{ - C: make([]frontend.Variable, cStream.Len()), + C: make([]frontend.Variable, len(c)+inputExtraBytes), D: d, Dict: dict, CheckCorrectness: true, - Level: level, + Level: settings.level, } assignment := &DecompressionTestCircuit{ - C: test_vector_utils.ToVariableSlice(cStream.D), - CLength: cStream.Len(), + C: test_vector_utils.ToVariableSlice(append(c, make([]byte, inputExtraBytes)...)), + CBegin: settings.cBegin, + CLength: len(c), } + RegisterHints() test.NewAssert(t).CheckCircuit(circuit, test.WithValidAssignment(assignment), test.WithBackends(backend.PLONK), test.WithCurves(ecc.BLS12_377)) } -func TestReadBytes(t *testing.T) { - expected := []byte{254, 0, 0, 0} - circuit := &readBytesCircuit{ - Words: make([]frontend.Variable, 8*len(expected)), - WordNbBits: 1, - Expected: expected, - } - words, err := goCompress.NewStream(expected, 8) - assert.NoError(t, err) - words = words.BreakUp(2) - assignment := &readBytesCircuit{ - Words: test_vector_utils.ToVariableSlice(words.D), - } - test.NewAssert(t).CheckCircuit(circuit, test.WithValidAssignment(assignment), test.WithBackends(backend.PLONK), test.WithCurves(ecc.BLS12_377)) -} - -type readBytesCircuit struct { - Words []frontend.Variable - WordNbBits int - Expected []byte -} - -func (c *readBytesCircuit) Define(api frontend.API) error { - byts := combineIntoBytes(api, c.Words, c.WordNbBits) - for i := range c.Expected { - api.AssertIsEqual(c.Expected[i], byts[i*8]) - } - return nil -} - func getDictionary() []byte { d, err := os.ReadFile("./testdata/dict_naive") if err != nil { diff --git a/std/compress/lzss/snark_testing.go b/std/compress/lzss/snark_testing.go index 179742d915..1d1d20b7c6 100644 --- a/std/compress/lzss/snark_testing.go +++ b/std/compress/lzss/snark_testing.go @@ -1,28 +1,17 @@ package lzss import ( - "compress/gzip" - "fmt" - goCompress "github.com/consensys/compress" "github.com/consensys/compress/lzss" - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - "github.com/consensys/gnark-crypto/hash" - "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs/scs" - "github.com/consensys/gnark/profile" "github.com/consensys/gnark/std/compress" - "github.com/consensys/gnark/std/hash/mimc" test_vector_utils "github.com/consensys/gnark/std/utils/test_vectors_utils" - "os" - "time" ) type DecompressionTestCircuit struct { C []frontend.Variable D []byte Dict []byte + CBegin frontend.Variable CLength frontend.Variable CheckCorrectness bool Level lzss.Level @@ -31,6 +20,9 @@ type DecompressionTestCircuit struct { func (c *DecompressionTestCircuit) Define(api frontend.API) error { dict := test_vector_utils.ToVariableSlice(lzss.AugmentDict(c.Dict)) dBack := make([]frontend.Variable, len(c.D)) // TODO Try smaller constants + if cb, ok := c.CBegin.(int); !ok || cb != 0 { + c.C = compress.ShiftLeft(api, c.C, c.CBegin) + } dLen, err := Decompress(api, c.C, c.CLength, dBack, dict, c.Level) if err != nil { return err @@ -43,140 +35,3 @@ func (c *DecompressionTestCircuit) Define(api frontend.API) error { } return nil } - -func BenchCompressionE2ECompilation(dict []byte, name string) (constraint.ConstraintSystem, error) { - d, err := os.ReadFile(name + "/data.bin") - if err != nil { - return nil, err - } - - // compress - - level := lzss.GoodCompression - - compressor, err := lzss.NewCompressor(dict, level) - if err != nil { - return nil, err - } - - c, err := compressor.Compress(d) - if err != nil { - return nil, err - } - - cStream, err := goCompress.NewStream(c, uint8(level)) - if err != nil { - return nil, err - } - - circuit := TestCompressionCircuit{ - C: make([]frontend.Variable, cStream.Len()), - D: make([]frontend.Variable, len(d)), - Dict: make([]frontend.Variable, len(lzss.AugmentDict(dict))), - Level: level, - } - - var start int64 - resetTimer := func() { - end := time.Now().UnixMilli() - if start != 0 { - fmt.Println(end-start, "ms") - } - start = end - } - - // compilation - fmt.Println("compilation") - p := profile.Start() - resetTimer() - cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &circuit, frontend.WithCapacity(70620000*2)) - if err != nil { - return nil, err - } - p.Stop() - fmt.Println(1+len(d)/1024, "KB:", p.NbConstraints(), "constraints, estimated", (p.NbConstraints()*600000)/len(d), "constraints for 600KB at", float64(p.NbConstraints())/float64(len(d)), "constraints per uncompressed byte") - resetTimer() - - outFile, err := os.OpenFile("./testdata/test_cases/"+name+"/e2e_cs.gz", os.O_CREATE, 0600) - closeFile := func() { - if err := outFile.Close(); err != nil { - panic(err) - } - } - defer closeFile() - if err != nil { - return nil, err - } - gz := gzip.NewWriter(outFile) - closeZip := func() { - if err := gz.Close(); err != nil { - panic(err) - } - } - defer closeZip() - if _, err = cs.WriteTo(gz); err != nil { - return nil, err - } - return cs, gz.Close() -} - -type TestCompressionCircuit struct { - CChecksum, DChecksum, DictChecksum frontend.Variable `gnark:",public"` - C []frontend.Variable - D []frontend.Variable - Dict []frontend.Variable - CLen, DLen frontend.Variable - Level lzss.Level -} - -func (c *TestCompressionCircuit) Define(api frontend.API) error { - - fmt.Println("packing") - cPacked := compress.Pack(api, c.C, int(c.Level)) - dPacked := compress.Pack(api, c.D, 8) - dictPacked := compress.Pack(api, c.Dict, 8) - - fmt.Println("computing checksum") - if err := checkSnark(api, cPacked, c.CLen, c.CChecksum); err != nil { - return err - } - if err := checkSnark(api, dPacked, c.DLen, c.DChecksum); err != nil { - return err - } - if err := checkSnark(api, dictPacked, len(c.Dict), c.DictChecksum); err != nil { - return err - } - - fmt.Println("decompressing") - dComputed := make([]frontend.Variable, len(c.D)) - if dComputedLen, err := Decompress(api, c.C, c.CLen, dComputed, c.Dict, c.Level); err != nil { - return err - } else { - api.AssertIsEqual(dComputedLen, c.DLen) - for i := range c.D { - api.AssertIsEqual(c.D[i], dComputed[i]) // could do this much more efficiently in groth16 using packing :( - } - } - - return nil -} - -func check(s goCompress.Stream, padTo int) (checksum fr.Element, err error) { - - s.D = append(s.D, make([]int, padTo-len(s.D))...) - - csb := s.Checksum(hash.MIMC_BLS12_377.New(), fr.Bits) - checksum.SetBytes(csb) - return -} - -func checkSnark(api frontend.API, e []frontend.Variable, eLen, checksum frontend.Variable) error { - hsh, err := mimc.NewMiMC(api) - if err != nil { - return err - } - hsh.Write(e...) - hsh.Write(eLen) - api.AssertIsEqual(hsh.Sum(), checksum) - return nil -} diff --git a/std/compress/snark_io.go b/std/compress/snark_io.go deleted file mode 100644 index c72cf0585c..0000000000 --- a/std/compress/snark_io.go +++ /dev/null @@ -1,69 +0,0 @@ -package compress - -import ( - "github.com/consensys/gnark/frontend" -) - -func Pack(api frontend.API, words []frontend.Variable, wordLen int) []frontend.Variable { - wordsPerElem := (api.Compiler().FieldBitLen() - 1) / wordLen - res := make([]frontend.Variable, 1+(len(words)-1)/wordsPerElem) - for elemI := range res { - res[elemI] = 0 - for wordI := 0; wordI < wordsPerElem; wordI++ { - absWordI := elemI*wordsPerElem + wordI - if absWordI >= len(words) { - break - } - res[elemI] = api.Add(res[elemI], api.Mul(words[absWordI], 1< api.Compiler().FieldBitLen() { @@ -104,7 +104,7 @@ func toBinary(api frontend.API, v frontend.Variable, opts ...BaseConversionOptio } // restore the zero bits which exceed the field bit-length when requested by - // setting WithNbDigits larger than the field bitlength. + // setting WithNbDigits larger than the field bitLength. bits = append(bits, make([]frontend.Variable, paddingBits)...) for i := cfg.NbDigits; i < len(bits); i++ { bits[i] = 0 // frontend.Variable is interface{}, we get nil pointer err if trying to access it.