Skip to content

Commit 8f03af2

Browse files
authored
Merge pull request #806 from take-cheeze/pt2.2
Support torch 2.2
2 parents 364db08 + f1d445d commit 8f03af2

File tree

10 files changed

+1034
-1516
lines changed

10 files changed

+1034
-1516
lines changed

.flexci/config.pbtxt

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,22 @@ configs {
110110
}
111111
}
112112

113+
configs {
114+
key: "pytorch-pfn-extras.torch202-linux"
115+
value {
116+
requirement {
117+
cpu: 8
118+
gpu: 2
119+
memory: 52
120+
disk: 10
121+
}
122+
time_limit: {
123+
seconds: 1800
124+
}
125+
command: "bash .flexci/linux/main-flexci.sh torch202"
126+
}
127+
}
128+
113129
configs {
114130
key: "pytorch-pfn-extras.torch110-win"
115131
value {
@@ -211,3 +227,20 @@ configs {
211227
command: ".flexci\\windows\\run.bat torch201"
212228
}
213229
}
230+
231+
configs {
232+
key: "pytorch-pfn-extras.torch202-win"
233+
value {
234+
requirement {
235+
cpu: 4
236+
gpu: 2
237+
memory: 24
238+
disk: 10
239+
image: "windows"
240+
}
241+
time_limit: {
242+
seconds: 1200
243+
}
244+
command: ".flexci\\windows\\run.bat torch202"
245+
}
246+
}

.flexci/linux/build_and_push.sh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,15 @@ case "${TARGET}" in
8181
--build-arg pip_install_dep_args="cupy-cuda11x"
8282
;;
8383

84+
torch202 )
85+
# PyTorch 2.2 + Python 3.10
86+
docker_build_and_push \
87+
--build-arg base_image="nvidia/cuda:12.1.0-cudnn8-devel-ubuntu20.04" \
88+
--build-arg python_version="3.10.5" \
89+
--build-arg pip_install_torch_args="torch==2.2.* torchvision==0.17.* -f https://download.pytorch.org/whl/cu121/torch_stable.html" \
90+
--build-arg pip_install_dep_args="cupy-cuda12x"
91+
;;
92+
8493
* )
8594
echo "${1}: Unknown test name."
8695
exit 1

.flexci/windows/_flexci.ps1

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,14 @@ function ActivateCUDA($version) {
5252
$Env:CUDA_PATH = $Env:CUDA_PATH_V11_8
5353
} elseif ($version -eq "11.x") {
5454
$Env:CUDA_PATH = $Env:CUDA_PATH_V11_8
55+
} elseif ($version -eq "12.0") {
56+
$Env:CUDA_PATH = $Env:CUDA_PATH_V12_0
57+
} elseif ($version -eq "12.1") {
58+
$Env:CUDA_PATH = $Env:CUDA_PATH_V12_1
59+
} elseif ($version -eq "12.2") {
60+
$Env:CUDA_PATH = $Env:CUDA_PATH_V12_2
61+
} elseif ($version -eq "12.x") {
62+
$Env:CUDA_PATH = $Env:CUDA_PATH_V12_2
5563
} else {
5664
throw "Unsupported CUDA version: $version"
5765
}

.flexci/windows/test.ps1

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@ if ($test -eq "torch110") {
5151
RunOrDie python -m pip install -U pip "setuptools<59.6"
5252
RunOrDieWithRetry 3 python -m pip install torch==2.1.* torchvision==0.16.* -f https://download.pytorch.org/whl/cu118/torch_stable.html
5353

54+
} elseif ($test -eq "torch202") {
55+
# PyTorch 2.2 + Python 3.10
56+
ActivateCUDA 12.1
57+
ActivatePython 3.10
58+
RunOrDie python -m pip install -U pip "setuptools<59.6"
59+
RunOrDieWithRetry 3 python -m pip install torch==2.2.* torchvision==0.17.* -f https://download.pytorch.org/whl/cu121/torch_stable.html
60+
5461
} else {
5562
throw "Unsupported test variant: $test"
5663
}

.github/workflows/pretest-and-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ jobs:
77
runs-on: ubuntu-22.04
88
strategy:
99
matrix:
10-
torch: ['1.10.*', '1.11.*', '1.12.*', '1.13.*', '2.0.*', '2.1.*']
10+
torch: ['1.10.*', '1.11.*', '1.12.*', '1.13.*', '2.0.*', '2.1.*', '2.2.*']
1111

1212
steps:
1313
- name: Checkout

pytorch_pfn_extras/nn/modules/lazy_batchnorm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
4747
)
4848
if self.track_running_stats:
4949
assert isinstance(self.running_mean, torch.Tensor)
50+
assert isinstance(self.num_features, int)
5051
self.running_mean = torch.zeros(
5152
self.num_features,
5253
device=self.running_mean.device,

pytorch_pfn_extras/onnx/pfto_exporter/export.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ def optimize_onnx(self, graph: torch._C.Graph) -> torch._C.Graph:
507507
for k, t in folded.items():
508508
c: torch._C.Value = graph.create("onnx::Constant", 1).output()
509509
assert isinstance(t, torch.Tensor)
510-
c.node().t_("value", cast(torch.Tensor, t))
510+
c.node().t_("value", t)
511511
graph.prependNode(c.node())
512512
# TODO(twata): Determine folded nodes from original graph and document it
513513
self.node_doc_string[c.node()] = f"Constant folded node: {input_table[k]}"
@@ -1155,9 +1155,9 @@ def _convert(self) -> None:
11551155
GLOBALS.onnx_shape_inference = False
11561156
else:
11571157
to_utils.__IN_ONNX_EXPORT = True # type: ignore[attr-defined]
1158-
sym_hel._set_opset_version(self.opset_version) # type: ignore[no-untyped-call]
1159-
sym_hel._set_operator_export_type(self.operator_export_type) # type: ignore[no-untyped-call]
1160-
sym_hel._set_onnx_shape_inference( # type: ignore[no-untyped-call]
1158+
sym_hel._set_opset_version(self.opset_version) # type: ignore[attr-defined, no-untyped-call]
1159+
sym_hel._set_operator_export_type(self.operator_export_type) # type: ignore[attr-defined, no-untyped-call]
1160+
sym_hel._set_onnx_shape_inference( # type: ignore[attr-defined, no-untyped-call]
11611161
False # TODO(twata): Use `self.onnx_shape_inference`
11621162
)
11631163
with record("pfto.original_outputs"):
@@ -1177,11 +1177,11 @@ def _convert(self) -> None:
11771177
else:
11781178
to_utils.__IN_ONNX_EXPORT = False # type: ignore[attr-defined]
11791179
if prev_opset_version is not None:
1180-
sym_hel._set_opset_version(prev_opset_version) # type: ignore[no-untyped-call]
1180+
sym_hel._set_opset_version(prev_opset_version) # type: ignore[attr-defined, no-untyped-call]
11811181
if prev_export_type is not None:
1182-
sym_hel._set_operator_export_type(prev_export_type) # type: ignore[no-untyped-call]
1182+
sym_hel._set_operator_export_type(prev_export_type) # type: ignore[attr-defined, no-untyped-call]
11831183
if prev_shape_inference is not None:
1184-
sym_hel._set_onnx_shape_inference(prev_shape_inference) # type: ignore[no-untyped-call]
1184+
sym_hel._set_onnx_shape_inference(prev_shape_inference) # type: ignore[attr-defined, no-untyped-call]
11851185

11861186
def generate(self, f: Union[str, typing.IO]) -> None:
11871187
with record("pfto.write_to_file"):

pytorch_pfn_extras/runtime/_to.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ def to(
5858
runtime = runtime_class(device, options)
5959
obj = module_or_tensor
6060
if isinstance(obj, torch.nn.Module):
61-
obj = runtime.move_module(obj)
62-
for module in obj.modules():
61+
mod = runtime.move_module(obj)
62+
for module in mod.modules():
6363
ppe.runtime._runtime._set_module_runtime_tag(module, runtime)
64-
return obj
64+
return mod
6565
elif isinstance(obj, torch.Tensor):
6666
return runtime.move_tensor(obj)
6767
else:

0 commit comments

Comments
 (0)