Skip to content

Commit d85fc05

Browse files
authored
CIFAR 10 Debugging/ Training on ET
Differential Revision: D78284284 Pull Request resolved: #12457
1 parent e9c11a4 commit d85fc05

File tree

6 files changed

+1588
-0
lines changed

6 files changed

+1588
-0
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load(":targets.bzl", "define_common_targets")
5+
6+
oncall("executorch")
7+
8+
define_common_targets()
Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import argparse
10+
11+
import torch
12+
from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge
13+
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
14+
from executorch.extension.training.examples.CIFAR.model import (
15+
CIFAR10Model,
16+
ModelWithLoss,
17+
)
18+
from executorch.extension.training.examples.CIFAR.utils import (
19+
fine_tune_executorch_model,
20+
get_data_loaders,
21+
save_json,
22+
train_model,
23+
)
24+
from torch.export import export
25+
from torch.export.experimental import _export_forward_backward
26+
27+
28+
def export_model(
29+
net: torch.nn.Module, input_tensor: torch.Tensor, label_tensor: torch.Tensor
30+
) -> ExecuTorchModule:
31+
"""
32+
Export a PyTorch model to an ExecutorTorch module format.
33+
34+
This function takes a PyTorch model and sample input/label
35+
tensors, wraps the model with a loss function, exports it
36+
using torch.export, applies forward-backward pass
37+
optimization, converts it to edge format, and finally to
38+
ExecutorTorch format.
39+
40+
Args:
41+
net (torch.nn.Module): The PyTorch model to be exported
42+
input_tensor (torch.Tensor): A sample input tensor with
43+
the correct shape
44+
label_tensor (torch.Tensor): A sample label tensor with
45+
the correct shape
46+
47+
Returns:
48+
ExecuTorchModule: The exported model in ExecutorTorch
49+
format ready for deployment
50+
"""
51+
criterion = torch.nn.CrossEntropyLoss()
52+
model_with_loss = ModelWithLoss(net, criterion)
53+
ep = export(model_with_loss, (input_tensor, label_tensor), strict=True)
54+
ep = _export_forward_backward(ep)
55+
ep = to_edge(ep, compile_config=EdgeCompileConfig(_check_ir_validity=False))
56+
ep = ep.to_executorch()
57+
return ep
58+
59+
60+
def export_model_with_ptd(
61+
net: torch.nn.Module, input_tensor: torch.Tensor, label_tensor: torch.Tensor
62+
) -> ExecuTorchModule:
63+
"""
64+
Export a PyTorch model to an ExecutorTorch module format with external
65+
tensor data.
66+
67+
This function takes a PyTorch model and sample input/label tensors,
68+
wraps the model with a loss function, exports it using torch.export,
69+
applies forward-backward pass optimization, converts it to edge format,
70+
and finally to ExecutorTorch format with external constants and mutable
71+
weights.
72+
73+
Args:
74+
net (torch.nn.Module): The PyTorch model to be exported
75+
input_tensor (torch.Tensor): A sample input tensor with the correct
76+
shape
77+
label_tensor (torch.Tensor): A sample label tensor with the correct
78+
shape
79+
80+
Returns:
81+
ExecuTorchModule: The exported model in ExecutorTorch format ready for
82+
deployment
83+
"""
84+
criterion = torch.nn.CrossEntropyLoss()
85+
model_with_loss = ModelWithLoss(net, criterion)
86+
ep = export(model_with_loss, (input_tensor, label_tensor), strict=True)
87+
ep = _export_forward_backward(ep)
88+
ep = to_edge(ep, compile_config=EdgeCompileConfig(_check_ir_validity=False))
89+
ep = ep.to_executorch(
90+
config=ExecutorchBackendConfig(
91+
external_constants=True, # This is the flag that
92+
# enables the external constants to be stored in a
93+
# separate file external to the PTE file.
94+
external_mutable_weights=True, # This is the flag
95+
# that enables all trainable weights will be stored
96+
# in a separate file external to the PTE file.
97+
)
98+
)
99+
return ep
100+
101+
102+
def save_model(ep: ExecuTorchModule, model_path: str) -> None:
103+
"""
104+
Save an ExecutorTorch model to a specified file path.
105+
106+
This function writes the buffer of an ExecutorTorchModule to a
107+
file in binary format.
108+
109+
Args:
110+
ep (ExecuTorchModule): The ExecutorTorch module to be saved.
111+
model_path (str): The file path where the model will be saved.
112+
"""
113+
with open(model_path, "wb") as file:
114+
file.write(ep.buffer)
115+
116+
117+
def parse_args() -> argparse.Namespace:
118+
"""
119+
Parse command line arguments for the CIFAR-10 training script.
120+
121+
This function sets up an argument parser with various configuration options
122+
for training a CIFAR-10 model with ExecutorTorch, including data paths,
123+
training hyperparameters, and model save locations.
124+
125+
Returns:
126+
argparse.Namespace: An object containing all the parsed command line
127+
arguments with their respective values (either user-provided or
128+
defaults).
129+
"""
130+
parser = argparse.ArgumentParser(description="CIFAR-10 Training Example")
131+
parser.add_argument(
132+
"--data-dir",
133+
type=str,
134+
default="./data",
135+
help="Directory to download CIFAR-10 dataset (default: ./data)",
136+
)
137+
parser.add_argument(
138+
"--batch-size",
139+
type=int,
140+
default=4,
141+
help="Batch size for data loaders (default: 4)",
142+
)
143+
parser.add_argument(
144+
"--use-balanced-dataset",
145+
action="store_true",
146+
default=True,
147+
help="Use balanced dataset instead of full CIFAR-10 (default: True)",
148+
)
149+
parser.add_argument(
150+
"--images-per-class",
151+
type=int,
152+
default=100,
153+
help="Number of images per class for balanced dataset (default: 100)",
154+
)
155+
parser.add_argument(
156+
"--model-path",
157+
type=str,
158+
default="cifar10_model.pth",
159+
help="PyTorch model path (default: cifar10_model.pth)",
160+
)
161+
162+
parser.add_argument(
163+
"--pte-model-path",
164+
type=str,
165+
default="cifar10_model.pte",
166+
help="PTE model path (default: cifar10_model.pte)",
167+
)
168+
169+
parser.add_argument(
170+
"--split-pte-model-path",
171+
type=str,
172+
default="split_cifar10_model.pte",
173+
help="Split PTE model path (default: split_cifar10_model.pte)",
174+
)
175+
176+
parser.add_argument(
177+
"--ptd-model-dir", type=str, default=".", help="PTD model path (default: .)"
178+
)
179+
180+
parser.add_argument(
181+
"--save-pt-json",
182+
type=str,
183+
default="cifar10_pt_model_finetuned_history.json",
184+
help="Save the et json file",
185+
)
186+
187+
parser.add_argument(
188+
"--save-et-json",
189+
type=str,
190+
default="cifar10_et_pte_only_model_finetuned_history.json",
191+
help="Save the et json file",
192+
)
193+
194+
parser.add_argument(
195+
"--epochs",
196+
type=int,
197+
default=1,
198+
help="Number of epochs for training (default: 1)",
199+
)
200+
201+
parser.add_argument(
202+
"--fine-tune-epochs",
203+
type=int,
204+
default=10,
205+
help="Number of fine-tuning epochs for fine-tuning (default: 150)",
206+
)
207+
208+
parser.add_argument(
209+
"--learning-rate",
210+
type=float,
211+
default=0.001,
212+
help="Learning rate for fine-tuning (default: 0.001)",
213+
)
214+
215+
return parser.parse_args()
216+
217+
218+
def main() -> None:
219+
220+
args = parse_args()
221+
222+
train_loader, test_loader = get_data_loaders(
223+
batch_size=args.batch_size,
224+
data_dir=args.data_dir,
225+
use_balanced_dataset=args.use_balanced_dataset,
226+
images_per_class=args.images_per_class,
227+
)
228+
229+
# initialize the main model
230+
model = CIFAR10Model()
231+
232+
model, train_hist = train_model(
233+
model,
234+
train_loader,
235+
test_loader,
236+
epochs=1,
237+
lr=0.001,
238+
momentum=0.9,
239+
save_path=args.model_path,
240+
)
241+
242+
save_json(train_hist, args.save_pt_json)
243+
244+
# Export the model for et runtime
245+
validation_sample_data = next(iter(test_loader))
246+
img, lbl = validation_sample_data
247+
sample_input = img[0:1, :]
248+
sample_label = lbl[0:1]
249+
250+
ep = export_model(model, sample_input, sample_label)
251+
252+
save_model(ep, args.pte_model_path)
253+
254+
et_model, et_hist = fine_tune_executorch_model(
255+
args.pte_model_path,
256+
args.pte_model_path,
257+
train_loader,
258+
test_loader,
259+
epochs=args.fine_tune_epochs,
260+
learning_rate=args.learning_rate,
261+
)
262+
263+
save_json(et_hist, args.save_et_json)
264+
265+
# Split the model into the pte and ptd files
266+
exported_program = export_model_with_ptd(model, sample_input, sample_label)
267+
268+
exported_program._tensor_data["generic_cifar"] = exported_program._tensor_data.pop(
269+
"_default_external_constant"
270+
)
271+
exported_program.write_tensor_data_to_file(args.ptd_model_dir)
272+
save_model(exported_program, args.split_pte_model_path)
273+
274+
# Finetune the PyTorch model
275+
model, train_hist = train_model(
276+
model,
277+
train_loader,
278+
test_loader,
279+
epochs=args.fine_tune_epochs,
280+
lr=args.learning_rate,
281+
momentum=0.9,
282+
save_path=args.model_path,
283+
)
284+
285+
save_json(train_hist, args.save_pt_json)
286+
287+
288+
if __name__ == "__main__":
289+
main()
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import typing
10+
11+
import torch
12+
13+
14+
class CIFAR10Model(torch.nn.Module):
15+
16+
def __init__(self, num_classes: int = 10) -> None:
17+
super(CIFAR10Model, self).__init__()
18+
self.features = torch.nn.Sequential(
19+
torch.nn.Conv2d(3, 32, kernel_size=3, padding=1),
20+
torch.nn.ReLU(inplace=True),
21+
torch.nn.MaxPool2d(kernel_size=2, stride=2),
22+
torch.nn.Conv2d(32, 64, kernel_size=3, padding=1),
23+
torch.nn.ReLU(inplace=True),
24+
torch.nn.MaxPool2d(kernel_size=2, stride=2),
25+
torch.nn.Conv2d(64, 128, kernel_size=3, padding=1),
26+
torch.nn.ReLU(inplace=True),
27+
torch.nn.MaxPool2d(kernel_size=2, stride=2),
28+
)
29+
30+
self.classifier = torch.nn.Sequential(
31+
torch.nn.Linear(128 * 4 * 4, 512),
32+
torch.nn.ReLU(inplace=True),
33+
torch.nn.Dropout(0.5),
34+
torch.nn.Linear(512, num_classes),
35+
)
36+
37+
def forward(self, x) -> torch.Tensor:
38+
"""
39+
The forward function takes the input image and applies the
40+
convolutional layers and the fully connected layers to
41+
extract the features and classify the image respectively.
42+
"""
43+
x = self.features(x)
44+
x = torch.flatten(x, 1)
45+
x = self.classifier(x)
46+
return x
47+
48+
49+
class ModelWithLoss(torch.nn.Module):
50+
"""
51+
NOTE: A wrapper class that combines a model and the loss function
52+
into a single module. Used for capturing the entire computational
53+
graph, i.e. forward pass and the loss calculation, to be captured
54+
during export. Our objective is to enable on-device training, so
55+
the loss calculation should also be included in the exported graph.
56+
"""
57+
58+
def __init__(
59+
self, model: torch.nn.Module, criterion: torch.nn.CrossEntropyLoss
60+
) -> None:
61+
super().__init__()
62+
self.model = model
63+
self.criterion = criterion
64+
65+
def forward(
66+
self, x: torch.Tensor, target: torch.Tensor
67+
) -> typing.Tuple[torch.Tensor, torch.Tensor]:
68+
# Forward pass through the model
69+
output = self.model(x)
70+
# Calculate loss
71+
loss = self.criterion(output, target)
72+
# Return loss and predicted class
73+
return loss, output.detach().argmax(dim=1)

0 commit comments

Comments
 (0)