Skip to content

Conversation

@tolleybot
Copy link

Add support for torch.export exported models (#1498)

Implements functionality to load and execute PyTorch models exported via torch.export (.pt2 files), enabling .NET applications to run ExportedProgram models as the PyTorch ecosystem transitions from ONNX to torch.export.

Summary

This PR adds support for loading and running AOTInductor-compiled .pt2 models in TorchSharp using torch::inductor::AOTIModelPackageLoader from LibTorch 2.9+.

Key Points:

  • ✅ Inference-only API (no training support)
  • ✅ Models must be compiled with torch._inductor.aoti_compile_and_package() in Python
  • ✅ 30-40% better latency than TorchScript (according to PyTorch docs)
  • ✅ Compatible with LibTorch 2.9+ which includes AOTIModelPackageLoader symbols

Implementation

Native Layer (C++)

Files:

  • src/Native/LibTorchSharp/Utils.h - Added AOTIModelPackageLoader header include
  • src/Native/LibTorchSharp/THSExport.h - C++ API declarations
  • src/Native/LibTorchSharp/THSExport.cpp - Implementation using torch::inductor::AOTIModelPackageLoader

Key Changes:

// Utils.h - Added header include for all files
#include "torch/csrc/inductor/aoti_package/model_package_loader.h"

// THSExport.cpp - Simple wrapper around AOTIModelPackageLoader
ExportedProgramModule THSExport_load(const char* filename)
{
    auto* loader = new torch::inductor::AOTIModelPackageLoader(filename);
    return loader;
}

void THSExport_Module_run(
    const ExportedProgramModule module,
    const Tensor* input_tensors,
    const int input_length,
    Tensor** result_tensors,
    int* result_length)
{
    std::vector<torch::Tensor> inputs;
    // ... convert inputs
    std::vector<torch::Tensor> outputs = module->run(inputs);
    // ... convert outputs
}

Managed Layer (C#)

Files:

  • src/TorchSharp/PInvoke/LibTorchSharp.THSExport.cs - PInvoke declarations
  • src/TorchSharp/Export/ExportedProgram.cs - High-level C# API

API Design:

// Basic usage
using var exported = torch.export.load("model.pt2");
var results = exported.run(input);

// Generic typing for single tensor output
using var exported = torch.export.load<Tensor>("model.pt2");
Tensor result = exported.run(input);

// Generic typing for tuple output
using var exported = torch.export.load<(Tensor, Tensor)>("model.pt2");
var (sum, diff) = exported.run(x, y);

Features:

  • Implements IDisposable for proper resource cleanup
  • Generic ExportedProgram<TResult> for type-safe returns
  • Support for single tensors, arrays, and tuples (up to 3 elements)
  • run(), forward(), and call() methods (all equivalent)

Testing

Files:

  • test/TorchSharpTest/TestExport.cs - 7 comprehensive unit tests
  • test/TorchSharpTest/generate_export_models.py - Python script to generate test models
  • test/TorchSharpTest/*.pt2 - 6 test models

Test Coverage:

[Fact] public void TestLoadExport_SimpleLinear()       // Basic model
[Fact] public void TestLoadExport_LinearReLU()         // Multi-layer
[Fact] public void TestLoadExport_TwoInputs()          // Multiple inputs
[Fact] public void TestLoadExport_TupleOutput()        // Tuple return
[Fact] public void TestLoadExport_ListOutput()         // Array return
[Fact] public void TestLoadExport_Sequential()         // Complex model
[Fact] public void TestExport_LoadNonExistentFile()    // Error handling

All 7 tests pass successfully.

Dependencies

Updated:

  • build/Dependencies.props - Updated LibTorch from 2.7.1 to 2.9.0

LibTorch 2.9.0 includes the torch::inductor::AOTIModelPackageLoader implementation that was previously only available in PyTorch source code.

Technical Details

Two .pt2 Formats

PyTorch has two different .pt2 export formats:

  1. Python-only (from torch.export.save()):

    • Cannot be loaded in C++
    • Uses pickle-based serialization
    • NOT supported by this implementation
  2. AOTInductor-compiled (from torch._inductor.aoti_compile_and_package()):

    • Can be loaded in C++ via AOTIModelPackageLoader
    • Ahead-of-time compiled for specific device
    • ✅ Supported by this implementation

Python Model Generation

To create compatible .pt2 files:

import torch
import torch._inductor

model = MyModule()
example_inputs = (torch.randn(1, 10),)

# Export the model
exported = torch.export.export(model, example_inputs)

# Compile with AOTInductor for C++ compatibility
torch._inductor.aoti_compile_and_package(
    exported,
    package_path="model.pt2"
)

Limitations

  • Inference only: No training, no parameter updates, no gradient computation
  • Device-specific: Models compiled for CPU cannot run on CUDA and vice versa
  • No device movement: Cannot move model between devices at runtime
  • LibTorch 2.9+ required: Older versions don't include AOTIModelPackageLoader

Performance

According to PyTorch documentation, AOTInductor provides:

  • 30-40% better latency compared to TorchScript
  • Optimized for production inference workloads
  • Single-graph representation with only ATen-level operations

Testing

# Build
dotnet build src/TorchSharp/TorchSharp.csproj

# Run tests
dotnet test test/TorchSharpTest/TorchSharpTest.csproj --filter "FullyQualifiedName~TestExport"

Migration Guide

For users currently using TorchScript:

Before (TorchScript):

# Python
torch.jit.save(traced_model, "model.pt")
// C#
var module = torch.jit.load("model.pt");
var result = module.forward(input);

After (torch.export):

# Python
import torch._inductor
exported = torch.export.export(model, example_inputs)
torch._inductor.aoti_compile_and_package(exported, package_path="model.pt2")
// C#
using var exported = torch.export.load("model.pt2");
var result = exported.run(input);

References

Fixes #1498

Implements functionality to load and execute PyTorch models exported via
torch.export (.pt2 files), enabling .NET applications to run ExportedProgram
models as the PyTorch ecosystem transitions from ONNX to torch.export.

## Implementation

### Native Layer
- Add THSExport.h and THSExport.cpp C++ wrappers for torch.export API
- Expose helper functions (toIValue, ReturnHelper) in THSJIT.h
- Add ExportedProgramModule typedef in Utils.h
- Update CMakeLists.txt to include THSExport sources

### Managed Layer
- Add LibTorchSharp.THSExport.cs with PInvoke declarations
- Implement ExportedProgram, ExportedProgram<TResult>, and
  ExportedProgram<T, TResult> classes in new Export namespace
- Provide torch.export.load() API following PyTorch conventions

### Features
- Load .pt2 ExportedProgram files
- Execute forward pass with type-safe generics
- Device management (CPU, CUDA, MPS)
- Dtype conversion support
- Parameters and buffers access
- Training/eval mode compatibility

### Testing
- Add TestExport.cs with 10 comprehensive unit tests
- Include 6 test .pt2 models covering various scenarios:
  - Simple linear model
  - Linear + ReLU
  - Multiple inputs
  - Tuple and list outputs
  - Sequential models
- Update TorchSharpTest.csproj to copy .pt2 files to output

## Technical Details

The implementation leverages ~80% of existing ScriptModule infrastructure,
including TensorOrScalar marshalling and return value processing. The .pt2
format is compatible with torch::jit::load() in LibTorch C++ API.

Fixes dotnet#1498
Implements functionality to load and execute PyTorch models exported via
torch.export (.pt2 files), enabling .NET applications to run ExportedProgram
models as the PyTorch ecosystem transitions from ONNX to torch.export.

## Implementation

### Native Layer
- Add THSExport.h and THSExport.cpp C++ wrappers for AOTIModelPackageLoader API
- Update Utils.h to include torch/csrc/inductor/aoti_package/model_package_loader.h
- Upgrade to LibTorch 2.9.0 which includes AOTIModelPackageLoader symbols

### Managed Layer
- Add LibTorchSharp.THSExport.cs with PInvoke declarations
- Implement ExportedProgram and ExportedProgram<TResult> classes in Export namespace
- Provide torch.export.load() API following PyTorch conventions

### Features
- Load .pt2 ExportedProgram files compiled with torch._inductor.aoti_compile_and_package()
- Execute inference-only forward pass with type-safe generics
- Support for single tensor, array, and tuple (up to 3 elements) outputs
- Proper IDisposable implementation for resource cleanup

### Testing
- Add TestExport.cs with 7 comprehensive unit tests (all passing)
- Include 6 test .pt2 models covering various scenarios:
  - Simple linear model
  - Linear + ReLU
  - Multiple inputs
  - Tuple and list outputs
  - Sequential models
- Add generate_export_models.py for regenerating test models

## Technical Details

The implementation uses torch::inductor::AOTIModelPackageLoader from LibTorch 2.9+
for AOTInductor-compiled models, providing 30-40% better latency than TorchScript.
Models are inference-only and compiled for specific device (CPU/CUDA) at build time.

Note: .pt2 files from torch.export.save() are Python-only and not supported.
Only .pt2 files from torch._inductor.aoti_compile_and_package() work in C++.

Fixes dotnet#1498
@tolleybot
Copy link
Author

@dotnet-policy-service agree

@tolleybot
Copy link
Author

tolleybot commented Oct 30, 2025

Build Failures : Missing LibTorch 2.9.0 Packages

I believe the CI builds are failing because the build system requires .sha files for LibTorch package validation, and these are missing for LibTorch 2.9.0

Missing SHA files:

  • ❌ Linux: libtorch-cxx11-abi-shared-with-deps-2.9.0+cpu.zip.sha
  • ❌ Windows: libtorch-win-shared-with-deps-2.9.0+cpu.zip.sha
  • ✅ macOS arm64: libtorch-macos-arm64-2.9.0.zip.sha (exists)

Package availability check:

  • Linux cxx11-abi: 403 error (not published yet)
  • Windows: Available
  • macOS arm64: Available

Why my local tests passed: I was building against the PyTorch Python installation at
/opt/homebrew/lib/python3.11/site-packages/torch/ which includes LibTorch 2.9.0 with AOTIModelPackageLoader support

Should we wait for PyTorch to publish all LibTorch 2.9.0 packages?

@masaru-kimura-hacarus
Copy link

masaru-kimura-hacarus commented Oct 31, 2025

@tolleybot

Missing SHA files:

  • ❌ Linux: libtorch-cxx11-abi-shared-with-deps-2.9.0+cpu.zip.sha
    ...

Package availability check:

  • Linux cxx11-abi: 403 error (not published yet)
    ...

Should we wait for PyTorch to publish all LibTorch 2.9.0 packages?

  • although i'm not sure about libtorch package naming convention,
    • PyTorch Get Started page shows me that libtorch-shared-with-deps-2.9.0+cpu.zip seems cxx11 ABI.
      image
    • PyTorch upstream released libtorch-cxx11-abi-shared-with-deps-2.7.1+cpu.zip, but no release for 2.8.0 or later in this package naming.
      image
    • OTOH, PyTorch upstream released libtorch-shared-with-deps-2.6.0+cpu.zip or earlier, and released libtorch-shared-with-deps-2.8.0+cpu.zip or later; only 2.7.0 and 2.7.1 are missing in this package naming.
      image

@masaru-kimura-hacarus
Copy link

masaru-kimura-hacarus commented Oct 31, 2025

@tolleybot

  • i'll attached a report created by Deep Research enabled Google Gemini 2.5 Pro, to answer "why libtorch-cxx11-abi-shared-with-deps-2.9.0+cpu.zip doesn't exists".
    Technical Analysis of the LibTorch ZIP File Naming Convention Change.pdf
    • as the executive summary said;

      since PyTorch version 2.8.0, filenames in the format libtorch-cxx11-abi-shared-with-deps-VERSION.zip are no longer present, having been replaced by a unified format: libtorch-shared-with-deps-VERSION.zip.

    • please don't care the last section titled "引用文献" (which is a Japanese word equivalent to "bibliography") uses some Japanse words, since the initial research is done by Japanese prompt and Google Gemini export feature looks malfunction if translation task involved.

Add SHA validation files for LibTorch 2.9.0 packages to enable CI builds.
PyTorch changed naming convention at 2.8.0 from 'libtorch-cxx11-abi-*' to
unified 'libtorch-shared-with-deps-*' (which is cxx11-abi by default).

Added:
- libtorch-shared-with-deps-2.9.0+cpu.zip.sha (Linux)
- libtorch-win-shared-with-deps-2.9.0+cpu.zip.sha (Windows)
- libtorch-win-shared-with-deps-debug-2.9.0+cpu.zip.sha (Windows Debug)

SHA values computed from official PyTorch downloads at download.pytorch.org.
@tolleybot
Copy link
Author

tolleybot commented Oct 31, 2025

@masaru-kimura-hacarus Thank you for the detailed investigation and the Gemini Deep Research report! You're absolutely right. I was looking for the wrong package name.

I've just pushed the correct SHA files using the new naming convention. Let's see if the CI builds pass now

@tolleybot
Copy link
Author

@dotnet-policy-service agree

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add support for torch.export models

2 participants