feat: update workspace paths and enhance gitignore
- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
This commit is contained in:
38
crates/stable-diffusion-burn/burn-crates/burn-tch/Cargo.toml
Normal file
38
crates/stable-diffusion-burn/burn-crates/burn-tch/Cargo.toml
Normal file
@@ -0,0 +1,38 @@
|
||||
[package]
|
||||
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||
categories = ["science"]
|
||||
description = "LibTorch backend for the Burn framework using the tch bindings."
|
||||
documentation = "https://docs.rs/burn-tch"
|
||||
edition.workspace = true
|
||||
keywords = ["deep-learning", "machine-learning", "data"]
|
||||
license.workspace = true
|
||||
name = "burn-tch"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-tch"
|
||||
version.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[features]
|
||||
default = ["std"]
|
||||
std = ["burn-backend/std"]
|
||||
doc = ["tch/doc-only"]
|
||||
tracing = [
|
||||
"burn-backend/tracing",
|
||||
]
|
||||
|
||||
[dependencies]
|
||||
burn-backend = { path = "../burn-backend", version = "=0.21.0-pre.2", default-features = false }
|
||||
|
||||
libc = { workspace = true }
|
||||
log = { workspace = true }
|
||||
tch = { workspace = true, features = ["download-libtorch"] }
|
||||
torch-sys = { workspace = true } # for build script lib dir detection
|
||||
|
||||
[build-dependencies]
|
||||
cc = "1.2.56"
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
features = ["doc"]
|
||||
rustdoc-args = ["--cfg", "docsrs"]
|
||||
1
crates/stable-diffusion-burn/burn-crates/burn-tch/LICENSE-APACHE
Symbolic link
1
crates/stable-diffusion-burn/burn-crates/burn-tch/LICENSE-APACHE
Symbolic link
@@ -0,0 +1 @@
|
||||
../../LICENSE-APACHE
|
||||
1
crates/stable-diffusion-burn/burn-crates/burn-tch/LICENSE-MIT
Symbolic link
1
crates/stable-diffusion-burn/burn-crates/burn-tch/LICENSE-MIT
Symbolic link
@@ -0,0 +1 @@
|
||||
../../LICENSE-MIT
|
||||
246
crates/stable-diffusion-burn/burn-crates/burn-tch/README.md
Normal file
246
crates/stable-diffusion-burn/burn-crates/burn-tch/README.md
Normal file
@@ -0,0 +1,246 @@
|
||||
# Burn Torch Backend
|
||||
|
||||
[Burn](https://github.com/tracel-ai/burn) Torch backend
|
||||
|
||||
[](https://crates.io/crates/burn-tch)
|
||||
[](https://github.com/tracel-ai/burn-tch/blob/master/README.md)
|
||||
|
||||
This crate provides a Torch backend for [Burn](https://github.com/tracel-ai/burn) utilizing the
|
||||
[`tch-rs`](https://github.com/LaurentMazare/tch-rs) crate, which offers a Rust interface to the
|
||||
[PyTorch](https://pytorch.org/) C++ API.
|
||||
|
||||
The backend supports CPU (multithreaded), [CUDA](https://pytorch.org/docs/stable/notes/cuda.html)
|
||||
(multiple GPUs), and [MPS](https://pytorch.org/docs/stable/notes/mps.html) devices (MacOS).
|
||||
|
||||
## Installation
|
||||
|
||||
[`tch-rs`](https://github.com/LaurentMazare/tch-rs) requires the C++ PyTorch library (LibTorch) to
|
||||
be available on your system.
|
||||
|
||||
By default, the CPU distribution is installed for LibTorch v2.9.0 as required by `tch-rs`.
|
||||
|
||||
<details>
|
||||
<summary><strong>CUDA</strong></summary>
|
||||
|
||||
To install the latest compatible CUDA distribution, set the `TORCH_CUDA_VERSION` environment
|
||||
variable before the `tch-rs` dependency is retrieved with `cargo`.
|
||||
|
||||
```shell
|
||||
export TORCH_CUDA_VERSION=cu128
|
||||
```
|
||||
|
||||
On Windows:
|
||||
|
||||
```powershell
|
||||
$Env:TORCH_CUDA_VERSION = "cu128"
|
||||
```
|
||||
|
||||
> Note: `tch` doesn't expose the downloaded libtorch directory on Windows when using the automatic
|
||||
> download feature, so the `torch_cuda.dll` cannot be detected properly during build. In this case,
|
||||
> you can set the `LIBTORCH` environment variable to point to the `libtorch/` folder in `torch-sys`
|
||||
> `OUT_DIR` (or move the downloaded lib to a different folder and point to it).
|
||||
|
||||
For example, running the validation sample for the first time could be done with the following
|
||||
commands:
|
||||
|
||||
```shell
|
||||
export TORCH_CUDA_VERSION=cu128
|
||||
cargo run --bin cuda --release
|
||||
```
|
||||
|
||||
**Important:** make sure your driver version is compatible with the selected CUDA version. A CUDA
|
||||
Toolkit installation is not required since LibTorch ships with the appropriate CUDA runtimes. Having
|
||||
the latest driver version is recommended, but you can always take a look at the
|
||||
[toolkit driver version table](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#id4)
|
||||
or
|
||||
[minimum required driver version](https://docs.nvidia.com/deploy/cuda-compatibility/index.html#minor-version-compatibility)
|
||||
(limited feature-set, might not work with all operations).
|
||||
|
||||
</details><br>
|
||||
|
||||
Once your installation is complete, you should be able to build/run your project. You can also
|
||||
validate your installation by running the appropriate `cpu`, `cuda` or `mps` sample as below.
|
||||
|
||||
```shell
|
||||
cargo run --bin cpu --release
|
||||
cargo run --bin cuda --release
|
||||
cargo run --bin mps --release
|
||||
```
|
||||
|
||||
_Note: no MPS distribution is available for automatic download at this time, please check out the
|
||||
[manual instructions](#metal-mps)._
|
||||
|
||||
### Manual Download
|
||||
|
||||
To install `tch-rs` with a different LibTorch distribution, you will have to manually download the
|
||||
desired LibTorch distribution. The instructions are detailed in the sections below for each
|
||||
platform.
|
||||
|
||||
| Compute Platform | CPU | GPU | Linux | MacOS | Windows | Android | iOS | WASM |
|
||||
| :------------------------ | :----------------------------: | :-: | :---: | :---: | :-----: | :-----: | :-: | :--: |
|
||||
| [CPU](#cpu) | Yes | No | Yes | Yes | Yes | Yes | Yes | No |
|
||||
| [CUDA](#cuda) | Yes <sup>[[1]](#cpu-sup)</sup> | Yes | Yes | No | Yes | No | No | No |
|
||||
| [Metal (MPS)](#metal-mps) | No | Yes | No | Yes | No | No | No | No |
|
||||
| Vulkan | Yes | Yes | Yes | Yes | Yes | Yes | No | No |
|
||||
|
||||
<sup><a id="cpu-sup">[1]</a> The LibTorch CUDA distribution also comes with CPU support.</sup>
|
||||
|
||||
#### CPU
|
||||
|
||||
<details open>
|
||||
<summary><strong>🐧 Linux</strong></summary>
|
||||
|
||||
First, download the LibTorch CPU distribution.
|
||||
|
||||
```shell
|
||||
wget -O libtorch.zip https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-2.9.0%2Bcpu.zip
|
||||
unzip libtorch.zip
|
||||
```
|
||||
|
||||
Then, point to that installation using the `LIBTORCH` and `LD_LIBRARY_PATH` environment variables
|
||||
before building `burn-tch` or a crate which depends on it.
|
||||
|
||||
```shell
|
||||
export LIBTORCH=/absolute/path/to/libtorch/
|
||||
export LD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$LD_LIBRARY_PATH
|
||||
```
|
||||
|
||||
</details><br>
|
||||
|
||||
<details>
|
||||
<summary><strong>🍎 Mac</strong></summary>
|
||||
|
||||
First, download the LibTorch CPU distribution.
|
||||
|
||||
```shell
|
||||
wget -O libtorch.zip https://download.pytorch.org/libtorch/cpu/libtorch-macos-arm64-2.9.0.zip
|
||||
unzip libtorch.zip
|
||||
```
|
||||
|
||||
Then, point to that installation using the `LIBTORCH` and `DYLD_LIBRARY_PATH` environment variables
|
||||
before building `burn-tch` or a crate which depends on it.
|
||||
|
||||
```shell
|
||||
export LIBTORCH=/absolute/path/to/libtorch/
|
||||
export DYLD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$DYLD_LIBRARY_PATH
|
||||
```
|
||||
|
||||
</details><br>
|
||||
|
||||
<details>
|
||||
<summary><strong>🪟 Windows</strong></summary>
|
||||
|
||||
First, download the LibTorch CPU distribution.
|
||||
|
||||
```powershell
|
||||
wget https://download.pytorch.org/libtorch/cpu/libtorch-win-shared-with-deps-2.9.0%2Bcpu.zip -OutFile libtorch.zip
|
||||
Expand-Archive libtorch.zip
|
||||
```
|
||||
|
||||
Then, set the `LIBTORCH` environment variable and append the library to your path as with the
|
||||
PowerShell commands below before building `burn-tch` or a crate which depends on it.
|
||||
|
||||
```powershell
|
||||
$Env:LIBTORCH = "/absolute/path/to/libtorch/"
|
||||
$Env:Path += ";/absolute/path/to/libtorch/"
|
||||
```
|
||||
|
||||
</details><br>
|
||||
|
||||
#### CUDA
|
||||
|
||||
LibTorch 2.9.0 currently includes binary distributions with CUDA 12.6, 12.8 or 13.0 runtimes. The
|
||||
manual installation instructions are detailed below for CUDA 12.6, but can be applied to the other
|
||||
CUDA versions by replacing `cu126` with the corresponding version string (e.g., `cu130`).
|
||||
|
||||
<details open>
|
||||
<summary><strong>🐧 Linux</strong></summary>
|
||||
|
||||
First, download the LibTorch CUDA 12.6 distribution.
|
||||
|
||||
```shell
|
||||
wget -O libtorch.zip https://download.pytorch.org/libtorch/cu126/libtorch-shared-with-deps-2.9.0%2Bcu126.zip
|
||||
unzip libtorch.zip
|
||||
```
|
||||
|
||||
Then, point to that installation using the `LIBTORCH` and `LD_LIBRARY_PATH` environment variables
|
||||
before building `burn-tch` or a crate which depends on it.
|
||||
|
||||
```shell
|
||||
export LIBTORCH=/absolute/path/to/libtorch/
|
||||
export LD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$LD_LIBRARY_PATH
|
||||
```
|
||||
|
||||
**Note:** make sure your CUDA installation is in your `PATH` and `LD_LIBRARY_PATH`.
|
||||
|
||||
</details><br>
|
||||
|
||||
<details>
|
||||
<summary><strong>🪟 Windows</strong></summary>
|
||||
|
||||
First, download the LibTorch CUDA 12.6 distribution.
|
||||
|
||||
```powershell
|
||||
wget https://download.pytorch.org/libtorch/cu126/libtorch-win-shared-with-deps-2.9.0%2Bcu126.zip -OutFile libtorch.zip
|
||||
Expand-Archive libtorch.zip
|
||||
```
|
||||
|
||||
Then, set the `LIBTORCH` environment variable and append the library to your path as with the
|
||||
PowerShell commands below before building `burn-tch` or a crate which depends on it.
|
||||
|
||||
```powershell
|
||||
$Env:LIBTORCH = "/absolute/path/to/libtorch/"
|
||||
$Env:Path += ";/absolute/path/to/libtorch/"
|
||||
```
|
||||
|
||||
</details><br>
|
||||
|
||||
#### Metal (MPS)
|
||||
|
||||
There is no official LibTorch distribution with MPS support at this time, so the easiest alternative
|
||||
is to use a PyTorch installation. This requires a Python installation.
|
||||
|
||||
_Note: MPS acceleration is available on MacOS 12.3+._
|
||||
|
||||
```shell
|
||||
pip install torch==2.9.0 numpy==1.26.4 setuptools
|
||||
export LIBTORCH_USE_PYTORCH=1
|
||||
export DYLD_LIBRARY_PATH=/path/to/pytorch/lib:$DYLD_LIBRARY_PATH
|
||||
```
|
||||
|
||||
**Note:** if `venv` is used, it should be activated during coding and building, or the compiler may
|
||||
not work properly.
|
||||
|
||||
## Example Usage
|
||||
|
||||
For a simple example, check out any of the test programs in [`src/bin/`](./src/bin/). Each program
|
||||
sets the device to use and performs a simple element-wise addition.
|
||||
|
||||
For a more complete example using the `tch` backend, take a loot at the
|
||||
[Burn mnist example](https://github.com/tracel-ai/burn/tree/main/examples/mnist).
|
||||
|
||||
## Too many environment variables?
|
||||
|
||||
Try `.cargo/config.toml` ([cargo book](https://doc.rust-lang.org/cargo/reference/config.html#env)).
|
||||
|
||||
Instead of setting the environments in your shell, you can manually add them to your
|
||||
`.cargo/config.toml`:
|
||||
|
||||
```toml
|
||||
[env]
|
||||
LD_LIBRARY_PATH = "/absolute/path/to/libtorch/lib"
|
||||
LIBTORCH = "/absolute/path/to/libtorch/libtorch"
|
||||
```
|
||||
|
||||
Or use bash commands below:
|
||||
|
||||
```bash
|
||||
mkdir .cargo
|
||||
cat <<EOF > .cargo/config.toml
|
||||
[env]
|
||||
LD_LIBRARY_PATH = "/absolute/path/to/libtorch/lib:$LD_LIBRARY_PATH"
|
||||
LIBTORCH = "/absolute/path/to/libtorch/libtorch"
|
||||
EOF
|
||||
```
|
||||
|
||||
This will automatically include the old `LD_LIBRARY_PATH` value in the new one.
|
||||
243
crates/stable-diffusion-burn/burn-crates/burn-tch/build.rs
Normal file
243
crates/stable-diffusion-burn/burn-crates/burn-tch/build.rs
Normal file
@@ -0,0 +1,243 @@
|
||||
// The LIBTORCH environment variable can be used to specify the directory
|
||||
// where libtorch has been installed.
|
||||
// When not specified this script downloads the cpu version for libtorch
|
||||
// and extracts it in OUT_DIR.
|
||||
//
|
||||
// On Linux, the TORCH_CUDA_VERSION environment variable can be used,
|
||||
// like 9.0, 90, or cu90 to specify the version of CUDA to use for libtorch.
|
||||
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::{env, fs};
|
||||
|
||||
const PYTHON_PRINT_PYTORCH_DETAILS: &str = r"
|
||||
import torch
|
||||
from torch.utils import cpp_extension
|
||||
print('LIBTORCH_VERSION:', torch.__version__.split('+')[0])
|
||||
print('LIBTORCH_CXX11:', torch._C._GLIBCXX_USE_CXX11_ABI)
|
||||
for include_path in cpp_extension.include_paths():
|
||||
print('LIBTORCH_INCLUDE:', include_path)
|
||||
for library_path in cpp_extension.library_paths():
|
||||
print('LIBTORCH_LIB:', library_path)
|
||||
";
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum Os {
|
||||
Linux,
|
||||
Macos,
|
||||
Windows,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct SystemInfo {
|
||||
os: Os,
|
||||
cxx11_abi: String,
|
||||
libtorch_include_dirs: Vec<PathBuf>,
|
||||
libtorch_lib_dir: PathBuf,
|
||||
}
|
||||
|
||||
fn env_var_rerun(name: &str) -> Result<String, env::VarError> {
|
||||
println!("cargo:rerun-if-env-changed={name}");
|
||||
env::var(name)
|
||||
}
|
||||
|
||||
impl SystemInfo {
|
||||
fn new() -> Option<Self> {
|
||||
let os = match env::var("CARGO_CFG_TARGET_OS")
|
||||
.expect("Unable to get TARGET_OS")
|
||||
.as_str()
|
||||
{
|
||||
"linux" => Os::Linux,
|
||||
"windows" => Os::Windows,
|
||||
"macos" => Os::Macos,
|
||||
os => panic!("unsupported TARGET_OS '{os}'"),
|
||||
};
|
||||
// Locate the currently active Python binary, similar to:
|
||||
// https://github.com/PyO3/maturin/blob/243b8ec91d07113f97a6fe74d9b2dcb88086e0eb/src/target.rs#L547
|
||||
let python_interpreter = match os {
|
||||
Os::Windows => PathBuf::from("python.exe"),
|
||||
Os::Linux | Os::Macos => {
|
||||
if env::var_os("VIRTUAL_ENV").is_some() {
|
||||
PathBuf::from("python")
|
||||
} else {
|
||||
PathBuf::from("python3")
|
||||
}
|
||||
}
|
||||
};
|
||||
let mut libtorch_include_dirs = vec![];
|
||||
let mut libtorch_lib_dir = None;
|
||||
let cxx11_abi = if env_var_rerun("LIBTORCH_USE_PYTORCH").is_ok() {
|
||||
let output = std::process::Command::new(&python_interpreter)
|
||||
.arg("-c")
|
||||
.arg(PYTHON_PRINT_PYTORCH_DETAILS)
|
||||
.output()
|
||||
.expect("error running python interpreter");
|
||||
let mut cxx11_abi = None;
|
||||
for line in String::from_utf8_lossy(&output.stdout).lines() {
|
||||
match line.strip_prefix("LIBTORCH_CXX11: ") {
|
||||
Some("True") => cxx11_abi = Some("1".to_owned()),
|
||||
Some("False") => cxx11_abi = Some("0".to_owned()),
|
||||
_ => {}
|
||||
}
|
||||
if let Some(path) = line.strip_prefix("LIBTORCH_INCLUDE: ") {
|
||||
libtorch_include_dirs.push(PathBuf::from(path))
|
||||
}
|
||||
if let Some(path) = line.strip_prefix("LIBTORCH_LIB: ") {
|
||||
libtorch_lib_dir = Some(PathBuf::from(path))
|
||||
}
|
||||
}
|
||||
match cxx11_abi {
|
||||
Some(cxx11_abi) => cxx11_abi,
|
||||
None => panic!("no cxx11 abi returned by python {output:?}"),
|
||||
}
|
||||
} else {
|
||||
let libtorch = Self::prepare_libtorch_dir(os)?;
|
||||
let includes = env_var_rerun("LIBTORCH_INCLUDE")
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(|_| libtorch.clone());
|
||||
let lib = env_var_rerun("LIBTORCH_LIB")
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(|_| libtorch.clone());
|
||||
libtorch_include_dirs.push(includes.join("include"));
|
||||
libtorch_include_dirs.push(includes.join("include/torch/csrc/api/include"));
|
||||
if lib.ends_with("lib") {
|
||||
// DEP_TCH_LIBTORCH_LIB might already point to /lib
|
||||
libtorch_lib_dir = Some(lib);
|
||||
} else {
|
||||
libtorch_lib_dir = Some(lib.join("lib"));
|
||||
}
|
||||
env_var_rerun("LIBTORCH_CXX11_ABI").unwrap_or_else(|_| "1".to_owned())
|
||||
};
|
||||
let libtorch_lib_dir = libtorch_lib_dir?;
|
||||
Some(Self {
|
||||
os,
|
||||
cxx11_abi,
|
||||
libtorch_include_dirs,
|
||||
libtorch_lib_dir,
|
||||
})
|
||||
}
|
||||
|
||||
fn check_system_location(os: Os) -> Option<PathBuf> {
|
||||
match os {
|
||||
Os::Linux => Path::new("/usr/lib/libtorch.so")
|
||||
.exists()
|
||||
.then(|| PathBuf::from("/usr")),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn prepare_libtorch_dir(os: Os) -> Option<PathBuf> {
|
||||
if let Ok(libtorch) = env_var_rerun("DEP_TCH_LIBTORCH_LIB") {
|
||||
Some(PathBuf::from(libtorch))
|
||||
} else if let Ok(libtorch) = env_var_rerun("LIBTORCH") {
|
||||
Some(PathBuf::from(libtorch))
|
||||
} else if let Some(pathbuf) = Self::check_system_location(os) {
|
||||
Some(pathbuf)
|
||||
} else {
|
||||
check_out_dir()
|
||||
}
|
||||
}
|
||||
|
||||
fn make(&self, use_cuda: bool, use_hip: bool) {
|
||||
let cuda_dependency = if use_cuda || use_hip {
|
||||
"src/cuda_hack/dummy_cuda_dependency.cpp"
|
||||
} else {
|
||||
"src/cuda_hack/fake_cuda_dependency.cpp"
|
||||
};
|
||||
println!("cargo:rerun-if-changed={cuda_dependency}");
|
||||
|
||||
match self.os {
|
||||
Os::Linux | Os::Macos => {
|
||||
cc::Build::new()
|
||||
.cpp(true)
|
||||
.pic(true)
|
||||
.warnings(false)
|
||||
.includes(&self.libtorch_include_dirs)
|
||||
.flag(format!("-Wl,-rpath={}", self.libtorch_lib_dir.display()))
|
||||
.flag("-std=c++17")
|
||||
.flag(format!("-D_GLIBCXX_USE_CXX11_ABI={}", self.cxx11_abi))
|
||||
.files(&[cuda_dependency])
|
||||
.compile("burn-tch");
|
||||
}
|
||||
Os::Windows => {
|
||||
cc::Build::new()
|
||||
.cpp(true)
|
||||
.pic(true)
|
||||
.warnings(false)
|
||||
.includes(&self.libtorch_include_dirs)
|
||||
.flag("/std:c++17")
|
||||
.files(&[cuda_dependency])
|
||||
.compile("burn-tch");
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
fn make_cpu() {
|
||||
let cuda_dependency = "src/cuda_hack/fake_cuda_dependency.cpp";
|
||||
println!("cargo:rerun-if-changed={cuda_dependency}");
|
||||
|
||||
let os = env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS");
|
||||
|
||||
match os.as_str() {
|
||||
"windows" => {
|
||||
cc::Build::new()
|
||||
.cpp(true)
|
||||
.pic(true)
|
||||
.warnings(false)
|
||||
.flag("/std:c++17")
|
||||
.files(&[cuda_dependency])
|
||||
.compile("burn-tch");
|
||||
}
|
||||
_ => {
|
||||
cc::Build::new()
|
||||
.cpp(true)
|
||||
.pic(true)
|
||||
.warnings(false)
|
||||
.flag("-std=c++17")
|
||||
.files(&[cuda_dependency])
|
||||
.compile("tch");
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
fn check_out_dir() -> Option<PathBuf> {
|
||||
let out_dir = env_var_rerun("OUT_DIR").ok()?;
|
||||
let libtorch_dir = PathBuf::from(out_dir).join("libtorch");
|
||||
libtorch_dir.exists().then_some(libtorch_dir)
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let system_info = SystemInfo::new();
|
||||
let out_dir = env_var_rerun("OUT_DIR").expect("Failed to get out dir");
|
||||
|
||||
let mut gpu_found = false;
|
||||
let found_dir = system_info.is_some();
|
||||
if let Some(system_info) = &system_info {
|
||||
let si_lib = &system_info.libtorch_lib_dir;
|
||||
let use_cuda =
|
||||
si_lib.join("libtorch_cuda.so").exists() || si_lib.join("torch_cuda.dll").exists();
|
||||
let use_hip =
|
||||
si_lib.join("libtorch_hip.so").exists() || si_lib.join("torch_hip.dll").exists();
|
||||
|
||||
system_info.make(use_cuda, use_hip);
|
||||
gpu_found = use_cuda || use_hip;
|
||||
} else {
|
||||
SystemInfo::make_cpu();
|
||||
}
|
||||
let check_file = PathBuf::from(out_dir).join("tch_gpu_check.rs");
|
||||
if gpu_found {
|
||||
fs::write(check_file, "#[allow(clippy::no_effect)]\n()").unwrap();
|
||||
} else {
|
||||
let message = if !found_dir {
|
||||
r#"Could not find libtorch dir.
|
||||
|
||||
If you are trying to use the automatically downloaded version, the path is not directly available on Windows. Instead, try setting the `LIBTORCH` environment variable for the manual download instructions.
|
||||
|
||||
If the library has already been downloaded in the torch-sys OUT_DIR, you can point the variable to this path (or move the downloaded lib and point to it)."#
|
||||
} else {
|
||||
"No libtorch_cuda or libtorch_hip found. Download the GPU version of libtorch to use a GPU device"
|
||||
};
|
||||
fs::write(check_file, format!("panic!(\"{message}\")")).unwrap();
|
||||
}
|
||||
}
|
||||
175
crates/stable-diffusion-burn/burn-crates/burn-tch/src/backend.rs
Normal file
175
crates/stable-diffusion-burn/burn-crates/burn-tch/src/backend.rs
Normal file
@@ -0,0 +1,175 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::IntoKind;
|
||||
|
||||
use super::TchTensor;
|
||||
use super::element::TchElement;
|
||||
use burn_backend::backend::{Backend, DeviceId, DeviceOps, ExecutionError};
|
||||
use burn_backend::ops::IntTensorOps;
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
/// The device struct when using the `tch` backend.
|
||||
///
|
||||
/// Note that you need to provide the device index when using Cuda.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use burn_tch::LibTorchDevice;
|
||||
///
|
||||
/// let device_gpu_1 = LibTorchDevice::Cuda(0); // First GPU
|
||||
/// let device_gpu_2 = LibTorchDevice::Cuda(1); // Second GPU
|
||||
/// let device_cpu = LibTorchDevice::Cpu; // CPU
|
||||
/// let device_mps = LibTorchDevice::Mps; // Metal Performance Shaders
|
||||
/// let device_vulkan = LibTorchDevice::Vulkan; // Vulkan
|
||||
/// ```
|
||||
#[derive(Default)]
|
||||
pub enum LibTorchDevice {
|
||||
/// CPU device.
|
||||
#[default]
|
||||
Cpu,
|
||||
|
||||
/// Cuda device with the given index. The index is the index of the Cuda device in the list of
|
||||
/// all Cuda devices found on the system.
|
||||
Cuda(usize),
|
||||
|
||||
/// Metal Performance Shaders device.
|
||||
Mps,
|
||||
|
||||
/// Vulkan device.
|
||||
Vulkan,
|
||||
}
|
||||
|
||||
impl From<LibTorchDevice> for tch::Device {
|
||||
#[allow(
|
||||
unreachable_code,
|
||||
reason = "CUDA branch always panics if the library is missing"
|
||||
)]
|
||||
fn from(device: LibTorchDevice) -> Self {
|
||||
match device {
|
||||
LibTorchDevice::Cpu => tch::Device::Cpu,
|
||||
LibTorchDevice::Cuda(_num) => {
|
||||
include!(concat!(env!("OUT_DIR"), "/tch_gpu_check.rs"));
|
||||
tch::Device::Cuda(_num)
|
||||
}
|
||||
LibTorchDevice::Mps => tch::Device::Mps,
|
||||
LibTorchDevice::Vulkan => tch::Device::Vulkan,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<tch::Device> for LibTorchDevice {
|
||||
fn from(device: tch::Device) -> Self {
|
||||
match device {
|
||||
tch::Device::Cpu => LibTorchDevice::Cpu,
|
||||
tch::Device::Cuda(num) => LibTorchDevice::Cuda(num),
|
||||
tch::Device::Mps => LibTorchDevice::Mps,
|
||||
tch::Device::Vulkan => LibTorchDevice::Vulkan,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl burn_backend::Device for LibTorchDevice {
|
||||
fn from_id(device_id: DeviceId) -> Self {
|
||||
match device_id.type_id {
|
||||
0 => Self::Cuda(device_id.index_id as usize),
|
||||
1 => Self::Mps,
|
||||
2 => Self::Cpu,
|
||||
3 => Self::Vulkan,
|
||||
_ => LibTorchDevice::Cpu,
|
||||
}
|
||||
}
|
||||
|
||||
fn to_id(&self) -> DeviceId {
|
||||
match self {
|
||||
LibTorchDevice::Cuda(index) => DeviceId::new(0, *index as u32),
|
||||
LibTorchDevice::Mps => DeviceId::new(1, 0),
|
||||
LibTorchDevice::Cpu => DeviceId::new(2, 0),
|
||||
LibTorchDevice::Vulkan => DeviceId::new(3, 0),
|
||||
}
|
||||
}
|
||||
|
||||
fn device_count(_type_id: u16) -> usize {
|
||||
// TODO: Somehow find the info using the tch API.
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
impl DeviceOps for LibTorchDevice {}
|
||||
|
||||
/// Tensor backend that uses `LibTorch` with the [tch] crate for executing tensor operations.
|
||||
///
|
||||
/// This backend is compatible with a wide range of hardwares ranging from CPUs to GPUs, but
|
||||
/// requires `LibTorch` to be installed correctly. The CPU version can be downloaded
|
||||
/// automatically and the CUDA version as well by setting the `TORCH_CUDA_VERSION` environment
|
||||
/// variable. For more complex configurations, check out the manual installation for
|
||||
/// [burn-tch](https://github.com/tracel-ai/burn/tree/main/crates/burn-tch).
|
||||
///
|
||||
/// Refer to the [tch] crate for more information.
|
||||
#[derive(Clone, Copy, Default, Debug)]
|
||||
pub struct LibTorch<E = f32> {
|
||||
_e: PhantomData<E>,
|
||||
}
|
||||
|
||||
impl<E: TchElement> Backend for LibTorch<E> {
|
||||
type Device = LibTorchDevice;
|
||||
|
||||
type FloatTensorPrimitive = TchTensor;
|
||||
type FloatElem = E;
|
||||
|
||||
type IntTensorPrimitive = TchTensor;
|
||||
type IntElem = i64;
|
||||
type BoolTensorPrimitive = TchTensor;
|
||||
type BoolElem = bool;
|
||||
|
||||
type QuantizedTensorPrimitive = TchTensor;
|
||||
|
||||
fn seed(_device: &Self::Device, seed: u64) {
|
||||
tch::manual_seed(seed as i64);
|
||||
}
|
||||
|
||||
fn ad_enabled(_device: &Self::Device) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn name(device: &Self::Device) -> String {
|
||||
match device {
|
||||
LibTorchDevice::Cpu => "libtorch<cpu>",
|
||||
LibTorchDevice::Cuda(_) => "libtorch<cuda>",
|
||||
LibTorchDevice::Mps => "libtorch<metal>",
|
||||
LibTorchDevice::Vulkan => "libtorch<vulkan>",
|
||||
}
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn sync(device: &Self::Device) -> Result<(), ExecutionError> {
|
||||
match device {
|
||||
LibTorchDevice::Cpu => (),
|
||||
LibTorchDevice::Cuda(index) => {
|
||||
tch::Cuda::synchronize(*index as i64);
|
||||
}
|
||||
_ => {
|
||||
// When there is no explicit way to synchronize, we write and read one value to sync
|
||||
burn_backend::read_sync(Self::int_into_data(Self::int_zeros(
|
||||
[1].into(),
|
||||
device,
|
||||
E::dtype().into(),
|
||||
)))
|
||||
.unwrap();
|
||||
}
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn dtype_usage(
|
||||
_device: &Self::Device,
|
||||
dtype: burn_backend::DType,
|
||||
) -> burn_backend::DTypeUsageSet {
|
||||
if dtype.try_into_kind().is_ok() {
|
||||
burn_backend::DTypeUsage::general()
|
||||
} else {
|
||||
burn_backend::DTypeUsageSet::empty()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
use burn_backend::{TensorMetadata, ops::FloatTensorOps};
|
||||
use burn_tch::{LibTorch, LibTorchDevice};
|
||||
|
||||
fn main() {
|
||||
type B = LibTorch<f32>;
|
||||
let device = LibTorchDevice::Cpu;
|
||||
|
||||
// Creation of two tensors, the first with explicit values and the second one with ones, with the same shape as the first
|
||||
let tensor_1 = B::float_from_data([[2f32, 3.], [4., 5.]].into(), &device);
|
||||
let tensor_2 = B::float_ones(tensor_1.shape(), &device, tensor_1.dtype().into());
|
||||
|
||||
// Print the element-wise addition of the two tensors.
|
||||
println!("{}", B::float_add(tensor_1, tensor_2));
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
use burn_backend::{TensorMetadata, ops::FloatTensorOps};
|
||||
use burn_tch::{LibTorch, LibTorchDevice};
|
||||
|
||||
fn main() {
|
||||
assert!(
|
||||
tch::utils::has_cuda(),
|
||||
"Could not detect valid CUDA configuration"
|
||||
);
|
||||
|
||||
type B = LibTorch<f32>;
|
||||
let device = LibTorchDevice::Cuda(0);
|
||||
|
||||
// Creation of two tensors, the first with explicit values and the second one with ones, with the same shape as the first
|
||||
let tensor_1 = B::float_from_data([[2f32, 3.], [4., 5.]].into(), &device);
|
||||
let tensor_2 = B::float_ones(tensor_1.shape(), &device, tensor_1.dtype().into());
|
||||
|
||||
// Print the element-wise addition of the two tensors.
|
||||
println!("{}", B::float_add(tensor_1, tensor_2));
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
use burn_backend::{TensorMetadata, ops::FloatTensorOps};
|
||||
use burn_tch::{LibTorch, LibTorchDevice};
|
||||
|
||||
fn main() {
|
||||
assert!(tch::utils::has_mps(), "Could not detect MPS");
|
||||
|
||||
type B = LibTorch<f32>;
|
||||
let device = LibTorchDevice::Mps;
|
||||
|
||||
// Creation of two tensors, the first with explicit values and the second one with ones, with the same shape as the first
|
||||
let tensor_1 = B::float_from_data([[2f32, 3.], [4., 5.]].into(), &device);
|
||||
let tensor_2 = B::float_ones(tensor_1.shape(), &device, tensor_1.dtype().into());
|
||||
|
||||
// Print the element-wise addition of the two tensors.
|
||||
println!("{}", B::float_add(tensor_1, tensor_2));
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
#include <iostream>
|
||||
#include <stdexcept>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
using namespace std;
|
||||
extern "C" {
|
||||
void dummy_cuda_dependency();
|
||||
}
|
||||
|
||||
struct cublasContext;
|
||||
|
||||
namespace at {
|
||||
namespace cuda {
|
||||
cublasContext *getCurrentCUDABlasHandle();
|
||||
int warp_size();
|
||||
} // namespace cuda
|
||||
} // namespace at
|
||||
char *magma_strerror(int err);
|
||||
void dummy_cuda_dependency() {
|
||||
try {
|
||||
at::cuda::getCurrentCUDABlasHandle();
|
||||
at::cuda::warp_size();
|
||||
} catch (std::exception &e) {
|
||||
if (getenv("TCH_PRINT_CUDA_INIT_ERROR") != nullptr) {
|
||||
std::cerr << "error initializing cuda: " << e.what() << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
extern "C" {
|
||||
void dummy_cuda_dependency();
|
||||
}
|
||||
|
||||
void dummy_cuda_dependency() {}
|
||||
@@ -0,0 +1,51 @@
|
||||
use burn_backend::Element;
|
||||
use burn_backend::{bf16, f16};
|
||||
|
||||
/// The element type for the tch backend.
|
||||
pub trait TchElement: Element + tch::kind::Element {
|
||||
/// Returns the associated tensor kind for [`tch::kind::Element`].
|
||||
fn kind() -> tch::Kind {
|
||||
Self::KIND
|
||||
}
|
||||
}
|
||||
|
||||
impl TchElement for f64 {}
|
||||
impl TchElement for f32 {}
|
||||
impl TchElement for f16 {}
|
||||
impl TchElement for bf16 {
|
||||
fn kind() -> tch::Kind {
|
||||
let mut kind = <Self as tch::kind::Element>::KIND;
|
||||
// Incorrect kind mapping in tch definitions, force bfloat16
|
||||
if matches!(Self::dtype(), burn_backend::DType::BF16) && kind == tch::Kind::Half {
|
||||
kind = tch::Kind::BFloat16
|
||||
}
|
||||
kind
|
||||
}
|
||||
}
|
||||
|
||||
impl TchElement for i64 {}
|
||||
impl TchElement for i32 {}
|
||||
impl TchElement for i16 {}
|
||||
impl TchElement for i8 {}
|
||||
|
||||
impl TchElement for u8 {}
|
||||
|
||||
impl TchElement for bool {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_elem_kinds() {
|
||||
assert_eq!(f64::kind(), tch::Kind::Double);
|
||||
assert_eq!(f32::kind(), tch::Kind::Float);
|
||||
assert_eq!(f16::kind(), tch::Kind::Half);
|
||||
assert_eq!(bf16::kind(), tch::Kind::BFloat16);
|
||||
assert_eq!(i64::kind(), tch::Kind::Int64);
|
||||
assert_eq!(i32::kind(), tch::Kind::Int);
|
||||
assert_eq!(i16::kind(), tch::Kind::Int16);
|
||||
assert_eq!(i8::kind(), tch::Kind::Int8);
|
||||
assert_eq!(bool::kind(), tch::Kind::Bool);
|
||||
}
|
||||
}
|
||||
14
crates/stable-diffusion-burn/burn-crates/burn-tch/src/lib.rs
Normal file
14
crates/stable-diffusion-burn/burn-crates/burn-tch/src/lib.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
#![warn(missing_docs)]
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
#![allow(clippy::single_range_in_vec_init)]
|
||||
|
||||
//! Burn Tch Backend
|
||||
|
||||
mod backend;
|
||||
mod element;
|
||||
mod ops;
|
||||
mod tensor;
|
||||
|
||||
pub use backend::*;
|
||||
pub use element::*;
|
||||
pub use tensor::*;
|
||||
@@ -0,0 +1,37 @@
|
||||
use crate::{LibTorch, TchTensor, element::TchElement};
|
||||
use burn_backend::ops::ActivationOps;
|
||||
|
||||
impl<E: TchElement> ActivationOps<Self> for LibTorch<E> {
|
||||
fn relu(tensor: TchTensor) -> TchTensor {
|
||||
tensor.unary_ops(|mut tensor| tensor.relu_(), |tensor| tensor.relu())
|
||||
}
|
||||
|
||||
fn gelu(tensor: TchTensor) -> TchTensor {
|
||||
tensor.unary_ops(
|
||||
|mut tensor| tensor.gelu_("none"),
|
||||
|tensor| tensor.gelu("none"),
|
||||
)
|
||||
}
|
||||
|
||||
fn gelu_backward(tensor: TchTensor, grad: TchTensor) -> TchTensor {
|
||||
let storage = tensor.storage.clone();
|
||||
let tensor = tensor.tensor.gelu_backward(&grad.tensor, "none");
|
||||
|
||||
TchTensor::from_existing(tensor, storage)
|
||||
}
|
||||
|
||||
fn sigmoid(tensor: TchTensor) -> TchTensor {
|
||||
tensor.unary_ops(|mut tensor| tensor.sigmoid_(), |tensor| tensor.sigmoid())
|
||||
}
|
||||
|
||||
fn log_sigmoid(tensor: TchTensor) -> TchTensor {
|
||||
// NOTE: we don't override log_sigmoid_backward because Torch has a special backward
|
||||
// formula that uses a buffer with computed values from the forward pass
|
||||
|
||||
// no in-place log_sigmoid_
|
||||
let storage = tensor.storage.clone();
|
||||
let tensor = tensor.tensor.log_sigmoid();
|
||||
|
||||
TchTensor::from_existing(tensor, storage)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,737 @@
|
||||
use burn_backend::{Shape, TensorMetadata};
|
||||
use tch::Scalar;
|
||||
|
||||
use crate::{LibTorchDevice, TchShape, TchTensor};
|
||||
|
||||
pub struct TchOps {
|
||||
// e: PhantomData<E>,
|
||||
}
|
||||
|
||||
impl TchOps {
|
||||
pub fn to_device(tensor: TchTensor, device: &LibTorchDevice) -> TchTensor {
|
||||
let device = (*device).into();
|
||||
|
||||
// We have to manually check if the device is the same, since when it's the case, we need to keep
|
||||
// the same storage reference and not create a new one.
|
||||
if tensor.tensor.device() == device {
|
||||
return tensor;
|
||||
}
|
||||
|
||||
TchTensor::new(tensor.tensor.to(device))
|
||||
}
|
||||
|
||||
pub fn reshape(tensor: TchTensor, shape: Shape) -> TchTensor {
|
||||
let shape_tch: TchShape = shape.into();
|
||||
|
||||
TchTensor::from_existing(tensor.tensor.reshape(shape_tch.dims), tensor.storage)
|
||||
}
|
||||
|
||||
pub fn repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor {
|
||||
let mut dims = vec![1; tensor.shape().num_dims()];
|
||||
dims[dim] = times as i64;
|
||||
let tensor = tch::Tensor::repeat(&tensor.tensor, dims);
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
pub fn slice_with_steps(tensor: TchTensor, slices: &[burn_backend::Slice]) -> TchTensor {
|
||||
let storage = tensor.storage.clone();
|
||||
let mut tensor = tensor.tensor.shallow_clone();
|
||||
|
||||
for (dim, slice) in slices.iter().enumerate() {
|
||||
let dim_i64 = dim as i64;
|
||||
// Convert slice to range using a dummy size (we'll use tensor dimensions)
|
||||
let dim_size = tensor.size()[dim];
|
||||
let range = slice.to_range(dim_size as usize);
|
||||
let start = range.start as i64;
|
||||
let end = range.end as i64;
|
||||
let step = slice.step as i64;
|
||||
|
||||
if step > 0 {
|
||||
// Forward stepping - use native slice
|
||||
tensor = tensor.slice(dim_i64, Some(start), Some(end), step);
|
||||
} else {
|
||||
// Negative stepping - we need to handle the semantics correctly
|
||||
// For negative steps, we iterate backwards from end-1
|
||||
// PyTorch's negative step works differently than our semantics
|
||||
// We need to reverse the selected range
|
||||
|
||||
// First get the slice with positive step
|
||||
tensor = tensor.slice(dim_i64, Some(start), Some(end), 1);
|
||||
|
||||
// Then reverse it and apply the step
|
||||
if step == -1 {
|
||||
// Simple reversal
|
||||
tensor = tensor.flip([dim_i64]);
|
||||
} else {
|
||||
// Reverse and then take every nth element
|
||||
tensor = tensor.flip([dim_i64]);
|
||||
let abs_step = step.abs();
|
||||
tensor = tensor.slice(dim_i64, None, None, abs_step);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TchTensor::partial(tensor, storage)
|
||||
}
|
||||
|
||||
pub fn slice_assign(
|
||||
tensor: TchTensor,
|
||||
slices: &[burn_backend::Slice],
|
||||
value: TchTensor,
|
||||
) -> TchTensor {
|
||||
// PyTorch's narrow operation only supports contiguous slices (step=1)
|
||||
// For non-unit steps, we use advanced indexing as a workaround
|
||||
let all_unit_steps = slices.iter().all(|s| s.step == 1);
|
||||
|
||||
if all_unit_steps {
|
||||
// Fast path: use narrow and copy_ for unit steps
|
||||
let tch_shape = TchShape::from(tensor.shape());
|
||||
|
||||
// Copy the input tensor if we can't mutate it
|
||||
let tensor_original: TchTensor =
|
||||
tensor.unary_ops(|tensor| tensor, |tensor| tensor.copy());
|
||||
let tensor_original = tensor_original.tensor;
|
||||
|
||||
let mut tensor = tensor_original.view_(tch_shape.dims);
|
||||
|
||||
for (i, slice) in slices.iter().enumerate().take(slices.len()) {
|
||||
// Convert Slice to range for narrow operation
|
||||
let dim_size = tensor.size()[i] as usize;
|
||||
let range = slice.to_range(dim_size);
|
||||
let start = range.start as i64;
|
||||
let length = (range.end - range.start) as i64;
|
||||
|
||||
tensor = tensor.narrow(i as i64, start, length);
|
||||
}
|
||||
|
||||
tensor.copy_(&value.tensor);
|
||||
TchTensor::new(tensor_original)
|
||||
} else {
|
||||
// Workaround for non-unit steps: use PyTorch's index_put operation
|
||||
// This generates explicit indices for the slice and uses advanced indexing
|
||||
let tensor_shape = tensor.shape();
|
||||
let dims = tensor_shape.clone();
|
||||
|
||||
// Copy the tensor since we'll modify it
|
||||
let result_tensor = tensor.tensor.shallow_clone();
|
||||
|
||||
// Use advanced indexing to set the values
|
||||
Self::slice_assign_with_advanced_indexing(result_tensor, slices, value.tensor, &dims)
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate indices for a slice with potentially non-unit step.
|
||||
/// For negative steps, generates indices in reverse order.
|
||||
fn generate_slice_indices(slice: &burn_backend::Slice, dim_size: usize) -> Vec<i64> {
|
||||
let step = slice.step;
|
||||
let range = slice.to_range(dim_size);
|
||||
|
||||
let mut indices = Vec::new();
|
||||
|
||||
if step > 0 {
|
||||
let mut idx = range.start as i64;
|
||||
while idx < range.end as i64 {
|
||||
indices.push(idx);
|
||||
idx += step as i64;
|
||||
}
|
||||
} else if step < 0 {
|
||||
// For negative steps, iterate backwards through the range
|
||||
let mut idx = (range.end - 1) as i64;
|
||||
while idx >= range.start as i64 {
|
||||
indices.push(idx);
|
||||
idx += step as i64; // step is negative, so this decreases
|
||||
}
|
||||
}
|
||||
|
||||
indices
|
||||
}
|
||||
|
||||
/// Implementation using advanced indexing for non-unit steps.
|
||||
/// Uses PyTorch's index_put operation to assign values at specific indices.
|
||||
fn slice_assign_with_advanced_indexing(
|
||||
mut tensor: tch::Tensor,
|
||||
slices: &[burn_backend::Slice],
|
||||
value: tch::Tensor,
|
||||
dims: &[usize],
|
||||
) -> TchTensor {
|
||||
// Generate all index combinations for the sliced regions
|
||||
let mut index_sets: Vec<Vec<i64>> = Vec::new();
|
||||
for (i, slice) in slices.iter().enumerate() {
|
||||
let dim_size = if i < dims.len() { dims[i] } else { 1 };
|
||||
let indices = Self::generate_slice_indices(slice, dim_size);
|
||||
index_sets.push(indices);
|
||||
}
|
||||
|
||||
// For unsliced dimensions, include all indices
|
||||
for &dim_size in dims.iter().skip(slices.len()) {
|
||||
let indices: Vec<i64> = (0..dim_size as i64).collect();
|
||||
index_sets.push(indices);
|
||||
}
|
||||
|
||||
// Convert index sets to tensors for index_put
|
||||
let mut final_indices = Vec::new();
|
||||
let total_elements = index_sets.iter().map(|s| s.len()).product::<usize>();
|
||||
|
||||
// Build flattened index arrays for each dimension using cartesian product
|
||||
// This creates the index tensors needed for PyTorch's index_put operation
|
||||
for dim_idx in 0..index_sets.len() {
|
||||
let mut dim_indices = Vec::with_capacity(total_elements);
|
||||
let repeat = index_sets[dim_idx + 1..]
|
||||
.iter()
|
||||
.map(|s| s.len())
|
||||
.product::<usize>()
|
||||
.max(1);
|
||||
let tile = index_sets[..dim_idx]
|
||||
.iter()
|
||||
.map(|s| s.len())
|
||||
.product::<usize>()
|
||||
.max(1);
|
||||
|
||||
for _ in 0..tile {
|
||||
for &idx in &index_sets[dim_idx] {
|
||||
for _ in 0..repeat {
|
||||
dim_indices.push(idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let indices_tensor = tch::Tensor::from_slice(&dim_indices).to_device(tensor.device());
|
||||
final_indices.push(indices_tensor);
|
||||
}
|
||||
|
||||
// PyTorch's index_put handles assignment correctly for negative steps
|
||||
// following NumPy semantics: values[i] goes to selected_indices[i]
|
||||
let value_flat = value.view(-1);
|
||||
|
||||
// Use index_put to assign values - convert to Option<Tensor>
|
||||
let final_indices_opt: Vec<Option<tch::Tensor>> =
|
||||
final_indices.into_iter().map(Some).collect();
|
||||
tensor = tensor.index_put(&final_indices_opt, &value_flat, false);
|
||||
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
pub fn gather(dim: usize, tensor: TchTensor, indices: TchTensor) -> TchTensor {
|
||||
let storage = tensor.storage.clone();
|
||||
let tensor = tensor.tensor.gather(dim as i64, &indices.tensor, false);
|
||||
|
||||
TchTensor::from_existing(tensor, storage)
|
||||
}
|
||||
|
||||
pub fn scatter(
|
||||
dim: usize,
|
||||
tensor: TchTensor,
|
||||
indices: TchTensor,
|
||||
value: TchTensor,
|
||||
) -> TchTensor {
|
||||
let storage = tensor.storage.clone();
|
||||
let tensor = tensor
|
||||
.tensor
|
||||
.scatter_add(dim as i64, &indices.tensor, &value.tensor);
|
||||
|
||||
TchTensor::from_existing(tensor, storage)
|
||||
}
|
||||
|
||||
pub fn index_select_dim(tensor: TchTensor, dim: usize, indices: TchTensor) -> TchTensor {
|
||||
let storage = tensor.storage.clone();
|
||||
let tensor = tensor.tensor.index_select(dim as i64, &indices.tensor);
|
||||
|
||||
TchTensor::from_existing(tensor, storage)
|
||||
}
|
||||
|
||||
pub fn select_assign(
|
||||
tensor: TchTensor,
|
||||
dim: usize,
|
||||
indices: TchTensor,
|
||||
value: TchTensor,
|
||||
) -> TchTensor {
|
||||
tensor.clone().unary_ops(
|
||||
|mut tensor| tensor.index_add_(dim as i64, &indices.tensor, &value.tensor),
|
||||
|tensor| tensor.index_add(dim as i64, &indices.tensor, &value.tensor),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn cat(tensors: Vec<TchTensor>, dim: usize) -> TchTensor {
|
||||
let tensors: Vec<tch::Tensor> = tensors
|
||||
.into_iter()
|
||||
.map(|t| t.tensor.shallow_clone())
|
||||
.collect();
|
||||
let tensor = tch::Tensor::cat(&tensors, dim as i64);
|
||||
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
pub fn equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchTensor::binary_ops_tensor(
|
||||
lhs,
|
||||
rhs,
|
||||
|lhs, rhs| lhs.eq_tensor_(rhs).to_kind(tch::Kind::Bool),
|
||||
|lhs, rhs| rhs.eq_tensor_(lhs).to_kind(tch::Kind::Bool),
|
||||
|lhs, rhs| lhs.eq_tensor(rhs),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn equal_elem<S: Into<tch::Scalar> + Clone>(lhs: TchTensor, rhs: S) -> TchTensor {
|
||||
lhs.unary_ops(
|
||||
|mut tensor| tensor.eq_(rhs.clone().into()).to_kind(tch::Kind::Bool),
|
||||
|tensor| tensor.eq(rhs.clone().into()),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn greater(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchTensor::binary_ops_tensor(
|
||||
lhs,
|
||||
rhs,
|
||||
|lhs, rhs| lhs.greater_tensor_(rhs).to_kind(tch::Kind::Bool),
|
||||
|lhs, rhs| rhs.less_tensor_(lhs).to_kind(tch::Kind::Bool),
|
||||
|lhs, rhs| lhs.greater_tensor(rhs),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn greater_elem<S: Into<tch::Scalar> + Clone>(lhs: TchTensor, rhs: S) -> TchTensor {
|
||||
lhs.unary_ops(
|
||||
|mut tensor| tensor.greater_(rhs.clone().into()).to_kind(tch::Kind::Bool),
|
||||
|tensor| tensor.greater(rhs.clone().into()),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn greater_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchTensor::binary_ops_tensor(
|
||||
lhs,
|
||||
rhs,
|
||||
|lhs, rhs| lhs.greater_equal_tensor_(rhs).to_kind(tch::Kind::Bool),
|
||||
|lhs, rhs| rhs.less_equal_tensor_(lhs).to_kind(tch::Kind::Bool),
|
||||
|lhs, rhs| lhs.greater_equal_tensor(rhs),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn greater_equal_elem<S: Into<Scalar> + Clone>(lhs: TchTensor, rhs: S) -> TchTensor {
|
||||
lhs.unary_ops(
|
||||
|mut tensor| {
|
||||
tensor
|
||||
.greater_equal_(rhs.clone().into())
|
||||
.to_kind(tch::Kind::Bool)
|
||||
},
|
||||
|tensor| tensor.greater_equal(rhs.clone().into()),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn lower(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchTensor::binary_ops_tensor(
|
||||
lhs,
|
||||
rhs,
|
||||
|lhs, rhs| lhs.less_tensor_(rhs).to_kind(tch::Kind::Bool),
|
||||
|lhs, rhs| rhs.greater_tensor_(lhs).to_kind(tch::Kind::Bool),
|
||||
|lhs, rhs| lhs.less_tensor(rhs),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn lower_elem<S: Into<Scalar> + Clone>(lhs: TchTensor, rhs: S) -> TchTensor {
|
||||
lhs.unary_ops(
|
||||
|mut tensor| tensor.less_(rhs.clone().into()).to_kind(tch::Kind::Bool),
|
||||
|tensor| tensor.less(rhs.clone().into()),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn lower_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchTensor::binary_ops_tensor(
|
||||
lhs,
|
||||
rhs,
|
||||
|lhs, rhs| lhs.less_equal_tensor_(rhs).to_kind(tch::Kind::Bool),
|
||||
|lhs, rhs| rhs.greater_equal_tensor_(lhs).to_kind(tch::Kind::Bool),
|
||||
|lhs, rhs| lhs.less_equal_tensor(rhs),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn lower_equal_elem<S: Into<Scalar> + Clone>(lhs: TchTensor, rhs: S) -> TchTensor {
|
||||
lhs.unary_ops(
|
||||
|mut tensor| {
|
||||
tensor
|
||||
.less_equal_(rhs.clone().into())
|
||||
.to_kind(tch::Kind::Bool)
|
||||
},
|
||||
|tensor| tensor.less_equal(rhs.clone().into()),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn add(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchTensor::binary_ops_tensor(
|
||||
lhs,
|
||||
rhs,
|
||||
|lhs, rhs| lhs.f_add_(rhs).unwrap(),
|
||||
|lhs, rhs| rhs.f_add_(lhs).unwrap(),
|
||||
|lhs, rhs| lhs.f_add(rhs).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn sub(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchTensor::binary_ops_tensor(
|
||||
lhs,
|
||||
rhs,
|
||||
|lhs, rhs| lhs.f_sub_(rhs).unwrap(),
|
||||
|lhs, rhs| lhs.f_sub(rhs).unwrap(),
|
||||
|lhs, rhs| lhs.f_sub(rhs).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn mul(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchTensor::binary_ops_tensor(
|
||||
lhs,
|
||||
rhs,
|
||||
|lhs, rhs| lhs.f_mul_(rhs).unwrap(),
|
||||
|lhs, rhs| rhs.f_mul_(lhs).unwrap(),
|
||||
|lhs, rhs| lhs.f_mul(rhs).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn div(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchTensor::binary_ops_tensor(
|
||||
lhs,
|
||||
rhs,
|
||||
|lhs, rhs| lhs.f_div_(rhs).unwrap(),
|
||||
|lhs, rhs| lhs.f_div(rhs).unwrap(),
|
||||
|lhs, rhs| lhs.f_div(rhs).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn remainder(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchTensor::binary_ops_tensor(
|
||||
lhs,
|
||||
rhs,
|
||||
|lhs, rhs| lhs.f_remainder_tensor_(rhs).unwrap(),
|
||||
|lhs, rhs| lhs.f_remainder_tensor(rhs).unwrap(),
|
||||
|lhs, rhs| lhs.f_remainder_tensor(rhs).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn mean(tensor: TchTensor) -> TchTensor {
|
||||
// view as 1d tensor
|
||||
let tensor = tensor.tensor.mean(tensor.tensor.kind()).view(1);
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
pub fn mean_dim(tensor: TchTensor, dim: usize) -> TchTensor {
|
||||
TchTensor::from_existing(
|
||||
tensor
|
||||
.tensor
|
||||
.mean_dim(Some([dim as i64].as_slice()), true, tensor.tensor.kind()),
|
||||
tensor.storage,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn sum(tensor: TchTensor) -> TchTensor {
|
||||
// view as 1d tensor
|
||||
let tensor = tensor.tensor.sum(tensor.tensor.kind()).view(1);
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
pub fn sum_dim(tensor: TchTensor, dim: usize) -> TchTensor {
|
||||
TchTensor::from_existing(
|
||||
tensor.tensor.sum_dim_intlist(
|
||||
Some([dim as i64].as_slice()),
|
||||
true,
|
||||
tensor.tensor.kind(),
|
||||
),
|
||||
tensor.storage,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn prod(tensor: TchTensor) -> TchTensor {
|
||||
// view as 1d tensor
|
||||
let tensor = tensor.tensor.prod(tensor.tensor.kind()).view(1);
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
pub fn prod_dim(tensor: TchTensor, dim: usize) -> TchTensor {
|
||||
TchTensor::from_existing(
|
||||
tensor
|
||||
.tensor
|
||||
.prod_dim_int(dim as i64, true, tensor.tensor.kind()),
|
||||
tensor.storage,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn cumsum(tensor: TchTensor, dim: usize) -> TchTensor {
|
||||
TchTensor::from_existing(
|
||||
tensor.tensor.cumsum(dim as i64, tensor.tensor.kind()),
|
||||
tensor.storage,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn cumprod(tensor: TchTensor, dim: usize) -> TchTensor {
|
||||
TchTensor::from_existing(
|
||||
tensor.tensor.cumprod(dim as i64, tensor.tensor.kind()),
|
||||
tensor.storage,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn cummin(tensor: TchTensor, dim: usize) -> TchTensor {
|
||||
let (values, _indices) = tensor.tensor.cummin(dim as i64);
|
||||
TchTensor::from_existing(values, tensor.storage)
|
||||
}
|
||||
|
||||
pub fn cummax(tensor: TchTensor, dim: usize) -> TchTensor {
|
||||
// cummax returns (values, indices) tuple in PyTorch, we only need values
|
||||
let (values, _indices) = tensor.tensor.cummax(dim as i64);
|
||||
TchTensor::from_existing(values, tensor.storage)
|
||||
}
|
||||
|
||||
pub fn argmax(tensor: TchTensor, dim: usize) -> TchTensor {
|
||||
let storage = tensor.storage.clone();
|
||||
let tensor = tensor.tensor.argmax(dim as i64, true);
|
||||
|
||||
TchTensor::from_existing(tensor, storage)
|
||||
}
|
||||
|
||||
pub fn argmin(tensor: TchTensor, dim: usize) -> TchTensor {
|
||||
let storage = tensor.storage.clone();
|
||||
let tensor = tensor.tensor.argmin(dim as i64, true);
|
||||
|
||||
TchTensor::from_existing(tensor, storage)
|
||||
}
|
||||
|
||||
pub fn max_dim(tensor: TchTensor, dim: usize) -> TchTensor {
|
||||
let storage = tensor.storage.clone();
|
||||
let (tensor, _indices) = tensor.tensor.max_dim(dim as i64, true);
|
||||
|
||||
TchTensor::from_existing(tensor, storage)
|
||||
}
|
||||
|
||||
pub fn max_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {
|
||||
let storage = tensor.storage.clone();
|
||||
let (tensor, indices) = tensor.tensor.max_dim(dim as i64, true);
|
||||
|
||||
let tensor = TchTensor::from_existing(tensor, storage);
|
||||
let indices = TchTensor::new(indices);
|
||||
|
||||
(tensor, indices)
|
||||
}
|
||||
|
||||
pub fn min_dim(tensor: TchTensor, dim: usize) -> TchTensor {
|
||||
let storage = tensor.storage.clone();
|
||||
let (tensor, _indices) = tensor.tensor.min_dim(dim as i64, true);
|
||||
|
||||
TchTensor::from_existing(tensor, storage)
|
||||
}
|
||||
|
||||
pub fn min_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {
|
||||
let storage = tensor.storage.clone();
|
||||
let (tensor, indices) = tensor.tensor.min_dim(dim as i64, true);
|
||||
|
||||
let tensor = TchTensor::from_existing(tensor, storage);
|
||||
let indices = TchTensor::new(indices);
|
||||
|
||||
(tensor, indices)
|
||||
}
|
||||
|
||||
pub fn clamp_min<S: Into<tch::Scalar> + Clone + Copy>(tensor: TchTensor, min: S) -> TchTensor {
|
||||
tensor.unary_ops(
|
||||
|mut tensor| tensor.clamp_min_(min),
|
||||
|tensor| tensor.clamp_min(min),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn clamp_max<S: Into<tch::Scalar> + Clone + Copy>(tensor: TchTensor, max: S) -> TchTensor {
|
||||
tensor.unary_ops(
|
||||
|mut tensor| tensor.clamp_max_(max),
|
||||
|tensor| tensor.clamp_max(max),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn clamp<S: Into<tch::Scalar> + Clone + Copy>(
|
||||
tensor: TchTensor,
|
||||
min: S,
|
||||
max: S,
|
||||
) -> TchTensor {
|
||||
tensor.unary_ops(
|
||||
|mut tensor| tensor.clamp_(min, max),
|
||||
|tensor| tensor.clamp(min, max),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn swap_dims(tensor: TchTensor, dim1: usize, dim2: usize) -> TchTensor {
|
||||
let tensor = tensor.tensor.transpose(dim1 as i64, dim2 as i64);
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
pub fn permute(tensor: TchTensor, axes: &[usize]) -> TchTensor {
|
||||
let tensor = tensor
|
||||
.tensor
|
||||
.permute(axes.iter().map(|x| *x as i64).collect::<Vec<_>>());
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
pub fn flip(tensor: TchTensor, axes: &[usize]) -> TchTensor {
|
||||
let dims = axes.iter().map(|x| *x as i64).collect::<Vec<_>>();
|
||||
let tensor = tensor.tensor.flip(dims);
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
pub fn powf(tensor: TchTensor, exponent: TchTensor) -> TchTensor {
|
||||
TchTensor::binary_ops_tensor(
|
||||
tensor,
|
||||
exponent,
|
||||
|lhs, rhs| lhs.f_pow_tensor_(rhs).unwrap(),
|
||||
|lhs, rhs| lhs.f_pow(rhs).unwrap(),
|
||||
|lhs, rhs| lhs.f_pow(rhs).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn sign(tensor: TchTensor) -> TchTensor {
|
||||
tensor.unary_ops(|mut tensor| tensor.sign_(), |tensor| tensor.sign())
|
||||
}
|
||||
|
||||
pub fn expand(tensor: TchTensor, shape: Shape) -> TchTensor {
|
||||
let storage = tensor.storage.clone();
|
||||
let broadcasted_tensor = tensor.tensor.broadcast_to(TchShape::from(shape).dims);
|
||||
TchTensor::from_existing(broadcasted_tensor, storage)
|
||||
}
|
||||
|
||||
pub fn unfold(tensor: TchTensor, dim: usize, size: usize, step: usize) -> TchTensor {
|
||||
let storage = tensor.storage.clone();
|
||||
let uf_tensor = tensor.tensor.unfold(dim as i64, size as i64, step as i64);
|
||||
|
||||
TchTensor::from_existing(uf_tensor, storage)
|
||||
}
|
||||
|
||||
pub fn sort(tensor: TchTensor, dim: usize, descending: bool) -> TchTensor {
|
||||
TchTensor::new(tensor.tensor.sort(dim as i64, descending).0)
|
||||
}
|
||||
|
||||
pub fn sort_with_indices(
|
||||
tensor: TchTensor,
|
||||
dim: usize,
|
||||
descending: bool,
|
||||
) -> (TchTensor, TchTensor) {
|
||||
let sorted = tensor.tensor.sort(dim as i64, descending);
|
||||
(TchTensor::new(sorted.0), TchTensor::new(sorted.1))
|
||||
}
|
||||
|
||||
pub fn argsort(tensor: TchTensor, dim: usize, descending: bool) -> TchTensor {
|
||||
TchTensor::new(tensor.tensor.argsort(dim as i64, descending))
|
||||
}
|
||||
|
||||
pub fn bitwise_and(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchTensor::binary_ops_tensor(
|
||||
lhs,
|
||||
rhs,
|
||||
|lhs, rhs| lhs.f_bitwise_and_tensor_(rhs).unwrap(),
|
||||
|lhs, rhs| rhs.f_bitwise_and_tensor_(lhs).unwrap(),
|
||||
|lhs, rhs| lhs.f_bitwise_and_tensor(rhs).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn bitwise_and_scalar<S: Into<Scalar> + Clone>(tensor: TchTensor, scalar: S) -> TchTensor {
|
||||
tensor.unary_ops(
|
||||
|mut tensor| tensor.f_bitwise_and_(scalar.clone().into()).unwrap(),
|
||||
|tensor| tensor.f_bitwise_and(scalar.clone().into()).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn bitwise_or(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchTensor::binary_ops_tensor(
|
||||
lhs,
|
||||
rhs,
|
||||
|lhs, rhs| lhs.f_bitwise_or_tensor_(rhs).unwrap(),
|
||||
|lhs, rhs| rhs.f_bitwise_or_tensor_(lhs).unwrap(),
|
||||
|lhs, rhs| lhs.f_bitwise_or_tensor(rhs).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn bitwise_or_scalar<S: Into<Scalar> + Clone>(tensor: TchTensor, scalar: S) -> TchTensor {
|
||||
tensor.unary_ops(
|
||||
|mut tensor| tensor.f_bitwise_or_(scalar.clone().into()).unwrap(),
|
||||
|tensor| tensor.f_bitwise_or(scalar.clone().into()).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn bitwise_xor(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchTensor::binary_ops_tensor(
|
||||
lhs,
|
||||
rhs,
|
||||
|lhs, rhs| lhs.f_bitwise_xor_tensor_(rhs).unwrap(),
|
||||
|lhs, rhs| rhs.f_bitwise_xor_tensor_(lhs).unwrap(),
|
||||
|lhs, rhs| lhs.f_bitwise_xor_tensor(rhs).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn bitwise_xor_scalar<S: Into<Scalar> + Clone>(tensor: TchTensor, scalar: S) -> TchTensor {
|
||||
tensor.unary_ops(
|
||||
|mut tensor| tensor.f_bitwise_xor_(scalar.clone().into()).unwrap(),
|
||||
|tensor| tensor.f_bitwise_xor(scalar.clone().into()).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn bitwise_not(tensor: TchTensor) -> TchTensor {
|
||||
tensor.unary_ops(
|
||||
|mut tensor| tensor.f_bitwise_not_().unwrap(),
|
||||
|tensor| tensor.f_bitwise_not().unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn bitwise_left_shift(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchTensor::binary_ops_tensor(
|
||||
lhs,
|
||||
rhs,
|
||||
|lhs, rhs| lhs.f_bitwise_left_shift_(rhs).unwrap(),
|
||||
|lhs, rhs| lhs.f_bitwise_left_shift(rhs).unwrap(),
|
||||
|lhs, rhs| lhs.f_bitwise_left_shift(rhs).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn bitwise_left_shift_scalar<S: Into<Scalar> + Clone>(
|
||||
tensor: TchTensor,
|
||||
scalar: S,
|
||||
) -> TchTensor {
|
||||
tensor.unary_ops(
|
||||
|mut tensor| {
|
||||
tensor
|
||||
.f_bitwise_left_shift_tensor_scalar_(scalar.clone().into())
|
||||
.unwrap()
|
||||
},
|
||||
|tensor| {
|
||||
tensor
|
||||
.f_bitwise_left_shift_tensor_scalar(scalar.clone().into())
|
||||
.unwrap()
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
pub fn bitwise_right_shift(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchTensor::binary_ops_tensor(
|
||||
lhs,
|
||||
rhs,
|
||||
|lhs, rhs| lhs.f_bitwise_right_shift_(rhs).unwrap(),
|
||||
|lhs, rhs| lhs.f_bitwise_right_shift(rhs).unwrap(),
|
||||
|lhs, rhs| lhs.f_bitwise_right_shift(rhs).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn bitwise_right_shift_scalar<S: Into<Scalar> + Clone>(
|
||||
tensor: TchTensor,
|
||||
scalar: S,
|
||||
) -> TchTensor {
|
||||
tensor.unary_ops(
|
||||
|mut tensor| {
|
||||
tensor
|
||||
.f_bitwise_right_shift_tensor_scalar_(scalar.clone().into())
|
||||
.unwrap()
|
||||
},
|
||||
|tensor| {
|
||||
tensor
|
||||
.f_bitwise_right_shift_tensor_scalar(scalar.clone().into())
|
||||
.unwrap()
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
pub fn atan2(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchTensor::binary_ops_tensor(
|
||||
lhs,
|
||||
rhs,
|
||||
|lhs, rhs| lhs.f_atan2_(rhs).unwrap(),
|
||||
|lhs, rhs| lhs.f_atan2(rhs).unwrap(),
|
||||
|lhs, rhs| lhs.f_atan2(rhs).unwrap(),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,219 @@
|
||||
use super::TchOps;
|
||||
use crate::{LibTorch, LibTorchDevice, TchShape, TchTensor, element::TchElement};
|
||||
use burn_backend::ExecutionError;
|
||||
use burn_backend::Scalar;
|
||||
use burn_backend::tensor::BoolTensor;
|
||||
use burn_backend::tensor::IntTensor;
|
||||
use burn_backend::{Shape, TensorData, TensorMetadata, ops::BoolTensorOps};
|
||||
|
||||
impl<E: TchElement> BoolTensorOps<Self> for LibTorch<E> {
|
||||
fn bool_from_data(data: TensorData, device: &LibTorchDevice) -> TchTensor {
|
||||
match data.dtype {
|
||||
burn_backend::DType::Bool => TchTensor::from_data::<bool>(data, (*device).into()),
|
||||
_ => unimplemented!("Unsupported dtype for `bool_from_data`"),
|
||||
}
|
||||
}
|
||||
|
||||
fn bool_repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor {
|
||||
TchOps::repeat_dim(tensor, dim, times)
|
||||
}
|
||||
|
||||
async fn bool_into_data(tensor: TchTensor) -> Result<TensorData, ExecutionError> {
|
||||
let shape = tensor.shape();
|
||||
let tensor = Self::bool_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
|
||||
let values: Result<Vec<bool>, tch::TchError> = tensor.tensor.shallow_clone().try_into();
|
||||
Ok(TensorData::new(values.unwrap(), shape))
|
||||
}
|
||||
|
||||
fn bool_to_device(tensor: TchTensor, device: &LibTorchDevice) -> TchTensor {
|
||||
TchOps::to_device(tensor, device)
|
||||
}
|
||||
|
||||
fn bool_reshape(tensor: TchTensor, shape: Shape) -> TchTensor {
|
||||
TchOps::reshape(tensor, shape)
|
||||
}
|
||||
|
||||
fn bool_device(tensor: &TchTensor) -> LibTorchDevice {
|
||||
tensor.tensor.device().into()
|
||||
}
|
||||
|
||||
fn bool_empty(shape: Shape, device: &LibTorchDevice) -> TchTensor {
|
||||
let tensor = tch::Tensor::empty(
|
||||
TchShape::from(shape).dims,
|
||||
(tch::Kind::Bool, (*device).into()),
|
||||
);
|
||||
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn bool_zeros(shape: Shape, device: &LibTorchDevice) -> TchTensor {
|
||||
let tensor = tch::Tensor::zeros(
|
||||
TchShape::from(shape).dims,
|
||||
(tch::Kind::Bool, (*device).into()),
|
||||
);
|
||||
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn bool_ones(shape: Shape, device: &LibTorchDevice) -> TchTensor {
|
||||
let tensor = tch::Tensor::ones(
|
||||
TchShape::from(shape).dims,
|
||||
(tch::Kind::Bool, (*device).into()),
|
||||
);
|
||||
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn bool_slice(tensor: TchTensor, slices: &[burn_backend::Slice]) -> TchTensor {
|
||||
TchOps::slice_with_steps(tensor, slices)
|
||||
}
|
||||
|
||||
fn bool_slice_assign(
|
||||
tensor: TchTensor,
|
||||
slices: &[burn_backend::Slice],
|
||||
value: TchTensor,
|
||||
) -> TchTensor {
|
||||
TchOps::slice_assign(tensor, slices, value)
|
||||
}
|
||||
|
||||
fn bool_cat(tensors: Vec<TchTensor>, dim: usize) -> TchTensor {
|
||||
TchOps::cat(tensors, dim)
|
||||
}
|
||||
|
||||
fn bool_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchOps::equal(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bool_not(tensor: TchTensor) -> TchTensor {
|
||||
tensor.unary_ops(
|
||||
|mut tensor| tensor.eq_(0).to_kind(tch::Kind::Bool),
|
||||
|tensor| tensor.eq(0),
|
||||
)
|
||||
}
|
||||
|
||||
fn bool_and(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchTensor::binary_ops_tensor(
|
||||
lhs,
|
||||
rhs,
|
||||
|lhs, rhs| lhs.logical_and_(rhs),
|
||||
|lhs, rhs| rhs.logical_and_(lhs),
|
||||
|lhs, rhs| lhs.logical_and(rhs),
|
||||
)
|
||||
}
|
||||
|
||||
fn bool_or(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchTensor::binary_ops_tensor(
|
||||
lhs,
|
||||
rhs,
|
||||
|lhs, rhs| lhs.logical_or_(rhs),
|
||||
|lhs, rhs| rhs.logical_or_(lhs),
|
||||
|lhs, rhs| lhs.logical_or(rhs),
|
||||
)
|
||||
}
|
||||
|
||||
fn bool_into_int(tensor: TchTensor) -> TchTensor {
|
||||
let tensor = tensor.tensor.to_kind(tch::Kind::Int64);
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn bool_into_float(tensor: TchTensor) -> TchTensor {
|
||||
let tensor = tensor.tensor.to_kind(E::kind());
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn bool_swap_dims(tensor: TchTensor, dim1: usize, dim2: usize) -> TchTensor {
|
||||
TchOps::swap_dims(tensor, dim1, dim2)
|
||||
}
|
||||
|
||||
fn bool_permute(tensor: TchTensor, axes: &[usize]) -> TchTensor {
|
||||
TchOps::permute(tensor, axes)
|
||||
}
|
||||
|
||||
fn bool_flip(tensor: TchTensor, axes: &[usize]) -> TchTensor {
|
||||
TchOps::flip(tensor, axes)
|
||||
}
|
||||
|
||||
async fn bool_argwhere(tensor: TchTensor) -> TchTensor {
|
||||
TchTensor::new(tensor.tensor.argwhere())
|
||||
}
|
||||
|
||||
fn bool_select(tensor: TchTensor, dim: usize, indices: TchTensor) -> TchTensor {
|
||||
TchOps::index_select_dim(tensor, dim, indices)
|
||||
}
|
||||
|
||||
fn bool_select_or(
|
||||
tensor: TchTensor,
|
||||
dim: usize,
|
||||
indices: TchTensor,
|
||||
value: TchTensor,
|
||||
) -> TchTensor {
|
||||
TchOps::select_assign(tensor, dim, indices, value)
|
||||
}
|
||||
|
||||
fn bool_expand(tensor: TchTensor, shape: Shape) -> TchTensor {
|
||||
TchOps::expand(tensor, shape)
|
||||
}
|
||||
|
||||
fn bool_unfold(
|
||||
tensor: IntTensor<Self>,
|
||||
dim: usize,
|
||||
size: usize,
|
||||
step: usize,
|
||||
) -> IntTensor<Self> {
|
||||
TchOps::unfold(tensor, dim, size, step)
|
||||
}
|
||||
|
||||
fn bool_mask_where(
|
||||
tensor: BoolTensor<Self>,
|
||||
mask: BoolTensor<Self>,
|
||||
value: BoolTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
TchTensor::binary_ops_tensor(
|
||||
tensor,
|
||||
value,
|
||||
|tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
|
||||
|tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
|
||||
|tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn bool_mask_fill(
|
||||
tensor: BoolTensor<Self>,
|
||||
mask: BoolTensor<Self>,
|
||||
value: Scalar,
|
||||
) -> BoolTensor<Self> {
|
||||
tensor.unary_ops(
|
||||
|mut tensor| {
|
||||
tensor
|
||||
.f_masked_fill_(&mask.tensor, value.elem::<i64>())
|
||||
.unwrap()
|
||||
},
|
||||
|tensor| {
|
||||
tensor
|
||||
.f_masked_fill(&mask.tensor, value.elem::<i64>())
|
||||
.unwrap()
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
fn bool_gather(
|
||||
dim: usize,
|
||||
tensor: BoolTensor<Self>,
|
||||
indices: IntTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
TchOps::gather(dim, tensor, indices)
|
||||
}
|
||||
|
||||
fn bool_scatter_or(
|
||||
dim: usize,
|
||||
tensor: BoolTensor<Self>,
|
||||
indices: IntTensor<Self>,
|
||||
value: BoolTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
TchOps::scatter(dim, tensor, indices, value)
|
||||
}
|
||||
|
||||
fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
|
||||
TchOps::equal_elem(lhs, rhs.elem::<i64>())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,501 @@
|
||||
use std::ops::Range;
|
||||
|
||||
use burn_backend::{
|
||||
Distribution, ExecutionError, IntDType, Scalar, Shape, TensorData, TensorMetadata,
|
||||
ops::{FloatTensorOps, IntTensorOps},
|
||||
tensor::IntTensor,
|
||||
};
|
||||
|
||||
use crate::{IntoKind, LibTorch, LibTorchDevice, TchShape, TchTensor, element::TchElement};
|
||||
|
||||
use super::TchOps;
|
||||
|
||||
impl<E: TchElement> IntTensorOps<Self> for LibTorch<E> {
|
||||
fn int_from_data(data: TensorData, device: &LibTorchDevice) -> TchTensor {
|
||||
match data.dtype {
|
||||
burn_backend::DType::I64 => TchTensor::from_data::<i64>(data, (*device).into()),
|
||||
_ => unimplemented!("Unsupported dtype for `int_from_data`"),
|
||||
}
|
||||
}
|
||||
|
||||
fn int_repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor {
|
||||
TchOps::repeat_dim(tensor, dim, times)
|
||||
}
|
||||
|
||||
async fn int_into_data(tensor: TchTensor) -> Result<TensorData, ExecutionError> {
|
||||
let shape = tensor.shape();
|
||||
let tensor = Self::int_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
|
||||
let values: Result<Vec<i64>, tch::TchError> = tensor.tensor.shallow_clone().try_into();
|
||||
Ok(TensorData::new(values.unwrap(), shape))
|
||||
}
|
||||
|
||||
fn int_to_device(tensor: TchTensor, device: &LibTorchDevice) -> TchTensor {
|
||||
TchOps::to_device(tensor, device)
|
||||
}
|
||||
|
||||
fn int_reshape(tensor: TchTensor, shape: Shape) -> TchTensor {
|
||||
TchOps::reshape(tensor, shape)
|
||||
}
|
||||
|
||||
fn int_device(tensor: &TchTensor) -> LibTorchDevice {
|
||||
tensor.tensor.device().into()
|
||||
}
|
||||
|
||||
fn int_empty(shape: Shape, device: &LibTorchDevice, dtype: IntDType) -> TchTensor {
|
||||
let tensor = tch::Tensor::empty(
|
||||
TchShape::from(shape).dims,
|
||||
(dtype.into_kind(), (*device).into()),
|
||||
);
|
||||
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn int_slice(tensor: TchTensor, slices: &[burn_backend::Slice]) -> TchTensor {
|
||||
TchOps::slice_with_steps(tensor, slices)
|
||||
}
|
||||
|
||||
fn int_slice_assign(
|
||||
tensor: TchTensor,
|
||||
slices: &[burn_backend::Slice],
|
||||
value: TchTensor,
|
||||
) -> TchTensor {
|
||||
TchOps::slice_assign(tensor, slices, value)
|
||||
}
|
||||
|
||||
fn int_cat(tensors: Vec<TchTensor>, dim: usize) -> TchTensor {
|
||||
TchOps::cat(tensors, dim)
|
||||
}
|
||||
|
||||
fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
let lhs = Self::int_into_float(lhs);
|
||||
let rhs = Self::int_into_float(rhs);
|
||||
let out = lhs.tensor.f_matmul(&rhs.tensor).unwrap();
|
||||
Self::float_into_int(TchTensor::new(out))
|
||||
}
|
||||
|
||||
fn int_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchOps::equal(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_equal_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor {
|
||||
TchOps::equal_elem(lhs, rhs.elem::<i64>())
|
||||
}
|
||||
|
||||
fn int_greater(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchOps::greater(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_greater_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor {
|
||||
TchOps::greater_elem(lhs, rhs.elem::<i64>())
|
||||
}
|
||||
|
||||
fn int_greater_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchOps::greater_equal(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_greater_equal_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor {
|
||||
TchOps::greater_equal_elem(lhs, rhs.elem::<i64>())
|
||||
}
|
||||
|
||||
fn int_lower(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchOps::lower(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_lower_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor {
|
||||
TchOps::lower_elem(lhs, rhs.elem::<i64>())
|
||||
}
|
||||
|
||||
fn int_lower_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchOps::lower_equal(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_lower_equal_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor {
|
||||
TchOps::lower_equal_elem(lhs, rhs.elem::<i64>())
|
||||
}
|
||||
|
||||
fn int_add(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchOps::add(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_add_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
|
||||
lhs.unary_ops(
|
||||
|mut tensor| tensor.f_add_scalar_(rhs.elem::<i64>()).unwrap(),
|
||||
|tensor| tensor.f_add_scalar(rhs.elem::<i64>()).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn int_sub(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchOps::sub(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_sub_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
|
||||
lhs.unary_ops(
|
||||
|mut tensor| tensor.f_sub_scalar_(rhs.elem::<i64>()).unwrap(),
|
||||
|tensor| tensor.f_sub_scalar(rhs.elem::<i64>()).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn int_mul(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchOps::mul(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_mul_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
|
||||
lhs.unary_ops(
|
||||
|mut tensor| tensor.f_mul_scalar_(rhs.elem::<i64>()).unwrap(),
|
||||
|tensor| tensor.f_mul_scalar(rhs.elem::<i64>()).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn int_div(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
let dtype = lhs.tensor.kind();
|
||||
let copy = false;
|
||||
let non_blocking = true;
|
||||
let lhs: TchTensor =
|
||||
TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
|
||||
let rhs: TchTensor =
|
||||
TchTensor::new(rhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
|
||||
|
||||
let out = TchOps::div(lhs, rhs);
|
||||
|
||||
TchTensor::new(out.tensor.to_dtype(dtype, non_blocking, copy))
|
||||
}
|
||||
|
||||
fn int_div_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
|
||||
let dtype = lhs.tensor.kind();
|
||||
let copy = false;
|
||||
let non_blocking = true;
|
||||
let lhs: TchTensor =
|
||||
TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
|
||||
|
||||
let out: TchTensor = lhs.unary_ops(
|
||||
|mut tensor| tensor.f_div_scalar_(rhs.elem::<i64>()).unwrap(),
|
||||
|tensor| tensor.f_div_scalar(rhs.elem::<i64>()).unwrap(),
|
||||
);
|
||||
|
||||
TchTensor::new(out.tensor.to_dtype(dtype, non_blocking, copy))
|
||||
}
|
||||
|
||||
fn int_remainder(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
let dtype = lhs.tensor.kind();
|
||||
let copy = false;
|
||||
let non_blocking = true;
|
||||
let lhs: TchTensor =
|
||||
TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
|
||||
let rhs: TchTensor =
|
||||
TchTensor::new(rhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
|
||||
|
||||
let out = TchOps::remainder(lhs, rhs);
|
||||
|
||||
TchTensor::new(out.tensor.to_dtype(dtype, non_blocking, copy))
|
||||
}
|
||||
|
||||
fn int_remainder_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
|
||||
lhs.unary_ops(
|
||||
|tensor| tensor.f_remainder(rhs.elem::<i64>()).unwrap(),
|
||||
|tensor| tensor.f_remainder(rhs.elem::<i64>()).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn int_zeros(shape: Shape, device: &LibTorchDevice, dtype: IntDType) -> TchTensor {
|
||||
let shape = TchShape::from(shape);
|
||||
let device: tch::Device = (*device).into();
|
||||
|
||||
TchTensor::new(tch::Tensor::zeros(shape.dims, (dtype.into_kind(), device)))
|
||||
}
|
||||
|
||||
fn int_ones(shape: Shape, device: &LibTorchDevice, dtype: IntDType) -> TchTensor {
|
||||
let shape = TchShape::from(shape);
|
||||
let device: tch::Device = (*device).into();
|
||||
|
||||
TchTensor::new(tch::Tensor::ones(shape.dims, (dtype.into_kind(), device)))
|
||||
}
|
||||
|
||||
fn int_full(
|
||||
shape: Shape,
|
||||
fill_value: Scalar,
|
||||
device: &LibTorchDevice,
|
||||
dtype: IntDType,
|
||||
) -> TchTensor {
|
||||
let shape = TchShape::from(shape);
|
||||
let device: tch::Device = (*device).into();
|
||||
|
||||
TchTensor::new(tch::Tensor::full(
|
||||
shape.dims,
|
||||
fill_value.elem::<i64>(),
|
||||
(dtype.into_kind(), device),
|
||||
))
|
||||
}
|
||||
|
||||
fn int_sum(tensor: TchTensor) -> TchTensor {
|
||||
TchOps::sum(tensor)
|
||||
}
|
||||
|
||||
fn int_sum_dim(tensor: TchTensor, dim: usize) -> TchTensor {
|
||||
TchOps::sum_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_prod(tensor: TchTensor) -> TchTensor {
|
||||
TchOps::prod(tensor)
|
||||
}
|
||||
|
||||
fn int_prod_dim(tensor: TchTensor, dim: usize) -> TchTensor {
|
||||
TchOps::prod_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_mean(tensor: TchTensor) -> TchTensor {
|
||||
let dtype = tensor.tensor.kind();
|
||||
let tensor: TchTensor =
|
||||
TchTensor::new(tensor.tensor.to_dtype(tch::Kind::Float, true, false));
|
||||
let output: TchTensor = TchTensor::new(TchOps::mean(tensor).tensor);
|
||||
|
||||
TchTensor::new(output.tensor.to_dtype(dtype, true, false))
|
||||
}
|
||||
|
||||
fn int_mean_dim(tensor: TchTensor, dim: usize) -> TchTensor {
|
||||
let dtype = tensor.tensor.kind();
|
||||
let tensor: TchTensor =
|
||||
TchTensor::new(tensor.tensor.to_dtype(tch::Kind::Float, true, false));
|
||||
|
||||
let output: TchTensor = TchTensor::new(TchOps::mean_dim(tensor, dim).tensor);
|
||||
|
||||
TchTensor::new(output.tensor.to_dtype(dtype, true, false))
|
||||
}
|
||||
|
||||
fn int_cumsum(tensor: TchTensor, dim: usize) -> TchTensor {
|
||||
TchOps::cumsum(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_cumprod(tensor: TchTensor, dim: usize) -> TchTensor {
|
||||
TchOps::cumprod(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_cummin(tensor: TchTensor, dim: usize) -> TchTensor {
|
||||
TchOps::cummin(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_cummax(tensor: TchTensor, dim: usize) -> TchTensor {
|
||||
TchOps::cummax(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_gather(dim: usize, tensor: TchTensor, indices: TchTensor) -> TchTensor {
|
||||
TchOps::gather(dim, tensor, indices)
|
||||
}
|
||||
|
||||
fn int_scatter_add(
|
||||
dim: usize,
|
||||
tensor: TchTensor,
|
||||
indices: TchTensor,
|
||||
value: TchTensor,
|
||||
) -> TchTensor {
|
||||
TchOps::scatter(dim, tensor, indices, value)
|
||||
}
|
||||
|
||||
fn int_select(tensor: TchTensor, dim: usize, indices: TchTensor) -> TchTensor {
|
||||
TchOps::index_select_dim(tensor, dim, indices)
|
||||
}
|
||||
|
||||
fn int_select_add(
|
||||
tensor: TchTensor,
|
||||
dim: usize,
|
||||
indices: TchTensor,
|
||||
value: TchTensor,
|
||||
) -> TchTensor {
|
||||
TchOps::select_assign(tensor, dim, indices, value)
|
||||
}
|
||||
|
||||
fn int_mask_where(tensor: TchTensor, mask: TchTensor, source: TchTensor) -> TchTensor {
|
||||
TchTensor::binary_ops_tensor(
|
||||
tensor,
|
||||
source,
|
||||
|tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
|
||||
|tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
|
||||
|tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn int_mask_fill(tensor: TchTensor, mask: TchTensor, value: Scalar) -> TchTensor {
|
||||
let value = value.elem::<i64>();
|
||||
tensor.unary_ops(
|
||||
|mut tensor| tensor.f_masked_fill_(&mask.tensor, value).unwrap(),
|
||||
|tensor| tensor.f_masked_fill(&mask.tensor, value).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn int_argmax(tensor: TchTensor, dim: usize) -> TchTensor {
|
||||
TchOps::argmax(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_argmin(tensor: TchTensor, dim: usize) -> TchTensor {
|
||||
TchOps::argmin(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_max_dim(tensor: TchTensor, dim: usize) -> TchTensor {
|
||||
TchOps::max_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_max_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {
|
||||
TchOps::max_dim_with_indices(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_min_dim(tensor: TchTensor, dim: usize) -> TchTensor {
|
||||
TchOps::min_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_min_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {
|
||||
TchOps::min_dim_with_indices(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_clamp_min(tensor: TchTensor, min: Scalar) -> TchTensor {
|
||||
TchOps::clamp_min(tensor, min.elem::<i64>())
|
||||
}
|
||||
|
||||
fn int_clamp_max(tensor: TchTensor, max: Scalar) -> TchTensor {
|
||||
TchOps::clamp_max(tensor, max.elem::<i64>())
|
||||
}
|
||||
|
||||
fn int_clamp(tensor: TchTensor, min: Scalar, max: Scalar) -> TchTensor {
|
||||
TchOps::clamp(tensor, min.elem::<i64>(), max.elem::<i64>())
|
||||
}
|
||||
|
||||
fn int_abs(tensor: TchTensor) -> TchTensor {
|
||||
tensor.unary_ops(|mut tensor| tensor.abs_(), |tensor| tensor.abs())
|
||||
}
|
||||
|
||||
fn int_into_float(tensor: TchTensor) -> TchTensor {
|
||||
let tensor = tensor.tensor.to_kind(E::kind());
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn int_swap_dims(tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {
|
||||
TchOps::swap_dims(tensor, dim1, dim2)
|
||||
}
|
||||
|
||||
fn int_random(shape: Shape, distribution: Distribution, device: &LibTorchDevice) -> TchTensor {
|
||||
match distribution {
|
||||
Distribution::Default => TchTensor::new(tch::Tensor::randint_low(
|
||||
0,
|
||||
255,
|
||||
shape.iter().map(|i| *i as i64).collect::<Vec<_>>(),
|
||||
(tch::Kind::Int64, (*device).into()),
|
||||
)),
|
||||
Distribution::Bernoulli(prob) => {
|
||||
let mut tensor = TchTensor::empty::<i64>(shape, *device);
|
||||
tensor
|
||||
.mut_ops(|tensor| tensor.f_bernoulli_float_(prob).unwrap())
|
||||
.unwrap()
|
||||
}
|
||||
Distribution::Uniform(from, to) => TchTensor::new(tch::Tensor::randint_low(
|
||||
from as i64,
|
||||
to as i64,
|
||||
shape.iter().map(|i| *i as i64).collect::<Vec<_>>(),
|
||||
(tch::Kind::Int64, (*device).into()),
|
||||
)),
|
||||
Distribution::Normal(mean, std) => {
|
||||
let mut tensor = TchTensor::empty::<i64>(shape, *device);
|
||||
tensor.mut_ops(|tensor| tensor.normal_(mean, std)).unwrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn int_arange(range: Range<i64>, device: &LibTorchDevice) -> TchTensor {
|
||||
let device: tch::Device = (*device).into();
|
||||
let mut tensor = tch::Tensor::arange(range.end - range.start, (tch::Kind::Int64, device));
|
||||
|
||||
if range.start != 0 {
|
||||
tensor = tensor.f_add_scalar_(range.start).unwrap();
|
||||
}
|
||||
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
|
||||
TchOps::permute(tensor, axes)
|
||||
}
|
||||
|
||||
fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
|
||||
TchOps::flip(tensor, axes)
|
||||
}
|
||||
|
||||
fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> {
|
||||
TchOps::sign(tensor)
|
||||
}
|
||||
|
||||
fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
|
||||
TchOps::expand(tensor, shape)
|
||||
}
|
||||
|
||||
fn int_sort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
|
||||
TchOps::sort(tensor, dim, descending)
|
||||
}
|
||||
|
||||
fn int_argsort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
|
||||
TchOps::argsort(tensor, dim, descending)
|
||||
}
|
||||
|
||||
fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
TchOps::bitwise_and(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
TchOps::bitwise_or(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
TchOps::bitwise_xor(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
|
||||
TchOps::bitwise_not(tensor)
|
||||
}
|
||||
|
||||
fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
TchOps::bitwise_and_scalar(lhs, rhs.elem::<i64>())
|
||||
}
|
||||
|
||||
fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
TchOps::bitwise_or_scalar(lhs, rhs.elem::<i64>())
|
||||
}
|
||||
|
||||
fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
TchOps::bitwise_xor_scalar(lhs, rhs.elem::<i64>())
|
||||
}
|
||||
|
||||
fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
TchOps::bitwise_left_shift(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
TchOps::bitwise_right_shift(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
TchOps::bitwise_left_shift_scalar(lhs, rhs.elem::<i64>())
|
||||
}
|
||||
|
||||
fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
TchOps::bitwise_right_shift_scalar(lhs, rhs.elem::<i64>())
|
||||
}
|
||||
|
||||
fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
|
||||
// NOTE: when dtypes of inputs to an arithmetic operation differ, tch handles type
|
||||
// promotion based on a set of rules: https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc
|
||||
|
||||
// Type promotion is not automatic on all backends so this behavior might differ
|
||||
let kind = dtype.into_kind();
|
||||
|
||||
if tensor.tensor.kind() == kind {
|
||||
tensor
|
||||
} else {
|
||||
TchTensor::new(tensor.tensor.to_kind(kind))
|
||||
}
|
||||
}
|
||||
|
||||
fn int_unfold(
|
||||
tensor: IntTensor<Self>,
|
||||
dim: usize,
|
||||
size: usize,
|
||||
step: usize,
|
||||
) -> IntTensor<Self> {
|
||||
TchOps::unfold(tensor, dim, size, step)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
mod activation;
|
||||
mod base;
|
||||
mod bool_tensor;
|
||||
mod int_tensor;
|
||||
mod module;
|
||||
mod qtensor;
|
||||
mod tensor;
|
||||
mod transaction;
|
||||
|
||||
pub(crate) use base::*;
|
||||
@@ -0,0 +1,473 @@
|
||||
use crate::{LibTorch, TchTensor, element::TchElement};
|
||||
use burn_backend::{
|
||||
TensorMetadata,
|
||||
ops::{
|
||||
AttentionModuleOptions, ConvOptions, ConvTransposeOptions, DeformConv2dBackward,
|
||||
DeformConvOptions, InterpolateMode, InterpolateOptions, MaxPool1dWithIndices,
|
||||
MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, attention::attention_fallback,
|
||||
},
|
||||
};
|
||||
|
||||
impl<E: TchElement> ModuleOps<Self> for LibTorch<E> {
|
||||
fn embedding(weights: TchTensor, indices: TchTensor) -> TchTensor {
|
||||
// Workaround for MPS "Placeholder storage has not been allocated" error.
|
||||
// See: https://github.com/pytorch/pytorch/issues/123995
|
||||
// MPS uses lazy allocation and the embedding operation (which uses index_select)
|
||||
// can fail if the tensors haven't been materialized yet.
|
||||
// We work around this by performing the embedding on CPU and transferring back to MPS.
|
||||
if matches!(weights.tensor.device(), tch::Device::Mps) {
|
||||
let cpu_weights = weights.tensor.to(tch::Device::Cpu);
|
||||
let cpu_indices = indices.tensor.to(tch::Device::Cpu);
|
||||
let result = tch::Tensor::embedding(&cpu_weights, &cpu_indices, -1, false, false)
|
||||
.to(tch::Device::Mps);
|
||||
return TchTensor::new(result);
|
||||
}
|
||||
|
||||
let tensor = tch::Tensor::embedding(&weights.tensor, &indices.tensor, -1, false, false);
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn embedding_backward(weights: TchTensor, output: TchTensor, indices: TchTensor) -> TchTensor {
|
||||
let [n_embedding, _d_model] = weights.shape().dims();
|
||||
|
||||
// Workaround for MPS "Placeholder storage has not been allocated" error.
|
||||
// See: https://github.com/pytorch/pytorch/issues/123995
|
||||
if matches!(output.tensor.device(), tch::Device::Mps) {
|
||||
let cpu_output = output.tensor.to(tch::Device::Cpu);
|
||||
let cpu_indices = indices.tensor.to(tch::Device::Cpu);
|
||||
let result = tch::Tensor::embedding_backward(
|
||||
&cpu_output,
|
||||
&cpu_indices,
|
||||
n_embedding as i64,
|
||||
-1,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.to(tch::Device::Mps);
|
||||
return TchTensor::new(result);
|
||||
}
|
||||
|
||||
let tensor = tch::Tensor::embedding_backward(
|
||||
&output.tensor,
|
||||
&indices.tensor,
|
||||
n_embedding as i64,
|
||||
-1,
|
||||
false,
|
||||
false,
|
||||
);
|
||||
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn conv1d(
|
||||
x: TchTensor,
|
||||
weight: TchTensor,
|
||||
bias: Option<TchTensor>,
|
||||
options: ConvOptions<1>,
|
||||
) -> TchTensor {
|
||||
let tensor = tch::Tensor::conv1d(
|
||||
&x.tensor,
|
||||
&weight.tensor,
|
||||
bias.map(|t| t.tensor),
|
||||
options.stride.map(|i| i as i64),
|
||||
options.padding.map(|i| i as i64),
|
||||
options.dilation.map(|i| i as i64),
|
||||
options.groups as i64,
|
||||
);
|
||||
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn conv2d(
|
||||
x: TchTensor,
|
||||
weight: TchTensor,
|
||||
bias: Option<TchTensor>,
|
||||
options: ConvOptions<2>,
|
||||
) -> TchTensor {
|
||||
let tensor = tch::Tensor::conv2d(
|
||||
&x.tensor,
|
||||
&weight.tensor,
|
||||
bias.map(|t| t.tensor),
|
||||
options.stride.map(|i| i as i64),
|
||||
options.padding.map(|i| i as i64),
|
||||
options.dilation.map(|i| i as i64),
|
||||
options.groups as i64,
|
||||
);
|
||||
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn conv3d(
|
||||
x: TchTensor,
|
||||
weight: TchTensor,
|
||||
bias: Option<TchTensor>,
|
||||
options: ConvOptions<3>,
|
||||
) -> TchTensor {
|
||||
let tensor = tch::Tensor::conv3d(
|
||||
&x.tensor,
|
||||
&weight.tensor,
|
||||
bias.map(|t| t.tensor),
|
||||
options.stride.map(|i| i as i64),
|
||||
options.padding.map(|i| i as i64),
|
||||
options.dilation.map(|i| i as i64),
|
||||
options.groups as i64,
|
||||
);
|
||||
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn deform_conv2d(
|
||||
_x: TchTensor,
|
||||
_offset: TchTensor,
|
||||
_weight: TchTensor,
|
||||
_mask: Option<TchTensor>,
|
||||
_bias: Option<TchTensor>,
|
||||
_options: DeformConvOptions<2>,
|
||||
) -> TchTensor {
|
||||
unimplemented!("Torch bindings don't support deform_conv2d");
|
||||
}
|
||||
|
||||
fn deform_conv2d_backward(
|
||||
_x: TchTensor,
|
||||
_offset: TchTensor,
|
||||
_weight: TchTensor,
|
||||
_mask: Option<TchTensor>,
|
||||
_bias: Option<TchTensor>,
|
||||
_out_grad: TchTensor,
|
||||
_options: DeformConvOptions<2>,
|
||||
) -> DeformConv2dBackward<Self> {
|
||||
unimplemented!("Torch bindings don't support deform_conv2d");
|
||||
}
|
||||
|
||||
fn conv_transpose1d(
|
||||
x: TchTensor,
|
||||
weight: TchTensor,
|
||||
bias: Option<TchTensor>,
|
||||
options: ConvTransposeOptions<1>,
|
||||
) -> TchTensor {
|
||||
let tensor = tch::Tensor::conv_transpose1d(
|
||||
&x.tensor,
|
||||
&weight.tensor,
|
||||
bias.map(|t| t.tensor),
|
||||
options.stride.map(|i| i as i64),
|
||||
options.padding.map(|i| i as i64),
|
||||
options.padding_out.map(|i| i as i64),
|
||||
options.groups as i64,
|
||||
options.dilation.map(|i| i as i64),
|
||||
);
|
||||
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn conv_transpose2d(
|
||||
x: TchTensor,
|
||||
weight: TchTensor,
|
||||
bias: Option<TchTensor>,
|
||||
options: ConvTransposeOptions<2>,
|
||||
) -> TchTensor {
|
||||
let tensor = tch::Tensor::conv_transpose2d(
|
||||
&x.tensor,
|
||||
&weight.tensor,
|
||||
bias.map(|t| t.tensor),
|
||||
options.stride.map(|i| i as i64),
|
||||
options.padding.map(|i| i as i64),
|
||||
options.padding_out.map(|i| i as i64),
|
||||
options.groups as i64,
|
||||
options.dilation.map(|i| i as i64),
|
||||
);
|
||||
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn conv_transpose3d(
|
||||
x: TchTensor,
|
||||
weight: TchTensor,
|
||||
bias: Option<TchTensor>,
|
||||
options: ConvTransposeOptions<3>,
|
||||
) -> TchTensor {
|
||||
let tensor = tch::Tensor::conv_transpose3d(
|
||||
&x.tensor,
|
||||
&weight.tensor,
|
||||
bias.map(|t| t.tensor),
|
||||
options.stride.map(|i| i as i64),
|
||||
options.padding.map(|i| i as i64),
|
||||
options.padding_out.map(|i| i as i64),
|
||||
options.groups as i64,
|
||||
options.dilation.map(|i| i as i64),
|
||||
);
|
||||
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn avg_pool1d(
|
||||
x: TchTensor,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
count_include_pad: bool,
|
||||
ceil_mode: bool,
|
||||
) -> TchTensor {
|
||||
let tensor = tch::Tensor::avg_pool1d(
|
||||
&x.tensor,
|
||||
[kernel_size as i64],
|
||||
[stride as i64],
|
||||
[padding as i64],
|
||||
ceil_mode,
|
||||
count_include_pad,
|
||||
);
|
||||
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
fn avg_pool2d(
|
||||
x: TchTensor,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
count_include_pad: bool,
|
||||
ceil_mode: bool,
|
||||
) -> TchTensor {
|
||||
let tensor = tch::Tensor::avg_pool2d(
|
||||
&x.tensor,
|
||||
[kernel_size[0] as i64, kernel_size[1] as i64],
|
||||
[stride[0] as i64, stride[1] as i64],
|
||||
[padding[0] as i64, padding[1] as i64],
|
||||
ceil_mode,
|
||||
count_include_pad,
|
||||
None,
|
||||
);
|
||||
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn avg_pool2d_backward(
|
||||
x: TchTensor,
|
||||
grad: TchTensor,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
count_include_pad: bool,
|
||||
ceil_mode: bool,
|
||||
) -> TchTensor {
|
||||
let tensor = tch::Tensor::avg_pool2d_backward(
|
||||
&x.tensor,
|
||||
&grad.tensor,
|
||||
[kernel_size[0] as i64, kernel_size[1] as i64],
|
||||
[stride[0] as i64, stride[1] as i64],
|
||||
[padding[0] as i64, padding[1] as i64],
|
||||
ceil_mode,
|
||||
count_include_pad,
|
||||
None,
|
||||
);
|
||||
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn max_pool1d(
|
||||
x: TchTensor,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
ceil_mode: bool,
|
||||
) -> TchTensor {
|
||||
let tensor = tch::Tensor::max_pool1d(
|
||||
&x.tensor,
|
||||
kernel_size as i64,
|
||||
stride as i64,
|
||||
padding as i64,
|
||||
dilation as i64,
|
||||
ceil_mode,
|
||||
);
|
||||
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn max_pool1d_with_indices(
|
||||
x: TchTensor,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
ceil_mode: bool,
|
||||
) -> MaxPool1dWithIndices<Self> {
|
||||
let (tensor, indices) = tch::Tensor::max_pool1d_with_indices(
|
||||
&x.tensor,
|
||||
kernel_size as i64,
|
||||
stride as i64,
|
||||
padding as i64,
|
||||
dilation as i64,
|
||||
ceil_mode,
|
||||
);
|
||||
|
||||
MaxPool1dWithIndices::new(TchTensor::new(tensor), TchTensor::new(indices))
|
||||
}
|
||||
|
||||
fn max_pool2d(
|
||||
x: TchTensor,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
ceil_mode: bool,
|
||||
) -> TchTensor {
|
||||
let tensor = tch::Tensor::max_pool2d(
|
||||
&x.tensor,
|
||||
[kernel_size[0] as i64, kernel_size[1] as i64],
|
||||
[stride[0] as i64, stride[1] as i64],
|
||||
[padding[0] as i64, padding[1] as i64],
|
||||
[dilation[0] as i64, dilation[1] as i64],
|
||||
ceil_mode,
|
||||
);
|
||||
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn max_pool2d_with_indices(
|
||||
x: TchTensor,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
ceil_mode: bool,
|
||||
) -> MaxPool2dWithIndices<Self> {
|
||||
let (tensor, indices) = tch::Tensor::max_pool2d_with_indices(
|
||||
&x.tensor,
|
||||
[kernel_size[0] as i64, kernel_size[1] as i64],
|
||||
[stride[0] as i64, stride[1] as i64],
|
||||
[padding[0] as i64, padding[1] as i64],
|
||||
[dilation[0] as i64, dilation[1] as i64],
|
||||
ceil_mode,
|
||||
);
|
||||
|
||||
MaxPool2dWithIndices::new(TchTensor::new(tensor), TchTensor::new(indices))
|
||||
}
|
||||
|
||||
fn max_pool2d_with_indices_backward(
|
||||
x: TchTensor,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
ceil_mode: bool,
|
||||
output_grad: TchTensor,
|
||||
indices: TchTensor,
|
||||
) -> MaxPool2dBackward<Self> {
|
||||
let grad = tch::Tensor::max_pool2d_with_indices_backward(
|
||||
&x.tensor,
|
||||
&output_grad.tensor,
|
||||
[kernel_size[0] as i64, kernel_size[1] as i64],
|
||||
[stride[0] as i64, stride[1] as i64],
|
||||
[padding[0] as i64, padding[1] as i64],
|
||||
[dilation[0] as i64, dilation[1] as i64],
|
||||
ceil_mode,
|
||||
&indices.tensor,
|
||||
);
|
||||
|
||||
MaxPool2dBackward::new(TchTensor::new(grad))
|
||||
}
|
||||
|
||||
fn adaptive_avg_pool2d(x: TchTensor, output_size: [usize; 2]) -> TchTensor {
|
||||
let tensor = tch::Tensor::adaptive_avg_pool2d(&x.tensor, output_size.map(|e| e as i64));
|
||||
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn adaptive_avg_pool2d_backward(x: TchTensor, grad: TchTensor) -> TchTensor {
|
||||
let tensor = tch::Tensor::internal_adaptive_avg_pool2d_backward(&x.tensor, &grad.tensor);
|
||||
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn adaptive_avg_pool1d(x: TchTensor, output_size: usize) -> TchTensor {
|
||||
let tensor = tch::Tensor::adaptive_avg_pool1d(&x.tensor, output_size as i64);
|
||||
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn interpolate(
|
||||
x: TchTensor,
|
||||
output_size: [usize; 2],
|
||||
options: InterpolateOptions,
|
||||
) -> TchTensor {
|
||||
let output_size = output_size.map(|e| e as i64);
|
||||
|
||||
let align_corners = options.align_corners;
|
||||
let tensor = match options.mode {
|
||||
InterpolateMode::Nearest => {
|
||||
tch::Tensor::upsample_nearest2d(&x.tensor, output_size, None, None)
|
||||
}
|
||||
InterpolateMode::Bilinear => {
|
||||
tch::Tensor::upsample_bilinear2d(&x.tensor, output_size, align_corners, None, None)
|
||||
}
|
||||
InterpolateMode::Bicubic => {
|
||||
tch::Tensor::upsample_bicubic2d(&x.tensor, output_size, align_corners, None, None)
|
||||
}
|
||||
};
|
||||
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn interpolate_backward(
|
||||
x: TchTensor,
|
||||
grad: TchTensor,
|
||||
output_size: [usize; 2],
|
||||
options: InterpolateOptions,
|
||||
) -> TchTensor {
|
||||
let output_size = output_size.map(|e| e as i64);
|
||||
let [n, c, h_in, w_in] = x.shape().dims();
|
||||
let input_size = [n as i64, c as i64, h_in as i64, w_in as i64];
|
||||
let align_corners = options.align_corners;
|
||||
|
||||
let tensor = match options.mode {
|
||||
InterpolateMode::Nearest => tch::Tensor::upsample_nearest2d_backward(
|
||||
&grad.tensor,
|
||||
output_size,
|
||||
input_size,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
InterpolateMode::Bilinear => tch::Tensor::upsample_bilinear2d_backward(
|
||||
&grad.tensor,
|
||||
output_size,
|
||||
input_size,
|
||||
align_corners,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
InterpolateMode::Bicubic => tch::Tensor::upsample_bicubic2d_backward(
|
||||
&grad.tensor,
|
||||
output_size,
|
||||
input_size,
|
||||
align_corners,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
};
|
||||
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn attention(
|
||||
query: TchTensor,
|
||||
key: TchTensor,
|
||||
value: TchTensor,
|
||||
mask: Option<TchTensor>,
|
||||
attn_bias: Option<TchTensor>,
|
||||
options: AttentionModuleOptions,
|
||||
) -> TchTensor {
|
||||
if attn_bias.is_some() {
|
||||
return attention_fallback::<Self>(query, key, value, mask, attn_bias, options);
|
||||
}
|
||||
|
||||
TchTensor::new(tch::Tensor::scaled_dot_product_attention(
|
||||
&query.tensor,
|
||||
&key.tensor,
|
||||
&value.tensor,
|
||||
mask.map(|m| m.tensor),
|
||||
0.,
|
||||
options.is_causal,
|
||||
options.scale,
|
||||
false,
|
||||
))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,140 @@
|
||||
use burn_backend::{
|
||||
ExecutionError, Shape, TensorData,
|
||||
ops::QTensorOps,
|
||||
quantization::{QuantScheme, QuantizationParametersPrimitive},
|
||||
tensor::{Device, FloatTensor, IntTensor, QuantizedTensor},
|
||||
};
|
||||
|
||||
use crate::{LibTorch, LibTorchDevice, TchElement};
|
||||
|
||||
impl<E: TchElement> QTensorOps<Self> for LibTorch<E> {
|
||||
fn q_from_data(_data: TensorData, _device: &LibTorchDevice) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn quantize(
|
||||
_tensor: FloatTensor<Self>,
|
||||
_scheme: &QuantScheme,
|
||||
_qparams: QuantizationParametersPrimitive<Self>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn quantize_dynamic(
|
||||
_tensor: FloatTensor<Self>,
|
||||
_scheme: &QuantScheme,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn dequantize(_tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_device(_tensor: &QuantizedTensor<Self>) -> LibTorchDevice {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_to_device(
|
||||
_tensor: QuantizedTensor<Self>,
|
||||
_device: &Device<Self>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_reshape(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
async fn q_into_data(_tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {
|
||||
unimplemented!()
|
||||
}
|
||||
fn q_swap_dims(
|
||||
_tensor: QuantizedTensor<Self>,
|
||||
_dim1: usize,
|
||||
_dim2: usize,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_permute(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_flip(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_select(
|
||||
_tensor: QuantizedTensor<Self>,
|
||||
_dim: usize,
|
||||
_indices: IntTensor<Self>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_slice(
|
||||
_tensor: QuantizedTensor<Self>,
|
||||
_slices: &[burn_backend::Slice],
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_argmax(_tensor: QuantizedTensor<Self>, _dim: usize) -> IntTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_argmin(_tensor: QuantizedTensor<Self>, _dim: usize) -> IntTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_max_dim_with_indices(
|
||||
_tensor: QuantizedTensor<Self>,
|
||||
_dim: usize,
|
||||
) -> (QuantizedTensor<Self>, IntTensor<Self>) {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_max_dim(_tensor: QuantizedTensor<Self>, _dim: usize) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_min_dim(_tensor: QuantizedTensor<Self>, _dim: usize) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_min_dim_with_indices(
|
||||
_tensor: QuantizedTensor<Self>,
|
||||
_dim: usize,
|
||||
) -> (QuantizedTensor<Self>, IntTensor<Self>) {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_expand(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_sort(
|
||||
_tensor: QuantizedTensor<Self>,
|
||||
_dim: usize,
|
||||
_descending: bool,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_sort_with_indices(
|
||||
_tensor: QuantizedTensor<Self>,
|
||||
_dim: usize,
|
||||
_descending: bool,
|
||||
) -> (QuantizedTensor<Self>, IntTensor<Self>) {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_argsort(
|
||||
_tensor: QuantizedTensor<Self>,
|
||||
_dim: usize,
|
||||
_descending: bool,
|
||||
) -> IntTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,539 @@
|
||||
use super::TchOps;
|
||||
use crate::{IntoKind, LibTorch, LibTorchDevice, TchShape, TchTensor, element::TchElement};
|
||||
use burn_backend::backend::ExecutionError;
|
||||
use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor};
|
||||
use burn_backend::{
|
||||
DType, Distribution, FloatDType, Shape, TensorData, TensorMetadata, ops::FloatTensorOps,
|
||||
};
|
||||
use burn_backend::{Scalar, bf16, f16};
|
||||
|
||||
impl<E: TchElement> FloatTensorOps<Self> for LibTorch<E> {
|
||||
fn float_from_data(data: TensorData, device: &LibTorchDevice) -> TchTensor {
|
||||
match data.dtype {
|
||||
DType::F64 => TchTensor::from_data::<f64>(data, (*device).into()),
|
||||
DType::F32 => TchTensor::from_data::<f32>(data, (*device).into()),
|
||||
DType::F16 => TchTensor::from_data::<f16>(data, (*device).into()),
|
||||
DType::BF16 => TchTensor::from_data::<bf16>(data, (*device).into()),
|
||||
_ => unimplemented!("Unsupported dtype for `float_from_data`"),
|
||||
}
|
||||
}
|
||||
|
||||
fn float_random(
|
||||
shape: Shape,
|
||||
distribution: Distribution,
|
||||
device: &LibTorchDevice,
|
||||
) -> TchTensor {
|
||||
match distribution {
|
||||
Distribution::Default => {
|
||||
let mut tensor = TchTensor::empty::<E>(shape, *device);
|
||||
tensor
|
||||
.mut_ops(|tensor| tensor.rand_like_out(tensor))
|
||||
.unwrap()
|
||||
}
|
||||
Distribution::Bernoulli(prob) => {
|
||||
let mut tensor = TchTensor::empty::<E>(shape, *device);
|
||||
tensor
|
||||
.mut_ops(|tensor| tensor.f_bernoulli_float_(prob).unwrap())
|
||||
.unwrap()
|
||||
}
|
||||
Distribution::Uniform(from, to) => {
|
||||
let mut tensor = TchTensor::empty::<E>(shape, *device);
|
||||
tensor.mut_ops(|tensor| tensor.uniform_(from, to)).unwrap()
|
||||
}
|
||||
Distribution::Normal(mean, std) => {
|
||||
let mut tensor = TchTensor::empty::<E>(shape, *device);
|
||||
tensor.mut_ops(|tensor| tensor.normal_(mean, std)).unwrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn float_repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor {
|
||||
TchOps::repeat_dim(tensor, dim, times)
|
||||
}
|
||||
|
||||
fn float_zeros(shape: Shape, device: &LibTorchDevice, dtype: FloatDType) -> TchTensor {
|
||||
let shape = TchShape::from(shape);
|
||||
let device: tch::Device = (*device).into();
|
||||
|
||||
TchTensor::new(tch::Tensor::zeros(shape.dims, (dtype.into_kind(), device)))
|
||||
}
|
||||
|
||||
fn float_ones(shape: Shape, device: &LibTorchDevice, dtype: FloatDType) -> TchTensor {
|
||||
let shape = TchShape::from(shape);
|
||||
let device: tch::Device = (*device).into();
|
||||
|
||||
TchTensor::new(tch::Tensor::ones(shape.dims, (dtype.into_kind(), device)))
|
||||
}
|
||||
|
||||
async fn float_into_data(tensor: TchTensor) -> Result<TensorData, ExecutionError> {
|
||||
let shape = tensor.shape();
|
||||
let tensor = Self::float_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
|
||||
Ok(match tensor.tensor.kind() {
|
||||
tch::Kind::Half => {
|
||||
let values = Vec::<f16>::try_from(&tensor).unwrap();
|
||||
TensorData::new(values, shape)
|
||||
}
|
||||
tch::Kind::Float => {
|
||||
let values = Vec::<f32>::try_from(&tensor).unwrap();
|
||||
TensorData::new(values, shape)
|
||||
}
|
||||
tch::Kind::Double => {
|
||||
let values = Vec::<f64>::try_from(&tensor).unwrap();
|
||||
TensorData::new(values, shape)
|
||||
}
|
||||
tch::Kind::BFloat16 => {
|
||||
let values = Vec::<bf16>::try_from(&tensor).unwrap();
|
||||
TensorData::new(values, shape)
|
||||
}
|
||||
_ => panic!("Not a valid float kind"),
|
||||
})
|
||||
}
|
||||
|
||||
fn float_device(tensor: &TchTensor) -> LibTorchDevice {
|
||||
tensor.tensor.device().into()
|
||||
}
|
||||
|
||||
fn float_to_device(tensor: TchTensor, device: &LibTorchDevice) -> TchTensor {
|
||||
TchOps::to_device(tensor, device)
|
||||
}
|
||||
|
||||
fn float_empty(shape: Shape, device: &LibTorchDevice, dtype: FloatDType) -> TchTensor {
|
||||
let tensor = tch::Tensor::empty(
|
||||
TchShape::from(shape).dims,
|
||||
(dtype.into_kind(), (*device).into()),
|
||||
);
|
||||
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn float_add(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchOps::add(lhs, rhs)
|
||||
}
|
||||
|
||||
fn float_add_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
|
||||
let rhs: f64 = rhs.elem();
|
||||
|
||||
lhs.unary_ops(
|
||||
|mut tensor| tensor.f_add_scalar_(rhs).unwrap(),
|
||||
|tensor| tensor.f_add_scalar(rhs).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn float_sub(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchOps::sub(lhs, rhs)
|
||||
}
|
||||
|
||||
fn float_sub_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
|
||||
let rhs: f64 = rhs.elem();
|
||||
|
||||
lhs.unary_ops(
|
||||
|mut tensor| tensor.f_sub_scalar_(rhs).unwrap(),
|
||||
|tensor| tensor.f_sub_scalar(rhs).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn float_mul(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchOps::mul(lhs, rhs)
|
||||
}
|
||||
|
||||
fn float_mul_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
|
||||
let rhs: f64 = rhs.elem();
|
||||
|
||||
lhs.unary_ops(
|
||||
|mut tensor| tensor.f_mul_scalar_(rhs).unwrap(),
|
||||
|tensor| tensor.f_mul_scalar(rhs).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn float_div(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchOps::div(lhs, rhs)
|
||||
}
|
||||
|
||||
fn float_div_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
|
||||
let rhs: f64 = rhs.elem();
|
||||
|
||||
lhs.unary_ops(
|
||||
|mut tensor| tensor.f_div_scalar_(rhs).unwrap(),
|
||||
|tensor| tensor.f_div_scalar(rhs).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn float_remainder(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchOps::remainder(lhs, rhs)
|
||||
}
|
||||
|
||||
fn float_remainder_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
|
||||
let rhs: f64 = rhs.elem();
|
||||
|
||||
lhs.unary_ops(
|
||||
|tensor| tensor.f_remainder(rhs).unwrap(),
|
||||
|tensor| tensor.f_remainder(rhs).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn float_matmul(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
let tensor = lhs.tensor.matmul(&rhs.tensor);
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn float_cross(lhs: TchTensor, rhs: TchTensor, dim: usize) -> TchTensor {
|
||||
let tensor = lhs.tensor.cross(&rhs.tensor, dim as i64);
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn float_recip(tensor: TchTensor) -> TchTensor {
|
||||
TchTensor::new(tensor.tensor.reciprocal())
|
||||
}
|
||||
|
||||
fn float_swap_dims(tensor: TchTensor, dim1: usize, dim2: usize) -> TchTensor {
|
||||
TchOps::swap_dims(tensor, dim1, dim2)
|
||||
}
|
||||
|
||||
fn float_reshape(tensor: TchTensor, shape: Shape) -> TchTensor {
|
||||
TchOps::reshape(tensor, shape)
|
||||
}
|
||||
|
||||
fn float_gather(dim: usize, tensor: TchTensor, indices: TchTensor) -> TchTensor {
|
||||
TchOps::gather(dim, tensor, indices)
|
||||
}
|
||||
|
||||
fn float_scatter_add(
|
||||
dim: usize,
|
||||
tensor: TchTensor,
|
||||
indices: TchTensor,
|
||||
value: TchTensor,
|
||||
) -> TchTensor {
|
||||
TchOps::scatter(dim, tensor, indices, value)
|
||||
}
|
||||
|
||||
fn float_select(tensor: TchTensor, dim: usize, indices: TchTensor) -> TchTensor {
|
||||
TchOps::index_select_dim(tensor, dim, indices)
|
||||
}
|
||||
|
||||
fn float_select_add(
|
||||
tensor: TchTensor,
|
||||
dim: usize,
|
||||
indices: TchTensor,
|
||||
value: TchTensor,
|
||||
) -> TchTensor {
|
||||
TchOps::select_assign(tensor, dim, indices, value)
|
||||
}
|
||||
|
||||
fn float_slice(tensor: TchTensor, slices: &[burn_backend::Slice]) -> TchTensor {
|
||||
TchOps::slice_with_steps(tensor, slices)
|
||||
}
|
||||
|
||||
fn float_slice_assign(
|
||||
tensor: TchTensor,
|
||||
slices: &[burn_backend::Slice],
|
||||
value: TchTensor,
|
||||
) -> TchTensor {
|
||||
TchOps::slice_assign(tensor, slices, value)
|
||||
}
|
||||
|
||||
fn float_mask_where(tensor: TchTensor, mask: TchTensor, value: TchTensor) -> TchTensor {
|
||||
let output = value.tensor.where_self(&mask.tensor, &tensor.tensor);
|
||||
|
||||
TchTensor::new(output)
|
||||
}
|
||||
|
||||
fn float_mask_fill(tensor: TchTensor, mask: TchTensor, value: Scalar) -> TchTensor {
|
||||
let value: f64 = value.elem();
|
||||
|
||||
tensor.unary_ops(
|
||||
|mut tensor| tensor.f_masked_fill_(&mask.tensor, value).unwrap(),
|
||||
|tensor| tensor.f_masked_fill(&mask.tensor, value).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn float_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchOps::equal(lhs, rhs)
|
||||
}
|
||||
|
||||
fn float_equal_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor {
|
||||
TchOps::equal_elem(lhs, rhs.elem::<f64>())
|
||||
}
|
||||
|
||||
fn float_greater(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchOps::greater(lhs, rhs)
|
||||
}
|
||||
|
||||
fn float_greater_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor {
|
||||
TchOps::greater_elem(lhs, rhs.elem::<f64>())
|
||||
}
|
||||
|
||||
fn float_greater_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchOps::greater_equal(lhs, rhs)
|
||||
}
|
||||
|
||||
fn float_greater_equal_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor {
|
||||
TchOps::greater_equal_elem(lhs, rhs.elem::<f64>())
|
||||
}
|
||||
|
||||
fn float_lower(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchOps::lower(lhs, rhs)
|
||||
}
|
||||
|
||||
fn float_lower_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor {
|
||||
TchOps::lower_elem(lhs, rhs.elem::<f64>())
|
||||
}
|
||||
|
||||
fn float_lower_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchOps::lower_equal(lhs, rhs)
|
||||
}
|
||||
|
||||
fn float_lower_equal_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor {
|
||||
TchOps::lower_equal_elem(lhs, rhs.elem::<f64>())
|
||||
}
|
||||
|
||||
fn float_mean(tensor: TchTensor) -> TchTensor {
|
||||
TchOps::mean(tensor)
|
||||
}
|
||||
|
||||
fn float_sum(tensor: TchTensor) -> TchTensor {
|
||||
TchOps::sum(tensor)
|
||||
}
|
||||
|
||||
fn float_sum_dim(tensor: TchTensor, dim: usize) -> TchTensor {
|
||||
TchOps::sum_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn float_mean_dim(tensor: TchTensor, dim: usize) -> TchTensor {
|
||||
TchOps::mean_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn float_cumsum(tensor: TchTensor, dim: usize) -> TchTensor {
|
||||
TchOps::cumsum(tensor, dim)
|
||||
}
|
||||
|
||||
fn float_cumprod(tensor: TchTensor, dim: usize) -> TchTensor {
|
||||
TchOps::cumprod(tensor, dim)
|
||||
}
|
||||
|
||||
fn float_cummin(tensor: TchTensor, dim: usize) -> TchTensor {
|
||||
TchOps::cummin(tensor, dim)
|
||||
}
|
||||
|
||||
fn float_cummax(tensor: TchTensor, dim: usize) -> TchTensor {
|
||||
TchOps::cummax(tensor, dim)
|
||||
}
|
||||
|
||||
fn float_prod(tensor: TchTensor) -> TchTensor {
|
||||
TchOps::prod(tensor)
|
||||
}
|
||||
|
||||
fn float_prod_dim(tensor: TchTensor, dim: usize) -> TchTensor {
|
||||
TchOps::prod_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn float_argmax(tensor: TchTensor, dim: usize) -> TchTensor {
|
||||
TchOps::argmax(tensor, dim)
|
||||
}
|
||||
|
||||
fn float_argmin(tensor: TchTensor, dim: usize) -> TchTensor {
|
||||
TchOps::argmin(tensor, dim)
|
||||
}
|
||||
|
||||
fn float_max_dim(tensor: TchTensor, dim: usize) -> TchTensor {
|
||||
TchOps::max_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn float_max_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {
|
||||
TchOps::max_dim_with_indices(tensor, dim)
|
||||
}
|
||||
|
||||
fn float_min_dim(tensor: TchTensor, dim: usize) -> TchTensor {
|
||||
TchOps::min_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn float_min_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {
|
||||
TchOps::min_dim_with_indices(tensor, dim)
|
||||
}
|
||||
|
||||
fn float_exp(tensor: TchTensor) -> TchTensor {
|
||||
tensor.unary_ops(|mut tensor| tensor.exp_(), |tensor| tensor.exp())
|
||||
}
|
||||
|
||||
fn float_log(tensor: TchTensor) -> TchTensor {
|
||||
tensor.unary_ops(|mut tensor| tensor.log_(), |tensor| tensor.log())
|
||||
}
|
||||
|
||||
fn float_log1p(tensor: TchTensor) -> TchTensor {
|
||||
tensor.unary_ops(|mut tensor| tensor.log1p_(), |tensor| tensor.log1p())
|
||||
}
|
||||
|
||||
fn float_powf_scalar_impl(tensor: TchTensor, value: Scalar) -> TchTensor {
|
||||
tensor.unary_ops(
|
||||
|mut tensor| tensor.f_pow_(value.elem::<f64>()).unwrap(),
|
||||
|tensor| tensor.pow_tensor_scalar(value.elem::<f64>()),
|
||||
)
|
||||
}
|
||||
|
||||
fn float_sqrt(tensor: TchTensor) -> TchTensor {
|
||||
tensor.unary_ops(|mut tensor| tensor.sqrt_(), |tensor| tensor.sqrt())
|
||||
}
|
||||
|
||||
fn float_abs(tensor: TchTensor) -> TchTensor {
|
||||
tensor.unary_ops(|mut tensor| tensor.abs_(), |tensor| tensor.abs())
|
||||
}
|
||||
|
||||
fn float_cos(tensor: TchTensor) -> TchTensor {
|
||||
tensor.unary_ops(|mut tensor| tensor.cos_(), |tensor| tensor.cos())
|
||||
}
|
||||
|
||||
fn float_cosh(tensor: TchTensor) -> TchTensor {
|
||||
tensor.unary_ops(|mut tensor| tensor.cosh_(), |tensor| tensor.cosh())
|
||||
}
|
||||
|
||||
fn float_sin(tensor: TchTensor) -> TchTensor {
|
||||
tensor.unary_ops(|mut tensor| tensor.sin_(), |tensor| tensor.sin())
|
||||
}
|
||||
|
||||
fn float_sinh(tensor: TchTensor) -> TchTensor {
|
||||
tensor.unary_ops(|mut tensor| tensor.sinh_(), |tensor| tensor.sinh())
|
||||
}
|
||||
|
||||
fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
tensor.unary_ops(|mut tensor| tensor.tan_(), |tensor| tensor.tan())
|
||||
}
|
||||
|
||||
fn float_tanh(tensor: TchTensor) -> TchTensor {
|
||||
tensor.unary_ops(|mut tensor| tensor.tanh_(), |tensor| tensor.tanh())
|
||||
}
|
||||
|
||||
fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
tensor.unary_ops(|mut tensor| tensor.acos_(), |tensor| tensor.acos())
|
||||
}
|
||||
|
||||
fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
tensor.unary_ops(|mut tensor| tensor.acosh_(), |tensor| tensor.acosh())
|
||||
}
|
||||
|
||||
fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
tensor.unary_ops(|mut tensor| tensor.asin_(), |tensor| tensor.asin())
|
||||
}
|
||||
|
||||
fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
tensor.unary_ops(|mut tensor| tensor.asinh_(), |tensor| tensor.asinh())
|
||||
}
|
||||
|
||||
fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
tensor.unary_ops(|mut tensor| tensor.atan_(), |tensor| tensor.atan())
|
||||
}
|
||||
|
||||
fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
tensor.unary_ops(|mut tensor| tensor.atanh_(), |tensor| tensor.atanh())
|
||||
}
|
||||
|
||||
fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
TchOps::atan2(lhs, rhs)
|
||||
}
|
||||
|
||||
fn float_round(tensor: TchTensor) -> TchTensor {
|
||||
tensor.unary_ops(|mut tensor| tensor.round_(), |tensor| tensor.round())
|
||||
}
|
||||
|
||||
fn float_floor(tensor: TchTensor) -> TchTensor {
|
||||
tensor.unary_ops(|mut tensor| tensor.floor_(), |tensor| tensor.floor())
|
||||
}
|
||||
|
||||
fn float_ceil(tensor: TchTensor) -> TchTensor {
|
||||
tensor.unary_ops(|mut tensor| tensor.ceil_(), |tensor| tensor.ceil())
|
||||
}
|
||||
|
||||
fn float_trunc(tensor: TchTensor) -> TchTensor {
|
||||
tensor.unary_ops(|mut tensor| tensor.trunc_(), |tensor| tensor.trunc())
|
||||
}
|
||||
|
||||
fn float_erf(tensor: TchTensor) -> TchTensor {
|
||||
tensor.unary_ops(|mut tensor| tensor.erf_(), |tensor| tensor.erf())
|
||||
}
|
||||
|
||||
fn float_cat(tensors: Vec<TchTensor>, dim: usize) -> TchTensor {
|
||||
TchOps::cat(tensors, dim)
|
||||
}
|
||||
|
||||
fn float_clamp_min(tensor: TchTensor, min: Scalar) -> TchTensor {
|
||||
TchOps::clamp_min(tensor, min.elem::<f64>())
|
||||
}
|
||||
|
||||
fn float_clamp_max(tensor: TchTensor, max: Scalar) -> TchTensor {
|
||||
TchOps::clamp_max(tensor, max.elem::<f64>())
|
||||
}
|
||||
|
||||
fn float_clamp(tensor: TchTensor, min: Scalar, max: Scalar) -> TchTensor {
|
||||
TchOps::clamp(tensor, min.elem::<f64>(), max.elem::<f64>())
|
||||
}
|
||||
|
||||
fn float_into_int(tensor: TchTensor) -> TchTensor {
|
||||
let tensor = tensor.tensor.to_kind(tch::Kind::Int64);
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn float_powf(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
|
||||
TchOps::powf(lhs, rhs)
|
||||
}
|
||||
|
||||
fn float_permute(tensor: TchTensor, axes: &[usize]) -> TchTensor {
|
||||
TchOps::permute(tensor, axes)
|
||||
}
|
||||
|
||||
fn float_flip(tensor: TchTensor, axes: &[usize]) -> TchTensor {
|
||||
TchOps::flip(tensor, axes)
|
||||
}
|
||||
|
||||
fn float_sign(tensor: TchTensor) -> TchTensor {
|
||||
TchOps::sign(tensor)
|
||||
}
|
||||
|
||||
fn float_expand(tensor: TchTensor, shape: Shape) -> TchTensor {
|
||||
TchOps::expand(tensor, shape)
|
||||
}
|
||||
|
||||
fn float_sort(tensor: TchTensor, dim: usize, descending: bool) -> TchTensor {
|
||||
TchOps::sort(tensor, dim, descending)
|
||||
}
|
||||
|
||||
fn float_sort_with_indices(
|
||||
tensor: TchTensor,
|
||||
dim: usize,
|
||||
descending: bool,
|
||||
) -> (TchTensor, TchTensor) {
|
||||
TchOps::sort_with_indices(tensor, dim, descending)
|
||||
}
|
||||
|
||||
fn float_argsort(tensor: TchTensor, dim: usize, descending: bool) -> IntTensor<Self> {
|
||||
TchOps::argsort(tensor, dim, descending)
|
||||
}
|
||||
|
||||
fn float_cast(tensor: TchTensor, dtype: FloatDType) -> TchTensor {
|
||||
// NOTE: when dtypes of inputs to an arithmetic operation differ, tch handles type
|
||||
// promotion based on a set of rules: https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc
|
||||
|
||||
// Type promotion is not automatic on all backends so this behavior might differ
|
||||
let kind = dtype.into_kind();
|
||||
|
||||
if tensor.tensor.kind() == kind {
|
||||
tensor
|
||||
} else {
|
||||
TchTensor::new(tensor.tensor.to_kind(kind))
|
||||
}
|
||||
}
|
||||
|
||||
fn float_unfold(
|
||||
tensor: FloatTensor<Self>,
|
||||
dim: usize,
|
||||
size: usize,
|
||||
step: usize,
|
||||
) -> FloatTensor<Self> {
|
||||
TchOps::unfold(tensor, dim, size, step)
|
||||
}
|
||||
|
||||
fn float_is_nan(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
|
||||
TchTensor::new(tensor.tensor.isnan())
|
||||
}
|
||||
|
||||
fn float_is_inf(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
|
||||
TchTensor::new(tensor.tensor.isinf())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
use burn_backend::ops::TransactionOps;
|
||||
|
||||
use crate::{LibTorch, TchElement};
|
||||
|
||||
impl<E: TchElement> TransactionOps<Self> for LibTorch<E> {}
|
||||
507
crates/stable-diffusion-burn/burn-crates/burn-tch/src/tensor.rs
Normal file
507
crates/stable-diffusion-burn/burn-crates/burn-tch/src/tensor.rs
Normal file
@@ -0,0 +1,507 @@
|
||||
use crate::{LibTorchDevice, TchElement};
|
||||
use burn_backend::{DType, FloatDType, IntDType, Shape, TensorData, TensorMetadata};
|
||||
use libc::c_void;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// A reference to a tensor storage.
|
||||
///
|
||||
/// We manually implement `Sync` and `Send` unsafely, so even if we could use `Rc`, it isn't safe.
|
||||
#[allow(clippy::arc_with_non_send_sync)]
|
||||
pub type StorageRef = Arc<*mut c_void>;
|
||||
|
||||
/// A reference to a tensor storage.
|
||||
#[derive(PartialEq, Debug, Clone)]
|
||||
pub enum Storage {
|
||||
/// When a tensor is a partial view of another tensor.
|
||||
View {
|
||||
/// Storage reference for the whole buffer.
|
||||
buffer_ref: StorageRef,
|
||||
/// Storage reference for the partial buffer.
|
||||
view_ref: StorageRef,
|
||||
},
|
||||
/// When a tensor use all of its buffer.
|
||||
Owned {
|
||||
/// Storage reference for the whole buffer.
|
||||
buffer_ref: StorageRef,
|
||||
},
|
||||
}
|
||||
|
||||
impl Storage {
|
||||
/// Check if the storage can be used inplace.
|
||||
pub fn can_mut(&self) -> bool {
|
||||
match self {
|
||||
Storage::View {
|
||||
buffer_ref: start_ref,
|
||||
view_ref,
|
||||
} => Arc::strong_count(start_ref) == 1 && Arc::strong_count(view_ref) == 1,
|
||||
Storage::Owned {
|
||||
buffer_ref: start_ref,
|
||||
} => Arc::strong_count(start_ref) == 1,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the whole buffer reference.
|
||||
pub fn buffer_ref(&self) -> &StorageRef {
|
||||
match self {
|
||||
Storage::View {
|
||||
buffer_ref: start_ref,
|
||||
view_ref: _,
|
||||
} => start_ref,
|
||||
Storage::Owned {
|
||||
buffer_ref: start_ref,
|
||||
} => start_ref,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A tensor using the tch backend.
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct TchTensor {
|
||||
/// Handle to the tensor. Call methods on this field.
|
||||
pub tensor: tch::Tensor,
|
||||
|
||||
/// The tensor's storage
|
||||
pub storage: Storage,
|
||||
}
|
||||
|
||||
impl TensorMetadata for TchTensor {
|
||||
fn dtype(&self) -> DType {
|
||||
match self.tensor.kind() {
|
||||
tch::Kind::Uint8 => DType::U8,
|
||||
tch::Kind::Int8 => DType::I8,
|
||||
tch::Kind::Int16 => DType::I16,
|
||||
tch::Kind::Int => DType::I32,
|
||||
tch::Kind::Int64 => DType::I64,
|
||||
tch::Kind::Half => DType::F16,
|
||||
tch::Kind::Float => DType::F32,
|
||||
tch::Kind::Double => DType::F64,
|
||||
tch::Kind::Bool => DType::Bool,
|
||||
tch::Kind::BFloat16 => DType::BF16,
|
||||
// Complex and quantization types are not valid/implemented.
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn shape(&self) -> Shape {
|
||||
Shape::from(self.tensor.size())
|
||||
}
|
||||
|
||||
fn rank(&self) -> usize {
|
||||
self.tensor.dim()
|
||||
}
|
||||
}
|
||||
|
||||
impl burn_backend::QTensorPrimitive for TchTensor {
|
||||
fn scheme(&self) -> &burn_backend::quantization::QuantScheme {
|
||||
unimplemented!("Quantization is not supported")
|
||||
}
|
||||
}
|
||||
|
||||
impl core::fmt::Display for TchTensor {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.tensor)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) trait IntoKind {
|
||||
fn try_into_kind(self) -> Result<tch::Kind, tch::TchError>;
|
||||
fn into_kind(self) -> tch::Kind
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
self.try_into_kind().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoKind for IntDType {
|
||||
fn try_into_kind(self) -> Result<tch::Kind, tch::TchError> {
|
||||
let dtype: DType = self.into();
|
||||
dtype.try_into_kind()
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoKind for FloatDType {
|
||||
fn try_into_kind(self) -> Result<tch::Kind, tch::TchError> {
|
||||
let dtype: DType = self.into();
|
||||
dtype.try_into_kind()
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoKind for DType {
|
||||
fn try_into_kind(self) -> Result<tch::Kind, tch::TchError> {
|
||||
match self {
|
||||
DType::F64 => Ok(tch::Kind::Double),
|
||||
DType::F32 => Ok(tch::Kind::Float),
|
||||
DType::Flex32 => Ok(tch::Kind::Float),
|
||||
DType::F16 => Ok(tch::Kind::Half),
|
||||
DType::BF16 => Ok(tch::Kind::BFloat16),
|
||||
DType::I64 => Ok(tch::Kind::Int64),
|
||||
DType::I32 => Ok(tch::Kind::Int),
|
||||
DType::I16 => Ok(tch::Kind::Int16),
|
||||
DType::I8 => Ok(tch::Kind::Int8),
|
||||
DType::U8 => Ok(tch::Kind::Uint8),
|
||||
DType::Bool => Ok(tch::Kind::Bool),
|
||||
other => Err(tch::TchError::Kind(format!("Unsupported dtype {other:?}"))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TchTensor {
|
||||
/// Create a new tensor.
|
||||
///
|
||||
/// Note that if the tensor was created from an operation that may reuse the same tensor
|
||||
/// storage as the parent, you should use [from_existing](TchTensor::from_existing)
|
||||
/// instead.
|
||||
pub fn new(tensor: tch::Tensor) -> Self {
|
||||
#[allow(clippy::arc_with_non_send_sync)]
|
||||
let storage = Storage::Owned {
|
||||
buffer_ref: Arc::new(tensor.data_ptr()),
|
||||
};
|
||||
|
||||
Self { tensor, storage }
|
||||
}
|
||||
|
||||
/// Create a tensor that was created from an operation executed on a parent tensor.
|
||||
///
|
||||
/// If the child tensor shared the same storage as its parent, it will be cloned, effectively
|
||||
/// tracking how much tensors point to the same memory space.
|
||||
pub fn from_existing(tensor: tch::Tensor, storage_parent: Storage) -> Self {
|
||||
let storage_child = tensor.data_ptr();
|
||||
let mut is_a_new_tensor = true;
|
||||
|
||||
match &storage_parent {
|
||||
Storage::View {
|
||||
buffer_ref: start_ref,
|
||||
view_ref,
|
||||
} => {
|
||||
if storage_child == *start_ref.as_ref() || storage_child == *view_ref.as_ref() {
|
||||
is_a_new_tensor = false;
|
||||
}
|
||||
}
|
||||
Storage::Owned {
|
||||
buffer_ref: start_ref,
|
||||
} => {
|
||||
if storage_child == *start_ref.as_ref() {
|
||||
is_a_new_tensor = false;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let storage = match is_a_new_tensor {
|
||||
true => Storage::Owned {
|
||||
#[allow(clippy::arc_with_non_send_sync)]
|
||||
buffer_ref: Arc::new(storage_child),
|
||||
},
|
||||
false => storage_parent.clone(),
|
||||
};
|
||||
|
||||
Self { tensor, storage }
|
||||
}
|
||||
|
||||
/// Create a tensor that uses a part of its parent tensor such as slice and narrow.
|
||||
pub fn partial(tensor: tch::Tensor, storage_parent: Storage) -> Self {
|
||||
let storage = Storage::View {
|
||||
buffer_ref: storage_parent.buffer_ref().clone(),
|
||||
#[allow(clippy::arc_with_non_send_sync)]
|
||||
view_ref: Arc::new(tensor.data_ptr()),
|
||||
};
|
||||
Self { tensor, storage }
|
||||
}
|
||||
}
|
||||
|
||||
// This is safe since we don't use autodiff from LibTorch.
|
||||
// Also, atomic reference counting is used to know if the tensor's data can be reused.
|
||||
// If there are multiple reference on the same tensor, it becomes read only.
|
||||
unsafe impl Send for TchTensor {}
|
||||
unsafe impl Sync for TchTensor {}
|
||||
|
||||
impl TchTensor {
|
||||
/// Checks if the tensor can be mutated in-place.
|
||||
///
|
||||
/// Returns `true` if the tensor's stride does not contain zero (no broadcasting)
|
||||
/// and the storage can be mutated.
|
||||
pub fn can_mut(&self) -> bool {
|
||||
let stride_contains_zero = self.tensor.stride().contains(&0);
|
||||
|
||||
!stride_contains_zero && self.storage.can_mut()
|
||||
}
|
||||
|
||||
/// Executes an operation on a tensor if the data can be reused.
|
||||
pub fn mut_ops<F: Fn(&mut tch::Tensor) -> tch::Tensor>(
|
||||
&mut self,
|
||||
func: F,
|
||||
) -> Option<TchTensor> {
|
||||
if !self.can_mut() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let data = self.storage.clone();
|
||||
Some(TchTensor::from_existing(func(&mut self.tensor), data))
|
||||
}
|
||||
|
||||
/// Executes a unary operation, reusing the tensor data if possible.
|
||||
pub fn unary_ops<FOwn, FRef>(self, fown: FOwn, fref: FRef) -> TchTensor
|
||||
where
|
||||
FOwn: Fn(tch::Tensor) -> tch::Tensor,
|
||||
FRef: Fn(&tch::Tensor) -> tch::Tensor,
|
||||
{
|
||||
if !self.can_mut() {
|
||||
return TchTensor::from_existing(fref(&self.tensor), self.storage);
|
||||
}
|
||||
|
||||
TchTensor::from_existing(fown(self.tensor), self.storage)
|
||||
}
|
||||
|
||||
/// Executes a binary operation, reusing the tensor data if possible.
|
||||
pub fn binary_ops_tensor<FLMut, FRMut, FRef>(
|
||||
mut lhs: Self,
|
||||
mut rhs: Self,
|
||||
flmut: FLMut,
|
||||
frmut: FRMut,
|
||||
fref: FRef,
|
||||
) -> TchTensor
|
||||
where
|
||||
FLMut: Fn(&mut tch::Tensor, &tch::Tensor) -> tch::Tensor,
|
||||
FRMut: Fn(&tch::Tensor, &mut tch::Tensor) -> tch::Tensor,
|
||||
FRef: Fn(&tch::Tensor, &tch::Tensor) -> tch::Tensor,
|
||||
{
|
||||
let lhs_shape = lhs.shape();
|
||||
let rhs_shape = rhs.shape();
|
||||
|
||||
// Both lhs and rhs are expected to have the same rank
|
||||
let d_out = lhs_shape.num_dims();
|
||||
let mut out_shape = Shape::from(vec![1usize; d_out]);
|
||||
|
||||
for i in 0..d_out {
|
||||
out_shape[i] = usize::max(lhs_shape[i], rhs_shape[i]);
|
||||
}
|
||||
|
||||
let num_elements_out = out_shape.num_elements();
|
||||
|
||||
// Attempt to mutate lhs tensor
|
||||
if lhs_shape.num_elements() == num_elements_out
|
||||
&& let Some(output) = lhs.mut_ops(|lhs| flmut(lhs, &rhs.tensor))
|
||||
{
|
||||
return output;
|
||||
}
|
||||
|
||||
// Attempt to mutate rhs tensor
|
||||
if rhs_shape.num_elements() == num_elements_out
|
||||
&& let Some(output) = rhs.mut_ops(|rhs| frmut(&lhs.tensor, rhs))
|
||||
{
|
||||
return output;
|
||||
}
|
||||
|
||||
let storage = lhs.storage;
|
||||
let tensor = fref(&lhs.tensor, &rhs.tensor);
|
||||
|
||||
TchTensor::from_existing(tensor, storage)
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for TchTensor {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
tensor: self.tensor.shallow_clone(),
|
||||
storage: self.storage.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A shape that can be used by LibTorch.
|
||||
#[derive(Debug)]
|
||||
pub struct TchShape {
|
||||
/// The shape's dimensions.
|
||||
pub dims: Vec<i64>,
|
||||
}
|
||||
|
||||
impl From<Shape> for TchShape {
|
||||
fn from(shape: Shape) -> Self {
|
||||
TchShape {
|
||||
dims: shape.iter().map(|d| *d as i64).collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&[usize]> for TchShape {
|
||||
fn from(shape: &[usize]) -> Self {
|
||||
TchShape {
|
||||
dims: shape.iter().map(|d| *d as i64).collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TchTensor {
|
||||
/// Creates a new tensor from a shape and a device.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `data` - The tensor's data.
|
||||
/// * `device` - The device on which the tensor will be allocated.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new tensor.
|
||||
pub fn from_data<E: TchElement>(data: TensorData, device: tch::Device) -> Self {
|
||||
let shape_tch = TchShape::from(data.shape.as_slice());
|
||||
let tensor =
|
||||
tch::Tensor::from_data_size(&data.bytes, &shape_tch.dims, E::kind()).to(device);
|
||||
|
||||
Self::new(tensor)
|
||||
}
|
||||
}
|
||||
|
||||
impl TchTensor {
|
||||
/// Creates an empty tensor from a shape and a device.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `shape` - The shape of the tensor.
|
||||
/// * `device` - The device to create the tensor on.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new empty tensor.
|
||||
pub fn empty<E: TchElement>(shape: Shape, device: LibTorchDevice) -> Self {
|
||||
let shape_tch = TchShape::from(shape);
|
||||
let tensor = tch::Tensor::empty(shape_tch.dims, (E::kind(), device.into()));
|
||||
|
||||
Self::new(tensor)
|
||||
}
|
||||
}
|
||||
|
||||
// Adapted from `tch` to use patched `T::kind()` instead of `T::KIND` which is incorrect for bf16.
|
||||
// TODO: remove when fixed in `tch` release (https://github.com/LaurentMazare/tch-rs/pull/996).
|
||||
impl<T: TchElement + Copy> TryFrom<&TchTensor> for Vec<T> {
|
||||
type Error = tch::TchError;
|
||||
fn try_from(tensor: &TchTensor) -> Result<Self, Self::Error> {
|
||||
let tensor = &tensor.tensor;
|
||||
let size = tensor.size();
|
||||
if size.len() != 1 {
|
||||
Err(tch::TchError::Convert(format!(
|
||||
"Attempting to convert a Tensor with {} dimensions to flat vector",
|
||||
size.len()
|
||||
)))?;
|
||||
}
|
||||
let numel = size[0] as usize;
|
||||
let mut vec = vec![T::ZERO; numel];
|
||||
// Adapted to use patched `T::kind()` instead
|
||||
// TODO: tensor.f_to_kind(T::KIND)?.f_copy_data(&mut vec, numel)?;
|
||||
f_copy_data(&mut tensor.f_to_kind(T::kind())?, &mut vec, numel)?;
|
||||
Ok(vec)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn ptr_to_string(ptr: *mut libc::c_char) -> Option<String> {
|
||||
if !ptr.is_null() {
|
||||
unsafe {
|
||||
let str = std::ffi::CStr::from_ptr(ptr).to_string_lossy().into_owned();
|
||||
libc::free(ptr as *mut libc::c_void);
|
||||
Some(str)
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Copies `numel` elements from `self` to `dst`.
|
||||
fn f_copy_data<T: TchElement>(
|
||||
tensor: &mut tch::Tensor,
|
||||
dst: &mut [T],
|
||||
numel: usize,
|
||||
) -> Result<(), tch::TchError> {
|
||||
if T::kind() != tensor.f_kind()? {
|
||||
return Err(tch::TchError::Kind(format!(
|
||||
"incoherent elt kind, {:?} != {:?}",
|
||||
tensor.f_kind(),
|
||||
T::kind()
|
||||
)));
|
||||
}
|
||||
if dst.len() < numel {
|
||||
return Err(tch::TchError::Shape(format!("slice len < {numel}")));
|
||||
}
|
||||
|
||||
unsafe {
|
||||
torch_sys::at_copy_data(
|
||||
tensor.as_mut_ptr(),
|
||||
dst.as_mut_ptr() as *const c_void,
|
||||
numel,
|
||||
T::kind().elt_size_in_bytes(),
|
||||
);
|
||||
match ptr_to_string(torch_sys::get_and_reset_last_err()) {
|
||||
None => Ok(()),
|
||||
Some(c_error) => Err(tch::TchError::Torch(c_error)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_backend::ops::FloatTensorOps;
|
||||
use burn_backend::{Backend, quantization::QuantScheme, read_sync};
|
||||
|
||||
type B = crate::LibTorch<f32>;
|
||||
|
||||
#[test]
|
||||
fn should_have_bf16_kind() {
|
||||
let data = TensorData::from([4.0, 4.0]);
|
||||
let tensor_1: TchTensor = B::float_from_data(data, &Default::default());
|
||||
let tensor_2 = B::float_cast(tensor_1, DType::BF16.into());
|
||||
|
||||
assert_eq!(tensor_2.tensor.kind(), tch::Kind::BFloat16);
|
||||
|
||||
let out = read_sync(B::float_into_data(tensor_2)).unwrap();
|
||||
|
||||
out.assert_eq(&TensorData::from([4.0, 4.0]), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_dtypes() {
|
||||
let device = Default::default();
|
||||
|
||||
assert!(B::supports_dtype(&device, DType::F64));
|
||||
assert!(B::supports_dtype(&device, DType::F32));
|
||||
assert!(B::supports_dtype(&device, DType::Flex32));
|
||||
assert!(B::supports_dtype(&device, DType::F16));
|
||||
assert!(B::supports_dtype(&device, DType::BF16));
|
||||
assert!(B::supports_dtype(&device, DType::I64));
|
||||
assert!(B::supports_dtype(&device, DType::I32));
|
||||
assert!(B::supports_dtype(&device, DType::I16));
|
||||
assert!(B::supports_dtype(&device, DType::I8));
|
||||
assert!(B::supports_dtype(&device, DType::U8));
|
||||
assert!(B::supports_dtype(&device, DType::Bool));
|
||||
|
||||
assert!(!B::supports_dtype(&device, DType::U64));
|
||||
assert!(!B::supports_dtype(&device, DType::U32));
|
||||
assert!(!B::supports_dtype(&device, DType::U16));
|
||||
assert!(!B::supports_dtype(
|
||||
&device,
|
||||
DType::QFloat(QuantScheme::default())
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_from_bf16() {
|
||||
let data = TensorData::from([[1.0], [1.]]).convert_dtype(DType::BF16);
|
||||
let tensor_1: TchTensor = B::float_from_data(data, &Default::default());
|
||||
let data = TensorData::from([[2.0], [2.]]).convert_dtype(DType::BF16);
|
||||
let tensor_2 = B::float_from_data(data, &Default::default());
|
||||
|
||||
let tensor_3 = B::float_add(tensor_1, tensor_2);
|
||||
|
||||
assert_eq!(tensor_3.tensor.kind(), tch::Kind::BFloat16);
|
||||
|
||||
let out = read_sync(B::float_into_data(tensor_3)).unwrap();
|
||||
|
||||
out.assert_eq(&TensorData::from([[3.0], [3.0]]), false);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
/// Dummy function to get CUDA to link properly
|
||||
pub fn dummy_cuda_dependency();
|
||||
}
|
||||
|
||||
#[used]
|
||||
static INIT_ARRAY: [unsafe extern "C" fn(); 1] = [dummy_cuda_dependency];
|
||||
Reference in New Issue
Block a user