Skip to content

Commit 769c4bf

Browse files
committed
Squashed commit of the following:
Updated Readme Fixing possible memory leaks and improving code generation for different data types Properly pass build target to cmake Changed the interface to accept multiple inputs Thread-safe pointer map, cosmetic renames
1 parent 257098a commit 769c4bf

File tree

12 files changed

+628
-243
lines changed

12 files changed

+628
-243
lines changed

README.md

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
# Pytorch Fortran bindings
22

3-
The goal of this code is provide Fortran HPC codes with a simple way to use Pytorch deep learning framework.
3+
The goal of this code is to provide Fortran HPC codes with a simple way to use Pytorch deep learning framework.
44
We want Fortran developers to take advantage of rich and optimized Torch ecosystem from within their existing codes.
55
The code is very much work-in-progress right now and any feedback or bug reports are welcome.
66

77
## Features
88

9-
* Define the model convinently in Python, save it and open in Fortran
9+
* Define the model conveniently in Python, save it and open in Fortran
1010
* Pass Fortran arrays into the model, run inference and get output as a native Fortran array
11-
* Train the model from inside Fortran (limit support for now) and save it
11+
* Train the model from inside Fortran and save it
1212
* Run the model on the CPU or the GPU with the data also coming from the CPU or GPU
1313
* Use OpenACC to achieve zero-copy data transfer for the GPU models
1414
* Focus on achieving negligible performance overhead
1515

1616
## Building
1717

18-
To assist with the build, we provide the Docker and [HPCCM](https://github.com/NVIDIA/hpc-container-maker) recipe for the container with all the necessary dependancies installed, see [container](container/)
18+
To assist with the build, we provide the Docker and [HPCCM](https://github.com/NVIDIA/hpc-container-maker) recipe for the container with all the necessary dependencies installed, see [container](container/)
1919

2020
You'll need to mount a folder with the cloned repository into the container, cd into this folder from the running container and execute `./make_nvhpc.sh`, `./make_gcc.sh` or `./make_intel.sh` depending on the compiler you want to use.
2121

@@ -40,4 +40,46 @@ install/bin/python_training ../examples/python_training/model.py
4040

4141
## API
4242

43-
We are working on documenting the API, for now please refer to the examples.
43+
We are working on documenting the full API. Please refer to the examples for more details.
44+
The bindings are provided through the following Fortran classes:
45+
46+
### Class `torch_tensor`
47+
This class represents a light-weight Pytorch representation of a Fortran array. It does not own the data and only keeps the respective pointer.
48+
Supported arrays of ranks up to 7 and datatypes `real32`, `real64`, `int32`, `int64`.
49+
Members:
50+
* `from_array(Fortran array or pointer :: array)` : create the tensor representation of a Fortran array.
51+
* `to_array(pointer :: array)` : create a Fortran pointer from the tensor. This API should be used to convert the returning data of a Pytorch model to the Fortran array.
52+
53+
### Class `torch_tensor_wrap`
54+
This class wraps a few tensors or scalars that can be passed as input into Pytorch models.
55+
Arrays and scalars must be of types `real32`, `real64`, `int32` or `int64`.
56+
Members:
57+
* `add_scalar(scalar)` : add the scalar value into the wrapper.
58+
* `add_tensor(torch_tensor :: tensor)` : add the tensor into the wrapper.
59+
* `add_array(Fortran array or pointe :: array)` : create the tensor representation of a Fortran array and add it into the wrapper.
60+
61+
62+
### Class `torch_module`
63+
This class represents the traced Pytorch model, typically a result of `torch.jit.trace` or `torch.jit.script` call from your Python script. This class in **not thread-safe**. For multi-threaded inference either create a threaded Pytorch model, or use a `torch_module` instance per thread (the latter could be less efficient).
64+
Members:
65+
* `load( character(*) :: filename, integer :: flags)` : load the module from a file. Flag can be set to `module_use_device` to enable the GPU processing.
66+
* `forward(torch_tensor_wrap :: inputs, torch_tensor :: output, integer :: flags)` : run the inference with Pytorch. The tensors and scalars from the `inputs` will be passed into Pytorch and the `output` will contain the result. `flags` is unused now
67+
* `create_optimizer_sgd(real :: learning_rate)` : create an SGD optimizer to use in the following training
68+
* `train(torch_tensor_wrap :: inputs, torch_tensor :: target, real :: loss)` : perform a single training step where `target` is the target result and `loss` is the L2 squared loss returned by the optimizer
69+
* `save(character(*) :: filename)` : save the trained model
70+
71+
### Class `torch_pymodule`
72+
This class represents the Pytorch Python script and required the interpreter to be called. Only one `torch_pymodule` can be opened at a time due to the Python interpreter limitation. Overheads calling this class are higher than with `torch_module`, but contrary to the `torch_module%train` one can now train their Pytorch model with any optimizer, dropouts, etc. The intended usage of this class is to run online training with a complex pipeline that cannot be expressed as TorchScript.
73+
Members:
74+
* `load( character(*) :: filename)` : load the module from a Python script
75+
* `forward(torch_tensor_wrap :: inputs, torch_tensor :: output)` : execute `ftn_pytorch_forward` function from the Python script. The function is expected to accept tensors and scalars and returns one tensor. The tensors and scalars from the `inputs` will be passed as argument and the `output` will contain the result.
76+
* `train(torch_tensor_wrap :: inputs, torch_tensor :: target, real :: loss)` : execute `ftn_pytorch_train` function from the Python script. The function is expected to accept tensors and scalars (with the last argument required to be the target tensor) and returns a tuple of bool `is_completed` and float `loss`. `is_completed` is returned as a result of the `train` function, and `loss` is set accordingly to the Python output. `is_completed` is meant to signify that the training is completed due to any stopping criterion
77+
* `save(character(*) :: filename)` : save the trained model
78+
79+
## Changelog
80+
81+
### v0.3
82+
* Changed interface: `forward` and `train` routines now accept `torch_tensor_wrap` instead of just `torch_tensor`. This allows a user to add multiple inputs consisting of tensors of different size and scalar values
83+
* Fixed possible small memory leaks due to tensor handles
84+
* Fixed build targets in the scripts, they now properly build Release versions by default
85+
* Added a short API help

examples/polynomial/polynomial.f90

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,9 @@ program polynomial
7373
logical, parameter :: use_gpu = .false.
7474
#endif
7575

76-
type(torch_module) :: torch_mod
77-
type(torch_tensor) :: in_tensor, out_tensor, target_tensor
76+
type(torch_module) :: torch_mod
77+
type(torch_tensor) :: out_tensor, target_tensor
78+
type(torch_tensor_wrap) :: in_tensors
7879

7980
real(real32) :: loss
8081
real(real32), dimension(1, batch_size) :: input, target
@@ -107,11 +108,12 @@ program polynomial
107108
end if
108109
call torch_mod%load(in_fname, flag)
109110
call torch_mod%create_optimizer_sgd(0.1)
111+
call in_tensors%create
110112

111113
!$acc data create (input, target) copyin(coeffs)
112114

113115
!$acc host_data use_device(input)
114-
call in_tensor% from_array(input)
116+
call in_tensors%add_array(input)
115117
!$acc end host_data
116118

117119
!$acc host_data use_device(target)
@@ -125,12 +127,12 @@ program polynomial
125127
!$acc update device(input)
126128

127129
call eval_polynomial(coeffs, input, target)
128-
call torch_mod%train(in_tensor, target_tensor, loss)
130+
call torch_mod%train(in_tensors, target_tensor, loss)
129131

130132
if (mod(batch_idx, 100) == 0) then
131133
print "(A,I6,A,F9.6)", "Batch ",batch_idx," loss is ",loss
132134
end if
133-
if (loss < 1e-3) exit
135+
if (loss < 1e-4) exit
134136
end do
135137

136138
if (batch_idx < max_batch_id) then
@@ -145,14 +147,14 @@ program polynomial
145147
!$acc update device(input)
146148
call eval_polynomial(coeffs, input, target)
147149

148-
call torch_mod%forward(in_tensor, out_tensor)
150+
call torch_mod%forward(in_tensors, out_tensor)
149151
call out_tensor%to_array(output)
150152

151153
!$acc update host(target, output)
152154
loss = sum( (target-output)**2 ) / batch_size
153155

154156
print *, target(1,1:4), output(1,1:4)
155-
print "(A,F9.6)", "L2 error of the trained model is ", loss
157+
print "(A,F9.6)", "Mean squared error of the trained model is ", loss
156158

157159
!$acc end data
158160

examples/python_training/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@
2525

2626
def ftn_pytorch_forward(input):
2727
print('Hello from python')
28+
for i in input:
29+
print(i)
2830
return torch.tensor([[1., -1.], [1., -1.]])
2931

30-
def ftn_pytorch_train(input):
32+
def ftn_pytorch_train(input, target):
3133
print('train from python')
3234
return (True, 42.0)

examples/python_training/python_training.f90

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,10 @@ program python_training
2727

2828
integer :: n
2929
type(torch_pymodule) :: torch_pymod
30-
type(torch_tensor) :: t_in, t_out, t_target
30+
type(torch_tensor) :: t_out, t_target
31+
type(torch_tensor_wrap) :: tw_in
3132

32-
real(real32) :: input(224, 224, 3, 10), target(224)
33+
real(real32) :: input(2, 3), target(2), factor
3334
real(real32), pointer :: output(:,:)
3435
real(real32) :: loss
3536
logical :: is_completed
@@ -46,18 +47,23 @@ program python_training
4647
allocate(character(arglen) :: filename)
4748
call get_command_argument(number=1, value=filename, status=stat)
4849

49-
input = 1.0
50-
call t_in%from_array(input)
50+
input(1,:) = 1.0
51+
input(2,:) = 2.0
52+
factor = 3.0
53+
54+
call tw_in%create
55+
call tw_in%add_array(input)
56+
call tw_in%add_scalar(factor)
5157
call t_target%from_array(target)
5258

5359
call torch_pymod%load(filename)
5460
! will call Python function ftn_pytorch_forward(input) -> output
55-
call torch_pymod%forward(t_in, t_out)
61+
call torch_pymod%forward(tw_in, t_out)
5662
call t_out%to_array(output)
5763
print *, output
5864

5965
! will call Python function ftn_pytorch_train(input, target) -> (is_completed, loss)
60-
is_completed = torch_pymod%train(t_in, t_target, loss)
66+
is_completed = torch_pymod%train(tw_in, t_target, loss)
6167
print *, is_completed, loss
6268

6369
end program

examples/resnet_forward/resnet_forward.f90

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ program resnet_forward
2727

2828
integer :: n
2929
type(torch_module) :: torch_mod
30-
type(torch_tensor) :: in_tensor, out_tensor
30+
type(torch_tensor_wrap) :: input_tensors
31+
type(torch_tensor) :: out_tensor
3132

3233
real(real32) :: input(224, 224, 3, 10)
3334
real(real32), pointer :: output(:, :)
@@ -45,9 +46,10 @@ program resnet_forward
4546
call get_command_argument(number=1, value=filename, status=stat)
4647

4748
input = 1.0
48-
call in_tensor%from_array(input)
49+
call input_tensors%create
50+
call input_tensors%add_array(input)
4951
call torch_mod%load(filename)
50-
call torch_mod%forward(in_tensor, out_tensor)
52+
call torch_mod%forward(input_tensors, out_tensor)
5153
call out_tensor%to_array(output)
5254

5355
print *, output(1:5, 1)

make_gnu.sh

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,25 +38,25 @@ mkdir -p $BUILD_PATH/build_proxy $BUILD_PATH/build_fortproxy $BUILD_PATH/build_e
3838
# c++ wrappers
3939
(
4040
cd $BUILD_PATH/build_proxy
41-
cmake -DOPENACC=$OPENACC -DCMAKE_INSTALL_PREFIX=$INSTALL_PATH -DCMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH -DTORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST ../../src/proxy_lib
42-
cmake --build . --config $CONFIG --parallel
41+
cmake -DOPENACC=$OPENACC -DCMAKE_BUILD_TYPE=$CONFIG -DCMAKE_INSTALL_PREFIX=$INSTALL_PATH -DCMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH -DTORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST ../../src/proxy_lib
42+
cmake --build . --parallel
4343
make install
4444
)
4545

4646
# fortran bindings
4747
(
4848
export PATH=$NVPATH:$PATH
4949
cd $BUILD_PATH/build_fortproxy
50-
cmake -DOPENACC=$OPENACC -DCMAKE_Fortran_COMPILER=gfortran -DCMAKE_INSTALL_PREFIX=$INSTALL_PATH -DCMAKE_PREFIX_PATH=$INSTALL_PATH/lib ../../src/f90_bindings/
51-
cmake --build . --config $CONFIG --parallel
50+
cmake -DOPENACC=$OPENACC -DCMAKE_BUILD_TYPE=$CONFIG -DCMAKE_INSTALL_PREFIX=$INSTALL_PATH -DCMAKE_Fortran_COMPILER=gfortran -DCMAKE_PREFIX_PATH=$INSTALL_PATH/lib ../../src/f90_bindings/
51+
cmake --build . --parallel
5252
make install
5353
)
5454

5555
# fortran examples
5656
(
5757
export PATH=$NVPATH:$PATH
5858
cd $BUILD_PATH/build_example
59-
cmake -DOPENACC=$OPENACC -DCMAKE_Fortran_COMPILER=gfortran -DCMAKE_INSTALL_PREFIX=$INSTALL_PATH ../../examples/
60-
cmake --build . --config $CONFIG --parallel
59+
cmake -DOPENACC=$OPENACC -DCMAKE_BUILD_TYPE=$CONFIG -DCMAKE_INSTALL_PREFIX=$INSTALL_PATH -DCMAKE_Fortran_COMPILER=gfortran ../../examples/
60+
cmake --build . --parallel
6161
make install
6262
)

make_intel.sh

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,25 +41,25 @@ mkdir -p $BUILD_PATH/build_proxy $BUILD_PATH/build_fortproxy $BUILD_PATH/build_e
4141
# c++ wrappers
4242
(
4343
cd $BUILD_PATH/build_proxy
44-
cmake -DOPENACC=$OPENACC -DCMAKE_INSTALL_PREFIX=$INSTALL_PATH -DCMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH ../../src/proxy_lib
45-
cmake --build . --config $CONFIG --parallel
44+
cmake -DOPENACC=$OPENACC -DCMAKE_BUILD_TYPE=$CONFIG -DCMAKE_INSTALL_PREFIX=$INSTALL_PATH -DCMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH ../../src/proxy_lib
45+
cmake --build . --parallel
4646
make install
4747
)
4848

4949
# fortran bindings
5050
(
5151
export PATH=$NVPATH:$PATH
5252
cd $BUILD_PATH/build_fortproxy
53-
cmake -DOPENACC=$OPENACC -DCMAKE_Fortran_COMPILER=ifort -DCMAKE_INSTALL_PREFIX=$INSTALL_PATH -DCMAKE_PREFIX_PATH=$INSTALL_PATH/lib ../../src/f90_bindings/
54-
cmake --build . --config $CONFIG --parallel
53+
cmake -DOPENACC=$OPENACC -DCMAKE_BUILD_TYPE=$CONFIG -DCMAKE_INSTALL_PREFIX=$INSTALL_PATH -DCMAKE_Fortran_COMPILER=ifort -DCMAKE_PREFIX_PATH=$INSTALL_PATH/lib ../../src/f90_bindings/
54+
cmake --build . --parallel
5555
make install
5656
)
5757

5858
# fortran examples
5959
(
6060
export PATH=$NVPATH:$PATH
6161
cd $BUILD_PATH/build_example
62-
cmake -DOPENACC=$OPENACC -DCMAKE_Fortran_COMPILER=ifort -DCMAKE_INSTALL_PREFIX=$INSTALL_PATH ../../examples/
63-
cmake --build . --config $CONFIG --parallel
62+
cmake -DOPENACC=$OPENACC -DCMAKE_BUILD_TYPE=$CONFIG -DCMAKE_INSTALL_PREFIX=$INSTALL_PATH -DCMAKE_Fortran_COMPILER=ifort ../../examples/
63+
cmake --build . --parallel
6464
make install
6565
)

make_nvhpc.sh

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,25 +45,25 @@ mkdir -p $BUILD_PATH/build_proxy $BUILD_PATH/build_fortproxy $BUILD_PATH/build_e
4545
# c++ wrappers
4646
(
4747
cd $BUILD_PATH/build_proxy
48-
cmake -DOPENACC=$OPENACC -DCMAKE_INSTALL_PREFIX=$INSTALL_PATH -DCMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH -DTORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST ../../src/proxy_lib
49-
cmake --build . --config $CONFIG --parallel
48+
cmake -DOPENACC=$OPENACC -DCMAKE_BUILD_TYPE=$CONFIG -DCMAKE_INSTALL_PREFIX=$INSTALL_PATH -DCMAKE_CXX_COMPILER=g++ -DCUDA_TOOLKIT_ROOT_DIR=/usr/local/cuda -DCMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH -DTORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST ../../src/proxy_lib
49+
cmake --build . --parallel
5050
make install
5151
)
5252

5353
# fortran bindings
5454
(
5555
export PATH=$NVPATH:$PATH
5656
cd $BUILD_PATH/build_fortproxy
57-
cmake -DOPENACC=$OPENACC -DCMAKE_Fortran_COMPILER=nvfortran -DCMAKE_INSTALL_PREFIX=$INSTALL_PATH -DCMAKE_PREFIX_PATH=$INSTALL_PATH/lib ../../src/f90_bindings/
58-
cmake --build . --config $CONFIG --parallel
57+
cmake -DOPENACC=$OPENACC -DCMAKE_BUILD_TYPE=$CONFIG -DCMAKE_INSTALL_PREFIX=$INSTALL_PATH -DCMAKE_Fortran_COMPILER=nvfortran -DCMAKE_PREFIX_PATH=$INSTALL_PATH/lib ../../src/f90_bindings/
58+
cmake --build . --parallel
5959
make install
6060
)
6161

6262
# fortran examples
6363
(
6464
export PATH=$NVPATH:$PATH
6565
cd $BUILD_PATH/build_example
66-
cmake -DOPENACC=$OPENACC -DCMAKE_Fortran_COMPILER=nvfortran -DCMAKE_INSTALL_PREFIX=$INSTALL_PATH ../../examples/
67-
cmake --build . --config $CONFIG --parallel
66+
cmake -DOPENACC=$OPENACC -DCMAKE_BUILD_TYPE=$CONFIG -DCMAKE_INSTALL_PREFIX=$INSTALL_PATH -DCMAKE_Fortran_COMPILER=nvfortran ../../examples/
67+
cmake --build . --parallel
6868
make install
6969
)

0 commit comments

Comments
 (0)