|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +# pyre-unsafe |
| 8 | + |
| 9 | +import argparse |
| 10 | + |
| 11 | +import torch |
| 12 | +from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge |
| 13 | +from executorch.extension.pybindings.portable_lib import ExecuTorchModule |
| 14 | +from executorch.extension.training.examples.CIFAR.model import ( |
| 15 | + CIFAR10Model, |
| 16 | + ModelWithLoss, |
| 17 | +) |
| 18 | +from executorch.extension.training.examples.CIFAR.utils import ( |
| 19 | + fine_tune_executorch_model, |
| 20 | + get_data_loaders, |
| 21 | + save_json, |
| 22 | + train_model, |
| 23 | +) |
| 24 | +from torch.export import export |
| 25 | +from torch.export.experimental import _export_forward_backward |
| 26 | + |
| 27 | + |
| 28 | +def export_model( |
| 29 | + net: torch.nn.Module, input_tensor: torch.Tensor, label_tensor: torch.Tensor |
| 30 | +) -> ExecuTorchModule: |
| 31 | + """ |
| 32 | + Export a PyTorch model to an ExecutorTorch module format. |
| 33 | +
|
| 34 | + This function takes a PyTorch model and sample input/label |
| 35 | + tensors, wraps the model with a loss function, exports it |
| 36 | + using torch.export, applies forward-backward pass |
| 37 | + optimization, converts it to edge format, and finally to |
| 38 | + ExecutorTorch format. |
| 39 | +
|
| 40 | + Args: |
| 41 | + net (torch.nn.Module): The PyTorch model to be exported |
| 42 | + input_tensor (torch.Tensor): A sample input tensor with |
| 43 | + the correct shape |
| 44 | + label_tensor (torch.Tensor): A sample label tensor with |
| 45 | + the correct shape |
| 46 | +
|
| 47 | + Returns: |
| 48 | + ExecuTorchModule: The exported model in ExecutorTorch |
| 49 | + format ready for deployment |
| 50 | + """ |
| 51 | + criterion = torch.nn.CrossEntropyLoss() |
| 52 | + model_with_loss = ModelWithLoss(net, criterion) |
| 53 | + ep = export(model_with_loss, (input_tensor, label_tensor), strict=True) |
| 54 | + ep = _export_forward_backward(ep) |
| 55 | + ep = to_edge(ep, compile_config=EdgeCompileConfig(_check_ir_validity=False)) |
| 56 | + ep = ep.to_executorch() |
| 57 | + return ep |
| 58 | + |
| 59 | + |
| 60 | +def export_model_with_ptd( |
| 61 | + net: torch.nn.Module, input_tensor: torch.Tensor, label_tensor: torch.Tensor |
| 62 | +) -> ExecuTorchModule: |
| 63 | + """ |
| 64 | + Export a PyTorch model to an ExecutorTorch module format with external |
| 65 | + tensor data. |
| 66 | +
|
| 67 | + This function takes a PyTorch model and sample input/label tensors, |
| 68 | + wraps the model with a loss function, exports it using torch.export, |
| 69 | + applies forward-backward pass optimization, converts it to edge format, |
| 70 | + and finally to ExecutorTorch format with external constants and mutable |
| 71 | + weights. |
| 72 | +
|
| 73 | + Args: |
| 74 | + net (torch.nn.Module): The PyTorch model to be exported |
| 75 | + input_tensor (torch.Tensor): A sample input tensor with the correct |
| 76 | + shape |
| 77 | + label_tensor (torch.Tensor): A sample label tensor with the correct |
| 78 | + shape |
| 79 | +
|
| 80 | + Returns: |
| 81 | + ExecuTorchModule: The exported model in ExecutorTorch format ready for |
| 82 | + deployment |
| 83 | + """ |
| 84 | + criterion = torch.nn.CrossEntropyLoss() |
| 85 | + model_with_loss = ModelWithLoss(net, criterion) |
| 86 | + ep = export(model_with_loss, (input_tensor, label_tensor), strict=True) |
| 87 | + ep = _export_forward_backward(ep) |
| 88 | + ep = to_edge(ep, compile_config=EdgeCompileConfig(_check_ir_validity=False)) |
| 89 | + ep = ep.to_executorch( |
| 90 | + config=ExecutorchBackendConfig( |
| 91 | + external_constants=True, # This is the flag that |
| 92 | + # enables the external constants to be stored in a |
| 93 | + # separate file external to the PTE file. |
| 94 | + external_mutable_weights=True, # This is the flag |
| 95 | + # that enables all trainable weights will be stored |
| 96 | + # in a separate file external to the PTE file. |
| 97 | + ) |
| 98 | + ) |
| 99 | + return ep |
| 100 | + |
| 101 | + |
| 102 | +def save_model(ep: ExecuTorchModule, model_path: str) -> None: |
| 103 | + """ |
| 104 | + Save an ExecutorTorch model to a specified file path. |
| 105 | +
|
| 106 | + This function writes the buffer of an ExecutorTorchModule to a |
| 107 | + file in binary format. |
| 108 | +
|
| 109 | + Args: |
| 110 | + ep (ExecuTorchModule): The ExecutorTorch module to be saved. |
| 111 | + model_path (str): The file path where the model will be saved. |
| 112 | + """ |
| 113 | + with open(model_path, "wb") as file: |
| 114 | + file.write(ep.buffer) |
| 115 | + |
| 116 | + |
| 117 | +def parse_args() -> argparse.Namespace: |
| 118 | + """ |
| 119 | + Parse command line arguments for the CIFAR-10 training script. |
| 120 | +
|
| 121 | + This function sets up an argument parser with various configuration options |
| 122 | + for training a CIFAR-10 model with ExecutorTorch, including data paths, |
| 123 | + training hyperparameters, and model save locations. |
| 124 | +
|
| 125 | + Returns: |
| 126 | + argparse.Namespace: An object containing all the parsed command line |
| 127 | + arguments with their respective values (either user-provided or |
| 128 | + defaults). |
| 129 | + """ |
| 130 | + parser = argparse.ArgumentParser(description="CIFAR-10 Training Example") |
| 131 | + parser.add_argument( |
| 132 | + "--data-dir", |
| 133 | + type=str, |
| 134 | + default="./data", |
| 135 | + help="Directory to download CIFAR-10 dataset (default: ./data)", |
| 136 | + ) |
| 137 | + parser.add_argument( |
| 138 | + "--batch-size", |
| 139 | + type=int, |
| 140 | + default=4, |
| 141 | + help="Batch size for data loaders (default: 4)", |
| 142 | + ) |
| 143 | + parser.add_argument( |
| 144 | + "--use-balanced-dataset", |
| 145 | + action="store_true", |
| 146 | + default=True, |
| 147 | + help="Use balanced dataset instead of full CIFAR-10 (default: True)", |
| 148 | + ) |
| 149 | + parser.add_argument( |
| 150 | + "--images-per-class", |
| 151 | + type=int, |
| 152 | + default=100, |
| 153 | + help="Number of images per class for balanced dataset (default: 100)", |
| 154 | + ) |
| 155 | + parser.add_argument( |
| 156 | + "--model-path", |
| 157 | + type=str, |
| 158 | + default="cifar10_model.pth", |
| 159 | + help="PyTorch model path (default: cifar10_model.pth)", |
| 160 | + ) |
| 161 | + |
| 162 | + parser.add_argument( |
| 163 | + "--pte-model-path", |
| 164 | + type=str, |
| 165 | + default="cifar10_model.pte", |
| 166 | + help="PTE model path (default: cifar10_model.pte)", |
| 167 | + ) |
| 168 | + |
| 169 | + parser.add_argument( |
| 170 | + "--split-pte-model-path", |
| 171 | + type=str, |
| 172 | + default="split_cifar10_model.pte", |
| 173 | + help="Split PTE model path (default: split_cifar10_model.pte)", |
| 174 | + ) |
| 175 | + |
| 176 | + parser.add_argument( |
| 177 | + "--ptd-model-dir", type=str, default=".", help="PTD model path (default: .)" |
| 178 | + ) |
| 179 | + |
| 180 | + parser.add_argument( |
| 181 | + "--save-pt-json", |
| 182 | + type=str, |
| 183 | + default="cifar10_pt_model_finetuned_history.json", |
| 184 | + help="Save the et json file", |
| 185 | + ) |
| 186 | + |
| 187 | + parser.add_argument( |
| 188 | + "--save-et-json", |
| 189 | + type=str, |
| 190 | + default="cifar10_et_pte_only_model_finetuned_history.json", |
| 191 | + help="Save the et json file", |
| 192 | + ) |
| 193 | + |
| 194 | + parser.add_argument( |
| 195 | + "--epochs", |
| 196 | + type=int, |
| 197 | + default=1, |
| 198 | + help="Number of epochs for training (default: 1)", |
| 199 | + ) |
| 200 | + |
| 201 | + parser.add_argument( |
| 202 | + "--fine-tune-epochs", |
| 203 | + type=int, |
| 204 | + default=10, |
| 205 | + help="Number of fine-tuning epochs for fine-tuning (default: 150)", |
| 206 | + ) |
| 207 | + |
| 208 | + parser.add_argument( |
| 209 | + "--learning-rate", |
| 210 | + type=float, |
| 211 | + default=0.001, |
| 212 | + help="Learning rate for fine-tuning (default: 0.001)", |
| 213 | + ) |
| 214 | + |
| 215 | + return parser.parse_args() |
| 216 | + |
| 217 | + |
| 218 | +def main() -> None: |
| 219 | + |
| 220 | + args = parse_args() |
| 221 | + |
| 222 | + train_loader, test_loader = get_data_loaders( |
| 223 | + batch_size=args.batch_size, |
| 224 | + data_dir=args.data_dir, |
| 225 | + use_balanced_dataset=args.use_balanced_dataset, |
| 226 | + images_per_class=args.images_per_class, |
| 227 | + ) |
| 228 | + |
| 229 | + # initialize the main model |
| 230 | + model = CIFAR10Model() |
| 231 | + |
| 232 | + model, train_hist = train_model( |
| 233 | + model, |
| 234 | + train_loader, |
| 235 | + test_loader, |
| 236 | + epochs=1, |
| 237 | + lr=0.001, |
| 238 | + momentum=0.9, |
| 239 | + save_path=args.model_path, |
| 240 | + ) |
| 241 | + |
| 242 | + save_json(train_hist, args.save_pt_json) |
| 243 | + |
| 244 | + # Export the model for et runtime |
| 245 | + validation_sample_data = next(iter(test_loader)) |
| 246 | + img, lbl = validation_sample_data |
| 247 | + sample_input = img[0:1, :] |
| 248 | + sample_label = lbl[0:1] |
| 249 | + |
| 250 | + ep = export_model(model, sample_input, sample_label) |
| 251 | + |
| 252 | + save_model(ep, args.pte_model_path) |
| 253 | + |
| 254 | + et_model, et_hist = fine_tune_executorch_model( |
| 255 | + args.pte_model_path, |
| 256 | + args.pte_model_path, |
| 257 | + train_loader, |
| 258 | + test_loader, |
| 259 | + epochs=args.fine_tune_epochs, |
| 260 | + learning_rate=args.learning_rate, |
| 261 | + ) |
| 262 | + |
| 263 | + save_json(et_hist, args.save_et_json) |
| 264 | + |
| 265 | + # Split the model into the pte and ptd files |
| 266 | + exported_program = export_model_with_ptd(model, sample_input, sample_label) |
| 267 | + |
| 268 | + exported_program._tensor_data["generic_cifar"] = exported_program._tensor_data.pop( |
| 269 | + "_default_external_constant" |
| 270 | + ) |
| 271 | + exported_program.write_tensor_data_to_file(args.ptd_model_dir) |
| 272 | + save_model(exported_program, args.split_pte_model_path) |
| 273 | + |
| 274 | + # Finetune the PyTorch model |
| 275 | + model, train_hist = train_model( |
| 276 | + model, |
| 277 | + train_loader, |
| 278 | + test_loader, |
| 279 | + epochs=args.fine_tune_epochs, |
| 280 | + lr=args.learning_rate, |
| 281 | + momentum=0.9, |
| 282 | + save_path=args.model_path, |
| 283 | + ) |
| 284 | + |
| 285 | + save_json(train_hist, args.save_pt_json) |
| 286 | + |
| 287 | + |
| 288 | +if __name__ == "__main__": |
| 289 | + main() |
0 commit comments