diff --git a/DataFormats/Detectors/Common/include/DetectorsCommonDataFormats/EncodedBlocks.h b/DataFormats/Detectors/Common/include/DetectorsCommonDataFormats/EncodedBlocks.h index 8d1b34d105dfc..149a3a009a25a 100644 --- a/DataFormats/Detectors/Common/include/DetectorsCommonDataFormats/EncodedBlocks.h +++ b/DataFormats/Detectors/Common/include/DetectorsCommonDataFormats/EncodedBlocks.h @@ -27,6 +27,25 @@ namespace o2 { namespace ctf { + +namespace detail +{ + +template +struct is_iterator : std::false_type { +}; + +template +struct is_iterator::iterator_category> || + std::is_same_v::iterator_category>>> + : std::true_type { +}; + +template +inline constexpr bool is_iterator_v = is_iterator::value; +} // namespace detail + using namespace o2::rans; constexpr size_t Alignment = 16; @@ -376,7 +395,7 @@ class EncodedBlocks template inline void encode(const VE& src, int slot, uint8_t probabilityBits, Metadata::OptStore opt, VB* buffer = nullptr, const void* encoderExt = nullptr) { - encode(&(*src.begin()), &(*src.end()), slot, probabilityBits, opt, buffer, encoderExt); + encode(std::begin(src), std::end(src), slot, probabilityBits, opt, buffer, encoderExt); } /// encode vector src to bloc at provided slot @@ -384,12 +403,12 @@ class EncodedBlocks void encode(const S_IT srcBegin, const S_IT srcEnd, int slot, uint8_t probabilityBits, Metadata::OptStore opt, VB* buffer = nullptr, const void* encoderExt = nullptr); /// decode block at provided slot to destination vector (will be resized as needed) - template - void decode(VD& dest, int slot, const void* decoderExt = nullptr) const; + template + void decode(container_T& dest, int slot, const void* decoderExt = nullptr) const; /// decode block at provided slot to destination pointer, the needed space assumed to be available - template - void decode(D* dest, int slot, const void* decoderExt = nullptr) const; + template , bool> = true> + void decode(D_IT dest, int slot, const void* decoderExt = nullptr) const; /// create a special EncodedBlocks containing only dictionaries made from provided vector of frequency tables static std::vector createDictionaryBlocks(const std::vector& vfreq, const std::vector& prbits); @@ -666,19 +685,19 @@ void EncodedBlocks::print(const std::string& prefix) const ///_____________________________________________________________________________ template -template -inline void EncodedBlocks::decode(VD& dest, // destination container +template +inline void EncodedBlocks::decode(container_T& dest, // destination container int slot, // slot of the block to decode const void* decoderExt) const // optional externally provided decoder { dest.resize(mMetadata[slot].messageLength); // allocate output buffer - decode(dest.data(), slot, decoderExt); + decode(std::begin(dest), slot, decoderExt); } ///_____________________________________________________________________________ template -template -void EncodedBlocks::decode(D* dest, // destination pointer +template , bool>> +void EncodedBlocks::decode(D_IT dest, // iterator to destination int slot, // slot of the block to decode const void* decoderExt) const // optional externally provided decoder { @@ -686,6 +705,8 @@ void EncodedBlocks::decode(D* dest, // destination const auto& block = mBlocks[slot]; const auto& md = mMetadata[slot]; + using dest_t = typename std::iterator_traits::value_type; + // decode if (block.getNStored()) { if (md.opt == Metadata::OptStore::EENCODE) { @@ -693,12 +714,12 @@ void EncodedBlocks::decode(D* dest, // destination LOG(ERROR) << "Dictionaty is not saved for slot " << slot << " and no external decoder is provided"; throw std::runtime_error("Dictionary is not saved and no external decoder provided"); } - const o2::rans::LiteralDecoder64* decoder = reinterpret_cast*>(decoderExt); - std::unique_ptr> decoderLoc; + const o2::rans::LiteralDecoder64* decoder = reinterpret_cast*>(decoderExt); + std::unique_ptr> decoderLoc; if (block.getNDict()) { // if dictionaty is saved, prefer it o2::rans::FrequencyTable frequencies; frequencies.addFrequencies(block.getDict(), block.getDict() + block.getNDict(), md.min, md.max); - decoderLoc = std::make_unique>(frequencies, md.probabilityBits); + decoderLoc = std::make_unique>(frequencies, md.probabilityBits); decoder = decoderLoc.get(); } else { // verify that decoded corresponds to stored metadata if (md.min != decoder->getMinSymbol() || md.max != decoder->getMaxSymbol()) { @@ -708,16 +729,20 @@ void EncodedBlocks::decode(D* dest, // destination } } // load incompressible symbols if they existed - std::vector literals; + std::vector literals; if (block.getNLiterals()) { // note: here we have to use md.nLiterals (original number of literal words) rather than md.nLiteralWords == block.getNLiterals() // (number of W-words in the EncodedBlock occupied by literals) as we cast literals stored in W-word array // to D-word array - literals = std::vector{reinterpret_cast(block.getLiterals()), reinterpret_cast(block.getLiterals()) + md.nLiterals}; + literals = std::vector{reinterpret_cast(block.getLiterals()), reinterpret_cast(block.getLiterals()) + md.nLiterals}; } decoder->process(dest, block.getData() + block.getNData(), md.messageLength, literals); } else { // data was stored as is - std::memcpy(dest, block.payload, md.messageLength * sizeof(D)); + using destPtr_t = typename std::iterator_traits::pointer; + destPtr_t srcBegin = reinterpret_cast(block.payload); + destPtr_t srcEnd = srcBegin + md.messageLength * sizeof(dest_t); + std::copy(srcBegin, srcEnd, dest); + //std::memcpy(dest, block.payload, md.messageLength * sizeof(dest_t)); } } } @@ -738,7 +763,7 @@ void EncodedBlocks::encode(const S_IT srcBegin, // iterator begin o mRegistry.nFilledBlocks++; using STYP = typename std::iterator_traits::value_type; using stream_t = typename o2::rans::Encoder64::stream_t; - ; + const size_t messageLength = std::distance(srcBegin, srcEnd); // cover three cases: // * empty source message: no entropy coding @@ -825,6 +850,7 @@ void EncodedBlocks::encode(const S_IT srcBegin, // iterator begin o // no dictionary needed expandStorage(dataSize); *meta = Metadata{messageLength, 0, sizeof(uint64_t), sizeof(stream_t), probabilityBits, opt, 0, 0, 0, dataSize, 0}; + //FIXME: no we don't need an intermediate vector. // provided iterator is not necessarily pointer, need to use intermediate vector!!! std::vector vtmp(srcBegin, srcEnd); bl->storeData(meta->nDataWords, reinterpret_cast(vtmp.data())); diff --git a/DataFormats/Detectors/TPC/include/DataFormatsTPC/CTF.h b/DataFormats/Detectors/TPC/include/DataFormatsTPC/CTF.h index c501776e34e06..3e352e2667451 100644 --- a/DataFormats/Detectors/TPC/include/DataFormatsTPC/CTF.h +++ b/DataFormats/Detectors/TPC/include/DataFormatsTPC/CTF.h @@ -33,6 +33,8 @@ struct CTFHeader : public CompressedClustersCounters { /// wrapper for the Entropy-encoded clusters of the TF struct CTF : public o2::ctf::EncodedBlocks { + using container_t = o2::ctf::EncodedBlocks; + static constexpr size_t N = getNBlocks(); static constexpr int NBitsQTot = 16; static constexpr int NBitsQMax = 10; diff --git a/Detectors/Base/include/DetectorsBase/CTFCoderBase.h b/Detectors/Base/include/DetectorsBase/CTFCoderBase.h index 8cf0e02639e04..59da6d5c2187b 100644 --- a/Detectors/Base/include/DetectorsBase/CTFCoderBase.h +++ b/Detectors/Base/include/DetectorsBase/CTFCoderBase.h @@ -59,10 +59,13 @@ class CTFCoderBase template void createCoder(OpType op, const o2::rans::FrequencyTable& freq, uint8_t probabilityBits, int slot) { - if (op == OpType::Encoder) { - mCoders[slot].reset(new o2::rans::LiteralEncoder64(freq, probabilityBits)); - } else { - mCoders[slot].reset(new o2::rans::LiteralDecoder64(freq, probabilityBits)); + switch (op) { + case OpType::Encoder: + mCoders[slot].reset(new o2::rans::LiteralEncoder64(freq, probabilityBits)); + break; + case OpType::Decoder: + mCoders[slot].reset(new o2::rans::LiteralDecoder64(freq, probabilityBits)); + break; } } diff --git a/Detectors/CTF/test/test_ctf_io_tpc.cxx b/Detectors/CTF/test/test_ctf_io_tpc.cxx index 99ebe3a36a5ef..12efdfcfe74fe 100644 --- a/Detectors/CTF/test/test_ctf_io_tpc.cxx +++ b/Detectors/CTF/test/test_ctf_io_tpc.cxx @@ -42,6 +42,7 @@ BOOST_AUTO_TEST_CASE(CTFTest) { CTFCoder coder; coder.setCompClusAddresses(c, buff); + coder.setCombineColumns(true); } ccFlat->set(sz, c); @@ -85,6 +86,7 @@ BOOST_AUTO_TEST_CASE(CTFTest) std::vector vecIO; { CTFCoder coder; + coder.setCombineColumns(true); coder.encode(vecIO, c); // compress } sw.Stop(); @@ -120,6 +122,7 @@ BOOST_AUTO_TEST_CASE(CTFTest) const auto ctfImage = o2::tpc::CTF::getImage(vecIO.data()); { CTFCoder coder; + coder.setCombineColumns(true); coder.decode(ctfImage, vecIn); // decompress } sw.Stop(); diff --git a/Detectors/TPC/reconstruction/include/TPCReconstruction/CTFCoder.h b/Detectors/TPC/reconstruction/include/TPCReconstruction/CTFCoder.h index 6e2bca379cfcd..d6f1a9f6d76b3 100644 --- a/Detectors/TPC/reconstruction/include/TPCReconstruction/CTFCoder.h +++ b/Detectors/TPC/reconstruction/include/TPCReconstruction/CTFCoder.h @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -27,6 +28,7 @@ #include "DetectorsCommonDataFormats/DetID.h" #include "DetectorsBase/CTFCoderBase.h" #include "rANS/rans.h" +#include "rANS/utils.h" class TTree; @@ -35,11 +37,71 @@ namespace o2 namespace tpc { +namespace detail +{ + +template +struct combinedType { + using type = std::conditional_t<(A + B > 16), uint32_t, std::conditional_t<(A + B > 8), uint16_t, uint8_t>>; +}; + +template +using combinedType_t = typename combinedType::type; + +template +class ShiftFunctor +{ + public: + template + inline value_T operator()(iterA_T iterA, iterB_T iterB) const + { + return *iterB + (static_cast(*iterA) << shift); + }; + + template + inline void operator()(iterA_T iterA, iterB_T iterB, value_T value) const + { + *iterA = value >> shift; + *iterB = value & ((0x1 << shift) - 0x1); + }; +}; + +template +auto makeInputIterators(iterA_T iterA, iterB_T iterB, size_t nElements, F functor) +{ + using namespace o2::rans::utils; + + auto advanceIter = [](auto iter, size_t nElements) { + auto tmp = iter; + std::advance(tmp, nElements); + return tmp; + }; + + return std::make_tuple(CombinedInputIterator{iterA, iterB, functor}, + CombinedInputIterator{advanceIter(iterA, nElements), advanceIter(iterB, nElements), functor}); +}; + +template +struct MergedColumnsDecoder { + + using combined_t = combinedType_t; + + template + static void decode(iterA_T iterA, iterB_T iterB, CTF::Slots slot, F decodingFunctor) + { + ShiftFunctor f{}; + auto iter = rans::utils::CombinedOutputIteratorFactory::makeIter(iterA, iterB, f); + + decodingFunctor(iter, slot); + } +}; + +} // namespace detail + class CTFCoder : public o2::ctf::CTFCoderBase { public: CTFCoder() : o2::ctf::CTFCoderBase(CTF::getNBlocks(), o2::detectors::DetID::TPC) {} - ~CTFCoder() = default; /// entropy-encode compressed clusters to flat buffer template @@ -79,55 +141,32 @@ class CTFCoder : public o2::ctf::CTFCoderBase private: void checkDataDictionaryConsistency(const CTFHeader& h); - template - static constexpr auto MVAR() - { - typename std::conditional<(NU + NL > 16), uint32_t, typename std::conditional<(NU + NL > 8), uint16_t, uint8_t>::type>::type tp = 0; - return tp; - } - template - static constexpr auto MPTR() - { - typename std::conditional<(NU + NL > 16), uint32_t, typename std::conditional<(NU + NL > 8), uint16_t, uint8_t>::type>::type* tp = nullptr; - return tp; - } -#define MTYPE(A, B) decltype(CTFCoder::MVAR()) - template - static void splitColumns(const std::vector& vm, CU*& vu, CL*& vl); + static void splitColumns(const std::vector>& vm, CU*& vu, CL*& vl); - template - static auto mergeColumns(const CU* vu, const CL* vl, size_t nelem); + template + void buildCoder(ctf::CTFCoderBase::OpType coderType, const CTF::container_t& ctf, CTF::Slots slot); bool mCombineColumns = false; // combine correlated columns ClassDefNV(CTFCoder, 1); }; -/// split words of input vector to columns assigned to provided pointers (memory must be allocated in advance) -template -void CTFCoder::splitColumns(const std::vector& vm, CU*& vu, CL*& vl) +template +void CTFCoder::buildCoder(ctf::CTFCoderBase::OpType coderType, const CTF::container_t& ctf, CTF::Slots slot) { - static_assert(NU <= sizeof(CU) * 8 && NL <= sizeof(CL) * 8, "output columns bit count is wrong"); - size_t n = vm.size(); - for (size_t i = 0; i < n; i++) { - vu[i] = static_cast(vm[i] >> NL); - vl[i] = static_cast(vm[i] & ((0x1 << NL) - 1)); - } -} + auto buildFrequencyTable = [](const CTF::container_t& ctf, CTF::Slots slot) -> rans::FrequencyTable { + rans::FrequencyTable frequencyTable; + auto block = ctf.getBlock(slot); + auto metaData = ctf.getMetadata(slot); + frequencyTable.addFrequencies(block.getDict(), block.getDict() + block.getNDict(), metaData.min, metaData.max); + return frequencyTable; + }; + auto getProbabilityBits = [](const CTF::container_t& ctf, CTF::Slots slot) -> int { + return ctf.getMetadata(slot).probabilityBits; + }; -/// merge elements of 2 columns pointed by vu and vl to a single vector with wider field -template -auto CTFCoder::mergeColumns(const CU* vu, const CL* vl, size_t nelem) -{ - // merge 2 columns to 1 - static_assert(NU <= sizeof(NU) * 8 && NL <= sizeof(NL) * 8, "input columns bit count is wrong"); - std::vector outv; - outv.reserve(nelem); - for (size_t i = 0; i < nelem; i++) { - outv.push_back((static_cast(vu[i]) << NL) | static_cast(vl[i])); - } - return std::move(outv); + this->createCoder(coderType, buildFrequencyTable(ctf, slot), getProbabilityBits(ctf, slot), static_cast(slot)); } /// entropy-encode clusters to buffer with CTF @@ -135,6 +174,7 @@ template void CTFCoder::encode(VEC& buff, const CompressedClusters& ccl) { using MD = o2::ctf::Metadata::OptStore; + using namespace detail; // what to do which each field: see o2::ctf::Metadata explanation constexpr MD optField[CTF::getNBlocks()] = { MD::EENCODE, //qTotA @@ -174,73 +214,75 @@ void CTFCoder::encode(VEC& buff, const CompressedClusters& ccl) ec->setHeader(CTFHeader{reinterpret_cast(ccl), flags}); ec->getANSHeader().majorVersion = 0; ec->getANSHeader().minorVersion = 1; - // at every encoding the buffer might be autoexpanded, so we don't work with fixed pointer ec -#define ENCODETPC(beg, end, slot, bits) CTF::get(buff.data())->encode(beg, end, int(slot), bits, optField[int(slot)], &buff, mCoders[int(slot)].get()); - // clang-format off + + auto encodeTPC = [&buff, &optField, &coders = mCoders](auto begin, auto end, CTF::Slots slot, size_t probabilityBits) { + // at every encoding the buffer might be autoexpanded, so we don't work with fixed pointer ec + const auto slotVal = static_cast(slot); + CTF::get(buff.data())->encode(begin, end, slotVal, probabilityBits, optField[slotVal], &buff, coders[slotVal].get()); + }; if (mCombineColumns) { - auto mrg = mergeColumns(ccl.qTotA, ccl.qMaxA, ccl.nAttachedClusters); - ENCODETPC(&mrg[0], (&mrg[0]) + ccl.nAttachedClusters, CTF::BLCqTotA, 0); - } - else { - ENCODETPC(ccl.qTotA, ccl.qTotA + ccl.nAttachedClusters, CTF::BLCqTotA, 0); + const auto [begin, end] = makeInputIterators(ccl.qTotA, ccl.qMaxA, ccl.nAttachedClusters, + ShiftFunctor, CTF::NBitsQMax>{}); + encodeTPC(begin, end, CTF::BLCqTotA, 0); + } else { + encodeTPC(ccl.qTotA, ccl.qTotA + ccl.nAttachedClusters, CTF::BLCqTotA, 0); } - ENCODETPC(ccl.qMaxA, ccl.qMaxA + (mCombineColumns ? 0 : ccl.nAttachedClusters), CTF::BLCqMaxA, 0); - - ENCODETPC(ccl.flagsA, ccl.flagsA + ccl.nAttachedClusters, CTF::BLCflagsA, 0); - + encodeTPC(ccl.qMaxA, ccl.qMaxA + (mCombineColumns ? 0 : ccl.nAttachedClusters), CTF::BLCqMaxA, 0); + + encodeTPC(ccl.flagsA, ccl.flagsA + ccl.nAttachedClusters, CTF::BLCflagsA, 0); + if (mCombineColumns) { - auto mrg = mergeColumns(ccl.rowDiffA, ccl.sliceLegDiffA, ccl.nAttachedClustersReduced); - ENCODETPC(&mrg[0], (&mrg[0]) + ccl.nAttachedClustersReduced, CTF::BLCrowDiffA, 0); - } - else { - ENCODETPC(ccl.rowDiffA, ccl.rowDiffA + ccl.nAttachedClustersReduced, CTF::BLCrowDiffA, 0); + const auto [begin, end] = makeInputIterators(ccl.rowDiffA, ccl.sliceLegDiffA, ccl.nAttachedClustersReduced, + ShiftFunctor, CTF::NBitsSliceLegDiff>{}); + encodeTPC(begin, end, CTF::BLCrowDiffA, 0); + } else { + encodeTPC(ccl.rowDiffA, ccl.rowDiffA + ccl.nAttachedClustersReduced, CTF::BLCrowDiffA, 0); } - ENCODETPC(ccl.sliceLegDiffA, ccl.sliceLegDiffA + (mCombineColumns ? 0 : ccl.nAttachedClustersReduced), CTF::BLCsliceLegDiffA, 0); + encodeTPC(ccl.sliceLegDiffA, ccl.sliceLegDiffA + (mCombineColumns ? 0 : ccl.nAttachedClustersReduced), CTF::BLCsliceLegDiffA, 0); - ENCODETPC(ccl.padResA, ccl.padResA + ccl.nAttachedClustersReduced, CTF::BLCpadResA, 0); - ENCODETPC(ccl.timeResA, ccl.timeResA + ccl.nAttachedClustersReduced, CTF::BLCtimeResA, 0); + encodeTPC(ccl.padResA, ccl.padResA + ccl.nAttachedClustersReduced, CTF::BLCpadResA, 0); + encodeTPC(ccl.timeResA, ccl.timeResA + ccl.nAttachedClustersReduced, CTF::BLCtimeResA, 0); if (mCombineColumns) { - auto mrg = mergeColumns(ccl.sigmaPadA, ccl.sigmaTimeA, ccl.nAttachedClusters); - ENCODETPC(&mrg[0], &mrg[0] + ccl.nAttachedClusters, CTF::BLCsigmaPadA, 0); + const auto [begin, end] = makeInputIterators(ccl.sigmaPadA, ccl.sigmaTimeA, ccl.nAttachedClusters, + ShiftFunctor, CTF::NBitsSigmaTime>{}); + encodeTPC(begin, end, CTF::BLCsigmaPadA, 0); + } else { + encodeTPC(ccl.sigmaPadA, ccl.sigmaPadA + ccl.nAttachedClusters, CTF::BLCsigmaPadA, 0); } - else { - ENCODETPC(ccl.sigmaPadA, ccl.sigmaPadA + ccl.nAttachedClusters, CTF::BLCsigmaPadA, 0); - } - ENCODETPC(ccl.sigmaTimeA, ccl.sigmaTimeA + (mCombineColumns ? 0 : ccl.nAttachedClusters), CTF::BLCsigmaTimeA, 0); + encodeTPC(ccl.sigmaTimeA, ccl.sigmaTimeA + (mCombineColumns ? 0 : ccl.nAttachedClusters), CTF::BLCsigmaTimeA, 0); - ENCODETPC(ccl.qPtA, ccl.qPtA + ccl.nTracks, CTF::BLCqPtA, 0); - ENCODETPC(ccl.rowA, ccl.rowA + ccl.nTracks, CTF::BLCrowA, 0); - ENCODETPC(ccl.sliceA, ccl.sliceA + ccl.nTracks, CTF::BLCsliceA, 0); - ENCODETPC(ccl.timeA, ccl.timeA + ccl.nTracks, CTF::BLCtimeA, 0); - ENCODETPC(ccl.padA, ccl.padA + ccl.nTracks, CTF::BLCpadA, 0); + encodeTPC(ccl.qPtA, ccl.qPtA + ccl.nTracks, CTF::BLCqPtA, 0); + encodeTPC(ccl.rowA, ccl.rowA + ccl.nTracks, CTF::BLCrowA, 0); + encodeTPC(ccl.sliceA, ccl.sliceA + ccl.nTracks, CTF::BLCsliceA, 0); + encodeTPC(ccl.timeA, ccl.timeA + ccl.nTracks, CTF::BLCtimeA, 0); + encodeTPC(ccl.padA, ccl.padA + ccl.nTracks, CTF::BLCpadA, 0); if (mCombineColumns) { - auto mrg = mergeColumns(ccl.qTotU,ccl.qMaxU,ccl.nUnattachedClusters); - ENCODETPC(&mrg[0], &mrg[0] + ccl.nUnattachedClusters, CTF::BLCqTotU, 0); - } - else { - ENCODETPC(ccl.qTotU, ccl.qTotU + ccl.nUnattachedClusters, CTF::BLCqTotU, 0); + const auto [begin, end] = makeInputIterators(ccl.qTotU, ccl.qMaxU, ccl.nUnattachedClusters, + ShiftFunctor, CTF::NBitsQMax>{}); + encodeTPC(begin, end, CTF::BLCqTotU, 0); + } else { + encodeTPC(ccl.qTotU, ccl.qTotU + ccl.nUnattachedClusters, CTF::BLCqTotU, 0); } - ENCODETPC(ccl.qMaxU, ccl.qMaxU + (mCombineColumns ? 0 : ccl.nUnattachedClusters), CTF::BLCqMaxU, 0); + encodeTPC(ccl.qMaxU, ccl.qMaxU + (mCombineColumns ? 0 : ccl.nUnattachedClusters), CTF::BLCqMaxU, 0); - ENCODETPC(ccl.flagsU, ccl.flagsU + ccl.nUnattachedClusters, CTF::BLCflagsU, 0); - ENCODETPC(ccl.padDiffU, ccl.padDiffU + ccl.nUnattachedClusters, CTF::BLCpadDiffU, 0); - ENCODETPC(ccl.timeDiffU, ccl.timeDiffU + ccl.nUnattachedClusters, CTF::BLCtimeDiffU, 0); + encodeTPC(ccl.flagsU, ccl.flagsU + ccl.nUnattachedClusters, CTF::BLCflagsU, 0); + encodeTPC(ccl.padDiffU, ccl.padDiffU + ccl.nUnattachedClusters, CTF::BLCpadDiffU, 0); + encodeTPC(ccl.timeDiffU, ccl.timeDiffU + ccl.nUnattachedClusters, CTF::BLCtimeDiffU, 0); if (mCombineColumns) { - auto mrg = mergeColumns(ccl.sigmaPadU, ccl.sigmaTimeU, ccl.nUnattachedClusters); - ENCODETPC(&mrg[0], &mrg[0] + ccl.nUnattachedClusters, CTF::BLCsigmaPadU, 0); + const auto [begin, end] = makeInputIterators(ccl.sigmaPadU, ccl.sigmaTimeU, ccl.nUnattachedClusters, + ShiftFunctor, CTF::NBitsSigmaTime>{}); + encodeTPC(begin, end, CTF::BLCsigmaPadU, 0); + } else { + encodeTPC(ccl.sigmaPadU, ccl.sigmaPadU + ccl.nUnattachedClusters, CTF::BLCsigmaPadU, 0); } - else { - ENCODETPC(ccl.sigmaPadU, ccl.sigmaPadU + ccl.nUnattachedClusters, CTF::BLCsigmaPadU, 0); - } - ENCODETPC(ccl.sigmaTimeU, ccl.sigmaTimeU + (mCombineColumns ? 0 : ccl.nUnattachedClusters), CTF::BLCsigmaTimeU, 0); + encodeTPC(ccl.sigmaTimeU, ccl.sigmaTimeU + (mCombineColumns ? 0 : ccl.nUnattachedClusters), CTF::BLCsigmaTimeU, 0); - ENCODETPC(ccl.nTrackClusters, ccl.nTrackClusters + ccl.nTracks, CTF::BLCnTrackClusters, 0); - ENCODETPC(ccl.nSliceRowClusters, ccl.nSliceRowClusters + ccl.nSliceRows, CTF::BLCnSliceRowClusters, 0); - // clang-format on + encodeTPC(ccl.nTrackClusters, ccl.nTrackClusters + ccl.nTracks, CTF::BLCnTrackClusters, 0); + encodeTPC(ccl.nSliceRowClusters, ccl.nSliceRowClusters + ccl.nSliceRows, CTF::BLCnSliceRowClusters, 0); CTF::get(buff.data())->print(getPrefix()); } @@ -248,6 +290,7 @@ void CTFCoder::encode(VEC& buff, const CompressedClusters& ccl) template void CTFCoder::decode(const CTF::base& ec, VEC& buffVec) { + using namespace detail; CompressedClusters cc; CompressedClustersCounters& ccCount = cc; auto& header = ec.getHeader(); @@ -266,78 +309,64 @@ void CTFCoder::decode(const CTF::base& ec, VEC& buffVec) ec.print(getPrefix()); // decode encoded data directly to destination buff -#define DECODETPC(part, slot) ec.decode(part, int(slot), mCoders[int(slot)].get()) - // clang-format off + auto decodeTPC = [&ec, &coders = mCoders](auto begin, CTF::Slots slot) { + const auto slotVal = static_cast(slot); + ec.decode(begin, slotVal, coders[slotVal].get()); + }; + if (mCombineColumns) { - std::vector mrg; - DECODETPC(mrg, CTF::BLCqTotA); - splitColumns(mrg, cc.qTotA, cc.qMaxA); - } - else { - DECODETPC(cc.qTotA, CTF::BLCqTotA); - DECODETPC(cc.qMaxA, CTF::BLCqMaxA); + detail::MergedColumnsDecoder::decode(cc.qTotA, cc.qMaxA, CTF::BLCqTotA, decodeTPC); + } else { + decodeTPC(cc.qTotA, CTF::BLCqTotA); + decodeTPC(cc.qMaxA, CTF::BLCqMaxA); } - - DECODETPC(cc.flagsA, CTF::BLCflagsA); - + + decodeTPC(cc.flagsA, CTF::BLCflagsA); + if (mCombineColumns) { - std::vector mrg; - DECODETPC(mrg, CTF::BLCrowDiffA); - splitColumns(mrg, cc.rowDiffA, cc.sliceLegDiffA); + detail::MergedColumnsDecoder::decode(cc.rowDiffA, cc.sliceLegDiffA, CTF::BLCrowDiffA, decodeTPC); + } else { + decodeTPC(cc.rowDiffA, CTF::BLCrowDiffA); + decodeTPC(cc.sliceLegDiffA, CTF::BLCsliceLegDiffA); } - else { - DECODETPC(cc.rowDiffA, CTF::BLCrowDiffA); - DECODETPC(cc.sliceLegDiffA, CTF::BLCsliceLegDiffA); - } - - DECODETPC(cc.padResA, CTF::BLCpadResA); - DECODETPC(cc.timeResA, CTF::BLCtimeResA); + + decodeTPC(cc.padResA, CTF::BLCpadResA); + decodeTPC(cc.timeResA, CTF::BLCtimeResA); if (mCombineColumns) { - std::vector mrg; - DECODETPC(mrg, CTF::BLCsigmaPadA); - splitColumns(mrg, cc.sigmaPadA, cc.sigmaTimeA); + detail::MergedColumnsDecoder::decode(cc.sigmaPadA, cc.sigmaTimeA, CTF::BLCsigmaPadA, decodeTPC); + } else { + decodeTPC(cc.sigmaPadA, CTF::BLCsigmaPadA); + decodeTPC(cc.sigmaTimeA, CTF::BLCsigmaTimeA); } - else { - DECODETPC(cc.sigmaPadA, CTF::BLCsigmaPadA); - DECODETPC(cc.sigmaTimeA, CTF::BLCsigmaTimeA); - } - - DECODETPC(cc.qPtA, CTF::BLCqPtA); - DECODETPC(cc.rowA, CTF::BLCrowA); - DECODETPC(cc.sliceA, CTF::BLCsliceA); - DECODETPC(cc.timeA, CTF::BLCtimeA); - DECODETPC(cc.padA, CTF::BLCpadA); + + decodeTPC(cc.qPtA, CTF::BLCqPtA); + decodeTPC(cc.rowA, CTF::BLCrowA); + decodeTPC(cc.sliceA, CTF::BLCsliceA); + decodeTPC(cc.timeA, CTF::BLCtimeA); + decodeTPC(cc.padA, CTF::BLCpadA); if (mCombineColumns) { - std::vector mrg; - DECODETPC(mrg, CTF::BLCqTotU); - splitColumns(mrg, cc.qTotU, cc.qMaxU); - } - else { - DECODETPC(cc.qTotU, CTF::BLCqTotU); - DECODETPC(cc.qMaxU, CTF::BLCqMaxU); + detail::MergedColumnsDecoder::decode(cc.qTotU, cc.qMaxU, CTF::BLCqTotU, decodeTPC); + } else { + decodeTPC(cc.qTotU, CTF::BLCqTotU); + decodeTPC(cc.qMaxU, CTF::BLCqMaxU); } - DECODETPC(cc.flagsU, CTF::BLCflagsU); - DECODETPC(cc.padDiffU, CTF::BLCpadDiffU); - DECODETPC(cc.timeDiffU, CTF::BLCtimeDiffU); + decodeTPC(cc.flagsU, CTF::BLCflagsU); + decodeTPC(cc.padDiffU, CTF::BLCpadDiffU); + decodeTPC(cc.timeDiffU, CTF::BLCtimeDiffU); if (mCombineColumns) { - std::vector mrg; - DECODETPC(mrg, CTF::BLCsigmaPadU); - splitColumns(mrg, cc.sigmaPadU, cc.sigmaTimeU); - } - else { - DECODETPC(cc.sigmaPadU, CTF::BLCsigmaPadU); - DECODETPC(cc.sigmaTimeU, CTF::BLCsigmaTimeU); + detail::MergedColumnsDecoder::decode(cc.sigmaPadU, cc.sigmaTimeU, CTF::BLCsigmaPadU, decodeTPC); + } else { + decodeTPC(cc.sigmaPadU, CTF::BLCsigmaPadU); + decodeTPC(cc.sigmaTimeU, CTF::BLCsigmaTimeU); } - - DECODETPC(cc.nTrackClusters, CTF::BLCnTrackClusters); - DECODETPC(cc.nSliceRowClusters, CTF::BLCnSliceRowClusters); - // clang-format on + + decodeTPC(cc.nTrackClusters, CTF::BLCnTrackClusters); + decodeTPC(cc.nSliceRowClusters, CTF::BLCnSliceRowClusters); } -#undef MTYPE } // namespace tpc } // namespace o2 diff --git a/Detectors/TPC/reconstruction/src/CTFCoder.cxx b/Detectors/TPC/reconstruction/src/CTFCoder.cxx index c7888b2bdd16f..20873f20bcdbf 100644 --- a/Detectors/TPC/reconstruction/src/CTFCoder.cxx +++ b/Detectors/TPC/reconstruction/src/CTFCoder.cxx @@ -91,6 +91,8 @@ void CTFCoder::setCompClusAddresses(CompressedClusters& c, void*& buff) ///________________________________ void CTFCoder::createCoders(const std::string& dictPath, o2::ctf::CTFCoderBase::OpType op) { + using namespace detail; + bool mayFail = true; // RS FIXME if the dictionary file is not there, do not produce exception auto buff = readDictionaryFromFile(dictPath, mayFail); if (!buff.size()) { @@ -99,73 +101,55 @@ void CTFCoder::createCoders(const std::string& dictPath, o2::ctf::CTFCoderBase:: } throw std::runtime_error("Failed to create CTF dictionaty"); } - const auto* ctf = CTF::get(buff.data()); + const CTF::container_t* ctf = CTF::get(buff.data()); mCombineColumns = ctf->getHeader().flags & CTFHeader::CombinedColumns; LOG(INFO) << "TPC CTF Columns Combining " << (mCombineColumns ? "ON" : "OFF"); - auto getFreq = [ctf](CTF::Slots slot) -> o2::rans::FrequencyTable { - o2::rans::FrequencyTable ft; - auto bl = ctf->getBlock(slot); - auto md = ctf->getMetadata(slot); - ft.addFrequencies(bl.getDict(), bl.getDict() + bl.getNDict(), md.min, md.max); - return std::move(ft); - }; - auto getProbBits = [ctf](CTF::Slots slot) -> int { - return ctf->getMetadata(slot).probabilityBits; - }; - - CompressedClusters cc; // just to get member types -#define MAKECODER(part, slot) createCoder::type>(op, getFreq(slot), getProbBits(slot), int(slot)) - // clang-format off + const CompressedClusters cc; // just to get member types if (mCombineColumns) { - MAKECODER( (MPTR()), CTF::BLCqTotA); // merged qTotA and qMaxA - } - else { - MAKECODER(cc.qTotA, CTF::BLCqTotA); + buildCoder>(op, *ctf, CTF::BLCqTotA); + } else { + buildCoder>(op, *ctf, CTF::BLCqTotA); } - MAKECODER(cc.qMaxA, CTF::BLCqMaxA); - MAKECODER(cc.flagsA, CTF::BLCflagsA); + buildCoder>(op, *ctf, CTF::BLCqMaxA); + buildCoder>(op, *ctf, CTF::BLCflagsA); if (mCombineColumns) { - MAKECODER( (MPTR()), CTF::BLCrowDiffA); // merged rowDiffA and sliceLegDiffA - } - else { - MAKECODER(cc.rowDiffA, CTF::BLCrowDiffA); + buildCoder>(op, *ctf, CTF::BLCrowDiffA); // merged rowDiffA and sliceLegDiffA + + } else { + buildCoder>(op, *ctf, CTF::BLCrowDiffA); } - MAKECODER(cc.sliceLegDiffA, CTF::BLCsliceLegDiffA); - MAKECODER(cc.padResA, CTF::BLCpadResA); - MAKECODER(cc.timeResA, CTF::BLCtimeResA); + buildCoder>(op, *ctf, CTF::BLCsliceLegDiffA); + buildCoder>(op, *ctf, CTF::BLCpadResA); + buildCoder>(op, *ctf, CTF::BLCtimeResA); if (mCombineColumns) { - MAKECODER( (MPTR()), CTF::BLCsigmaPadA); // merged sigmaPadA and sigmaTimeA - } - else { - MAKECODER(cc.sigmaPadA, CTF::BLCsigmaPadA); + buildCoder>(op, *ctf, CTF::BLCsigmaPadA); // merged sigmaPadA and sigmaTimeA + } else { + buildCoder>(op, *ctf, CTF::BLCsigmaPadA); } - MAKECODER(cc.sigmaTimeA, CTF::BLCsigmaTimeA); - MAKECODER(cc.qPtA, CTF::BLCqPtA); - MAKECODER(cc.rowA, CTF::BLCrowA); - MAKECODER(cc.sliceA, CTF::BLCsliceA); - MAKECODER(cc.timeA, CTF::BLCtimeA); - MAKECODER(cc.padA, CTF::BLCpadA); + buildCoder>(op, *ctf, CTF::BLCsigmaTimeA); + buildCoder>(op, *ctf, CTF::BLCqPtA); + buildCoder>(op, *ctf, CTF::BLCrowA); + buildCoder>(op, *ctf, CTF::BLCsliceA); + buildCoder>(op, *ctf, CTF::BLCtimeA); + buildCoder>(op, *ctf, CTF::BLCpadA); if (mCombineColumns) { - MAKECODER( (MPTR()), CTF::BLCqTotU); // merged qTotU and qMaxU - } - else { - MAKECODER(cc.qTotU, CTF::BLCqTotU); + buildCoder>(op, *ctf, CTF::BLCqTotU); // merged qTotU and qMaxU + } else { + buildCoder>(op, *ctf, CTF::BLCqTotU); } - MAKECODER(cc.qMaxU, CTF::BLCqMaxU); - MAKECODER(cc.flagsU, CTF::BLCflagsU); - MAKECODER(cc.padDiffU, CTF::BLCpadDiffU); - MAKECODER(cc.timeDiffU, CTF::BLCtimeDiffU); + buildCoder>(op, *ctf, CTF::BLCqMaxU); + buildCoder>(op, *ctf, CTF::BLCflagsU); + buildCoder>(op, *ctf, CTF::BLCpadDiffU); + buildCoder>(op, *ctf, CTF::BLCtimeDiffU); if (mCombineColumns) { - MAKECODER( (MPTR()), CTF::BLCsigmaPadU); // merged sigmaPadA and sigmaTimeA - } - else { - MAKECODER(cc.sigmaPadU, CTF::BLCsigmaPadU); + buildCoder>(op, *ctf, CTF::BLCsigmaPadU); // merged sigmaPadU and sigmaTimeU + } else { + buildCoder>(op, *ctf, CTF::BLCsigmaPadU); } - MAKECODER(cc.sigmaTimeU, CTF::BLCsigmaTimeU); - MAKECODER(cc.nTrackClusters, CTF::BLCnTrackClusters); - MAKECODER(cc.nSliceRowClusters, CTF::BLCnSliceRowClusters); - // clang-format on + buildCoder>(op, *ctf, CTF::BLCsigmaTimeU); + buildCoder>(op, *ctf, CTF::BLCnTrackClusters); + buildCoder>(op, *ctf, CTF::BLCnSliceRowClusters); } /// make sure loaded dictionaries (if any) are consistent with data diff --git a/Utilities/rANS/benchmarks/bench_ransCombinedIterator.cxx b/Utilities/rANS/benchmarks/bench_ransCombinedIterator.cxx index 8b3ab59181cd0..6fcff21090c8d 100644 --- a/Utilities/rANS/benchmarks/bench_ransCombinedIterator.cxx +++ b/Utilities/rANS/benchmarks/bench_ransCombinedIterator.cxx @@ -99,7 +99,7 @@ static void BM_Array_Write_Iterator(benchmark::State& state) *iterB = value & ((1 << shift) - 1); }; - o2::rans::utils::CombinedOutputIterator out(a.begin(), b.begin(), writeOP); + auto out = o2::rans::utils::CombinedOutputIteratorFactory::makeIter(a.begin(), b.begin(), writeOP); for (auto iter = c.begin(); iter != c.end(); ++iter) { *out = *iter + 1; diff --git a/Utilities/rANS/include/rANS/Decoder.h b/Utilities/rANS/include/rANS/Decoder.h index b9ace6e49fc2b..419d8a942a5f0 100644 --- a/Utilities/rANS/include/rANS/Decoder.h +++ b/Utilities/rANS/include/rANS/Decoder.h @@ -31,6 +31,7 @@ #include "internal/SymbolTable.h" #include "internal/Decoder.h" #include "internal/SymbolStatistics.h" +#include "internal/helper.h" namespace o2 { @@ -54,7 +55,7 @@ class Decoder ~Decoder() = default; Decoder(const FrequencyTable& stats, size_t probabilityBits); - template + template && internal::isCompatibleIter_v, bool> = true> void process(const source_IT outputBegin, const stream_IT inputEnd, size_t messageLength) const; size_t getAlphabetRangeBits() const { return mSymbolTable->getAlphabetRangeBits(); } @@ -107,15 +108,13 @@ Decoder::Decoder(const FrequencyTable& frequencies, }; template -template +template && internal::isCompatibleIter_v, bool>> void Decoder::process(const source_IT outputBegin, const stream_IT inputEnd, size_t messageLength) const { using namespace internal; LOG(trace) << "start decoding"; RANSTimer t; t.start(); - static_assert(std::is_same::value_type, source_T>::value); - static_assert(std::is_same::value_type, stream_T>::value); if (messageLength == 0) { LOG(warning) << "Empty message passed to decoder, skipping decode process"; diff --git a/Utilities/rANS/include/rANS/DedupDecoder.h b/Utilities/rANS/include/rANS/DedupDecoder.h index b9c3538b95dcf..91c028ad07bea 100644 --- a/Utilities/rANS/include/rANS/DedupDecoder.h +++ b/Utilities/rANS/include/rANS/DedupDecoder.h @@ -44,12 +44,12 @@ class DedupDecoder : public Decoder public: using duplicatesMap_t = std::map; - template + template && internal::isCompatibleIter_v, bool> = true> void process(const source_IT outputBegin, const stream_IT inputEnd, size_t messageLength, duplicatesMap_t& duplicates) const; }; template -template +template && internal::isCompatibleIter_v, bool>> void DedupDecoder::process(const source_IT outputBegin, const stream_IT inputEnd, size_t messageLength, duplicatesMap_t& duplicates) const { using namespace internal; diff --git a/Utilities/rANS/include/rANS/DedupEncoder.h b/Utilities/rANS/include/rANS/DedupEncoder.h index 3a4efaf2d9f74..1dfcf3c6db85d 100644 --- a/Utilities/rANS/include/rANS/DedupEncoder.h +++ b/Utilities/rANS/include/rANS/DedupEncoder.h @@ -46,13 +46,13 @@ class DedupEncoder : public Encoder public: using duplicatesMap_t = std::map; - template + template && internal::isCompatibleIter_v, bool> = true> const stream_IT process(const stream_IT outputBegin, const stream_IT outputEnd, const source_IT inputBegin, source_IT inputEnd, duplicatesMap_t& duplicates) const; }; template -template +template && internal::isCompatibleIter_v, bool>> const stream_IT DedupEncoder::process(const stream_IT outputBegin, const stream_IT outputEnd, const source_IT inputBegin, const source_IT inputEnd, duplicatesMap_t& duplicates) const { using namespace internal; @@ -107,7 +107,7 @@ const stream_IT DedupEncoder::process(const stream_ return std::tuple(++dedupIT, coder.putSymbol(outputIter, encoderSymbol, this->mProbabilityBits)); }; - while (inputIT > inputBegin) { // NB: working in reverse! + while (inputIT != inputBegin) { // NB: working in reverse! std::tie(inputIT, outputIter) = encode(--inputIT, outputIter, rans); assert(outputIter < outputEnd); } diff --git a/Utilities/rANS/include/rANS/Encoder.h b/Utilities/rANS/include/rANS/Encoder.h index 907302b14202c..45633e2e86a58 100644 --- a/Utilities/rANS/include/rANS/Encoder.h +++ b/Utilities/rANS/include/rANS/Encoder.h @@ -53,7 +53,7 @@ class Encoder Encoder(encoderSymbolTable_t&& e, size_t probabilityBits); Encoder(const FrequencyTable& frequencies, size_t probabilityBits); - template + template && internal::isCompatibleIter_v, bool> = true> const stream_IT process(const stream_IT outputBegin, const stream_IT outputEnd, const source_IT inputBegin, const source_IT inputEnd) const; @@ -113,7 +113,7 @@ Encoder::Encoder(const FrequencyTable& frequencies, } template -template +template && internal::isCompatibleIter_v, bool>> const stream_IT Encoder::Encoder::process(const stream_IT outputBegin, const stream_IT outputEnd, const source_IT inputBegin, const source_IT inputEnd) const { using namespace internal; @@ -154,7 +154,7 @@ const stream_IT Encoder::Encoder::process(const str assert(outputIter < outputEnd); } - while (inputIT > inputBegin) { // NB: working in reverse! + while (inputIT != inputBegin) { // NB: working in reverse! std::tie(inputIT, outputIter) = encode(--inputIT, outputIter, rans1); std::tie(inputIT, outputIter) = encode(--inputIT, outputIter, rans0); assert(outputIter < outputEnd); diff --git a/Utilities/rANS/include/rANS/LiteralDecoder.h b/Utilities/rANS/include/rANS/LiteralDecoder.h index 4e8b2716184c9..8c3e3d77113c4 100644 --- a/Utilities/rANS/include/rANS/LiteralDecoder.h +++ b/Utilities/rANS/include/rANS/LiteralDecoder.h @@ -42,12 +42,12 @@ class LiteralDecoder : public Decoder using Decoder::Decoder; public: - template + template && internal::isCompatibleIter_v, bool> = true> void process(const source_IT outputBegin, const stream_IT inputEnd, size_t messageLength, std::vector& literals) const; }; template -template +template && internal::isCompatibleIter_v, bool>> void LiteralDecoder::process(const source_IT outputBegin, const stream_IT inputEnd, size_t messageLength, std::vector& literals) const { using namespace internal; diff --git a/Utilities/rANS/include/rANS/LiteralEncoder.h b/Utilities/rANS/include/rANS/LiteralEncoder.h index 989ee035cb5e5..8805ba91aee4c 100644 --- a/Utilities/rANS/include/rANS/LiteralEncoder.h +++ b/Utilities/rANS/include/rANS/LiteralEncoder.h @@ -41,13 +41,13 @@ class LiteralEncoder : public Encoder using Encoder::Encoder; public: - template + template && internal::isCompatibleIter_v, bool> = true> const stream_IT process(const stream_IT outputBegin, const stream_IT outputEnd, const source_IT inputBegin, source_IT inputEnd, std::vector& literals) const; }; template -template +template && internal::isCompatibleIter_v, bool>> const stream_IT LiteralEncoder::process(const stream_IT outputBegin, const stream_IT outputEnd, const source_IT inputBegin, const source_IT inputEnd, std::vector& literals) const { using namespace internal; @@ -92,7 +92,7 @@ const stream_IT LiteralEncoder::process(const strea assert(outputIter < outputEnd); } - while (inputIT > inputBegin) { // NB: working in reverse! + while (inputIT != inputBegin) { // NB: working in reverse! std::tie(inputIT, outputIter) = encode(--inputIT, outputIter, rans1); std::tie(inputIT, outputIter) = encode(--inputIT, outputIter, rans0); assert(outputIter < outputEnd); diff --git a/Utilities/rANS/include/rANS/internal/Decoder.h b/Utilities/rANS/include/rANS/internal/Decoder.h index d983a1f91a34d..60e9b4d162769 100644 --- a/Utilities/rANS/include/rANS/internal/Decoder.h +++ b/Utilities/rANS/include/rANS/internal/Decoder.h @@ -49,21 +49,21 @@ class Decoder // Initializes a rANS decoder. // Unlike the encoder, the decoder works forwards as you'd expect. - template + template , bool> = true> Stream_IT init(Stream_IT iter); // Returns the current cumulative frequency (map it to a symbol yourself!) uint32_t get(uint32_t scale_bits); // Equivalent to Rans32DecAdvance that takes a symbol. - template + template , bool> = true> Stream_IT advanceSymbol(Stream_IT iter, const DecoderSymbol& sym, uint32_t scale_bits); private: State_T mState; // Renormalize. - template + template , bool> = true> std::tuple renorm(State_T x, Stream_IT iter); // L ('l' in the paper) is the lower bound of our normalization interval. @@ -79,10 +79,9 @@ template Decoder::Decoder() : mState(0){}; template -template +template , bool>> Stream_IT Decoder::init(Stream_IT iter) { - static_assert(std::is_same::value_type, Stream_T>::value); State_T x = 0; Stream_IT streamPos = iter; @@ -116,7 +115,7 @@ uint32_t Decoder::get(uint32_t scale_bits) }; template -template +template , bool>> Stream_IT Decoder::advanceSymbol(Stream_IT iter, const DecoderSymbol& sym, uint32_t scale_bits) { static_assert(std::is_same::value_type, Stream_T>::value); @@ -134,7 +133,7 @@ Stream_IT Decoder::advanceSymbol(Stream_IT iter, const Decode }; template -template +template , bool>> inline std::tuple Decoder::renorm(State_T x, Stream_IT iter) { static_assert(std::is_same::value_type, Stream_T>::value); diff --git a/Utilities/rANS/include/rANS/internal/Encoder.h b/Utilities/rANS/include/rANS/internal/Encoder.h index c3e9c9fb04b89..8cd1247c1147a 100644 --- a/Utilities/rANS/include/rANS/internal/Encoder.h +++ b/Utilities/rANS/include/rANS/internal/Encoder.h @@ -55,25 +55,25 @@ class Encoder // NOTE: With rANS, you need to encode symbols in *reverse order*, i.e. from // beginning to end! Likewise, the output bytestream is written *backwards*: // ptr starts pointing at the end of the output buffer and keeps decrementing. - template + template , bool> = true> Stream_IT put(Stream_IT iter, uint32_t start, uint32_t freq, uint32_t scale_bits); // Flushes the rANS encoder. - template + template , bool> = true> Stream_IT flush(Stream_IT iter); // Encodes a given symbol. This is faster than straight RansEnc since we can do // multiplications instead of a divide. // // See Rans32EncSymbolInit for a description of how this works. - template + template , bool> = true> Stream_IT putSymbol(Stream_IT iter, const EncoderSymbol& sym, uint32_t scale_bits); private: State_T mState; // Renormalize the encoder. - template + template , bool> = true> std::tuple renorm(State_T x, Stream_IT iter, uint32_t freq, uint32_t scale_bits); // L ('l' in the paper) is the lower bound of our normalization interval. @@ -89,10 +89,9 @@ template Encoder::Encoder() : mState(LOWER_BOUND){}; template -template +template , bool>> Stream_IT Encoder::put(Stream_IT iter, uint32_t start, uint32_t freq, uint32_t scale_bits) { - static_assert(std::is_same::value_type, Stream_T>::value); // renormalize Stream_IT streamPos; State_T x; @@ -104,10 +103,9 @@ Stream_IT Encoder::put(Stream_IT iter, uint32_t start, uint32 }; template -template +template , bool>> Stream_IT Encoder::flush(Stream_IT iter) { - static_assert(std::is_same::value_type, Stream_T>::value); Stream_IT streamPos = iter; @@ -134,10 +132,9 @@ Stream_IT Encoder::flush(Stream_IT iter) }; template -template +template , bool>> Stream_IT Encoder::putSymbol(Stream_IT iter, const EncoderSymbol& sym, uint32_t scale_bits) { - static_assert(std::is_same::value_type, Stream_T>::value); assert(sym.freq != 0); // can't encode symbol with freq=0 @@ -163,11 +160,9 @@ Stream_IT Encoder::putSymbol(Stream_IT iter, const EncoderSym }; template -template +template , bool>> inline std::tuple Encoder::renorm(State_T x, Stream_IT iter, uint32_t freq, uint32_t scale_bits) { - static_assert(std::is_same::value_type, Stream_T>::value); - Stream_IT streamPos = iter; State_T x_max = ((LOWER_BOUND >> scale_bits) << STREAM_BITS) * freq; // this turns into a shift. diff --git a/Utilities/rANS/include/rANS/internal/helper.h b/Utilities/rANS/include/rANS/internal/helper.h index b8c5c46283d43..9fbdb2e041978 100644 --- a/Utilities/rANS/include/rANS/internal/helper.h +++ b/Utilities/rANS/include/rANS/internal/helper.h @@ -19,6 +19,8 @@ #include #include #include +#include +#include namespace o2 { @@ -60,6 +62,9 @@ class RANSTimer std::chrono::time_point mStop; }; +template +inline constexpr bool isCompatibleIter_v = std::is_same_v::value_type, T>; + } // namespace internal } // namespace rans } // namespace o2 diff --git a/Utilities/rANS/include/rANS/utils/CombinedIterator.h b/Utilities/rANS/include/rANS/utils/CombinedIterator.h index 46a92d45ccfba..0bc6eb41b282e 100644 --- a/Utilities/rANS/include/rANS/utils/CombinedIterator.h +++ b/Utilities/rANS/include/rANS/utils/CombinedIterator.h @@ -32,13 +32,15 @@ namespace utils template class CombinedInputIterator { + + public: using difference_type = std::ptrdiff_t; - using value_type = std::result_of; + using value_type = std::invoke_result_t; using pointer = value_type*; using reference = value_type&; - using iterator_category = std::input_iterator_tag; + using iterator_category = std::bidirectional_iterator_tag; - public: + CombinedInputIterator() = default; CombinedInputIterator(iterA_T iterA, iterB_T iterB, F functor); CombinedInputIterator(const CombinedInputIterator& iter) = default; CombinedInputIterator(CombinedInputIterator&& iter) = default; @@ -53,14 +55,16 @@ class CombinedInputIterator //pointer arithmetics CombinedInputIterator& operator++(); CombinedInputIterator operator++(int); + CombinedInputIterator& operator--(); + CombinedInputIterator operator--(int); // dereference auto operator*() const; private: - iterA_T mIterA; - iterB_T mIterB; - F mFunctor; + iterA_T mIterA{}; + iterB_T mIterB{}; + F mFunctor{}; public: friend std::ostream& operator<<(std::ostream& o, const CombinedInputIterator& iter) @@ -70,7 +74,7 @@ class CombinedInputIterator } }; -template +template class CombinedOutputIterator { @@ -79,20 +83,19 @@ class CombinedOutputIterator public: Proxy(CombinedOutputIterator& iter); - template - Proxy& operator=(value_T value); + Proxy& operator=(input_T value); private: - CombinedOutputIterator& mIter; + CombinedOutputIterator* mIter; }; + public: using difference_type = std::ptrdiff_t; - using value_type = Proxy; + using value_type = input_T; using pointer = value_type*; using reference = value_type&; using iterator_category = std::input_iterator_tag; - public: CombinedOutputIterator(iterA_T iterA, iterB_T iterB, F functor); CombinedOutputIterator(const CombinedOutputIterator& iter) = default; CombinedOutputIterator(CombinedOutputIterator&& iter) = default; @@ -105,12 +108,13 @@ class CombinedOutputIterator CombinedOutputIterator operator++(int); // dereference - value_type operator*(); + Proxy& operator*(); private: - iterA_T mIterA; - iterB_T mIterB; - F mFunctor; + iterA_T mIterA{}; + iterB_T mIterB{}; + F mFunctor{}; + Proxy mProxy{*this}; public: friend std::ostream& operator<<(std::ostream& o, const CombinedOutputIterator& iter) @@ -120,6 +124,16 @@ class CombinedOutputIterator } }; +template +struct CombinedOutputIteratorFactory { + + template + static inline auto makeIter(iterA_T iterA, iterB_T iterB, F functor) -> CombinedOutputIterator + { + return {iterA, iterB, functor}; + } +}; + template CombinedInputIterator::CombinedInputIterator(iterA_T iterA, iterB_T iterB, F functor) : mIterA(iterA), mIterB(iterB), mFunctor(functor) { @@ -162,56 +176,72 @@ inline auto CombinedInputIterator::operator++(int) -> Combi } template -inline auto CombinedInputIterator::operator*() const +inline auto CombinedInputIterator::operator--() -> CombinedInputIterator& { - return mFunctor(mIterA, mIterB); + --mIterA; + --mIterB; + return *this; } template -CombinedOutputIterator::CombinedOutputIterator(iterA_T iterA, iterB_T iterB, F functor) : mIterA(iterA), mIterB(iterB), mFunctor(functor) +inline auto CombinedInputIterator::operator--(int) -> CombinedInputIterator { + auto res = *this; + --(*this); + return res; } template -auto CombinedOutputIterator::operator=(const CombinedOutputIterator& other) -> CombinedOutputIterator& +inline auto CombinedInputIterator::operator*() const +{ + return mFunctor(mIterA, mIterB); +} + +template +CombinedOutputIterator::CombinedOutputIterator(iterA_T iterA, iterB_T iterB, F functor) : mIterA(iterA), mIterB(iterB), mFunctor(functor) +{ +} + +template +auto CombinedOutputIterator::operator=(const CombinedOutputIterator& other) -> CombinedOutputIterator& { mIterA = other.mIterA; mIterB = other.mIterB; return *this; } -template -inline auto CombinedOutputIterator::operator++() -> CombinedOutputIterator& +template +inline auto CombinedOutputIterator::operator++() -> CombinedOutputIterator& { ++mIterA; ++mIterB; return *this; } -template -inline auto CombinedOutputIterator::operator++(int) -> CombinedOutputIterator +template +inline auto CombinedOutputIterator::operator++(int) -> CombinedOutputIterator { auto res = *this; ++(*this); return res; } -template -inline auto CombinedOutputIterator::operator*() -> value_type +template +inline auto CombinedOutputIterator::operator*() -> Proxy& { - return Proxy(*this); + mProxy = {*this}; + return mProxy; } -template -CombinedOutputIterator::Proxy::Proxy(CombinedOutputIterator& iter) : mIter(iter) +template +CombinedOutputIterator::Proxy::Proxy(CombinedOutputIterator& iter) : mIter(&iter) { } -template -template -inline auto CombinedOutputIterator::Proxy::operator=(value_T value) -> CombinedOutputIterator::Proxy& +template +inline auto CombinedOutputIterator::Proxy::operator=(input_T value) -> Proxy& { - mIter.mFunctor(mIter.mIterA, mIter.mIterB, value); + mIter->mFunctor(mIter->mIterA, mIter->mIterB, value); return *this; } diff --git a/Utilities/rANS/test/test_ransCombinedIterator.cxx b/Utilities/rANS/test/test_ransCombinedIterator.cxx index e9a14774df33e..ad6c37952b5f1 100644 --- a/Utilities/rANS/test/test_ransCombinedIterator.cxx +++ b/Utilities/rANS/test/test_ransCombinedIterator.cxx @@ -21,19 +21,10 @@ #include "rANS/utils.h" -struct test_CombninedIteratorFixture { - std::vector a{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf}; - std::vector b{a.rbegin(), a.rend()}; - size_t shift = 16; - std::vector aAndB{0x0001000f, 0x0002000e, 0x0003000d, 0x0004000c, 0x0005000b, - 0x0006000a, 0x00070009, 0x00080008, 0x00090007, 0x000a0006, - 0x000b0005, 0x000c0004, 0x000d0003, 0x000e0002, 0x000f0001}; -}; - -class ReadShiftFunctor +class ShiftFunctor { public: - ReadShiftFunctor(size_t shift) : mShift(shift){}; + ShiftFunctor(size_t shift) : mShift{shift} {}; template inline uint32_t operator()(iterA_T iterA, iterB_T iterB) const @@ -41,15 +32,6 @@ class ReadShiftFunctor return *iterB + (static_cast(*iterA) << mShift); }; - private: - size_t mShift; -}; - -class WriteShiftFunctor -{ - public: - WriteShiftFunctor(size_t shift) : mShift(shift){}; - template inline void operator()(iterA_T iterA, iterB_T iterB, uint32_t value) const { @@ -61,15 +43,18 @@ class WriteShiftFunctor size_t mShift; }; +struct test_CombninedIteratorFixture { + const std::vector a{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf}; + const std::vector b{a.rbegin(), a.rend()}; + const ShiftFunctor f{16}; + const std::vector aAndB{0x0001000f, 0x0002000e, 0x0003000d, 0x0004000c, 0x0005000b, + 0x0006000a, 0x00070009, 0x00080008, 0x00090007, 0x000a0006, + 0x000b0005, 0x000c0004, 0x000d0003, 0x000e0002, 0x000f0001}; +}; + BOOST_FIXTURE_TEST_CASE(test_CombinedInputIteratorBase, test_CombninedIteratorFixture) { - // auto readOP = [](auto iterA, auto iterB) -> uint32_t { - // return *iterB + (static_cast(*iterA) << 16); - // }; - - ReadShiftFunctor f(shift); - o2::rans::utils::CombinedInputIterator iter(a.begin(), b.begin(), f); // test equal const o2::rans::utils::CombinedInputIterator first(a.begin(), b.begin(), f); @@ -77,13 +62,22 @@ BOOST_FIXTURE_TEST_CASE(test_CombinedInputIteratorBase, test_CombninedIteratorFi // test not equal const o2::rans::utils::CombinedInputIterator second(++(a.begin()), ++(b.begin()), f); BOOST_CHECK_NE(iter, second); - // test pre increment + // test pre-increment ++iter; BOOST_CHECK_EQUAL(iter, second); - //test postIncrement + //test post-increment iter = first; BOOST_CHECK_EQUAL(iter++, first); BOOST_CHECK_EQUAL(iter, second); + // test pre-decrement + iter = second; + --iter; + BOOST_CHECK_EQUAL(iter, first); + // test post-decrement + iter = second; + BOOST_CHECK_EQUAL(iter--, second); + BOOST_CHECK_EQUAL(iter, first); + //test deref const uint32_t val = first.operator*(); BOOST_CHECK_EQUAL(val, aAndB.front()); @@ -94,15 +88,8 @@ BOOST_FIXTURE_TEST_CASE(test_CombinedOutputIteratorBase, test_CombninedIteratorF std::vector aOut(2, 0x0); std::vector bOut(2, 0x0); - // auto writeOP = [](auto iterA, auto iterB, uint32_t value) -> void { - // const uint32_t shift = 16; - // *iterA = value >> shift; - // *iterB = value & ((1 << shift) - 1); - // }; - - WriteShiftFunctor f(shift); - - o2::rans::utils::CombinedOutputIterator iter(aOut.begin(), bOut.begin(), f); + o2::rans::utils::CombinedOutputIteratorFactory iterFactory; + auto iter = iterFactory.makeIter(aOut.begin(), bOut.begin(), f); // test deref: *iter = aAndB[0]; @@ -111,7 +98,7 @@ BOOST_FIXTURE_TEST_CASE(test_CombinedOutputIteratorBase, test_CombninedIteratorF aOut[0] = 0x0; bOut[0] = 0x0; - // test pre increment + // test pre-increment *(++iter) = aAndB[1]; BOOST_CHECK_EQUAL(aOut[0], 0); BOOST_CHECK_EQUAL(bOut[0], 0); @@ -119,9 +106,9 @@ BOOST_FIXTURE_TEST_CASE(test_CombinedOutputIteratorBase, test_CombninedIteratorF BOOST_CHECK_EQUAL(bOut[1], b[1]); aOut.assign(2, 0x0); bOut.assign(2, 0x0); - iter = o2::rans::utils::CombinedOutputIterator(aOut.begin(), bOut.begin(), f); + iter = iterFactory.makeIter(aOut.begin(), bOut.begin(), f); - // test post increment + // test post-increment auto preInc = iter++; *preInc = aAndB[0]; BOOST_CHECK_EQUAL(aOut[0], a[0]); @@ -139,10 +126,6 @@ BOOST_FIXTURE_TEST_CASE(test_CombinedOutputIteratorBase, test_CombninedIteratorF BOOST_FIXTURE_TEST_CASE(test_CombinedInputIteratorReadArray, test_CombninedIteratorFixture) { - // auto readOP = [](auto iterA, auto iterB) -> uint32_t { - // return *iterB + (static_cast(*iterA) << 16); - // }; - ReadShiftFunctor f(shift); const o2::rans::utils::CombinedInputIterator begin(a.begin(), b.begin(), f); const o2::rans::utils::CombinedInputIterator end(a.end(), b.end(), f); @@ -154,7 +137,7 @@ BOOST_FIXTURE_TEST_CASE(test_CombinedOutputIteratorWriteArray, test_CombninedIte std::vector aRes(a.size(), 0); std::vector bRes(b.size(), 0); - o2::rans::utils::CombinedOutputIterator iter(aRes.begin(), bRes.begin(), WriteShiftFunctor(shift)); + auto iter = o2::rans::utils::CombinedOutputIteratorFactory::makeIter(aRes.begin(), bRes.begin(), f); for (auto input : aAndB) { *iter++ = input; }