|
| 1 | +""" |
| 2 | +Copyright 2025 Google LLC |
| 3 | +
|
| 4 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +you may not use this file except in compliance with the License. |
| 6 | +You may obtain a copy of the License at |
| 7 | +
|
| 8 | + https://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +
|
| 10 | +Unless required by applicable law or agreed to in writing, software |
| 11 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +See the License for the specific language governing permissions and |
| 14 | +limitations under the License. |
| 15 | +""" |
| 16 | + |
| 17 | +""" |
| 18 | +This file implements an agent that evaluates the correctness of JAX code |
| 19 | +generated from PyTorch code by running pytest test cases. It uses a language |
| 20 | +model to generate the test cases and captures the results of the tests. |
| 21 | +
|
| 22 | +The agent performs the following steps: |
| 23 | +1. Reads pairs of PyTorch and JAX files from specified directories. |
| 24 | +2. For each pair, it generates a pytest-compatible test case using a language |
| 25 | + model. |
| 26 | +3. It runs the generated test case and captures the output, including the number |
| 27 | + of passed and failed tests. |
| 28 | +4. It logs the results and calculates overall accuracy metrics. |
| 29 | +
|
| 30 | +Example Invocation: |
| 31 | +python code_evaluation_agent.py |
| 32 | +
|
| 33 | +Ensure the paths to the PyTorch and JAX code directories are correctly set in |
| 34 | +the script. The script will create a directory for test cases if it doesn't |
| 35 | +exist and will overwrite existing test cases based on the `overwrite_existing_files` |
| 36 | +flag. |
| 37 | +
|
| 38 | +Overall Accuracy Metrics: |
| 39 | +- Test Case Accuracy: The percentage of individual test cases that passed across |
| 40 | + all generated tests. |
| 41 | +- File Accuracy: The percentage of files for which all generated test cases passed. |
| 42 | +
|
| 43 | +Relevant Files: |
| 44 | +- `prompt_code_evaluation.py`: Contains the prompts used by the language model |
| 45 | + for generating test cases. |
| 46 | +- `utils.py`: Provides utility functions such as `get_last_defined_module` |
| 47 | + (to extract the main module from a code string) and `run_pytest_capture_output` |
| 48 | + (to execute pytest and capture its results). |
| 49 | +- `code_generation_agent/llm_agent.py`: Contains the `GeminiAgent` class used |
| 50 | + to interact with the language model. |
| 51 | +- `orchestration_agent/Utils.py`: Contains `parse_python_code` for extracting |
| 52 | + code from LLM responses. |
| 53 | +""" |
| 54 | +import argparse |
| 55 | +import os, logging, sys |
| 56 | +from prompt_code_evaluation import CodeEvaluation |
| 57 | +from utils import get_last_defined_module, run_pytest_capture_output |
| 58 | + |
| 59 | +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) |
| 60 | + |
| 61 | +from code_generation_agent.llm_agent import GeminiAgent |
| 62 | +from orchestration_agent.Utils import parse_python_code |
| 63 | + |
| 64 | +logging.basicConfig( |
| 65 | + format="%(asctime)s %(levelname)-8s %(message)s", |
| 66 | + level=logging.INFO, |
| 67 | + datefmt="%Y-%m-%d %H:%M:%S", |
| 68 | +) |
| 69 | +logger = logging.getLogger(__name__) |
| 70 | +# logging.raiseExceptions = False |
| 71 | + |
| 72 | + |
| 73 | +parser = argparse.ArgumentParser(description="Code Evaluation Agent") |
| 74 | +parser.add_argument("--error_penalty", type=int, default=10, help="Penalty for errors in test case generation or execution.") |
| 75 | +parser.add_argument("--pytorch_path", type=str, default="../code_generation_agent/dataset/PyTorch/", help="Path to the directory containing PyTorch files.") |
| 76 | +parser.add_argument("--jax_path", type=str, default="../code_generation_agent/dataset/jax_converted/", help="Path to the directory containing JAX files.") |
| 77 | +parser.add_argument("--testcase_path", type=str, default="../code_generation_agent/dataset/test_cases/", help="Path to the directory for generated test cases.") |
| 78 | +parser.add_argument("--overwrite_existing_files", action="store_true", help="Overwrite existing test case files.") |
| 79 | +args = parser.parse_args() |
| 80 | + |
| 81 | +overwrite_existing_files = args.overwrite_existing_files |
| 82 | +error_penalty = args.error_penalty |
| 83 | +pytorch_path = args.pytorch_path |
| 84 | +jax_path = args.jax_path |
| 85 | +testcase_path = args.testcase_path |
| 86 | +os.makedirs(testcase_path, exist_ok=True) |
| 87 | + |
| 88 | +llm_agent = GeminiAgent(CodeEvaluation["SystemPrompt"]) |
| 89 | + |
| 90 | + |
| 91 | +def get_file_pairs(pytorch_path, jax_path): |
| 92 | + """Generates lists of file paths for PyTorch and JAX files that have a common name. |
| 93 | +
|
| 94 | + This function finds files with the same name in the specified PyTorch and JAX |
| 95 | + directories, filtering out any files in the JAX directory that start with "__". |
| 96 | +
|
| 97 | + Args: |
| 98 | + pytorch_path: The path to the directory containing PyTorch files. |
| 99 | + jax_path: The path to the directory containing JAX files. |
| 100 | +
|
| 101 | + Returns: |
| 102 | + A tuple containing two lists of strings: |
| 103 | + - The first list contains the full paths to the common PyTorch files. |
| 104 | + - The second list contains the full paths to the common JAX files. |
| 105 | + """ |
| 106 | + pytorch_files = os.listdir(pytorch_path) |
| 107 | + jax_files = list(filter(lambda x: not x.startswith("__"), os.listdir(jax_path))) |
| 108 | + common_files = list(set(pytorch_files).intersection(jax_files)) |
| 109 | + return list(map(lambda x: pytorch_path + x, common_files)), list(map(lambda x: jax_path + x, common_files)) |
| 110 | + |
| 111 | + |
| 112 | +def make_test_case_and_run(python_file, jax_file): |
| 113 | + """Generates a test case and runs it for a given PyTorch and JAX file pair. |
| 114 | +
|
| 115 | + This function uses a language model to generate a pytest-compatible test case |
| 116 | + for a PyTorch and JAX code file pair. It then runs the test and captures the output. |
| 117 | + If the files have inconsistent entry points or the test case cannot be generated, |
| 118 | + a penalty is applied. |
| 119 | +
|
| 120 | + Args: |
| 121 | + python_file: The path to the PyTorch code file. |
| 122 | + jax_file: The path to the JAX code file. |
| 123 | +
|
| 124 | + Returns: |
| 125 | + A tuple containing the number of passed and failed test cases. |
| 126 | + """ |
| 127 | + try: |
| 128 | + logger.info(f"Processing {python_file}") |
| 129 | + out_file_path = os.path.join(testcase_path, python_file.split("/")[-1]) |
| 130 | + if overwrite_existing_files or not os.path.exists(out_file_path): |
| 131 | + with open(python_file) as f: |
| 132 | + python_code = f.read() |
| 133 | + with open(jax_file) as f: |
| 134 | + jax_code = f.read() |
| 135 | + entry_module = get_last_defined_module(python_code) |
| 136 | + if get_last_defined_module(jax_code) != entry_module: |
| 137 | + logger.error( |
| 138 | + f"It seems inconsistency in {python_file} code PyTorch have {entry_module} and JAX have {get_last_defined_module(jax_code)} as entry Module" |
| 139 | + ) |
| 140 | + # Penalty in case of Entry point not exist or different from torch |
| 141 | + return 0, error_penalty |
| 142 | + prompt = CodeEvaluation["TESTCASE"] |
| 143 | + python_code = ( |
| 144 | + "from " + ".".join(python_file.split("/")[1:]).replace(".py", " import " + entry_module) + "\n\n" + python_code |
| 145 | + ) |
| 146 | + jax_code = "from " + ".".join(jax_file.split("/")[1:]).replace(".py", " import " + entry_module) + "\n\n" + jax_code |
| 147 | + prompt = prompt.replace("<module.path.to.pytorch_code>", python_code) |
| 148 | + prompt = prompt.replace("<module.path.to.jax_code>", jax_code) |
| 149 | + prompt = prompt.replace("<function_or_class_to_call>", entry_module) |
| 150 | + response = llm_agent(prompt) |
| 151 | + generated_code = parse_python_code(response.text) |
| 152 | + with open(out_file_path, "w") as f: |
| 153 | + f.write("import os,sys\nsys.path.append(os.path.abspath('..'))\n") |
| 154 | + f.write(generated_code) |
| 155 | + logger.info("Written at %s", out_file_path) |
| 156 | + if "<UNABLETOGENERATE>" in response: |
| 157 | + return 0, error_penalty |
| 158 | + else: |
| 159 | + logger.info("File Exists using same") |
| 160 | + file = python_file.split("/")[-1] |
| 161 | + output, exit_code, is_dependency_error, passed, failed = run_pytest_capture_output(file, code_folder=testcase_path) |
| 162 | + return passed, failed |
| 163 | + except Exception as e: |
| 164 | + logger.error("Exception in code generation %s", e) |
| 165 | + logger.error("The code file is %s", python_file.split("/")[-1]) |
| 166 | + logger.error("The generated Code is %s", response) |
| 167 | + # Penalty in case of Exception |
| 168 | + return 0, error_penalty |
| 169 | + |
| 170 | + |
| 171 | +def run_code_evaluation(): |
| 172 | + """Runs the full code evaluation process. |
| 173 | +
|
| 174 | + This function orchestrates the evaluation of PyTorch and JAX code file pairs. |
| 175 | + It iterates through the common files, generates and runs a test case for each, |
| 176 | + and then logs the results. It also calculates and prints the overall |
| 177 | + test case and file accuracy. |
| 178 | + """ |
| 179 | + total_passed, total_failed = 0, 0 |
| 180 | + all_passed, all_failed, total_files = 0, 0, 0 |
| 181 | + for python_file, jax_file in zip(*get_file_pairs(pytorch_path, jax_path)): |
| 182 | + num_passed, num_failed = make_test_case_and_run(python_file, jax_file) |
| 183 | + if num_passed == num_failed == 0: # when the code cannot be executed |
| 184 | + # Penalty in case of issue in test case and not executed |
| 185 | + num_failed = error_penalty |
| 186 | + logger.info(f"{python_file.split('/')[-1]} have {num_passed} cases passed and {num_failed} cases failed") |
| 187 | + total_passed += num_passed |
| 188 | + total_failed += num_failed |
| 189 | + if num_passed == 0: |
| 190 | + all_failed += 1 |
| 191 | + if num_failed == 0: |
| 192 | + all_passed += 1 |
| 193 | + total_files += 1 |
| 194 | + |
| 195 | + logger.info("****** Results ******") |
| 196 | + logger.info(f"{all_passed} files have all module passed {all_failed} files have all module failed") |
| 197 | + logger.info( |
| 198 | + f"Test case Accuracy {total_passed*100/(total_passed+total_failed):.2f}%", |
| 199 | + ) |
| 200 | + logger.info( |
| 201 | + f"File Accuracy {all_passed * 100 / total_files:.2f}%", |
| 202 | + ) |
| 203 | + |
| 204 | + |
| 205 | +if __name__ == "__main__": |
| 206 | + run_code_evaluation() |
0 commit comments