Skip to content

Commit 79d1524

Browse files
authored
[NFC][IR2Vec] Moving parseVocabSection() to VocabStorage (#161711)
1 parent 3960ff6 commit 79d1524

File tree

2 files changed

+50
-45
lines changed

2 files changed

+50
-45
lines changed

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,13 @@ class VocabStorage {
210210
const_iterator end() const {
211211
return const_iterator(this, getNumSections(), 0);
212212
}
213+
214+
using VocabMap = std::map<std::string, Embedding>;
215+
/// Parse a vocabulary section from JSON and populate the target vocabulary
216+
/// map.
217+
static Error parseVocabSection(StringRef Key,
218+
const json::Value &ParsedVocabValue,
219+
VocabMap &TargetVocab, unsigned &Dim);
213220
};
214221

215222
/// Class for storing and accessing the IR2Vec vocabulary.
@@ -600,8 +607,6 @@ class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
600607

601608
Error readVocabulary(VocabMap &OpcVocab, VocabMap &TypeVocab,
602609
VocabMap &ArgVocab);
603-
Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue,
604-
VocabMap &TargetVocab, unsigned &Dim);
605610
void generateVocabStorage(VocabMap &OpcVocab, VocabMap &TypeVocab,
606611
VocabMap &ArgVocab);
607612
void emitError(Error Err, LLVMContext &Ctx);

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 43 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,43 @@ bool VocabStorage::const_iterator::operator!=(
330330
return !(*this == Other);
331331
}
332332

333+
Error VocabStorage::parseVocabSection(StringRef Key,
334+
const json::Value &ParsedVocabValue,
335+
VocabMap &TargetVocab, unsigned &Dim) {
336+
json::Path::Root Path("");
337+
const json::Object *RootObj = ParsedVocabValue.getAsObject();
338+
if (!RootObj)
339+
return createStringError(errc::invalid_argument,
340+
"JSON root is not an object");
341+
342+
const json::Value *SectionValue = RootObj->get(Key);
343+
if (!SectionValue)
344+
return createStringError(errc::invalid_argument,
345+
"Missing '" + std::string(Key) +
346+
"' section in vocabulary file");
347+
if (!json::fromJSON(*SectionValue, TargetVocab, Path))
348+
return createStringError(errc::illegal_byte_sequence,
349+
"Unable to parse '" + std::string(Key) +
350+
"' section from vocabulary");
351+
352+
Dim = TargetVocab.begin()->second.size();
353+
if (Dim == 0)
354+
return createStringError(errc::illegal_byte_sequence,
355+
"Dimension of '" + std::string(Key) +
356+
"' section of the vocabulary is zero");
357+
358+
if (!std::all_of(TargetVocab.begin(), TargetVocab.end(),
359+
[Dim](const std::pair<StringRef, Embedding> &Entry) {
360+
return Entry.second.size() == Dim;
361+
}))
362+
return createStringError(
363+
errc::illegal_byte_sequence,
364+
"All vectors in the '" + std::string(Key) +
365+
"' section of the vocabulary are not of the same dimension");
366+
367+
return Error::success();
368+
}
369+
333370
// ==----------------------------------------------------------------------===//
334371
// Vocabulary
335372
//===----------------------------------------------------------------------===//
@@ -460,43 +497,6 @@ VocabStorage Vocabulary::createDummyVocabForTest(unsigned Dim) {
460497
// IR2VecVocabAnalysis
461498
//===----------------------------------------------------------------------===//
462499

463-
Error IR2VecVocabAnalysis::parseVocabSection(
464-
StringRef Key, const json::Value &ParsedVocabValue, VocabMap &TargetVocab,
465-
unsigned &Dim) {
466-
json::Path::Root Path("");
467-
const json::Object *RootObj = ParsedVocabValue.getAsObject();
468-
if (!RootObj)
469-
return createStringError(errc::invalid_argument,
470-
"JSON root is not an object");
471-
472-
const json::Value *SectionValue = RootObj->get(Key);
473-
if (!SectionValue)
474-
return createStringError(errc::invalid_argument,
475-
"Missing '" + std::string(Key) +
476-
"' section in vocabulary file");
477-
if (!json::fromJSON(*SectionValue, TargetVocab, Path))
478-
return createStringError(errc::illegal_byte_sequence,
479-
"Unable to parse '" + std::string(Key) +
480-
"' section from vocabulary");
481-
482-
Dim = TargetVocab.begin()->second.size();
483-
if (Dim == 0)
484-
return createStringError(errc::illegal_byte_sequence,
485-
"Dimension of '" + std::string(Key) +
486-
"' section of the vocabulary is zero");
487-
488-
if (!std::all_of(TargetVocab.begin(), TargetVocab.end(),
489-
[Dim](const std::pair<StringRef, Embedding> &Entry) {
490-
return Entry.second.size() == Dim;
491-
}))
492-
return createStringError(
493-
errc::illegal_byte_sequence,
494-
"All vectors in the '" + std::string(Key) +
495-
"' section of the vocabulary are not of the same dimension");
496-
497-
return Error::success();
498-
}
499-
500500
// FIXME: Make this optional. We can avoid file reads
501501
// by auto-generating a default vocabulary during the build time.
502502
Error IR2VecVocabAnalysis::readVocabulary(VocabMap &OpcVocab,
@@ -513,16 +513,16 @@ Error IR2VecVocabAnalysis::readVocabulary(VocabMap &OpcVocab,
513513
return ParsedVocabValue.takeError();
514514

515515
unsigned OpcodeDim = 0, TypeDim = 0, ArgDim = 0;
516-
if (auto Err =
517-
parseVocabSection("Opcodes", *ParsedVocabValue, OpcVocab, OpcodeDim))
516+
if (auto Err = VocabStorage::parseVocabSection("Opcodes", *ParsedVocabValue,
517+
OpcVocab, OpcodeDim))
518518
return Err;
519519

520-
if (auto Err =
521-
parseVocabSection("Types", *ParsedVocabValue, TypeVocab, TypeDim))
520+
if (auto Err = VocabStorage::parseVocabSection("Types", *ParsedVocabValue,
521+
TypeVocab, TypeDim))
522522
return Err;
523523

524-
if (auto Err =
525-
parseVocabSection("Arguments", *ParsedVocabValue, ArgVocab, ArgDim))
524+
if (auto Err = VocabStorage::parseVocabSection("Arguments", *ParsedVocabValue,
525+
ArgVocab, ArgDim))
526526
return Err;
527527

528528
if (!(OpcodeDim == TypeDim && TypeDim == ArgDim))

0 commit comments

Comments
 (0)