# Burn Torch Backend [Burn](https://github.com/tracel-ai/burn) Torch backend [![Current Crates.io Version](https://img.shields.io/crates/v/burn-tch.svg)](https://crates.io/crates/burn-tch) [![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-tch/blob/master/README.md) This crate provides a Torch backend for [Burn](https://github.com/tracel-ai/burn) utilizing the [`tch-rs`](https://github.com/LaurentMazare/tch-rs) crate, which offers a Rust interface to the [PyTorch](https://pytorch.org/) C++ API. The backend supports CPU (multithreaded), [CUDA](https://pytorch.org/docs/stable/notes/cuda.html) (multiple GPUs), and [MPS](https://pytorch.org/docs/stable/notes/mps.html) devices (MacOS). ## Installation [`tch-rs`](https://github.com/LaurentMazare/tch-rs) requires the C++ PyTorch library (LibTorch) to be available on your system. By default, the CPU distribution is installed for LibTorch v2.9.0 as required by `tch-rs`.
CUDA To install the latest compatible CUDA distribution, set the `TORCH_CUDA_VERSION` environment variable before the `tch-rs` dependency is retrieved with `cargo`. ```shell export TORCH_CUDA_VERSION=cu128 ``` On Windows: ```powershell $Env:TORCH_CUDA_VERSION = "cu128" ``` > Note: `tch` doesn't expose the downloaded libtorch directory on Windows when using the automatic > download feature, so the `torch_cuda.dll` cannot be detected properly during build. In this case, > you can set the `LIBTORCH` environment variable to point to the `libtorch/` folder in `torch-sys` > `OUT_DIR` (or move the downloaded lib to a different folder and point to it). For example, running the validation sample for the first time could be done with the following commands: ```shell export TORCH_CUDA_VERSION=cu128 cargo run --bin cuda --release ``` **Important:** make sure your driver version is compatible with the selected CUDA version. A CUDA Toolkit installation is not required since LibTorch ships with the appropriate CUDA runtimes. Having the latest driver version is recommended, but you can always take a look at the [toolkit driver version table](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#id4) or [minimum required driver version](https://docs.nvidia.com/deploy/cuda-compatibility/index.html#minor-version-compatibility) (limited feature-set, might not work with all operations).

Once your installation is complete, you should be able to build/run your project. You can also validate your installation by running the appropriate `cpu`, `cuda` or `mps` sample as below. ```shell cargo run --bin cpu --release cargo run --bin cuda --release cargo run --bin mps --release ``` _Note: no MPS distribution is available for automatic download at this time, please check out the [manual instructions](#metal-mps)._ ### Manual Download To install `tch-rs` with a different LibTorch distribution, you will have to manually download the desired LibTorch distribution. The instructions are detailed in the sections below for each platform. | Compute Platform | CPU | GPU | Linux | MacOS | Windows | Android | iOS | WASM | | :------------------------ | :----------------------------: | :-: | :---: | :---: | :-----: | :-----: | :-: | :--: | | [CPU](#cpu) | Yes | No | Yes | Yes | Yes | Yes | Yes | No | | [CUDA](#cuda) | Yes [[1]](#cpu-sup) | Yes | Yes | No | Yes | No | No | No | | [Metal (MPS)](#metal-mps) | No | Yes | No | Yes | No | No | No | No | | Vulkan | Yes | Yes | Yes | Yes | Yes | Yes | No | No | [1] The LibTorch CUDA distribution also comes with CPU support. #### CPU
🐧 Linux First, download the LibTorch CPU distribution. ```shell wget -O libtorch.zip https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-2.9.0%2Bcpu.zip unzip libtorch.zip ``` Then, point to that installation using the `LIBTORCH` and `LD_LIBRARY_PATH` environment variables before building `burn-tch` or a crate which depends on it. ```shell export LIBTORCH=/absolute/path/to/libtorch/ export LD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$LD_LIBRARY_PATH ```

🍎 Mac First, download the LibTorch CPU distribution. ```shell wget -O libtorch.zip https://download.pytorch.org/libtorch/cpu/libtorch-macos-arm64-2.9.0.zip unzip libtorch.zip ``` Then, point to that installation using the `LIBTORCH` and `DYLD_LIBRARY_PATH` environment variables before building `burn-tch` or a crate which depends on it. ```shell export LIBTORCH=/absolute/path/to/libtorch/ export DYLD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$DYLD_LIBRARY_PATH ```

🪟 Windows First, download the LibTorch CPU distribution. ```powershell wget https://download.pytorch.org/libtorch/cpu/libtorch-win-shared-with-deps-2.9.0%2Bcpu.zip -OutFile libtorch.zip Expand-Archive libtorch.zip ``` Then, set the `LIBTORCH` environment variable and append the library to your path as with the PowerShell commands below before building `burn-tch` or a crate which depends on it. ```powershell $Env:LIBTORCH = "/absolute/path/to/libtorch/" $Env:Path += ";/absolute/path/to/libtorch/" ```

#### CUDA LibTorch 2.9.0 currently includes binary distributions with CUDA 12.6, 12.8 or 13.0 runtimes. The manual installation instructions are detailed below for CUDA 12.6, but can be applied to the other CUDA versions by replacing `cu126` with the corresponding version string (e.g., `cu130`).
🐧 Linux First, download the LibTorch CUDA 12.6 distribution. ```shell wget -O libtorch.zip https://download.pytorch.org/libtorch/cu126/libtorch-shared-with-deps-2.9.0%2Bcu126.zip unzip libtorch.zip ``` Then, point to that installation using the `LIBTORCH` and `LD_LIBRARY_PATH` environment variables before building `burn-tch` or a crate which depends on it. ```shell export LIBTORCH=/absolute/path/to/libtorch/ export LD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$LD_LIBRARY_PATH ``` **Note:** make sure your CUDA installation is in your `PATH` and `LD_LIBRARY_PATH`.

🪟 Windows First, download the LibTorch CUDA 12.6 distribution. ```powershell wget https://download.pytorch.org/libtorch/cu126/libtorch-win-shared-with-deps-2.9.0%2Bcu126.zip -OutFile libtorch.zip Expand-Archive libtorch.zip ``` Then, set the `LIBTORCH` environment variable and append the library to your path as with the PowerShell commands below before building `burn-tch` or a crate which depends on it. ```powershell $Env:LIBTORCH = "/absolute/path/to/libtorch/" $Env:Path += ";/absolute/path/to/libtorch/" ```

#### Metal (MPS) There is no official LibTorch distribution with MPS support at this time, so the easiest alternative is to use a PyTorch installation. This requires a Python installation. _Note: MPS acceleration is available on MacOS 12.3+._ ```shell pip install torch==2.9.0 numpy==1.26.4 setuptools export LIBTORCH_USE_PYTORCH=1 export DYLD_LIBRARY_PATH=/path/to/pytorch/lib:$DYLD_LIBRARY_PATH ``` **Note:** if `venv` is used, it should be activated during coding and building, or the compiler may not work properly. ## Example Usage For a simple example, check out any of the test programs in [`src/bin/`](./src/bin/). Each program sets the device to use and performs a simple element-wise addition. For a more complete example using the `tch` backend, take a loot at the [Burn mnist example](https://github.com/tracel-ai/burn/tree/main/examples/mnist). ## Too many environment variables? Try `.cargo/config.toml` ([cargo book](https://doc.rust-lang.org/cargo/reference/config.html#env)). Instead of setting the environments in your shell, you can manually add them to your `.cargo/config.toml`: ```toml [env] LD_LIBRARY_PATH = "/absolute/path/to/libtorch/lib" LIBTORCH = "/absolute/path/to/libtorch/libtorch" ``` Or use bash commands below: ```bash mkdir .cargo cat < .cargo/config.toml [env] LD_LIBRARY_PATH = "/absolute/path/to/libtorch/lib:$LD_LIBRARY_PATH" LIBTORCH = "/absolute/path/to/libtorch/libtorch" EOF ``` This will automatically include the old `LD_LIBRARY_PATH` value in the new one.