11#include "llm.h"
2+
23#include <HalideRuntime.h>
4+
5+ #include <chrono>
36#include <iomanip>
47#include <iostream>
58
@@ -20,6 +23,49 @@ ABSL_FLAG(int, max_tokens, 512,
2023 "Maximum number of input and output tokens. This value needs to be "
2124 "at least larger than the number of input tokens.");
2225
26+ ABSL_FLAG(bool, show_timing, false,
27+ "Show timing for operations.");
28+
29+ namespace {
30+
31+ // Prefer high_resolution_clock, but only if it's steady...
32+ template<bool HighResIsSteady = std::chrono::high_resolution_clock::is_steady>
33+ struct SteadyClock {
34+ using type = std::chrono::high_resolution_clock;
35+ };
36+
37+ // ...otherwise use steady_clock.
38+ template<>
39+ struct SteadyClock<false> {
40+ using type = std::chrono::steady_clock;
41+ };
42+
43+
44+ struct TimingScope {
45+ TimingScope(const char *name, int iterations = 1) : name(name), iterations(iterations) {
46+ start = SteadyClock<>::type::now();
47+ }
48+
49+ ~TimingScope() {
50+ if (absl::GetFlag(FLAGS_show_timing)) {
51+ SteadyClock<>::type::time_point end = SteadyClock<>::type::now();
52+ double secs = std::chrono::duration_cast<std::chrono::duration<double>>(end - start).count();
53+ std::cerr << name << ": took " << secs << "s";
54+ if (iterations != 1) {
55+ std::cerr << " " << secs / iterations << "s per iteration.\n";
56+ } else {
57+ std::cerr << "\n";
58+ }
59+ }
60+ }
61+
62+ std::string name;
63+ int iterations;
64+ SteadyClock<>::type::time_point start;
65+ };
66+
67+ }
68+
2369int main(int argc, char *argv[]) {
2470 absl::ParseCommandLine(argc, argv);
2571
@@ -30,6 +76,7 @@ int main(int argc, char *argv[]) {
3076
3177 sentencepiece::SentencePieceProcessor tokenizer;
3278 {
79+ TimingScope load_tokenizer("Loading tokenizer");
3380 auto result = tokenizer.Load(tokenizer_path);
3481 if (!result.ok()) {
3582 std::cerr << result.message();
@@ -49,66 +96,87 @@ int main(int argc, char *argv[]) {
4996 auto result = tokenizer.Encode(bracketed_prompt, &prompt_tokens);
5097 }
5198
52- std::cerr << "Loading LLM params.\n";
53- auto p = hallmark::LoadLlmParams(model_path);
54- if (!p.ok()) {
55- std::cerr << p.status() << "\n";
56- return 1;
99+ hallmark::LlmParams llm_params;
100+ {
101+ TimingScope load_tokenizer("Loading LLM params");
102+ auto p = hallmark::LoadLlmParams(model_path);
103+ if (!p.ok()) {
104+ std::cerr << p.status() << "\n";
105+ return 1;
106+ }
107+ llm_params = std::move(p.value());
57108 }
58- auto llm_params = std::move(p.value());
59109 llm_params.seq_size_T = max_tokens;
60110
61- std::cerr << "Loading LLM weights.\n";
62- auto w = hallmark::LoadLlmWeights(model_path, llm_params);
63- if (!w.ok()) {
64- std::cerr << w.status() << "\n";
65- return 1;
111+ hallmark::LlmWeights llm_weights;
112+ {
113+ TimingScope load_tokenizer("Loading LLM params");
114+ auto w = hallmark::LoadLlmWeights(model_path, llm_params);
115+ if (!w.ok()) {
116+ std::cerr << w.status() << "\n";
117+ return 1;
118+ }
119+ llm_weights = std::move(w.value());
66120 }
67- auto llm_weights = std::move(w.value());
68121
69- std::cerr << "Creating LLM.\n";
70- auto l = hallmark::Llm::CreateLlm(llm_weights, llm_params);
71- if (!l.ok()) {
72- std::cerr << l.status() << "\n";
73- return 2;
122+ std::unique_ptr<hallmark::Llm> llm;
123+ {
124+ TimingScope load_tokenizer("Creating LLM");
125+ auto l = hallmark::Llm::CreateLlm(llm_weights, llm_params);
126+ if (!l.ok()) {
127+ std::cerr << l.status() << "\n";
128+ return 2;
129+ }
130+ llm = std::move(l.value());
74131 }
75- auto llm = std::move(l.value());
76132
77133 if (!llm->Reset().ok()) {
78134 std::cerr << "Reset fails\n";
79135 return 3;
80136 }
81- if (!llm->InitAttentionMaskValues(llm_params.seq_size_T).ok()) {
82- std::cerr << "InitAttentionMaskValues fails\n";
83- return 4;
137+ {
138+ TimingScope load_tokenizer("Init attention mask");
139+ if (!llm->InitAttentionMaskValues(llm_params.seq_size_T).ok()) {
140+ std::cerr << "InitAttentionMaskValues fails\n";
141+ return 4;
142+ }
84143 }
85144
86- if (!llm->InitInputTokens(prompt_tokens).ok()) {
87- std::cerr << "InitInputTokens fails\n";
88- return 1;
145+ {
146+ TimingScope load_tokenizer("Init input tokens", prompt_tokens.size());
147+ if (!llm->InitInputTokens(prompt_tokens).ok()) {
148+ std::cerr << "InitInputTokens fails\n";
149+ return 1;
150+ }
89151 }
90152
91153 std::cout << prompt << "\n";
92154
93- for (int token = prompt_tokens.size(); token < max_tokens; token++) {
155+ {
156+ TimingScope generate("\nGenerate tokens", max_tokens);
94157 std::vector<int> output_tokens;
95- if (!llm->GetNextToken(&output_tokens).ok()) {
96- std::cerr << "GetNextToken fails\n";
97- return 6;
98- }
99- if (output_tokens.empty()) {
100- std::cerr << "Empty result from GetNextToken.\n";
101- }
102- std::string decoded_tokens;
103- if (!tokenizer.Decode(output_tokens, &decoded_tokens).ok()) {
104- std::cerr << "Decode fails\n";
105- return 7;
106- }
107- if (decoded_tokens.empty()) {
108- std::cout << "_";
158+ for (int token = prompt_tokens.size(); token < max_tokens - 2; token += output_tokens.size()) {
159+ output_tokens.clear();
160+ if (!llm->GetNextToken(&output_tokens).ok()) {
161+ std::cerr << "GetNextToken fails\n";
162+ return 6;
163+ }
164+ if (output_tokens.empty()) {
165+ std::cerr << "Empty result from GetNextToken.\n";
166+ } else if (output_tokens.size() > 1) {
167+ std::cerr << "More than one token returned from GetNextToken token " << token << ".\n";
168+ }
169+ std::string decoded_tokens;
170+ if (!tokenizer.Decode(output_tokens, &decoded_tokens).ok()) {
171+ std::cerr << "Decode fails\n";
172+ return 7;
173+ }
174+ if (decoded_tokens.empty()) {
175+ std::cout << "_";
176+ }
177+ std::cout << decoded_tokens;
178+ std::cout.flush();
109179 }
110- std::cout << decoded_tokens;
111- std::cout.flush();
112180 }
113181
114182 return 0;
0 commit comments