Skip to content

Commit c6c47d8

Browse files
mongodbenBen Perlmutter
andauthored
(EAI-1235): NL2AS benchmark orchestration and CLI integration (#874)
* rename dir for consistent naming * checkpoint * checkpoint * checkpoint * checkpoint * glue it all together * revert name change * updates based on review * remove mcp from core (not needed) * Fix build err * (EAI-1234): NL2AS Eval metrics (#875) * move retrieval metrics to core * add atlas search benchmark metrics --------- Co-authored-by: Ben Perlmutter <[email protected]> --------- Co-authored-by: Ben Perlmutter <[email protected]>
1 parent 6203daf commit c6c47d8

38 files changed

+4935
-1913
lines changed

package-lock.json

Lines changed: 4198 additions & 1799 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

packages/benchmarks/package.json

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,18 @@
4545
},
4646
"dependencies": {
4747
"@ai-sdk/openai": "^1.3.6",
48+
"@modelcontextprotocol/sdk": "^1.17.2",
4849
"@supercharge/promise-pool": "^3.2.0",
4950
"ai": "^4.2.10",
5051
"autoevals": "^0.0.129",
5152
"csv-writer": "^1.6.0",
5253
"dotenv": "^16",
5354
"mongodb-chatbot-server": "*",
55+
"mongodb-mcp-server": "^0.2.0",
5456
"mongodb-rag-core": "*",
5557
"mongodb-schema": "^12.2.0",
5658
"yaml": "^2.7.1",
5759
"yargs": "^17.7.2",
5860
"zod": "^3.23.8"
5961
}
60-
}
62+
}

packages/benchmarks/src/bin/mongoDbBenchmarkCli.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ import { assertEnvVars, BRAINTRUST_ENV_VARS } from "mongodb-rag-core";
66
import { multipleChoiceBenchmarkConfig } from "../quizQuestions/config";
77
import { nlPromptResponseBenchmark } from "../nlPromptResponse/config";
88
import { discoveryBenchmarkConfig } from "../discovery/config";
9-
import { nlToMongoshBenchmarkConfig } from "../textToDriver/config";
9+
import { nlToMongoshBenchmarkConfig } from "../textToDriver/nlToMongoshBenchmarkConfig";
10+
import { nlToAtlasSearchBenchmarkConfig } from "../textToDriver/nltoAtlasSearchBenchmarkConfig";
1011

1112
const { BRAINTRUST_API_KEY, BRAINTRUST_ENDPOINT } =
1213
assertEnvVars(BRAINTRUST_ENV_VARS);
@@ -22,6 +23,7 @@ const config: BenchmarkCliConfig = {
2223
nl_prompt_response: nlPromptResponseBenchmark,
2324
discovery: discoveryBenchmarkConfig,
2425
nl_to_mongosh: nlToMongoshBenchmarkConfig,
26+
nl_to_atlas_search: nlToAtlasSearchBenchmarkConfig,
2527
},
2628
};
2729

packages/benchmarks/src/cli/BenchmarkConfig.ts

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ export interface BenchmarkTask<
2323
taskFunc: (
2424
modelProvider: ModelProvider,
2525
deployment: ModelConfig
26-
) => EvalTask<Input, Output, Expected, Metadata, Parameters>;
26+
) =>
27+
| Promise<EvalTask<Input, Output, Expected, Metadata, Parameters>>
28+
| EvalTask<Input, Output, Expected, Metadata, Parameters>;
2729
description?: string;
2830
}
2931

@@ -48,6 +50,10 @@ export interface BenchmarkConfig<
4850
datasets: Record<string, BenchmarkDataset<Input, Expected, Metadata>>;
4951
tasks: Record<string, BenchmarkTask<Input, Output, Expected, Metadata>>;
5052
scorers: Record<string, BenchmarkScorer<Input, Output, Expected, Metadata>>;
53+
environment?: {
54+
beforeAll?: () => Promise<void>;
55+
afterAll?: () => Promise<void>;
56+
};
5157
}
5258

5359
export type ModelProvider = {

packages/benchmarks/src/cli/runBenchmark.test.ts

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ describe("runBenchmark", () => {
2828
let mockScorerFunc: jest.Mock;
2929
let mockProcessFunction: jest.Mock;
3030
let mockPromisePoolInstance: any;
31+
let mockBeforeAll: jest.Mock;
32+
let mockAfterAll: jest.Mock;
3133

3234
beforeEach(() => {
3335
mockDataset1 = [
@@ -38,6 +40,8 @@ describe("runBenchmark", () => {
3840

3941
mockTaskFunc = jest.fn().mockReturnValue("mock-task-result");
4042
mockScorerFunc = jest.fn().mockReturnValue("mock-scorer-result");
43+
mockBeforeAll = jest.fn();
44+
mockAfterAll = jest.fn();
4145

4246
mockModels = [
4347
{
@@ -64,6 +68,10 @@ describe("runBenchmark", () => {
6468
},
6569
benchmarks: {
6670
"test-benchmark": {
71+
environment: {
72+
beforeAll: mockBeforeAll,
73+
afterAll: mockAfterAll,
74+
},
6775
description: "Test benchmark",
6876
projectName: "test-project",
6977
datasets: {
@@ -629,5 +637,12 @@ describe("runBenchmark", () => {
629637
mockConfig.models[1]
630638
);
631639
});
640+
641+
it("should call beforeAll and afterAll functions", async () => {
642+
await runBenchmark(mockConfig, mockArgs);
643+
644+
expect(mockBeforeAll).toHaveBeenCalled();
645+
expect(mockAfterAll).toHaveBeenCalled();
646+
});
632647
});
633648
});

packages/benchmarks/src/cli/runBenchmark.ts

Lines changed: 64 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -53,65 +53,72 @@ export async function runBenchmark(
5353
console.log(`Task: ${task}`);
5454
console.log(`Model concurrency: ${modelConcurrency}`);
5555

56-
// Run benchmarks with model concurrency
57-
const { results } = await PromisePool.for(models)
58-
.withConcurrency(modelConcurrency)
59-
.handleError((error) => {
60-
console.error(error);
61-
})
62-
.process(async (model) => {
63-
const maxConcurrency = taskConcurrency ?? model.maxConcurrency ?? 1;
64-
65-
console.log(`Running experiments for model: ${model.label}`);
66-
67-
// Run each task-dataset combination
68-
69-
const dataset = (
70-
await Promise.all(
71-
datasetsToRun.map(([_datasetName, datasetConfig]) =>
72-
datasetConfig.getDataset()
56+
// Setup environment
57+
await benchmarkConfig.environment?.beforeAll?.();
58+
59+
try {
60+
// Run benchmarks with model concurrency
61+
const { results } = await PromisePool.for(models)
62+
.withConcurrency(modelConcurrency)
63+
.handleError((error) => {
64+
console.error(error);
65+
})
66+
.process(async (model) => {
67+
const maxConcurrency = taskConcurrency ?? model.maxConcurrency ?? 1;
68+
69+
console.log(`Running experiments for model: ${model.label}`);
70+
71+
// Run each task-dataset combination
72+
73+
const dataset = (
74+
await Promise.all(
75+
datasetsToRun.map(([_datasetName, datasetConfig]) =>
76+
datasetConfig.getDataset()
77+
)
7378
)
74-
)
75-
).flat();
76-
const datasetName = datasetsToRun.map(([name]) => name).join("+");
77-
78-
const experimentName = makeExperimentName({
79-
baseName: type,
80-
experimentType: task,
81-
datasets: datasetName,
82-
model: model.label,
83-
});
84-
85-
console.log(`Running experiment: ${experimentName}`);
86-
87-
const scores = Object.values(benchmarkConfig.scorers).map(
88-
(scorer) => scorer.scorerFunc
89-
);
90-
91-
try {
92-
// Load dataset
93-
// Run evaluation
94-
const evalResult = await Eval(benchmarkConfig.projectName, {
95-
data: dataset,
96-
experimentName,
97-
maxConcurrency,
98-
metadata: {
99-
model: model.label,
100-
task,
101-
dataset: datasetName,
102-
taskConcurrency,
103-
},
104-
task: taskToRun.taskFunc(config.modelProvider, model),
105-
scores,
79+
).flat();
80+
const datasetName = datasetsToRun.map(([name]) => name).join("+");
81+
82+
const experimentName = makeExperimentName({
83+
baseName: type,
84+
experimentType: task,
85+
datasets: datasetName,
86+
model: model.label,
10687
});
10788

108-
console.log(`✓ Completed experiment: ${experimentName}`);
109-
return { evalResult, dataset, experimentName };
110-
} catch (error) {
111-
console.error(`✗ Failed experiment: ${experimentName}`, error);
112-
}
113-
});
89+
console.log(`Running experiment: ${experimentName}`);
90+
91+
const scores = Object.values(benchmarkConfig.scorers).map(
92+
(scorer) => scorer.scorerFunc
93+
);
94+
95+
try {
96+
// Load dataset
97+
// Run evaluation
98+
const evalResult = await Eval(benchmarkConfig.projectName, {
99+
data: dataset,
100+
experimentName,
101+
maxConcurrency,
102+
metadata: {
103+
model: model.label,
104+
task,
105+
dataset: datasetName,
106+
taskConcurrency,
107+
},
108+
task: await taskToRun.taskFunc(config.modelProvider, model),
109+
scores,
110+
});
111+
112+
console.log(`✓ Completed experiment: ${experimentName}`);
113+
return { evalResult, dataset, experimentName };
114+
} catch (error) {
115+
console.error(`✗ Failed experiment: ${experimentName}`, error);
116+
}
117+
});
114118

115-
console.log("Benchmark run completed");
116-
return results;
119+
console.log("Benchmark run completed");
120+
return results;
121+
} finally {
122+
await benchmarkConfig.environment?.afterAll?.();
123+
}
117124
}

packages/benchmarks/src/textToDriver/bin/mongoshBenchmarks/claudeGenerated/promptCompletionAnnotatedSchema.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import { loadTextToDriverBraintrustEvalCases } from "../../../loadBraintrustData
33
import {
44
ReasonableOutput,
55
SuccessfulExecution,
6-
} from "../../../evaluationMetrics";
6+
} from "../../../scorers/evaluationMetrics";
77
import { annotatedDbSchemas } from "../../../generateDriverCode/annotatedDbSchemas";
88
import { createOpenAI, wrapLanguageModel } from "mongodb-rag-core/aiSdk";
99
import { BraintrustMiddleware } from "mongodb-rag-core/braintrust";

packages/benchmarks/src/textToDriver/bin/mongoshBenchmarks/config.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import {
77
SchemaStrategy,
88
SystemPromptStrategy,
99
} from "../../generateDriverCode/languagePrompts/PromptStrategies";
10-
import { makeMongoshBenchmarkMetrics } from "../../evaluationMetrics";
10+
import { makeMongoshBenchmarkMetrics } from "../../scorers/evaluationMetrics";
1111

1212
export { MODELS } from "../../../benchmarkModels";
1313

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import { GenerateTextResult, ToolSet } from "mongodb-rag-core/aiSdk";
2+
import { TextToDriverEvalTask, TextToDriverOutput } from "../TextToDriverEval";
3+
import {
4+
makeMongoDbMcpAgent,
5+
MakeMongoDbMcpAgentParams,
6+
} from "./mongoDbMcpAgent";
7+
8+
export async function makeGenerateAtlasSearchCodeAgenticTask(
9+
constructorArgs: MakeMongoDbMcpAgentParams
10+
): Promise<TextToDriverEvalTask> {
11+
const agent = await makeMongoDbMcpAgent(constructorArgs);
12+
return async function generateAtlasSearchCodeAgentic({
13+
databaseName,
14+
nlQuery,
15+
}) {
16+
const response = await agent({
17+
messages: [makeAtlasSearchUserMessage(databaseName, nlQuery)],
18+
});
19+
20+
return extractOutputFromMessages(response);
21+
};
22+
}
23+
24+
function makeAtlasSearchUserMessage(dbName: string, nlQuery: string) {
25+
return {
26+
role: "user" as const,
27+
content: `Database name: ${dbName}
28+
Natural language query: ${nlQuery}`,
29+
};
30+
}
31+
32+
function extractOutputFromMessages(
33+
agentResponse: GenerateTextResult<ToolSet, unknown>
34+
): TextToDriverOutput {
35+
// Find the last call to the `aggregate` tool
36+
const toolCalls =
37+
agentResponse.steps?.flatMap((step) => step.toolCalls || []) || [];
38+
const lastAggregateCall = toolCalls.findLast(
39+
(call) => call.toolName === "aggregate"
40+
);
41+
42+
if (!lastAggregateCall) {
43+
return {
44+
execution: {
45+
executionTimeMs: null,
46+
result: null,
47+
error: { message: "No tool calls found" },
48+
},
49+
generatedCode: "",
50+
} satisfies TextToDriverOutput;
51+
}
52+
53+
// Extract the tool call argument and stringify it for generatedCode
54+
const generatedCode = JSON.stringify(lastAggregateCall.input, null, 2);
55+
56+
// Get the result from the tool results in the steps
57+
const toolResults =
58+
agentResponse.steps?.flatMap((step) => step.toolResults || []) || [];
59+
const correspondingResult = toolResults.find(
60+
(result) => result.toolCallId === lastAggregateCall.toolCallId
61+
);
62+
const toolResult = correspondingResult?.output || null;
63+
64+
return {
65+
execution: {
66+
executionTimeMs: null,
67+
result: toolResult,
68+
},
69+
generatedCode,
70+
} satisfies TextToDriverOutput;
71+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
export const atlasSearchPrompt = `You are a MongoDB Atlas Search expert. You are given a natural language query and you need to generate the appropriate Atlas Search query.
2+
3+
You may use the available tools to help you explore the database and generate the query
4+
5+
Once you have generated a query that you are confident in, simply respond "Done" to the user.`;

0 commit comments

Comments
 (0)