Skip to content

Commit ca19ae6

Browse files
rework fetchModelName to work by endpoint
this change allows us to use different model names for the precheckEndpoint and precheckScoringEndpoint Signed-off-by: greg pereira <[email protected]>
1 parent 5d8d28a commit ca19ae6

File tree

3 files changed

+25
-12
lines changed

3 files changed

+25
-12
lines changed

ui/apiserver/apiserver.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ func (api *ApiServer) runIlabChatCommand(question, context string) (string, erro
336336
cmd = exec.Command("echo", cmdArgs...)
337337
api.logger.Infof("Running in test mode: %s", commandStr)
338338
} else {
339-
modelName, err := api.fetchModelName(true)
339+
modelName, err := api.fetchModelName(true, api.preCheckEndpointURL)
340340
if err != nil {
341341
api.logger.Errorf("Failed to fetch model name: %v", err)
342342
return "failed to retrieve the model name", err
@@ -382,9 +382,8 @@ func setupLogger(debugMode bool) *zap.SugaredLogger {
382382

383383
// fetchModelName hits the defined precheck endpoint with "/models" appended to extract the model name.
384384
// If fullName is true, it returns the entire ID value; if false, it returns the parsed out name after the double hyphens.
385-
func (api *ApiServer) fetchModelName(fullName bool) (string, error) {
385+
func (api *ApiServer) fetchModelName(fullName bool, endpoint string) (string, error) {
386386
// Ensure the endpoint URL ends with "/models"
387-
endpoint := api.preCheckEndpointURL
388387
if !strings.HasSuffix(endpoint, "/") {
389388
endpoint += "/"
390389
}

worker/cmd/generate.go

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ var generateCmd = &cobra.Command{
218218

219219
func (w *Worker) runPrecheckScoring(precheckPRAnswers []string, precheckEndpointAnswers []string, precheckPRQuestions []string, lab string, outputDir string, preCheckScoringModelName string) error {
220220
if len(precheckPRAnswers) != len(precheckEndpointAnswers) {
221-
errMsg := "PR questions a Endpoint answers returned a different number of entries, something went wrong."
221+
errMsg := "PR questions a Endpoint answers returned a different number of entries, something went wrong"
222222
w.logger.Error(errMsg)
223223
return fmt.Errorf(errMsg)
224224
}
@@ -638,7 +638,7 @@ func (w *Worker) processJob() {
638638
// sdg-svc does not have a models endpoint as yet
639639
if jobType != jobSDG && PreCheckEndpointURL != localEndpoint {
640640
var err error
641-
modelName, err = w.fetchModelName(true)
641+
modelName, err = w.fetchModelName(true, w.precheckEndpoint)
642642
if err != nil {
643643
w.logger.Errorf("Failed to fetch model name: %v", err)
644644
modelName = "unknown"
@@ -683,7 +683,21 @@ func (w *Worker) processJob() {
683683
w.reportJobError(err)
684684
return
685685
}
686-
err = w.runPrecheckScoring(precheckPRAnswers, precheckEndpointAnswers, precheckPRQuestions, lab, outputDir, modelName)
686+
687+
var scoringModelName string
688+
// sdg-svc does not have a models endpoint as yet
689+
if jobType == jobPreCheck && w.precheckScoringEndpoint != localEndpoint {
690+
var err error
691+
scoringModelName, err = w.fetchModelName(true, w.precheckScoringEndpoint)
692+
if err != nil {
693+
w.logger.Errorf("Failed to fetch model name: %v", err)
694+
scoringModelName = "unknown"
695+
}
696+
} else {
697+
scoringModelName = w.getModelNameFromConfig() // will default to standard precheck model
698+
}
699+
700+
err = w.runPrecheckScoring(precheckPRAnswers, precheckEndpointAnswers, precheckPRQuestions, lab, outputDir, scoringModelName)
687701
if err != nil {
688702
sugar.Errorf("Could not run scoring on result of precheck: %v", err)
689703
w.reportJobError(err)
@@ -975,9 +989,9 @@ func (w *Worker) getModelNameFromConfig() string {
975989

976990
// fetchModelName hits the defined precheckEndpoint with "/models" appended to extract the model name.
977991
// If fullName is true, it returns the entire ID value; if false, it returns the parsed out name after the double hyphens.
978-
func (w *Worker) fetchModelName(fullName bool) (string, error) {
992+
func (w *Worker) fetchModelName(fullName bool, endpoint string) (string, error) {
979993
// Ensure the endpoint URL ends with "/models"
980-
endpoint := w.precheckEndpoint
994+
// endpoint := w.precheckEndpoint
981995
if !strings.HasSuffix(endpoint, "/") {
982996
endpoint += "/"
983997
}
@@ -1073,7 +1087,7 @@ func (w *Worker) determineModelName(jobType string) string {
10731087

10741088
// precheck is the only case we use a remote OpenAI endpoint right now
10751089
if PreCheckEndpointURL != localEndpoint && jobType == jobPreCheck {
1076-
modelName, err := w.fetchModelName(false)
1090+
modelName, err := w.fetchModelName(false, w.precheckEndpoint)
10771091
if err != nil {
10781092
w.logger.Errorf("Failed to fetch model name: %v", err)
10791093
return "unknown"

worker/cmd/generate_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,12 +161,12 @@ func TestFetchModelName(t *testing.T) {
161161
20,
162162
)
163163

164-
modelName, err := w.fetchModelName(false)
164+
modelName, err := w.fetchModelName(false, w.precheckEndpoint)
165165
assert.NoError(t, err, "fetchModelName should not return an error")
166166
expectedModelName := "Mixtral-8x7B-Instruct-v0.1"
167167
assert.Equal(t, expectedModelName, modelName, "The model name should be extracted correctly")
168168

169-
modelName, err = w.fetchModelName(true)
169+
modelName, err = w.fetchModelName(true, w.precheckEndpoint)
170170
assert.NoError(t, err, "fetchModelName should not return an error")
171171
expectedModelName = "/shared_model_storage/transformers_cache/models--mistralai--Mixtral-8x7B-Instruct-v0.1/snapshots/5c79a376139be989ef1838f360bf4f1f256d7aec"
172172
assert.Equal(t, expectedModelName, modelName, "The model name should be extracted correctly")
@@ -222,7 +222,7 @@ func TestFetchModelNameWithInvalidObject(t *testing.T) {
222222
"dummy-ca-cert-path.pem",
223223
20,
224224
)
225-
modelName, err := w.fetchModelName(false)
225+
modelName, err := w.fetchModelName(false, w.precheckEndpoint)
226226

227227
// Verify that an error was returned due to the invalid "object" field
228228
assert.Error(t, err, "fetchModelName should return an error for invalid object field")

0 commit comments

Comments
 (0)