feat: update workspace paths and enhance gitignore

- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
This commit is contained in:
2026-03-05 19:39:14 +01:00
parent 4bb7ca9074
commit 3a67c0979c
1605 changed files with 537032 additions and 2 deletions

View File

@@ -0,0 +1,38 @@
[package]
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
categories = ["science"]
description = "LibTorch backend for the Burn framework using the tch bindings."
documentation = "https://docs.rs/burn-tch"
edition.workspace = true
keywords = ["deep-learning", "machine-learning", "data"]
license.workspace = true
name = "burn-tch"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-tch"
version.workspace = true
[lints]
workspace = true
[features]
default = ["std"]
std = ["burn-backend/std"]
doc = ["tch/doc-only"]
tracing = [
"burn-backend/tracing",
]
[dependencies]
burn-backend = { path = "../burn-backend", version = "=0.21.0-pre.2", default-features = false }
libc = { workspace = true }
log = { workspace = true }
tch = { workspace = true, features = ["download-libtorch"] }
torch-sys = { workspace = true } # for build script lib dir detection
[build-dependencies]
cc = "1.2.56"
[package.metadata.docs.rs]
features = ["doc"]
rustdoc-args = ["--cfg", "docsrs"]

View File

@@ -0,0 +1 @@
../../LICENSE-APACHE

View File

@@ -0,0 +1 @@
../../LICENSE-MIT

View File

@@ -0,0 +1,246 @@
# Burn Torch Backend
[Burn](https://github.com/tracel-ai/burn) Torch backend
[![Current Crates.io Version](https://img.shields.io/crates/v/burn-tch.svg)](https://crates.io/crates/burn-tch)
[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-tch/blob/master/README.md)
This crate provides a Torch backend for [Burn](https://github.com/tracel-ai/burn) utilizing the
[`tch-rs`](https://github.com/LaurentMazare/tch-rs) crate, which offers a Rust interface to the
[PyTorch](https://pytorch.org/) C++ API.
The backend supports CPU (multithreaded), [CUDA](https://pytorch.org/docs/stable/notes/cuda.html)
(multiple GPUs), and [MPS](https://pytorch.org/docs/stable/notes/mps.html) devices (MacOS).
## Installation
[`tch-rs`](https://github.com/LaurentMazare/tch-rs) requires the C++ PyTorch library (LibTorch) to
be available on your system.
By default, the CPU distribution is installed for LibTorch v2.9.0 as required by `tch-rs`.
<details>
<summary><strong>CUDA</strong></summary>
To install the latest compatible CUDA distribution, set the `TORCH_CUDA_VERSION` environment
variable before the `tch-rs` dependency is retrieved with `cargo`.
```shell
export TORCH_CUDA_VERSION=cu128
```
On Windows:
```powershell
$Env:TORCH_CUDA_VERSION = "cu128"
```
> Note: `tch` doesn't expose the downloaded libtorch directory on Windows when using the automatic
> download feature, so the `torch_cuda.dll` cannot be detected properly during build. In this case,
> you can set the `LIBTORCH` environment variable to point to the `libtorch/` folder in `torch-sys`
> `OUT_DIR` (or move the downloaded lib to a different folder and point to it).
For example, running the validation sample for the first time could be done with the following
commands:
```shell
export TORCH_CUDA_VERSION=cu128
cargo run --bin cuda --release
```
**Important:** make sure your driver version is compatible with the selected CUDA version. A CUDA
Toolkit installation is not required since LibTorch ships with the appropriate CUDA runtimes. Having
the latest driver version is recommended, but you can always take a look at the
[toolkit driver version table](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#id4)
or
[minimum required driver version](https://docs.nvidia.com/deploy/cuda-compatibility/index.html#minor-version-compatibility)
(limited feature-set, might not work with all operations).
</details><br>
Once your installation is complete, you should be able to build/run your project. You can also
validate your installation by running the appropriate `cpu`, `cuda` or `mps` sample as below.
```shell
cargo run --bin cpu --release
cargo run --bin cuda --release
cargo run --bin mps --release
```
_Note: no MPS distribution is available for automatic download at this time, please check out the
[manual instructions](#metal-mps)._
### Manual Download
To install `tch-rs` with a different LibTorch distribution, you will have to manually download the
desired LibTorch distribution. The instructions are detailed in the sections below for each
platform.
| Compute Platform | CPU | GPU | Linux | MacOS | Windows | Android | iOS | WASM |
| :------------------------ | :----------------------------: | :-: | :---: | :---: | :-----: | :-----: | :-: | :--: |
| [CPU](#cpu) | Yes | No | Yes | Yes | Yes | Yes | Yes | No |
| [CUDA](#cuda) | Yes <sup>[[1]](#cpu-sup)</sup> | Yes | Yes | No | Yes | No | No | No |
| [Metal (MPS)](#metal-mps) | No | Yes | No | Yes | No | No | No | No |
| Vulkan | Yes | Yes | Yes | Yes | Yes | Yes | No | No |
<sup><a id="cpu-sup">[1]</a> The LibTorch CUDA distribution also comes with CPU support.</sup>
#### CPU
<details open>
<summary><strong>🐧 Linux</strong></summary>
First, download the LibTorch CPU distribution.
```shell
wget -O libtorch.zip https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-2.9.0%2Bcpu.zip
unzip libtorch.zip
```
Then, point to that installation using the `LIBTORCH` and `LD_LIBRARY_PATH` environment variables
before building `burn-tch` or a crate which depends on it.
```shell
export LIBTORCH=/absolute/path/to/libtorch/
export LD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$LD_LIBRARY_PATH
```
</details><br>
<details>
<summary><strong>🍎 Mac</strong></summary>
First, download the LibTorch CPU distribution.
```shell
wget -O libtorch.zip https://download.pytorch.org/libtorch/cpu/libtorch-macos-arm64-2.9.0.zip
unzip libtorch.zip
```
Then, point to that installation using the `LIBTORCH` and `DYLD_LIBRARY_PATH` environment variables
before building `burn-tch` or a crate which depends on it.
```shell
export LIBTORCH=/absolute/path/to/libtorch/
export DYLD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$DYLD_LIBRARY_PATH
```
</details><br>
<details>
<summary><strong>🪟 Windows</strong></summary>
First, download the LibTorch CPU distribution.
```powershell
wget https://download.pytorch.org/libtorch/cpu/libtorch-win-shared-with-deps-2.9.0%2Bcpu.zip -OutFile libtorch.zip
Expand-Archive libtorch.zip
```
Then, set the `LIBTORCH` environment variable and append the library to your path as with the
PowerShell commands below before building `burn-tch` or a crate which depends on it.
```powershell
$Env:LIBTORCH = "/absolute/path/to/libtorch/"
$Env:Path += ";/absolute/path/to/libtorch/"
```
</details><br>
#### CUDA
LibTorch 2.9.0 currently includes binary distributions with CUDA 12.6, 12.8 or 13.0 runtimes. The
manual installation instructions are detailed below for CUDA 12.6, but can be applied to the other
CUDA versions by replacing `cu126` with the corresponding version string (e.g., `cu130`).
<details open>
<summary><strong>🐧 Linux</strong></summary>
First, download the LibTorch CUDA 12.6 distribution.
```shell
wget -O libtorch.zip https://download.pytorch.org/libtorch/cu126/libtorch-shared-with-deps-2.9.0%2Bcu126.zip
unzip libtorch.zip
```
Then, point to that installation using the `LIBTORCH` and `LD_LIBRARY_PATH` environment variables
before building `burn-tch` or a crate which depends on it.
```shell
export LIBTORCH=/absolute/path/to/libtorch/
export LD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$LD_LIBRARY_PATH
```
**Note:** make sure your CUDA installation is in your `PATH` and `LD_LIBRARY_PATH`.
</details><br>
<details>
<summary><strong>🪟 Windows</strong></summary>
First, download the LibTorch CUDA 12.6 distribution.
```powershell
wget https://download.pytorch.org/libtorch/cu126/libtorch-win-shared-with-deps-2.9.0%2Bcu126.zip -OutFile libtorch.zip
Expand-Archive libtorch.zip
```
Then, set the `LIBTORCH` environment variable and append the library to your path as with the
PowerShell commands below before building `burn-tch` or a crate which depends on it.
```powershell
$Env:LIBTORCH = "/absolute/path/to/libtorch/"
$Env:Path += ";/absolute/path/to/libtorch/"
```
</details><br>
#### Metal (MPS)
There is no official LibTorch distribution with MPS support at this time, so the easiest alternative
is to use a PyTorch installation. This requires a Python installation.
_Note: MPS acceleration is available on MacOS 12.3+._
```shell
pip install torch==2.9.0 numpy==1.26.4 setuptools
export LIBTORCH_USE_PYTORCH=1
export DYLD_LIBRARY_PATH=/path/to/pytorch/lib:$DYLD_LIBRARY_PATH
```
**Note:** if `venv` is used, it should be activated during coding and building, or the compiler may
not work properly.
## Example Usage
For a simple example, check out any of the test programs in [`src/bin/`](./src/bin/). Each program
sets the device to use and performs a simple element-wise addition.
For a more complete example using the `tch` backend, take a loot at the
[Burn mnist example](https://github.com/tracel-ai/burn/tree/main/examples/mnist).
## Too many environment variables?
Try `.cargo/config.toml` ([cargo book](https://doc.rust-lang.org/cargo/reference/config.html#env)).
Instead of setting the environments in your shell, you can manually add them to your
`.cargo/config.toml`:
```toml
[env]
LD_LIBRARY_PATH = "/absolute/path/to/libtorch/lib"
LIBTORCH = "/absolute/path/to/libtorch/libtorch"
```
Or use bash commands below:
```bash
mkdir .cargo
cat <<EOF > .cargo/config.toml
[env]
LD_LIBRARY_PATH = "/absolute/path/to/libtorch/lib:$LD_LIBRARY_PATH"
LIBTORCH = "/absolute/path/to/libtorch/libtorch"
EOF
```
This will automatically include the old `LD_LIBRARY_PATH` value in the new one.

View File

@@ -0,0 +1,243 @@
// The LIBTORCH environment variable can be used to specify the directory
// where libtorch has been installed.
// When not specified this script downloads the cpu version for libtorch
// and extracts it in OUT_DIR.
//
// On Linux, the TORCH_CUDA_VERSION environment variable can be used,
// like 9.0, 90, or cu90 to specify the version of CUDA to use for libtorch.
use std::path::{Path, PathBuf};
use std::{env, fs};
const PYTHON_PRINT_PYTORCH_DETAILS: &str = r"
import torch
from torch.utils import cpp_extension
print('LIBTORCH_VERSION:', torch.__version__.split('+')[0])
print('LIBTORCH_CXX11:', torch._C._GLIBCXX_USE_CXX11_ABI)
for include_path in cpp_extension.include_paths():
print('LIBTORCH_INCLUDE:', include_path)
for library_path in cpp_extension.library_paths():
print('LIBTORCH_LIB:', library_path)
";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Os {
Linux,
Macos,
Windows,
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
struct SystemInfo {
os: Os,
cxx11_abi: String,
libtorch_include_dirs: Vec<PathBuf>,
libtorch_lib_dir: PathBuf,
}
fn env_var_rerun(name: &str) -> Result<String, env::VarError> {
println!("cargo:rerun-if-env-changed={name}");
env::var(name)
}
impl SystemInfo {
fn new() -> Option<Self> {
let os = match env::var("CARGO_CFG_TARGET_OS")
.expect("Unable to get TARGET_OS")
.as_str()
{
"linux" => Os::Linux,
"windows" => Os::Windows,
"macos" => Os::Macos,
os => panic!("unsupported TARGET_OS '{os}'"),
};
// Locate the currently active Python binary, similar to:
// https://github.com/PyO3/maturin/blob/243b8ec91d07113f97a6fe74d9b2dcb88086e0eb/src/target.rs#L547
let python_interpreter = match os {
Os::Windows => PathBuf::from("python.exe"),
Os::Linux | Os::Macos => {
if env::var_os("VIRTUAL_ENV").is_some() {
PathBuf::from("python")
} else {
PathBuf::from("python3")
}
}
};
let mut libtorch_include_dirs = vec![];
let mut libtorch_lib_dir = None;
let cxx11_abi = if env_var_rerun("LIBTORCH_USE_PYTORCH").is_ok() {
let output = std::process::Command::new(&python_interpreter)
.arg("-c")
.arg(PYTHON_PRINT_PYTORCH_DETAILS)
.output()
.expect("error running python interpreter");
let mut cxx11_abi = None;
for line in String::from_utf8_lossy(&output.stdout).lines() {
match line.strip_prefix("LIBTORCH_CXX11: ") {
Some("True") => cxx11_abi = Some("1".to_owned()),
Some("False") => cxx11_abi = Some("0".to_owned()),
_ => {}
}
if let Some(path) = line.strip_prefix("LIBTORCH_INCLUDE: ") {
libtorch_include_dirs.push(PathBuf::from(path))
}
if let Some(path) = line.strip_prefix("LIBTORCH_LIB: ") {
libtorch_lib_dir = Some(PathBuf::from(path))
}
}
match cxx11_abi {
Some(cxx11_abi) => cxx11_abi,
None => panic!("no cxx11 abi returned by python {output:?}"),
}
} else {
let libtorch = Self::prepare_libtorch_dir(os)?;
let includes = env_var_rerun("LIBTORCH_INCLUDE")
.map(PathBuf::from)
.unwrap_or_else(|_| libtorch.clone());
let lib = env_var_rerun("LIBTORCH_LIB")
.map(PathBuf::from)
.unwrap_or_else(|_| libtorch.clone());
libtorch_include_dirs.push(includes.join("include"));
libtorch_include_dirs.push(includes.join("include/torch/csrc/api/include"));
if lib.ends_with("lib") {
// DEP_TCH_LIBTORCH_LIB might already point to /lib
libtorch_lib_dir = Some(lib);
} else {
libtorch_lib_dir = Some(lib.join("lib"));
}
env_var_rerun("LIBTORCH_CXX11_ABI").unwrap_or_else(|_| "1".to_owned())
};
let libtorch_lib_dir = libtorch_lib_dir?;
Some(Self {
os,
cxx11_abi,
libtorch_include_dirs,
libtorch_lib_dir,
})
}
fn check_system_location(os: Os) -> Option<PathBuf> {
match os {
Os::Linux => Path::new("/usr/lib/libtorch.so")
.exists()
.then(|| PathBuf::from("/usr")),
_ => None,
}
}
fn prepare_libtorch_dir(os: Os) -> Option<PathBuf> {
if let Ok(libtorch) = env_var_rerun("DEP_TCH_LIBTORCH_LIB") {
Some(PathBuf::from(libtorch))
} else if let Ok(libtorch) = env_var_rerun("LIBTORCH") {
Some(PathBuf::from(libtorch))
} else if let Some(pathbuf) = Self::check_system_location(os) {
Some(pathbuf)
} else {
check_out_dir()
}
}
fn make(&self, use_cuda: bool, use_hip: bool) {
let cuda_dependency = if use_cuda || use_hip {
"src/cuda_hack/dummy_cuda_dependency.cpp"
} else {
"src/cuda_hack/fake_cuda_dependency.cpp"
};
println!("cargo:rerun-if-changed={cuda_dependency}");
match self.os {
Os::Linux | Os::Macos => {
cc::Build::new()
.cpp(true)
.pic(true)
.warnings(false)
.includes(&self.libtorch_include_dirs)
.flag(format!("-Wl,-rpath={}", self.libtorch_lib_dir.display()))
.flag("-std=c++17")
.flag(format!("-D_GLIBCXX_USE_CXX11_ABI={}", self.cxx11_abi))
.files(&[cuda_dependency])
.compile("burn-tch");
}
Os::Windows => {
cc::Build::new()
.cpp(true)
.pic(true)
.warnings(false)
.includes(&self.libtorch_include_dirs)
.flag("/std:c++17")
.files(&[cuda_dependency])
.compile("burn-tch");
}
};
}
fn make_cpu() {
let cuda_dependency = "src/cuda_hack/fake_cuda_dependency.cpp";
println!("cargo:rerun-if-changed={cuda_dependency}");
let os = env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS");
match os.as_str() {
"windows" => {
cc::Build::new()
.cpp(true)
.pic(true)
.warnings(false)
.flag("/std:c++17")
.files(&[cuda_dependency])
.compile("burn-tch");
}
_ => {
cc::Build::new()
.cpp(true)
.pic(true)
.warnings(false)
.flag("-std=c++17")
.files(&[cuda_dependency])
.compile("tch");
}
};
}
}
fn check_out_dir() -> Option<PathBuf> {
let out_dir = env_var_rerun("OUT_DIR").ok()?;
let libtorch_dir = PathBuf::from(out_dir).join("libtorch");
libtorch_dir.exists().then_some(libtorch_dir)
}
fn main() {
let system_info = SystemInfo::new();
let out_dir = env_var_rerun("OUT_DIR").expect("Failed to get out dir");
let mut gpu_found = false;
let found_dir = system_info.is_some();
if let Some(system_info) = &system_info {
let si_lib = &system_info.libtorch_lib_dir;
let use_cuda =
si_lib.join("libtorch_cuda.so").exists() || si_lib.join("torch_cuda.dll").exists();
let use_hip =
si_lib.join("libtorch_hip.so").exists() || si_lib.join("torch_hip.dll").exists();
system_info.make(use_cuda, use_hip);
gpu_found = use_cuda || use_hip;
} else {
SystemInfo::make_cpu();
}
let check_file = PathBuf::from(out_dir).join("tch_gpu_check.rs");
if gpu_found {
fs::write(check_file, "#[allow(clippy::no_effect)]\n()").unwrap();
} else {
let message = if !found_dir {
r#"Could not find libtorch dir.
If you are trying to use the automatically downloaded version, the path is not directly available on Windows. Instead, try setting the `LIBTORCH` environment variable for the manual download instructions.
If the library has already been downloaded in the torch-sys OUT_DIR, you can point the variable to this path (or move the downloaded lib and point to it)."#
} else {
"No libtorch_cuda or libtorch_hip found. Download the GPU version of libtorch to use a GPU device"
};
fs::write(check_file, format!("panic!(\"{message}\")")).unwrap();
}
}

View File

@@ -0,0 +1,175 @@
use std::marker::PhantomData;
use crate::IntoKind;
use super::TchTensor;
use super::element::TchElement;
use burn_backend::backend::{Backend, DeviceId, DeviceOps, ExecutionError};
use burn_backend::ops::IntTensorOps;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
/// The device struct when using the `tch` backend.
///
/// Note that you need to provide the device index when using Cuda.
///
/// # Example
///
/// ```no_run
/// use burn_tch::LibTorchDevice;
///
/// let device_gpu_1 = LibTorchDevice::Cuda(0); // First GPU
/// let device_gpu_2 = LibTorchDevice::Cuda(1); // Second GPU
/// let device_cpu = LibTorchDevice::Cpu; // CPU
/// let device_mps = LibTorchDevice::Mps; // Metal Performance Shaders
/// let device_vulkan = LibTorchDevice::Vulkan; // Vulkan
/// ```
#[derive(Default)]
pub enum LibTorchDevice {
/// CPU device.
#[default]
Cpu,
/// Cuda device with the given index. The index is the index of the Cuda device in the list of
/// all Cuda devices found on the system.
Cuda(usize),
/// Metal Performance Shaders device.
Mps,
/// Vulkan device.
Vulkan,
}
impl From<LibTorchDevice> for tch::Device {
#[allow(
unreachable_code,
reason = "CUDA branch always panics if the library is missing"
)]
fn from(device: LibTorchDevice) -> Self {
match device {
LibTorchDevice::Cpu => tch::Device::Cpu,
LibTorchDevice::Cuda(_num) => {
include!(concat!(env!("OUT_DIR"), "/tch_gpu_check.rs"));
tch::Device::Cuda(_num)
}
LibTorchDevice::Mps => tch::Device::Mps,
LibTorchDevice::Vulkan => tch::Device::Vulkan,
}
}
}
impl From<tch::Device> for LibTorchDevice {
fn from(device: tch::Device) -> Self {
match device {
tch::Device::Cpu => LibTorchDevice::Cpu,
tch::Device::Cuda(num) => LibTorchDevice::Cuda(num),
tch::Device::Mps => LibTorchDevice::Mps,
tch::Device::Vulkan => LibTorchDevice::Vulkan,
}
}
}
impl burn_backend::Device for LibTorchDevice {
fn from_id(device_id: DeviceId) -> Self {
match device_id.type_id {
0 => Self::Cuda(device_id.index_id as usize),
1 => Self::Mps,
2 => Self::Cpu,
3 => Self::Vulkan,
_ => LibTorchDevice::Cpu,
}
}
fn to_id(&self) -> DeviceId {
match self {
LibTorchDevice::Cuda(index) => DeviceId::new(0, *index as u32),
LibTorchDevice::Mps => DeviceId::new(1, 0),
LibTorchDevice::Cpu => DeviceId::new(2, 0),
LibTorchDevice::Vulkan => DeviceId::new(3, 0),
}
}
fn device_count(_type_id: u16) -> usize {
// TODO: Somehow find the info using the tch API.
1
}
}
impl DeviceOps for LibTorchDevice {}
/// Tensor backend that uses `LibTorch` with the [tch] crate for executing tensor operations.
///
/// This backend is compatible with a wide range of hardwares ranging from CPUs to GPUs, but
/// requires `LibTorch` to be installed correctly. The CPU version can be downloaded
/// automatically and the CUDA version as well by setting the `TORCH_CUDA_VERSION` environment
/// variable. For more complex configurations, check out the manual installation for
/// [burn-tch](https://github.com/tracel-ai/burn/tree/main/crates/burn-tch).
///
/// Refer to the [tch] crate for more information.
#[derive(Clone, Copy, Default, Debug)]
pub struct LibTorch<E = f32> {
_e: PhantomData<E>,
}
impl<E: TchElement> Backend for LibTorch<E> {
type Device = LibTorchDevice;
type FloatTensorPrimitive = TchTensor;
type FloatElem = E;
type IntTensorPrimitive = TchTensor;
type IntElem = i64;
type BoolTensorPrimitive = TchTensor;
type BoolElem = bool;
type QuantizedTensorPrimitive = TchTensor;
fn seed(_device: &Self::Device, seed: u64) {
tch::manual_seed(seed as i64);
}
fn ad_enabled(_device: &Self::Device) -> bool {
false
}
fn name(device: &Self::Device) -> String {
match device {
LibTorchDevice::Cpu => "libtorch<cpu>",
LibTorchDevice::Cuda(_) => "libtorch<cuda>",
LibTorchDevice::Mps => "libtorch<metal>",
LibTorchDevice::Vulkan => "libtorch<vulkan>",
}
.to_string()
}
fn sync(device: &Self::Device) -> Result<(), ExecutionError> {
match device {
LibTorchDevice::Cpu => (),
LibTorchDevice::Cuda(index) => {
tch::Cuda::synchronize(*index as i64);
}
_ => {
// When there is no explicit way to synchronize, we write and read one value to sync
burn_backend::read_sync(Self::int_into_data(Self::int_zeros(
[1].into(),
device,
E::dtype().into(),
)))
.unwrap();
}
};
Ok(())
}
fn dtype_usage(
_device: &Self::Device,
dtype: burn_backend::DType,
) -> burn_backend::DTypeUsageSet {
if dtype.try_into_kind().is_ok() {
burn_backend::DTypeUsage::general()
} else {
burn_backend::DTypeUsageSet::empty()
}
}
}

View File

@@ -0,0 +1,14 @@
use burn_backend::{TensorMetadata, ops::FloatTensorOps};
use burn_tch::{LibTorch, LibTorchDevice};
fn main() {
type B = LibTorch<f32>;
let device = LibTorchDevice::Cpu;
// Creation of two tensors, the first with explicit values and the second one with ones, with the same shape as the first
let tensor_1 = B::float_from_data([[2f32, 3.], [4., 5.]].into(), &device);
let tensor_2 = B::float_ones(tensor_1.shape(), &device, tensor_1.dtype().into());
// Print the element-wise addition of the two tensors.
println!("{}", B::float_add(tensor_1, tensor_2));
}

View File

@@ -0,0 +1,19 @@
use burn_backend::{TensorMetadata, ops::FloatTensorOps};
use burn_tch::{LibTorch, LibTorchDevice};
fn main() {
assert!(
tch::utils::has_cuda(),
"Could not detect valid CUDA configuration"
);
type B = LibTorch<f32>;
let device = LibTorchDevice::Cuda(0);
// Creation of two tensors, the first with explicit values and the second one with ones, with the same shape as the first
let tensor_1 = B::float_from_data([[2f32, 3.], [4., 5.]].into(), &device);
let tensor_2 = B::float_ones(tensor_1.shape(), &device, tensor_1.dtype().into());
// Print the element-wise addition of the two tensors.
println!("{}", B::float_add(tensor_1, tensor_2));
}

View File

@@ -0,0 +1,16 @@
use burn_backend::{TensorMetadata, ops::FloatTensorOps};
use burn_tch::{LibTorch, LibTorchDevice};
fn main() {
assert!(tch::utils::has_mps(), "Could not detect MPS");
type B = LibTorch<f32>;
let device = LibTorchDevice::Mps;
// Creation of two tensors, the first with explicit values and the second one with ones, with the same shape as the first
let tensor_1 = B::float_from_data([[2f32, 3.], [4., 5.]].into(), &device);
let tensor_2 = B::float_ones(tensor_1.shape(), &device, tensor_1.dtype().into());
// Print the element-wise addition of the two tensors.
println!("{}", B::float_add(tensor_1, tensor_2));
}

View File

@@ -0,0 +1,28 @@
#include <iostream>
#include <stdexcept>
#include <stdint.h>
#include <stdio.h>
using namespace std;
extern "C" {
void dummy_cuda_dependency();
}
struct cublasContext;
namespace at {
namespace cuda {
cublasContext *getCurrentCUDABlasHandle();
int warp_size();
} // namespace cuda
} // namespace at
char *magma_strerror(int err);
void dummy_cuda_dependency() {
try {
at::cuda::getCurrentCUDABlasHandle();
at::cuda::warp_size();
} catch (std::exception &e) {
if (getenv("TCH_PRINT_CUDA_INIT_ERROR") != nullptr) {
std::cerr << "error initializing cuda: " << e.what() << std::endl;
}
}
}

View File

@@ -0,0 +1,5 @@
extern "C" {
void dummy_cuda_dependency();
}
void dummy_cuda_dependency() {}

View File

@@ -0,0 +1,51 @@
use burn_backend::Element;
use burn_backend::{bf16, f16};
/// The element type for the tch backend.
pub trait TchElement: Element + tch::kind::Element {
/// Returns the associated tensor kind for [`tch::kind::Element`].
fn kind() -> tch::Kind {
Self::KIND
}
}
impl TchElement for f64 {}
impl TchElement for f32 {}
impl TchElement for f16 {}
impl TchElement for bf16 {
fn kind() -> tch::Kind {
let mut kind = <Self as tch::kind::Element>::KIND;
// Incorrect kind mapping in tch definitions, force bfloat16
if matches!(Self::dtype(), burn_backend::DType::BF16) && kind == tch::Kind::Half {
kind = tch::Kind::BFloat16
}
kind
}
}
impl TchElement for i64 {}
impl TchElement for i32 {}
impl TchElement for i16 {}
impl TchElement for i8 {}
impl TchElement for u8 {}
impl TchElement for bool {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_elem_kinds() {
assert_eq!(f64::kind(), tch::Kind::Double);
assert_eq!(f32::kind(), tch::Kind::Float);
assert_eq!(f16::kind(), tch::Kind::Half);
assert_eq!(bf16::kind(), tch::Kind::BFloat16);
assert_eq!(i64::kind(), tch::Kind::Int64);
assert_eq!(i32::kind(), tch::Kind::Int);
assert_eq!(i16::kind(), tch::Kind::Int16);
assert_eq!(i8::kind(), tch::Kind::Int8);
assert_eq!(bool::kind(), tch::Kind::Bool);
}
}

View File

@@ -0,0 +1,14 @@
#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_cfg))]
#![allow(clippy::single_range_in_vec_init)]
//! Burn Tch Backend
mod backend;
mod element;
mod ops;
mod tensor;
pub use backend::*;
pub use element::*;
pub use tensor::*;

View File

@@ -0,0 +1,37 @@
use crate::{LibTorch, TchTensor, element::TchElement};
use burn_backend::ops::ActivationOps;
impl<E: TchElement> ActivationOps<Self> for LibTorch<E> {
fn relu(tensor: TchTensor) -> TchTensor {
tensor.unary_ops(|mut tensor| tensor.relu_(), |tensor| tensor.relu())
}
fn gelu(tensor: TchTensor) -> TchTensor {
tensor.unary_ops(
|mut tensor| tensor.gelu_("none"),
|tensor| tensor.gelu("none"),
)
}
fn gelu_backward(tensor: TchTensor, grad: TchTensor) -> TchTensor {
let storage = tensor.storage.clone();
let tensor = tensor.tensor.gelu_backward(&grad.tensor, "none");
TchTensor::from_existing(tensor, storage)
}
fn sigmoid(tensor: TchTensor) -> TchTensor {
tensor.unary_ops(|mut tensor| tensor.sigmoid_(), |tensor| tensor.sigmoid())
}
fn log_sigmoid(tensor: TchTensor) -> TchTensor {
// NOTE: we don't override log_sigmoid_backward because Torch has a special backward
// formula that uses a buffer with computed values from the forward pass
// no in-place log_sigmoid_
let storage = tensor.storage.clone();
let tensor = tensor.tensor.log_sigmoid();
TchTensor::from_existing(tensor, storage)
}
}

View File

@@ -0,0 +1,737 @@
use burn_backend::{Shape, TensorMetadata};
use tch::Scalar;
use crate::{LibTorchDevice, TchShape, TchTensor};
pub struct TchOps {
// e: PhantomData<E>,
}
impl TchOps {
pub fn to_device(tensor: TchTensor, device: &LibTorchDevice) -> TchTensor {
let device = (*device).into();
// We have to manually check if the device is the same, since when it's the case, we need to keep
// the same storage reference and not create a new one.
if tensor.tensor.device() == device {
return tensor;
}
TchTensor::new(tensor.tensor.to(device))
}
pub fn reshape(tensor: TchTensor, shape: Shape) -> TchTensor {
let shape_tch: TchShape = shape.into();
TchTensor::from_existing(tensor.tensor.reshape(shape_tch.dims), tensor.storage)
}
pub fn repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor {
let mut dims = vec![1; tensor.shape().num_dims()];
dims[dim] = times as i64;
let tensor = tch::Tensor::repeat(&tensor.tensor, dims);
TchTensor::new(tensor)
}
pub fn slice_with_steps(tensor: TchTensor, slices: &[burn_backend::Slice]) -> TchTensor {
let storage = tensor.storage.clone();
let mut tensor = tensor.tensor.shallow_clone();
for (dim, slice) in slices.iter().enumerate() {
let dim_i64 = dim as i64;
// Convert slice to range using a dummy size (we'll use tensor dimensions)
let dim_size = tensor.size()[dim];
let range = slice.to_range(dim_size as usize);
let start = range.start as i64;
let end = range.end as i64;
let step = slice.step as i64;
if step > 0 {
// Forward stepping - use native slice
tensor = tensor.slice(dim_i64, Some(start), Some(end), step);
} else {
// Negative stepping - we need to handle the semantics correctly
// For negative steps, we iterate backwards from end-1
// PyTorch's negative step works differently than our semantics
// We need to reverse the selected range
// First get the slice with positive step
tensor = tensor.slice(dim_i64, Some(start), Some(end), 1);
// Then reverse it and apply the step
if step == -1 {
// Simple reversal
tensor = tensor.flip([dim_i64]);
} else {
// Reverse and then take every nth element
tensor = tensor.flip([dim_i64]);
let abs_step = step.abs();
tensor = tensor.slice(dim_i64, None, None, abs_step);
}
}
}
TchTensor::partial(tensor, storage)
}
pub fn slice_assign(
tensor: TchTensor,
slices: &[burn_backend::Slice],
value: TchTensor,
) -> TchTensor {
// PyTorch's narrow operation only supports contiguous slices (step=1)
// For non-unit steps, we use advanced indexing as a workaround
let all_unit_steps = slices.iter().all(|s| s.step == 1);
if all_unit_steps {
// Fast path: use narrow and copy_ for unit steps
let tch_shape = TchShape::from(tensor.shape());
// Copy the input tensor if we can't mutate it
let tensor_original: TchTensor =
tensor.unary_ops(|tensor| tensor, |tensor| tensor.copy());
let tensor_original = tensor_original.tensor;
let mut tensor = tensor_original.view_(tch_shape.dims);
for (i, slice) in slices.iter().enumerate().take(slices.len()) {
// Convert Slice to range for narrow operation
let dim_size = tensor.size()[i] as usize;
let range = slice.to_range(dim_size);
let start = range.start as i64;
let length = (range.end - range.start) as i64;
tensor = tensor.narrow(i as i64, start, length);
}
tensor.copy_(&value.tensor);
TchTensor::new(tensor_original)
} else {
// Workaround for non-unit steps: use PyTorch's index_put operation
// This generates explicit indices for the slice and uses advanced indexing
let tensor_shape = tensor.shape();
let dims = tensor_shape.clone();
// Copy the tensor since we'll modify it
let result_tensor = tensor.tensor.shallow_clone();
// Use advanced indexing to set the values
Self::slice_assign_with_advanced_indexing(result_tensor, slices, value.tensor, &dims)
}
}
/// Generate indices for a slice with potentially non-unit step.
/// For negative steps, generates indices in reverse order.
fn generate_slice_indices(slice: &burn_backend::Slice, dim_size: usize) -> Vec<i64> {
let step = slice.step;
let range = slice.to_range(dim_size);
let mut indices = Vec::new();
if step > 0 {
let mut idx = range.start as i64;
while idx < range.end as i64 {
indices.push(idx);
idx += step as i64;
}
} else if step < 0 {
// For negative steps, iterate backwards through the range
let mut idx = (range.end - 1) as i64;
while idx >= range.start as i64 {
indices.push(idx);
idx += step as i64; // step is negative, so this decreases
}
}
indices
}
/// Implementation using advanced indexing for non-unit steps.
/// Uses PyTorch's index_put operation to assign values at specific indices.
fn slice_assign_with_advanced_indexing(
mut tensor: tch::Tensor,
slices: &[burn_backend::Slice],
value: tch::Tensor,
dims: &[usize],
) -> TchTensor {
// Generate all index combinations for the sliced regions
let mut index_sets: Vec<Vec<i64>> = Vec::new();
for (i, slice) in slices.iter().enumerate() {
let dim_size = if i < dims.len() { dims[i] } else { 1 };
let indices = Self::generate_slice_indices(slice, dim_size);
index_sets.push(indices);
}
// For unsliced dimensions, include all indices
for &dim_size in dims.iter().skip(slices.len()) {
let indices: Vec<i64> = (0..dim_size as i64).collect();
index_sets.push(indices);
}
// Convert index sets to tensors for index_put
let mut final_indices = Vec::new();
let total_elements = index_sets.iter().map(|s| s.len()).product::<usize>();
// Build flattened index arrays for each dimension using cartesian product
// This creates the index tensors needed for PyTorch's index_put operation
for dim_idx in 0..index_sets.len() {
let mut dim_indices = Vec::with_capacity(total_elements);
let repeat = index_sets[dim_idx + 1..]
.iter()
.map(|s| s.len())
.product::<usize>()
.max(1);
let tile = index_sets[..dim_idx]
.iter()
.map(|s| s.len())
.product::<usize>()
.max(1);
for _ in 0..tile {
for &idx in &index_sets[dim_idx] {
for _ in 0..repeat {
dim_indices.push(idx);
}
}
}
let indices_tensor = tch::Tensor::from_slice(&dim_indices).to_device(tensor.device());
final_indices.push(indices_tensor);
}
// PyTorch's index_put handles assignment correctly for negative steps
// following NumPy semantics: values[i] goes to selected_indices[i]
let value_flat = value.view(-1);
// Use index_put to assign values - convert to Option<Tensor>
let final_indices_opt: Vec<Option<tch::Tensor>> =
final_indices.into_iter().map(Some).collect();
tensor = tensor.index_put(&final_indices_opt, &value_flat, false);
TchTensor::new(tensor)
}
pub fn gather(dim: usize, tensor: TchTensor, indices: TchTensor) -> TchTensor {
let storage = tensor.storage.clone();
let tensor = tensor.tensor.gather(dim as i64, &indices.tensor, false);
TchTensor::from_existing(tensor, storage)
}
pub fn scatter(
dim: usize,
tensor: TchTensor,
indices: TchTensor,
value: TchTensor,
) -> TchTensor {
let storage = tensor.storage.clone();
let tensor = tensor
.tensor
.scatter_add(dim as i64, &indices.tensor, &value.tensor);
TchTensor::from_existing(tensor, storage)
}
pub fn index_select_dim(tensor: TchTensor, dim: usize, indices: TchTensor) -> TchTensor {
let storage = tensor.storage.clone();
let tensor = tensor.tensor.index_select(dim as i64, &indices.tensor);
TchTensor::from_existing(tensor, storage)
}
pub fn select_assign(
tensor: TchTensor,
dim: usize,
indices: TchTensor,
value: TchTensor,
) -> TchTensor {
tensor.clone().unary_ops(
|mut tensor| tensor.index_add_(dim as i64, &indices.tensor, &value.tensor),
|tensor| tensor.index_add(dim as i64, &indices.tensor, &value.tensor),
)
}
pub fn cat(tensors: Vec<TchTensor>, dim: usize) -> TchTensor {
let tensors: Vec<tch::Tensor> = tensors
.into_iter()
.map(|t| t.tensor.shallow_clone())
.collect();
let tensor = tch::Tensor::cat(&tensors, dim as i64);
TchTensor::new(tensor)
}
pub fn equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchTensor::binary_ops_tensor(
lhs,
rhs,
|lhs, rhs| lhs.eq_tensor_(rhs).to_kind(tch::Kind::Bool),
|lhs, rhs| rhs.eq_tensor_(lhs).to_kind(tch::Kind::Bool),
|lhs, rhs| lhs.eq_tensor(rhs),
)
}
pub fn equal_elem<S: Into<tch::Scalar> + Clone>(lhs: TchTensor, rhs: S) -> TchTensor {
lhs.unary_ops(
|mut tensor| tensor.eq_(rhs.clone().into()).to_kind(tch::Kind::Bool),
|tensor| tensor.eq(rhs.clone().into()),
)
}
pub fn greater(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchTensor::binary_ops_tensor(
lhs,
rhs,
|lhs, rhs| lhs.greater_tensor_(rhs).to_kind(tch::Kind::Bool),
|lhs, rhs| rhs.less_tensor_(lhs).to_kind(tch::Kind::Bool),
|lhs, rhs| lhs.greater_tensor(rhs),
)
}
pub fn greater_elem<S: Into<tch::Scalar> + Clone>(lhs: TchTensor, rhs: S) -> TchTensor {
lhs.unary_ops(
|mut tensor| tensor.greater_(rhs.clone().into()).to_kind(tch::Kind::Bool),
|tensor| tensor.greater(rhs.clone().into()),
)
}
pub fn greater_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchTensor::binary_ops_tensor(
lhs,
rhs,
|lhs, rhs| lhs.greater_equal_tensor_(rhs).to_kind(tch::Kind::Bool),
|lhs, rhs| rhs.less_equal_tensor_(lhs).to_kind(tch::Kind::Bool),
|lhs, rhs| lhs.greater_equal_tensor(rhs),
)
}
pub fn greater_equal_elem<S: Into<Scalar> + Clone>(lhs: TchTensor, rhs: S) -> TchTensor {
lhs.unary_ops(
|mut tensor| {
tensor
.greater_equal_(rhs.clone().into())
.to_kind(tch::Kind::Bool)
},
|tensor| tensor.greater_equal(rhs.clone().into()),
)
}
pub fn lower(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchTensor::binary_ops_tensor(
lhs,
rhs,
|lhs, rhs| lhs.less_tensor_(rhs).to_kind(tch::Kind::Bool),
|lhs, rhs| rhs.greater_tensor_(lhs).to_kind(tch::Kind::Bool),
|lhs, rhs| lhs.less_tensor(rhs),
)
}
pub fn lower_elem<S: Into<Scalar> + Clone>(lhs: TchTensor, rhs: S) -> TchTensor {
lhs.unary_ops(
|mut tensor| tensor.less_(rhs.clone().into()).to_kind(tch::Kind::Bool),
|tensor| tensor.less(rhs.clone().into()),
)
}
pub fn lower_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchTensor::binary_ops_tensor(
lhs,
rhs,
|lhs, rhs| lhs.less_equal_tensor_(rhs).to_kind(tch::Kind::Bool),
|lhs, rhs| rhs.greater_equal_tensor_(lhs).to_kind(tch::Kind::Bool),
|lhs, rhs| lhs.less_equal_tensor(rhs),
)
}
pub fn lower_equal_elem<S: Into<Scalar> + Clone>(lhs: TchTensor, rhs: S) -> TchTensor {
lhs.unary_ops(
|mut tensor| {
tensor
.less_equal_(rhs.clone().into())
.to_kind(tch::Kind::Bool)
},
|tensor| tensor.less_equal(rhs.clone().into()),
)
}
pub fn add(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchTensor::binary_ops_tensor(
lhs,
rhs,
|lhs, rhs| lhs.f_add_(rhs).unwrap(),
|lhs, rhs| rhs.f_add_(lhs).unwrap(),
|lhs, rhs| lhs.f_add(rhs).unwrap(),
)
}
pub fn sub(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchTensor::binary_ops_tensor(
lhs,
rhs,
|lhs, rhs| lhs.f_sub_(rhs).unwrap(),
|lhs, rhs| lhs.f_sub(rhs).unwrap(),
|lhs, rhs| lhs.f_sub(rhs).unwrap(),
)
}
pub fn mul(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchTensor::binary_ops_tensor(
lhs,
rhs,
|lhs, rhs| lhs.f_mul_(rhs).unwrap(),
|lhs, rhs| rhs.f_mul_(lhs).unwrap(),
|lhs, rhs| lhs.f_mul(rhs).unwrap(),
)
}
pub fn div(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchTensor::binary_ops_tensor(
lhs,
rhs,
|lhs, rhs| lhs.f_div_(rhs).unwrap(),
|lhs, rhs| lhs.f_div(rhs).unwrap(),
|lhs, rhs| lhs.f_div(rhs).unwrap(),
)
}
pub fn remainder(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchTensor::binary_ops_tensor(
lhs,
rhs,
|lhs, rhs| lhs.f_remainder_tensor_(rhs).unwrap(),
|lhs, rhs| lhs.f_remainder_tensor(rhs).unwrap(),
|lhs, rhs| lhs.f_remainder_tensor(rhs).unwrap(),
)
}
pub fn mean(tensor: TchTensor) -> TchTensor {
// view as 1d tensor
let tensor = tensor.tensor.mean(tensor.tensor.kind()).view(1);
TchTensor::new(tensor)
}
pub fn mean_dim(tensor: TchTensor, dim: usize) -> TchTensor {
TchTensor::from_existing(
tensor
.tensor
.mean_dim(Some([dim as i64].as_slice()), true, tensor.tensor.kind()),
tensor.storage,
)
}
pub fn sum(tensor: TchTensor) -> TchTensor {
// view as 1d tensor
let tensor = tensor.tensor.sum(tensor.tensor.kind()).view(1);
TchTensor::new(tensor)
}
pub fn sum_dim(tensor: TchTensor, dim: usize) -> TchTensor {
TchTensor::from_existing(
tensor.tensor.sum_dim_intlist(
Some([dim as i64].as_slice()),
true,
tensor.tensor.kind(),
),
tensor.storage,
)
}
pub fn prod(tensor: TchTensor) -> TchTensor {
// view as 1d tensor
let tensor = tensor.tensor.prod(tensor.tensor.kind()).view(1);
TchTensor::new(tensor)
}
pub fn prod_dim(tensor: TchTensor, dim: usize) -> TchTensor {
TchTensor::from_existing(
tensor
.tensor
.prod_dim_int(dim as i64, true, tensor.tensor.kind()),
tensor.storage,
)
}
pub fn cumsum(tensor: TchTensor, dim: usize) -> TchTensor {
TchTensor::from_existing(
tensor.tensor.cumsum(dim as i64, tensor.tensor.kind()),
tensor.storage,
)
}
pub fn cumprod(tensor: TchTensor, dim: usize) -> TchTensor {
TchTensor::from_existing(
tensor.tensor.cumprod(dim as i64, tensor.tensor.kind()),
tensor.storage,
)
}
pub fn cummin(tensor: TchTensor, dim: usize) -> TchTensor {
let (values, _indices) = tensor.tensor.cummin(dim as i64);
TchTensor::from_existing(values, tensor.storage)
}
pub fn cummax(tensor: TchTensor, dim: usize) -> TchTensor {
// cummax returns (values, indices) tuple in PyTorch, we only need values
let (values, _indices) = tensor.tensor.cummax(dim as i64);
TchTensor::from_existing(values, tensor.storage)
}
pub fn argmax(tensor: TchTensor, dim: usize) -> TchTensor {
let storage = tensor.storage.clone();
let tensor = tensor.tensor.argmax(dim as i64, true);
TchTensor::from_existing(tensor, storage)
}
pub fn argmin(tensor: TchTensor, dim: usize) -> TchTensor {
let storage = tensor.storage.clone();
let tensor = tensor.tensor.argmin(dim as i64, true);
TchTensor::from_existing(tensor, storage)
}
pub fn max_dim(tensor: TchTensor, dim: usize) -> TchTensor {
let storage = tensor.storage.clone();
let (tensor, _indices) = tensor.tensor.max_dim(dim as i64, true);
TchTensor::from_existing(tensor, storage)
}
pub fn max_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {
let storage = tensor.storage.clone();
let (tensor, indices) = tensor.tensor.max_dim(dim as i64, true);
let tensor = TchTensor::from_existing(tensor, storage);
let indices = TchTensor::new(indices);
(tensor, indices)
}
pub fn min_dim(tensor: TchTensor, dim: usize) -> TchTensor {
let storage = tensor.storage.clone();
let (tensor, _indices) = tensor.tensor.min_dim(dim as i64, true);
TchTensor::from_existing(tensor, storage)
}
pub fn min_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {
let storage = tensor.storage.clone();
let (tensor, indices) = tensor.tensor.min_dim(dim as i64, true);
let tensor = TchTensor::from_existing(tensor, storage);
let indices = TchTensor::new(indices);
(tensor, indices)
}
pub fn clamp_min<S: Into<tch::Scalar> + Clone + Copy>(tensor: TchTensor, min: S) -> TchTensor {
tensor.unary_ops(
|mut tensor| tensor.clamp_min_(min),
|tensor| tensor.clamp_min(min),
)
}
pub fn clamp_max<S: Into<tch::Scalar> + Clone + Copy>(tensor: TchTensor, max: S) -> TchTensor {
tensor.unary_ops(
|mut tensor| tensor.clamp_max_(max),
|tensor| tensor.clamp_max(max),
)
}
pub fn clamp<S: Into<tch::Scalar> + Clone + Copy>(
tensor: TchTensor,
min: S,
max: S,
) -> TchTensor {
tensor.unary_ops(
|mut tensor| tensor.clamp_(min, max),
|tensor| tensor.clamp(min, max),
)
}
pub fn swap_dims(tensor: TchTensor, dim1: usize, dim2: usize) -> TchTensor {
let tensor = tensor.tensor.transpose(dim1 as i64, dim2 as i64);
TchTensor::new(tensor)
}
pub fn permute(tensor: TchTensor, axes: &[usize]) -> TchTensor {
let tensor = tensor
.tensor
.permute(axes.iter().map(|x| *x as i64).collect::<Vec<_>>());
TchTensor::new(tensor)
}
pub fn flip(tensor: TchTensor, axes: &[usize]) -> TchTensor {
let dims = axes.iter().map(|x| *x as i64).collect::<Vec<_>>();
let tensor = tensor.tensor.flip(dims);
TchTensor::new(tensor)
}
pub fn powf(tensor: TchTensor, exponent: TchTensor) -> TchTensor {
TchTensor::binary_ops_tensor(
tensor,
exponent,
|lhs, rhs| lhs.f_pow_tensor_(rhs).unwrap(),
|lhs, rhs| lhs.f_pow(rhs).unwrap(),
|lhs, rhs| lhs.f_pow(rhs).unwrap(),
)
}
pub fn sign(tensor: TchTensor) -> TchTensor {
tensor.unary_ops(|mut tensor| tensor.sign_(), |tensor| tensor.sign())
}
pub fn expand(tensor: TchTensor, shape: Shape) -> TchTensor {
let storage = tensor.storage.clone();
let broadcasted_tensor = tensor.tensor.broadcast_to(TchShape::from(shape).dims);
TchTensor::from_existing(broadcasted_tensor, storage)
}
pub fn unfold(tensor: TchTensor, dim: usize, size: usize, step: usize) -> TchTensor {
let storage = tensor.storage.clone();
let uf_tensor = tensor.tensor.unfold(dim as i64, size as i64, step as i64);
TchTensor::from_existing(uf_tensor, storage)
}
pub fn sort(tensor: TchTensor, dim: usize, descending: bool) -> TchTensor {
TchTensor::new(tensor.tensor.sort(dim as i64, descending).0)
}
pub fn sort_with_indices(
tensor: TchTensor,
dim: usize,
descending: bool,
) -> (TchTensor, TchTensor) {
let sorted = tensor.tensor.sort(dim as i64, descending);
(TchTensor::new(sorted.0), TchTensor::new(sorted.1))
}
pub fn argsort(tensor: TchTensor, dim: usize, descending: bool) -> TchTensor {
TchTensor::new(tensor.tensor.argsort(dim as i64, descending))
}
pub fn bitwise_and(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchTensor::binary_ops_tensor(
lhs,
rhs,
|lhs, rhs| lhs.f_bitwise_and_tensor_(rhs).unwrap(),
|lhs, rhs| rhs.f_bitwise_and_tensor_(lhs).unwrap(),
|lhs, rhs| lhs.f_bitwise_and_tensor(rhs).unwrap(),
)
}
pub fn bitwise_and_scalar<S: Into<Scalar> + Clone>(tensor: TchTensor, scalar: S) -> TchTensor {
tensor.unary_ops(
|mut tensor| tensor.f_bitwise_and_(scalar.clone().into()).unwrap(),
|tensor| tensor.f_bitwise_and(scalar.clone().into()).unwrap(),
)
}
pub fn bitwise_or(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchTensor::binary_ops_tensor(
lhs,
rhs,
|lhs, rhs| lhs.f_bitwise_or_tensor_(rhs).unwrap(),
|lhs, rhs| rhs.f_bitwise_or_tensor_(lhs).unwrap(),
|lhs, rhs| lhs.f_bitwise_or_tensor(rhs).unwrap(),
)
}
pub fn bitwise_or_scalar<S: Into<Scalar> + Clone>(tensor: TchTensor, scalar: S) -> TchTensor {
tensor.unary_ops(
|mut tensor| tensor.f_bitwise_or_(scalar.clone().into()).unwrap(),
|tensor| tensor.f_bitwise_or(scalar.clone().into()).unwrap(),
)
}
pub fn bitwise_xor(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchTensor::binary_ops_tensor(
lhs,
rhs,
|lhs, rhs| lhs.f_bitwise_xor_tensor_(rhs).unwrap(),
|lhs, rhs| rhs.f_bitwise_xor_tensor_(lhs).unwrap(),
|lhs, rhs| lhs.f_bitwise_xor_tensor(rhs).unwrap(),
)
}
pub fn bitwise_xor_scalar<S: Into<Scalar> + Clone>(tensor: TchTensor, scalar: S) -> TchTensor {
tensor.unary_ops(
|mut tensor| tensor.f_bitwise_xor_(scalar.clone().into()).unwrap(),
|tensor| tensor.f_bitwise_xor(scalar.clone().into()).unwrap(),
)
}
pub fn bitwise_not(tensor: TchTensor) -> TchTensor {
tensor.unary_ops(
|mut tensor| tensor.f_bitwise_not_().unwrap(),
|tensor| tensor.f_bitwise_not().unwrap(),
)
}
pub fn bitwise_left_shift(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchTensor::binary_ops_tensor(
lhs,
rhs,
|lhs, rhs| lhs.f_bitwise_left_shift_(rhs).unwrap(),
|lhs, rhs| lhs.f_bitwise_left_shift(rhs).unwrap(),
|lhs, rhs| lhs.f_bitwise_left_shift(rhs).unwrap(),
)
}
pub fn bitwise_left_shift_scalar<S: Into<Scalar> + Clone>(
tensor: TchTensor,
scalar: S,
) -> TchTensor {
tensor.unary_ops(
|mut tensor| {
tensor
.f_bitwise_left_shift_tensor_scalar_(scalar.clone().into())
.unwrap()
},
|tensor| {
tensor
.f_bitwise_left_shift_tensor_scalar(scalar.clone().into())
.unwrap()
},
)
}
pub fn bitwise_right_shift(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchTensor::binary_ops_tensor(
lhs,
rhs,
|lhs, rhs| lhs.f_bitwise_right_shift_(rhs).unwrap(),
|lhs, rhs| lhs.f_bitwise_right_shift(rhs).unwrap(),
|lhs, rhs| lhs.f_bitwise_right_shift(rhs).unwrap(),
)
}
pub fn bitwise_right_shift_scalar<S: Into<Scalar> + Clone>(
tensor: TchTensor,
scalar: S,
) -> TchTensor {
tensor.unary_ops(
|mut tensor| {
tensor
.f_bitwise_right_shift_tensor_scalar_(scalar.clone().into())
.unwrap()
},
|tensor| {
tensor
.f_bitwise_right_shift_tensor_scalar(scalar.clone().into())
.unwrap()
},
)
}
pub fn atan2(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchTensor::binary_ops_tensor(
lhs,
rhs,
|lhs, rhs| lhs.f_atan2_(rhs).unwrap(),
|lhs, rhs| lhs.f_atan2(rhs).unwrap(),
|lhs, rhs| lhs.f_atan2(rhs).unwrap(),
)
}
}

View File

@@ -0,0 +1,219 @@
use super::TchOps;
use crate::{LibTorch, LibTorchDevice, TchShape, TchTensor, element::TchElement};
use burn_backend::ExecutionError;
use burn_backend::Scalar;
use burn_backend::tensor::BoolTensor;
use burn_backend::tensor::IntTensor;
use burn_backend::{Shape, TensorData, TensorMetadata, ops::BoolTensorOps};
impl<E: TchElement> BoolTensorOps<Self> for LibTorch<E> {
fn bool_from_data(data: TensorData, device: &LibTorchDevice) -> TchTensor {
match data.dtype {
burn_backend::DType::Bool => TchTensor::from_data::<bool>(data, (*device).into()),
_ => unimplemented!("Unsupported dtype for `bool_from_data`"),
}
}
fn bool_repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor {
TchOps::repeat_dim(tensor, dim, times)
}
async fn bool_into_data(tensor: TchTensor) -> Result<TensorData, ExecutionError> {
let shape = tensor.shape();
let tensor = Self::bool_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
let values: Result<Vec<bool>, tch::TchError> = tensor.tensor.shallow_clone().try_into();
Ok(TensorData::new(values.unwrap(), shape))
}
fn bool_to_device(tensor: TchTensor, device: &LibTorchDevice) -> TchTensor {
TchOps::to_device(tensor, device)
}
fn bool_reshape(tensor: TchTensor, shape: Shape) -> TchTensor {
TchOps::reshape(tensor, shape)
}
fn bool_device(tensor: &TchTensor) -> LibTorchDevice {
tensor.tensor.device().into()
}
fn bool_empty(shape: Shape, device: &LibTorchDevice) -> TchTensor {
let tensor = tch::Tensor::empty(
TchShape::from(shape).dims,
(tch::Kind::Bool, (*device).into()),
);
TchTensor::new(tensor)
}
fn bool_zeros(shape: Shape, device: &LibTorchDevice) -> TchTensor {
let tensor = tch::Tensor::zeros(
TchShape::from(shape).dims,
(tch::Kind::Bool, (*device).into()),
);
TchTensor::new(tensor)
}
fn bool_ones(shape: Shape, device: &LibTorchDevice) -> TchTensor {
let tensor = tch::Tensor::ones(
TchShape::from(shape).dims,
(tch::Kind::Bool, (*device).into()),
);
TchTensor::new(tensor)
}
fn bool_slice(tensor: TchTensor, slices: &[burn_backend::Slice]) -> TchTensor {
TchOps::slice_with_steps(tensor, slices)
}
fn bool_slice_assign(
tensor: TchTensor,
slices: &[burn_backend::Slice],
value: TchTensor,
) -> TchTensor {
TchOps::slice_assign(tensor, slices, value)
}
fn bool_cat(tensors: Vec<TchTensor>, dim: usize) -> TchTensor {
TchOps::cat(tensors, dim)
}
fn bool_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchOps::equal(lhs, rhs)
}
fn bool_not(tensor: TchTensor) -> TchTensor {
tensor.unary_ops(
|mut tensor| tensor.eq_(0).to_kind(tch::Kind::Bool),
|tensor| tensor.eq(0),
)
}
fn bool_and(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchTensor::binary_ops_tensor(
lhs,
rhs,
|lhs, rhs| lhs.logical_and_(rhs),
|lhs, rhs| rhs.logical_and_(lhs),
|lhs, rhs| lhs.logical_and(rhs),
)
}
fn bool_or(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchTensor::binary_ops_tensor(
lhs,
rhs,
|lhs, rhs| lhs.logical_or_(rhs),
|lhs, rhs| rhs.logical_or_(lhs),
|lhs, rhs| lhs.logical_or(rhs),
)
}
fn bool_into_int(tensor: TchTensor) -> TchTensor {
let tensor = tensor.tensor.to_kind(tch::Kind::Int64);
TchTensor::new(tensor)
}
fn bool_into_float(tensor: TchTensor) -> TchTensor {
let tensor = tensor.tensor.to_kind(E::kind());
TchTensor::new(tensor)
}
fn bool_swap_dims(tensor: TchTensor, dim1: usize, dim2: usize) -> TchTensor {
TchOps::swap_dims(tensor, dim1, dim2)
}
fn bool_permute(tensor: TchTensor, axes: &[usize]) -> TchTensor {
TchOps::permute(tensor, axes)
}
fn bool_flip(tensor: TchTensor, axes: &[usize]) -> TchTensor {
TchOps::flip(tensor, axes)
}
async fn bool_argwhere(tensor: TchTensor) -> TchTensor {
TchTensor::new(tensor.tensor.argwhere())
}
fn bool_select(tensor: TchTensor, dim: usize, indices: TchTensor) -> TchTensor {
TchOps::index_select_dim(tensor, dim, indices)
}
fn bool_select_or(
tensor: TchTensor,
dim: usize,
indices: TchTensor,
value: TchTensor,
) -> TchTensor {
TchOps::select_assign(tensor, dim, indices, value)
}
fn bool_expand(tensor: TchTensor, shape: Shape) -> TchTensor {
TchOps::expand(tensor, shape)
}
fn bool_unfold(
tensor: IntTensor<Self>,
dim: usize,
size: usize,
step: usize,
) -> IntTensor<Self> {
TchOps::unfold(tensor, dim, size, step)
}
fn bool_mask_where(
tensor: BoolTensor<Self>,
mask: BoolTensor<Self>,
value: BoolTensor<Self>,
) -> BoolTensor<Self> {
TchTensor::binary_ops_tensor(
tensor,
value,
|tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
|tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
|tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
)
}
fn bool_mask_fill(
tensor: BoolTensor<Self>,
mask: BoolTensor<Self>,
value: Scalar,
) -> BoolTensor<Self> {
tensor.unary_ops(
|mut tensor| {
tensor
.f_masked_fill_(&mask.tensor, value.elem::<i64>())
.unwrap()
},
|tensor| {
tensor
.f_masked_fill(&mask.tensor, value.elem::<i64>())
.unwrap()
},
)
}
fn bool_gather(
dim: usize,
tensor: BoolTensor<Self>,
indices: IntTensor<Self>,
) -> BoolTensor<Self> {
TchOps::gather(dim, tensor, indices)
}
fn bool_scatter_or(
dim: usize,
tensor: BoolTensor<Self>,
indices: IntTensor<Self>,
value: BoolTensor<Self>,
) -> BoolTensor<Self> {
TchOps::scatter(dim, tensor, indices, value)
}
fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
TchOps::equal_elem(lhs, rhs.elem::<i64>())
}
}

View File

@@ -0,0 +1,501 @@
use std::ops::Range;
use burn_backend::{
Distribution, ExecutionError, IntDType, Scalar, Shape, TensorData, TensorMetadata,
ops::{FloatTensorOps, IntTensorOps},
tensor::IntTensor,
};
use crate::{IntoKind, LibTorch, LibTorchDevice, TchShape, TchTensor, element::TchElement};
use super::TchOps;
impl<E: TchElement> IntTensorOps<Self> for LibTorch<E> {
fn int_from_data(data: TensorData, device: &LibTorchDevice) -> TchTensor {
match data.dtype {
burn_backend::DType::I64 => TchTensor::from_data::<i64>(data, (*device).into()),
_ => unimplemented!("Unsupported dtype for `int_from_data`"),
}
}
fn int_repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor {
TchOps::repeat_dim(tensor, dim, times)
}
async fn int_into_data(tensor: TchTensor) -> Result<TensorData, ExecutionError> {
let shape = tensor.shape();
let tensor = Self::int_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
let values: Result<Vec<i64>, tch::TchError> = tensor.tensor.shallow_clone().try_into();
Ok(TensorData::new(values.unwrap(), shape))
}
fn int_to_device(tensor: TchTensor, device: &LibTorchDevice) -> TchTensor {
TchOps::to_device(tensor, device)
}
fn int_reshape(tensor: TchTensor, shape: Shape) -> TchTensor {
TchOps::reshape(tensor, shape)
}
fn int_device(tensor: &TchTensor) -> LibTorchDevice {
tensor.tensor.device().into()
}
fn int_empty(shape: Shape, device: &LibTorchDevice, dtype: IntDType) -> TchTensor {
let tensor = tch::Tensor::empty(
TchShape::from(shape).dims,
(dtype.into_kind(), (*device).into()),
);
TchTensor::new(tensor)
}
fn int_slice(tensor: TchTensor, slices: &[burn_backend::Slice]) -> TchTensor {
TchOps::slice_with_steps(tensor, slices)
}
fn int_slice_assign(
tensor: TchTensor,
slices: &[burn_backend::Slice],
value: TchTensor,
) -> TchTensor {
TchOps::slice_assign(tensor, slices, value)
}
fn int_cat(tensors: Vec<TchTensor>, dim: usize) -> TchTensor {
TchOps::cat(tensors, dim)
}
fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
let lhs = Self::int_into_float(lhs);
let rhs = Self::int_into_float(rhs);
let out = lhs.tensor.f_matmul(&rhs.tensor).unwrap();
Self::float_into_int(TchTensor::new(out))
}
fn int_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchOps::equal(lhs, rhs)
}
fn int_equal_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor {
TchOps::equal_elem(lhs, rhs.elem::<i64>())
}
fn int_greater(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchOps::greater(lhs, rhs)
}
fn int_greater_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor {
TchOps::greater_elem(lhs, rhs.elem::<i64>())
}
fn int_greater_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchOps::greater_equal(lhs, rhs)
}
fn int_greater_equal_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor {
TchOps::greater_equal_elem(lhs, rhs.elem::<i64>())
}
fn int_lower(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchOps::lower(lhs, rhs)
}
fn int_lower_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor {
TchOps::lower_elem(lhs, rhs.elem::<i64>())
}
fn int_lower_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchOps::lower_equal(lhs, rhs)
}
fn int_lower_equal_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor {
TchOps::lower_equal_elem(lhs, rhs.elem::<i64>())
}
fn int_add(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchOps::add(lhs, rhs)
}
fn int_add_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
lhs.unary_ops(
|mut tensor| tensor.f_add_scalar_(rhs.elem::<i64>()).unwrap(),
|tensor| tensor.f_add_scalar(rhs.elem::<i64>()).unwrap(),
)
}
fn int_sub(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchOps::sub(lhs, rhs)
}
fn int_sub_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
lhs.unary_ops(
|mut tensor| tensor.f_sub_scalar_(rhs.elem::<i64>()).unwrap(),
|tensor| tensor.f_sub_scalar(rhs.elem::<i64>()).unwrap(),
)
}
fn int_mul(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchOps::mul(lhs, rhs)
}
fn int_mul_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
lhs.unary_ops(
|mut tensor| tensor.f_mul_scalar_(rhs.elem::<i64>()).unwrap(),
|tensor| tensor.f_mul_scalar(rhs.elem::<i64>()).unwrap(),
)
}
fn int_div(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
let dtype = lhs.tensor.kind();
let copy = false;
let non_blocking = true;
let lhs: TchTensor =
TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
let rhs: TchTensor =
TchTensor::new(rhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
let out = TchOps::div(lhs, rhs);
TchTensor::new(out.tensor.to_dtype(dtype, non_blocking, copy))
}
fn int_div_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
let dtype = lhs.tensor.kind();
let copy = false;
let non_blocking = true;
let lhs: TchTensor =
TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
let out: TchTensor = lhs.unary_ops(
|mut tensor| tensor.f_div_scalar_(rhs.elem::<i64>()).unwrap(),
|tensor| tensor.f_div_scalar(rhs.elem::<i64>()).unwrap(),
);
TchTensor::new(out.tensor.to_dtype(dtype, non_blocking, copy))
}
fn int_remainder(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
let dtype = lhs.tensor.kind();
let copy = false;
let non_blocking = true;
let lhs: TchTensor =
TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
let rhs: TchTensor =
TchTensor::new(rhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
let out = TchOps::remainder(lhs, rhs);
TchTensor::new(out.tensor.to_dtype(dtype, non_blocking, copy))
}
fn int_remainder_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
lhs.unary_ops(
|tensor| tensor.f_remainder(rhs.elem::<i64>()).unwrap(),
|tensor| tensor.f_remainder(rhs.elem::<i64>()).unwrap(),
)
}
fn int_zeros(shape: Shape, device: &LibTorchDevice, dtype: IntDType) -> TchTensor {
let shape = TchShape::from(shape);
let device: tch::Device = (*device).into();
TchTensor::new(tch::Tensor::zeros(shape.dims, (dtype.into_kind(), device)))
}
fn int_ones(shape: Shape, device: &LibTorchDevice, dtype: IntDType) -> TchTensor {
let shape = TchShape::from(shape);
let device: tch::Device = (*device).into();
TchTensor::new(tch::Tensor::ones(shape.dims, (dtype.into_kind(), device)))
}
fn int_full(
shape: Shape,
fill_value: Scalar,
device: &LibTorchDevice,
dtype: IntDType,
) -> TchTensor {
let shape = TchShape::from(shape);
let device: tch::Device = (*device).into();
TchTensor::new(tch::Tensor::full(
shape.dims,
fill_value.elem::<i64>(),
(dtype.into_kind(), device),
))
}
fn int_sum(tensor: TchTensor) -> TchTensor {
TchOps::sum(tensor)
}
fn int_sum_dim(tensor: TchTensor, dim: usize) -> TchTensor {
TchOps::sum_dim(tensor, dim)
}
fn int_prod(tensor: TchTensor) -> TchTensor {
TchOps::prod(tensor)
}
fn int_prod_dim(tensor: TchTensor, dim: usize) -> TchTensor {
TchOps::prod_dim(tensor, dim)
}
fn int_mean(tensor: TchTensor) -> TchTensor {
let dtype = tensor.tensor.kind();
let tensor: TchTensor =
TchTensor::new(tensor.tensor.to_dtype(tch::Kind::Float, true, false));
let output: TchTensor = TchTensor::new(TchOps::mean(tensor).tensor);
TchTensor::new(output.tensor.to_dtype(dtype, true, false))
}
fn int_mean_dim(tensor: TchTensor, dim: usize) -> TchTensor {
let dtype = tensor.tensor.kind();
let tensor: TchTensor =
TchTensor::new(tensor.tensor.to_dtype(tch::Kind::Float, true, false));
let output: TchTensor = TchTensor::new(TchOps::mean_dim(tensor, dim).tensor);
TchTensor::new(output.tensor.to_dtype(dtype, true, false))
}
fn int_cumsum(tensor: TchTensor, dim: usize) -> TchTensor {
TchOps::cumsum(tensor, dim)
}
fn int_cumprod(tensor: TchTensor, dim: usize) -> TchTensor {
TchOps::cumprod(tensor, dim)
}
fn int_cummin(tensor: TchTensor, dim: usize) -> TchTensor {
TchOps::cummin(tensor, dim)
}
fn int_cummax(tensor: TchTensor, dim: usize) -> TchTensor {
TchOps::cummax(tensor, dim)
}
fn int_gather(dim: usize, tensor: TchTensor, indices: TchTensor) -> TchTensor {
TchOps::gather(dim, tensor, indices)
}
fn int_scatter_add(
dim: usize,
tensor: TchTensor,
indices: TchTensor,
value: TchTensor,
) -> TchTensor {
TchOps::scatter(dim, tensor, indices, value)
}
fn int_select(tensor: TchTensor, dim: usize, indices: TchTensor) -> TchTensor {
TchOps::index_select_dim(tensor, dim, indices)
}
fn int_select_add(
tensor: TchTensor,
dim: usize,
indices: TchTensor,
value: TchTensor,
) -> TchTensor {
TchOps::select_assign(tensor, dim, indices, value)
}
fn int_mask_where(tensor: TchTensor, mask: TchTensor, source: TchTensor) -> TchTensor {
TchTensor::binary_ops_tensor(
tensor,
source,
|tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
|tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
|tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
)
}
fn int_mask_fill(tensor: TchTensor, mask: TchTensor, value: Scalar) -> TchTensor {
let value = value.elem::<i64>();
tensor.unary_ops(
|mut tensor| tensor.f_masked_fill_(&mask.tensor, value).unwrap(),
|tensor| tensor.f_masked_fill(&mask.tensor, value).unwrap(),
)
}
fn int_argmax(tensor: TchTensor, dim: usize) -> TchTensor {
TchOps::argmax(tensor, dim)
}
fn int_argmin(tensor: TchTensor, dim: usize) -> TchTensor {
TchOps::argmin(tensor, dim)
}
fn int_max_dim(tensor: TchTensor, dim: usize) -> TchTensor {
TchOps::max_dim(tensor, dim)
}
fn int_max_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {
TchOps::max_dim_with_indices(tensor, dim)
}
fn int_min_dim(tensor: TchTensor, dim: usize) -> TchTensor {
TchOps::min_dim(tensor, dim)
}
fn int_min_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {
TchOps::min_dim_with_indices(tensor, dim)
}
fn int_clamp_min(tensor: TchTensor, min: Scalar) -> TchTensor {
TchOps::clamp_min(tensor, min.elem::<i64>())
}
fn int_clamp_max(tensor: TchTensor, max: Scalar) -> TchTensor {
TchOps::clamp_max(tensor, max.elem::<i64>())
}
fn int_clamp(tensor: TchTensor, min: Scalar, max: Scalar) -> TchTensor {
TchOps::clamp(tensor, min.elem::<i64>(), max.elem::<i64>())
}
fn int_abs(tensor: TchTensor) -> TchTensor {
tensor.unary_ops(|mut tensor| tensor.abs_(), |tensor| tensor.abs())
}
fn int_into_float(tensor: TchTensor) -> TchTensor {
let tensor = tensor.tensor.to_kind(E::kind());
TchTensor::new(tensor)
}
fn int_swap_dims(tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {
TchOps::swap_dims(tensor, dim1, dim2)
}
fn int_random(shape: Shape, distribution: Distribution, device: &LibTorchDevice) -> TchTensor {
match distribution {
Distribution::Default => TchTensor::new(tch::Tensor::randint_low(
0,
255,
shape.iter().map(|i| *i as i64).collect::<Vec<_>>(),
(tch::Kind::Int64, (*device).into()),
)),
Distribution::Bernoulli(prob) => {
let mut tensor = TchTensor::empty::<i64>(shape, *device);
tensor
.mut_ops(|tensor| tensor.f_bernoulli_float_(prob).unwrap())
.unwrap()
}
Distribution::Uniform(from, to) => TchTensor::new(tch::Tensor::randint_low(
from as i64,
to as i64,
shape.iter().map(|i| *i as i64).collect::<Vec<_>>(),
(tch::Kind::Int64, (*device).into()),
)),
Distribution::Normal(mean, std) => {
let mut tensor = TchTensor::empty::<i64>(shape, *device);
tensor.mut_ops(|tensor| tensor.normal_(mean, std)).unwrap()
}
}
}
fn int_arange(range: Range<i64>, device: &LibTorchDevice) -> TchTensor {
let device: tch::Device = (*device).into();
let mut tensor = tch::Tensor::arange(range.end - range.start, (tch::Kind::Int64, device));
if range.start != 0 {
tensor = tensor.f_add_scalar_(range.start).unwrap();
}
TchTensor::new(tensor)
}
fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
TchOps::permute(tensor, axes)
}
fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
TchOps::flip(tensor, axes)
}
fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> {
TchOps::sign(tensor)
}
fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
TchOps::expand(tensor, shape)
}
fn int_sort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
TchOps::sort(tensor, dim, descending)
}
fn int_argsort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
TchOps::argsort(tensor, dim, descending)
}
fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
TchOps::bitwise_and(lhs, rhs)
}
fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
TchOps::bitwise_or(lhs, rhs)
}
fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
TchOps::bitwise_xor(lhs, rhs)
}
fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
TchOps::bitwise_not(tensor)
}
fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
TchOps::bitwise_and_scalar(lhs, rhs.elem::<i64>())
}
fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
TchOps::bitwise_or_scalar(lhs, rhs.elem::<i64>())
}
fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
TchOps::bitwise_xor_scalar(lhs, rhs.elem::<i64>())
}
fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
TchOps::bitwise_left_shift(lhs, rhs)
}
fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
TchOps::bitwise_right_shift(lhs, rhs)
}
fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
TchOps::bitwise_left_shift_scalar(lhs, rhs.elem::<i64>())
}
fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
TchOps::bitwise_right_shift_scalar(lhs, rhs.elem::<i64>())
}
fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
// NOTE: when dtypes of inputs to an arithmetic operation differ, tch handles type
// promotion based on a set of rules: https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc
// Type promotion is not automatic on all backends so this behavior might differ
let kind = dtype.into_kind();
if tensor.tensor.kind() == kind {
tensor
} else {
TchTensor::new(tensor.tensor.to_kind(kind))
}
}
fn int_unfold(
tensor: IntTensor<Self>,
dim: usize,
size: usize,
step: usize,
) -> IntTensor<Self> {
TchOps::unfold(tensor, dim, size, step)
}
}

View File

@@ -0,0 +1,10 @@
mod activation;
mod base;
mod bool_tensor;
mod int_tensor;
mod module;
mod qtensor;
mod tensor;
mod transaction;
pub(crate) use base::*;

View File

@@ -0,0 +1,473 @@
use crate::{LibTorch, TchTensor, element::TchElement};
use burn_backend::{
TensorMetadata,
ops::{
AttentionModuleOptions, ConvOptions, ConvTransposeOptions, DeformConv2dBackward,
DeformConvOptions, InterpolateMode, InterpolateOptions, MaxPool1dWithIndices,
MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, attention::attention_fallback,
},
};
impl<E: TchElement> ModuleOps<Self> for LibTorch<E> {
fn embedding(weights: TchTensor, indices: TchTensor) -> TchTensor {
// Workaround for MPS "Placeholder storage has not been allocated" error.
// See: https://github.com/pytorch/pytorch/issues/123995
// MPS uses lazy allocation and the embedding operation (which uses index_select)
// can fail if the tensors haven't been materialized yet.
// We work around this by performing the embedding on CPU and transferring back to MPS.
if matches!(weights.tensor.device(), tch::Device::Mps) {
let cpu_weights = weights.tensor.to(tch::Device::Cpu);
let cpu_indices = indices.tensor.to(tch::Device::Cpu);
let result = tch::Tensor::embedding(&cpu_weights, &cpu_indices, -1, false, false)
.to(tch::Device::Mps);
return TchTensor::new(result);
}
let tensor = tch::Tensor::embedding(&weights.tensor, &indices.tensor, -1, false, false);
TchTensor::new(tensor)
}
fn embedding_backward(weights: TchTensor, output: TchTensor, indices: TchTensor) -> TchTensor {
let [n_embedding, _d_model] = weights.shape().dims();
// Workaround for MPS "Placeholder storage has not been allocated" error.
// See: https://github.com/pytorch/pytorch/issues/123995
if matches!(output.tensor.device(), tch::Device::Mps) {
let cpu_output = output.tensor.to(tch::Device::Cpu);
let cpu_indices = indices.tensor.to(tch::Device::Cpu);
let result = tch::Tensor::embedding_backward(
&cpu_output,
&cpu_indices,
n_embedding as i64,
-1,
false,
false,
)
.to(tch::Device::Mps);
return TchTensor::new(result);
}
let tensor = tch::Tensor::embedding_backward(
&output.tensor,
&indices.tensor,
n_embedding as i64,
-1,
false,
false,
);
TchTensor::new(tensor)
}
fn conv1d(
x: TchTensor,
weight: TchTensor,
bias: Option<TchTensor>,
options: ConvOptions<1>,
) -> TchTensor {
let tensor = tch::Tensor::conv1d(
&x.tensor,
&weight.tensor,
bias.map(|t| t.tensor),
options.stride.map(|i| i as i64),
options.padding.map(|i| i as i64),
options.dilation.map(|i| i as i64),
options.groups as i64,
);
TchTensor::new(tensor)
}
fn conv2d(
x: TchTensor,
weight: TchTensor,
bias: Option<TchTensor>,
options: ConvOptions<2>,
) -> TchTensor {
let tensor = tch::Tensor::conv2d(
&x.tensor,
&weight.tensor,
bias.map(|t| t.tensor),
options.stride.map(|i| i as i64),
options.padding.map(|i| i as i64),
options.dilation.map(|i| i as i64),
options.groups as i64,
);
TchTensor::new(tensor)
}
fn conv3d(
x: TchTensor,
weight: TchTensor,
bias: Option<TchTensor>,
options: ConvOptions<3>,
) -> TchTensor {
let tensor = tch::Tensor::conv3d(
&x.tensor,
&weight.tensor,
bias.map(|t| t.tensor),
options.stride.map(|i| i as i64),
options.padding.map(|i| i as i64),
options.dilation.map(|i| i as i64),
options.groups as i64,
);
TchTensor::new(tensor)
}
fn deform_conv2d(
_x: TchTensor,
_offset: TchTensor,
_weight: TchTensor,
_mask: Option<TchTensor>,
_bias: Option<TchTensor>,
_options: DeformConvOptions<2>,
) -> TchTensor {
unimplemented!("Torch bindings don't support deform_conv2d");
}
fn deform_conv2d_backward(
_x: TchTensor,
_offset: TchTensor,
_weight: TchTensor,
_mask: Option<TchTensor>,
_bias: Option<TchTensor>,
_out_grad: TchTensor,
_options: DeformConvOptions<2>,
) -> DeformConv2dBackward<Self> {
unimplemented!("Torch bindings don't support deform_conv2d");
}
fn conv_transpose1d(
x: TchTensor,
weight: TchTensor,
bias: Option<TchTensor>,
options: ConvTransposeOptions<1>,
) -> TchTensor {
let tensor = tch::Tensor::conv_transpose1d(
&x.tensor,
&weight.tensor,
bias.map(|t| t.tensor),
options.stride.map(|i| i as i64),
options.padding.map(|i| i as i64),
options.padding_out.map(|i| i as i64),
options.groups as i64,
options.dilation.map(|i| i as i64),
);
TchTensor::new(tensor)
}
fn conv_transpose2d(
x: TchTensor,
weight: TchTensor,
bias: Option<TchTensor>,
options: ConvTransposeOptions<2>,
) -> TchTensor {
let tensor = tch::Tensor::conv_transpose2d(
&x.tensor,
&weight.tensor,
bias.map(|t| t.tensor),
options.stride.map(|i| i as i64),
options.padding.map(|i| i as i64),
options.padding_out.map(|i| i as i64),
options.groups as i64,
options.dilation.map(|i| i as i64),
);
TchTensor::new(tensor)
}
fn conv_transpose3d(
x: TchTensor,
weight: TchTensor,
bias: Option<TchTensor>,
options: ConvTransposeOptions<3>,
) -> TchTensor {
let tensor = tch::Tensor::conv_transpose3d(
&x.tensor,
&weight.tensor,
bias.map(|t| t.tensor),
options.stride.map(|i| i as i64),
options.padding.map(|i| i as i64),
options.padding_out.map(|i| i as i64),
options.groups as i64,
options.dilation.map(|i| i as i64),
);
TchTensor::new(tensor)
}
fn avg_pool1d(
x: TchTensor,
kernel_size: usize,
stride: usize,
padding: usize,
count_include_pad: bool,
ceil_mode: bool,
) -> TchTensor {
let tensor = tch::Tensor::avg_pool1d(
&x.tensor,
[kernel_size as i64],
[stride as i64],
[padding as i64],
ceil_mode,
count_include_pad,
);
TchTensor::new(tensor)
}
fn avg_pool2d(
x: TchTensor,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
count_include_pad: bool,
ceil_mode: bool,
) -> TchTensor {
let tensor = tch::Tensor::avg_pool2d(
&x.tensor,
[kernel_size[0] as i64, kernel_size[1] as i64],
[stride[0] as i64, stride[1] as i64],
[padding[0] as i64, padding[1] as i64],
ceil_mode,
count_include_pad,
None,
);
TchTensor::new(tensor)
}
fn avg_pool2d_backward(
x: TchTensor,
grad: TchTensor,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
count_include_pad: bool,
ceil_mode: bool,
) -> TchTensor {
let tensor = tch::Tensor::avg_pool2d_backward(
&x.tensor,
&grad.tensor,
[kernel_size[0] as i64, kernel_size[1] as i64],
[stride[0] as i64, stride[1] as i64],
[padding[0] as i64, padding[1] as i64],
ceil_mode,
count_include_pad,
None,
);
TchTensor::new(tensor)
}
fn max_pool1d(
x: TchTensor,
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
ceil_mode: bool,
) -> TchTensor {
let tensor = tch::Tensor::max_pool1d(
&x.tensor,
kernel_size as i64,
stride as i64,
padding as i64,
dilation as i64,
ceil_mode,
);
TchTensor::new(tensor)
}
fn max_pool1d_with_indices(
x: TchTensor,
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
ceil_mode: bool,
) -> MaxPool1dWithIndices<Self> {
let (tensor, indices) = tch::Tensor::max_pool1d_with_indices(
&x.tensor,
kernel_size as i64,
stride as i64,
padding as i64,
dilation as i64,
ceil_mode,
);
MaxPool1dWithIndices::new(TchTensor::new(tensor), TchTensor::new(indices))
}
fn max_pool2d(
x: TchTensor,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
ceil_mode: bool,
) -> TchTensor {
let tensor = tch::Tensor::max_pool2d(
&x.tensor,
[kernel_size[0] as i64, kernel_size[1] as i64],
[stride[0] as i64, stride[1] as i64],
[padding[0] as i64, padding[1] as i64],
[dilation[0] as i64, dilation[1] as i64],
ceil_mode,
);
TchTensor::new(tensor)
}
fn max_pool2d_with_indices(
x: TchTensor,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
ceil_mode: bool,
) -> MaxPool2dWithIndices<Self> {
let (tensor, indices) = tch::Tensor::max_pool2d_with_indices(
&x.tensor,
[kernel_size[0] as i64, kernel_size[1] as i64],
[stride[0] as i64, stride[1] as i64],
[padding[0] as i64, padding[1] as i64],
[dilation[0] as i64, dilation[1] as i64],
ceil_mode,
);
MaxPool2dWithIndices::new(TchTensor::new(tensor), TchTensor::new(indices))
}
fn max_pool2d_with_indices_backward(
x: TchTensor,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
ceil_mode: bool,
output_grad: TchTensor,
indices: TchTensor,
) -> MaxPool2dBackward<Self> {
let grad = tch::Tensor::max_pool2d_with_indices_backward(
&x.tensor,
&output_grad.tensor,
[kernel_size[0] as i64, kernel_size[1] as i64],
[stride[0] as i64, stride[1] as i64],
[padding[0] as i64, padding[1] as i64],
[dilation[0] as i64, dilation[1] as i64],
ceil_mode,
&indices.tensor,
);
MaxPool2dBackward::new(TchTensor::new(grad))
}
fn adaptive_avg_pool2d(x: TchTensor, output_size: [usize; 2]) -> TchTensor {
let tensor = tch::Tensor::adaptive_avg_pool2d(&x.tensor, output_size.map(|e| e as i64));
TchTensor::new(tensor)
}
fn adaptive_avg_pool2d_backward(x: TchTensor, grad: TchTensor) -> TchTensor {
let tensor = tch::Tensor::internal_adaptive_avg_pool2d_backward(&x.tensor, &grad.tensor);
TchTensor::new(tensor)
}
fn adaptive_avg_pool1d(x: TchTensor, output_size: usize) -> TchTensor {
let tensor = tch::Tensor::adaptive_avg_pool1d(&x.tensor, output_size as i64);
TchTensor::new(tensor)
}
fn interpolate(
x: TchTensor,
output_size: [usize; 2],
options: InterpolateOptions,
) -> TchTensor {
let output_size = output_size.map(|e| e as i64);
let align_corners = options.align_corners;
let tensor = match options.mode {
InterpolateMode::Nearest => {
tch::Tensor::upsample_nearest2d(&x.tensor, output_size, None, None)
}
InterpolateMode::Bilinear => {
tch::Tensor::upsample_bilinear2d(&x.tensor, output_size, align_corners, None, None)
}
InterpolateMode::Bicubic => {
tch::Tensor::upsample_bicubic2d(&x.tensor, output_size, align_corners, None, None)
}
};
TchTensor::new(tensor)
}
fn interpolate_backward(
x: TchTensor,
grad: TchTensor,
output_size: [usize; 2],
options: InterpolateOptions,
) -> TchTensor {
let output_size = output_size.map(|e| e as i64);
let [n, c, h_in, w_in] = x.shape().dims();
let input_size = [n as i64, c as i64, h_in as i64, w_in as i64];
let align_corners = options.align_corners;
let tensor = match options.mode {
InterpolateMode::Nearest => tch::Tensor::upsample_nearest2d_backward(
&grad.tensor,
output_size,
input_size,
None,
None,
),
InterpolateMode::Bilinear => tch::Tensor::upsample_bilinear2d_backward(
&grad.tensor,
output_size,
input_size,
align_corners,
None,
None,
),
InterpolateMode::Bicubic => tch::Tensor::upsample_bicubic2d_backward(
&grad.tensor,
output_size,
input_size,
align_corners,
None,
None,
),
};
TchTensor::new(tensor)
}
fn attention(
query: TchTensor,
key: TchTensor,
value: TchTensor,
mask: Option<TchTensor>,
attn_bias: Option<TchTensor>,
options: AttentionModuleOptions,
) -> TchTensor {
if attn_bias.is_some() {
return attention_fallback::<Self>(query, key, value, mask, attn_bias, options);
}
TchTensor::new(tch::Tensor::scaled_dot_product_attention(
&query.tensor,
&key.tensor,
&value.tensor,
mask.map(|m| m.tensor),
0.,
options.is_causal,
options.scale,
false,
))
}
}

View File

@@ -0,0 +1,140 @@
use burn_backend::{
ExecutionError, Shape, TensorData,
ops::QTensorOps,
quantization::{QuantScheme, QuantizationParametersPrimitive},
tensor::{Device, FloatTensor, IntTensor, QuantizedTensor},
};
use crate::{LibTorch, LibTorchDevice, TchElement};
impl<E: TchElement> QTensorOps<Self> for LibTorch<E> {
fn q_from_data(_data: TensorData, _device: &LibTorchDevice) -> QuantizedTensor<Self> {
unimplemented!()
}
fn quantize(
_tensor: FloatTensor<Self>,
_scheme: &QuantScheme,
_qparams: QuantizationParametersPrimitive<Self>,
) -> QuantizedTensor<Self> {
unimplemented!()
}
fn quantize_dynamic(
_tensor: FloatTensor<Self>,
_scheme: &QuantScheme,
) -> QuantizedTensor<Self> {
unimplemented!()
}
fn dequantize(_tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
unimplemented!()
}
fn q_device(_tensor: &QuantizedTensor<Self>) -> LibTorchDevice {
unimplemented!()
}
fn q_to_device(
_tensor: QuantizedTensor<Self>,
_device: &Device<Self>,
) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_reshape(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {
unimplemented!()
}
async fn q_into_data(_tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {
unimplemented!()
}
fn q_swap_dims(
_tensor: QuantizedTensor<Self>,
_dim1: usize,
_dim2: usize,
) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_permute(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_flip(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_select(
_tensor: QuantizedTensor<Self>,
_dim: usize,
_indices: IntTensor<Self>,
) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_slice(
_tensor: QuantizedTensor<Self>,
_slices: &[burn_backend::Slice],
) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_argmax(_tensor: QuantizedTensor<Self>, _dim: usize) -> IntTensor<Self> {
unimplemented!()
}
fn q_argmin(_tensor: QuantizedTensor<Self>, _dim: usize) -> IntTensor<Self> {
unimplemented!()
}
fn q_max_dim_with_indices(
_tensor: QuantizedTensor<Self>,
_dim: usize,
) -> (QuantizedTensor<Self>, IntTensor<Self>) {
unimplemented!()
}
fn q_max_dim(_tensor: QuantizedTensor<Self>, _dim: usize) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_min_dim(_tensor: QuantizedTensor<Self>, _dim: usize) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_min_dim_with_indices(
_tensor: QuantizedTensor<Self>,
_dim: usize,
) -> (QuantizedTensor<Self>, IntTensor<Self>) {
unimplemented!()
}
fn q_expand(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_sort(
_tensor: QuantizedTensor<Self>,
_dim: usize,
_descending: bool,
) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_sort_with_indices(
_tensor: QuantizedTensor<Self>,
_dim: usize,
_descending: bool,
) -> (QuantizedTensor<Self>, IntTensor<Self>) {
unimplemented!()
}
fn q_argsort(
_tensor: QuantizedTensor<Self>,
_dim: usize,
_descending: bool,
) -> IntTensor<Self> {
unimplemented!()
}
}

View File

@@ -0,0 +1,539 @@
use super::TchOps;
use crate::{IntoKind, LibTorch, LibTorchDevice, TchShape, TchTensor, element::TchElement};
use burn_backend::backend::ExecutionError;
use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor};
use burn_backend::{
DType, Distribution, FloatDType, Shape, TensorData, TensorMetadata, ops::FloatTensorOps,
};
use burn_backend::{Scalar, bf16, f16};
impl<E: TchElement> FloatTensorOps<Self> for LibTorch<E> {
fn float_from_data(data: TensorData, device: &LibTorchDevice) -> TchTensor {
match data.dtype {
DType::F64 => TchTensor::from_data::<f64>(data, (*device).into()),
DType::F32 => TchTensor::from_data::<f32>(data, (*device).into()),
DType::F16 => TchTensor::from_data::<f16>(data, (*device).into()),
DType::BF16 => TchTensor::from_data::<bf16>(data, (*device).into()),
_ => unimplemented!("Unsupported dtype for `float_from_data`"),
}
}
fn float_random(
shape: Shape,
distribution: Distribution,
device: &LibTorchDevice,
) -> TchTensor {
match distribution {
Distribution::Default => {
let mut tensor = TchTensor::empty::<E>(shape, *device);
tensor
.mut_ops(|tensor| tensor.rand_like_out(tensor))
.unwrap()
}
Distribution::Bernoulli(prob) => {
let mut tensor = TchTensor::empty::<E>(shape, *device);
tensor
.mut_ops(|tensor| tensor.f_bernoulli_float_(prob).unwrap())
.unwrap()
}
Distribution::Uniform(from, to) => {
let mut tensor = TchTensor::empty::<E>(shape, *device);
tensor.mut_ops(|tensor| tensor.uniform_(from, to)).unwrap()
}
Distribution::Normal(mean, std) => {
let mut tensor = TchTensor::empty::<E>(shape, *device);
tensor.mut_ops(|tensor| tensor.normal_(mean, std)).unwrap()
}
}
}
fn float_repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor {
TchOps::repeat_dim(tensor, dim, times)
}
fn float_zeros(shape: Shape, device: &LibTorchDevice, dtype: FloatDType) -> TchTensor {
let shape = TchShape::from(shape);
let device: tch::Device = (*device).into();
TchTensor::new(tch::Tensor::zeros(shape.dims, (dtype.into_kind(), device)))
}
fn float_ones(shape: Shape, device: &LibTorchDevice, dtype: FloatDType) -> TchTensor {
let shape = TchShape::from(shape);
let device: tch::Device = (*device).into();
TchTensor::new(tch::Tensor::ones(shape.dims, (dtype.into_kind(), device)))
}
async fn float_into_data(tensor: TchTensor) -> Result<TensorData, ExecutionError> {
let shape = tensor.shape();
let tensor = Self::float_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
Ok(match tensor.tensor.kind() {
tch::Kind::Half => {
let values = Vec::<f16>::try_from(&tensor).unwrap();
TensorData::new(values, shape)
}
tch::Kind::Float => {
let values = Vec::<f32>::try_from(&tensor).unwrap();
TensorData::new(values, shape)
}
tch::Kind::Double => {
let values = Vec::<f64>::try_from(&tensor).unwrap();
TensorData::new(values, shape)
}
tch::Kind::BFloat16 => {
let values = Vec::<bf16>::try_from(&tensor).unwrap();
TensorData::new(values, shape)
}
_ => panic!("Not a valid float kind"),
})
}
fn float_device(tensor: &TchTensor) -> LibTorchDevice {
tensor.tensor.device().into()
}
fn float_to_device(tensor: TchTensor, device: &LibTorchDevice) -> TchTensor {
TchOps::to_device(tensor, device)
}
fn float_empty(shape: Shape, device: &LibTorchDevice, dtype: FloatDType) -> TchTensor {
let tensor = tch::Tensor::empty(
TchShape::from(shape).dims,
(dtype.into_kind(), (*device).into()),
);
TchTensor::new(tensor)
}
fn float_add(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchOps::add(lhs, rhs)
}
fn float_add_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
let rhs: f64 = rhs.elem();
lhs.unary_ops(
|mut tensor| tensor.f_add_scalar_(rhs).unwrap(),
|tensor| tensor.f_add_scalar(rhs).unwrap(),
)
}
fn float_sub(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchOps::sub(lhs, rhs)
}
fn float_sub_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
let rhs: f64 = rhs.elem();
lhs.unary_ops(
|mut tensor| tensor.f_sub_scalar_(rhs).unwrap(),
|tensor| tensor.f_sub_scalar(rhs).unwrap(),
)
}
fn float_mul(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchOps::mul(lhs, rhs)
}
fn float_mul_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
let rhs: f64 = rhs.elem();
lhs.unary_ops(
|mut tensor| tensor.f_mul_scalar_(rhs).unwrap(),
|tensor| tensor.f_mul_scalar(rhs).unwrap(),
)
}
fn float_div(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchOps::div(lhs, rhs)
}
fn float_div_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
let rhs: f64 = rhs.elem();
lhs.unary_ops(
|mut tensor| tensor.f_div_scalar_(rhs).unwrap(),
|tensor| tensor.f_div_scalar(rhs).unwrap(),
)
}
fn float_remainder(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchOps::remainder(lhs, rhs)
}
fn float_remainder_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
let rhs: f64 = rhs.elem();
lhs.unary_ops(
|tensor| tensor.f_remainder(rhs).unwrap(),
|tensor| tensor.f_remainder(rhs).unwrap(),
)
}
fn float_matmul(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
let tensor = lhs.tensor.matmul(&rhs.tensor);
TchTensor::new(tensor)
}
fn float_cross(lhs: TchTensor, rhs: TchTensor, dim: usize) -> TchTensor {
let tensor = lhs.tensor.cross(&rhs.tensor, dim as i64);
TchTensor::new(tensor)
}
fn float_recip(tensor: TchTensor) -> TchTensor {
TchTensor::new(tensor.tensor.reciprocal())
}
fn float_swap_dims(tensor: TchTensor, dim1: usize, dim2: usize) -> TchTensor {
TchOps::swap_dims(tensor, dim1, dim2)
}
fn float_reshape(tensor: TchTensor, shape: Shape) -> TchTensor {
TchOps::reshape(tensor, shape)
}
fn float_gather(dim: usize, tensor: TchTensor, indices: TchTensor) -> TchTensor {
TchOps::gather(dim, tensor, indices)
}
fn float_scatter_add(
dim: usize,
tensor: TchTensor,
indices: TchTensor,
value: TchTensor,
) -> TchTensor {
TchOps::scatter(dim, tensor, indices, value)
}
fn float_select(tensor: TchTensor, dim: usize, indices: TchTensor) -> TchTensor {
TchOps::index_select_dim(tensor, dim, indices)
}
fn float_select_add(
tensor: TchTensor,
dim: usize,
indices: TchTensor,
value: TchTensor,
) -> TchTensor {
TchOps::select_assign(tensor, dim, indices, value)
}
fn float_slice(tensor: TchTensor, slices: &[burn_backend::Slice]) -> TchTensor {
TchOps::slice_with_steps(tensor, slices)
}
fn float_slice_assign(
tensor: TchTensor,
slices: &[burn_backend::Slice],
value: TchTensor,
) -> TchTensor {
TchOps::slice_assign(tensor, slices, value)
}
fn float_mask_where(tensor: TchTensor, mask: TchTensor, value: TchTensor) -> TchTensor {
let output = value.tensor.where_self(&mask.tensor, &tensor.tensor);
TchTensor::new(output)
}
fn float_mask_fill(tensor: TchTensor, mask: TchTensor, value: Scalar) -> TchTensor {
let value: f64 = value.elem();
tensor.unary_ops(
|mut tensor| tensor.f_masked_fill_(&mask.tensor, value).unwrap(),
|tensor| tensor.f_masked_fill(&mask.tensor, value).unwrap(),
)
}
fn float_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchOps::equal(lhs, rhs)
}
fn float_equal_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor {
TchOps::equal_elem(lhs, rhs.elem::<f64>())
}
fn float_greater(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchOps::greater(lhs, rhs)
}
fn float_greater_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor {
TchOps::greater_elem(lhs, rhs.elem::<f64>())
}
fn float_greater_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchOps::greater_equal(lhs, rhs)
}
fn float_greater_equal_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor {
TchOps::greater_equal_elem(lhs, rhs.elem::<f64>())
}
fn float_lower(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchOps::lower(lhs, rhs)
}
fn float_lower_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor {
TchOps::lower_elem(lhs, rhs.elem::<f64>())
}
fn float_lower_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchOps::lower_equal(lhs, rhs)
}
fn float_lower_equal_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor {
TchOps::lower_equal_elem(lhs, rhs.elem::<f64>())
}
fn float_mean(tensor: TchTensor) -> TchTensor {
TchOps::mean(tensor)
}
fn float_sum(tensor: TchTensor) -> TchTensor {
TchOps::sum(tensor)
}
fn float_sum_dim(tensor: TchTensor, dim: usize) -> TchTensor {
TchOps::sum_dim(tensor, dim)
}
fn float_mean_dim(tensor: TchTensor, dim: usize) -> TchTensor {
TchOps::mean_dim(tensor, dim)
}
fn float_cumsum(tensor: TchTensor, dim: usize) -> TchTensor {
TchOps::cumsum(tensor, dim)
}
fn float_cumprod(tensor: TchTensor, dim: usize) -> TchTensor {
TchOps::cumprod(tensor, dim)
}
fn float_cummin(tensor: TchTensor, dim: usize) -> TchTensor {
TchOps::cummin(tensor, dim)
}
fn float_cummax(tensor: TchTensor, dim: usize) -> TchTensor {
TchOps::cummax(tensor, dim)
}
fn float_prod(tensor: TchTensor) -> TchTensor {
TchOps::prod(tensor)
}
fn float_prod_dim(tensor: TchTensor, dim: usize) -> TchTensor {
TchOps::prod_dim(tensor, dim)
}
fn float_argmax(tensor: TchTensor, dim: usize) -> TchTensor {
TchOps::argmax(tensor, dim)
}
fn float_argmin(tensor: TchTensor, dim: usize) -> TchTensor {
TchOps::argmin(tensor, dim)
}
fn float_max_dim(tensor: TchTensor, dim: usize) -> TchTensor {
TchOps::max_dim(tensor, dim)
}
fn float_max_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {
TchOps::max_dim_with_indices(tensor, dim)
}
fn float_min_dim(tensor: TchTensor, dim: usize) -> TchTensor {
TchOps::min_dim(tensor, dim)
}
fn float_min_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {
TchOps::min_dim_with_indices(tensor, dim)
}
fn float_exp(tensor: TchTensor) -> TchTensor {
tensor.unary_ops(|mut tensor| tensor.exp_(), |tensor| tensor.exp())
}
fn float_log(tensor: TchTensor) -> TchTensor {
tensor.unary_ops(|mut tensor| tensor.log_(), |tensor| tensor.log())
}
fn float_log1p(tensor: TchTensor) -> TchTensor {
tensor.unary_ops(|mut tensor| tensor.log1p_(), |tensor| tensor.log1p())
}
fn float_powf_scalar_impl(tensor: TchTensor, value: Scalar) -> TchTensor {
tensor.unary_ops(
|mut tensor| tensor.f_pow_(value.elem::<f64>()).unwrap(),
|tensor| tensor.pow_tensor_scalar(value.elem::<f64>()),
)
}
fn float_sqrt(tensor: TchTensor) -> TchTensor {
tensor.unary_ops(|mut tensor| tensor.sqrt_(), |tensor| tensor.sqrt())
}
fn float_abs(tensor: TchTensor) -> TchTensor {
tensor.unary_ops(|mut tensor| tensor.abs_(), |tensor| tensor.abs())
}
fn float_cos(tensor: TchTensor) -> TchTensor {
tensor.unary_ops(|mut tensor| tensor.cos_(), |tensor| tensor.cos())
}
fn float_cosh(tensor: TchTensor) -> TchTensor {
tensor.unary_ops(|mut tensor| tensor.cosh_(), |tensor| tensor.cosh())
}
fn float_sin(tensor: TchTensor) -> TchTensor {
tensor.unary_ops(|mut tensor| tensor.sin_(), |tensor| tensor.sin())
}
fn float_sinh(tensor: TchTensor) -> TchTensor {
tensor.unary_ops(|mut tensor| tensor.sinh_(), |tensor| tensor.sinh())
}
fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
tensor.unary_ops(|mut tensor| tensor.tan_(), |tensor| tensor.tan())
}
fn float_tanh(tensor: TchTensor) -> TchTensor {
tensor.unary_ops(|mut tensor| tensor.tanh_(), |tensor| tensor.tanh())
}
fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
tensor.unary_ops(|mut tensor| tensor.acos_(), |tensor| tensor.acos())
}
fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
tensor.unary_ops(|mut tensor| tensor.acosh_(), |tensor| tensor.acosh())
}
fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
tensor.unary_ops(|mut tensor| tensor.asin_(), |tensor| tensor.asin())
}
fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
tensor.unary_ops(|mut tensor| tensor.asinh_(), |tensor| tensor.asinh())
}
fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
tensor.unary_ops(|mut tensor| tensor.atan_(), |tensor| tensor.atan())
}
fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
tensor.unary_ops(|mut tensor| tensor.atanh_(), |tensor| tensor.atanh())
}
fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
TchOps::atan2(lhs, rhs)
}
fn float_round(tensor: TchTensor) -> TchTensor {
tensor.unary_ops(|mut tensor| tensor.round_(), |tensor| tensor.round())
}
fn float_floor(tensor: TchTensor) -> TchTensor {
tensor.unary_ops(|mut tensor| tensor.floor_(), |tensor| tensor.floor())
}
fn float_ceil(tensor: TchTensor) -> TchTensor {
tensor.unary_ops(|mut tensor| tensor.ceil_(), |tensor| tensor.ceil())
}
fn float_trunc(tensor: TchTensor) -> TchTensor {
tensor.unary_ops(|mut tensor| tensor.trunc_(), |tensor| tensor.trunc())
}
fn float_erf(tensor: TchTensor) -> TchTensor {
tensor.unary_ops(|mut tensor| tensor.erf_(), |tensor| tensor.erf())
}
fn float_cat(tensors: Vec<TchTensor>, dim: usize) -> TchTensor {
TchOps::cat(tensors, dim)
}
fn float_clamp_min(tensor: TchTensor, min: Scalar) -> TchTensor {
TchOps::clamp_min(tensor, min.elem::<f64>())
}
fn float_clamp_max(tensor: TchTensor, max: Scalar) -> TchTensor {
TchOps::clamp_max(tensor, max.elem::<f64>())
}
fn float_clamp(tensor: TchTensor, min: Scalar, max: Scalar) -> TchTensor {
TchOps::clamp(tensor, min.elem::<f64>(), max.elem::<f64>())
}
fn float_into_int(tensor: TchTensor) -> TchTensor {
let tensor = tensor.tensor.to_kind(tch::Kind::Int64);
TchTensor::new(tensor)
}
fn float_powf(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
TchOps::powf(lhs, rhs)
}
fn float_permute(tensor: TchTensor, axes: &[usize]) -> TchTensor {
TchOps::permute(tensor, axes)
}
fn float_flip(tensor: TchTensor, axes: &[usize]) -> TchTensor {
TchOps::flip(tensor, axes)
}
fn float_sign(tensor: TchTensor) -> TchTensor {
TchOps::sign(tensor)
}
fn float_expand(tensor: TchTensor, shape: Shape) -> TchTensor {
TchOps::expand(tensor, shape)
}
fn float_sort(tensor: TchTensor, dim: usize, descending: bool) -> TchTensor {
TchOps::sort(tensor, dim, descending)
}
fn float_sort_with_indices(
tensor: TchTensor,
dim: usize,
descending: bool,
) -> (TchTensor, TchTensor) {
TchOps::sort_with_indices(tensor, dim, descending)
}
fn float_argsort(tensor: TchTensor, dim: usize, descending: bool) -> IntTensor<Self> {
TchOps::argsort(tensor, dim, descending)
}
fn float_cast(tensor: TchTensor, dtype: FloatDType) -> TchTensor {
// NOTE: when dtypes of inputs to an arithmetic operation differ, tch handles type
// promotion based on a set of rules: https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc
// Type promotion is not automatic on all backends so this behavior might differ
let kind = dtype.into_kind();
if tensor.tensor.kind() == kind {
tensor
} else {
TchTensor::new(tensor.tensor.to_kind(kind))
}
}
fn float_unfold(
tensor: FloatTensor<Self>,
dim: usize,
size: usize,
step: usize,
) -> FloatTensor<Self> {
TchOps::unfold(tensor, dim, size, step)
}
fn float_is_nan(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
TchTensor::new(tensor.tensor.isnan())
}
fn float_is_inf(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
TchTensor::new(tensor.tensor.isinf())
}
}

View File

@@ -0,0 +1,5 @@
use burn_backend::ops::TransactionOps;
use crate::{LibTorch, TchElement};
impl<E: TchElement> TransactionOps<Self> for LibTorch<E> {}

View File

@@ -0,0 +1,507 @@
use crate::{LibTorchDevice, TchElement};
use burn_backend::{DType, FloatDType, IntDType, Shape, TensorData, TensorMetadata};
use libc::c_void;
use std::sync::Arc;
/// A reference to a tensor storage.
///
/// We manually implement `Sync` and `Send` unsafely, so even if we could use `Rc`, it isn't safe.
#[allow(clippy::arc_with_non_send_sync)]
pub type StorageRef = Arc<*mut c_void>;
/// A reference to a tensor storage.
#[derive(PartialEq, Debug, Clone)]
pub enum Storage {
/// When a tensor is a partial view of another tensor.
View {
/// Storage reference for the whole buffer.
buffer_ref: StorageRef,
/// Storage reference for the partial buffer.
view_ref: StorageRef,
},
/// When a tensor use all of its buffer.
Owned {
/// Storage reference for the whole buffer.
buffer_ref: StorageRef,
},
}
impl Storage {
/// Check if the storage can be used inplace.
pub fn can_mut(&self) -> bool {
match self {
Storage::View {
buffer_ref: start_ref,
view_ref,
} => Arc::strong_count(start_ref) == 1 && Arc::strong_count(view_ref) == 1,
Storage::Owned {
buffer_ref: start_ref,
} => Arc::strong_count(start_ref) == 1,
}
}
/// Get the whole buffer reference.
pub fn buffer_ref(&self) -> &StorageRef {
match self {
Storage::View {
buffer_ref: start_ref,
view_ref: _,
} => start_ref,
Storage::Owned {
buffer_ref: start_ref,
} => start_ref,
}
}
}
/// A tensor using the tch backend.
#[derive(Debug, PartialEq)]
pub struct TchTensor {
/// Handle to the tensor. Call methods on this field.
pub tensor: tch::Tensor,
/// The tensor's storage
pub storage: Storage,
}
impl TensorMetadata for TchTensor {
fn dtype(&self) -> DType {
match self.tensor.kind() {
tch::Kind::Uint8 => DType::U8,
tch::Kind::Int8 => DType::I8,
tch::Kind::Int16 => DType::I16,
tch::Kind::Int => DType::I32,
tch::Kind::Int64 => DType::I64,
tch::Kind::Half => DType::F16,
tch::Kind::Float => DType::F32,
tch::Kind::Double => DType::F64,
tch::Kind::Bool => DType::Bool,
tch::Kind::BFloat16 => DType::BF16,
// Complex and quantization types are not valid/implemented.
_ => unimplemented!(),
}
}
fn shape(&self) -> Shape {
Shape::from(self.tensor.size())
}
fn rank(&self) -> usize {
self.tensor.dim()
}
}
impl burn_backend::QTensorPrimitive for TchTensor {
fn scheme(&self) -> &burn_backend::quantization::QuantScheme {
unimplemented!("Quantization is not supported")
}
}
impl core::fmt::Display for TchTensor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.tensor)
}
}
pub(crate) trait IntoKind {
fn try_into_kind(self) -> Result<tch::Kind, tch::TchError>;
fn into_kind(self) -> tch::Kind
where
Self: Sized,
{
self.try_into_kind().unwrap()
}
}
impl IntoKind for IntDType {
fn try_into_kind(self) -> Result<tch::Kind, tch::TchError> {
let dtype: DType = self.into();
dtype.try_into_kind()
}
}
impl IntoKind for FloatDType {
fn try_into_kind(self) -> Result<tch::Kind, tch::TchError> {
let dtype: DType = self.into();
dtype.try_into_kind()
}
}
impl IntoKind for DType {
fn try_into_kind(self) -> Result<tch::Kind, tch::TchError> {
match self {
DType::F64 => Ok(tch::Kind::Double),
DType::F32 => Ok(tch::Kind::Float),
DType::Flex32 => Ok(tch::Kind::Float),
DType::F16 => Ok(tch::Kind::Half),
DType::BF16 => Ok(tch::Kind::BFloat16),
DType::I64 => Ok(tch::Kind::Int64),
DType::I32 => Ok(tch::Kind::Int),
DType::I16 => Ok(tch::Kind::Int16),
DType::I8 => Ok(tch::Kind::Int8),
DType::U8 => Ok(tch::Kind::Uint8),
DType::Bool => Ok(tch::Kind::Bool),
other => Err(tch::TchError::Kind(format!("Unsupported dtype {other:?}"))),
}
}
}
impl TchTensor {
/// Create a new tensor.
///
/// Note that if the tensor was created from an operation that may reuse the same tensor
/// storage as the parent, you should use [from_existing](TchTensor::from_existing)
/// instead.
pub fn new(tensor: tch::Tensor) -> Self {
#[allow(clippy::arc_with_non_send_sync)]
let storage = Storage::Owned {
buffer_ref: Arc::new(tensor.data_ptr()),
};
Self { tensor, storage }
}
/// Create a tensor that was created from an operation executed on a parent tensor.
///
/// If the child tensor shared the same storage as its parent, it will be cloned, effectively
/// tracking how much tensors point to the same memory space.
pub fn from_existing(tensor: tch::Tensor, storage_parent: Storage) -> Self {
let storage_child = tensor.data_ptr();
let mut is_a_new_tensor = true;
match &storage_parent {
Storage::View {
buffer_ref: start_ref,
view_ref,
} => {
if storage_child == *start_ref.as_ref() || storage_child == *view_ref.as_ref() {
is_a_new_tensor = false;
}
}
Storage::Owned {
buffer_ref: start_ref,
} => {
if storage_child == *start_ref.as_ref() {
is_a_new_tensor = false;
}
}
};
let storage = match is_a_new_tensor {
true => Storage::Owned {
#[allow(clippy::arc_with_non_send_sync)]
buffer_ref: Arc::new(storage_child),
},
false => storage_parent.clone(),
};
Self { tensor, storage }
}
/// Create a tensor that uses a part of its parent tensor such as slice and narrow.
pub fn partial(tensor: tch::Tensor, storage_parent: Storage) -> Self {
let storage = Storage::View {
buffer_ref: storage_parent.buffer_ref().clone(),
#[allow(clippy::arc_with_non_send_sync)]
view_ref: Arc::new(tensor.data_ptr()),
};
Self { tensor, storage }
}
}
// This is safe since we don't use autodiff from LibTorch.
// Also, atomic reference counting is used to know if the tensor's data can be reused.
// If there are multiple reference on the same tensor, it becomes read only.
unsafe impl Send for TchTensor {}
unsafe impl Sync for TchTensor {}
impl TchTensor {
/// Checks if the tensor can be mutated in-place.
///
/// Returns `true` if the tensor's stride does not contain zero (no broadcasting)
/// and the storage can be mutated.
pub fn can_mut(&self) -> bool {
let stride_contains_zero = self.tensor.stride().contains(&0);
!stride_contains_zero && self.storage.can_mut()
}
/// Executes an operation on a tensor if the data can be reused.
pub fn mut_ops<F: Fn(&mut tch::Tensor) -> tch::Tensor>(
&mut self,
func: F,
) -> Option<TchTensor> {
if !self.can_mut() {
return None;
}
let data = self.storage.clone();
Some(TchTensor::from_existing(func(&mut self.tensor), data))
}
/// Executes a unary operation, reusing the tensor data if possible.
pub fn unary_ops<FOwn, FRef>(self, fown: FOwn, fref: FRef) -> TchTensor
where
FOwn: Fn(tch::Tensor) -> tch::Tensor,
FRef: Fn(&tch::Tensor) -> tch::Tensor,
{
if !self.can_mut() {
return TchTensor::from_existing(fref(&self.tensor), self.storage);
}
TchTensor::from_existing(fown(self.tensor), self.storage)
}
/// Executes a binary operation, reusing the tensor data if possible.
pub fn binary_ops_tensor<FLMut, FRMut, FRef>(
mut lhs: Self,
mut rhs: Self,
flmut: FLMut,
frmut: FRMut,
fref: FRef,
) -> TchTensor
where
FLMut: Fn(&mut tch::Tensor, &tch::Tensor) -> tch::Tensor,
FRMut: Fn(&tch::Tensor, &mut tch::Tensor) -> tch::Tensor,
FRef: Fn(&tch::Tensor, &tch::Tensor) -> tch::Tensor,
{
let lhs_shape = lhs.shape();
let rhs_shape = rhs.shape();
// Both lhs and rhs are expected to have the same rank
let d_out = lhs_shape.num_dims();
let mut out_shape = Shape::from(vec![1usize; d_out]);
for i in 0..d_out {
out_shape[i] = usize::max(lhs_shape[i], rhs_shape[i]);
}
let num_elements_out = out_shape.num_elements();
// Attempt to mutate lhs tensor
if lhs_shape.num_elements() == num_elements_out
&& let Some(output) = lhs.mut_ops(|lhs| flmut(lhs, &rhs.tensor))
{
return output;
}
// Attempt to mutate rhs tensor
if rhs_shape.num_elements() == num_elements_out
&& let Some(output) = rhs.mut_ops(|rhs| frmut(&lhs.tensor, rhs))
{
return output;
}
let storage = lhs.storage;
let tensor = fref(&lhs.tensor, &rhs.tensor);
TchTensor::from_existing(tensor, storage)
}
}
impl Clone for TchTensor {
fn clone(&self) -> Self {
Self {
tensor: self.tensor.shallow_clone(),
storage: self.storage.clone(),
}
}
}
/// A shape that can be used by LibTorch.
#[derive(Debug)]
pub struct TchShape {
/// The shape's dimensions.
pub dims: Vec<i64>,
}
impl From<Shape> for TchShape {
fn from(shape: Shape) -> Self {
TchShape {
dims: shape.iter().map(|d| *d as i64).collect(),
}
}
}
impl From<&[usize]> for TchShape {
fn from(shape: &[usize]) -> Self {
TchShape {
dims: shape.iter().map(|d| *d as i64).collect(),
}
}
}
impl TchTensor {
/// Creates a new tensor from a shape and a device.
///
/// # Arguments
///
/// * `data` - The tensor's data.
/// * `device` - The device on which the tensor will be allocated.
///
/// # Returns
///
/// A new tensor.
pub fn from_data<E: TchElement>(data: TensorData, device: tch::Device) -> Self {
let shape_tch = TchShape::from(data.shape.as_slice());
let tensor =
tch::Tensor::from_data_size(&data.bytes, &shape_tch.dims, E::kind()).to(device);
Self::new(tensor)
}
}
impl TchTensor {
/// Creates an empty tensor from a shape and a device.
///
/// # Arguments
///
/// * `shape` - The shape of the tensor.
/// * `device` - The device to create the tensor on.
///
/// # Returns
///
/// A new empty tensor.
pub fn empty<E: TchElement>(shape: Shape, device: LibTorchDevice) -> Self {
let shape_tch = TchShape::from(shape);
let tensor = tch::Tensor::empty(shape_tch.dims, (E::kind(), device.into()));
Self::new(tensor)
}
}
// Adapted from `tch` to use patched `T::kind()` instead of `T::KIND` which is incorrect for bf16.
// TODO: remove when fixed in `tch` release (https://github.com/LaurentMazare/tch-rs/pull/996).
impl<T: TchElement + Copy> TryFrom<&TchTensor> for Vec<T> {
type Error = tch::TchError;
fn try_from(tensor: &TchTensor) -> Result<Self, Self::Error> {
let tensor = &tensor.tensor;
let size = tensor.size();
if size.len() != 1 {
Err(tch::TchError::Convert(format!(
"Attempting to convert a Tensor with {} dimensions to flat vector",
size.len()
)))?;
}
let numel = size[0] as usize;
let mut vec = vec![T::ZERO; numel];
// Adapted to use patched `T::kind()` instead
// TODO: tensor.f_to_kind(T::KIND)?.f_copy_data(&mut vec, numel)?;
f_copy_data(&mut tensor.f_to_kind(T::kind())?, &mut vec, numel)?;
Ok(vec)
}
}
unsafe fn ptr_to_string(ptr: *mut libc::c_char) -> Option<String> {
if !ptr.is_null() {
unsafe {
let str = std::ffi::CStr::from_ptr(ptr).to_string_lossy().into_owned();
libc::free(ptr as *mut libc::c_void);
Some(str)
}
} else {
None
}
}
/// Copies `numel` elements from `self` to `dst`.
fn f_copy_data<T: TchElement>(
tensor: &mut tch::Tensor,
dst: &mut [T],
numel: usize,
) -> Result<(), tch::TchError> {
if T::kind() != tensor.f_kind()? {
return Err(tch::TchError::Kind(format!(
"incoherent elt kind, {:?} != {:?}",
tensor.f_kind(),
T::kind()
)));
}
if dst.len() < numel {
return Err(tch::TchError::Shape(format!("slice len < {numel}")));
}
unsafe {
torch_sys::at_copy_data(
tensor.as_mut_ptr(),
dst.as_mut_ptr() as *const c_void,
numel,
T::kind().elt_size_in_bytes(),
);
match ptr_to_string(torch_sys::get_and_reset_last_err()) {
None => Ok(()),
Some(c_error) => Err(tch::TchError::Torch(c_error)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn_backend::ops::FloatTensorOps;
use burn_backend::{Backend, quantization::QuantScheme, read_sync};
type B = crate::LibTorch<f32>;
#[test]
fn should_have_bf16_kind() {
let data = TensorData::from([4.0, 4.0]);
let tensor_1: TchTensor = B::float_from_data(data, &Default::default());
let tensor_2 = B::float_cast(tensor_1, DType::BF16.into());
assert_eq!(tensor_2.tensor.kind(), tch::Kind::BFloat16);
let out = read_sync(B::float_into_data(tensor_2)).unwrap();
out.assert_eq(&TensorData::from([4.0, 4.0]), false);
}
#[test]
fn should_support_dtypes() {
let device = Default::default();
assert!(B::supports_dtype(&device, DType::F64));
assert!(B::supports_dtype(&device, DType::F32));
assert!(B::supports_dtype(&device, DType::Flex32));
assert!(B::supports_dtype(&device, DType::F16));
assert!(B::supports_dtype(&device, DType::BF16));
assert!(B::supports_dtype(&device, DType::I64));
assert!(B::supports_dtype(&device, DType::I32));
assert!(B::supports_dtype(&device, DType::I16));
assert!(B::supports_dtype(&device, DType::I8));
assert!(B::supports_dtype(&device, DType::U8));
assert!(B::supports_dtype(&device, DType::Bool));
assert!(!B::supports_dtype(&device, DType::U64));
assert!(!B::supports_dtype(&device, DType::U32));
assert!(!B::supports_dtype(&device, DType::U16));
assert!(!B::supports_dtype(
&device,
DType::QFloat(QuantScheme::default())
));
}
#[test]
fn should_support_from_bf16() {
let data = TensorData::from([[1.0], [1.]]).convert_dtype(DType::BF16);
let tensor_1: TchTensor = B::float_from_data(data, &Default::default());
let data = TensorData::from([[2.0], [2.]]).convert_dtype(DType::BF16);
let tensor_2 = B::float_from_data(data, &Default::default());
let tensor_3 = B::float_add(tensor_1, tensor_2);
assert_eq!(tensor_3.tensor.kind(), tch::Kind::BFloat16);
let out = read_sync(B::float_into_data(tensor_3)).unwrap();
out.assert_eq(&TensorData::from([[3.0], [3.0]]), false);
}
}
unsafe extern "C" {
/// Dummy function to get CUDA to link properly
pub fn dummy_cuda_dependency();
}
#[used]
static INIT_ARRAY: [unsafe extern "C" fn(); 1] = [dummy_cuda_dependency];