Skip to content

Commit df1c471

Browse files
Merge pull request #2152 from AI-Hypercomputer:unitTestAgent
PiperOrigin-RevId: 794217061
2 parents 01c079d + 90c1129 commit df1c471

File tree

7 files changed

+474
-5
lines changed

7 files changed

+474
-5
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Code Evaluation Agent
2+
3+
This agent automates the evaluation of JAX code that has been converted from PyTorch. It works by generating and executing `pytest` test cases to compare the functional equivalence of the original PyTorch code and the converted JAX code. The agent leverages a large language model (Gemini) to create these test cases dynamically.
4+
5+
## Workflow
6+
7+
1. **File Pairing**: The agent identifies pairs of corresponding PyTorch and JAX files from specified input directories.
8+
2. **Test Case Generation**: For each file pair, it prompts the Gemini model to generate a comprehensive `pytest` test case. The generated test compares the outputs of the PyTorch and JAX modules using randomized inputs to ensure they are numerically close (`numpy.allclose`).
9+
3. **Test Execution**: The generated test case is saved as a Python file and executed using `pytest`.
10+
4. **Result Aggregation**: The agent captures the results (pass/fail counts) from each test run.
11+
5. **Reporting**: Finally, it calculates and logs two key metrics:
12+
* **Test Case Accuracy**: The percentage of individual test cases that passed across all files.
13+
* **File Accuracy**: The percentage of files for which all generated test cases passed.
14+
15+
## File Descriptions
16+
17+
- **`code_evaluation_agent.py`**: The main executable script that orchestrates the entire evaluation process.
18+
- **`prompt_code_evaluation.py`**: Contains the system and user prompt templates that instruct the Gemini model on how to generate the `pytest` test cases.
19+
- **`utils.py`**: Provides helper functions, including `run_pytest_capture_output` to execute `pytest` and capture its results, and `get_last_defined_module` to identify the primary component in a code file.
20+
21+
## Setup
22+
23+
1. **Install Dependencies**:
24+
Make sure you have the required Python packages installed.
25+
```bash
26+
pip install pytest google-generativeai backoff python-dotenv
27+
```
28+
29+
2. **Configure Environment Variables**:
30+
This agent uses the `GeminiAgent` from the `code_generation_agent`, which requires a `.env` file in the `code_generation_agent` directory.
31+
32+
```.env
33+
# in MaxText/experimental/agent/code_generation_agent/.env
34+
GOOGLE_API_KEY="YOUR_API_KEY_HERE"
35+
Model="gemini-2.5-pro"
36+
```
37+
38+
3. **Configure Paths**:
39+
In `code_evaluation_agent.py`, set the following path variables to point to your datasets. The script will create the test case directory if it doesn't exist. You can modify the paths as needed.
40+
41+
```python
42+
# in code_evaluation_agent.py
43+
pytorch_path="../code_generation_agent/dataset/PyTorch/"
44+
jax_path="../code_generation_agent/dataset/jax_converted/"
45+
testcase_path="../code_generation_agent/dataset/test_cases/"
46+
```
47+
48+
## Usage
49+
50+
Before running the agent, ensure you have:
51+
52+
1. Your original PyTorch files in the directory specified by `pytorch_path`.
53+
2. The corresponding converted JAX files in the directory specified by `jax_path`. The filenames must match between the two directories.
54+
55+
To start the evaluation process, run the following command from within the `code_evaluation_agent` directory:
56+
57+
```bash
58+
python code_evaluation_agent.py
59+
```
60+
61+
The agent will process each file pair, generate tests, run them, and print the progress and final accuracy metrics to the console.
62+
63+
## Output
64+
65+
The agent provides real-time logging for each file being processed. At the end of the run, it prints a summary of the results, including:
66+
67+
- The number of files that passed all tests.
68+
- The number of files that had at least one failing test.
69+
- The overall **Test Case Accuracy**.
70+
- The overall **File Accuracy**.
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
"""
18+
This file implements an agent that evaluates the correctness of JAX code
19+
generated from PyTorch code by running pytest test cases. It uses a language
20+
model to generate the test cases and captures the results of the tests.
21+
22+
The agent performs the following steps:
23+
1. Reads pairs of PyTorch and JAX files from specified directories.
24+
2. For each pair, it generates a pytest-compatible test case using a language
25+
model.
26+
3. It runs the generated test case and captures the output, including the number
27+
of passed and failed tests.
28+
4. It logs the results and calculates overall accuracy metrics.
29+
30+
Example Invocation:
31+
python code_evaluation_agent.py
32+
33+
Ensure the paths to the PyTorch and JAX code directories are correctly set in
34+
the script. The script will create a directory for test cases if it doesn't
35+
exist and will overwrite existing test cases based on the `overwrite_existing_files`
36+
flag.
37+
38+
Overall Accuracy Metrics:
39+
- Test Case Accuracy: The percentage of individual test cases that passed across
40+
all generated tests.
41+
- File Accuracy: The percentage of files for which all generated test cases passed.
42+
43+
Relevant Files:
44+
- `prompt_code_evaluation.py`: Contains the prompts used by the language model
45+
for generating test cases.
46+
- `utils.py`: Provides utility functions such as `get_last_defined_module`
47+
(to extract the main module from a code string) and `run_pytest_capture_output`
48+
(to execute pytest and capture its results).
49+
- `code_generation_agent/llm_agent.py`: Contains the `GeminiAgent` class used
50+
to interact with the language model.
51+
- `orchestration_agent/Utils.py`: Contains `parse_python_code` for extracting
52+
code from LLM responses.
53+
"""
54+
import argparse
55+
import os, logging, sys
56+
from prompt_code_evaluation import CodeEvaluation
57+
from utils import get_last_defined_module, run_pytest_capture_output
58+
59+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
60+
61+
from code_generation_agent.llm_agent import GeminiAgent
62+
from orchestration_agent.Utils import parse_python_code
63+
64+
logging.basicConfig(
65+
format="%(asctime)s %(levelname)-8s %(message)s",
66+
level=logging.INFO,
67+
datefmt="%Y-%m-%d %H:%M:%S",
68+
)
69+
logger = logging.getLogger(__name__)
70+
# logging.raiseExceptions = False
71+
72+
73+
parser = argparse.ArgumentParser(description="Code Evaluation Agent")
74+
parser.add_argument("--error_penalty", type=int, default=10, help="Penalty for errors in test case generation or execution.")
75+
parser.add_argument("--pytorch_path", type=str, default="../code_generation_agent/dataset/PyTorch/", help="Path to the directory containing PyTorch files.")
76+
parser.add_argument("--jax_path", type=str, default="../code_generation_agent/dataset/jax_converted/", help="Path to the directory containing JAX files.")
77+
parser.add_argument("--testcase_path", type=str, default="../code_generation_agent/dataset/test_cases/", help="Path to the directory for generated test cases.")
78+
parser.add_argument("--overwrite_existing_files", action="store_true", help="Overwrite existing test case files.")
79+
args = parser.parse_args()
80+
81+
overwrite_existing_files = args.overwrite_existing_files
82+
error_penalty = args.error_penalty
83+
pytorch_path = args.pytorch_path
84+
jax_path = args.jax_path
85+
testcase_path = args.testcase_path
86+
os.makedirs(testcase_path, exist_ok=True)
87+
88+
llm_agent = GeminiAgent(CodeEvaluation["SystemPrompt"])
89+
90+
91+
def get_file_pairs(pytorch_path, jax_path):
92+
"""Generates lists of file paths for PyTorch and JAX files that have a common name.
93+
94+
This function finds files with the same name in the specified PyTorch and JAX
95+
directories, filtering out any files in the JAX directory that start with "__".
96+
97+
Args:
98+
pytorch_path: The path to the directory containing PyTorch files.
99+
jax_path: The path to the directory containing JAX files.
100+
101+
Returns:
102+
A tuple containing two lists of strings:
103+
- The first list contains the full paths to the common PyTorch files.
104+
- The second list contains the full paths to the common JAX files.
105+
"""
106+
pytorch_files = os.listdir(pytorch_path)
107+
jax_files = list(filter(lambda x: not x.startswith("__"), os.listdir(jax_path)))
108+
common_files = list(set(pytorch_files).intersection(jax_files))
109+
return list(map(lambda x: pytorch_path + x, common_files)), list(map(lambda x: jax_path + x, common_files))
110+
111+
112+
def make_test_case_and_run(python_file, jax_file):
113+
"""Generates a test case and runs it for a given PyTorch and JAX file pair.
114+
115+
This function uses a language model to generate a pytest-compatible test case
116+
for a PyTorch and JAX code file pair. It then runs the test and captures the output.
117+
If the files have inconsistent entry points or the test case cannot be generated,
118+
a penalty is applied.
119+
120+
Args:
121+
python_file: The path to the PyTorch code file.
122+
jax_file: The path to the JAX code file.
123+
124+
Returns:
125+
A tuple containing the number of passed and failed test cases.
126+
"""
127+
try:
128+
logger.info(f"Processing {python_file}")
129+
out_file_path = os.path.join(testcase_path, python_file.split("/")[-1])
130+
if overwrite_existing_files or not os.path.exists(out_file_path):
131+
with open(python_file) as f:
132+
python_code = f.read()
133+
with open(jax_file) as f:
134+
jax_code = f.read()
135+
entry_module = get_last_defined_module(python_code)
136+
if get_last_defined_module(jax_code) != entry_module:
137+
logger.error(
138+
f"It seems inconsistency in {python_file} code PyTorch have {entry_module} and JAX have {get_last_defined_module(jax_code)} as entry Module"
139+
)
140+
# Penalty in case of Entry point not exist or different from torch
141+
return 0, error_penalty
142+
prompt = CodeEvaluation["TESTCASE"]
143+
python_code = (
144+
"from " + ".".join(python_file.split("/")[1:]).replace(".py", " import " + entry_module) + "\n\n" + python_code
145+
)
146+
jax_code = "from " + ".".join(jax_file.split("/")[1:]).replace(".py", " import " + entry_module) + "\n\n" + jax_code
147+
prompt = prompt.replace("<module.path.to.pytorch_code>", python_code)
148+
prompt = prompt.replace("<module.path.to.jax_code>", jax_code)
149+
prompt = prompt.replace("<function_or_class_to_call>", entry_module)
150+
response = llm_agent(prompt)
151+
generated_code = parse_python_code(response.text)
152+
with open(out_file_path, "w") as f:
153+
f.write("import os,sys\nsys.path.append(os.path.abspath('..'))\n")
154+
f.write(generated_code)
155+
logger.info("Written at %s", out_file_path)
156+
if "<UNABLETOGENERATE>" in response:
157+
return 0, error_penalty
158+
else:
159+
logger.info("File Exists using same")
160+
file = python_file.split("/")[-1]
161+
output, exit_code, is_dependency_error, passed, failed = run_pytest_capture_output(file, code_folder=testcase_path)
162+
return passed, failed
163+
except Exception as e:
164+
logger.error("Exception in code generation %s", e)
165+
logger.error("The code file is %s", python_file.split("/")[-1])
166+
logger.error("The generated Code is %s", response)
167+
# Penalty in case of Exception
168+
return 0, error_penalty
169+
170+
171+
def run_code_evaluation():
172+
"""Runs the full code evaluation process.
173+
174+
This function orchestrates the evaluation of PyTorch and JAX code file pairs.
175+
It iterates through the common files, generates and runs a test case for each,
176+
and then logs the results. It also calculates and prints the overall
177+
test case and file accuracy.
178+
"""
179+
total_passed, total_failed = 0, 0
180+
all_passed, all_failed, total_files = 0, 0, 0
181+
for python_file, jax_file in zip(*get_file_pairs(pytorch_path, jax_path)):
182+
num_passed, num_failed = make_test_case_and_run(python_file, jax_file)
183+
if num_passed == num_failed == 0: # when the code cannot be executed
184+
# Penalty in case of issue in test case and not executed
185+
num_failed = error_penalty
186+
logger.info(f"{python_file.split('/')[-1]} have {num_passed} cases passed and {num_failed} cases failed")
187+
total_passed += num_passed
188+
total_failed += num_failed
189+
if num_passed == 0:
190+
all_failed += 1
191+
if num_failed == 0:
192+
all_passed += 1
193+
total_files += 1
194+
195+
logger.info("****** Results ******")
196+
logger.info(f"{all_passed} files have all module passed {all_failed} files have all module failed")
197+
logger.info(
198+
f"Test case Accuracy {total_passed*100/(total_passed+total_failed):.2f}%",
199+
)
200+
logger.info(
201+
f"File Accuracy {all_passed * 100 / total_files:.2f}%",
202+
)
203+
204+
205+
if __name__ == "__main__":
206+
run_code_evaluation()
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
"""
18+
This file contains the prompt templates used by the code evaluation agent.
19+
"""
20+
21+
CodeEvaluation = {
22+
"SystemPrompt": """You are an expert machine learning engineer and automated testing specialist with deep
23+
knowledge of Python, NumPy, PyTorch, JAX (Including libraries such as Flax, Flax.nnx and Optax).
24+
25+
You can:
26+
- Convert code written in PyTorch, Numpy, or other frameworks into functionally equivalent JAX code using appropriate libraries.
27+
- Analyze JAX-based code and generate meaningful testcases using `pytest`.
28+
- When both PyTorch and JAX modules are provided, generate a comprehensive test suite that:
29+
1. validates the PyTorch module independently.
30+
2. validates the JAX module independently.
31+
3. Compares their outputs across multiple randomized inputs using `numpy.allclose`.
32+
33+
Guidelines:
34+
- Assume helper functions and classes not defined in the code are already implemented and available.
35+
- Do not add or modify import statements unless they exist in the provided code.
36+
- Only return test code (no explanations) unless explicitly asked.
37+
- For trivial or untestable code, return `NOTESTCASE`.
38+
- When comparing PyTorch and JAX:
39+
- Accept `#torch_path` and `#jax_path` as import paths.
40+
- Accept an optional `#entry_point` that identifies the function or class to invoke.
41+
- Automatically generate randomized test inputs for shapes like `(2,3)`, `(4,)`, etc.
42+
- Write clear assertions for:
43+
- Output validity (no errors or exceptions)
44+
- Output comparison (`np.allclose`)
45+
""",
46+
"TESTCASE": """#torch_path
47+
<module.path.to.pytorch_code>
48+
49+
#jax_path
50+
<module.path.to.jax_code>
51+
52+
#entry_point
53+
<function_or_class_to_call>
54+
55+
#input_gen
56+
<code to generate input tensors or arrays>
57+
58+
#torch_code
59+
'''
60+
<insert full PyTorch code here>
61+
'''
62+
63+
#jax_code
64+
'''
65+
<insert full JAX code here>
66+
'''""",
67+
}

0 commit comments

Comments
 (0)