feat: update workspace paths and enhance gitignore

- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
This commit is contained in:
2026-03-05 19:39:14 +01:00
parent 4bb7ca9074
commit 3a67c0979c
1605 changed files with 537032 additions and 2 deletions

View File

@@ -0,0 +1,44 @@
[package]
authors = ["louisfd <louisfd94@gmail.com>"]
categories = ["science"]
description = "[Deprecated] Candle backend for the Burn framework - use burn-cubecl, burn-ndarray, or burn-tch instead"
edition.workspace = true
keywords = ["deep-learning", "machine-learning", "data"]
license.workspace = true
name = "burn-candle"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-candle"
documentation = "https://docs.rs/burn-candle"
version.workspace = true
[lints]
workspace = true
[features]
default = ["std"]
std = []
doc = ["default"]
tracing = [
"burn-backend/tracing",
"burn-std/tracing",
]
cuda = ["candle-core/cuda"]
metal = ["candle-core/metal"]
accelerate = ["candle-core/accelerate"]
[dependencies]
burn-backend = { path = "../burn-backend", version = "=0.21.0-pre.2", default-features = false }
# For rand utils and stub mutex
burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", default-features = false }
candle-core = { workspace = true }
derive-new = { workspace = true }
[dev-dependencies]
burn-tch = { path = "../burn-tch", version = "=0.21.0-pre.2", default-features = false, features = [
] }
[package.metadata.docs.rs]
features = ["doc"]
rustdoc-args = ["--cfg", "docsrs"]

View File

@@ -0,0 +1 @@
../../LICENSE-APACHE

View File

@@ -0,0 +1 @@
../../LICENSE-MIT

View File

@@ -0,0 +1,14 @@
# Burn Candle Backend
> **Deprecated:** This crate is deprecated as of `0.21.0-pre.2` and will be removed in a future release.
> Please migrate to one of the actively maintained backends:
> - **CubeCL backends** (CUDA, ROCm, Vulkan, Metal, WebGPU) for GPU acceleration
> - **NdArray** for portable CPU execution
> - **LibTorch** (`burn-tch`) for a mature CPU/GPU backend
This crate provides a backend for [Burn](https://github.com/tracel-ai/burn) based on the [Candle](https://github.com/huggingface/candle) framework.
## Feature Flags
- `cuda` - Cuda GPU device (NVIDIA only)
- `accelerate` - Accelerate framework (macOS only)

View File

@@ -0,0 +1,300 @@
use std::marker::PhantomData;
use burn_backend::{
BackTrace, Backend, DType, DTypeUsage, DeviceId, DeviceOps, ExecutionError, QTensorPrimitive,
tensor::Device,
};
use burn_std::{
rand::{SeedableRng, StdRng},
stub::Mutex,
};
use candle_core::{DeviceLocation, backend::BackendDevice};
use crate::{
CandleTensor, IntoDType,
element::{CandleElement, FloatCandleElement, IntCandleElement},
};
/// Tensor backend that uses the [candle](candle_core) crate for executing tensor operations.
///
/// It is compatible with a wide range of hardware configurations, including CPUs and GPUs
/// that support CUDA or Metal. Additionally, the backend can be compiled to `wasm` when using the CPU.
#[derive(Clone, Default, Debug)]
pub struct Candle<F = f32, I = i64>
where
F: FloatCandleElement,
I: IntCandleElement,
{
_float: PhantomData<F>,
_int: PhantomData<I>,
}
// Seed for CPU device
pub(crate) static SEED: Mutex<Option<StdRng>> = Mutex::new(None);
pub(crate) fn get_seeded_rng() -> StdRng {
let mut seed = SEED.lock().unwrap();
seed.take().unwrap_or_else(burn_std::rand::get_seeded_rng)
}
pub(crate) fn set_seeded_rng(rng_seeded: StdRng) {
let mut seed = SEED.lock().unwrap();
*seed = Some(rng_seeded);
}
/// The device type for the candle backend.
#[derive(Clone, Debug, PartialEq, Eq)]
/// The device struct when using the `candle` backend.
///
/// To create a Cuda or Metal device from the index, use the associated methods to create the variant:
/// ```no_run
/// use burn_candle::CandleDevice;
///
/// // Create a Cuda device from its index
/// let device = CandleDevice::cuda(0);
/// // Create a Metal device from its index
/// let device = CandleDevice::metal(0);
/// ```
#[derive(Default)]
pub enum CandleDevice {
/// CPU device.
#[default]
Cpu,
/// Cuda device with the given index. The index is the index of the Cuda device in the list of
/// all Cuda devices found on the system.
Cuda(CudaDevice),
/// Metal device with the given index. The index is the index of the Metal device in the list of
/// all Metal devices found on the system.
Metal(MetalDevice),
}
impl CandleDevice {
/// Create a Cuda device with the given index.
/// The index is the index of the Cuda device in the list of all Cuda devices found on the system.
pub fn cuda(index: usize) -> Self {
CandleDevice::Cuda(CudaDevice {
device: candle_core::CudaDevice::new(index).unwrap(),
index,
})
}
/// Create a Metal device with the given index.
/// The index is the index of the Metal device in the list of all Metal devices found on the system.
pub fn metal(index: usize) -> Self {
CandleDevice::Metal(MetalDevice {
device: candle_core::MetalDevice::new(index).unwrap(),
index,
})
}
pub(crate) fn set_seed(&self, seed: u64) {
match self {
CandleDevice::Cpu => {
// candle_core::cpu_backend::CpuDevice.set_seed(seed).unwrap();
// Candle does not support seeding the CPU rng so we use a global seed
let rng = StdRng::seed_from_u64(seed);
set_seeded_rng(rng);
}
CandleDevice::Cuda(cuda_device) => cuda_device.device.set_seed(seed).unwrap(),
CandleDevice::Metal(metal_device) => metal_device.device.set_seed(seed).unwrap(),
}
}
}
#[derive(Clone, Debug)]
/// A Cuda device for the `candle` backend.
pub struct CudaDevice {
pub(crate) device: candle_core::CudaDevice,
/// The index of the Cuda device in the list of all devices on the system.
pub index: usize,
}
impl PartialEq for CudaDevice {
fn eq(&self, other: &Self) -> bool {
self.device.same_device(&other.device) && self.index == other.index
}
}
impl Eq for CudaDevice {}
#[derive(Clone, Debug)]
/// A Metal device for the `candle` backend.
pub struct MetalDevice {
pub(crate) device: candle_core::MetalDevice,
/// The index of the Metal device in the list of all devices on the system.
pub index: usize,
}
impl PartialEq for MetalDevice {
fn eq(&self, other: &Self) -> bool {
self.device.same_device(&other.device) && self.index == other.index
}
}
impl Eq for MetalDevice {}
impl From<CandleDevice> for candle_core::Device {
fn from(device: CandleDevice) -> Self {
match device {
CandleDevice::Cpu => candle_core::Device::Cpu,
CandleDevice::Cuda(device) => candle_core::Device::Cuda(device.device),
CandleDevice::Metal(device) => candle_core::Device::Metal(device.device),
}
}
}
impl From<candle_core::Device> for CandleDevice {
fn from(device: candle_core::Device) -> Self {
match device.location() {
DeviceLocation::Cpu => CandleDevice::Cpu,
DeviceLocation::Cuda { gpu_id } => {
if let candle_core::Device::Cuda(device) = device {
CandleDevice::Cuda(CudaDevice {
device,
index: gpu_id,
})
} else {
panic!("Expected CUDA device.");
}
}
DeviceLocation::Metal { gpu_id } => {
if let candle_core::Device::Metal(device) = device {
CandleDevice::Metal(MetalDevice {
device,
index: gpu_id,
})
} else {
panic!("Expected Metal device.");
}
}
}
}
}
impl burn_backend::Device for CandleDevice {
fn to_id(&self) -> burn_backend::DeviceId {
match self {
CandleDevice::Cuda(device) => DeviceId::new(0, device.index as u32),
CandleDevice::Metal(device) => DeviceId::new(1, device.index as u32),
CandleDevice::Cpu => DeviceId::new(2, 0),
}
}
fn from_id(device_id: DeviceId) -> Self {
match device_id.type_id {
0 => CandleDevice::cuda(device_id.index_id as usize),
1 => CandleDevice::metal(device_id.index_id as usize),
_ => CandleDevice::Cpu,
}
}
fn device_count(type_id: u16) -> usize {
// TODO: Fix that
1
}
}
impl DeviceOps for CandleDevice {}
impl<F: FloatCandleElement, I: IntCandleElement> Backend for Candle<F, I> {
type Device = CandleDevice;
type FloatTensorPrimitive = CandleTensor;
type FloatElem = F;
type IntTensorPrimitive = CandleTensor;
type IntElem = I;
type BoolTensorPrimitive = CandleTensor;
type BoolElem = u8;
type QuantizedTensorPrimitive = CandleTensor;
fn ad_enabled(_device: &Self::Device) -> bool {
false
}
fn name(device: &Self::Device) -> String {
match device {
CandleDevice::Cpu => "candle<cpu>",
CandleDevice::Cuda(..) => "candle<cuda>",
CandleDevice::Metal(..) => "candle<metal>",
}
.to_string()
}
fn seed(device: &CandleDevice, seed: u64) {
device.set_seed(seed);
}
fn sync(device: &Device<Self>) -> Result<(), ExecutionError> {
let device: candle_core::Device = (device.clone()).into();
match device {
candle_core::Device::Cpu => (),
candle_core::Device::Cuda(device) => {
#[cfg(feature = "cuda")]
device
.synchronize()
.map_err(|err| ExecutionError::Generic {
reason: format!("Can't sync the cuda device: {err}"),
backtrace: BackTrace::capture(),
})?;
}
candle_core::Device::Metal(device) => {
// For some reason, device.wait_until_completed() does not seem to work,
// and neither does writing and reading a value with into_data
return Err(ExecutionError::Generic {
reason:
"Device synchronization unavailable with Metal device on Candle backend"
.into(),
backtrace: BackTrace::capture(),
});
}
}
Ok(())
}
fn dtype_usage(device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet {
if dtype.try_into_dtype().is_ok() {
burn_backend::DTypeUsage::general()
} else {
burn_backend::DTypeUsageSet::empty()
}
}
}
#[cfg(test)]
mod tests {
use burn_std::QuantScheme;
use super::*;
#[test]
fn should_support_dtypes() {
type B = Candle<f32>;
let device = Default::default();
assert!(B::supports_dtype(&device, DType::F64));
assert!(B::supports_dtype(&device, DType::F32));
assert!(B::supports_dtype(&device, DType::Flex32));
assert!(B::supports_dtype(&device, DType::F16));
assert!(B::supports_dtype(&device, DType::BF16));
assert!(B::supports_dtype(&device, DType::I64));
assert!(B::supports_dtype(&device, DType::U32));
assert!(B::supports_dtype(&device, DType::U8));
assert!(B::supports_dtype(&device, DType::I32));
assert!(B::supports_dtype(&device, DType::I16));
assert!(!B::supports_dtype(&device, DType::U64));
assert!(!B::supports_dtype(&device, DType::U16));
assert!(!B::supports_dtype(&device, DType::I8));
assert!(!B::supports_dtype(&device, DType::Bool));
assert!(!B::supports_dtype(
&device,
DType::QFloat(QuantScheme::default())
));
}
}

View File

@@ -0,0 +1,32 @@
use std::borrow::Borrow;
use burn_backend::{Element, bf16, f16};
use candle_core::{FloatDType, Tensor, WithDType};
/// Candle element
pub trait CandleElement: Element + WithDType {}
/// Candle float element
pub trait FloatCandleElement: CandleElement + FloatDType {}
/// Candle int element
pub trait IntCandleElement: CandleElement {}
impl CandleElement for f64 {}
impl FloatCandleElement for f64 {}
impl CandleElement for f32 {}
impl FloatCandleElement for f32 {}
impl CandleElement for f16 {}
impl FloatCandleElement for f16 {}
impl CandleElement for bf16 {}
impl FloatCandleElement for bf16 {}
impl CandleElement for u8 {}
impl IntCandleElement for u8 {}
impl CandleElement for u32 {}
impl IntCandleElement for u32 {}
impl CandleElement for i64 {}
impl IntCandleElement for i64 {}

View File

@@ -0,0 +1,27 @@
#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_cfg))]
#![allow(unused)] // TODO remove when backend filled
#![deprecated(
since = "0.21.0-pre.2",
note = "burn-candle is deprecated and will be removed in a future release. Use burn-cubecl (CUDA/ROCm/Vulkan/Metal/WebGPU), burn-ndarray, or burn-tch instead."
)]
//! Burn Candle Backend
//!
//! **Deprecated:** This backend is deprecated and will be removed in a future release.
//! Please migrate to one of the actively maintained backends:
//! - CubeCL backends (CUDA, ROCm, Vulkan, Metal, WebGPU) for GPU acceleration
//! - NdArray for portable CPU execution
//! - LibTorch (`burn-tch`) for a mature CPU/GPU backend
#[macro_use]
extern crate derive_new;
mod backend;
mod element;
mod ops;
mod tensor;
pub use backend::*;
pub use element::*;
pub use tensor::*;

View File

@@ -0,0 +1,17 @@
use burn_backend::{ops::ActivationOps, tensor::FloatTensor};
use crate::{
Candle, CandleTensor,
element::{CandleElement, FloatCandleElement, IntCandleElement},
tensor,
};
impl<F: FloatCandleElement, I: IntCandleElement> ActivationOps<Self> for Candle<F, I> {
fn gelu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.gelu().unwrap())
}
fn relu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.relu().unwrap())
}
}

View File

@@ -0,0 +1,572 @@
use std::cmp::max;
use std::marker::PhantomData;
use crate::{
Candle, CandleDevice, CandleTensor,
element::{CandleElement, FloatCandleElement, IntCandleElement},
};
use burn_backend::{
BackTrace, Backend, Distribution, ExecutionError, Slice, bf16, f16,
ops::unfold::{calculate_unfold_shape, calculate_unfold_windows},
};
use burn_backend::{Element, Shape, TensorData, TensorMetadata};
use candle_core::{Layout, WithDType};
use super::tensor;
pub fn cpu_random<E: CandleElement>(shape: Shape, distribution: Distribution) -> TensorData {
let mut rng = crate::get_seeded_rng();
let data = TensorData::random::<E, _, _>(shape, distribution, &mut rng);
crate::set_seeded_rng(rng);
data
}
pub fn cat(tensors: Vec<CandleTensor>, dim: usize) -> CandleTensor {
let tensors: Vec<candle_core::Tensor> = tensors.into_iter().map(|t| t.tensor).collect();
CandleTensor::new(candle_core::Tensor::cat(&tensors, dim).unwrap())
}
pub fn from_data<E: CandleElement>(data: TensorData, device: &CandleDevice) -> CandleTensor {
CandleTensor::from_data::<E>(data, device.clone())
}
pub fn into_data(tensor: CandleTensor) -> Result<TensorData, ExecutionError> {
fn tensor_data_from_dtype<T: WithDType + Element>(
tensor: &CandleTensor,
) -> Result<TensorData, ExecutionError> {
let data = tensor
.tensor
.flatten_all()
.map_err(|err| ExecutionError::Generic {
reason: format!("{err}"),
backtrace: BackTrace::capture(),
})?
.to_vec1::<T>()
.map_err(|err| ExecutionError::Generic {
reason: format!("{err}"),
backtrace: BackTrace::capture(),
})?;
Ok(TensorData::new(data, tensor.shape()))
}
match tensor.tensor.dtype() {
candle_core::DType::BF16 => tensor_data_from_dtype::<bf16>(&tensor),
candle_core::DType::F16 => tensor_data_from_dtype::<f16>(&tensor),
candle_core::DType::F32 => tensor_data_from_dtype::<f32>(&tensor),
candle_core::DType::F64 => tensor_data_from_dtype::<f64>(&tensor),
candle_core::DType::U8 => tensor_data_from_dtype::<u8>(&tensor),
candle_core::DType::U32 => tensor_data_from_dtype::<u32>(&tensor),
candle_core::DType::I16 => tensor_data_from_dtype::<i16>(&tensor),
candle_core::DType::I32 => tensor_data_from_dtype::<i32>(&tensor),
candle_core::DType::I64 => tensor_data_from_dtype::<i64>(&tensor),
other => todo!("{other:?} not yet supported"),
}
}
pub fn to_device(tensor: CandleTensor, device: &CandleDevice) -> CandleTensor {
CandleTensor::new(tensor.tensor.to_device(&(device.clone()).into()).unwrap())
}
pub fn empty(shape: Shape, device: &CandleDevice, dtype: candle_core::DType) -> CandleTensor {
zeros(shape, device, dtype)
}
pub fn zeros(shape: Shape, device: &CandleDevice, dtype: candle_core::DType) -> CandleTensor {
CandleTensor::new(
candle_core::Tensor::zeros(shape.to_vec(), dtype, &(device.clone()).into()).unwrap(),
)
}
pub fn ones(shape: Shape, device: &CandleDevice, dtype: candle_core::DType) -> CandleTensor {
CandleTensor::new(
candle_core::Tensor::ones(shape.to_vec(), dtype, &(device.clone()).into()).unwrap(),
)
}
pub fn swap_dims(mut tensor: CandleTensor, dim1: usize, dim2: usize) -> CandleTensor {
CandleTensor::new(tensor.tensor.transpose(dim1, dim2).unwrap())
}
pub fn permute(tensor: CandleTensor, axes: &[usize]) -> CandleTensor {
CandleTensor::new(tensor.tensor.permute(axes).unwrap())
}
pub fn flip(tensor: CandleTensor, axes: &[usize]) -> CandleTensor {
// FIXME: Replace with an appropriate method when Candle provides one.
let mut tensor = tensor.tensor;
for &axis in axes {
// Ensure tensor is contiguous before index_select (required by Candle)
tensor = tensor.contiguous().unwrap();
let indexes = candle_core::Tensor::arange_step(
tensor.dim(axis).unwrap() as i64 - 1,
-1,
-1,
tensor.device(),
)
.unwrap();
tensor = tensor.index_select(&indexes, axis).unwrap();
}
CandleTensor::new(tensor)
}
pub fn reshape(tensor: CandleTensor, shape: Shape) -> CandleTensor {
CandleTensor::new(tensor.tensor.reshape(shape.to_vec()).unwrap())
}
pub fn device(tensor: &CandleTensor) -> CandleDevice {
tensor.tensor.device().clone().into()
}
pub fn shape(tensor: &CandleTensor) -> Shape {
tensor.shape()
}
pub fn slice(tensor: CandleTensor, ranges: &[std::ops::Range<usize>]) -> CandleTensor {
let mut narrow_tensor = tensor.tensor;
for (i, range) in ranges.iter().enumerate().take(ranges.len()) {
narrow_tensor = narrow_tensor
.narrow(i, range.start, range.end - range.start)
.unwrap()
}
CandleTensor::new(narrow_tensor)
}
pub fn slice_with_steps(tensor: CandleTensor, slices: &[Slice]) -> CandleTensor {
let mut result_tensor = tensor.tensor;
for (dim, slice) in slices.iter().enumerate() {
if slice.step == 1 {
// Use narrow for step=1 (more efficient)
// Convert slice to range using tensor shape
let dim_size = result_tensor.dim(dim).unwrap();
let range = slice.to_range(dim_size);
let start = range.start;
let length = range.end - range.start;
result_tensor = result_tensor.narrow(dim, start, length).unwrap();
} else {
// Use index_select for step != 1
let dim_size = result_tensor.dim(dim).unwrap();
let range = slice.to_range(dim_size);
let start = range.start;
let end = range.end;
let step = slice.step;
// Generate indices based on step direction
let indices_vec = if step > 0 {
// Forward stepping
let step_usize = step as usize;
(start..end).step_by(step_usize).collect::<Vec<_>>()
} else {
// Backward stepping (negative step)
let step_usize = step.unsigned_abs();
// Start from end-1 and go backwards
let mut indices = Vec::new();
let mut idx = end - 1;
while idx >= start && idx < end {
// Check for underflow
indices.push(idx);
if idx >= step_usize {
idx -= step_usize;
} else {
break;
}
}
indices
};
// Convert indices to tensor and use index_select
let indices_len = indices_vec.len();
let device = result_tensor.device();
let indices = candle_core::Tensor::from_vec(
indices_vec.iter().map(|&x| x as u32).collect::<Vec<_>>(),
indices_len,
device,
)
.unwrap();
result_tensor = result_tensor.index_select(&indices, dim).unwrap();
}
}
CandleTensor::new(result_tensor)
}
pub fn slice_assign(tensor: CandleTensor, slices: &[Slice], value: CandleTensor) -> CandleTensor {
// Check if all slices have step=1 (candle's native slice_assign requirement)
let all_unit_steps = slices.iter().all(|s| s.step == 1);
if all_unit_steps {
// Convert Slice to Range for candle's native slice_assign
let ranges: Vec<std::ops::Range<usize>> = slices
.iter()
.enumerate()
.map(|(dim, slice)| {
let dim_size = tensor.tensor.dim(dim).unwrap_or(usize::MAX);
slice.to_range(dim_size)
})
.collect();
CandleTensor::new(tensor.tensor.slice_assign(&ranges, &value.tensor).unwrap())
} else {
// Implement slice_assign with steps using scatter operations
slice_assign_with_steps_workaround(tensor, slices, value)
}
}
/// Implements slice_assign for non-unit steps using index operations
fn slice_assign_with_steps_workaround(
tensor: CandleTensor,
slices: &[Slice],
value: CandleTensor,
) -> CandleTensor {
let shape = tensor.shape();
let ndims = shape.num_dims();
let device = tensor.tensor.device();
// Generate indices for each dimension based on slice specifications
let indices_per_dim = generate_slice_indices(slices, &shape);
// Early return if no elements to assign
let total_elements: usize = indices_per_dim.iter().map(|v| v.len()).product();
if total_elements == 0 {
return tensor;
}
// Flatten tensors and get metadata
let value_flat = value.tensor.flatten_all().unwrap();
let strides = tensor.tensor.stride();
let tensor_shape = tensor.tensor.dims();
// Use a macro to handle different dtypes without code duplication
macro_rules! apply_slice_assign {
($dtype:ty, $to_vec_fn:ident) => {{
let mut tensor_vec: Vec<$dtype> =
tensor.tensor.flatten_all().unwrap().$to_vec_fn().unwrap();
let value_vec: Vec<$dtype> = value_flat.$to_vec_fn().unwrap();
// Apply assignments using cartesian product of indices
for (value_idx, &value) in value_vec.iter().enumerate() {
let flat_idx = compute_flat_index(value_idx, &indices_per_dim, &strides);
if flat_idx < tensor_vec.len() {
tensor_vec[flat_idx] = value;
}
}
candle_core::Tensor::from_vec(tensor_vec, tensor_shape, device).unwrap()
}};
}
use candle_core::DType;
let result = match tensor.tensor.dtype() {
DType::F32 => apply_slice_assign!(f32, to_vec1),
DType::F64 => apply_slice_assign!(f64, to_vec1),
DType::I64 => apply_slice_assign!(i64, to_vec1),
DType::U32 => apply_slice_assign!(u32, to_vec1),
DType::U8 => apply_slice_assign!(u8, to_vec1),
_ => panic!(
"Unsupported dtype {:?} for slice_assign with steps",
tensor.tensor.dtype()
),
};
CandleTensor::new(result)
}
/// Generate indices for each dimension based on slice specifications
fn generate_slice_indices(slices: &[Slice], tensor_dims: &[usize]) -> Vec<Vec<usize>> {
let ndims = tensor_dims.len();
let mut indices_per_dim = Vec::with_capacity(ndims);
// Process provided slices
for (dim_idx, slice) in slices.iter().enumerate() {
let dim_size = tensor_dims[dim_idx];
let range = slice.to_range(dim_size);
let indices = generate_stepped_indices(range.start, range.end, slice.step);
indices_per_dim.push(indices);
}
// Fill remaining dimensions with full ranges
for &dim_size in tensor_dims.iter().skip(slices.len()) {
indices_per_dim.push((0..dim_size).collect());
}
indices_per_dim
}
/// Generate indices for a single dimension with stepping
fn generate_stepped_indices(start: usize, end: usize, step: isize) -> Vec<usize> {
if step > 0 {
// Forward stepping
(start..end).step_by(step as usize).collect()
} else if step < 0 {
// Backward stepping: start from end-1 and go backwards
let step_size = step.unsigned_abs();
let mut indices = Vec::new();
let mut idx = end.saturating_sub(1);
while idx >= start && idx < end {
indices.push(idx);
if idx >= step_size {
idx -= step_size;
} else {
break;
}
}
indices
} else {
// This branch should never be reached since step is validated to be non-zero
panic!("Step cannot be zero")
}
}
/// Compute flat index from multi-dimensional indices using cartesian product logic
fn compute_flat_index(
value_idx: usize,
indices_per_dim: &[Vec<usize>],
strides: &[usize],
) -> usize {
let mut flat_idx = 0;
let mut remainder = value_idx;
// Convert value_idx to multi-dimensional indices and compute flat tensor index
for dim in (0..indices_per_dim.len()).rev() {
let dim_size = indices_per_dim[dim].len();
let idx_in_dim = remainder % dim_size;
remainder /= dim_size;
let actual_idx = indices_per_dim[dim][idx_in_dim];
flat_idx += actual_idx * strides[dim];
}
flat_idx
}
pub fn narrow(tensor: CandleTensor, dim: usize, start: usize, length: usize) -> CandleTensor {
let tensor = tensor.tensor.narrow(dim, start, length);
match tensor {
Ok(tensor) => CandleTensor::new(tensor),
Err(e) => panic!("error narrow from Candle"),
}
}
pub fn chunk(tensor: CandleTensor, chunks: usize, dim: usize) -> Vec<CandleTensor> {
let tensors = tensor.tensor.chunk(chunks, dim);
match tensors {
Ok(tensors) => tensors.into_iter().map(CandleTensor::new).collect(),
Err(e) => panic!("error chunk from Candle"),
}
}
pub fn expand(tensor: CandleTensor, shape: Shape) -> CandleTensor {
CandleTensor::new(tensor.tensor.broadcast_as(shape.to_vec()).unwrap())
}
pub fn unfold(tensor: CandleTensor, dim: usize, size: usize, step: usize) -> CandleTensor {
let result_shape = calculate_unfold_shape(tensor.shape(), dim, size, step);
let windows = result_shape[dim];
let mut select_ranges = tensor.shape().into_ranges();
let new_axis = select_ranges.len();
let mut stack = Vec::with_capacity(windows);
for widx in 0..windows {
let start = widx * step;
let end = start + size;
select_ranges[dim] = start..end;
let mut window_slice = slice(tensor.clone(), &select_ranges);
window_slice = swap_dims(window_slice, dim, new_axis);
let window_slice = CandleTensor::new(window_slice.tensor.unsqueeze(new_axis).unwrap());
stack.push(window_slice);
}
cat(stack, dim)
}
pub fn sign(tensor: CandleTensor) -> CandleTensor {
CandleTensor::new(tensor.tensor.sign().unwrap())
}
pub fn mask_where_broadcasted(
tensor: CandleTensor,
mask: CandleTensor,
value: CandleTensor,
) -> CandleTensor {
let shape = tensor
.tensor
.shape()
.broadcast_shape_binary_op(mask.tensor.shape(), "where_cond")
.unwrap();
let mut tensor = tensor.tensor;
let mut mask = mask.tensor;
let mut value = value.tensor;
if shape != *tensor.shape() {
tensor = tensor.broadcast_as(shape.clone()).unwrap();
}
if shape != *mask.shape() {
mask = mask.broadcast_as(shape.clone()).unwrap();
}
if shape != *value.shape() {
value = value.broadcast_as(shape).unwrap();
}
CandleTensor::new(mask.where_cond(&value, &tensor).unwrap())
}
pub fn cross(lhs: CandleTensor, rhs: CandleTensor, dim: usize) -> CandleTensor {
let shape_lhs = lhs.shape();
let shape_rhs = rhs.shape();
let ndims = shape_lhs.num_dims();
// Broadcast the shapes except along dim
let mut broadcast_shape = vec![0; ndims];
for (i, item) in broadcast_shape.iter_mut().enumerate().take(ndims) {
if i == dim {
*item = shape_lhs[i];
} else {
let l = shape_lhs[i];
let r = shape_rhs[i];
if l == r {
*item = l;
} else if l == 1 {
*item = r;
} else if r == 1 {
*item = l;
} else {
panic!("Tensors are not broadcastable along dimension {}", i);
}
}
}
// Broadcast lhs and rhs
let lhs_broadcast = if shape_lhs == Shape::from(broadcast_shape.clone()) {
lhs
} else {
expand(lhs, Shape::from(broadcast_shape.clone()))
};
let rhs_broadcast = if shape_rhs == Shape::from(broadcast_shape.clone()) {
rhs
} else {
expand(rhs, Shape::from(broadcast_shape.clone()))
};
// Now, move dim to the last dimension
let mut perm = (0..ndims).collect::<Vec<_>>();
perm.remove(dim);
perm.push(dim);
let lhs_permuted = permute(lhs_broadcast, &perm);
let rhs_permuted = permute(rhs_broadcast, &perm);
// Reshape to (*, 3)
let total_elements = lhs_permuted.shape().num_elements();
let batch_size = total_elements / 3;
let lhs_reshaped = reshape(lhs_permuted, Shape::new([batch_size, 3]));
let rhs_reshaped = reshape(rhs_permuted, Shape::new([batch_size, 3]));
// Extract components using narrow and squeeze
let lhs_0 = CandleTensor::new(
lhs_reshaped
.tensor
.narrow(1, 0, 1)
.unwrap()
.squeeze(1)
.unwrap(),
);
let lhs_1 = CandleTensor::new(
lhs_reshaped
.tensor
.narrow(1, 1, 1)
.unwrap()
.squeeze(1)
.unwrap(),
);
let lhs_2 = CandleTensor::new(
lhs_reshaped
.tensor
.narrow(1, 2, 1)
.unwrap()
.squeeze(1)
.unwrap(),
);
let rhs_0 = CandleTensor::new(
rhs_reshaped
.tensor
.narrow(1, 0, 1)
.unwrap()
.squeeze(1)
.unwrap(),
);
let rhs_1 = CandleTensor::new(
rhs_reshaped
.tensor
.narrow(1, 1, 1)
.unwrap()
.squeeze(1)
.unwrap(),
);
let rhs_2 = CandleTensor::new(
rhs_reshaped
.tensor
.narrow(1, 2, 1)
.unwrap()
.squeeze(1)
.unwrap(),
);
// Compute cross product components
let result_0 = CandleTensor::new(
lhs_1
.tensor
.mul(&rhs_2.tensor)
.unwrap()
.sub(&lhs_2.tensor.mul(&rhs_1.tensor).unwrap())
.unwrap(),
);
let result_1 = CandleTensor::new(
lhs_2
.tensor
.mul(&rhs_0.tensor)
.unwrap()
.sub(&lhs_0.tensor.mul(&rhs_2.tensor).unwrap())
.unwrap(),
);
let result_2 = CandleTensor::new(
lhs_0
.tensor
.mul(&rhs_1.tensor)
.unwrap()
.sub(&lhs_1.tensor.mul(&rhs_0.tensor).unwrap())
.unwrap(),
);
// Stack the components
let result_0_unsqueezed = CandleTensor::new(result_0.tensor.unsqueeze(1).unwrap());
let result_1_unsqueezed = CandleTensor::new(result_1.tensor.unsqueeze(1).unwrap());
let result_2_unsqueezed = CandleTensor::new(result_2.tensor.unsqueeze(1).unwrap());
let result = cat(
vec![
result_0_unsqueezed,
result_1_unsqueezed,
result_2_unsqueezed,
],
1,
);
// Reshape back to the broadcast shape with dim at the end
let mut result_shape = broadcast_shape;
result_shape.remove(dim);
result_shape.push(3);
let result_reshaped = reshape(result, Shape::from(result_shape));
// Permute back
let mut inv_perm = vec![0; ndims];
for (i, &p) in perm.iter().enumerate() {
inv_perm[p] = i;
}
permute(result_reshaped, &inv_perm)
}

View File

@@ -0,0 +1,212 @@
use burn_backend::{
BackTrace, DType, ExecutionError, Scalar, Shape, Slice, TensorData, TensorMetadata,
ops::BoolTensorOps,
tensor::{BoolTensor, Device, FloatTensor, IntTensor},
};
use crate::{
Candle, CandleTensor,
element::{CandleElement, FloatCandleElement, IntCandleElement},
};
use super::base::{expand, permute, unfold};
impl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<Self> for Candle<F, I> {
fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
super::base::empty(shape, device, candle_core::DType::U8)
}
fn bool_zeros(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
super::base::zeros(shape, device, candle_core::DType::U8)
}
fn bool_ones(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
super::base::ones(shape, device, candle_core::DType::U8)
}
async fn bool_into_data(tensor: BoolTensor<Self>) -> Result<TensorData, ExecutionError> {
let x: Vec<u8> = tensor
.tensor
.flatten_all()
.map_err(|err| ExecutionError::Generic {
reason: format!("{err}"),
backtrace: BackTrace::capture(),
})?
.to_vec1()
.map_err(|err| ExecutionError::Generic {
reason: format!("{err}"),
backtrace: BackTrace::capture(),
})?;
let y = x.iter().map(|b| !matches!(b, 0)).collect();
Ok(TensorData::new(y, tensor.shape()))
}
fn bool_from_data(data: TensorData, device: &Device<Self>) -> BoolTensor<Self> {
match data.dtype {
DType::U8 => super::base::from_data::<u8>(data, device),
_ => unimplemented!("Unsupported dtype for `bool_from_data`"),
}
}
fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {
CandleTensor::new(tensor.tensor.to_dtype(I::DTYPE).unwrap())
}
fn bool_into_float(tensor: BoolTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.to_dtype(F::DTYPE).unwrap())
}
fn bool_device(tensor: &BoolTensor<Self>) -> Device<Self> {
super::base::device(tensor)
}
fn bool_to_device(tensor: BoolTensor<Self>, device: &Device<Self>) -> BoolTensor<Self> {
super::base::to_device(tensor, device)
}
fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
super::base::reshape(tensor, shape)
}
fn bool_slice(tensor: BoolTensor<Self>, slices: &[Slice]) -> BoolTensor<Self> {
super::base::slice_with_steps(tensor, slices)
}
fn bool_slice_assign(
tensor: BoolTensor<Self>,
slices: &[Slice],
value: BoolTensor<Self>,
) -> BoolTensor<Self> {
super::base::slice_assign(tensor, slices, value)
}
fn bool_cat(tensors: Vec<BoolTensor<Self>>, dim: usize) -> BoolTensor<Self> {
super::base::cat(tensors, dim)
}
fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
let (lhs_broadcast, rhs_broadcast) =
super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();
CandleTensor::new(lhs_broadcast.eq(&rhs_broadcast).unwrap())
}
fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
let x = (candle_core::Tensor::zeros_like(&tensor.tensor).unwrap());
CandleTensor::new(tensor.tensor.eq(&x).unwrap())
}
fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
let x = candle_core::Tensor::ones_like(&lhs.tensor).unwrap();
CandleTensor::new(lhs.tensor.add(&rhs.tensor).unwrap().gt(&x).unwrap())
}
fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
CandleTensor::new(
lhs.tensor
.add(&rhs.tensor)
.unwrap()
.clamp(0u32, 1u32)
.unwrap(),
)
}
fn bool_swap_dims(tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {
super::base::swap_dims(tensor, dim1, dim2)
}
fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
super::base::permute(tensor, axes)
}
fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
super::base::flip(tensor, axes)
}
fn bool_select(
tensor: BoolTensor<Self>,
dim: usize,
indices: IntTensor<Self>,
) -> BoolTensor<Self> {
CandleTensor::new(tensor.tensor.index_select(&indices.tensor, dim).unwrap())
}
fn bool_select_or(
tensor: BoolTensor<Self>,
dim: usize,
indices: IntTensor<Self>,
value: BoolTensor<Self>,
) -> BoolTensor<Self> {
CandleTensor::new(
tensor
.tensor
.index_add(&indices.tensor, &value.tensor, dim)
.unwrap(),
)
}
fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
expand(tensor, shape)
}
fn bool_unfold(
tensor: BoolTensor<Self>,
dim: usize,
size: usize,
step: usize,
) -> BoolTensor<Self> {
unfold(tensor, dim, size, step)
}
fn bool_mask_where(
tensor: BoolTensor<Self>,
mask: BoolTensor<Self>,
value: BoolTensor<Self>,
) -> BoolTensor<Self> {
super::base::mask_where_broadcasted(tensor, mask, value)
}
fn bool_mask_fill(
tensor: BoolTensor<Self>,
mask: BoolTensor<Self>,
value: Scalar,
) -> BoolTensor<Self> {
CandleTensor::new(
mask.tensor
.where_cond(
&super::candle_utils::fill_like::<u8>(value.elem(), &tensor.tensor),
&tensor.tensor,
)
.unwrap(),
)
}
fn bool_gather(
dim: usize,
tensor: BoolTensor<Self>,
indices: IntTensor<Self>,
) -> BoolTensor<Self> {
let tensor = tensor.tensor.contiguous().unwrap();
let indices = indices.tensor.contiguous().unwrap();
CandleTensor::new(tensor.gather(&indices, dim).unwrap())
}
fn bool_scatter_or(
dim: usize,
tensor: BoolTensor<Self>,
indices: IntTensor<Self>,
value: BoolTensor<Self>,
) -> BoolTensor<Self> {
CandleTensor::new(
tensor
.tensor
.scatter_add(&indices.tensor, &value.tensor, dim)
.unwrap(),
)
}
fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
CandleTensor::new(lhs.tensor.eq(rhs.elem::<u8>()).unwrap())
}
}

View File

@@ -0,0 +1,46 @@
use candle_core::{DType, Device, Shape, Tensor};
use crate::element::CandleElement;
pub(crate) fn fill<E: CandleElement, S: Into<Shape>>(
value: E,
shape: S,
dtype: DType,
device: &Device,
) -> Tensor {
let values = (Tensor::ones((1), dtype, device).unwrap() * value.elem::<f64>()).unwrap();
values.expand(shape).unwrap()
}
pub(crate) fn fill_like<E: CandleElement>(value: E, reference_tensor: &Tensor) -> Tensor {
fill(
value,
reference_tensor.shape(),
reference_tensor.dtype(),
reference_tensor.device(),
)
}
/// Broadcasts two tensors to a common shape for comparison operations
pub(crate) fn broadcast_for_comparison(
lhs: &Tensor,
rhs: &Tensor,
) -> Result<(Tensor, Tensor), candle_core::Error> {
let broadcast_shape = lhs
.shape()
.broadcast_shape_binary_op(rhs.shape(), "comparison")?;
let lhs = if broadcast_shape != *lhs.shape() {
lhs.broadcast_as(&broadcast_shape)?
} else {
lhs.clone()
};
let rhs = if broadcast_shape != *rhs.shape() {
rhs.broadcast_as(&broadcast_shape)?
} else {
rhs.clone()
};
Ok((lhs, rhs))
}

View File

@@ -0,0 +1,521 @@
use burn_backend::{
DType, Distribution, ElementConversion, ExecutionError, IntDType, Scalar, Shape, Slice,
TensorData,
ops::{FloatTensorOps, IntTensorOps},
tensor::{Bool, BoolTensor, Device, FloatTensor, IntElem, IntTensor},
};
use crate::{
Candle, CandleDevice, CandleTensor, IntoDType,
element::{CandleElement, FloatCandleElement, IntCandleElement},
};
use super::base::{cpu_random, expand, permute, sign, unfold};
impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F, I> {
fn int_empty(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
super::base::empty(shape, device, dtype.into_dtype())
}
async fn int_into_data(tensor: IntTensor<Self>) -> Result<TensorData, ExecutionError> {
super::base::into_data(tensor)
}
fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<Self> {
match data.dtype {
DType::I64 => super::base::from_data::<i64>(data, device),
DType::U32 => super::base::from_data::<u32>(data, device),
DType::U8 => super::base::from_data::<u8>(data, device),
_ => unimplemented!("Unsupported dtype for `int_from_data`"),
}
}
fn int_device(tensor: &IntTensor<Self>) -> Device<Self> {
super::base::device(tensor)
}
fn int_to_device(tensor: IntTensor<Self>, device: &Device<Self>) -> IntTensor<Self> {
super::base::to_device(tensor, device)
}
fn int_reshape(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
super::base::reshape(tensor, shape)
}
fn int_slice(tensor: IntTensor<Self>, slices: &[Slice]) -> IntTensor<Self> {
super::base::slice_with_steps(tensor, slices)
}
fn int_slice_assign(
tensor: IntTensor<Self>,
slices: &[Slice],
value: IntTensor<Self>,
) -> IntTensor<Self> {
super::base::slice_assign(tensor, slices, value)
}
fn int_into_float(tensor: IntTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.to_dtype(F::DTYPE).unwrap())
}
fn int_mask_where(
tensor: IntTensor<Self>,
mask: BoolTensor<Self>,
source: IntTensor<Self>,
) -> IntTensor<Self> {
super::base::mask_where_broadcasted(tensor, mask, source)
}
fn int_mask_fill(
tensor: IntTensor<Self>,
mask: BoolTensor<Self>,
value: Scalar,
) -> IntTensor<Self> {
CandleTensor::new(
mask.tensor
.where_cond(
&super::candle_utils::fill_like::<I>(value.elem(), &tensor.tensor),
&tensor.tensor,
)
.unwrap(),
)
}
fn int_gather(
dim: usize,
tensor: IntTensor<Self>,
indices: IntTensor<Self>,
) -> IntTensor<Self> {
let tensor = tensor.tensor.contiguous().unwrap();
let indices = indices.tensor.contiguous().unwrap();
CandleTensor::new(tensor.gather(&indices, dim).unwrap())
}
fn int_scatter_add(
dim: usize,
tensor: IntTensor<Self>,
indices: IntTensor<Self>,
value: IntTensor<Self>,
) -> IntTensor<Self> {
CandleTensor::new(
tensor
.tensor
.scatter_add(&indices.tensor, &value.tensor, dim)
.unwrap(),
)
}
fn int_select(
tensor: IntTensor<Self>,
dim: usize,
indices: IntTensor<Self>,
) -> IntTensor<Self> {
CandleTensor::new(tensor.tensor.index_select(&indices.tensor, dim).unwrap())
}
fn int_select_add(
tensor: IntTensor<Self>,
dim: usize,
indices: IntTensor<Self>,
value: IntTensor<Self>,
) -> IntTensor<Self> {
CandleTensor::new(
tensor
.tensor
.index_add(&indices.tensor, &value.tensor, dim)
.unwrap(),
)
}
fn int_cat(tensors: Vec<IntTensor<Self>>, dim: usize) -> IntTensor<Self> {
super::base::cat(tensors, dim)
}
fn int_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
let (lhs_broadcast, rhs_broadcast) =
super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();
CandleTensor::new(lhs_broadcast.eq(&rhs_broadcast).unwrap())
}
fn int_equal_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
CandleTensor::new(lhs.tensor.eq(rhs.elem::<I>()).unwrap())
}
fn int_greater(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
let (lhs_broadcast, rhs_broadcast) =
super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();
CandleTensor::new(lhs_broadcast.gt(&rhs_broadcast).unwrap())
}
fn int_greater_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
CandleTensor::new(
lhs.tensor
.gt(&super::candle_utils::fill_like::<I>(
rhs.elem(),
&lhs.tensor,
))
.unwrap(),
)
}
fn int_greater_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
let (lhs_broadcast, rhs_broadcast) =
super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();
CandleTensor::new(lhs_broadcast.ge(&rhs_broadcast).unwrap())
}
fn int_greater_equal_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
CandleTensor::new(
lhs.tensor
.ge(&super::candle_utils::fill_like::<I>(
rhs.elem(),
&lhs.tensor,
))
.unwrap(),
)
}
fn int_lower(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
let (lhs_broadcast, rhs_broadcast) =
super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();
CandleTensor::new(lhs_broadcast.lt(&rhs_broadcast).unwrap())
}
fn int_lower_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
CandleTensor::new(
lhs.tensor
.lt(&super::candle_utils::fill_like::<I>(
rhs.elem(),
&lhs.tensor,
))
.unwrap(),
)
}
fn int_lower_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
let (lhs_broadcast, rhs_broadcast) =
super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();
CandleTensor::new(lhs_broadcast.le(&rhs_broadcast).unwrap())
}
fn int_lower_equal_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
CandleTensor::new(
lhs.tensor
.le(&super::candle_utils::fill_like::<I>(
rhs.elem(),
&lhs.tensor,
))
.unwrap(),
)
}
fn int_add(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
CandleTensor::new(lhs.tensor.broadcast_add(&rhs.tensor).unwrap())
}
fn int_add_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
CandleTensor::new((lhs.tensor + rhs.elem::<f64>()).unwrap())
}
fn int_sub(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
CandleTensor::new(lhs.tensor.broadcast_sub(&rhs.tensor).unwrap())
}
fn int_sub_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
CandleTensor::new((lhs.tensor - rhs.elem::<f64>()).unwrap())
}
fn int_mul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
CandleTensor::new(lhs.tensor.broadcast_mul(&rhs.tensor).unwrap())
}
fn int_mul_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
CandleTensor::new((lhs.tensor * rhs.elem::<f64>()).unwrap())
}
fn int_div(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
CandleTensor::new(lhs.tensor.broadcast_div(&rhs.tensor).unwrap())
}
fn int_div_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
// Candle implements scalar a/b as a * (1/b). With ints 1/b is rounded to 0 so we always obtain 0.
panic!("Not supported by Candle")
}
fn int_remainder(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
CandleTensor::new(
(lhs.tensor.clone()
- lhs
.tensor
.broadcast_div(&rhs.tensor)
.unwrap()
.broadcast_mul(&rhs.tensor)
.unwrap())
.unwrap(),
)
}
fn int_remainder_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
// Same problem as int_div_scalar.
panic!("Not supported by Candle")
}
fn int_zeros(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
CandleTensor::new(
candle_core::Tensor::zeros(
shape.to_vec(),
dtype.into_dtype(),
&(device.clone()).into(),
)
.unwrap(),
)
}
fn int_ones(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
CandleTensor::new(
candle_core::Tensor::ones(shape.to_vec(), dtype.into_dtype(), &(device.clone()).into())
.unwrap(),
)
}
fn int_sum(tensor: IntTensor<Self>) -> IntTensor<Self> {
let sum = tensor.tensor.sum_all().unwrap().to_scalar::<I>().unwrap();
CandleTensor::from_data::<I>(
TensorData::new([sum].into(), [1]),
Self::int_device(&tensor),
)
}
fn int_sum_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
CandleTensor::new(tensor.tensor.sum_keepdim(dim).unwrap())
}
fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
todo!(
"prod is not implemented for Candle IntTensor (see https://github.com/tracel-ai/burn/issues/1454)"
)
}
fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
todo!(
"prod_int is not implemented for Candle IntTensor (see https://github.com/tracel-ai/burn/issues/1454)"
)
}
fn int_mean_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
// Candle implements scalar a/b as a * (1/b). With ints 1/b is rounded to 0 so we always obtain 0.
panic!("Not supported by Candle")
}
fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
// Candle's cumsum doesn't support integer types, so we convert to float,
// compute cumsum, and convert back to int
let dtype = tensor.tensor.dtype();
let tensor_float = tensor.tensor.to_dtype(candle_core::DType::F32).unwrap();
let result_float = tensor_float.cumsum(dim).unwrap();
CandleTensor::new(result_float.to_dtype(dtype).unwrap())
}
fn int_cumprod(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
// Convert to float for computation, then convert back
let dtype = tensor.tensor.dtype();
let tensor_float = tensor.tensor.to_dtype(candle_core::DType::F32).unwrap();
let result_float = super::utils::cumulative_with_op(&tensor_float, dim, |prev, curr| {
prev.broadcast_mul(curr)
});
CandleTensor::new(result_float.to_dtype(dtype).unwrap())
}
fn int_cummin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
// Convert to float for computation, then convert back
let dtype = tensor.tensor.dtype();
let tensor_float = tensor.tensor.to_dtype(candle_core::DType::F32).unwrap();
let result_float = super::utils::cumulative_with_op(&tensor_float, dim, |prev, curr| {
prev.broadcast_minimum(curr)
});
CandleTensor::new(result_float.to_dtype(dtype).unwrap())
}
fn int_cummax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
let result = super::utils::cumulative_with_op(&tensor.tensor, dim, |prev, curr| {
prev.broadcast_maximum(curr)
});
CandleTensor::new(result)
}
fn int_argmax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
CandleTensor::new(
tensor
.tensor
.argmax_keepdim(dim)
.unwrap()
.to_dtype(I::DTYPE)
.unwrap(),
)
}
fn int_argmin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
CandleTensor::new(
tensor
.tensor
.argmin_keepdim(dim)
.unwrap()
.to_dtype(I::DTYPE)
.unwrap(),
)
}
fn int_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
// Ugly type conversion here as Candle does not support unary ops on ints
match tensor.tensor.dtype() {
candle_core::DType::U8 | candle_core::DType::U32 => tensor,
candle_core::DType::I64 => CandleTensor::new(
tensor
.tensor
.to_dtype(F::DTYPE)
.unwrap()
.abs()
.unwrap()
.to_dtype(candle_core::DType::I64)
.unwrap(),
),
_ => unreachable!(),
}
}
fn int_swap_dims(tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {
super::base::swap_dims(tensor, dim1, dim2)
}
fn int_random(
shape: Shape,
distribution: Distribution,
device: &Device<Self>,
) -> IntTensor<Self> {
if let CandleDevice::Cpu = device {
let distribution = if distribution == Distribution::Default {
Distribution::Uniform(0.0, 255.0)
} else {
distribution
};
// Use our own seed since candle doesn't support it on CPU
return Self::int_from_data(cpu_random::<I>(shape, distribution), device);
}
let shape = shape.to_vec();
let device = &(device.clone()).into();
match distribution {
Distribution::Default => CandleTensor::new(
candle_core::Tensor::rand(0.elem::<F>(), 255.elem::<F>(), shape, device)
.unwrap()
.to_dtype(I::DTYPE)
.unwrap(),
),
Distribution::Bernoulli(prob) => CandleTensor::new(
candle_core::Tensor::rand(0.elem::<F>(), 1.elem::<F>(), shape.clone(), device)
.unwrap()
.to_dtype(I::DTYPE)
.unwrap()
.lt(&super::candle_utils::fill(prob, shape, I::DTYPE, device))
.unwrap()
.to_dtype(I::DTYPE)
.unwrap(),
),
Distribution::Uniform(from, to) => CandleTensor::new(
candle_core::Tensor::rand(from.elem::<F>(), to.elem::<F>(), shape, device).unwrap(),
),
Distribution::Normal(mean, std) => CandleTensor::new(
candle_core::Tensor::randn(mean.elem::<F>(), std.elem::<F>(), shape, device)
.unwrap(),
),
}
}
fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
super::base::permute(tensor, axes)
}
fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
super::base::flip(tensor, axes)
}
fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
expand(tensor, shape)
}
fn int_unfold(
tensor: IntTensor<Self>,
dim: usize,
size: usize,
step: usize,
) -> IntTensor<Self> {
unfold(tensor, dim, size, step)
}
fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> {
sign(tensor)
}
fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
unimplemented!("bitwise_and is not implemented for Candle IntTensor");
}
fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
unimplemented!("bitwise_and_scalar is not implemented for Candle IntTensor");
}
fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
unimplemented!("bitwise_or is not implemented for Candle IntTensor");
}
fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
unimplemented!("bitwise_or_scalar is not implemented for Candle IntTensor");
}
fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
unimplemented!("bitwise_xor is not implemented for Candle IntTensor");
}
fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
unimplemented!("bitwise_xor_scalar is not implemented for Candle IntTensor");
}
fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
unimplemented!("bitwise_not is not implemented for Candle IntTensor");
}
fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
unimplemented!("bitwise_left_shift is not implemented for Candle IntTensor");
}
fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
unimplemented!("bitwise_right_shift is not implemented for Candle IntTensor");
}
fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
unimplemented!("bitwise_left_shift_scalar is not implemented for Candle IntTensor");
}
fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
unimplemented!("bitwise_right_shift_scalar is not implemented for Candle IntTensor");
}
fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
let lhs = Self::int_into_float(lhs);
let rhs = Self::int_into_float(rhs);
let out = Self::float_matmul(lhs, rhs);
Self::float_into_int(out)
}
fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
let dtype = dtype.into_dtype();
if tensor.tensor.dtype() == dtype {
tensor
} else {
CandleTensor::new(tensor.tensor.to_dtype(dtype).unwrap())
}
}
}

View File

@@ -0,0 +1,10 @@
mod activation;
mod base;
mod bool_tensor;
mod candle_utils;
mod int_tensor;
mod module;
mod qtensor;
mod tensor;
mod transaction;
mod utils;

View File

@@ -0,0 +1,327 @@
use burn_backend::{
Shape,
ops::{
ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions,
InterpolateMode, InterpolateOptions, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
UnfoldOptions, attention::attention_fallback,
},
tensor::{FloatTensor, IntTensor},
};
use candle_core::ToUsize2;
use crate::{
Candle, CandleTensor,
element::{CandleElement, FloatCandleElement, IntCandleElement},
ops::base::reshape,
};
impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<Self> for Candle<F, I> {
fn conv1d(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
bias: Option<FloatTensor<Self>>,
options: ConvOptions<1>,
) -> FloatTensor<Self> {
let conv = x
.tensor
.conv1d(
&weight.tensor,
options.padding[0],
options.stride[0],
options.dilation[0],
options.groups,
)
.unwrap();
CandleTensor::new(match bias {
Some(bias) => conv
.broadcast_add(&bias.tensor.unsqueeze(1).unwrap())
.unwrap(),
None => conv,
})
}
fn conv2d(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
bias: Option<FloatTensor<Self>>,
options: ConvOptions<2>,
) -> FloatTensor<Self> {
assert!(
options.dilation[0] == options.dilation[1]
&& options.padding[0] == options.padding[1]
&& options.stride[0] == options.stride[1],
"Candle does not support per dimension options in convolutions"
);
let conv = x
.tensor
.conv2d(
&weight.tensor,
options.padding[0],
options.stride[0],
options.dilation[0],
options.groups,
)
.unwrap();
CandleTensor::new(match bias {
Some(bias) => conv
.broadcast_add(
&bias
.tensor
.unsqueeze(0)
.unwrap()
.unsqueeze(2)
.unwrap()
.unsqueeze(3)
.unwrap(),
)
.unwrap(),
None => conv,
})
}
fn deform_conv2d(
x: FloatTensor<Self>,
offset: FloatTensor<Self>,
weight: FloatTensor<Self>,
mask: Option<FloatTensor<Self>>,
bias: Option<FloatTensor<Self>>,
options: DeformConvOptions<2>,
) -> FloatTensor<Self> {
unimplemented!("Candle does not support deformable convolutions")
}
fn deform_conv2d_backward(
x: FloatTensor<Self>,
offset: FloatTensor<Self>,
weight: FloatTensor<Self>,
mask: Option<FloatTensor<Self>>,
bias: Option<FloatTensor<Self>>,
output_grad: FloatTensor<Self>,
options: DeformConvOptions<2>,
) -> DeformConv2dBackward<Self> {
unimplemented!("Candle does not support deformable convolutions")
}
fn conv3d(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
bias: Option<FloatTensor<Self>>,
options: ConvOptions<3>,
) -> FloatTensor<Self> {
panic!("Candle does not support 3D convolutions");
}
fn conv_transpose1d(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
bias: Option<FloatTensor<Self>>,
options: ConvTransposeOptions<1>,
) -> FloatTensor<Self> {
let conv_transpose = x
.tensor
.conv_transpose1d(
&weight.tensor,
options.padding[0],
options.padding_out[0],
options.stride[0],
options.dilation[0],
options.groups,
)
.unwrap();
CandleTensor::new(match bias {
Some(bias) => conv_transpose
.broadcast_add(&bias.tensor.unsqueeze(0).unwrap().unsqueeze(2).unwrap())
.unwrap(),
None => conv_transpose,
})
}
fn conv_transpose2d(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
bias: Option<FloatTensor<Self>>,
options: ConvTransposeOptions<2>,
) -> FloatTensor<Self> {
assert!(
options.dilation[0] == options.dilation[1]
&& options.padding[0] == options.padding[1]
&& options.padding_out[0] == options.padding_out[1]
&& options.stride[0] == options.stride[1],
"Candle does not support per dimension options in transposed convolutions"
);
assert!(
options.groups == 1,
"Candle does not support groups in transposed convolutions"
);
let conv_transpose = x
.tensor
.conv_transpose2d(
&weight.tensor,
options.padding[0],
options.padding_out[0],
options.stride[0],
options.dilation[0],
)
.unwrap();
CandleTensor::new(match bias {
Some(bias) => conv_transpose
.broadcast_add(
&bias
.tensor
.unsqueeze(0)
.unwrap()
.unsqueeze(2)
.unwrap()
.unsqueeze(3)
.unwrap(),
)
.unwrap(),
None => conv_transpose,
})
}
fn conv_transpose3d(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
bias: Option<FloatTensor<Self>>,
options: ConvTransposeOptions<3>,
) -> FloatTensor<Self> {
panic!("Candle does not support 3D transposed convolutions");
}
fn avg_pool2d(
x: FloatTensor<Self>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
count_include_pad: bool,
ceil_mode: bool,
) -> FloatTensor<Self> {
assert!(
padding[0] == 0 && padding[1] == 0,
"Candle does not support padding in pooling"
);
assert!(
count_include_pad,
"Candle does not support excluding pad count in pooling"
);
assert!(!ceil_mode, "Candle does not support ceil_mode in pooling");
CandleTensor::new(
x.tensor
.avg_pool2d_with_stride((kernel_size[0], kernel_size[1]), (stride[0], stride[1]))
.unwrap(),
)
}
fn avg_pool2d_backward(
x: FloatTensor<Self>,
grad: FloatTensor<Self>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
count_include_pad: bool,
_ceil_mode: bool,
) -> FloatTensor<Self> {
panic!("avg_pool2d_backward is not supported by Candle")
}
fn max_pool2d(
x: FloatTensor<Self>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
ceil_mode: bool,
) -> FloatTensor<Self> {
assert!(
padding[0] == 0 && padding[1] == 0,
"Candle does not support padding in pooling"
);
assert!(
dilation[0] == 1 && dilation[1] == 1,
"Candle does not support dilation in pooling"
);
assert!(!ceil_mode, "Candle does not support ceil_mode in pooling");
CandleTensor::new(
x.tensor
.max_pool2d_with_stride((kernel_size[0], kernel_size[1]), (stride[0], stride[1]))
.unwrap(),
)
}
fn max_pool2d_with_indices(
x: FloatTensor<Self>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
_ceil_mode: bool,
) -> MaxPool2dWithIndices<Candle<F, I>> {
panic!("max_pool2d_with_indices is not supported by Candle")
}
fn max_pool2d_with_indices_backward(
x: FloatTensor<Self>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
_ceil_mode: bool,
output_grad: FloatTensor<Self>,
indices: IntTensor<Self>,
) -> MaxPool2dBackward<Candle<F, I>> {
panic!("max_pool2d_with_indices_backward is not supported by Candle")
}
fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
panic!("adaptive_avg_pool2 is not supported by Candle")
}
fn adaptive_avg_pool2d_backward(
x: FloatTensor<Self>,
grad: FloatTensor<Self>,
) -> FloatTensor<Self> {
panic!("adaptive_avg_pool2d_backward is not supported by Candle")
}
fn interpolate(
x: FloatTensor<Self>,
output_size: [usize; 2],
options: InterpolateOptions,
) -> FloatTensor<Self> {
let tensor = match options.mode {
InterpolateMode::Nearest => x
.tensor
.upsample_nearest2d(output_size[0], output_size[1])
.unwrap(),
InterpolateMode::Bilinear => {
panic!("bilinear interpolation is not supported by Candle")
}
InterpolateMode::Bicubic => {
panic!("bicubic interpolation is not supported by Candle")
}
};
CandleTensor::new(tensor)
}
fn interpolate_backward(
x: FloatTensor<Self>,
grad: FloatTensor<Self>,
output_size: [usize; 2],
options: InterpolateOptions,
) -> FloatTensor<Self> {
panic!("interpolate_backward is not supported by Candle")
}
fn attention(
query: FloatTensor<Self>,
key: FloatTensor<Self>,
value: FloatTensor<Self>,
mask: Option<burn_backend::tensor::BoolTensor<Self>>,
attn_bias: Option<FloatTensor<Self>>,
options: burn_backend::ops::AttentionModuleOptions,
) -> FloatTensor<Self> {
attention_fallback::<Self>(query, key, value, mask, attn_bias, options)
}
}

View File

@@ -0,0 +1,88 @@
use burn_backend::{
Backend, DType, ExecutionError, Shape, Slice, TensorData,
ops::QTensorOps,
quantization::{QuantScheme, QuantizationParametersPrimitive},
tensor::{Device, FloatTensor, IntTensor, QuantizedTensor},
};
use crate::{
Candle,
element::{FloatCandleElement, IntCandleElement},
};
impl<F: FloatCandleElement, I: IntCandleElement> QTensorOps<Self> for Candle<F, I> {
fn q_from_data(data: TensorData, device: &Device<Self>) -> QuantizedTensor<Self> {
unimplemented!()
}
fn quantize(
_tensor: FloatTensor<Self>,
_scheme: &QuantScheme,
_qparams: QuantizationParametersPrimitive<Self>,
) -> QuantizedTensor<Self> {
unimplemented!()
}
fn dequantize(_tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
unimplemented!()
}
fn q_device(_tensor: &QuantizedTensor<Self>) -> Device<Self> {
unimplemented!()
}
fn q_to_device(
_tensor: QuantizedTensor<Self>,
_device: &Device<Self>,
) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_reshape(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {
unimplemented!()
}
async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {
unimplemented!()
}
fn q_swap_dims(
_tensor: QuantizedTensor<Self>,
_dim1: usize,
_dim2: usize,
) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_permute(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_flip(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_gather(
_dim: usize,
_tensor: QuantizedTensor<Self>,
_indices: IntTensor<Self>,
) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_select(
_tensor: QuantizedTensor<Self>,
_dim: usize,
_indices: IntTensor<Self>,
) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_slice(_tensor: QuantizedTensor<Self>, _slices: &[Slice]) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_expand(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {
unimplemented!()
}
}

View File

@@ -0,0 +1,612 @@
use std::borrow::Borrow;
use burn_backend::{
DType, Distribution, ElementConversion, ExecutionError, FloatDType, Scalar, Shape, Slice,
TensorData, bf16, f16,
ops::FloatTensorOps,
tensor::{BoolTensor, Device, FloatElem, FloatTensor, IntTensor},
};
use candle_core::{Tensor, backend::BackendStorage, shape};
use crate::{
Candle, CandleDevice, CandleTensor, IntoDType,
element::{CandleElement, FloatCandleElement, IntCandleElement},
};
use super::base::{cpu_random, expand, permute, sign, unfold};
impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle<F, I> {
fn float_from_data(data: TensorData, device: &Device<Self>) -> CandleTensor {
match data.dtype {
DType::F64 => super::base::from_data::<f64>(data, device),
DType::F32 => super::base::from_data::<f32>(data, device),
DType::F16 => super::base::from_data::<f16>(data, device),
DType::BF16 => super::base::from_data::<bf16>(data, device),
_ => unimplemented!("Unsupported dtype for `float_from_data`"),
}
}
fn float_random(
shape: Shape,
distribution: Distribution,
device: &Device<Self>,
) -> FloatTensor<Self> {
if let CandleDevice::Cpu = device {
// Use our own seed since candle doesn't support it on CPU
return Self::float_from_data(cpu_random::<F>(shape, distribution), device);
}
let shape = shape.to_vec();
let device = &(device.clone()).into();
match distribution {
Distribution::Default => CandleTensor::new(
candle_core::Tensor::rand(0.elem::<F>(), 1.elem::<F>(), shape, device)
.unwrap()
.to_dtype(F::DTYPE)
.unwrap(),
),
Distribution::Bernoulli(prob) => CandleTensor::new(
candle_core::Tensor::rand(0.elem::<F>(), 1.elem::<F>(), shape.clone(), device)
.unwrap()
.to_dtype(F::DTYPE)
.unwrap()
.lt(&super::candle_utils::fill(prob, shape, F::DTYPE, device))
.unwrap()
.to_dtype(F::DTYPE)
.unwrap(),
),
Distribution::Uniform(from, to) => CandleTensor::new(
candle_core::Tensor::rand(from.elem::<F>(), to.elem::<F>(), shape, device).unwrap(),
),
Distribution::Normal(mean, std) => CandleTensor::new(
candle_core::Tensor::randn(mean.elem::<F>(), std.elem::<F>(), shape, device)
.unwrap(),
),
}
}
async fn float_into_data(tensor: CandleTensor) -> Result<TensorData, ExecutionError> {
super::base::into_data(tensor)
}
fn float_device(tensor: &CandleTensor) -> Device<Self> {
super::base::device(tensor)
}
fn float_to_device(tensor: CandleTensor, device: &Device<Self>) -> CandleTensor {
super::base::to_device(tensor, device)
}
fn float_into_int(tensor: CandleTensor) -> IntTensor<Self> {
CandleTensor::new(tensor.tensor.to_dtype(I::DTYPE).unwrap())
}
fn float_empty(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
super::base::empty(shape, device, dtype.into_dtype())
}
fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(lhs.tensor.broadcast_add(&rhs.tensor).unwrap())
}
fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
CandleTensor::new((lhs.tensor + rhs.elem::<f64>()).unwrap())
}
fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(lhs.tensor.broadcast_sub(&rhs.tensor).unwrap())
}
fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
CandleTensor::new((lhs.tensor - rhs.elem::<f64>()).unwrap())
}
fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(lhs.tensor.broadcast_mul(&rhs.tensor).unwrap())
}
fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
CandleTensor::new((lhs.tensor * rhs.elem::<f64>()).unwrap())
}
fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(lhs.tensor.broadcast_div(&rhs.tensor).unwrap())
}
fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
CandleTensor::new((lhs.tensor / rhs.elem::<f64>()).unwrap())
}
fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(
(lhs.tensor.clone()
- lhs
.tensor
.broadcast_div(&rhs.tensor)
.unwrap()
.floor()
.unwrap()
.broadcast_mul(&rhs.tensor)
.unwrap())
.unwrap(),
)
}
fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
// In PyTorch, remainder can also be defined as torch.remainder(a, b) == a - a.div(b, rounding_mode="floor") * b
let rhs_val = rhs.elem::<f64>();
let division_result = (lhs.tensor.clone() / rhs_val).unwrap().floor().unwrap();
let product = division_result * rhs_val;
CandleTensor::new((lhs.tensor - product).unwrap())
}
fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
let lhs_contiguous = if !lhs.tensor.is_contiguous() {
lhs.tensor.contiguous().unwrap()
} else {
lhs.tensor
};
let rhs_contiguous = if !rhs.tensor.is_contiguous() {
rhs.tensor.contiguous().unwrap()
} else {
rhs.tensor
};
CandleTensor::new(lhs_contiguous.broadcast_matmul(&rhs_contiguous).unwrap())
}
fn float_cross(
lhs: FloatTensor<Self>,
rhs: FloatTensor<Self>,
dim: usize,
) -> FloatTensor<Self> {
super::base::cross(lhs, rhs, dim)
}
fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
super::base::swap_dims(tensor, dim1, dim2)
}
fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
super::base::reshape(tensor, shape)
}
fn float_gather(
dim: usize,
tensor: FloatTensor<Self>,
indices: IntTensor<Self>,
) -> FloatTensor<Self> {
let tensor = tensor.tensor.contiguous().unwrap();
let indices = indices.tensor.contiguous().unwrap();
CandleTensor::new(tensor.gather(&indices, dim).unwrap())
}
fn float_scatter_add(
dim: usize,
tensor: FloatTensor<Self>,
indices: IntTensor<Self>,
value: FloatTensor<Self>,
) -> FloatTensor<Self> {
CandleTensor::new(
tensor
.tensor
.scatter_add(&indices.tensor, &value.tensor, dim)
.unwrap(),
)
}
fn float_select(
tensor: FloatTensor<Self>,
dim: usize,
indices: IntTensor<Self>,
) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.index_select(&indices.tensor, dim).unwrap())
}
fn float_select_add(
tensor: FloatTensor<Self>,
dim: usize,
indices: IntTensor<Self>,
value: FloatTensor<Self>,
) -> FloatTensor<Self> {
CandleTensor::new(
tensor
.tensor
.index_add(&indices.tensor, &value.tensor, dim)
.unwrap(),
)
}
fn float_slice(tensor: FloatTensor<Self>, slices: &[Slice]) -> FloatTensor<Self> {
super::base::slice_with_steps(tensor, slices)
}
fn float_slice_assign(
tensor: FloatTensor<Self>,
slices: &[Slice],
value: FloatTensor<Self>,
) -> FloatTensor<Self> {
super::base::slice_assign(tensor, slices, value)
}
fn float_mask_where(
tensor: FloatTensor<Self>,
mask: BoolTensor<Self>,
value: FloatTensor<Self>,
) -> FloatTensor<Self> {
super::base::mask_where_broadcasted(tensor, mask, value)
}
fn float_mask_fill(
tensor: FloatTensor<Self>,
mask: BoolTensor<Self>,
value: Scalar,
) -> FloatTensor<Self> {
let value = super::candle_utils::fill_like::<F>(value.elem(), &tensor.tensor);
super::base::mask_where_broadcasted(tensor, mask, CandleTensor::new(value))
}
fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
let (lhs_broadcast, rhs_broadcast) =
super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();
CandleTensor::new(lhs_broadcast.eq(&rhs_broadcast).unwrap())
}
fn float_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
CandleTensor::new(lhs.tensor.eq(rhs.elem::<F>()).unwrap())
}
fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
let (lhs_broadcast, rhs_broadcast) =
super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();
CandleTensor::new(lhs_broadcast.gt(&rhs_broadcast).unwrap())
}
fn float_greater_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
CandleTensor::new(
lhs.tensor
.gt(&super::candle_utils::fill_like::<F>(
rhs.elem(),
&lhs.tensor,
))
.unwrap(),
)
}
fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
let (lhs_broadcast, rhs_broadcast) =
super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();
CandleTensor::new(lhs_broadcast.ge(&rhs_broadcast).unwrap())
}
fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
CandleTensor::new(
lhs.tensor
.ge(&super::candle_utils::fill_like::<F>(
rhs.elem(),
&lhs.tensor,
))
.unwrap(),
)
}
fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
let (lhs_broadcast, rhs_broadcast) =
super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();
CandleTensor::new(lhs_broadcast.lt(&rhs_broadcast).unwrap())
}
fn float_lower_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
CandleTensor::new(
lhs.tensor
.lt(&super::candle_utils::fill_like::<F>(
rhs.elem(),
&lhs.tensor,
))
.unwrap(),
)
}
fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
let (lhs_broadcast, rhs_broadcast) =
super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();
CandleTensor::new(lhs_broadcast.le(&rhs_broadcast).unwrap())
}
fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
CandleTensor::new(
lhs.tensor
.le(&super::candle_utils::fill_like::<F>(
rhs.elem(),
&lhs.tensor,
))
.unwrap(),
)
}
fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
let sum = tensor.tensor.sum_all().unwrap().to_scalar::<F>().unwrap();
CandleTensor::from_data::<F>(
TensorData::new([sum].into(), [1]),
Self::float_device(&tensor),
)
}
fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.sum_keepdim(dim).unwrap())
}
fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.mean_keepdim(dim).unwrap())
}
fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.cumsum(dim).unwrap())
}
fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
let result = super::utils::cumulative_with_op(&tensor.tensor, dim, |prev, curr| {
prev.broadcast_mul(curr)
});
CandleTensor::new(result)
}
fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
let result = super::utils::cumulative_with_op(&tensor.tensor, dim, |prev, curr| {
prev.broadcast_minimum(curr)
});
CandleTensor::new(result)
}
fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
let result = super::utils::cumulative_with_op(&tensor.tensor, dim, |prev, curr| {
prev.broadcast_maximum(curr)
});
CandleTensor::new(result)
}
fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.exp().unwrap())
}
fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.log().unwrap())
}
fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new((tensor.tensor + 1.).unwrap().log().unwrap())
}
fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.powf(value.elem::<f64>()).unwrap())
}
fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.sqrt().unwrap())
}
fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.abs().unwrap())
}
fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.cos().unwrap())
}
fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
// cosh(x) = (e^x + e^(-x)) / 2
let exp_x = tensor.tensor.exp().unwrap();
CandleTensor::new(((exp_x.clone() + exp_x.recip().unwrap()).unwrap() / 2.0).unwrap())
}
fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.sin().unwrap())
}
fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
// sinh(x) = (e^x - e^(-x)) / 2
let exp_x = tensor.tensor.exp().unwrap();
CandleTensor::new(((exp_x.clone() - exp_x.recip().unwrap()).unwrap() / 2.0).unwrap())
}
fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new((tensor.tensor.sin().unwrap() / tensor.tensor.cos().unwrap()).unwrap())
}
fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.tanh().unwrap())
}
fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
// acos(x) = PI/2 - asin(x)
let neg_asin_x = Self::float_neg(Self::float_asin(tensor));
Self::float_add_scalar(neg_asin_x, core::f64::consts::FRAC_PI_2.into())
}
fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
// acosh(x) = ln(x + sqrt(x^2 - 1))
let x_squared = Self::float_powi_scalar(tensor.clone(), 2.into());
let x_sq_minus_one = Self::float_sub_scalar(x_squared, 1f64.into());
let sqrt_term = Self::float_sqrt(x_sq_minus_one);
Self::float_log(Self::float_add(tensor, sqrt_term))
}
fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
// asin(x) = atan(x / sqrt(1 - x^2))
let x_squared = Self::float_powi_scalar(tensor.clone(), 2.into());
let one_minus_x_sq = Self::float_add_scalar(Self::float_neg(x_squared), 1f64.into());
let sqrt_term = Self::float_sqrt(one_minus_x_sq);
Self::float_atan(Self::float_div(tensor, sqrt_term))
}
fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
// asinh(x) = ln(x + sqrt(x^2 + 1))
let x_squared = Self::float_powi_scalar(tensor.clone(), 2.into());
let x_sq_plus_one = Self::float_add_scalar(x_squared, 1f64.into());
let sqrt_term = Self::float_sqrt(x_sq_plus_one);
Self::float_log(Self::float_add(tensor, sqrt_term))
}
fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
// atan(x) = asin(x / sqrt(1 + x^2))
let x_squared = Self::float_powi_scalar(tensor.clone(), 2.into());
let one_plus_x_sq = Self::float_add_scalar(x_squared, 1f64.into());
let sqrt_term = Self::float_sqrt(one_plus_x_sq);
Self::float_asin(Self::float_div(tensor, sqrt_term))
}
fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
// atanh(x) = ln((1 + x) / (1 - x)) / 2
let num = (1.0 + tensor.tensor.clone()).unwrap();
let denom = (1.0 - tensor.tensor).unwrap();
CandleTensor::new(((num / denom).unwrap().log().unwrap() / 2.0).unwrap())
}
fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
// atan2(y, x) = 2 * atan(y / (sqrt(x^2 + y^2) + x))
let x_squared = Self::float_powi_scalar(rhs.clone(), 2.into());
let y_squared = Self::float_powi_scalar(lhs.clone(), 2.into());
let r = Self::float_sqrt(Self::float_add(x_squared, y_squared));
let ratio = Self::float_div(lhs, Self::float_add(r, rhs));
Self::float_mul_scalar(Self::float_atan(ratio), 2f64.into())
}
fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
let inner = |tensor: FloatTensor<Self>| -> candle_core::Result<FloatTensor<Self>> {
// implements round_to_even for consistent behavior vs libtorch
// https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/runtime/register_ops_utils.h#L65-L67
let floor_a = tensor.tensor.floor()?;
let frac_part = tensor.tensor.sub(&floor_a)?;
let half = (candle_core::Tensor::ones_like(&tensor.tensor)? * 0.5)?;
let mask_half = frac_part.eq(&half)?;
let half_tensor = tensor.tensor.mul(&half)?;
let rounded_half = half_tensor.round()?;
let doubled =
rounded_half.mul(&(candle_core::Tensor::ones_like(&tensor.tensor)? * 2.0)?)?;
let standard_round = tensor.tensor.round()?;
Ok(CandleTensor::new(
mask_half.where_cond(&doubled, &standard_round)?,
))
};
inner(tensor).unwrap()
}
fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.floor().unwrap())
}
fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.ceil().unwrap())
}
fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
// truncate(x) = ⌊x⌋ if x ≥ 0, and ⌈x⌉ if x < 0
// This preserves the sign of zero and handles all special cases correctly
let is_negative = tensor.tensor.lt(0.0).unwrap();
let floored = tensor.tensor.floor().unwrap();
let ceiled = tensor.tensor.ceil().unwrap();
CandleTensor::new(is_negative.where_cond(&ceiled, &floored).unwrap())
}
fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.erf().unwrap())
}
fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
super::base::cat(tensors, dim)
}
fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
CandleTensor::new(
tensor
.tensor
.argmax_keepdim(dim)
.unwrap()
.to_dtype(I::DTYPE)
.unwrap(),
)
}
fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
CandleTensor::new(
tensor
.tensor
.argmin_keepdim(dim)
.unwrap()
.to_dtype(I::DTYPE)
.unwrap(),
)
}
fn float_clamp_max(tensor: FloatTensor<Self>, max: Scalar) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.minimum(max.elem::<F>()).unwrap())
}
fn float_clamp_min(tensor: FloatTensor<Self>, min: Scalar) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.maximum(min.elem::<F>()).unwrap())
}
fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {
CandleTensor::new(
tensor
.tensor
.clamp(min.elem::<F>(), max.elem::<F>())
.unwrap(),
)
}
fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.recip().unwrap())
}
fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
//broadcast_pow is in main but not yet published
//note: probably replace once pow once 0.3.3 is out
//see: https://github.com/huggingface/candle/pull/1583/files#diff-6319fa1e16dadc4c7b4e25698139703d93b70f30a1f8e2ac0999978e39efaa81R2594
CandleTensor::new(
rhs.tensor
.broadcast_mul(&lhs.tensor.log().unwrap())
.unwrap()
.exp()
.unwrap(),
)
}
fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
super::base::permute(tensor, axes)
}
fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
super::base::flip(tensor, axes)
}
fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
expand(tensor, shape)
}
fn float_unfold(
tensor: FloatTensor<Self>,
dim: usize,
size: usize,
step: usize,
) -> FloatTensor<Self> {
unfold(tensor, dim, size, step)
}
fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
sign(tensor)
}
fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
let dtype = dtype.into_dtype();
if tensor.tensor.dtype() == dtype {
tensor
} else {
CandleTensor::new(tensor.tensor.to_dtype(dtype).unwrap())
}
}
}

View File

@@ -0,0 +1,11 @@
use burn_backend::{
Backend,
ops::{TransactionOps, TransactionPrimitive},
};
use crate::{
Candle,
element::{FloatCandleElement, IntCandleElement},
};
impl<F: FloatCandleElement, I: IntCandleElement> TransactionOps<Self> for Candle<F, I> {}

View File

@@ -0,0 +1,29 @@
/// Helper function for cumulative operations in Candle backend
///
/// This function reduces code duplication for cumulative operations (cumprod, cummin, cummax)
/// which all follow the same pattern of slicing, applying an operation, and concatenating.
///
/// # Arguments
///
/// * `tensor` - The input tensor
/// * `dim` - The dimension along which to apply the cumulative operation
/// * `op` - A closure that takes two tensor references and produces a result tensor
pub fn cumulative_with_op<F>(tensor: &candle_core::Tensor, dim: usize, op: F) -> candle_core::Tensor
where
F: Fn(&candle_core::Tensor, &candle_core::Tensor) -> candle_core::Result<candle_core::Tensor>,
{
let dim_size = tensor.dims()[dim];
let mut slices = Vec::with_capacity(dim_size);
// First slice is the initial value
slices.push(tensor.narrow(dim, 0, 1).unwrap());
// Apply cumulative operation
for i in 1..dim_size {
let curr = tensor.narrow(dim, i, 1).unwrap();
let result = op(&slices[i - 1], &curr).unwrap();
slices.push(result);
}
candle_core::Tensor::cat(&slices, dim).unwrap()
}

View File

@@ -0,0 +1,114 @@
use burn_backend::{DType, FloatDType, IntDType, Shape, quantization::QuantScheme};
use burn_backend::{Element, QTensorPrimitive, TensorData, TensorMetadata};
use crate::{CandleDevice, element::CandleElement};
/// A tensor that uses the candle backend.
#[derive(Debug, Clone)]
pub struct CandleTensor {
pub(crate) tensor: candle_core::Tensor,
}
impl TensorMetadata for CandleTensor {
fn dtype(&self) -> DType {
match self.tensor.dtype() {
candle_core::DType::U8 => DType::U8,
candle_core::DType::U32 => DType::U32,
candle_core::DType::I64 => DType::I64,
candle_core::DType::BF16 => DType::BF16,
candle_core::DType::F16 => DType::F16,
candle_core::DType::F32 => DType::F32,
candle_core::DType::F64 => DType::F64,
candle_core::DType::I16 => DType::I16,
candle_core::DType::I32 => DType::I32,
other => todo!("{other:?} not yet supported"),
}
}
fn shape(&self) -> Shape {
Shape::from(self.tensor.dims().to_vec())
}
fn rank(&self) -> usize {
self.tensor.dims().len()
}
}
impl QTensorPrimitive for CandleTensor {
fn scheme(&self) -> &QuantScheme {
unimplemented!("Quantization is not supported")
}
}
impl CandleTensor {
/// Create a new tensor.
pub fn new(tensor: candle_core::Tensor) -> Self {
Self { tensor }
}
/// Creates a new tensor from data and a device.
///
/// # Arguments
///
/// * `data` - The tensor's data.
/// * `device` - The device on which the tensor will be allocated.
///
/// # Returns
///
/// A new tensor.
pub fn from_data<E: CandleElement>(data: TensorData, device: CandleDevice) -> Self {
let candle_shape: candle_core::Shape = data.shape.clone().into();
let tensor = candle_core::Tensor::from_slice(
data.as_slice::<E>().unwrap(),
candle_shape,
&device.into(),
);
Self::new(tensor.unwrap())
}
}
pub(crate) trait IntoDType {
fn try_into_dtype(self) -> Result<candle_core::DType, candle_core::Error>;
fn into_dtype(self) -> candle_core::DType
where
Self: Sized,
{
self.try_into_dtype().unwrap()
}
}
impl IntoDType for IntDType {
fn try_into_dtype(self) -> Result<candle_core::DType, candle_core::Error> {
let dtype: DType = self.into();
dtype.try_into_dtype()
}
}
impl IntoDType for FloatDType {
fn try_into_dtype(self) -> Result<candle_core::DType, candle_core::Error> {
let dtype: DType = self.into();
dtype.try_into_dtype()
}
}
impl IntoDType for DType {
fn try_into_dtype(self) -> Result<candle_core::DType, candle_core::Error> {
match self {
DType::F64 => Ok(candle_core::DType::F64),
DType::F32 => Ok(candle_core::DType::F32),
DType::Flex32 => Ok(candle_core::DType::F32),
DType::F16 => Ok(candle_core::DType::F16),
DType::BF16 => Ok(candle_core::DType::BF16),
DType::I64 => Ok(candle_core::DType::I64),
DType::U32 => Ok(candle_core::DType::U32),
DType::U8 => Ok(candle_core::DType::U8),
DType::I16 => Ok(candle_core::DType::I16),
DType::I32 => Ok(candle_core::DType::I32),
// DType::Bool => Ok(candle_core::DType::U8),
_ => Err(candle_core::Error::Msg(format!(
"Unsupported dtype {self:?}"
))),
}
}
}