From b8170f084e18c6977f38a0f587c6145c2d2cfa43 Mon Sep 17 00:00:00 2001 From: nils m Date: Tue, 15 Apr 2025 16:10:25 -0700 Subject: [PATCH 1/2] Add circuit location statistics for recursion predicates * Overhaul EDSL location handling; we now generate mlir Locations directly instead of keeping our own SourceLoc structure, and it's relatively easy to add additional context to component construction using ScopedSourceLoc. * Expand --op-stats flag to work on `gen_predicates`; this outputs `...-opstats.txt` files in the output directory for each predicate processed. * --op-stats on `gen_predicates` also outputs encoded cycle counts in `...-encoded-cycles.txt` * Added a bunch more source annotations to get better granularity for recursion predicates. --- .github/workflows/main.yml | 2 +- risc0/core/BUILD.bazel | 1 - risc0/core/source_loc.h | 78 ------- zirgen/circuit/keccak/predicates.cpp | 2 + zirgen/circuit/predicates/gen_predicates.cpp | 2 + zirgen/circuit/recursion/encode.cpp | 48 ++++- zirgen/circuit/recursion/encode.h | 5 +- zirgen/circuit/recursion/recursion.cpp | 3 +- zirgen/circuit/rv32im/v1/edsl/rv32im.cpp | 3 +- zirgen/circuit/verify/fri.cpp | 8 +- zirgen/circuit/verify/merkle.cpp | 6 +- zirgen/circuit/verify/poly.cpp | 4 +- zirgen/circuit/verify/verify.cpp | 23 ++- zirgen/compiler/codegen/gen_cpp.cpp | 3 +- zirgen/compiler/codegen/gen_recursion.cpp | 36 +++- zirgen/compiler/edsl/BUILD.bazel | 1 - zirgen/compiler/edsl/component.cpp | 6 +- zirgen/compiler/edsl/component.h | 18 +- zirgen/compiler/edsl/edsl.cpp | 207 +++++++++---------- zirgen/compiler/edsl/edsl.h | 144 ++++++++----- zirgen/compiler/edsl/source_loc.h | 84 -------- zirgen/compiler/stats/OpStats.cpp | 85 +++++++- zirgen/compiler/stats/OpStats.h | 17 ++ zirgen/components/bits.cpp | 10 +- zirgen/components/fpext.cpp | 32 ++- zirgen/components/fpext.h | 18 +- zirgen/components/onehot.h | 6 +- zirgen/components/reg.h | 10 +- zirgen/components/u32.cpp | 6 +- zirgen/components/u32.h | 4 +- zirgen/dsl/passes/GenerateAccum.cpp | 157 +++++++------- 31 files changed, 533 insertions(+), 496 deletions(-) delete mode 100644 risc0/core/source_loc.h delete mode 100644 zirgen/compiler/edsl/source_loc.h diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index fdeaed76..02bc6c09 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -35,7 +35,7 @@ jobs: fetch-depth: 0 - uses: risc0/risc0/.github/actions/rustup@a9d723e29a44563497220a998b5de4e03d9da049 - name: Install cargo-sort - uses: risc0/cargo-install@v1 + uses: risc0/cargo-install@9f6037ed331dcf7da101461a20656273fa72abf0 with: crate: cargo-sort version: "1.0" diff --git a/risc0/core/BUILD.bazel b/risc0/core/BUILD.bazel index 1c850361..748d7941 100644 --- a/risc0/core/BUILD.bazel +++ b/risc0/core/BUILD.bazel @@ -8,7 +8,6 @@ cc_library( hdrs = [ "elf.h", "log.h", - "source_loc.h", "util.h", ], visibility = ["//visibility:public"], diff --git a/risc0/core/source_loc.h b/risc0/core/source_loc.h deleted file mode 100644 index b85524b4..00000000 --- a/risc0/core/source_loc.h +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2024 RISC Zero, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -/// \file -/// SourceLoc is basically a near clone of std::source_location, but since it's not in the spec -/// until c++20, we build our own minialist version. This is currently used to help track EDSL -/// expressions for the IR. - -#ifdef __has_builtin -#if __has_builtin(__builtin_FILE) -#define FILE_EXPR __builtin_FILE() -#else -#define FILE_EXPR __FILE__ -#endif -#else -#define FILE_EXPR __FILE__ -#endif - -#ifdef __has_builtin -#if __has_builtin(__builtin_LINE) -#define LINE_EXPR __builtin_LINE() -#else -#define LINE_EXPR __LINE__ -#endif -#else -#define LINE_EXPR __LINE__ -#endif - -#ifdef __has_builtin -#if __has_builtin(__builtin_COLUMN) -#define COLUMN_EXPR __builtin_COLUMN() -#else -#define COLUMN_EXPR 0 -#endif -#else -#define COLUMN_EXPR 0 -#endif - -namespace risc0 { - -/// A source location. Capture as much as we can from the compiler, which might be nothing. -struct SourceLoc { -public: - /// Get the "current" source location. When used in default values, this effectively captures the - /// call site of the function declaring the default value, which is very useful. - static constexpr SourceLoc current(const char* filename = FILE_EXPR, - int line = LINE_EXPR, - int column = COLUMN_EXPR) noexcept { - SourceLoc loc; - loc.filename = filename; - loc.line = line; - loc.column = column; - return loc; - } - - constexpr SourceLoc() noexcept : filename(""), line(0), column(0) {} - - const char* filename; - size_t line; - size_t column; -}; - -} // namespace risc0 diff --git a/zirgen/circuit/keccak/predicates.cpp b/zirgen/circuit/keccak/predicates.cpp index 2e053640..f0931a6d 100644 --- a/zirgen/circuit/keccak/predicates.cpp +++ b/zirgen/circuit/keccak/predicates.cpp @@ -16,6 +16,7 @@ #include "zirgen/circuit/recursion/code.h" #include "zirgen/circuit/verify/wrap_zirgen.h" #include "zirgen/compiler/codegen/codegen.h" +#include "zirgen/compiler/stats/OpStats.h" using namespace zirgen; using namespace zirgen::verify; @@ -55,6 +56,7 @@ static cl::opt keccakIR("keccak-ir", int main(int argc, char* argv[]) { llvm::InitLLVM y(argc, argv); registerEdslCLOptions(); + registerOpStatsCLOptions(); llvm::cl::ParseCommandLineOptions(argc, argv, "keccak predicates"); Module module; diff --git a/zirgen/circuit/predicates/gen_predicates.cpp b/zirgen/circuit/predicates/gen_predicates.cpp index a0e99b98..525a7e22 100644 --- a/zirgen/circuit/predicates/gen_predicates.cpp +++ b/zirgen/circuit/predicates/gen_predicates.cpp @@ -19,6 +19,7 @@ #include "zirgen/circuit/verify/wrap_rv32im.h" #include "zirgen/circuit/verify/wrap_zirgen.h" #include "zirgen/compiler/codegen/codegen.h" +#include "zirgen/compiler/stats/OpStats.h" using namespace zirgen; using namespace zirgen::verify; @@ -202,6 +203,7 @@ static cl::opt int main(int argc, char* argv[]) { llvm::InitLLVM y(argc, argv); registerEdslCLOptions(); + registerOpStatsCLOptions(); llvm::cl::ParseCommandLineOptions(argc, argv, "gen_predicates edsl"); Module module; diff --git a/zirgen/circuit/recursion/encode.cpp b/zirgen/circuit/recursion/encode.cpp index dfe1ce31..bbbb40f7 100644 --- a/zirgen/circuit/recursion/encode.cpp +++ b/zirgen/circuit/recursion/encode.cpp @@ -1,4 +1,4 @@ -// Copyright 2024 RISC Zero, Inc. +// Copyright 2025 RISC Zero, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -136,8 +136,10 @@ struct Instructions { uint64_t padShaCountConst; uint64_t shaRngConsts; llvm::StringMap tagConsts; + EncodeStats* stats = nullptr; - Instructions(HashType hashType) : hashType(hashType), nextOut(1), microUsed(0) { + Instructions(HashType hashType, EncodeStats* stats) + : hashType(hashType), nextOut(1), microUsed(0), stats(stats) { addMacro(/*outs=*/0, MacroOpcode::WOM_INIT); // Make: [0, 1, 0, 0], [0, 0, 1, 0], and [0, 0, 0, 1] fp4Rot1 = addConst(0, 1); @@ -224,6 +226,15 @@ struct Instructions { if (microUsed == 3) { microUsed = 0; } + + if (stats) { + if (out) { + ScopedLocation loc(out.getLoc()); + stats->locs[currentLoc()]++; + } else { + stats->locs[currentLoc()]++; + } + } return outId; } @@ -245,6 +256,8 @@ struct Instructions { uint64_t writeAddr = nextOut; data.back().writeAddr = writeAddr; nextOut += outs; + if (stats) + stats->locs[currentLoc()] += 3; return writeAddr; } @@ -263,6 +276,8 @@ struct Instructions { uint64_t writeAddr = nextOut; data.back().writeAddr = writeAddr; nextOut += 1; // Write exactly one thing (evaluated point) + if (stats) + stats->locs[currentLoc()] += 3; return writeAddr; } @@ -286,6 +301,8 @@ struct Instructions { data.back().data.poseidon2Mem.inputs[i] = inputs[i]; } data.back().writeAddr = nextOut; + if (stats) + stats->locs[currentLoc()] += 3; } void addPoseidon2Full(uint64_t cycle) { @@ -295,6 +312,8 @@ struct Instructions { data.back().opType = OpType::POSEIDON2_FULL; data.back().data.poseidon2Full.cycle = cycle; data.back().writeAddr = nextOut; + if (stats) + stats->locs[currentLoc()] += 3; } void addPoseidon2Partial() { @@ -303,6 +322,8 @@ struct Instructions { data.emplace_back(); data.back().opType = OpType::POSEIDON2_PARTIAL; data.back().writeAddr = nextOut; + if (stats) + stats->locs[currentLoc()] += 3; } uint64_t addPoseidon2Store(uint64_t doMont, uint64_t group) { @@ -315,10 +336,13 @@ struct Instructions { uint64_t writeAddr = nextOut; data.back().writeAddr = writeAddr; nextOut += 8; + if (stats) + stats->locs[currentLoc()] += 3; return writeAddr; } void doShaInit() { + ScopedLocation loc; for (size_t i = 0; i < 4; i++) { shaUsed++; addMacro(/*outs=*/0, MacroOpcode::SHA_INIT, shaInit[3 - i], shaInit[3 - i + 4]); @@ -326,6 +350,7 @@ struct Instructions { } void doShaLoad(llvm::ArrayRef values, uint64_t subtype) { + ScopedLocation loc; for (size_t i = 0; i < 16; i++) { shaUsed++; addMacro(/*outs=*/0, MacroOpcode::SHA_LOAD, values[i], shaK[i], subtype); @@ -333,12 +358,14 @@ struct Instructions { } void doShaMix() { + ScopedLocation loc; for (size_t i = 0; i < 48; i++) { shaUsed++; addMacro(/*outs=*/0, MacroOpcode::SHA_MIX, 0, shaK[16 + i]); } } uint64_t doShaFini() { + ScopedLocation loc; uint64_t out = nextOut; for (size_t i = 0; i < 4; i++) { shaUsed++; @@ -349,6 +376,7 @@ struct Instructions { } uint64_t doSha(llvm::ArrayRef values, uint64_t subtype) { + ScopedLocation loc; doShaInit(); uint64_t ret = 0; for (size_t i = 0; i < values.size() / 16; i++) { @@ -360,6 +388,7 @@ struct Instructions { } uint64_t doShaFold(uint64_t lhs, uint64_t rhs) { + ScopedLocation loc; std::vector ids(16); for (size_t i = 0; i < 8; i++) { ids[i] = lhs + i; @@ -369,6 +398,7 @@ struct Instructions { } uint64_t doIntoDigestShaBytes(llvm::ArrayRef bytes) { + ScopedLocation loc; // We keep things in low / high form right until the end so that the final adds are // all contiguous since all the 'digest' stuff assumes digests are always contiguous. std::vector low; @@ -388,6 +418,7 @@ struct Instructions { } uint64_t doIntoDigestShaWords(llvm::ArrayRef words) { + ScopedLocation loc; // We keep things in low / high form right until the end so that the final adds are // all contiguous since all the 'digest' stuff assumes digests are always contiguous. std::vector low; @@ -404,6 +435,7 @@ struct Instructions { } uint64_t doShaTag(llvm::StringRef tag) { + ScopedLocation loc; if (tagConsts.count(tag)) { return tagConsts.find(tag)->second; } else { @@ -422,6 +454,7 @@ struct Instructions { llvm::ArrayRef digests, llvm::ArrayRef digestTypes, llvm::ArrayRef vals) { + ScopedLocation loc; std::vector words; for (size_t i = 0; i < 8; i++) { @@ -472,6 +505,7 @@ struct Instructions { // representation, with 16 bits in each of the two low components of an // extension field element. void taggedStructPushVals(std::vector& words, llvm::ArrayRef vals) { + ScopedLocation loc; // Get low 16 bits of each value (done in a loop for better packing) std::vector lowVals; for (size_t i = 0; i < vals.size(); i++) { @@ -489,6 +523,7 @@ struct Instructions { } std::pair> doHashCheckedBytes(uint64_t evalPt, uint64_t count) { + ScopedLocation loc; if (!count) { // Special case for 0 outputs return {doPoseidon2({}), {}}; @@ -519,6 +554,7 @@ struct Instructions { std::tuple> doHashCheckedBytesPublic(uint64_t evalPt, uint64_t count) { + ScopedLocation loc; if (!count) throw std::runtime_error("Cannont publically hash empty checked bytes"); @@ -570,6 +606,7 @@ struct Instructions { } uint64_t doPoseidon2(llvm::ArrayRef values) { + ScopedLocation loc; if (values.empty()) { auto psuite = poseidon2HashSuite(); auto hashVal = psuite->hash(nullptr, 0); @@ -634,6 +671,7 @@ struct Instructions { } uint64_t doIntoDigestPoseidon2(llvm::ArrayRef words) { + ScopedLocation loc; // Do pointless adds to make all the words land in sequential spots uint64_t ret = nextOut; size_t pad = words.size() / 8 - 1; @@ -646,6 +684,7 @@ struct Instructions { } void addInst(Operation& op) { + ScopedLocation loc(op.getLoc()); TypeSwitch(&op) .Case([&](Zll::ExternOp op) { if (op.getName() == "write") { @@ -984,6 +1023,7 @@ uint64_t ShaRng::generateFp(Instructions& insts) { } void ShaRng::mix(Instructions& insts, uint64_t digest) { + ScopedLocation loc; uint64_t xorOut = insts.nextOut; for (size_t i = 0; i < 8; i++) { // Xors and returns 2 shorts: [a, b, 0, 0] ^ [c, d, 0, 0] -> [a ^ c, b ^ d, 0, 0] @@ -1022,6 +1062,7 @@ uint64_t Poseidon2Rng::generateFp(Instructions& insts) { } void Poseidon2Rng::mix(Instructions& insts, uint64_t digest) { + ScopedLocation loc; if (stateUsed != 0) { stateUsed = 0; mix(insts, 0); @@ -1069,6 +1110,7 @@ void Poseidon2Rng::mix(Instructions& insts, uint64_t digest) { } void MixedPoseidon2ShaRng::mix(Instructions& insts, uint64_t digest) { + ScopedLocation loc; // For each element of the Poseidon2 hash, we convert it to a form usable by SHA. // This is done in stages so that macro ops are grouped together for efficiency // First, we 'and' things by 0xffff @@ -1104,7 +1146,7 @@ std::vector encode(HashType hashType, mlir::Block* block, llvm::DenseMap* toIdReturn, EncodeStats* stats) { - Instructions insts(hashType); + Instructions insts(hashType, stats); for (Operation& op : block->without_terminator()) { insts.addInst(op); } diff --git a/zirgen/circuit/recursion/encode.h b/zirgen/circuit/recursion/encode.h index 1f5bb217..7441b0a3 100644 --- a/zirgen/circuit/recursion/encode.h +++ b/zirgen/circuit/recursion/encode.h @@ -1,4 +1,4 @@ -// Copyright 2024 RISC Zero, Inc. +// Copyright 2025 RISC Zero, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -31,6 +31,9 @@ struct EncodeStats { size_t totCycles = 0; size_t shaCycles = 0; size_t poseidon2Cycles = 0; + + // Locations and the number of micro cycles used (= number of macro cycles * 3). + llvm::DenseMap locs; }; std::vector encode(HashType hashType, diff --git a/zirgen/circuit/recursion/recursion.cpp b/zirgen/circuit/recursion/recursion.cpp index 273197cc..600d943e 100644 --- a/zirgen/circuit/recursion/recursion.cpp +++ b/zirgen/circuit/recursion/recursion.cpp @@ -1,4 +1,4 @@ -// Copyright 2024 RISC Zero, Inc. +// Copyright 2025 RISC Zero, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -20,7 +20,6 @@ using namespace zirgen; using namespace zirgen::recursion; -using namespace risc0; using namespace mlir; int main(int argc, char* argv[]) { diff --git a/zirgen/circuit/rv32im/v1/edsl/rv32im.cpp b/zirgen/circuit/rv32im/v1/edsl/rv32im.cpp index f4e3a119..faa7a0df 100644 --- a/zirgen/circuit/rv32im/v1/edsl/rv32im.cpp +++ b/zirgen/circuit/rv32im/v1/edsl/rv32im.cpp @@ -1,4 +1,4 @@ -// Copyright 2024 RISC Zero, Inc. +// Copyright 2025 RISC Zero, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -23,7 +23,6 @@ using namespace zirgen; using namespace zirgen::rv32im_v1; -using namespace risc0; using namespace mlir; int main(int argc, char* argv[]) { diff --git a/zirgen/circuit/verify/fri.cpp b/zirgen/circuit/verify/fri.cpp index ffec017c..60988c80 100644 --- a/zirgen/circuit/verify/fri.cpp +++ b/zirgen/circuit/verify/fri.cpp @@ -1,4 +1,4 @@ -// Copyright 2024 RISC Zero, Inc. +// Copyright 2025 RISC Zero, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -73,6 +73,8 @@ Val fold_eval(const std::vector& values, Val x) { } Val dynamic_pow(Val in, Val pow, size_t maxPow) { + ScopedLocation loc; + Val out = 1; Val mul = in; for (size_t i = 0; i < log2Ceil(maxPow); i++) { @@ -98,6 +100,8 @@ struct VerifyRoundInfo { , mix(iop.rngExtVal()) {} void verifyQuery(ReadIopVal& iop, Val* pos, Val* goal) const { + ScopedLocation loc; + // Compute which group we are in Val group = *pos & (domain - 1); Val qout = (*pos - group) / domain; @@ -117,6 +121,8 @@ struct VerifyRoundInfo { // Verify a FRI proof, void friVerify(ReadIopVal& iop, size_t deg, InnerVerify inner) { + ScopedLocation loc; + size_t domain = deg * kInvRate; size_t origDomain = domain; std::vector rounds; diff --git a/zirgen/circuit/verify/merkle.cpp b/zirgen/circuit/verify/merkle.cpp index 459f2361..eaa91af4 100644 --- a/zirgen/circuit/verify/merkle.cpp +++ b/zirgen/circuit/verify/merkle.cpp @@ -1,4 +1,4 @@ -// Copyright 2024 RISC Zero, Inc. +// Copyright 2025 RISC Zero, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -51,6 +51,8 @@ MerkleTreeVerifier::MerkleTreeVerifier(std::string bufName, size_t queries, bool useExtension) : MerkleTreeParams(rowSize, colSize, queries, useExtension), top(topSize), bufName(bufName) { + ScopedLocation loc; + auto topRec = iop.readDigests(topSize); top.insert(top.end(), topRec.begin(), topRec.end()); for (size_t i = topSize; i-- > 1;) { @@ -64,6 +66,8 @@ DigestVal MerkleTreeVerifier::getRoot() const { } std::vector MerkleTreeVerifier::verify(ReadIopVal& iop, Val idx) const { + ScopedLocation loc; + std::vector out; if (useExtension) { out = iop.readExtVals(colSize); diff --git a/zirgen/circuit/verify/poly.cpp b/zirgen/circuit/verify/poly.cpp index d8adff7a..7919d3cd 100644 --- a/zirgen/circuit/verify/poly.cpp +++ b/zirgen/circuit/verify/poly.cpp @@ -1,4 +1,4 @@ -// Copyright 2024 RISC Zero, Inc. +// Copyright 2025 RISC Zero, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,6 +17,8 @@ namespace zirgen::verify { Val poly_eval(const std::vector& coeffs, Val x) { + ScopedLocation loc; + Val tot = 0; Val mul = 1; for (size_t i = 0; i < coeffs.size(); i++) { diff --git a/zirgen/circuit/verify/verify.cpp b/zirgen/circuit/verify/verify.cpp index 6a4e9498..c095e6b0 100644 --- a/zirgen/circuit/verify/verify.cpp +++ b/zirgen/circuit/verify/verify.cpp @@ -1,4 +1,4 @@ -// Copyright 2024 RISC Zero, Inc. +// Copyright 2025 RISC Zero, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -38,6 +38,8 @@ template T dbg(std::string fmt, T arg) { } // namespace VerifyInfo verify(ReadIopVal& iop, size_t po2, const CircuitInterface& circuit) { + ScopedLocation loc; + VerifyInfo verifyInfo; // At the start of verification, add the version strings to the Fiat-Shamir transcript. @@ -177,11 +179,22 @@ VerifyInfo verify(ReadIopVal& iop, size_t po2, const CircuitInterface& circuit) } // Finally, do a FRI verification friVerify(iop, size, [&](ReadIopVal& iop, Val idx) { + ScopedLocation loc; + auto x = dynamic_pow(kRouFwd[log2Ceil(domain)], idx, domain); std::map> rows; - rows[/*REGISTER_GROUP_ACCUM*/ 0] = accumMerkle.verify(iop, idx); - rows[/*REGISTER_GROUP_CODE=*/1] = codeMerkle.verify(iop, idx); - rows[/*REGISTER_GROUP_DATA=*/2] = dataMerkle.verify(iop, idx); + { + ScopedLocation loc; + rows[/*REGISTER_GROUP_ACCUM*/ 0] = accumMerkle.verify(iop, idx); + } + { + ScopedLocation loc; + rows[/*REGISTER_GROUP_CODE=*/1] = codeMerkle.verify(iop, idx); + } + { + ScopedLocation loc; + rows[/*REGISTER_GROUP_DATA=*/2] = dataMerkle.verify(iop, idx); + } auto checkRow = checkMerkle.verify(iop, idx); Val curMix = 1; std::vector tot(comboU.size(), 0); @@ -219,6 +232,8 @@ VerifyInfo verifyRecursion(ReadIopVal& allowedRoot, std::vector seals, std::vector alloweds, const CircuitInterface& circuit) { + ScopedLocation loc; + VerifyInfo verifyInfo; verifyInfo.codeRoot = allowedRoot.readDigests(1)[0]; diff --git a/zirgen/compiler/codegen/gen_cpp.cpp b/zirgen/compiler/codegen/gen_cpp.cpp index 92720403..fe87ddad 100644 --- a/zirgen/compiler/codegen/gen_cpp.cpp +++ b/zirgen/compiler/codegen/gen_cpp.cpp @@ -1,4 +1,4 @@ -// Copyright 2024 RISC Zero, Inc. +// Copyright 2025 RISC Zero, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -37,7 +37,6 @@ struct CppStreamEmitterImpl : CppStreamEmitter { void header(func::FuncOp func) { ofs << "// This code is automatically generated\n\n"; ofs << "#include \"impl.h\"\n\n"; - ofs << "using namespace risc0;\n\n"; ofs << "namespace circuit::" << func.getName() << " {\n\n"; } diff --git a/zirgen/compiler/codegen/gen_recursion.cpp b/zirgen/compiler/codegen/gen_recursion.cpp index 663c6c35..2b74e04d 100644 --- a/zirgen/compiler/codegen/gen_recursion.cpp +++ b/zirgen/compiler/codegen/gen_recursion.cpp @@ -1,4 +1,4 @@ -// Copyright 2024 RISC Zero, Inc. +// Copyright 2025 RISC Zero, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ #include "zirgen/Dialect/Zll/IR/IR.h" #include "zirgen/circuit/recursion/encode.h" +#include "zirgen/compiler/stats/OpStats.h" using namespace mlir; @@ -29,16 +30,6 @@ namespace zirgen { namespace { -class EmitRecursionPass : public impl::EmitRecursionBase { -public: - EmitRecursionPass() = default; - EmitRecursionPass(StringRef dir) { this->outputDir = dir.str(); } - void runOnOperation() override { - recursion::EncodeStats stats; - emitRecursion(outputDir, getOperation(), &stats); - } -}; - std::unique_ptr openOutputFile(const std::string& path, const std::string& name) { std::string filename = path + "/" + name; @@ -50,6 +41,23 @@ std::unique_ptr openOutputFile(const std::string& path, return ofs; } +class EmitRecursionPass : public impl::EmitRecursionBase { +public: + EmitRecursionPass() = default; + EmitRecursionPass(StringRef dir) { this->outputDir = dir.str(); } + void runOnOperation() override { + BogoCycleAnalysis bogoCycles; + + if (bogoCycles.isEnabled()) { + auto ofs = openOutputFile(outputDir, getOperation().getName().str() + "-opstats.txt"); + bogoCycles.printStatsIfRequired(getOperation(), *ofs); + } + + recursion::EncodeStats stats; + emitRecursion(outputDir, getOperation(), &stats); + } +}; + } // namespace void emitRecursion(const std::string& path, func::FuncOp func, recursion::EncodeStats* stats) { @@ -69,6 +77,12 @@ void emitRecursion(const std::string& path, func::FuncOp func, recursion::Encode for (const auto& elem : locs) { *debugOfs << elem.first << " <- " << elem.second << "\n"; } + + BogoCycleAnalysis bogoCycles; + if (bogoCycles.isEnabled()) { + auto ofs = openOutputFile(path, func.getName().str() + "-encoded-cycles.txt"); + bogoCycles.printStatsFromMap(func.getContext(), stats->locs, *ofs); + } } std::unique_ptr> diff --git a/zirgen/compiler/edsl/BUILD.bazel b/zirgen/compiler/edsl/BUILD.bazel index 3286d303..93c39654 100644 --- a/zirgen/compiler/edsl/BUILD.bazel +++ b/zirgen/compiler/edsl/BUILD.bazel @@ -12,7 +12,6 @@ cc_library( hdrs = [ "component.h", "edsl.h", - "source_loc.h", ], deps = [ "//zirgen/Dialect/ZStruct/IR", diff --git a/zirgen/compiler/edsl/component.cpp b/zirgen/compiler/edsl/component.cpp index 48ef88ad..a98aaaac 100644 --- a/zirgen/compiler/edsl/component.cpp +++ b/zirgen/compiler/edsl/component.cpp @@ -1,4 +1,4 @@ -// Copyright 2024 RISC Zero, Inc. +// Copyright 2025 RISC Zero, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -214,7 +214,9 @@ std::string demangle(std::string ident) { } // namespace -void CompContext::pushConstruct(llvm::StringRef ident, llvm::StringRef mangledTy, SourceLoc loc) { +void CompContext::pushConstruct(llvm::StringRef ident, + llvm::StringRef mangledTy, + mlir::Location loc) { std::string demangledTy = demangle(mangledTy.str()); // Generate something a bit more readable out of excessive things like this: diff --git a/zirgen/compiler/edsl/component.h b/zirgen/compiler/edsl/component.h index 111a52f4..aaecf8f7 100644 --- a/zirgen/compiler/edsl/component.h +++ b/zirgen/compiler/edsl/component.h @@ -1,4 +1,4 @@ -// Copyright 2024 RISC Zero, Inc. +// Copyright 2025 RISC Zero, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -43,7 +43,7 @@ struct ConstructInfo { std::map labels; std::string typeName; std::map> subcomponents; - SourceLoc loc; + mlir::LocationAttr loc; }; // A context singleton used during component constructon @@ -61,7 +61,8 @@ class CompContext { static void leaveMux(); // Debug info - static void pushConstruct(llvm::StringRef ident, llvm::StringRef ty, SourceLoc loc = current()); + static void + pushConstruct(llvm::StringRef ident, llvm::StringRef ty, mlir::Location loc = currentLoc()); static void popConstruct(); static std::shared_ptr getCurConstruct(); static void saveLabel(Buffer buf, llvm::StringRef label); @@ -91,12 +92,13 @@ class CompContext { // A label for a component in the layout. class Label { public: - Label(SourceLoc loc = current()) : loc(loc) {} + Label(mlir::Location loc = currentLoc()) : loc(loc) {} - /* implicit */ Label(llvm::StringRef label, SourceLoc loc = current()) : label(label), loc(loc) {} + /* implicit */ Label(llvm::StringRef label, mlir::Location loc = currentLoc()) + : label(label), loc(loc) {} // Numbered instance of something - Label(llvm::StringRef label, size_t index, SourceLoc loc = current()) + Label(llvm::StringRef label, size_t index, mlir::Location loc = currentLoc()) : label((label + "[" + std::to_string(index) + "]").str()), loc(loc) {} // Convert to a singular label @@ -109,7 +111,7 @@ class Label { return genArray(seq); } - SourceLoc getLoc() { return loc; } + mlir::Location getLoc() { return loc; } private: template @@ -118,7 +120,7 @@ class Label { } std::string label; - SourceLoc loc; + mlir::Location loc; }; inline std::vector Labels(std::initializer_list labels) { diff --git a/zirgen/compiler/edsl/edsl.cpp b/zirgen/compiler/edsl/edsl.cpp index 7469b367..b0488618 100644 --- a/zirgen/compiler/edsl/edsl.cpp +++ b/zirgen/compiler/edsl/edsl.cpp @@ -1,4 +1,4 @@ -// Copyright 2024 RISC Zero, Inc. +// Copyright 2025 RISC Zero, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ #include "zirgen/compiler/edsl/edsl.h" #include "mlir/Analysis/TopologicalSortUtils.h" +#include "mlir/IR/Location.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" @@ -36,6 +37,7 @@ bool gBackUnchecked = false; bool gBackUsed = false; Module* curModule = nullptr; +MLIRContext* mlirContext = nullptr; OpBuilder& getBuilder() { assert(curModule); @@ -43,69 +45,42 @@ OpBuilder& getBuilder() { } MLIRContext* getCtx() { - assert(curModule); - return curModule->getCtx(); -} - -Location toLoc(SourceLoc loc, StringRef ident = {}) { - LocationAttr inner; - if (loc.filename) { - auto id = StringAttr::get(getCtx(), loc.filename); - inner = FileLineColLoc::get(id, loc.line, loc.column); - } else { - inner = UnknownLoc::get(getCtx()); - } - - if (ident.empty()) { - return inner; - } else { - return NameLoc::get(StringAttr::get(getCtx(), ident), inner); - } + if (curModule) + assert(mlirContext == curModule->getCtx()); + return mlirContext; } -std::vector& getLocStack() { - static std::vector stack; - return stack; +Location nameLoc(Location parent, StringRef ident) { + return NameLoc::get(getBuilder().getStringAttr(ident), parent); } } // namespace -Module* Module::getCurModule() { - return curModule; -} - -OverrideLocation::OverrideLocation(SourceLoc loc) { - getLocStack().push_back(loc); -} - -OverrideLocation::~OverrideLocation() { - getLocStack().pop_back(); +void registerEdslContext(mlir::MLIRContext* ctx) { + mlirContext = ctx; } -SourceLoc checkCurrentLoc(SourceLoc loc) { - if (getLocStack().empty()) { - return loc; - } - return getLocStack().back(); +Module* Module::getCurModule() { + return curModule; } -Val::Val(uint64_t val, SourceLoc loc) { +Val::Val(uint64_t val, Location loc) { Type ty = ValType::getBaseType(getBuilder().getContext()); - value = getBuilder().create(toLoc(loc), ty, val); + value = getBuilder().create(loc, ty, val); } -Val::Val(llvm::ArrayRef val, SourceLoc loc) { +Val::Val(llvm::ArrayRef val, Location loc) { assert(val.size() == kBabyBearExtSize); Type ty = ValType::get(getBuilder().getContext(), kFieldPrimeDefault, val.size()); - value = getBuilder().create(toLoc(loc), ty, val); + value = getBuilder().create(loc, ty, val); } -Val::Val(Register reg, SourceLoc loc) { +Val::Val(Register reg, Location loc) { if (cast(reg.buf.getType()).getKind() == BufferKind::Global) { - value = getBuilder().create(toLoc(loc, reg.ident), reg.buf, 0); + value = getBuilder().create(nameLoc(loc, reg.ident), reg.buf, 0); } else { auto getOp = - getBuilder().create(toLoc(loc, reg.ident), reg.buf, 0, gBackDist, IntegerAttr()); + getBuilder().create(nameLoc(loc, reg.ident), reg.buf, 0, gBackDist, IntegerAttr()); if (gBackUnchecked) { getOp->setAttr("unchecked", UnitAttr::get(getOp.getContext())); } @@ -116,56 +91,57 @@ Val::Val(Register reg, SourceLoc loc) { void Register::operator=(CaptureVal x) { if (cast(buf.getType()).getKind() == BufferKind::Global) { - getBuilder().create(toLoc(x.loc, x.ident), buf, 0, x.getValue()); + getBuilder().create(nameLoc(x.loc, x.ident), buf, 0, x.getValue()); } else { - getBuilder().create(toLoc(x.loc, x.ident), buf, 0, x.getValue()); + getBuilder().create(nameLoc(x.loc, x.ident), buf, 0, x.getValue()); } } mlir::Location CaptureIdx::getLoc() { - return toLoc(loc); + return loc; } -Val Buffer::get(size_t idx, StringRef ident, SourceLoc loc) { +Val Buffer::get(size_t idx, StringRef ident, Location loc) { if (cast(buf.getType()).getKind() == BufferKind::Global) { - return Val(getBuilder().create(toLoc(loc, ident), buf, 0)); + return Val(getBuilder().create(nameLoc(loc, ident), buf, 0)); } else { - return Val(getBuilder().create(toLoc(loc, ident), buf, idx, 0, IntegerAttr())); + return Val(getBuilder().create(nameLoc(loc, ident), buf, idx, 0, IntegerAttr())); } } -void Buffer::set(size_t idx, Val x, StringRef ident, SourceLoc loc) { +void Buffer::set(size_t idx, Val x, StringRef ident, Location loc) { if (cast(buf.getType()).getKind() == BufferKind::Global) { - getBuilder().create(toLoc(loc, ident), buf, idx, x.getValue()); + getBuilder().create(nameLoc(loc, ident), buf, idx, x.getValue()); } else { - getBuilder().create(toLoc(loc, ident), buf, idx, x.getValue()); + getBuilder().create(nameLoc(loc, ident), buf, idx, x.getValue()); } } -void Buffer::setDigest(size_t idx, DigestVal x, StringRef ident, SourceLoc loc) { +void Buffer::setDigest(size_t idx, DigestVal x, StringRef ident, Location loc) { if (cast(buf.getType()).getKind() != BufferKind::Global) { throw(std::runtime_error("Currently digests can only be stored in globals")); } - getBuilder().create(toLoc(loc, ident), buf, idx, x.getValue()); + getBuilder().create(nameLoc(loc, ident), buf, idx, x.getValue()); } -Buffer Buffer::slice(size_t offset, size_t size, SourceLoc loc) { - return Buffer(getBuilder().create(toLoc(loc), buf, offset, size)); +Buffer Buffer::slice(size_t offset, size_t size, Location loc) { + return Buffer(getBuilder().create(loc, buf, offset, size)); } -Register Buffer::getRegister(size_t idx, StringRef ident, SourceLoc loc) { +Register Buffer::getRegister(size_t idx, StringRef ident, Location loc) { if (idx >= cast(buf.getType()).getSize()) { - llvm::errs() << "Out of bounds index: " << loc.filename << ":" << loc.line << "\n"; + llvm::errs() << "Out of bounds index: " << loc << "\n"; throw std::runtime_error("OOB Index"); } - return Register(getBuilder().create(toLoc(loc), buf, idx, 1), ident); + return Register(getBuilder().create(loc, buf, idx, 1), ident); } mlir::Location CaptureVal::getLoc() { - return toLoc(loc); + return loc; } Module::Module() : builder(&ctx) { + registerEdslContext(&ctx); ctx.getOrLoadDialect(); ctx.getOrLoadDialect(); ctx.getOrLoadDialect(); @@ -335,7 +311,7 @@ void Module::dumpStage(size_t stage, bool debug) { void Module::beginFunc(const std::string& name, const std::vector& args, - SourceLoc loc) { + Location loc) { curModule = this; std::vector inTypes; for (auto ai : args) { @@ -347,7 +323,7 @@ void Module::beginFunc(const std::string& name, } } auto funcType = FunctionType::get(&ctx, inTypes, {}); - auto func = builder.create(toLoc(loc), name, funcType); + auto func = builder.create(loc, name, funcType); pushIP(func.addEntryBlock()); } @@ -400,8 +376,8 @@ void Module::setProtocolInfo(ProtocolInfo info) { setModuleAttr(getModule(), getBuilder().getAttr(info)); } -mlir::func::FuncOp Module::endFunc(SourceLoc loc) { - auto returnOp = builder.create(toLoc(loc)); +mlir::func::FuncOp Module::endFunc(Location loc) { + auto returnOp = builder.create(loc); popIP(); curModule = nullptr; return returnOp->getParentOfType(); @@ -538,7 +514,7 @@ std::vector doExtern(const std::string& name, const std::string& extra, size_t outSize, llvm::ArrayRef in, - SourceLoc loc) { + Location loc) { MLIRContext* ctx = getBuilder().getContext(); std::vector outTypes; for (size_t i = 0; i < outSize; i++) { @@ -548,7 +524,7 @@ std::vector doExtern(const std::string& name, for (auto& val : in) { inValues.push_back(val.getValue()); } - auto op = getBuilder().create(toLoc(loc), outTypes, inValues, name, extra); + auto op = getBuilder().create(loc, outTypes, inValues, name, extra); std::vector outs; for (size_t i = 0; i < outSize; i++) { outs.emplace_back(op.getResult(i)); @@ -577,9 +553,9 @@ void emitLayoutInternal(std::shared_ptr info) { } } -NondetGuard::NondetGuard(SourceLoc loc) { +NondetGuard::NondetGuard(Location loc) { assert(curModule); - auto nondetOp = getBuilder().create(toLoc(loc)); + auto nondetOp = getBuilder().create(loc); Block* innerBlock = new Block(); nondetOp.getInner().push_back(innerBlock); curModule->pushIP(innerBlock); @@ -587,21 +563,21 @@ NondetGuard::NondetGuard(SourceLoc loc) { NondetGuard::~NondetGuard() { assert(curModule); - getBuilder().create(toLoc(SourceLoc())); + getBuilder().create(currentLoc()); curModule->popIP(); } -IfGuard::IfGuard(Val cond, SourceLoc loc) { +IfGuard::IfGuard(Val cond, Location loc) { assert(curModule); Block* innerBlock = new Block(); - auto ifOp = getBuilder().create(toLoc(loc), cond.getValue()); + auto ifOp = getBuilder().create(loc, cond.getValue()); ifOp.getInner().push_back(innerBlock); curModule->pushIP(innerBlock); } IfGuard::~IfGuard() { assert(curModule); - getBuilder().create(toLoc(SourceLoc())); + getBuilder().create(currentLoc()); curModule->popIP(); } @@ -617,31 +593,31 @@ void endBack() { gBackUnchecked = false; } -DigestVal hash(llvm::ArrayRef inputs, bool flip, SourceLoc loc) { +DigestVal hash(llvm::ArrayRef inputs, bool flip, Location loc) { std::vector vals; for (const auto& in : inputs) { vals.push_back(in.getValue()); } Type digestType = DigestType::get(getBuilder().getContext(), DigestKind::Default); - Value out = getBuilder().create(toLoc(loc), digestType, flip, vals); + Value out = getBuilder().create(loc, digestType, flip, vals); return DigestVal(out); } -DigestVal intoDigest(llvm::ArrayRef inputs, DigestKind kind, SourceLoc loc) { +DigestVal intoDigest(llvm::ArrayRef inputs, DigestKind kind, Location loc) { std::vector vals; for (const auto& in : inputs) { vals.push_back(in.getValue()); } auto digestType = DigestType::get(getBuilder().getContext(), kind); - Value out = getBuilder().create(toLoc(loc), digestType, vals); + Value out = getBuilder().create(loc, digestType, vals); return DigestVal(out); } -std::vector fromDigest(DigestVal digest, size_t size, SourceLoc loc) { +std::vector fromDigest(DigestVal digest, size_t size, Location loc) { auto& builder = getBuilder(); Type valType = ValType::getBaseType(builder.getContext()); std::vector types(size, valType); - auto fromOp = builder.create(toLoc(loc), types, digest.getValue()); + auto fromOp = builder.create(loc, types, digest.getValue()); std::vector vals; for (Value out : fromOp.getOut()) { vals.push_back(Val(out)); @@ -649,15 +625,15 @@ std::vector fromDigest(DigestVal digest, size_t size, SourceLoc loc) { return vals; } -DigestVal fold(DigestVal lhs, DigestVal rhs, SourceLoc loc) { - Value out = getBuilder().create(toLoc(loc), lhs.getValue(), rhs.getValue()); +DigestVal fold(DigestVal lhs, DigestVal rhs, Location loc) { + Value out = getBuilder().create(loc, lhs.getValue(), rhs.getValue()); return DigestVal(out); } DigestVal taggedStruct(llvm::StringRef tag, llvm::ArrayRef digests, llvm::ArrayRef vals, - SourceLoc loc) { + Location loc) { std::vector digestVals; for (const auto& in : digests) { digestVals.push_back(in.getValue()); @@ -667,23 +643,23 @@ DigestVal taggedStruct(llvm::StringRef tag, valsVals.push_back(in.getValue()); } - Value out = getBuilder().create(toLoc(loc), tag, digestVals, valsVals); + Value out = getBuilder().create(loc, tag, digestVals, valsVals); return DigestVal(out); } -DigestVal taggedListCons(llvm::StringRef tag, DigestVal head, DigestVal tail, SourceLoc loc) { +DigestVal taggedListCons(llvm::StringRef tag, DigestVal head, DigestVal tail, Location loc) { return taggedStruct(tag, {head, tail}, {}, loc); } -void assert_eq(DigestVal lhs, DigestVal rhs, SourceLoc loc) { - getBuilder().create(toLoc(loc), lhs.getValue(), rhs.getValue()); +void assert_eq(DigestVal lhs, DigestVal rhs, Location loc) { + getBuilder().create(loc, lhs.getValue(), rhs.getValue()); } -std::vector ReadIopVal::readBaseVals(size_t count, bool flip, SourceLoc sloc) { +std::vector ReadIopVal::readBaseVals(size_t count, bool flip, Location sloc) { auto& builder = getBuilder(); Type valType = ValType::getBaseType(builder.getContext()); std::vector types(count, valType); - auto readOp = builder.create(toLoc(sloc), types, getValue(), flip); + auto readOp = builder.create(sloc, types, getValue(), flip); std::vector out; for (size_t i = 0; i < count; i++) { out.emplace_back(readOp.getOuts()[i]); @@ -691,11 +667,11 @@ std::vector ReadIopVal::readBaseVals(size_t count, bool flip, SourceLoc slo return out; } -std::vector ReadIopVal::readExtVals(size_t count, bool flip, SourceLoc sloc) { +std::vector ReadIopVal::readExtVals(size_t count, bool flip, Location sloc) { auto& builder = getBuilder(); Type valType = ValType::getExtensionType(builder.getContext()); std::vector types(count, valType); - auto readOp = builder.create(toLoc(sloc), types, getValue(), flip); + auto readOp = builder.create(sloc, types, getValue(), flip); std::vector out; for (size_t i = 0; i < count; i++) { out.emplace_back(readOp.getOuts()[i]); @@ -703,11 +679,11 @@ std::vector ReadIopVal::readExtVals(size_t count, bool flip, SourceLoc sloc return out; } -std::vector ReadIopVal::readDigests(size_t count, SourceLoc sloc) { +std::vector ReadIopVal::readDigests(size_t count, Location sloc) { auto& builder = getBuilder(); Type digestType = DigestType::get(builder.getContext(), DigestKind::Default); std::vector types(count, digestType); - auto readOp = builder.create(toLoc(sloc), types, getValue(), false); + auto readOp = builder.create(sloc, types, getValue(), false); std::vector out; for (size_t i = 0; i < count; i++) { out.emplace_back(readOp.getOuts()[i]); @@ -715,51 +691,51 @@ std::vector ReadIopVal::readDigests(size_t count, SourceLoc sloc) { return out; } -void ReadIopVal::commit(DigestVal digest, SourceLoc loc) { - getBuilder().create(toLoc(loc), getValue(), digest.getValue()); +void ReadIopVal::commit(DigestVal digest, Location loc) { + getBuilder().create(loc, getValue(), digest.getValue()); } -Val ReadIopVal::rngBits(uint32_t bits, SourceLoc loc) { +Val ReadIopVal::rngBits(uint32_t bits, Location loc) { auto& builder = getBuilder(); Type valType = ValType::getBaseType(builder.getContext()); - return Val(builder.create(toLoc(loc), valType, getValue(), bits)); + return Val(builder.create(loc, valType, getValue(), bits)); } -Val ReadIopVal::rngBaseVal(SourceLoc loc) { +Val ReadIopVal::rngBaseVal(Location loc) { auto& builder = getBuilder(); Type valType = ValType::getBaseType(builder.getContext()); - return Val(builder.create(toLoc(loc), valType, getValue())); + return Val(builder.create(loc, valType, getValue())); } -Val ReadIopVal::rngExtVal(SourceLoc loc) { +Val ReadIopVal::rngExtVal(Location loc) { auto& builder = getBuilder(); Type valType = ValType::getExtensionType(builder.getContext()); - return Val(builder.create(toLoc(loc), valType, getValue())); + return Val(builder.create(loc, valType, getValue())); } -Val select(Val idx, llvm::ArrayRef inputs, SourceLoc loc) { +Val select(Val idx, llvm::ArrayRef inputs, Location loc) { auto& builder = getBuilder(); std::vector vals; for (const auto& in : inputs) { vals.push_back(in.getValue()); } - Value out = builder.create(toLoc(loc), vals[0].getType(), idx.getValue(), vals); + Value out = builder.create(loc, vals[0].getType(), idx.getValue(), vals); return out; } -DigestVal select(Val idx, llvm::ArrayRef inputs, SourceLoc loc) { +DigestVal select(Val idx, llvm::ArrayRef inputs, Location loc) { auto& builder = getBuilder(); std::vector vals; for (const auto& in : inputs) { vals.push_back(in.getValue()); } - Value out = builder.create(toLoc(loc), vals[0].getType(), idx.getValue(), vals); + Value out = builder.create(loc, vals[0].getType(), idx.getValue(), vals); return out; } -Val normalize(Val in, SourceLoc loc) { +Val normalize(Val in, Location loc) { auto& builder = getBuilder(); - Value out = builder.create(toLoc(loc), in.getValue(), 0, ""); + Value out = builder.create(loc, in.getValue(), 0, ""); return out; } @@ -770,4 +746,23 @@ void registerEdslCLOptions() { mlir::registerDefaultTimingManagerCLOptions(); } +static LocationAttr currentLocation; + +Location currentLoc(const char* filename, int line, int column) { + Location loc = mlir::FileLineColRange::get(getCtx(), filename, line, column); + + if (currentLocation) + return mlir::CallSiteLoc::get(/*callee=*/loc, /*caller=*/currentLocation); + else + return loc; +} + +ScopedLocation::ScopedLocation(Location loc) : prevLoc(currentLocation) { + currentLocation = loc; +} + +ScopedLocation::~ScopedLocation() { + currentLocation = prevLoc; +} + } // namespace zirgen diff --git a/zirgen/compiler/edsl/edsl.h b/zirgen/compiler/edsl/edsl.h index d24d1951..6e5a01f0 100644 --- a/zirgen/compiler/edsl/edsl.h +++ b/zirgen/compiler/edsl/edsl.h @@ -1,4 +1,4 @@ -// Copyright 2024 RISC Zero, Inc. +// Copyright 2025 RISC Zero, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -22,23 +22,59 @@ #include "zirgen/Dialect/Zll/IR/IR.h" #include "zirgen/Dialect/Zll/IR/Interpreter.h" #include "zirgen/compiler/codegen/protocol_info_const.h" -#include "zirgen/compiler/edsl/source_loc.h" namespace zirgen { -using risc0::SourceLoc; - -SourceLoc checkCurrentLoc(SourceLoc loc); +/// TODO: Move to std::source_location when we're ok to require c++20 everywhere. + +#ifdef __has_builtin +#if __has_builtin(__builtin_FILE) +#define FILE_EXPR __builtin_FILE() +#else +#define FILE_EXPR __FILE__ +#endif +#else +#define FILE_EXPR __FILE__ +#endif + +#ifdef __has_builtin +#if __has_builtin(__builtin_LINE) +#define LINE_EXPR __builtin_LINE() +#else +#define LINE_EXPR __LINE__ +#endif +#else +#define LINE_EXPR __LINE__ +#endif + +#ifdef __has_builtin +#if __has_builtin(__builtin_COLUMN) +#define COLUMN_EXPR __builtin_COLUMN() +#else +#define COLUMN_EXPR 0 +#endif +#else +#define COLUMN_EXPR 0 +#endif + +/// Register an EDSL context; this must be called before doing anything like `current`. +void registerEdslContext(mlir::MLIRContext* ctx); + +/// Get the "current" source location. When used in default values, this effectively captures the +/// call site of the function declaring the default value, which is very useful. +mlir::Location +currentLoc(const char* filename = FILE_EXPR, int line = LINE_EXPR, int column = COLUMN_EXPR); + +/// Generate CallSiteLocs while a location is in scope. +struct ScopedLocation { +public: + ScopedLocation(mlir::Location loc = zirgen::currentLoc()); + ~ScopedLocation(); -struct OverrideLocation { - OverrideLocation(SourceLoc loc); - ~OverrideLocation(); +private: + mlir::LocationAttr prevLoc; }; -inline SourceLoc current(SourceLoc loc = SourceLoc::current()) { - return checkCurrentLoc(loc); -} - class Val; class DigestVal; class Register; @@ -53,9 +89,9 @@ class Val { public: Val() = default; Val(mlir::Value value) : value(value) {} - Val(uint64_t val, SourceLoc loc = current()); - Val(llvm::ArrayRef coeffs, SourceLoc loc = current()); - Val(Register reg, SourceLoc loc = current()); + Val(uint64_t val, mlir::Location loc = currentLoc()); + Val(llvm::ArrayRef coeffs, mlir::Location loc = currentLoc()); + Val(Register reg, mlir::Location loc = currentLoc()); mlir::Value getValue() const { return value; } @@ -79,10 +115,10 @@ class Register { }; struct CaptureIdx { - CaptureIdx(size_t idx, SourceLoc loc = current()) : idx(idx), loc(loc) {} + CaptureIdx(size_t idx, mlir::Location loc = currentLoc()) : idx(idx), loc(loc) {} size_t idx; - SourceLoc loc; + mlir::Location loc; mlir::Location getLoc(); }; @@ -93,13 +129,13 @@ class Buffer { public: Buffer(mlir::Value buf) : buf(buf) {} size_t size() { return mlir::cast(buf.getType()).getSize(); } - Val get(size_t idx, llvm::StringRef ident, SourceLoc loc = current()); - void set(size_t idx, Val x, llvm::StringRef ident, SourceLoc loc = current()); - void setDigest(size_t idx, DigestVal x, llvm::StringRef ident, SourceLoc loc = current()); - Buffer slice(size_t offset, size_t size, SourceLoc loc = current()); - Register getRegister(size_t idx, llvm::StringRef ident = {}, SourceLoc loc = current()); + Val get(size_t idx, llvm::StringRef ident, mlir::Location loc = currentLoc()); + void set(size_t idx, Val x, llvm::StringRef ident, mlir::Location loc = currentLoc()); + void setDigest(size_t idx, DigestVal x, llvm::StringRef ident, mlir::Location loc = currentLoc()); + Buffer slice(size_t offset, size_t size, mlir::Location loc = currentLoc()); + Register getRegister(size_t idx, llvm::StringRef ident = {}, mlir::Location loc = currentLoc()); Register operator[](CaptureIdx idx) { return getRegister(idx.idx, {}, idx.loc); } - void labelLayout(llvm::ArrayRef labels, SourceLoc loc = current()) const; + void labelLayout(llvm::ArrayRef labels, mlir::Location loc = currentLoc()) const; mlir::Value getBuf() { return buf; } private: @@ -118,18 +154,19 @@ struct CaptureVal { template static int try_get(U obj, long) { return 0; } public: - CaptureVal(uint64_t val, SourceLoc loc = current()) : val(val, loc), loc(loc) {} - CaptureVal(Val val, SourceLoc loc = current()) : val(val), loc(loc) {} - CaptureVal(Register val, SourceLoc loc = current()) : val(val, loc), loc(loc), ident(val.ident) {} + CaptureVal(uint64_t val, mlir::Location loc = currentLoc()) : val(val, loc), loc(loc) {} + CaptureVal(Val val, mlir::Location loc = currentLoc()) : val(val), loc(loc) {} + CaptureVal(Register val, mlir::Location loc = currentLoc()) + : val(val, loc), loc(loc), ident(val.ident) {} template (nullptr), 0))>::value, int>::type = 0> - CaptureVal(Comp comp, SourceLoc loc = current()) : val(comp->get()), loc(loc) {} + CaptureVal(Comp comp, mlir::Location loc = currentLoc()) : val(comp->get()), loc(loc) {} Val val; - SourceLoc loc; + mlir::Location loc; mlir::Value getValue() { return val.getValue(); } mlir::Location getLoc(); std::string ident; @@ -183,7 +220,7 @@ class Module { inline mlir::func::FuncOp addFunc(const std::string& name, std::array args, F func, - SourceLoc loc = current()) { + mlir::Location loc = currentLoc()) { beginFunc(name, std::vector(args.begin(), args.end()), loc); std::array vargs; for (size_t i = 0; i < N; i++) { @@ -229,8 +266,9 @@ class Module { void setProtocolInfo(ProtocolInfo info); private: - void beginFunc(const std::string& name, const std::vector& args, SourceLoc loc); - mlir::func::FuncOp endFunc(SourceLoc loc); + void + beginFunc(const std::string& name, const std::vector& args, mlir::Location loc); + mlir::func::FuncOp endFunc(mlir::Location loc); void pushIP(mlir::Block* block); void popIP(); void runFunc(mlir::func::FuncOp func, @@ -272,18 +310,18 @@ std::vector doExtern(const std::string& name, const std::string& extra, size_t outSize, llvm::ArrayRef in, - SourceLoc loc = current()); + mlir::Location loc = currentLoc()); class NondetGuard { public: - NondetGuard(SourceLoc loc = current()); + NondetGuard(mlir::Location loc = currentLoc()); ~NondetGuard(); operator bool() { return true; } }; class IfGuard { public: - IfGuard(Val cond, SourceLoc loc = current()); + IfGuard(Val cond, mlir::Location loc = currentLoc()); ~IfGuard(); operator bool() { return true; } }; @@ -345,43 +383,45 @@ template <> struct LogPrep { static void toLogVec(std::vector& out, DigestVal x) { out.push_back(Val(x.getValue())); } }; -DigestVal hash(llvm::ArrayRef inputs, bool flip = false, SourceLoc loc = current()); +DigestVal hash(llvm::ArrayRef inputs, bool flip = false, mlir::Location loc = currentLoc()); DigestVal intoDigest(llvm::ArrayRef inputs, Zll::DigestKind kind = Zll::DigestKind::Default, - SourceLoc loc = current()); -std::vector fromDigest(DigestVal digest, size_t size, SourceLoc loc = current()); -DigestVal fold(DigestVal lhs, DigestVal rhs, SourceLoc loc = current()); + mlir::Location loc = currentLoc()); +std::vector fromDigest(DigestVal digest, size_t size, mlir::Location loc = currentLoc()); +DigestVal fold(DigestVal lhs, DigestVal rhs, mlir::Location loc = currentLoc()); DigestVal taggedStruct(llvm::StringRef tag, llvm::ArrayRef digests, llvm::ArrayRef vals, - SourceLoc loc = current()); -DigestVal -taggedListCons(llvm::StringRef tag, DigestVal head, DigestVal tail, SourceLoc loc = current()); -void assert_eq(DigestVal lhs, DigestVal rhs, SourceLoc loc = current()); + mlir::Location loc = currentLoc()); +DigestVal taggedListCons(llvm::StringRef tag, + DigestVal head, + DigestVal tail, + mlir::Location loc = currentLoc()); +void assert_eq(DigestVal lhs, DigestVal rhs, mlir::Location loc = currentLoc()); class ReadIopVal { public: ReadIopVal(mlir::Value value) : value(value) {} mlir::Value getValue() const { return value; } - std::vector readBaseVals(size_t count, bool flip = false, SourceLoc loc = current()); - std::vector readExtVals(size_t count, bool flip = false, SourceLoc loc = current()); + std::vector readBaseVals(size_t count, bool flip = false, mlir::Location loc = currentLoc()); + std::vector readExtVals(size_t count, bool flip = false, mlir::Location loc = currentLoc()); // Read digests of the DigestKind::Default from the IOP stream. - std::vector readDigests(size_t count, SourceLoc loc = current()); - void commit(DigestVal digest, SourceLoc loc = current()); - Val rngBits(uint32_t bits, SourceLoc loc = current()); - Val rngBaseVal(SourceLoc loc = current()); - Val rngExtVal(SourceLoc loc = current()); + std::vector readDigests(size_t count, mlir::Location loc = currentLoc()); + void commit(DigestVal digest, mlir::Location loc = currentLoc()); + Val rngBits(uint32_t bits, mlir::Location loc = currentLoc()); + Val rngBaseVal(mlir::Location loc = currentLoc()); + Val rngExtVal(mlir::Location loc = currentLoc()); private: mlir::Value value; }; -Val select(Val idx, llvm::ArrayRef in, SourceLoc loc = current()); -DigestVal select(Val idx, llvm::ArrayRef in, SourceLoc loc = current()); +Val select(Val idx, llvm::ArrayRef in, mlir::Location loc = currentLoc()); +DigestVal select(Val idx, llvm::ArrayRef in, mlir::Location loc = currentLoc()); -Val normalize(Val in, SourceLoc loc = current()); +Val normalize(Val in, mlir::Location loc = currentLoc()); struct HashCheckedPublicOutput { DigestVal poseidon; diff --git a/zirgen/compiler/edsl/source_loc.h b/zirgen/compiler/edsl/source_loc.h deleted file mode 100644 index 092e87bb..00000000 --- a/zirgen/compiler/edsl/source_loc.h +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright 2024 RISC Zero, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -/// \file -/// SourceLoc is basically a near clone of std::source_location, but since it's not in the spec -/// until c++20, we build our own minialist version. This is currently used to help track EDSL -/// expressions for the IR. - -#ifdef __clang__ -#if defined(__has_builtin) -#if __has_builtin(__builtin_FILE) && __has_builtin(__builtin_LINE) -#define HAS_FILE_LINE 1 -#else // __has_builtin(__builtin_FILE) && __has_builtin(__builtin_LINE) -#define HAS_FILE_LINE 0 -#endif // __has_builtin(__builtin_FILE) && __has_builtin(__builtin_LINE) -#if __has_builtin(__builtin_COLUMN) && __has_builtin(__builtin_COLUMN) -#define HAS_COLUMN 1 -#else // __has_builtin(__builtin_COLUMN) && __has_builtin(__builtin_COLUMN) -#define HAS_COLUMN 0 -#endif // __has_builtin(__builtin_COLUMN) && __has_builtin(__builtin_COLUMN) -#else // defined(__has_builtin) -#define HAS_FILE_LINE 0 -#define HAS_COLUMN 0 -#endif // defined(__has_builtin) -#elif defined(__GNUC__) && __GNUC__ >= 7 -// gcc says it supports __has_builtin but experimentation indicates otherwise. -#define HAS_FILE_LINE 1 -#define HAS_COLUMN 0 -#else // defined(__GNUC__) && __GNUC__ >= 7 -#define HAS_FILE_LINE 0 -#define HAS_COLUMN 0 -#endif // defined(__GNUC__) && __GNUC__ >= 7 - -namespace risc0 { - -/// A source location. Capture as much as we can from the compiler, which might be nothing. -struct SourceLoc { -public: - /// Get the "current" source location. When used in default values, this effectively captures the - /// call site of the function declaring the default value, which is very useful. - static constexpr SourceLoc current( -#if HAS_FILE_LINE - const char* filename = __builtin_FILE(), - int line = __builtin_LINE(), -#else - const char* filename = __FILE__, - int line = __LINE__, -#endif -#if HAS_COLUMN - int column = __builtin_COLUMN() -#else - int column = 0 -#endif - ) noexcept { - SourceLoc loc; - loc.filename = filename; - loc.line = line; - loc.column = column; - return loc; - } - - constexpr SourceLoc() noexcept : filename(""), line(0), column(0) {} - - const char* filename; - size_t line; - size_t column; -}; - -} // namespace risc0 diff --git a/zirgen/compiler/stats/OpStats.cpp b/zirgen/compiler/stats/OpStats.cpp index f9231072..a95a7a36 100644 --- a/zirgen/compiler/stats/OpStats.cpp +++ b/zirgen/compiler/stats/OpStats.cpp @@ -52,6 +52,7 @@ struct OpStatsCLOptions { clEnumValN(SortOrder::Flat, "flat", "Uncombined"), clEnumValN(SortOrder::Any, "any", "Combined occurrences anywhere in call stack"), clEnumValN(SortOrder::Max, "max", "Maximum of all the metrics"))}; + cl::opt flatCost{"op-stats-flat-cost", cl::desc("Count all operations as the same cost")}; }; llvm::ManagedStatic clOpts; @@ -141,7 +142,9 @@ struct LocStat { bool operator<(const LocStat& rhs) const { return sortStat() < rhs.sortStat(); } }; -class LocStats { +} // namespace + +class BogoCycleAnalysis::LocStats { public: void countLoc(Location loc, double c) { seenFlat.clear(); @@ -255,8 +258,6 @@ class LocStats { DenseSet seenFlat, seenInside, seenOutside, seenAny; }; -} // namespace - // Calculate the number of bogocycles used for OpT with the given extension field. template double BogoCycleAnalysis::getOrCalcBogoCycles() { std::pair id = std::make_pair(OpT::getOperationName(), K); @@ -287,6 +288,9 @@ template double BogoCycleAnalysis::getOrC } double BogoCycleAnalysis::getBogoCycles(Operation* op) { + if (clOpts->flatCost) + return 1; + // Relative costs of operations, in bogocycles. return TypeSwitch(op) .Case([&](auto op) { @@ -299,7 +303,34 @@ double BogoCycleAnalysis::getBogoCycles(Operation* op) { }) .Case( [&](auto op) { return getOrCalcBogoCycles(); }) - .Default([](auto) { return 0; }); + .Default([&](auto) { + // Unknown operations + if (didWarnAbout.insert(op->getName()).second) { + static llvm::DenseSet globalDidWarnAbout; + static llvm::sys::Mutex mu; + + mu.lock(); + if (globalDidWarnAbout.insert(op->getName()).second) { + llvm::errs() << "OpStats: Don't know how to estimate how long " << op->getName() + << " takes\n"; + } + mu.unlock(); + } + return 0; + }); +} + +bool BogoCycleAnalysis::isEnabled() { + if (!clOpts.isConstructed()) { + throw(std::runtime_error("op-stats command line options must be registered")); + } + + if (clOpts->sortOrder == SortOrder::Disabled) { + LLVM_DEBUG({ llvm::dbgs() << "BogoCycle operation statistics not requested"; }); + return false; + } + + return true; } void BogoCycleAnalysis::printStatsIfRequired(Operation* topOp, llvm::raw_ostream& os) { @@ -324,6 +355,47 @@ void BogoCycleAnalysis::printStatsIfRequired(Operation* topOp, llvm::raw_ostream locStats.countLoc(op->getLoc(), c); }); + OpPrintingFlags flags; + flags.enableDebugInfo(/*enable=*/true, /*prettyForm=*/true); + + AsmState asmState(topOp, flags); + printStatsInternal(asmState, totCycles, locStats, os); +} + +void BogoCycleAnalysis::printStatsFromMap(MLIRContext* ctx, + llvm::DenseMap locs, + llvm::raw_ostream& os) { + if (!clOpts.isConstructed()) { + throw(std::runtime_error("op-stats command line options must be registered")); + } + + if (clOpts->sortOrder == SortOrder::Disabled) { + LLVM_DEBUG({ llvm::dbgs() << "BogoCycle operation statistics not requested"; }); + return; + } + + double totCycles = 0; + LocStats locStats; + + for (auto& [loc, c] : locs) { + if (!c) + continue; + + totCycles += c; + locStats.countLoc(loc, c); + } + + OpPrintingFlags flags; + flags.enableDebugInfo(/*enable=*/true, /*prettyForm=*/true); + + AsmState asmState(ctx, flags); + printStatsInternal(asmState, totCycles, locStats, os); +} + +void BogoCycleAnalysis::printStatsInternal(AsmState& asmState, + double totCycles, + LocStats& locStats, + llvm::raw_ostream& os) { auto totLocStats = locStats.toVector(); llvm::sort(totLocStats, llvm::less_second()); @@ -335,11 +407,6 @@ void BogoCycleAnalysis::printStatsIfRequired(Operation* topOp, llvm::raw_ostream (const char*)"Any", (const char*)"Location"); - // Set us up to be able to - OpPrintingFlags flags; - flags.enableDebugInfo(/*enable=*/true, /*prettyForm=*/true); - AsmState asmState(topOp, flags); - for (auto& [loc, stat] : llvm::reverse(totLocStats)) { if (stat.flat) os << llvm::format("%9.5f%% ", stat.flat * 100. / totCycles); diff --git a/zirgen/compiler/stats/OpStats.h b/zirgen/compiler/stats/OpStats.h index 83ee75de..247a538d 100644 --- a/zirgen/compiler/stats/OpStats.h +++ b/zirgen/compiler/stats/OpStats.h @@ -33,11 +33,28 @@ class BogoCycleAnalysis { // if the op-stqts void printStatsIfRequired(mlir::Operation* topOp, llvm::raw_ostream& os); + // Returns true if statistics are enabled. + bool isEnabled(); + + // Tally up and print out statistics for a raw location-to-cycles map + void printStatsFromMap(mlir::MLIRContext* ctx, + llvm::DenseMap locs, + llvm::raw_ostream& os); + private: + class LocStats; + template double getOrCalcBogoCycles(); + void printStatsInternal(mlir::AsmState& asmState, + double totCycles, + LocStats& locStats, + llvm::raw_ostream& os); + // Bogo cycles measured for various operations std::map, double> bogoCycles; + + llvm::DenseSet didWarnAbout; }; void registerOpStatsCLOptions(); diff --git a/zirgen/components/bits.cpp b/zirgen/components/bits.cpp index 83759f5c..20ee815f 100644 --- a/zirgen/components/bits.cpp +++ b/zirgen/components/bits.cpp @@ -1,4 +1,4 @@ -// Copyright 2024 RISC Zero, Inc. +// Copyright 2025 RISC Zero, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,14 +17,14 @@ namespace zirgen { // Checks that a val is a bit. -void isBit(Val val, SourceLoc loc) { - OverrideLocation local(loc); +void isBit(Val val, Location loc) { + ScopedLocation local(loc); // The following constraint enforces that either val = 0 or val = 1 eqz(val * (1 - val)); } -void isBits(Buffer buf, SourceLoc loc) { - OverrideLocation local(loc); +void isBits(Buffer buf, Location loc) { + ScopedLocation local(loc); for (size_t i = 0; i < buf.size(); i++) { isBit(buf[i]); } diff --git a/zirgen/components/fpext.cpp b/zirgen/components/fpext.cpp index d68a48f9..3ec837d7 100644 --- a/zirgen/components/fpext.cpp +++ b/zirgen/components/fpext.cpp @@ -1,4 +1,4 @@ -// Copyright 2024 RISC Zero, Inc. +// Copyright 2025 RISC Zero, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,8 +14,6 @@ #include "fpext.h" -using namespace risc0; - namespace zirgen { FpExtRegImpl::FpExtRegImpl(llvm::StringRef source) { @@ -24,8 +22,8 @@ FpExtRegImpl::FpExtRegImpl(llvm::StringRef source) { } } -FpExt FpExtRegImpl::get(SourceLoc loc) { - OverrideLocation local(loc); +FpExt FpExtRegImpl::get(mlir::Location loc) { + ScopedLocation local(loc); std::array arr; for (size_t i = 0; i < kExtSize; i++) { arr[i] = elems[i]; @@ -39,29 +37,29 @@ void FpExtRegImpl::set(CaptureFpExt rhs) { } } -FpExt::FpExt(Val x, SourceLoc loc) { - OverrideLocation local(loc); +FpExt::FpExt(Val x, mlir::Location loc) { + ScopedLocation local(loc); elems[0] = x; for (size_t i = 1; i < kExtSize; i++) { elems[i] = 0; } } -FpExt::FpExt(std::array elems, risc0::SourceLoc loc) { - OverrideLocation local(loc); +FpExt::FpExt(std::array elems, mlir::Location loc) { + ScopedLocation local(loc); for (size_t i = 0; i < kExtSize; i++) { this->elems[i] = elems[i]; } } -FpExt::FpExt(FpExtReg reg, risc0::SourceLoc loc) { - OverrideLocation local(loc); +FpExt::FpExt(FpExtReg reg, mlir::Location loc) { + ScopedLocation local(loc); for (size_t i = 0; i < kExtSize; i++) { elems[i] = reg->elem(i); } } -FpExt FpExt::fromVals(llvm::ArrayRef vals, risc0::SourceLoc loc) { +FpExt FpExt::fromVals(llvm::ArrayRef vals, mlir::Location loc) { assert(vals.size() == kExtSize); std::array elems; std::copy(vals.begin(), vals.end(), elems.begin()); @@ -69,7 +67,7 @@ FpExt FpExt::fromVals(llvm::ArrayRef vals, risc0::SourceLoc loc) { } FpExt operator+(CaptureFpExt a, CaptureFpExt b) { - OverrideLocation local(a.loc); + ScopedLocation local(a.loc); std::array out; for (size_t i = 0; i < kExtSize; i++) { out[i] = a.ext.elem(i) + b.ext.elem(i); @@ -78,7 +76,7 @@ FpExt operator+(CaptureFpExt a, CaptureFpExt b) { } FpExt operator-(CaptureFpExt a, CaptureFpExt b) { - OverrideLocation local(a.loc); + ScopedLocation local(a.loc); std::array out; for (size_t i = 0; i < kExtSize; i++) { out[i] = a.ext.elem(i) - b.ext.elem(i); @@ -87,7 +85,7 @@ FpExt operator-(CaptureFpExt a, CaptureFpExt b) { } FpExt operator*(CaptureFpExt a, CaptureFpExt b) { - OverrideLocation local(a.loc); + ScopedLocation local(a.loc); std::array out; Val NBETA = -Val(11); // Rename the element arrays to something small for readability @@ -108,14 +106,14 @@ FpExt operator*(CaptureFpExt a, CaptureFpExt b) { } void eq(CaptureFpExt a, CaptureFpExt b) { - OverrideLocation local(a.loc); + ScopedLocation local(a.loc); for (size_t i = 0; i < kExtSize; i++) { eq(a.ext.elem(i), b.ext.elem(i)); } } FpExt inv(CaptureFpExt a) { - OverrideLocation local(a.loc); + ScopedLocation local(a.loc); Val BETA = 11; #define a(i) a.ext.elem(i) #if GOLDILOCKS diff --git a/zirgen/components/fpext.h b/zirgen/components/fpext.h index 5a37fec7..c39b3182 100644 --- a/zirgen/components/fpext.h +++ b/zirgen/components/fpext.h @@ -1,4 +1,4 @@ -// Copyright 2024 RISC Zero, Inc. +// Copyright 2025 RISC Zero, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -31,7 +31,7 @@ class FpExt; class FpExtRegImpl : public CompImpl { public: FpExtRegImpl(llvm::StringRef source = "data"); - FpExt get(risc0::SourceLoc loc = current()); + FpExt get(mlir::Location loc = currentLoc()); void set(CaptureFpExt rhs); Val elem(size_t i) { return elems[i]; } @@ -44,13 +44,13 @@ using FpExtReg = Comp; class FpExt { public: FpExt() = default; - FpExt(Val x, risc0::SourceLoc loc = current()); - FpExt(std::array elems, risc0::SourceLoc loc = current()); - FpExt(FpExtReg reg, risc0::SourceLoc loc = current()); + FpExt(Val x, mlir::Location loc = currentLoc()); + FpExt(std::array elems, mlir::Location loc = currentLoc()); + FpExt(FpExtReg reg, mlir::Location loc = currentLoc()); Val elem(size_t i) { return elems[i]; } std::array getElems() { return elems; } std::vector toVals() { return std::vector(elems.begin(), elems.end()); } - static FpExt fromVals(llvm::ArrayRef vals, SourceLoc loc = current()); + static FpExt fromVals(llvm::ArrayRef vals, mlir::Location loc = currentLoc()); private: std::array elems; @@ -58,10 +58,10 @@ class FpExt { class CaptureFpExt { public: - CaptureFpExt(FpExt ext, risc0::SourceLoc loc = current()) : ext(ext), loc(loc) {} - CaptureFpExt(FpExtReg ext, risc0::SourceLoc loc = current()) : ext(ext, loc), loc(loc) {} + CaptureFpExt(FpExt ext, mlir::Location loc = currentLoc()) : ext(ext), loc(loc) {} + CaptureFpExt(FpExtReg ext, mlir::Location loc = currentLoc()) : ext(ext, loc), loc(loc) {} FpExt ext; - risc0::SourceLoc loc; + mlir::Location loc; }; FpExt operator+(CaptureFpExt a, CaptureFpExt b); diff --git a/zirgen/components/onehot.h b/zirgen/components/onehot.h index c3e733cf..9c93753a 100644 --- a/zirgen/components/onehot.h +++ b/zirgen/components/onehot.h @@ -1,4 +1,4 @@ -// Copyright 2024 RISC Zero, Inc. +// Copyright 2025 RISC Zero, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -58,8 +58,8 @@ template class OneHotImpl : public CompImpl> { return tot; } - Val at(size_t idx, SourceLoc loc = current()) { - OverrideLocation guard(loc); + Val at(size_t idx, mlir::Location loc = currentLoc()) { + ScopedLocation guard(loc); return bits[idx]; } diff --git a/zirgen/components/reg.h b/zirgen/components/reg.h index 60fcd324..c6402e1e 100644 --- a/zirgen/components/reg.h +++ b/zirgen/components/reg.h @@ -1,4 +1,4 @@ -// Copyright 2024 RISC Zero, Inc. +// Copyright 2025 RISC Zero, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -33,12 +33,12 @@ class RegImpl : public CompImpl { CompContext::saveLabel(buf, label); } // The "source" parameter indicates the grouping of a RegImpl. - Val get(SourceLoc loc = current()) { - OverrideLocation guard(loc); + Val get(mlir::Location loc = currentLoc()) { + ScopedLocation guard(loc); return buf.getRegister(0, constructPath, loc); } - void set(Val in, SourceLoc loc = current()) { - OverrideLocation guard(loc); + void set(Val in, mlir::Location loc = currentLoc()) { + ScopedLocation guard(loc); buf.getRegister(0, constructPath, loc) = in; } Buffer raw() { return buf; } diff --git a/zirgen/components/u32.cpp b/zirgen/components/u32.cpp index 64b4ce22..ac2d7486 100644 --- a/zirgen/components/u32.cpp +++ b/zirgen/components/u32.cpp @@ -1,4 +1,4 @@ -// Copyright 2024 RISC Zero, Inc. +// Copyright 2025 RISC Zero, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -26,8 +26,8 @@ U32Val U32Val::underflowProtect() { return U32Val(0x100, 0xff, 0xff, 0xff); } -void eq(U32Val a, U32Val b, SourceLoc loc) { - OverrideLocation local(loc); +void eq(U32Val a, U32Val b, mlir::Location loc) { + ScopedLocation local(loc); for (size_t i = 0; i < 4; i++) { eq(a.bytes[i], b.bytes[i]); } diff --git a/zirgen/components/u32.h b/zirgen/components/u32.h index c6b1ec43..ddbaa522 100644 --- a/zirgen/components/u32.h +++ b/zirgen/components/u32.h @@ -1,4 +1,4 @@ -// Copyright 2024 RISC Zero, Inc. +// Copyright 2025 RISC Zero, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -50,7 +50,7 @@ template <> struct LogPrep { } }; -void eq(U32Val a, U32Val b, SourceLoc loc = SourceLoc::current()); +void eq(U32Val a, U32Val b, mlir::Location loc = currentLoc()); class U32RegImpl : public CompImpl { public: diff --git a/zirgen/dsl/passes/GenerateAccum.cpp b/zirgen/dsl/passes/GenerateAccum.cpp index cc59e893..d0ae2def 100644 --- a/zirgen/dsl/passes/GenerateAccum.cpp +++ b/zirgen/dsl/passes/GenerateAccum.cpp @@ -20,7 +20,7 @@ #include "zirgen/Dialect/ZStruct/Transforms/RewritePatterns.h" #include "zirgen/Dialect/Zll/IR/IR.h" #include "zirgen/Utilities/KeyPath.h" -#include "zirgen/compiler/edsl/source_loc.h" +#include "zirgen/compiler/edsl/edsl.h" #include "zirgen/dsl/passes/CommonRewrites.h" #include "zirgen/dsl/passes/PassDetail.h" @@ -29,13 +29,6 @@ using namespace zirgen::Zhlt; using namespace zirgen::ZStruct; using namespace zirgen::Zll; -namespace { - -mlir::Location currentLoc(MLIRContext* ctx, risc0::SourceLoc loc = risc0::SourceLoc::current()) { - return FileLineColLoc::get(StringAttr::get(ctx, loc.filename), loc.line, loc.column); -} - -} // namespace namespace zirgen { namespace dsl { @@ -44,19 +37,19 @@ struct RandomnessMap { RandomnessMap(OpBuilder& builder, Value pivot, Value zeroDistance) : pivot(pivot), ctx(pivot.getContext()) { if (pivot.getType() == Zhlt::getExtRefType(pivot.getContext())) { - this->pivot = builder.create(currentLoc(ctx), pivot, zeroDistance); + this->pivot = builder.create(currentLoc(), pivot, zeroDistance); } else if (auto pivotType = dyn_cast(pivot.getType())) { for (auto field : pivotType.getFields()) { if (field.name == "$offset") continue; - Value member = builder.create(currentLoc(ctx), pivot, field.name); + Value member = builder.create(currentLoc(), pivot, field.name); map.insert({field.name, RandomnessMap(builder, member, zeroDistance)}); } } else if (auto pivotType = dyn_cast(pivot.getType())) { for (size_t i = 0; i < pivotType.getSize(); i++) { - Value index = builder.create(currentLoc(ctx), builder.getIndexAttr(i)); - Value element = builder.create(currentLoc(ctx), pivot, index); + Value index = builder.create(currentLoc(), builder.getIndexAttr(i)); + Value element = builder.create(currentLoc(), pivot, index); map.insert({i, RandomnessMap(builder, element, zeroDistance)}); } } else { @@ -76,21 +69,21 @@ class AccumBuilder { : builder(builder), ctx(builder.getContext()), accumLayout(accumLayout) { // Preemptively lookup all verifier randomness values and store them in a // dictionary. Unneeded values will be pruned later by folding. - zeroDistance = builder.create(currentLoc(ctx), builder.getIndexAttr(0)); + zeroDistance = builder.create(currentLoc(), builder.getIndexAttr(0)); verifierRandomness = RandomnessMap(builder, randomnessLayout, zeroDistance); - offset = builder.create(currentLoc(ctx), randomnessLayout, "$offset"); - offset = builder.create(currentLoc(ctx), offset, zeroDistance); + offset = builder.create(currentLoc(), randomnessLayout, "$offset"); + offset = builder.create(currentLoc(), offset, zeroDistance); auto accumLayoutType = cast(accumLayout.getType()); // Read last accumulator from previous row IntegerAttr distanceAttr = builder.getIndexAttr(1); Value distance = - builder.create(currentLoc(ctx), builder.getIndexType(), distanceAttr); + builder.create(currentLoc(), builder.getIndexType(), distanceAttr); Value lastLayout = getAccumColumnLayout(accumLayoutType.getSize() - 1); - auto loadPrevOp = builder.create(currentLoc(ctx), lastLayout, distance); + auto loadPrevOp = builder.create(currentLoc(), lastLayout, distance); loadPrevOp->setAttr("unchecked", builder.getUnitAttr()); this->oldT = loadPrevOp; this->t = oldT; @@ -119,10 +112,10 @@ class AccumBuilder { // we still need to write the last accum column for this cycle. if (accCol == 0 || accCol != accumLayoutType.getSize()) { Value lastLayout = getAccumColumnLayout(accumLayoutType.getSize() - 1); - builder.create(currentLoc(ctx), lastLayout, t); - Value newT = builder.create(currentLoc(ctx), lastLayout, zeroDistance); - Value diff = builder.create(currentLoc(ctx), newT, oldT); - builder.create(currentLoc(ctx), diff); + builder.create(currentLoc(), lastLayout, t); + Value newT = builder.create(currentLoc(), lastLayout, zeroDistance); + Value diff = builder.create(currentLoc(), newT, oldT); + builder.create(currentLoc(), diff); } } @@ -141,30 +134,30 @@ class AccumBuilder { // For a single register, compute r_a * a[i] assert(randomness.pivot.getType() == extValType); Value r = randomness.pivot; - Value colLayout = builder.create(currentLoc(ctx), layout, "@super"); - Value colValue = builder.create(currentLoc(ctx), colLayout, zeroDistance); - return builder.create(currentLoc(ctx), r, colValue); + Value colLayout = builder.create(currentLoc(), layout, "@super"); + Value colValue = builder.create(currentLoc(), colLayout, zeroDistance); + return builder.create(currentLoc(), r, colValue); } else if (auto layoutType = dyn_cast(layout.getType())) { // for a LayoutType, sum condensations of non-count fields auto fields = layoutType.getFields(); if (layoutType.getKind() == LayoutKind::Argument) fields = fields.drop_front(); - Value v = builder.create(currentLoc(ctx), valType, 0); + Value v = builder.create(currentLoc(), valType, 0); for (auto field : fields) { - Value sublayout = builder.create(currentLoc(ctx), layout, field.name); + Value sublayout = builder.create(currentLoc(), layout, field.name); Value sumTerm = condenseArgument(sublayout, randomness.map.at(field.name.getValue())); - v = builder.create(currentLoc(ctx), v, sumTerm); + v = builder.create(currentLoc(), v, sumTerm); } return v; } else if (auto layoutType = dyn_cast(layout.getType())) { // For a LayoutArrayType, sum condensations at all subscripts - Value v = builder.create(currentLoc(ctx), valType, 0); + Value v = builder.create(currentLoc(), valType, 0); for (size_t i = 0; i < layoutType.getSize(); i++) { - Value index = builder.create(currentLoc(ctx), builder.getIndexAttr(i)); - Value sublayout = builder.create(currentLoc(ctx), layout, index); + Value index = builder.create(currentLoc(), builder.getIndexAttr(i)); + Value sublayout = builder.create(currentLoc(), layout, index); Value sumTerm = condenseArgument(sublayout, randomness.map.at(i)); - v = builder.create(currentLoc(ctx), v, sumTerm); + v = builder.create(currentLoc(), v, sumTerm); } return v; } else { @@ -174,20 +167,20 @@ class AccumBuilder { } void addConstraint(Value newT) { - Value deltaT = builder.create(currentLoc(ctx), newT, oldT); - Value constraint = builder.create(currentLoc(ctx), deltaT, constraintLhs); + Value deltaT = builder.create(currentLoc(), newT, oldT); + Value constraint = builder.create(currentLoc(), deltaT, constraintLhs); for (size_t i = 0; i < accCount; i++) { - constraint = builder.create(currentLoc(ctx), constraint, constraintRhsTerms[i]); + constraint = builder.create(currentLoc(), constraint, constraintRhsTerms[i]); } - builder.create(currentLoc(ctx), constraint); + builder.create(currentLoc(), constraint); } // Writes t to the next accum column, and returns a value loaded from it. Value storeTemporarySum(Value t) { assert(accCount != 0 && "writing accum column without accumulating any new arguments"); Value tLayout = getAccumColumnLayout(accCol); - builder.create(currentLoc(ctx), tLayout, t); - Value newT = builder.create(currentLoc(ctx), tLayout, zeroDistance); + builder.create(currentLoc(), tLayout, t); + Value newT = builder.create(currentLoc(), tLayout, zeroDistance); addConstraint(newT); oldT = newT; @@ -214,13 +207,13 @@ class AccumBuilder { constraintRhsTerms[1] = vPlusOffset; constraintRhsTerms[2] = vPlusOffset; } else { - constraintLhs = builder.create(currentLoc(ctx), constraintLhs, vPlusOffset); + constraintLhs = builder.create(currentLoc(), constraintLhs, vPlusOffset); constraintRhsTerms[accCount] = - builder.create(currentLoc(ctx), constraintRhsTerms[accCount], c); + builder.create(currentLoc(), constraintRhsTerms[accCount], c); for (size_t i = 0; i < 3; i++) { if (i != accCount) { constraintRhsTerms[i] = - builder.create(currentLoc(ctx), constraintRhsTerms[i], vPlusOffset); + builder.create(currentLoc(), constraintRhsTerms[i], vPlusOffset); } } } @@ -231,14 +224,14 @@ class AccumBuilder { Value accumulateArgument(Value t, LayoutType type, Value layout) { MLIRContext* ctx = builder.getContext(); StringAttr countName = type.getFields()[0].name; - Value cLayout = builder.create(currentLoc(ctx), layout, countName); + Value cLayout = builder.create(currentLoc(), layout, countName); cLayout = Zhlt::coerceTo(cLayout, Zhlt::getRefType(ctx), builder); - Value c = builder.create(currentLoc(ctx), cLayout, zeroDistance); + Value c = builder.create(currentLoc(), cLayout, zeroDistance); Value v = condenseArgument(layout, verifierRandomness.map.at(type.getId())); - Value vPlusOffset = builder.create(currentLoc(ctx), v, offset); - Value denominator = builder.create(currentLoc(ctx), vPlusOffset); - Value delta = builder.create(currentLoc(ctx), c, denominator); - Value tNew = builder.create(currentLoc(ctx), t, delta); + Value vPlusOffset = builder.create(currentLoc(), v, offset); + Value denominator = builder.create(currentLoc(), vPlusOffset); + Value delta = builder.create(currentLoc(), c, denominator); + Value tNew = builder.create(currentLoc(), t, delta); buildConstraintTerms(vPlusOffset, c); @@ -257,9 +250,9 @@ class AccumBuilder { // would disrupt the register allocation and constraint generation. for (size_t i = 0; i < layoutArrayType.getSize(); i++) { IntegerAttr indexAttr = builder.getIndexAttr(i); - Value indexValue = builder.create( - currentLoc(ctx), builder.getIndexType(), indexAttr); - Value sublayout = builder.create(currentLoc(ctx), layout, indexValue); + Value indexValue = + builder.create(currentLoc(), builder.getIndexType(), indexAttr); + Value sublayout = builder.create(currentLoc(), layout, indexValue); t = doAccum(t, sublayout); } }) @@ -273,7 +266,7 @@ class AccumBuilder { return; } for (auto field : layoutType.getFields()) { - Value sublayout = builder.create(currentLoc(ctx), layout, field.name); + Value sublayout = builder.create(currentLoc(), layout, field.name); t = doAccum(t, sublayout); } }); @@ -283,8 +276,8 @@ class AccumBuilder { Value getAccumColumnLayout(size_t index) { IntegerAttr indexAttr = builder.getIndexAttr(index); Value indexValue = - builder.create(currentLoc(ctx), builder.getIndexType(), indexAttr); - return builder.create(currentLoc(ctx), accumLayout, indexValue); + builder.create(currentLoc(), builder.getIndexType(), indexAttr); + return builder.create(currentLoc(), accumLayout, indexValue); } OpBuilder& builder; @@ -423,6 +416,9 @@ struct GenerateAccumPass : public GenerateAccumBase { } void runOnOperation() override { + // Register our MLIRContext so we can use EDSL location tracking. + registerEdslContext(&getContext()); + // Get user accum function (if one exist, otherwise null smart-ptr) auto userAccum = getUserAccum(getOperation()); @@ -522,14 +518,14 @@ struct GenerateAccumPass : public GenerateAccumBase { // Make the accum component std::string accumName = (component.getName() + "$accum").str(); auto accum = builder.create( - currentLoc(ctx), accumName, Zhlt::getComponentType(ctx), accumParams, accumLayoutType); + currentLoc(), accumName, Zhlt::getComponentType(ctx), accumParams, accumLayoutType); SymbolTable::setSymbolVisibility(accum, SymbolTable::Visibility::Public); builder.setInsertionPointToStart(accum.addEntryBlock()); // Create globals for verifier randomness LayoutType mixLayoutType = getRandomnessLayoutType(userRandomnessSize); - auto randomnessLayout = builder.create( - currentLoc(ctx), mixLayoutType, "mix", "randomness"); + auto randomnessLayout = + builder.create(currentLoc(), mixLayoutType, "mix", "randomness"); // Walk down the key path to the major mux layout component, capture top as we go. Value cur = accum.getConstructParam()[0]; @@ -541,17 +537,17 @@ struct GenerateAccumPass : public GenerateAccumBase { } } if (auto* strKey = std::get_if(&key)) { - cur = builder.create(currentLoc(ctx), cur, *strKey); + cur = builder.create(currentLoc(), cur, *strKey); } else { Value index = builder.create( - currentLoc(ctx), builder.getIndexAttr(std::get(key))); - cur = builder.create(currentLoc(ctx), cur, index); + currentLoc(), builder.getIndexAttr(std::get(key))); + cur = builder.create(currentLoc(), cur, index); } } // Do 'user' accum work (if any) if (userAccum) { - auto loc = currentLoc(ctx); + auto loc = currentLoc(); Value userLayout = builder.create(loc, accum.getLayout(), "user"); Value userRandomness = builder.create(loc, randomnessLayout, "$user"); // Load randomness and put into an array @@ -570,23 +566,22 @@ struct GenerateAccumPass : public GenerateAccumBase { } // Load from the list of saved selectors - Value selectorLayoutArray = builder.create(currentLoc(ctx), cur, "@selector"); + Value selectorLayoutArray = builder.create(currentLoc(), cur, "@selector"); size_t armCount = dyn_cast(selectorLayoutArray.getType()).getSize(); SmallVector selectors; - Value zeroDistance = - builder.create(currentLoc(ctx), builder.getIndexAttr(0)); + Value zeroDistance = builder.create(currentLoc(), builder.getIndexAttr(0)); for (size_t i = 0; i < armCount; i++) { - Value index = builder.create(currentLoc(ctx), builder.getIndexAttr(i)); - Value nondetReg = builder.create(currentLoc(ctx), selectorLayoutArray, index); - Value ref = builder.create(currentLoc(ctx), nondetReg, "@super"); - Value val = builder.create(currentLoc(ctx), ref, zeroDistance); + Value index = builder.create(currentLoc(), builder.getIndexAttr(i)); + Value nondetReg = builder.create(currentLoc(), selectorLayoutArray, index); + Value ref = builder.create(currentLoc(), nondetReg, "@super"); + Value val = builder.create(currentLoc(), ref, zeroDistance); selectors.push_back(val); } // Build a switch op over the (saved) selectors Type componentType = Zhlt::getComponentType(builder.getContext()); auto switchOp = builder.create( - currentLoc(ctx), componentType, selectors, /*numArms=*/armCount); + currentLoc(), componentType, selectors, /*numArms=*/armCount); // Now, generate each arm for (size_t i = 0; i < armCount; i++) { @@ -598,36 +593,34 @@ struct GenerateAccumPass : public GenerateAccumBase { majorType.getFields().end(), [&](auto field) { return field.name == armName; }); if (hasArm) { - Value armLayout = builder.create(currentLoc(ctx), cur, armName); - Value columnsLayout = - builder.create(currentLoc(ctx), accum.getLayout(), "columns"); + Value armLayout = builder.create(currentLoc(), cur, armName); + Value columnsLayout = builder.create(currentLoc(), accum.getLayout(), "columns"); AccumBuilder accumBuilder(builder, columnsLayout, randomnessLayout); accumBuilder.build(armLayout); accumBuilder.finalize(); } else { // Read accumulator + forward Value distance = builder.create( - currentLoc(ctx), builder.getIndexType(), builder.getIndexAttr(1)); + currentLoc(), builder.getIndexType(), builder.getIndexAttr(1)); Value indexValue = builder.create( - currentLoc(ctx), builder.getIndexType(), builder.getIndexAttr(columns - 1)); - Value columnsLayout = - builder.create(currentLoc(ctx), accum.getLayout(), "columns"); - Value lastLayout = builder.create(currentLoc(ctx), columnsLayout, indexValue); - auto prevLoadOp = builder.create(currentLoc(ctx), lastLayout, distance); + currentLoc(), builder.getIndexType(), builder.getIndexAttr(columns - 1)); + Value columnsLayout = builder.create(currentLoc(), accum.getLayout(), "columns"); + Value lastLayout = builder.create(currentLoc(), columnsLayout, indexValue); + auto prevLoadOp = builder.create(currentLoc(), lastLayout, distance); prevLoadOp->setAttr("unchecked", builder.getUnitAttr()); Value prevVal = prevLoadOp; - builder.create(currentLoc(ctx), lastLayout, prevVal); + builder.create(currentLoc(), lastLayout, prevVal); } - Value empty = builder.create( - currentLoc(ctx), Zhlt::getComponentType(ctx), ValueRange{}); - builder.create(currentLoc(ctx), empty); + Value empty = + builder.create(currentLoc(), Zhlt::getComponentType(ctx), ValueRange{}); + builder.create(currentLoc(), empty); } // Make a null return Value empty = - builder.create(currentLoc(ctx), Zhlt::getComponentType(ctx), ValueRange{}); - builder.create(currentLoc(ctx), empty); + builder.create(currentLoc(), Zhlt::getComponentType(ctx), ValueRange{}); + builder.create(currentLoc(), empty); } private: From 9b7da7028c702d9545c3ec80858009300fa92d58 Mon Sep 17 00:00:00 2001 From: nils m Date: Wed, 16 Apr 2025 10:40:51 -0700 Subject: [PATCH 2/2] Add more predicate op stats granularity --- zirgen/circuit/recursion/encode.cpp | 30 +++++++++++++++++++---------- zirgen/compiler/stats/OpStats.cpp | 4 ++++ 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/zirgen/circuit/recursion/encode.cpp b/zirgen/circuit/recursion/encode.cpp index bbbb40f7..6165ef94 100644 --- a/zirgen/circuit/recursion/encode.cpp +++ b/zirgen/circuit/recursion/encode.cpp @@ -207,8 +207,16 @@ struct Instructions { uint64_t addHalfsConst(uint32_t tot) { return addConst(tot & 0xffff, tot >> 16); } - uint64_t - addMicro(Value out, MicroOpcode opcode, uint64_t op0 = 0, uint64_t op1 = 0, uint64_t op2 = 0) { + uint64_t addMicro(Value out, + MicroOpcode opcode, + uint64_t op0 = 0, + uint64_t op1 = 0, + uint64_t op2 = 0, + Location callerLoc = currentLoc()) { + std::optional valueLoc; + if (out) + valueLoc.emplace(out.getLoc()); + ScopedLocation loc(callerLoc); if (microUsed == 0) { data.emplace_back(); data.back().opType = OpType::MICRO; @@ -228,12 +236,7 @@ struct Instructions { } if (stats) { - if (out) { - ScopedLocation loc(out.getLoc()); - stats->locs[currentLoc()]++; - } else { - stats->locs[currentLoc()]++; - } + stats->locs[currentLoc()]++; } return outId; } @@ -244,8 +247,13 @@ struct Instructions { } } - uint64_t - addMacro(size_t outs, MacroOpcode opcode, uint64_t op0 = 0, uint64_t op1 = 0, uint64_t op2 = 0) { + uint64_t addMacro(size_t outs, + MacroOpcode opcode, + uint64_t op0 = 0, + uint64_t op1 = 0, + uint64_t op2 = 0, + Location callerLoc = currentLoc()) { + ScopedLocation loc(callerLoc); finishMicros(); data.emplace_back(); data.back().opType = OpType::MACRO; @@ -837,6 +845,8 @@ struct Instructions { } }) .Case([&](Iop::ReadOp op) { + ScopedLocation loc; + size_t k = 0; size_t rep = 1; size_t demont = false; diff --git a/zirgen/compiler/stats/OpStats.cpp b/zirgen/compiler/stats/OpStats.cpp index a95a7a36..4d961442 100644 --- a/zirgen/compiler/stats/OpStats.cpp +++ b/zirgen/compiler/stats/OpStats.cpp @@ -396,6 +396,7 @@ void BogoCycleAnalysis::printStatsInternal(AsmState& asmState, double totCycles, LocStats& locStats, llvm::raw_ostream& os) { + const size_t kMaxLines = 100000; auto totLocStats = locStats.toVector(); llvm::sort(totLocStats, llvm::less_second()); @@ -407,7 +408,10 @@ void BogoCycleAnalysis::printStatsInternal(AsmState& asmState, (const char*)"Any", (const char*)"Location"); + size_t lineCount = 0; for (auto& [loc, stat] : llvm::reverse(totLocStats)) { + if (++lineCount > kMaxLines) + break; if (stat.flat) os << llvm::format("%9.5f%% ", stat.flat * 100. / totCycles); else