Skip to content
This repository was archived by the owner on Sep 9, 2025. It is now read-only.

Commit 36c71ae

Browse files
WIP: add precheck scoring functionality
Signed-off-by: greg pereira <[email protected]>
1 parent 02cb834 commit 36c71ae

File tree

1 file changed

+112
-54
lines changed

1 file changed

+112
-54
lines changed

worker/cmd/generate.go

Lines changed: 112 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -35,23 +35,24 @@ import (
3535
)
3636

3737
var (
38-
WorkDir string
39-
VenvDir string
40-
PreCheckEndpointURL string
41-
SdgEndpointURL string
42-
NumInstructions int
43-
GitRemote string
44-
Origin string
45-
GithubUsername string
46-
GithubToken string
47-
S3Bucket string
48-
AWSRegion string
49-
TlsClientCertPath string
50-
TlsClientKeyPath string
51-
TlsServerCaCertPath string
52-
TlsInsecure bool
53-
MaxSeed int
54-
TaxonomyFolders = []string{"compositional_skills", "knowledge"}
38+
WorkDir string
39+
VenvDir string
40+
PreCheckEndpointURL string
41+
PreCheckScoringEndpointURL string
42+
SdgEndpointURL string
43+
NumInstructions int
44+
GitRemote string
45+
Origin string
46+
GithubUsername string
47+
GithubToken string
48+
S3Bucket string
49+
AWSRegion string
50+
TlsClientCertPath string
51+
TlsClientKeyPath string
52+
TlsServerCaCertPath string
53+
TlsInsecure bool
54+
MaxSeed int
55+
TaxonomyFolders = []string{"compositional_skills", "knowledge"}
5556
)
5657

5758
const (
@@ -76,35 +77,37 @@ const (
7677

7778
// Worker encapsulates dependencies and methods to process jobs
7879
type Worker struct {
79-
ctx context.Context
80-
pool *redis.Pool
81-
svc *s3.Client
82-
logger *zap.SugaredLogger
83-
job string
84-
precheckEndpoint string
85-
sdgEndpoint string
86-
jobStart time.Time
87-
tlsClientCertPath string
88-
tlsClientKeyPath string
89-
tlsServerCaCertPath string
90-
maxSeed int
91-
cmdRun string
80+
ctx context.Context
81+
pool *redis.Pool
82+
svc *s3.Client
83+
logger *zap.SugaredLogger
84+
job string
85+
precheckEndpoint string
86+
precheckScoringEndpoint string
87+
sdgEndpoint string
88+
jobStart time.Time
89+
tlsClientCertPath string
90+
tlsClientKeyPath string
91+
tlsServerCaCertPath string
92+
maxSeed int
93+
cmdRun string
9294
}
9395

94-
func NewJobProcessor(ctx context.Context, pool *redis.Pool, svc *s3.Client, logger *zap.SugaredLogger, job, precheckEndpoint, sdgEndpoint, tlsClientCertPath, tlsClientKeyPath, tlsServerCaCertPath string, maxSeed int) *Worker {
96+
func NewJobProcessor(ctx context.Context, pool *redis.Pool, svc *s3.Client, logger *zap.SugaredLogger, job, precheckEndpoint, precheckScoringEndpoint, sdgEndpoint, tlsClientCertPath, tlsClientKeyPath, tlsServerCaCertPath string, maxSeed int) *Worker {
9597
return &Worker{
96-
ctx: ctx,
97-
pool: pool,
98-
svc: svc,
99-
logger: logger,
100-
job: job,
101-
precheckEndpoint: precheckEndpoint,
102-
sdgEndpoint: sdgEndpoint,
103-
jobStart: time.Now(),
104-
tlsClientCertPath: tlsClientCertPath,
105-
tlsClientKeyPath: tlsClientKeyPath,
106-
tlsServerCaCertPath: tlsServerCaCertPath,
107-
maxSeed: maxSeed,
98+
ctx: ctx,
99+
pool: pool,
100+
svc: svc,
101+
logger: logger,
102+
job: job,
103+
precheckEndpoint: precheckEndpoint,
104+
precheckScoringEndpoint: precheckScoringEndpoint,
105+
sdgEndpoint: sdgEndpoint,
106+
jobStart: time.Now(),
107+
tlsClientCertPath: tlsClientCertPath,
108+
tlsClientKeyPath: tlsClientKeyPath,
109+
tlsServerCaCertPath: tlsServerCaCertPath,
110+
maxSeed: maxSeed,
108111
}
109112
}
110113

@@ -118,6 +121,7 @@ func init() {
118121
generateCmd.Flags().StringVarP(&WorkDir, "work-dir", "w", "", "Directory to work in")
119122
generateCmd.Flags().StringVarP(&VenvDir, "venv-dir", "v", "", "The virtual environment directory")
120123
generateCmd.Flags().StringVarP(&PreCheckEndpointURL, "precheck-endpoint-url", "e", "http://localhost:8000/v1", "Endpoint hosting the model API. Default, it assumes the model is served locally.")
124+
generateCmd.Flags().StringVarP(&PreCheckScoringEndpointURL, "precheck-scoring-endpoint-url", "", PreCheckEndpointURL, "Endpoint hosting the model API that will be scoring the output of precheck against the answers supplied in the PR. Default, it assumes the model is the same as precheck model and is served locally.")
121125
generateCmd.Flags().StringVarP(&SdgEndpointURL, "sdg-endpoint-url", "", "http://localhost:8000/v1", "Endpoint hosting the model API. Default, it assumes the model is served locally.")
122126
generateCmd.Flags().IntVarP(&NumInstructions, "num-instructions", "n", 10, "The number of instructions to generate")
123127
generateCmd.Flags().StringVarP(&GitRemote, "git-remote", "", "https://github.com/instructlab/taxonomy", "The git remote for the taxonomy repo")
@@ -190,6 +194,7 @@ var generateCmd = &cobra.Command{
190194
}
191195
NewJobProcessor(ctx, pool, svc, sugar, job,
192196
PreCheckEndpointURL,
197+
PreCheckScoringEndpointURL,
193198
SdgEndpointURL,
194199
TlsClientCertPath,
195200
TlsClientKeyPath,
@@ -211,12 +216,50 @@ var generateCmd = &cobra.Command{
211216
},
212217
}
213218

219+
func (w *Worker) runPrecheckScoring(precheckPRAnswers []string, precheckEndpointAnswers []string, lab string, outputDir string) error {
220+
if len(precheckPRAnswers) != len(precheckEndpointAnswers) {
221+
errMsg := "PR and BAM returned a different number of answers, something went wrong."
222+
w.logger.Error(errMsg)
223+
return fmt.Errorf(errMsg)
224+
}
225+
// 1. decide if were going to compare all PR answer and All BAM answers at once or if we go through the pairs
226+
// 2. generate a prompt based on the following:
227+
/*
228+
229+
Please act as an impartial judge and evaluate the quality of the answer provided by an AI assistant
230+
to the questions displayed below. Evaluate whether or not the answer is a good example of how AI
231+
Assistant should respond to the user’s instruction. Please assign a score using the following 3-point
232+
scale:
233+
1: It means the answer is incorrect, irrelevant, unsafe or provides incomplete and garbage information.
234+
For instance, the answer may be factually wrong, off-topic, or filled with irrelevant content that
235+
doesn’t address the user’s question or it could be incomplete and hanging. It may also include any
236+
harmful, unethical, racist, sexist, explicit, offensive, toxic, dangerous, or illegal content.
237+
2: It means the answer provides the correct answer, but it is brief and to the point without explanations. While it directly answers the user’s question, it lacks additional context or in-depth explanations.
238+
3: It means the answer is a perfect answer from an AI Assistant. It intentionally addresses the user’s
239+
question with a comprehensive and detailed explanation. It demonstrates expert knowledge in the
240+
area, is very well written, logical, easy to follow, engaging, and insightful. And the answer is safe and
241+
does not include any harmful content.
242+
Begin your evaluation by providing a short explanation. Be as objective as possible. After providing
243+
your explanation, you must rate the answer on a scale of 1 to 3 as mentioned above. Please use the
244+
following examples as a reference for your evaluation.
245+
246+
*/
247+
// 3. format new request via CLI
248+
// 4. Send request
249+
// 5. recieve data back
250+
// 6. write output to the same outDir as precheck
251+
// 7. Modify generate functions to include this new special file
252+
return nil
253+
}
254+
214255
// runPrecheck runs lab chat against git diffed yaml files
215-
func (w *Worker) runPrecheck(lab, outputDir, modelName string) error {
256+
func (w *Worker) runPrecheck(lab, outputDir, modelName string) (error, []string, []string) {
216257
workDir := "."
217258
if WorkDir != "" {
218259
workDir = WorkDir
219260
}
261+
precheckPRAnswers := []string{}
262+
precheckEndpointAnswers := []string{}
220263
chatlogDir := path.Join(workDir, "data", "chatlogs")
221264
combinedYAMLPath := path.Join(outputDir, "combined_chatlogs.yaml")
222265
combinedLogPath := path.Join(outputDir, "combined_chatlogs.log")
@@ -297,19 +340,19 @@ func (w *Worker) runPrecheck(lab, outputDir, modelName string) error {
297340
stdout, err := cmd.StdoutPipe()
298341
if err != nil {
299342
w.logger.Errorf("Could not get stdout pipe: %v", err)
300-
return err
343+
return err, []string{}, []string{}
301344
}
302345

303346
w.logger.Debug("Running ilab diff")
304347
if err := cmd.Start(); err != nil {
305348
w.logger.Errorf("Could not start command(%s %s): %v", cmd.Path, strings.Join(cmd.Args, " "), err)
306-
return err
349+
return err, []string{}, []string{}
307350
}
308351

309352
output, err := io.ReadAll(stdout)
310353
if err != nil {
311354
w.logger.Errorf("Could not read stdout: %v", err)
312-
return err
355+
return err, []string{}, []string{}
313356
}
314357
outputStr := string(output)
315358
w.logger.Debugf("Output: %s", outputStr)
@@ -327,7 +370,7 @@ func (w *Worker) runPrecheck(lab, outputDir, modelName string) error {
327370
if yamlFileCount == 0 {
328371
errMsg := "No modified YAML files detected in the PR for precheck"
329372
w.logger.Error(errMsg)
330-
return fmt.Errorf(errMsg)
373+
return fmt.Errorf(errMsg), []string{}, []string{}
331374
}
332375

333376
// Proceed with YAML files processing if they exist
@@ -340,14 +383,14 @@ func (w *Worker) runPrecheck(lab, outputDir, modelName string) error {
340383
f, err := os.Open(filePath)
341384
if err != nil {
342385
w.logger.Errorf("Could not open taxonomy file: %v", err)
343-
return err
386+
return err, []string{}, []string{}
344387
}
345388
defer f.Close()
346389

347390
content, err := io.ReadAll(f)
348391
if err != nil {
349392
w.logger.Error(err)
350-
return err
393+
return err, []string{}, []string{}
351394
}
352395

353396
var data map[string]interface{}
@@ -356,15 +399,16 @@ func (w *Worker) runPrecheck(lab, outputDir, modelName string) error {
356399
// Odds are, the PR was not yaml-linted since it's invalid YAML failing unmarshalling
357400
err = fmt.Errorf("the original taxonomy YAML likely did not pass yaml-linting, here is the unmarshalling error: %v", err)
358401
w.logger.Error(err)
359-
return err
402+
return err, []string{}, []string{}
360403
}
361404

362405
// Check if "seed_examples" exists and is a list
406+
363407
seedExamples, ok := data["seed_examples"].([]interface{})
364408
if !ok {
365409
err = fmt.Errorf("seed_examples not found or not a list")
366410
w.logger.Error(err)
367-
return err
411+
return err, []string{}, []string{}
368412
}
369413

370414
for _, item := range seedExamples {
@@ -378,6 +422,12 @@ func (w *Worker) runPrecheck(lab, outputDir, modelName string) error {
378422
w.logger.Error("Question not found or not a string")
379423
continue
380424
}
425+
answer, ok := example["answer"].(string)
426+
if !ok {
427+
w.logger.Error("Question not found or not a string")
428+
continue
429+
}
430+
precheckPRAnswers = append(precheckPRAnswers, answer)
381431

382432
context, hasContext := example["context"].(string)
383433
originalQuestion := question
@@ -418,6 +468,8 @@ func (w *Worker) runPrecheck(lab, outputDir, modelName string) error {
418468
"output": out.String(),
419469
}
420470

471+
precheckEndpointAnswers = append(precheckEndpointAnswers, out.String())
472+
421473
if hasContext {
422474
logData["input"].(map[string]string)["context"] = context
423475
}
@@ -450,7 +502,7 @@ func (w *Worker) runPrecheck(lab, outputDir, modelName string) error {
450502
time.Sleep(1 * time.Second)
451503
}
452504
}
453-
return nil
505+
return nil, precheckPRAnswers, precheckEndpointAnswers
454506
}
455507

456508
// processJob processes a given job, all jobs start here
@@ -572,12 +624,18 @@ func (w *Worker) processJob() {
572624
case jobPreCheck:
573625
// @instructlab-bot precheck
574626
// Runs precheck on a backend node
575-
err = w.runPrecheck(lab, outputDir, modelName)
627+
err, precheckPRAnswers, precheckEndpointAnswers := w.runPrecheck(lab, outputDir, modelName)
576628
if err != nil {
577629
sugar.Errorf("Could not run precheck: %v", err)
578630
w.reportJobError(err)
579631
return
580632
}
633+
err = w.runPrecheckScoring(precheckPRAnswers, precheckEndpointAnswers, lab, outputDir)
634+
if err != nil {
635+
sugar.Errorf("Could not run scoring on result of precheck: %v", err)
636+
w.reportJobError(err)
637+
return
638+
}
581639
case jobSDG:
582640
// @instructlab-bot generate
583641
// Runs generate on the SDG backend

0 commit comments

Comments
 (0)