feat: update workspace paths and enhance gitignore
- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
This commit is contained in:
@@ -0,0 +1,44 @@
|
||||
[package]
|
||||
authors = ["louisfd <louisfd94@gmail.com>"]
|
||||
categories = ["science"]
|
||||
description = "[Deprecated] Candle backend for the Burn framework - use burn-cubecl, burn-ndarray, or burn-tch instead"
|
||||
edition.workspace = true
|
||||
keywords = ["deep-learning", "machine-learning", "data"]
|
||||
license.workspace = true
|
||||
name = "burn-candle"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-candle"
|
||||
documentation = "https://docs.rs/burn-candle"
|
||||
version.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[features]
|
||||
default = ["std"]
|
||||
std = []
|
||||
doc = ["default"]
|
||||
tracing = [
|
||||
"burn-backend/tracing",
|
||||
"burn-std/tracing",
|
||||
]
|
||||
|
||||
cuda = ["candle-core/cuda"]
|
||||
metal = ["candle-core/metal"]
|
||||
accelerate = ["candle-core/accelerate"]
|
||||
|
||||
[dependencies]
|
||||
burn-backend = { path = "../burn-backend", version = "=0.21.0-pre.2", default-features = false }
|
||||
# For rand utils and stub mutex
|
||||
burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", default-features = false }
|
||||
|
||||
candle-core = { workspace = true }
|
||||
derive-new = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
burn-tch = { path = "../burn-tch", version = "=0.21.0-pre.2", default-features = false, features = [
|
||||
] }
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
features = ["doc"]
|
||||
rustdoc-args = ["--cfg", "docsrs"]
|
||||
@@ -0,0 +1 @@
|
||||
../../LICENSE-APACHE
|
||||
1
crates/stable-diffusion-burn/burn-crates/burn-candle/LICENSE-MIT
Symbolic link
1
crates/stable-diffusion-burn/burn-crates/burn-candle/LICENSE-MIT
Symbolic link
@@ -0,0 +1 @@
|
||||
../../LICENSE-MIT
|
||||
@@ -0,0 +1,14 @@
|
||||
# Burn Candle Backend
|
||||
|
||||
> **Deprecated:** This crate is deprecated as of `0.21.0-pre.2` and will be removed in a future release.
|
||||
> Please migrate to one of the actively maintained backends:
|
||||
> - **CubeCL backends** (CUDA, ROCm, Vulkan, Metal, WebGPU) for GPU acceleration
|
||||
> - **NdArray** for portable CPU execution
|
||||
> - **LibTorch** (`burn-tch`) for a mature CPU/GPU backend
|
||||
|
||||
This crate provides a backend for [Burn](https://github.com/tracel-ai/burn) based on the [Candle](https://github.com/huggingface/candle) framework.
|
||||
|
||||
## Feature Flags
|
||||
|
||||
- `cuda` - Cuda GPU device (NVIDIA only)
|
||||
- `accelerate` - Accelerate framework (macOS only)
|
||||
@@ -0,0 +1,300 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use burn_backend::{
|
||||
BackTrace, Backend, DType, DTypeUsage, DeviceId, DeviceOps, ExecutionError, QTensorPrimitive,
|
||||
tensor::Device,
|
||||
};
|
||||
use burn_std::{
|
||||
rand::{SeedableRng, StdRng},
|
||||
stub::Mutex,
|
||||
};
|
||||
use candle_core::{DeviceLocation, backend::BackendDevice};
|
||||
|
||||
use crate::{
|
||||
CandleTensor, IntoDType,
|
||||
element::{CandleElement, FloatCandleElement, IntCandleElement},
|
||||
};
|
||||
|
||||
/// Tensor backend that uses the [candle](candle_core) crate for executing tensor operations.
|
||||
///
|
||||
/// It is compatible with a wide range of hardware configurations, including CPUs and GPUs
|
||||
/// that support CUDA or Metal. Additionally, the backend can be compiled to `wasm` when using the CPU.
|
||||
#[derive(Clone, Default, Debug)]
|
||||
pub struct Candle<F = f32, I = i64>
|
||||
where
|
||||
F: FloatCandleElement,
|
||||
I: IntCandleElement,
|
||||
{
|
||||
_float: PhantomData<F>,
|
||||
_int: PhantomData<I>,
|
||||
}
|
||||
|
||||
// Seed for CPU device
|
||||
pub(crate) static SEED: Mutex<Option<StdRng>> = Mutex::new(None);
|
||||
|
||||
pub(crate) fn get_seeded_rng() -> StdRng {
|
||||
let mut seed = SEED.lock().unwrap();
|
||||
seed.take().unwrap_or_else(burn_std::rand::get_seeded_rng)
|
||||
}
|
||||
|
||||
pub(crate) fn set_seeded_rng(rng_seeded: StdRng) {
|
||||
let mut seed = SEED.lock().unwrap();
|
||||
*seed = Some(rng_seeded);
|
||||
}
|
||||
|
||||
/// The device type for the candle backend.
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
/// The device struct when using the `candle` backend.
|
||||
///
|
||||
/// To create a Cuda or Metal device from the index, use the associated methods to create the variant:
|
||||
/// ```no_run
|
||||
/// use burn_candle::CandleDevice;
|
||||
///
|
||||
/// // Create a Cuda device from its index
|
||||
/// let device = CandleDevice::cuda(0);
|
||||
/// // Create a Metal device from its index
|
||||
/// let device = CandleDevice::metal(0);
|
||||
/// ```
|
||||
#[derive(Default)]
|
||||
pub enum CandleDevice {
|
||||
/// CPU device.
|
||||
#[default]
|
||||
Cpu,
|
||||
|
||||
/// Cuda device with the given index. The index is the index of the Cuda device in the list of
|
||||
/// all Cuda devices found on the system.
|
||||
Cuda(CudaDevice),
|
||||
|
||||
/// Metal device with the given index. The index is the index of the Metal device in the list of
|
||||
/// all Metal devices found on the system.
|
||||
Metal(MetalDevice),
|
||||
}
|
||||
|
||||
impl CandleDevice {
|
||||
/// Create a Cuda device with the given index.
|
||||
/// The index is the index of the Cuda device in the list of all Cuda devices found on the system.
|
||||
pub fn cuda(index: usize) -> Self {
|
||||
CandleDevice::Cuda(CudaDevice {
|
||||
device: candle_core::CudaDevice::new(index).unwrap(),
|
||||
index,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a Metal device with the given index.
|
||||
/// The index is the index of the Metal device in the list of all Metal devices found on the system.
|
||||
pub fn metal(index: usize) -> Self {
|
||||
CandleDevice::Metal(MetalDevice {
|
||||
device: candle_core::MetalDevice::new(index).unwrap(),
|
||||
index,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn set_seed(&self, seed: u64) {
|
||||
match self {
|
||||
CandleDevice::Cpu => {
|
||||
// candle_core::cpu_backend::CpuDevice.set_seed(seed).unwrap();
|
||||
// Candle does not support seeding the CPU rng so we use a global seed
|
||||
let rng = StdRng::seed_from_u64(seed);
|
||||
set_seeded_rng(rng);
|
||||
}
|
||||
CandleDevice::Cuda(cuda_device) => cuda_device.device.set_seed(seed).unwrap(),
|
||||
CandleDevice::Metal(metal_device) => metal_device.device.set_seed(seed).unwrap(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
/// A Cuda device for the `candle` backend.
|
||||
pub struct CudaDevice {
|
||||
pub(crate) device: candle_core::CudaDevice,
|
||||
/// The index of the Cuda device in the list of all devices on the system.
|
||||
pub index: usize,
|
||||
}
|
||||
|
||||
impl PartialEq for CudaDevice {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.device.same_device(&other.device) && self.index == other.index
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for CudaDevice {}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
/// A Metal device for the `candle` backend.
|
||||
pub struct MetalDevice {
|
||||
pub(crate) device: candle_core::MetalDevice,
|
||||
/// The index of the Metal device in the list of all devices on the system.
|
||||
pub index: usize,
|
||||
}
|
||||
|
||||
impl PartialEq for MetalDevice {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.device.same_device(&other.device) && self.index == other.index
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for MetalDevice {}
|
||||
|
||||
impl From<CandleDevice> for candle_core::Device {
|
||||
fn from(device: CandleDevice) -> Self {
|
||||
match device {
|
||||
CandleDevice::Cpu => candle_core::Device::Cpu,
|
||||
CandleDevice::Cuda(device) => candle_core::Device::Cuda(device.device),
|
||||
CandleDevice::Metal(device) => candle_core::Device::Metal(device.device),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<candle_core::Device> for CandleDevice {
|
||||
fn from(device: candle_core::Device) -> Self {
|
||||
match device.location() {
|
||||
DeviceLocation::Cpu => CandleDevice::Cpu,
|
||||
DeviceLocation::Cuda { gpu_id } => {
|
||||
if let candle_core::Device::Cuda(device) = device {
|
||||
CandleDevice::Cuda(CudaDevice {
|
||||
device,
|
||||
index: gpu_id,
|
||||
})
|
||||
} else {
|
||||
panic!("Expected CUDA device.");
|
||||
}
|
||||
}
|
||||
DeviceLocation::Metal { gpu_id } => {
|
||||
if let candle_core::Device::Metal(device) = device {
|
||||
CandleDevice::Metal(MetalDevice {
|
||||
device,
|
||||
index: gpu_id,
|
||||
})
|
||||
} else {
|
||||
panic!("Expected Metal device.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl burn_backend::Device for CandleDevice {
|
||||
fn to_id(&self) -> burn_backend::DeviceId {
|
||||
match self {
|
||||
CandleDevice::Cuda(device) => DeviceId::new(0, device.index as u32),
|
||||
CandleDevice::Metal(device) => DeviceId::new(1, device.index as u32),
|
||||
CandleDevice::Cpu => DeviceId::new(2, 0),
|
||||
}
|
||||
}
|
||||
|
||||
fn from_id(device_id: DeviceId) -> Self {
|
||||
match device_id.type_id {
|
||||
0 => CandleDevice::cuda(device_id.index_id as usize),
|
||||
1 => CandleDevice::metal(device_id.index_id as usize),
|
||||
_ => CandleDevice::Cpu,
|
||||
}
|
||||
}
|
||||
|
||||
fn device_count(type_id: u16) -> usize {
|
||||
// TODO: Fix that
|
||||
1
|
||||
}
|
||||
}
|
||||
impl DeviceOps for CandleDevice {}
|
||||
|
||||
impl<F: FloatCandleElement, I: IntCandleElement> Backend for Candle<F, I> {
|
||||
type Device = CandleDevice;
|
||||
|
||||
type FloatTensorPrimitive = CandleTensor;
|
||||
type FloatElem = F;
|
||||
|
||||
type IntTensorPrimitive = CandleTensor;
|
||||
type IntElem = I;
|
||||
|
||||
type BoolTensorPrimitive = CandleTensor;
|
||||
type BoolElem = u8;
|
||||
|
||||
type QuantizedTensorPrimitive = CandleTensor;
|
||||
|
||||
fn ad_enabled(_device: &Self::Device) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn name(device: &Self::Device) -> String {
|
||||
match device {
|
||||
CandleDevice::Cpu => "candle<cpu>",
|
||||
CandleDevice::Cuda(..) => "candle<cuda>",
|
||||
CandleDevice::Metal(..) => "candle<metal>",
|
||||
}
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn seed(device: &CandleDevice, seed: u64) {
|
||||
device.set_seed(seed);
|
||||
}
|
||||
|
||||
fn sync(device: &Device<Self>) -> Result<(), ExecutionError> {
|
||||
let device: candle_core::Device = (device.clone()).into();
|
||||
|
||||
match device {
|
||||
candle_core::Device::Cpu => (),
|
||||
candle_core::Device::Cuda(device) => {
|
||||
#[cfg(feature = "cuda")]
|
||||
device
|
||||
.synchronize()
|
||||
.map_err(|err| ExecutionError::Generic {
|
||||
reason: format!("Can't sync the cuda device: {err}"),
|
||||
backtrace: BackTrace::capture(),
|
||||
})?;
|
||||
}
|
||||
candle_core::Device::Metal(device) => {
|
||||
// For some reason, device.wait_until_completed() does not seem to work,
|
||||
// and neither does writing and reading a value with into_data
|
||||
return Err(ExecutionError::Generic {
|
||||
reason:
|
||||
"Device synchronization unavailable with Metal device on Candle backend"
|
||||
.into(),
|
||||
backtrace: BackTrace::capture(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn dtype_usage(device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet {
|
||||
if dtype.try_into_dtype().is_ok() {
|
||||
burn_backend::DTypeUsage::general()
|
||||
} else {
|
||||
burn_backend::DTypeUsageSet::empty()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use burn_std::QuantScheme;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn should_support_dtypes() {
|
||||
type B = Candle<f32>;
|
||||
let device = Default::default();
|
||||
|
||||
assert!(B::supports_dtype(&device, DType::F64));
|
||||
assert!(B::supports_dtype(&device, DType::F32));
|
||||
assert!(B::supports_dtype(&device, DType::Flex32));
|
||||
assert!(B::supports_dtype(&device, DType::F16));
|
||||
assert!(B::supports_dtype(&device, DType::BF16));
|
||||
assert!(B::supports_dtype(&device, DType::I64));
|
||||
assert!(B::supports_dtype(&device, DType::U32));
|
||||
assert!(B::supports_dtype(&device, DType::U8));
|
||||
assert!(B::supports_dtype(&device, DType::I32));
|
||||
assert!(B::supports_dtype(&device, DType::I16));
|
||||
|
||||
assert!(!B::supports_dtype(&device, DType::U64));
|
||||
assert!(!B::supports_dtype(&device, DType::U16));
|
||||
assert!(!B::supports_dtype(&device, DType::I8));
|
||||
assert!(!B::supports_dtype(&device, DType::Bool));
|
||||
assert!(!B::supports_dtype(
|
||||
&device,
|
||||
DType::QFloat(QuantScheme::default())
|
||||
));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
use std::borrow::Borrow;
|
||||
|
||||
use burn_backend::{Element, bf16, f16};
|
||||
use candle_core::{FloatDType, Tensor, WithDType};
|
||||
|
||||
/// Candle element
|
||||
pub trait CandleElement: Element + WithDType {}
|
||||
/// Candle float element
|
||||
pub trait FloatCandleElement: CandleElement + FloatDType {}
|
||||
/// Candle int element
|
||||
pub trait IntCandleElement: CandleElement {}
|
||||
|
||||
impl CandleElement for f64 {}
|
||||
impl FloatCandleElement for f64 {}
|
||||
|
||||
impl CandleElement for f32 {}
|
||||
impl FloatCandleElement for f32 {}
|
||||
|
||||
impl CandleElement for f16 {}
|
||||
impl FloatCandleElement for f16 {}
|
||||
|
||||
impl CandleElement for bf16 {}
|
||||
impl FloatCandleElement for bf16 {}
|
||||
|
||||
impl CandleElement for u8 {}
|
||||
impl IntCandleElement for u8 {}
|
||||
|
||||
impl CandleElement for u32 {}
|
||||
impl IntCandleElement for u32 {}
|
||||
|
||||
impl CandleElement for i64 {}
|
||||
impl IntCandleElement for i64 {}
|
||||
@@ -0,0 +1,27 @@
|
||||
#![warn(missing_docs)]
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
#![allow(unused)] // TODO remove when backend filled
|
||||
#![deprecated(
|
||||
since = "0.21.0-pre.2",
|
||||
note = "burn-candle is deprecated and will be removed in a future release. Use burn-cubecl (CUDA/ROCm/Vulkan/Metal/WebGPU), burn-ndarray, or burn-tch instead."
|
||||
)]
|
||||
|
||||
//! Burn Candle Backend
|
||||
//!
|
||||
//! **Deprecated:** This backend is deprecated and will be removed in a future release.
|
||||
//! Please migrate to one of the actively maintained backends:
|
||||
//! - CubeCL backends (CUDA, ROCm, Vulkan, Metal, WebGPU) for GPU acceleration
|
||||
//! - NdArray for portable CPU execution
|
||||
//! - LibTorch (`burn-tch`) for a mature CPU/GPU backend
|
||||
|
||||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
|
||||
mod backend;
|
||||
mod element;
|
||||
mod ops;
|
||||
mod tensor;
|
||||
|
||||
pub use backend::*;
|
||||
pub use element::*;
|
||||
pub use tensor::*;
|
||||
@@ -0,0 +1,17 @@
|
||||
use burn_backend::{ops::ActivationOps, tensor::FloatTensor};
|
||||
|
||||
use crate::{
|
||||
Candle, CandleTensor,
|
||||
element::{CandleElement, FloatCandleElement, IntCandleElement},
|
||||
tensor,
|
||||
};
|
||||
|
||||
impl<F: FloatCandleElement, I: IntCandleElement> ActivationOps<Self> for Candle<F, I> {
|
||||
fn gelu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
CandleTensor::new(tensor.tensor.gelu().unwrap())
|
||||
}
|
||||
|
||||
fn relu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
CandleTensor::new(tensor.tensor.relu().unwrap())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,572 @@
|
||||
use std::cmp::max;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::{
|
||||
Candle, CandleDevice, CandleTensor,
|
||||
element::{CandleElement, FloatCandleElement, IntCandleElement},
|
||||
};
|
||||
use burn_backend::{
|
||||
BackTrace, Backend, Distribution, ExecutionError, Slice, bf16, f16,
|
||||
ops::unfold::{calculate_unfold_shape, calculate_unfold_windows},
|
||||
};
|
||||
use burn_backend::{Element, Shape, TensorData, TensorMetadata};
|
||||
use candle_core::{Layout, WithDType};
|
||||
|
||||
use super::tensor;
|
||||
|
||||
pub fn cpu_random<E: CandleElement>(shape: Shape, distribution: Distribution) -> TensorData {
|
||||
let mut rng = crate::get_seeded_rng();
|
||||
let data = TensorData::random::<E, _, _>(shape, distribution, &mut rng);
|
||||
crate::set_seeded_rng(rng);
|
||||
data
|
||||
}
|
||||
|
||||
pub fn cat(tensors: Vec<CandleTensor>, dim: usize) -> CandleTensor {
|
||||
let tensors: Vec<candle_core::Tensor> = tensors.into_iter().map(|t| t.tensor).collect();
|
||||
CandleTensor::new(candle_core::Tensor::cat(&tensors, dim).unwrap())
|
||||
}
|
||||
|
||||
pub fn from_data<E: CandleElement>(data: TensorData, device: &CandleDevice) -> CandleTensor {
|
||||
CandleTensor::from_data::<E>(data, device.clone())
|
||||
}
|
||||
pub fn into_data(tensor: CandleTensor) -> Result<TensorData, ExecutionError> {
|
||||
fn tensor_data_from_dtype<T: WithDType + Element>(
|
||||
tensor: &CandleTensor,
|
||||
) -> Result<TensorData, ExecutionError> {
|
||||
let data = tensor
|
||||
.tensor
|
||||
.flatten_all()
|
||||
.map_err(|err| ExecutionError::Generic {
|
||||
reason: format!("{err}"),
|
||||
backtrace: BackTrace::capture(),
|
||||
})?
|
||||
.to_vec1::<T>()
|
||||
.map_err(|err| ExecutionError::Generic {
|
||||
reason: format!("{err}"),
|
||||
backtrace: BackTrace::capture(),
|
||||
})?;
|
||||
Ok(TensorData::new(data, tensor.shape()))
|
||||
}
|
||||
|
||||
match tensor.tensor.dtype() {
|
||||
candle_core::DType::BF16 => tensor_data_from_dtype::<bf16>(&tensor),
|
||||
candle_core::DType::F16 => tensor_data_from_dtype::<f16>(&tensor),
|
||||
candle_core::DType::F32 => tensor_data_from_dtype::<f32>(&tensor),
|
||||
candle_core::DType::F64 => tensor_data_from_dtype::<f64>(&tensor),
|
||||
candle_core::DType::U8 => tensor_data_from_dtype::<u8>(&tensor),
|
||||
candle_core::DType::U32 => tensor_data_from_dtype::<u32>(&tensor),
|
||||
candle_core::DType::I16 => tensor_data_from_dtype::<i16>(&tensor),
|
||||
candle_core::DType::I32 => tensor_data_from_dtype::<i32>(&tensor),
|
||||
candle_core::DType::I64 => tensor_data_from_dtype::<i64>(&tensor),
|
||||
other => todo!("{other:?} not yet supported"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_device(tensor: CandleTensor, device: &CandleDevice) -> CandleTensor {
|
||||
CandleTensor::new(tensor.tensor.to_device(&(device.clone()).into()).unwrap())
|
||||
}
|
||||
|
||||
pub fn empty(shape: Shape, device: &CandleDevice, dtype: candle_core::DType) -> CandleTensor {
|
||||
zeros(shape, device, dtype)
|
||||
}
|
||||
|
||||
pub fn zeros(shape: Shape, device: &CandleDevice, dtype: candle_core::DType) -> CandleTensor {
|
||||
CandleTensor::new(
|
||||
candle_core::Tensor::zeros(shape.to_vec(), dtype, &(device.clone()).into()).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn ones(shape: Shape, device: &CandleDevice, dtype: candle_core::DType) -> CandleTensor {
|
||||
CandleTensor::new(
|
||||
candle_core::Tensor::ones(shape.to_vec(), dtype, &(device.clone()).into()).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn swap_dims(mut tensor: CandleTensor, dim1: usize, dim2: usize) -> CandleTensor {
|
||||
CandleTensor::new(tensor.tensor.transpose(dim1, dim2).unwrap())
|
||||
}
|
||||
|
||||
pub fn permute(tensor: CandleTensor, axes: &[usize]) -> CandleTensor {
|
||||
CandleTensor::new(tensor.tensor.permute(axes).unwrap())
|
||||
}
|
||||
|
||||
pub fn flip(tensor: CandleTensor, axes: &[usize]) -> CandleTensor {
|
||||
// FIXME: Replace with an appropriate method when Candle provides one.
|
||||
let mut tensor = tensor.tensor;
|
||||
for &axis in axes {
|
||||
// Ensure tensor is contiguous before index_select (required by Candle)
|
||||
tensor = tensor.contiguous().unwrap();
|
||||
|
||||
let indexes = candle_core::Tensor::arange_step(
|
||||
tensor.dim(axis).unwrap() as i64 - 1,
|
||||
-1,
|
||||
-1,
|
||||
tensor.device(),
|
||||
)
|
||||
.unwrap();
|
||||
tensor = tensor.index_select(&indexes, axis).unwrap();
|
||||
}
|
||||
|
||||
CandleTensor::new(tensor)
|
||||
}
|
||||
|
||||
pub fn reshape(tensor: CandleTensor, shape: Shape) -> CandleTensor {
|
||||
CandleTensor::new(tensor.tensor.reshape(shape.to_vec()).unwrap())
|
||||
}
|
||||
|
||||
pub fn device(tensor: &CandleTensor) -> CandleDevice {
|
||||
tensor.tensor.device().clone().into()
|
||||
}
|
||||
|
||||
pub fn shape(tensor: &CandleTensor) -> Shape {
|
||||
tensor.shape()
|
||||
}
|
||||
|
||||
pub fn slice(tensor: CandleTensor, ranges: &[std::ops::Range<usize>]) -> CandleTensor {
|
||||
let mut narrow_tensor = tensor.tensor;
|
||||
for (i, range) in ranges.iter().enumerate().take(ranges.len()) {
|
||||
narrow_tensor = narrow_tensor
|
||||
.narrow(i, range.start, range.end - range.start)
|
||||
.unwrap()
|
||||
}
|
||||
CandleTensor::new(narrow_tensor)
|
||||
}
|
||||
|
||||
pub fn slice_with_steps(tensor: CandleTensor, slices: &[Slice]) -> CandleTensor {
|
||||
let mut result_tensor = tensor.tensor;
|
||||
|
||||
for (dim, slice) in slices.iter().enumerate() {
|
||||
if slice.step == 1 {
|
||||
// Use narrow for step=1 (more efficient)
|
||||
// Convert slice to range using tensor shape
|
||||
let dim_size = result_tensor.dim(dim).unwrap();
|
||||
let range = slice.to_range(dim_size);
|
||||
let start = range.start;
|
||||
let length = range.end - range.start;
|
||||
result_tensor = result_tensor.narrow(dim, start, length).unwrap();
|
||||
} else {
|
||||
// Use index_select for step != 1
|
||||
let dim_size = result_tensor.dim(dim).unwrap();
|
||||
let range = slice.to_range(dim_size);
|
||||
let start = range.start;
|
||||
let end = range.end;
|
||||
let step = slice.step;
|
||||
|
||||
// Generate indices based on step direction
|
||||
let indices_vec = if step > 0 {
|
||||
// Forward stepping
|
||||
let step_usize = step as usize;
|
||||
(start..end).step_by(step_usize).collect::<Vec<_>>()
|
||||
} else {
|
||||
// Backward stepping (negative step)
|
||||
let step_usize = step.unsigned_abs();
|
||||
// Start from end-1 and go backwards
|
||||
let mut indices = Vec::new();
|
||||
let mut idx = end - 1;
|
||||
while idx >= start && idx < end {
|
||||
// Check for underflow
|
||||
indices.push(idx);
|
||||
if idx >= step_usize {
|
||||
idx -= step_usize;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
indices
|
||||
};
|
||||
|
||||
// Convert indices to tensor and use index_select
|
||||
let indices_len = indices_vec.len();
|
||||
let device = result_tensor.device();
|
||||
let indices = candle_core::Tensor::from_vec(
|
||||
indices_vec.iter().map(|&x| x as u32).collect::<Vec<_>>(),
|
||||
indices_len,
|
||||
device,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
result_tensor = result_tensor.index_select(&indices, dim).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
CandleTensor::new(result_tensor)
|
||||
}
|
||||
|
||||
pub fn slice_assign(tensor: CandleTensor, slices: &[Slice], value: CandleTensor) -> CandleTensor {
|
||||
// Check if all slices have step=1 (candle's native slice_assign requirement)
|
||||
let all_unit_steps = slices.iter().all(|s| s.step == 1);
|
||||
|
||||
if all_unit_steps {
|
||||
// Convert Slice to Range for candle's native slice_assign
|
||||
let ranges: Vec<std::ops::Range<usize>> = slices
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(dim, slice)| {
|
||||
let dim_size = tensor.tensor.dim(dim).unwrap_or(usize::MAX);
|
||||
slice.to_range(dim_size)
|
||||
})
|
||||
.collect();
|
||||
|
||||
CandleTensor::new(tensor.tensor.slice_assign(&ranges, &value.tensor).unwrap())
|
||||
} else {
|
||||
// Implement slice_assign with steps using scatter operations
|
||||
slice_assign_with_steps_workaround(tensor, slices, value)
|
||||
}
|
||||
}
|
||||
|
||||
/// Implements slice_assign for non-unit steps using index operations
|
||||
fn slice_assign_with_steps_workaround(
|
||||
tensor: CandleTensor,
|
||||
slices: &[Slice],
|
||||
value: CandleTensor,
|
||||
) -> CandleTensor {
|
||||
let shape = tensor.shape();
|
||||
let ndims = shape.num_dims();
|
||||
let device = tensor.tensor.device();
|
||||
|
||||
// Generate indices for each dimension based on slice specifications
|
||||
let indices_per_dim = generate_slice_indices(slices, &shape);
|
||||
|
||||
// Early return if no elements to assign
|
||||
let total_elements: usize = indices_per_dim.iter().map(|v| v.len()).product();
|
||||
if total_elements == 0 {
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// Flatten tensors and get metadata
|
||||
let value_flat = value.tensor.flatten_all().unwrap();
|
||||
let strides = tensor.tensor.stride();
|
||||
let tensor_shape = tensor.tensor.dims();
|
||||
|
||||
// Use a macro to handle different dtypes without code duplication
|
||||
macro_rules! apply_slice_assign {
|
||||
($dtype:ty, $to_vec_fn:ident) => {{
|
||||
let mut tensor_vec: Vec<$dtype> =
|
||||
tensor.tensor.flatten_all().unwrap().$to_vec_fn().unwrap();
|
||||
let value_vec: Vec<$dtype> = value_flat.$to_vec_fn().unwrap();
|
||||
|
||||
// Apply assignments using cartesian product of indices
|
||||
for (value_idx, &value) in value_vec.iter().enumerate() {
|
||||
let flat_idx = compute_flat_index(value_idx, &indices_per_dim, &strides);
|
||||
if flat_idx < tensor_vec.len() {
|
||||
tensor_vec[flat_idx] = value;
|
||||
}
|
||||
}
|
||||
|
||||
candle_core::Tensor::from_vec(tensor_vec, tensor_shape, device).unwrap()
|
||||
}};
|
||||
}
|
||||
|
||||
use candle_core::DType;
|
||||
let result = match tensor.tensor.dtype() {
|
||||
DType::F32 => apply_slice_assign!(f32, to_vec1),
|
||||
DType::F64 => apply_slice_assign!(f64, to_vec1),
|
||||
DType::I64 => apply_slice_assign!(i64, to_vec1),
|
||||
DType::U32 => apply_slice_assign!(u32, to_vec1),
|
||||
DType::U8 => apply_slice_assign!(u8, to_vec1),
|
||||
_ => panic!(
|
||||
"Unsupported dtype {:?} for slice_assign with steps",
|
||||
tensor.tensor.dtype()
|
||||
),
|
||||
};
|
||||
|
||||
CandleTensor::new(result)
|
||||
}
|
||||
|
||||
/// Generate indices for each dimension based on slice specifications
|
||||
fn generate_slice_indices(slices: &[Slice], tensor_dims: &[usize]) -> Vec<Vec<usize>> {
|
||||
let ndims = tensor_dims.len();
|
||||
let mut indices_per_dim = Vec::with_capacity(ndims);
|
||||
|
||||
// Process provided slices
|
||||
for (dim_idx, slice) in slices.iter().enumerate() {
|
||||
let dim_size = tensor_dims[dim_idx];
|
||||
let range = slice.to_range(dim_size);
|
||||
let indices = generate_stepped_indices(range.start, range.end, slice.step);
|
||||
indices_per_dim.push(indices);
|
||||
}
|
||||
|
||||
// Fill remaining dimensions with full ranges
|
||||
for &dim_size in tensor_dims.iter().skip(slices.len()) {
|
||||
indices_per_dim.push((0..dim_size).collect());
|
||||
}
|
||||
|
||||
indices_per_dim
|
||||
}
|
||||
|
||||
/// Generate indices for a single dimension with stepping
|
||||
fn generate_stepped_indices(start: usize, end: usize, step: isize) -> Vec<usize> {
|
||||
if step > 0 {
|
||||
// Forward stepping
|
||||
(start..end).step_by(step as usize).collect()
|
||||
} else if step < 0 {
|
||||
// Backward stepping: start from end-1 and go backwards
|
||||
let step_size = step.unsigned_abs();
|
||||
let mut indices = Vec::new();
|
||||
let mut idx = end.saturating_sub(1);
|
||||
|
||||
while idx >= start && idx < end {
|
||||
indices.push(idx);
|
||||
if idx >= step_size {
|
||||
idx -= step_size;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
indices
|
||||
} else {
|
||||
// This branch should never be reached since step is validated to be non-zero
|
||||
panic!("Step cannot be zero")
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute flat index from multi-dimensional indices using cartesian product logic
|
||||
fn compute_flat_index(
|
||||
value_idx: usize,
|
||||
indices_per_dim: &[Vec<usize>],
|
||||
strides: &[usize],
|
||||
) -> usize {
|
||||
let mut flat_idx = 0;
|
||||
let mut remainder = value_idx;
|
||||
|
||||
// Convert value_idx to multi-dimensional indices and compute flat tensor index
|
||||
for dim in (0..indices_per_dim.len()).rev() {
|
||||
let dim_size = indices_per_dim[dim].len();
|
||||
let idx_in_dim = remainder % dim_size;
|
||||
remainder /= dim_size;
|
||||
|
||||
let actual_idx = indices_per_dim[dim][idx_in_dim];
|
||||
flat_idx += actual_idx * strides[dim];
|
||||
}
|
||||
|
||||
flat_idx
|
||||
}
|
||||
|
||||
pub fn narrow(tensor: CandleTensor, dim: usize, start: usize, length: usize) -> CandleTensor {
|
||||
let tensor = tensor.tensor.narrow(dim, start, length);
|
||||
match tensor {
|
||||
Ok(tensor) => CandleTensor::new(tensor),
|
||||
Err(e) => panic!("error narrow from Candle"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn chunk(tensor: CandleTensor, chunks: usize, dim: usize) -> Vec<CandleTensor> {
|
||||
let tensors = tensor.tensor.chunk(chunks, dim);
|
||||
match tensors {
|
||||
Ok(tensors) => tensors.into_iter().map(CandleTensor::new).collect(),
|
||||
Err(e) => panic!("error chunk from Candle"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn expand(tensor: CandleTensor, shape: Shape) -> CandleTensor {
|
||||
CandleTensor::new(tensor.tensor.broadcast_as(shape.to_vec()).unwrap())
|
||||
}
|
||||
|
||||
pub fn unfold(tensor: CandleTensor, dim: usize, size: usize, step: usize) -> CandleTensor {
|
||||
let result_shape = calculate_unfold_shape(tensor.shape(), dim, size, step);
|
||||
let windows = result_shape[dim];
|
||||
|
||||
let mut select_ranges = tensor.shape().into_ranges();
|
||||
let new_axis = select_ranges.len();
|
||||
|
||||
let mut stack = Vec::with_capacity(windows);
|
||||
for widx in 0..windows {
|
||||
let start = widx * step;
|
||||
let end = start + size;
|
||||
select_ranges[dim] = start..end;
|
||||
|
||||
let mut window_slice = slice(tensor.clone(), &select_ranges);
|
||||
|
||||
window_slice = swap_dims(window_slice, dim, new_axis);
|
||||
let window_slice = CandleTensor::new(window_slice.tensor.unsqueeze(new_axis).unwrap());
|
||||
|
||||
stack.push(window_slice);
|
||||
}
|
||||
cat(stack, dim)
|
||||
}
|
||||
|
||||
pub fn sign(tensor: CandleTensor) -> CandleTensor {
|
||||
CandleTensor::new(tensor.tensor.sign().unwrap())
|
||||
}
|
||||
|
||||
pub fn mask_where_broadcasted(
|
||||
tensor: CandleTensor,
|
||||
mask: CandleTensor,
|
||||
value: CandleTensor,
|
||||
) -> CandleTensor {
|
||||
let shape = tensor
|
||||
.tensor
|
||||
.shape()
|
||||
.broadcast_shape_binary_op(mask.tensor.shape(), "where_cond")
|
||||
.unwrap();
|
||||
|
||||
let mut tensor = tensor.tensor;
|
||||
let mut mask = mask.tensor;
|
||||
let mut value = value.tensor;
|
||||
|
||||
if shape != *tensor.shape() {
|
||||
tensor = tensor.broadcast_as(shape.clone()).unwrap();
|
||||
}
|
||||
if shape != *mask.shape() {
|
||||
mask = mask.broadcast_as(shape.clone()).unwrap();
|
||||
}
|
||||
if shape != *value.shape() {
|
||||
value = value.broadcast_as(shape).unwrap();
|
||||
}
|
||||
|
||||
CandleTensor::new(mask.where_cond(&value, &tensor).unwrap())
|
||||
}
|
||||
|
||||
pub fn cross(lhs: CandleTensor, rhs: CandleTensor, dim: usize) -> CandleTensor {
|
||||
let shape_lhs = lhs.shape();
|
||||
let shape_rhs = rhs.shape();
|
||||
let ndims = shape_lhs.num_dims();
|
||||
|
||||
// Broadcast the shapes except along dim
|
||||
let mut broadcast_shape = vec![0; ndims];
|
||||
for (i, item) in broadcast_shape.iter_mut().enumerate().take(ndims) {
|
||||
if i == dim {
|
||||
*item = shape_lhs[i];
|
||||
} else {
|
||||
let l = shape_lhs[i];
|
||||
let r = shape_rhs[i];
|
||||
if l == r {
|
||||
*item = l;
|
||||
} else if l == 1 {
|
||||
*item = r;
|
||||
} else if r == 1 {
|
||||
*item = l;
|
||||
} else {
|
||||
panic!("Tensors are not broadcastable along dimension {}", i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Broadcast lhs and rhs
|
||||
let lhs_broadcast = if shape_lhs == Shape::from(broadcast_shape.clone()) {
|
||||
lhs
|
||||
} else {
|
||||
expand(lhs, Shape::from(broadcast_shape.clone()))
|
||||
};
|
||||
let rhs_broadcast = if shape_rhs == Shape::from(broadcast_shape.clone()) {
|
||||
rhs
|
||||
} else {
|
||||
expand(rhs, Shape::from(broadcast_shape.clone()))
|
||||
};
|
||||
|
||||
// Now, move dim to the last dimension
|
||||
let mut perm = (0..ndims).collect::<Vec<_>>();
|
||||
perm.remove(dim);
|
||||
perm.push(dim);
|
||||
|
||||
let lhs_permuted = permute(lhs_broadcast, &perm);
|
||||
let rhs_permuted = permute(rhs_broadcast, &perm);
|
||||
|
||||
// Reshape to (*, 3)
|
||||
let total_elements = lhs_permuted.shape().num_elements();
|
||||
let batch_size = total_elements / 3;
|
||||
let lhs_reshaped = reshape(lhs_permuted, Shape::new([batch_size, 3]));
|
||||
let rhs_reshaped = reshape(rhs_permuted, Shape::new([batch_size, 3]));
|
||||
|
||||
// Extract components using narrow and squeeze
|
||||
let lhs_0 = CandleTensor::new(
|
||||
lhs_reshaped
|
||||
.tensor
|
||||
.narrow(1, 0, 1)
|
||||
.unwrap()
|
||||
.squeeze(1)
|
||||
.unwrap(),
|
||||
);
|
||||
let lhs_1 = CandleTensor::new(
|
||||
lhs_reshaped
|
||||
.tensor
|
||||
.narrow(1, 1, 1)
|
||||
.unwrap()
|
||||
.squeeze(1)
|
||||
.unwrap(),
|
||||
);
|
||||
let lhs_2 = CandleTensor::new(
|
||||
lhs_reshaped
|
||||
.tensor
|
||||
.narrow(1, 2, 1)
|
||||
.unwrap()
|
||||
.squeeze(1)
|
||||
.unwrap(),
|
||||
);
|
||||
let rhs_0 = CandleTensor::new(
|
||||
rhs_reshaped
|
||||
.tensor
|
||||
.narrow(1, 0, 1)
|
||||
.unwrap()
|
||||
.squeeze(1)
|
||||
.unwrap(),
|
||||
);
|
||||
let rhs_1 = CandleTensor::new(
|
||||
rhs_reshaped
|
||||
.tensor
|
||||
.narrow(1, 1, 1)
|
||||
.unwrap()
|
||||
.squeeze(1)
|
||||
.unwrap(),
|
||||
);
|
||||
let rhs_2 = CandleTensor::new(
|
||||
rhs_reshaped
|
||||
.tensor
|
||||
.narrow(1, 2, 1)
|
||||
.unwrap()
|
||||
.squeeze(1)
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
// Compute cross product components
|
||||
let result_0 = CandleTensor::new(
|
||||
lhs_1
|
||||
.tensor
|
||||
.mul(&rhs_2.tensor)
|
||||
.unwrap()
|
||||
.sub(&lhs_2.tensor.mul(&rhs_1.tensor).unwrap())
|
||||
.unwrap(),
|
||||
);
|
||||
let result_1 = CandleTensor::new(
|
||||
lhs_2
|
||||
.tensor
|
||||
.mul(&rhs_0.tensor)
|
||||
.unwrap()
|
||||
.sub(&lhs_0.tensor.mul(&rhs_2.tensor).unwrap())
|
||||
.unwrap(),
|
||||
);
|
||||
let result_2 = CandleTensor::new(
|
||||
lhs_0
|
||||
.tensor
|
||||
.mul(&rhs_1.tensor)
|
||||
.unwrap()
|
||||
.sub(&lhs_1.tensor.mul(&rhs_0.tensor).unwrap())
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
// Stack the components
|
||||
let result_0_unsqueezed = CandleTensor::new(result_0.tensor.unsqueeze(1).unwrap());
|
||||
let result_1_unsqueezed = CandleTensor::new(result_1.tensor.unsqueeze(1).unwrap());
|
||||
let result_2_unsqueezed = CandleTensor::new(result_2.tensor.unsqueeze(1).unwrap());
|
||||
let result = cat(
|
||||
vec![
|
||||
result_0_unsqueezed,
|
||||
result_1_unsqueezed,
|
||||
result_2_unsqueezed,
|
||||
],
|
||||
1,
|
||||
);
|
||||
|
||||
// Reshape back to the broadcast shape with dim at the end
|
||||
let mut result_shape = broadcast_shape;
|
||||
result_shape.remove(dim);
|
||||
result_shape.push(3);
|
||||
let result_reshaped = reshape(result, Shape::from(result_shape));
|
||||
|
||||
// Permute back
|
||||
let mut inv_perm = vec![0; ndims];
|
||||
for (i, &p) in perm.iter().enumerate() {
|
||||
inv_perm[p] = i;
|
||||
}
|
||||
permute(result_reshaped, &inv_perm)
|
||||
}
|
||||
@@ -0,0 +1,212 @@
|
||||
use burn_backend::{
|
||||
BackTrace, DType, ExecutionError, Scalar, Shape, Slice, TensorData, TensorMetadata,
|
||||
ops::BoolTensorOps,
|
||||
tensor::{BoolTensor, Device, FloatTensor, IntTensor},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
Candle, CandleTensor,
|
||||
element::{CandleElement, FloatCandleElement, IntCandleElement},
|
||||
};
|
||||
|
||||
use super::base::{expand, permute, unfold};
|
||||
|
||||
impl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<Self> for Candle<F, I> {
|
||||
fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
|
||||
super::base::empty(shape, device, candle_core::DType::U8)
|
||||
}
|
||||
|
||||
fn bool_zeros(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
|
||||
super::base::zeros(shape, device, candle_core::DType::U8)
|
||||
}
|
||||
|
||||
fn bool_ones(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
|
||||
super::base::ones(shape, device, candle_core::DType::U8)
|
||||
}
|
||||
|
||||
async fn bool_into_data(tensor: BoolTensor<Self>) -> Result<TensorData, ExecutionError> {
|
||||
let x: Vec<u8> = tensor
|
||||
.tensor
|
||||
.flatten_all()
|
||||
.map_err(|err| ExecutionError::Generic {
|
||||
reason: format!("{err}"),
|
||||
backtrace: BackTrace::capture(),
|
||||
})?
|
||||
.to_vec1()
|
||||
.map_err(|err| ExecutionError::Generic {
|
||||
reason: format!("{err}"),
|
||||
backtrace: BackTrace::capture(),
|
||||
})?;
|
||||
|
||||
let y = x.iter().map(|b| !matches!(b, 0)).collect();
|
||||
|
||||
Ok(TensorData::new(y, tensor.shape()))
|
||||
}
|
||||
|
||||
fn bool_from_data(data: TensorData, device: &Device<Self>) -> BoolTensor<Self> {
|
||||
match data.dtype {
|
||||
DType::U8 => super::base::from_data::<u8>(data, device),
|
||||
_ => unimplemented!("Unsupported dtype for `bool_from_data`"),
|
||||
}
|
||||
}
|
||||
|
||||
fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {
|
||||
CandleTensor::new(tensor.tensor.to_dtype(I::DTYPE).unwrap())
|
||||
}
|
||||
|
||||
fn bool_into_float(tensor: BoolTensor<Self>) -> FloatTensor<Self> {
|
||||
CandleTensor::new(tensor.tensor.to_dtype(F::DTYPE).unwrap())
|
||||
}
|
||||
|
||||
fn bool_device(tensor: &BoolTensor<Self>) -> Device<Self> {
|
||||
super::base::device(tensor)
|
||||
}
|
||||
|
||||
fn bool_to_device(tensor: BoolTensor<Self>, device: &Device<Self>) -> BoolTensor<Self> {
|
||||
super::base::to_device(tensor, device)
|
||||
}
|
||||
|
||||
fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
|
||||
super::base::reshape(tensor, shape)
|
||||
}
|
||||
|
||||
fn bool_slice(tensor: BoolTensor<Self>, slices: &[Slice]) -> BoolTensor<Self> {
|
||||
super::base::slice_with_steps(tensor, slices)
|
||||
}
|
||||
|
||||
fn bool_slice_assign(
|
||||
tensor: BoolTensor<Self>,
|
||||
slices: &[Slice],
|
||||
value: BoolTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
super::base::slice_assign(tensor, slices, value)
|
||||
}
|
||||
|
||||
fn bool_cat(tensors: Vec<BoolTensor<Self>>, dim: usize) -> BoolTensor<Self> {
|
||||
super::base::cat(tensors, dim)
|
||||
}
|
||||
|
||||
fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
|
||||
let (lhs_broadcast, rhs_broadcast) =
|
||||
super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();
|
||||
CandleTensor::new(lhs_broadcast.eq(&rhs_broadcast).unwrap())
|
||||
}
|
||||
|
||||
fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
|
||||
let x = (candle_core::Tensor::zeros_like(&tensor.tensor).unwrap());
|
||||
CandleTensor::new(tensor.tensor.eq(&x).unwrap())
|
||||
}
|
||||
|
||||
fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
|
||||
let x = candle_core::Tensor::ones_like(&lhs.tensor).unwrap();
|
||||
CandleTensor::new(lhs.tensor.add(&rhs.tensor).unwrap().gt(&x).unwrap())
|
||||
}
|
||||
|
||||
fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
|
||||
CandleTensor::new(
|
||||
lhs.tensor
|
||||
.add(&rhs.tensor)
|
||||
.unwrap()
|
||||
.clamp(0u32, 1u32)
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn bool_swap_dims(tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {
|
||||
super::base::swap_dims(tensor, dim1, dim2)
|
||||
}
|
||||
|
||||
fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
|
||||
super::base::permute(tensor, axes)
|
||||
}
|
||||
|
||||
fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
|
||||
super::base::flip(tensor, axes)
|
||||
}
|
||||
|
||||
fn bool_select(
|
||||
tensor: BoolTensor<Self>,
|
||||
dim: usize,
|
||||
indices: IntTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
CandleTensor::new(tensor.tensor.index_select(&indices.tensor, dim).unwrap())
|
||||
}
|
||||
|
||||
fn bool_select_or(
|
||||
tensor: BoolTensor<Self>,
|
||||
dim: usize,
|
||||
indices: IntTensor<Self>,
|
||||
value: BoolTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
CandleTensor::new(
|
||||
tensor
|
||||
.tensor
|
||||
.index_add(&indices.tensor, &value.tensor, dim)
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
|
||||
expand(tensor, shape)
|
||||
}
|
||||
|
||||
fn bool_unfold(
|
||||
tensor: BoolTensor<Self>,
|
||||
dim: usize,
|
||||
size: usize,
|
||||
step: usize,
|
||||
) -> BoolTensor<Self> {
|
||||
unfold(tensor, dim, size, step)
|
||||
}
|
||||
|
||||
fn bool_mask_where(
|
||||
tensor: BoolTensor<Self>,
|
||||
mask: BoolTensor<Self>,
|
||||
value: BoolTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
super::base::mask_where_broadcasted(tensor, mask, value)
|
||||
}
|
||||
|
||||
fn bool_mask_fill(
|
||||
tensor: BoolTensor<Self>,
|
||||
mask: BoolTensor<Self>,
|
||||
value: Scalar,
|
||||
) -> BoolTensor<Self> {
|
||||
CandleTensor::new(
|
||||
mask.tensor
|
||||
.where_cond(
|
||||
&super::candle_utils::fill_like::<u8>(value.elem(), &tensor.tensor),
|
||||
&tensor.tensor,
|
||||
)
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn bool_gather(
|
||||
dim: usize,
|
||||
tensor: BoolTensor<Self>,
|
||||
indices: IntTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
let tensor = tensor.tensor.contiguous().unwrap();
|
||||
let indices = indices.tensor.contiguous().unwrap();
|
||||
CandleTensor::new(tensor.gather(&indices, dim).unwrap())
|
||||
}
|
||||
|
||||
fn bool_scatter_or(
|
||||
dim: usize,
|
||||
tensor: BoolTensor<Self>,
|
||||
indices: IntTensor<Self>,
|
||||
value: BoolTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
CandleTensor::new(
|
||||
tensor
|
||||
.tensor
|
||||
.scatter_add(&indices.tensor, &value.tensor, dim)
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
|
||||
CandleTensor::new(lhs.tensor.eq(rhs.elem::<u8>()).unwrap())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
use candle_core::{DType, Device, Shape, Tensor};
|
||||
|
||||
use crate::element::CandleElement;
|
||||
|
||||
pub(crate) fn fill<E: CandleElement, S: Into<Shape>>(
|
||||
value: E,
|
||||
shape: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
) -> Tensor {
|
||||
let values = (Tensor::ones((1), dtype, device).unwrap() * value.elem::<f64>()).unwrap();
|
||||
values.expand(shape).unwrap()
|
||||
}
|
||||
|
||||
pub(crate) fn fill_like<E: CandleElement>(value: E, reference_tensor: &Tensor) -> Tensor {
|
||||
fill(
|
||||
value,
|
||||
reference_tensor.shape(),
|
||||
reference_tensor.dtype(),
|
||||
reference_tensor.device(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Broadcasts two tensors to a common shape for comparison operations
|
||||
pub(crate) fn broadcast_for_comparison(
|
||||
lhs: &Tensor,
|
||||
rhs: &Tensor,
|
||||
) -> Result<(Tensor, Tensor), candle_core::Error> {
|
||||
let broadcast_shape = lhs
|
||||
.shape()
|
||||
.broadcast_shape_binary_op(rhs.shape(), "comparison")?;
|
||||
|
||||
let lhs = if broadcast_shape != *lhs.shape() {
|
||||
lhs.broadcast_as(&broadcast_shape)?
|
||||
} else {
|
||||
lhs.clone()
|
||||
};
|
||||
|
||||
let rhs = if broadcast_shape != *rhs.shape() {
|
||||
rhs.broadcast_as(&broadcast_shape)?
|
||||
} else {
|
||||
rhs.clone()
|
||||
};
|
||||
|
||||
Ok((lhs, rhs))
|
||||
}
|
||||
@@ -0,0 +1,521 @@
|
||||
use burn_backend::{
|
||||
DType, Distribution, ElementConversion, ExecutionError, IntDType, Scalar, Shape, Slice,
|
||||
TensorData,
|
||||
ops::{FloatTensorOps, IntTensorOps},
|
||||
tensor::{Bool, BoolTensor, Device, FloatTensor, IntElem, IntTensor},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
Candle, CandleDevice, CandleTensor, IntoDType,
|
||||
element::{CandleElement, FloatCandleElement, IntCandleElement},
|
||||
};
|
||||
|
||||
use super::base::{cpu_random, expand, permute, sign, unfold};
|
||||
|
||||
impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F, I> {
|
||||
fn int_empty(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
|
||||
super::base::empty(shape, device, dtype.into_dtype())
|
||||
}
|
||||
|
||||
async fn int_into_data(tensor: IntTensor<Self>) -> Result<TensorData, ExecutionError> {
|
||||
super::base::into_data(tensor)
|
||||
}
|
||||
|
||||
fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<Self> {
|
||||
match data.dtype {
|
||||
DType::I64 => super::base::from_data::<i64>(data, device),
|
||||
DType::U32 => super::base::from_data::<u32>(data, device),
|
||||
DType::U8 => super::base::from_data::<u8>(data, device),
|
||||
_ => unimplemented!("Unsupported dtype for `int_from_data`"),
|
||||
}
|
||||
}
|
||||
|
||||
fn int_device(tensor: &IntTensor<Self>) -> Device<Self> {
|
||||
super::base::device(tensor)
|
||||
}
|
||||
|
||||
fn int_to_device(tensor: IntTensor<Self>, device: &Device<Self>) -> IntTensor<Self> {
|
||||
super::base::to_device(tensor, device)
|
||||
}
|
||||
|
||||
fn int_reshape(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
|
||||
super::base::reshape(tensor, shape)
|
||||
}
|
||||
|
||||
fn int_slice(tensor: IntTensor<Self>, slices: &[Slice]) -> IntTensor<Self> {
|
||||
super::base::slice_with_steps(tensor, slices)
|
||||
}
|
||||
|
||||
fn int_slice_assign(
|
||||
tensor: IntTensor<Self>,
|
||||
slices: &[Slice],
|
||||
value: IntTensor<Self>,
|
||||
) -> IntTensor<Self> {
|
||||
super::base::slice_assign(tensor, slices, value)
|
||||
}
|
||||
|
||||
fn int_into_float(tensor: IntTensor<Self>) -> FloatTensor<Self> {
|
||||
CandleTensor::new(tensor.tensor.to_dtype(F::DTYPE).unwrap())
|
||||
}
|
||||
|
||||
fn int_mask_where(
|
||||
tensor: IntTensor<Self>,
|
||||
mask: BoolTensor<Self>,
|
||||
source: IntTensor<Self>,
|
||||
) -> IntTensor<Self> {
|
||||
super::base::mask_where_broadcasted(tensor, mask, source)
|
||||
}
|
||||
|
||||
fn int_mask_fill(
|
||||
tensor: IntTensor<Self>,
|
||||
mask: BoolTensor<Self>,
|
||||
value: Scalar,
|
||||
) -> IntTensor<Self> {
|
||||
CandleTensor::new(
|
||||
mask.tensor
|
||||
.where_cond(
|
||||
&super::candle_utils::fill_like::<I>(value.elem(), &tensor.tensor),
|
||||
&tensor.tensor,
|
||||
)
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn int_gather(
|
||||
dim: usize,
|
||||
tensor: IntTensor<Self>,
|
||||
indices: IntTensor<Self>,
|
||||
) -> IntTensor<Self> {
|
||||
let tensor = tensor.tensor.contiguous().unwrap();
|
||||
let indices = indices.tensor.contiguous().unwrap();
|
||||
CandleTensor::new(tensor.gather(&indices, dim).unwrap())
|
||||
}
|
||||
|
||||
fn int_scatter_add(
|
||||
dim: usize,
|
||||
tensor: IntTensor<Self>,
|
||||
indices: IntTensor<Self>,
|
||||
value: IntTensor<Self>,
|
||||
) -> IntTensor<Self> {
|
||||
CandleTensor::new(
|
||||
tensor
|
||||
.tensor
|
||||
.scatter_add(&indices.tensor, &value.tensor, dim)
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn int_select(
|
||||
tensor: IntTensor<Self>,
|
||||
dim: usize,
|
||||
indices: IntTensor<Self>,
|
||||
) -> IntTensor<Self> {
|
||||
CandleTensor::new(tensor.tensor.index_select(&indices.tensor, dim).unwrap())
|
||||
}
|
||||
|
||||
fn int_select_add(
|
||||
tensor: IntTensor<Self>,
|
||||
dim: usize,
|
||||
indices: IntTensor<Self>,
|
||||
value: IntTensor<Self>,
|
||||
) -> IntTensor<Self> {
|
||||
CandleTensor::new(
|
||||
tensor
|
||||
.tensor
|
||||
.index_add(&indices.tensor, &value.tensor, dim)
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn int_cat(tensors: Vec<IntTensor<Self>>, dim: usize) -> IntTensor<Self> {
|
||||
super::base::cat(tensors, dim)
|
||||
}
|
||||
|
||||
fn int_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
|
||||
let (lhs_broadcast, rhs_broadcast) =
|
||||
super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();
|
||||
CandleTensor::new(lhs_broadcast.eq(&rhs_broadcast).unwrap())
|
||||
}
|
||||
|
||||
fn int_equal_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
|
||||
CandleTensor::new(lhs.tensor.eq(rhs.elem::<I>()).unwrap())
|
||||
}
|
||||
|
||||
fn int_greater(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
|
||||
let (lhs_broadcast, rhs_broadcast) =
|
||||
super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();
|
||||
CandleTensor::new(lhs_broadcast.gt(&rhs_broadcast).unwrap())
|
||||
}
|
||||
|
||||
fn int_greater_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
|
||||
CandleTensor::new(
|
||||
lhs.tensor
|
||||
.gt(&super::candle_utils::fill_like::<I>(
|
||||
rhs.elem(),
|
||||
&lhs.tensor,
|
||||
))
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn int_greater_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
|
||||
let (lhs_broadcast, rhs_broadcast) =
|
||||
super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();
|
||||
CandleTensor::new(lhs_broadcast.ge(&rhs_broadcast).unwrap())
|
||||
}
|
||||
|
||||
fn int_greater_equal_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
|
||||
CandleTensor::new(
|
||||
lhs.tensor
|
||||
.ge(&super::candle_utils::fill_like::<I>(
|
||||
rhs.elem(),
|
||||
&lhs.tensor,
|
||||
))
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn int_lower(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
|
||||
let (lhs_broadcast, rhs_broadcast) =
|
||||
super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();
|
||||
CandleTensor::new(lhs_broadcast.lt(&rhs_broadcast).unwrap())
|
||||
}
|
||||
|
||||
fn int_lower_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
|
||||
CandleTensor::new(
|
||||
lhs.tensor
|
||||
.lt(&super::candle_utils::fill_like::<I>(
|
||||
rhs.elem(),
|
||||
&lhs.tensor,
|
||||
))
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn int_lower_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
|
||||
let (lhs_broadcast, rhs_broadcast) =
|
||||
super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();
|
||||
CandleTensor::new(lhs_broadcast.le(&rhs_broadcast).unwrap())
|
||||
}
|
||||
|
||||
fn int_lower_equal_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
|
||||
CandleTensor::new(
|
||||
lhs.tensor
|
||||
.le(&super::candle_utils::fill_like::<I>(
|
||||
rhs.elem(),
|
||||
&lhs.tensor,
|
||||
))
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn int_add(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
CandleTensor::new(lhs.tensor.broadcast_add(&rhs.tensor).unwrap())
|
||||
}
|
||||
|
||||
fn int_add_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
CandleTensor::new((lhs.tensor + rhs.elem::<f64>()).unwrap())
|
||||
}
|
||||
|
||||
fn int_sub(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
CandleTensor::new(lhs.tensor.broadcast_sub(&rhs.tensor).unwrap())
|
||||
}
|
||||
|
||||
fn int_sub_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
CandleTensor::new((lhs.tensor - rhs.elem::<f64>()).unwrap())
|
||||
}
|
||||
|
||||
fn int_mul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
CandleTensor::new(lhs.tensor.broadcast_mul(&rhs.tensor).unwrap())
|
||||
}
|
||||
|
||||
fn int_mul_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
CandleTensor::new((lhs.tensor * rhs.elem::<f64>()).unwrap())
|
||||
}
|
||||
|
||||
fn int_div(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
CandleTensor::new(lhs.tensor.broadcast_div(&rhs.tensor).unwrap())
|
||||
}
|
||||
|
||||
fn int_div_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
// Candle implements scalar a/b as a * (1/b). With ints 1/b is rounded to 0 so we always obtain 0.
|
||||
panic!("Not supported by Candle")
|
||||
}
|
||||
|
||||
fn int_remainder(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
CandleTensor::new(
|
||||
(lhs.tensor.clone()
|
||||
- lhs
|
||||
.tensor
|
||||
.broadcast_div(&rhs.tensor)
|
||||
.unwrap()
|
||||
.broadcast_mul(&rhs.tensor)
|
||||
.unwrap())
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn int_remainder_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
// Same problem as int_div_scalar.
|
||||
panic!("Not supported by Candle")
|
||||
}
|
||||
|
||||
fn int_zeros(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
|
||||
CandleTensor::new(
|
||||
candle_core::Tensor::zeros(
|
||||
shape.to_vec(),
|
||||
dtype.into_dtype(),
|
||||
&(device.clone()).into(),
|
||||
)
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn int_ones(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
|
||||
CandleTensor::new(
|
||||
candle_core::Tensor::ones(shape.to_vec(), dtype.into_dtype(), &(device.clone()).into())
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn int_sum(tensor: IntTensor<Self>) -> IntTensor<Self> {
|
||||
let sum = tensor.tensor.sum_all().unwrap().to_scalar::<I>().unwrap();
|
||||
CandleTensor::from_data::<I>(
|
||||
TensorData::new([sum].into(), [1]),
|
||||
Self::int_device(&tensor),
|
||||
)
|
||||
}
|
||||
|
||||
fn int_sum_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
|
||||
CandleTensor::new(tensor.tensor.sum_keepdim(dim).unwrap())
|
||||
}
|
||||
|
||||
fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
|
||||
todo!(
|
||||
"prod is not implemented for Candle IntTensor (see https://github.com/tracel-ai/burn/issues/1454)"
|
||||
)
|
||||
}
|
||||
|
||||
fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
|
||||
todo!(
|
||||
"prod_int is not implemented for Candle IntTensor (see https://github.com/tracel-ai/burn/issues/1454)"
|
||||
)
|
||||
}
|
||||
|
||||
fn int_mean_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
|
||||
// Candle implements scalar a/b as a * (1/b). With ints 1/b is rounded to 0 so we always obtain 0.
|
||||
panic!("Not supported by Candle")
|
||||
}
|
||||
|
||||
fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
|
||||
// Candle's cumsum doesn't support integer types, so we convert to float,
|
||||
// compute cumsum, and convert back to int
|
||||
let dtype = tensor.tensor.dtype();
|
||||
let tensor_float = tensor.tensor.to_dtype(candle_core::DType::F32).unwrap();
|
||||
let result_float = tensor_float.cumsum(dim).unwrap();
|
||||
CandleTensor::new(result_float.to_dtype(dtype).unwrap())
|
||||
}
|
||||
|
||||
fn int_cumprod(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
|
||||
// Convert to float for computation, then convert back
|
||||
let dtype = tensor.tensor.dtype();
|
||||
let tensor_float = tensor.tensor.to_dtype(candle_core::DType::F32).unwrap();
|
||||
|
||||
let result_float = super::utils::cumulative_with_op(&tensor_float, dim, |prev, curr| {
|
||||
prev.broadcast_mul(curr)
|
||||
});
|
||||
CandleTensor::new(result_float.to_dtype(dtype).unwrap())
|
||||
}
|
||||
|
||||
fn int_cummin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
|
||||
// Convert to float for computation, then convert back
|
||||
let dtype = tensor.tensor.dtype();
|
||||
let tensor_float = tensor.tensor.to_dtype(candle_core::DType::F32).unwrap();
|
||||
|
||||
let result_float = super::utils::cumulative_with_op(&tensor_float, dim, |prev, curr| {
|
||||
prev.broadcast_minimum(curr)
|
||||
});
|
||||
CandleTensor::new(result_float.to_dtype(dtype).unwrap())
|
||||
}
|
||||
|
||||
fn int_cummax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
|
||||
let result = super::utils::cumulative_with_op(&tensor.tensor, dim, |prev, curr| {
|
||||
prev.broadcast_maximum(curr)
|
||||
});
|
||||
CandleTensor::new(result)
|
||||
}
|
||||
|
||||
fn int_argmax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
|
||||
CandleTensor::new(
|
||||
tensor
|
||||
.tensor
|
||||
.argmax_keepdim(dim)
|
||||
.unwrap()
|
||||
.to_dtype(I::DTYPE)
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn int_argmin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
|
||||
CandleTensor::new(
|
||||
tensor
|
||||
.tensor
|
||||
.argmin_keepdim(dim)
|
||||
.unwrap()
|
||||
.to_dtype(I::DTYPE)
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn int_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
|
||||
// Ugly type conversion here as Candle does not support unary ops on ints
|
||||
match tensor.tensor.dtype() {
|
||||
candle_core::DType::U8 | candle_core::DType::U32 => tensor,
|
||||
candle_core::DType::I64 => CandleTensor::new(
|
||||
tensor
|
||||
.tensor
|
||||
.to_dtype(F::DTYPE)
|
||||
.unwrap()
|
||||
.abs()
|
||||
.unwrap()
|
||||
.to_dtype(candle_core::DType::I64)
|
||||
.unwrap(),
|
||||
),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn int_swap_dims(tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {
|
||||
super::base::swap_dims(tensor, dim1, dim2)
|
||||
}
|
||||
|
||||
fn int_random(
|
||||
shape: Shape,
|
||||
distribution: Distribution,
|
||||
device: &Device<Self>,
|
||||
) -> IntTensor<Self> {
|
||||
if let CandleDevice::Cpu = device {
|
||||
let distribution = if distribution == Distribution::Default {
|
||||
Distribution::Uniform(0.0, 255.0)
|
||||
} else {
|
||||
distribution
|
||||
};
|
||||
// Use our own seed since candle doesn't support it on CPU
|
||||
return Self::int_from_data(cpu_random::<I>(shape, distribution), device);
|
||||
}
|
||||
|
||||
let shape = shape.to_vec();
|
||||
let device = &(device.clone()).into();
|
||||
match distribution {
|
||||
Distribution::Default => CandleTensor::new(
|
||||
candle_core::Tensor::rand(0.elem::<F>(), 255.elem::<F>(), shape, device)
|
||||
.unwrap()
|
||||
.to_dtype(I::DTYPE)
|
||||
.unwrap(),
|
||||
),
|
||||
Distribution::Bernoulli(prob) => CandleTensor::new(
|
||||
candle_core::Tensor::rand(0.elem::<F>(), 1.elem::<F>(), shape.clone(), device)
|
||||
.unwrap()
|
||||
.to_dtype(I::DTYPE)
|
||||
.unwrap()
|
||||
.lt(&super::candle_utils::fill(prob, shape, I::DTYPE, device))
|
||||
.unwrap()
|
||||
.to_dtype(I::DTYPE)
|
||||
.unwrap(),
|
||||
),
|
||||
Distribution::Uniform(from, to) => CandleTensor::new(
|
||||
candle_core::Tensor::rand(from.elem::<F>(), to.elem::<F>(), shape, device).unwrap(),
|
||||
),
|
||||
Distribution::Normal(mean, std) => CandleTensor::new(
|
||||
candle_core::Tensor::randn(mean.elem::<F>(), std.elem::<F>(), shape, device)
|
||||
.unwrap(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
|
||||
super::base::permute(tensor, axes)
|
||||
}
|
||||
|
||||
fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
|
||||
super::base::flip(tensor, axes)
|
||||
}
|
||||
|
||||
fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
|
||||
expand(tensor, shape)
|
||||
}
|
||||
|
||||
fn int_unfold(
|
||||
tensor: IntTensor<Self>,
|
||||
dim: usize,
|
||||
size: usize,
|
||||
step: usize,
|
||||
) -> IntTensor<Self> {
|
||||
unfold(tensor, dim, size, step)
|
||||
}
|
||||
|
||||
fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> {
|
||||
sign(tensor)
|
||||
}
|
||||
fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
unimplemented!("bitwise_and is not implemented for Candle IntTensor");
|
||||
}
|
||||
|
||||
fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
unimplemented!("bitwise_and_scalar is not implemented for Candle IntTensor");
|
||||
}
|
||||
|
||||
fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
unimplemented!("bitwise_or is not implemented for Candle IntTensor");
|
||||
}
|
||||
|
||||
fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
unimplemented!("bitwise_or_scalar is not implemented for Candle IntTensor");
|
||||
}
|
||||
|
||||
fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
unimplemented!("bitwise_xor is not implemented for Candle IntTensor");
|
||||
}
|
||||
|
||||
fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
unimplemented!("bitwise_xor_scalar is not implemented for Candle IntTensor");
|
||||
}
|
||||
|
||||
fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
|
||||
unimplemented!("bitwise_not is not implemented for Candle IntTensor");
|
||||
}
|
||||
|
||||
fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
unimplemented!("bitwise_left_shift is not implemented for Candle IntTensor");
|
||||
}
|
||||
|
||||
fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
unimplemented!("bitwise_right_shift is not implemented for Candle IntTensor");
|
||||
}
|
||||
|
||||
fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
unimplemented!("bitwise_left_shift_scalar is not implemented for Candle IntTensor");
|
||||
}
|
||||
|
||||
fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
|
||||
unimplemented!("bitwise_right_shift_scalar is not implemented for Candle IntTensor");
|
||||
}
|
||||
|
||||
fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
|
||||
let lhs = Self::int_into_float(lhs);
|
||||
let rhs = Self::int_into_float(rhs);
|
||||
|
||||
let out = Self::float_matmul(lhs, rhs);
|
||||
Self::float_into_int(out)
|
||||
}
|
||||
|
||||
fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
|
||||
let dtype = dtype.into_dtype();
|
||||
|
||||
if tensor.tensor.dtype() == dtype {
|
||||
tensor
|
||||
} else {
|
||||
CandleTensor::new(tensor.tensor.to_dtype(dtype).unwrap())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
mod activation;
|
||||
mod base;
|
||||
mod bool_tensor;
|
||||
mod candle_utils;
|
||||
mod int_tensor;
|
||||
mod module;
|
||||
mod qtensor;
|
||||
mod tensor;
|
||||
mod transaction;
|
||||
mod utils;
|
||||
@@ -0,0 +1,327 @@
|
||||
use burn_backend::{
|
||||
Shape,
|
||||
ops::{
|
||||
ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions,
|
||||
InterpolateMode, InterpolateOptions, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
|
||||
UnfoldOptions, attention::attention_fallback,
|
||||
},
|
||||
tensor::{FloatTensor, IntTensor},
|
||||
};
|
||||
use candle_core::ToUsize2;
|
||||
|
||||
use crate::{
|
||||
Candle, CandleTensor,
|
||||
element::{CandleElement, FloatCandleElement, IntCandleElement},
|
||||
ops::base::reshape,
|
||||
};
|
||||
|
||||
impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<Self> for Candle<F, I> {
|
||||
fn conv1d(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
options: ConvOptions<1>,
|
||||
) -> FloatTensor<Self> {
|
||||
let conv = x
|
||||
.tensor
|
||||
.conv1d(
|
||||
&weight.tensor,
|
||||
options.padding[0],
|
||||
options.stride[0],
|
||||
options.dilation[0],
|
||||
options.groups,
|
||||
)
|
||||
.unwrap();
|
||||
CandleTensor::new(match bias {
|
||||
Some(bias) => conv
|
||||
.broadcast_add(&bias.tensor.unsqueeze(1).unwrap())
|
||||
.unwrap(),
|
||||
None => conv,
|
||||
})
|
||||
}
|
||||
|
||||
fn conv2d(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
options: ConvOptions<2>,
|
||||
) -> FloatTensor<Self> {
|
||||
assert!(
|
||||
options.dilation[0] == options.dilation[1]
|
||||
&& options.padding[0] == options.padding[1]
|
||||
&& options.stride[0] == options.stride[1],
|
||||
"Candle does not support per dimension options in convolutions"
|
||||
);
|
||||
let conv = x
|
||||
.tensor
|
||||
.conv2d(
|
||||
&weight.tensor,
|
||||
options.padding[0],
|
||||
options.stride[0],
|
||||
options.dilation[0],
|
||||
options.groups,
|
||||
)
|
||||
.unwrap();
|
||||
CandleTensor::new(match bias {
|
||||
Some(bias) => conv
|
||||
.broadcast_add(
|
||||
&bias
|
||||
.tensor
|
||||
.unsqueeze(0)
|
||||
.unwrap()
|
||||
.unsqueeze(2)
|
||||
.unwrap()
|
||||
.unsqueeze(3)
|
||||
.unwrap(),
|
||||
)
|
||||
.unwrap(),
|
||||
None => conv,
|
||||
})
|
||||
}
|
||||
|
||||
fn deform_conv2d(
|
||||
x: FloatTensor<Self>,
|
||||
offset: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
mask: Option<FloatTensor<Self>>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
options: DeformConvOptions<2>,
|
||||
) -> FloatTensor<Self> {
|
||||
unimplemented!("Candle does not support deformable convolutions")
|
||||
}
|
||||
|
||||
fn deform_conv2d_backward(
|
||||
x: FloatTensor<Self>,
|
||||
offset: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
mask: Option<FloatTensor<Self>>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
output_grad: FloatTensor<Self>,
|
||||
options: DeformConvOptions<2>,
|
||||
) -> DeformConv2dBackward<Self> {
|
||||
unimplemented!("Candle does not support deformable convolutions")
|
||||
}
|
||||
|
||||
fn conv3d(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
options: ConvOptions<3>,
|
||||
) -> FloatTensor<Self> {
|
||||
panic!("Candle does not support 3D convolutions");
|
||||
}
|
||||
|
||||
fn conv_transpose1d(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
options: ConvTransposeOptions<1>,
|
||||
) -> FloatTensor<Self> {
|
||||
let conv_transpose = x
|
||||
.tensor
|
||||
.conv_transpose1d(
|
||||
&weight.tensor,
|
||||
options.padding[0],
|
||||
options.padding_out[0],
|
||||
options.stride[0],
|
||||
options.dilation[0],
|
||||
options.groups,
|
||||
)
|
||||
.unwrap();
|
||||
CandleTensor::new(match bias {
|
||||
Some(bias) => conv_transpose
|
||||
.broadcast_add(&bias.tensor.unsqueeze(0).unwrap().unsqueeze(2).unwrap())
|
||||
.unwrap(),
|
||||
None => conv_transpose,
|
||||
})
|
||||
}
|
||||
|
||||
fn conv_transpose2d(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
options: ConvTransposeOptions<2>,
|
||||
) -> FloatTensor<Self> {
|
||||
assert!(
|
||||
options.dilation[0] == options.dilation[1]
|
||||
&& options.padding[0] == options.padding[1]
|
||||
&& options.padding_out[0] == options.padding_out[1]
|
||||
&& options.stride[0] == options.stride[1],
|
||||
"Candle does not support per dimension options in transposed convolutions"
|
||||
);
|
||||
assert!(
|
||||
options.groups == 1,
|
||||
"Candle does not support groups in transposed convolutions"
|
||||
);
|
||||
let conv_transpose = x
|
||||
.tensor
|
||||
.conv_transpose2d(
|
||||
&weight.tensor,
|
||||
options.padding[0],
|
||||
options.padding_out[0],
|
||||
options.stride[0],
|
||||
options.dilation[0],
|
||||
)
|
||||
.unwrap();
|
||||
CandleTensor::new(match bias {
|
||||
Some(bias) => conv_transpose
|
||||
.broadcast_add(
|
||||
&bias
|
||||
.tensor
|
||||
.unsqueeze(0)
|
||||
.unwrap()
|
||||
.unsqueeze(2)
|
||||
.unwrap()
|
||||
.unsqueeze(3)
|
||||
.unwrap(),
|
||||
)
|
||||
.unwrap(),
|
||||
None => conv_transpose,
|
||||
})
|
||||
}
|
||||
|
||||
fn conv_transpose3d(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
options: ConvTransposeOptions<3>,
|
||||
) -> FloatTensor<Self> {
|
||||
panic!("Candle does not support 3D transposed convolutions");
|
||||
}
|
||||
|
||||
fn avg_pool2d(
|
||||
x: FloatTensor<Self>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
count_include_pad: bool,
|
||||
ceil_mode: bool,
|
||||
) -> FloatTensor<Self> {
|
||||
assert!(
|
||||
padding[0] == 0 && padding[1] == 0,
|
||||
"Candle does not support padding in pooling"
|
||||
);
|
||||
assert!(
|
||||
count_include_pad,
|
||||
"Candle does not support excluding pad count in pooling"
|
||||
);
|
||||
assert!(!ceil_mode, "Candle does not support ceil_mode in pooling");
|
||||
CandleTensor::new(
|
||||
x.tensor
|
||||
.avg_pool2d_with_stride((kernel_size[0], kernel_size[1]), (stride[0], stride[1]))
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn avg_pool2d_backward(
|
||||
x: FloatTensor<Self>,
|
||||
grad: FloatTensor<Self>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
count_include_pad: bool,
|
||||
_ceil_mode: bool,
|
||||
) -> FloatTensor<Self> {
|
||||
panic!("avg_pool2d_backward is not supported by Candle")
|
||||
}
|
||||
|
||||
fn max_pool2d(
|
||||
x: FloatTensor<Self>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
ceil_mode: bool,
|
||||
) -> FloatTensor<Self> {
|
||||
assert!(
|
||||
padding[0] == 0 && padding[1] == 0,
|
||||
"Candle does not support padding in pooling"
|
||||
);
|
||||
assert!(
|
||||
dilation[0] == 1 && dilation[1] == 1,
|
||||
"Candle does not support dilation in pooling"
|
||||
);
|
||||
assert!(!ceil_mode, "Candle does not support ceil_mode in pooling");
|
||||
CandleTensor::new(
|
||||
x.tensor
|
||||
.max_pool2d_with_stride((kernel_size[0], kernel_size[1]), (stride[0], stride[1]))
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn max_pool2d_with_indices(
|
||||
x: FloatTensor<Self>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
_ceil_mode: bool,
|
||||
) -> MaxPool2dWithIndices<Candle<F, I>> {
|
||||
panic!("max_pool2d_with_indices is not supported by Candle")
|
||||
}
|
||||
|
||||
fn max_pool2d_with_indices_backward(
|
||||
x: FloatTensor<Self>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
_ceil_mode: bool,
|
||||
output_grad: FloatTensor<Self>,
|
||||
indices: IntTensor<Self>,
|
||||
) -> MaxPool2dBackward<Candle<F, I>> {
|
||||
panic!("max_pool2d_with_indices_backward is not supported by Candle")
|
||||
}
|
||||
|
||||
fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
|
||||
panic!("adaptive_avg_pool2 is not supported by Candle")
|
||||
}
|
||||
|
||||
fn adaptive_avg_pool2d_backward(
|
||||
x: FloatTensor<Self>,
|
||||
grad: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
panic!("adaptive_avg_pool2d_backward is not supported by Candle")
|
||||
}
|
||||
|
||||
fn interpolate(
|
||||
x: FloatTensor<Self>,
|
||||
output_size: [usize; 2],
|
||||
options: InterpolateOptions,
|
||||
) -> FloatTensor<Self> {
|
||||
let tensor = match options.mode {
|
||||
InterpolateMode::Nearest => x
|
||||
.tensor
|
||||
.upsample_nearest2d(output_size[0], output_size[1])
|
||||
.unwrap(),
|
||||
InterpolateMode::Bilinear => {
|
||||
panic!("bilinear interpolation is not supported by Candle")
|
||||
}
|
||||
InterpolateMode::Bicubic => {
|
||||
panic!("bicubic interpolation is not supported by Candle")
|
||||
}
|
||||
};
|
||||
|
||||
CandleTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn interpolate_backward(
|
||||
x: FloatTensor<Self>,
|
||||
grad: FloatTensor<Self>,
|
||||
output_size: [usize; 2],
|
||||
options: InterpolateOptions,
|
||||
) -> FloatTensor<Self> {
|
||||
panic!("interpolate_backward is not supported by Candle")
|
||||
}
|
||||
|
||||
fn attention(
|
||||
query: FloatTensor<Self>,
|
||||
key: FloatTensor<Self>,
|
||||
value: FloatTensor<Self>,
|
||||
mask: Option<burn_backend::tensor::BoolTensor<Self>>,
|
||||
attn_bias: Option<FloatTensor<Self>>,
|
||||
options: burn_backend::ops::AttentionModuleOptions,
|
||||
) -> FloatTensor<Self> {
|
||||
attention_fallback::<Self>(query, key, value, mask, attn_bias, options)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
use burn_backend::{
|
||||
Backend, DType, ExecutionError, Shape, Slice, TensorData,
|
||||
ops::QTensorOps,
|
||||
quantization::{QuantScheme, QuantizationParametersPrimitive},
|
||||
tensor::{Device, FloatTensor, IntTensor, QuantizedTensor},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
Candle,
|
||||
element::{FloatCandleElement, IntCandleElement},
|
||||
};
|
||||
|
||||
impl<F: FloatCandleElement, I: IntCandleElement> QTensorOps<Self> for Candle<F, I> {
|
||||
fn q_from_data(data: TensorData, device: &Device<Self>) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn quantize(
|
||||
_tensor: FloatTensor<Self>,
|
||||
_scheme: &QuantScheme,
|
||||
_qparams: QuantizationParametersPrimitive<Self>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn dequantize(_tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_device(_tensor: &QuantizedTensor<Self>) -> Device<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_to_device(
|
||||
_tensor: QuantizedTensor<Self>,
|
||||
_device: &Device<Self>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_reshape(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_swap_dims(
|
||||
_tensor: QuantizedTensor<Self>,
|
||||
_dim1: usize,
|
||||
_dim2: usize,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_permute(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_flip(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_gather(
|
||||
_dim: usize,
|
||||
_tensor: QuantizedTensor<Self>,
|
||||
_indices: IntTensor<Self>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_select(
|
||||
_tensor: QuantizedTensor<Self>,
|
||||
_dim: usize,
|
||||
_indices: IntTensor<Self>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_slice(_tensor: QuantizedTensor<Self>, _slices: &[Slice]) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_expand(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,612 @@
|
||||
use std::borrow::Borrow;
|
||||
|
||||
use burn_backend::{
|
||||
DType, Distribution, ElementConversion, ExecutionError, FloatDType, Scalar, Shape, Slice,
|
||||
TensorData, bf16, f16,
|
||||
ops::FloatTensorOps,
|
||||
tensor::{BoolTensor, Device, FloatElem, FloatTensor, IntTensor},
|
||||
};
|
||||
use candle_core::{Tensor, backend::BackendStorage, shape};
|
||||
|
||||
use crate::{
|
||||
Candle, CandleDevice, CandleTensor, IntoDType,
|
||||
element::{CandleElement, FloatCandleElement, IntCandleElement},
|
||||
};
|
||||
|
||||
use super::base::{cpu_random, expand, permute, sign, unfold};
|
||||
|
||||
impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle<F, I> {
|
||||
fn float_from_data(data: TensorData, device: &Device<Self>) -> CandleTensor {
|
||||
match data.dtype {
|
||||
DType::F64 => super::base::from_data::<f64>(data, device),
|
||||
DType::F32 => super::base::from_data::<f32>(data, device),
|
||||
DType::F16 => super::base::from_data::<f16>(data, device),
|
||||
DType::BF16 => super::base::from_data::<bf16>(data, device),
|
||||
_ => unimplemented!("Unsupported dtype for `float_from_data`"),
|
||||
}
|
||||
}
|
||||
|
||||
fn float_random(
|
||||
shape: Shape,
|
||||
distribution: Distribution,
|
||||
device: &Device<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
if let CandleDevice::Cpu = device {
|
||||
// Use our own seed since candle doesn't support it on CPU
|
||||
return Self::float_from_data(cpu_random::<F>(shape, distribution), device);
|
||||
}
|
||||
|
||||
let shape = shape.to_vec();
|
||||
let device = &(device.clone()).into();
|
||||
match distribution {
|
||||
Distribution::Default => CandleTensor::new(
|
||||
candle_core::Tensor::rand(0.elem::<F>(), 1.elem::<F>(), shape, device)
|
||||
.unwrap()
|
||||
.to_dtype(F::DTYPE)
|
||||
.unwrap(),
|
||||
),
|
||||
Distribution::Bernoulli(prob) => CandleTensor::new(
|
||||
candle_core::Tensor::rand(0.elem::<F>(), 1.elem::<F>(), shape.clone(), device)
|
||||
.unwrap()
|
||||
.to_dtype(F::DTYPE)
|
||||
.unwrap()
|
||||
.lt(&super::candle_utils::fill(prob, shape, F::DTYPE, device))
|
||||
.unwrap()
|
||||
.to_dtype(F::DTYPE)
|
||||
.unwrap(),
|
||||
),
|
||||
Distribution::Uniform(from, to) => CandleTensor::new(
|
||||
candle_core::Tensor::rand(from.elem::<F>(), to.elem::<F>(), shape, device).unwrap(),
|
||||
),
|
||||
Distribution::Normal(mean, std) => CandleTensor::new(
|
||||
candle_core::Tensor::randn(mean.elem::<F>(), std.elem::<F>(), shape, device)
|
||||
.unwrap(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
async fn float_into_data(tensor: CandleTensor) -> Result<TensorData, ExecutionError> {
|
||||
super::base::into_data(tensor)
|
||||
}
|
||||
|
||||
fn float_device(tensor: &CandleTensor) -> Device<Self> {
|
||||
super::base::device(tensor)
|
||||
}
|
||||
|
||||
fn float_to_device(tensor: CandleTensor, device: &Device<Self>) -> CandleTensor {
|
||||
super::base::to_device(tensor, device)
|
||||
}
|
||||
|
||||
fn float_into_int(tensor: CandleTensor) -> IntTensor<Self> {
|
||||
CandleTensor::new(tensor.tensor.to_dtype(I::DTYPE).unwrap())
|
||||
}
|
||||
|
||||
fn float_empty(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
|
||||
super::base::empty(shape, device, dtype.into_dtype())
|
||||
}
|
||||
|
||||
fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
CandleTensor::new(lhs.tensor.broadcast_add(&rhs.tensor).unwrap())
|
||||
}
|
||||
|
||||
fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
|
||||
CandleTensor::new((lhs.tensor + rhs.elem::<f64>()).unwrap())
|
||||
}
|
||||
|
||||
fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
CandleTensor::new(lhs.tensor.broadcast_sub(&rhs.tensor).unwrap())
|
||||
}
|
||||
|
||||
fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
|
||||
CandleTensor::new((lhs.tensor - rhs.elem::<f64>()).unwrap())
|
||||
}
|
||||
|
||||
fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
CandleTensor::new(lhs.tensor.broadcast_mul(&rhs.tensor).unwrap())
|
||||
}
|
||||
|
||||
fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
|
||||
CandleTensor::new((lhs.tensor * rhs.elem::<f64>()).unwrap())
|
||||
}
|
||||
|
||||
fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
CandleTensor::new(lhs.tensor.broadcast_div(&rhs.tensor).unwrap())
|
||||
}
|
||||
|
||||
fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
|
||||
CandleTensor::new((lhs.tensor / rhs.elem::<f64>()).unwrap())
|
||||
}
|
||||
|
||||
fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
CandleTensor::new(
|
||||
(lhs.tensor.clone()
|
||||
- lhs
|
||||
.tensor
|
||||
.broadcast_div(&rhs.tensor)
|
||||
.unwrap()
|
||||
.floor()
|
||||
.unwrap()
|
||||
.broadcast_mul(&rhs.tensor)
|
||||
.unwrap())
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
|
||||
// In PyTorch, remainder can also be defined as torch.remainder(a, b) == a - a.div(b, rounding_mode="floor") * b
|
||||
let rhs_val = rhs.elem::<f64>();
|
||||
let division_result = (lhs.tensor.clone() / rhs_val).unwrap().floor().unwrap();
|
||||
let product = division_result * rhs_val;
|
||||
|
||||
CandleTensor::new((lhs.tensor - product).unwrap())
|
||||
}
|
||||
|
||||
fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
let lhs_contiguous = if !lhs.tensor.is_contiguous() {
|
||||
lhs.tensor.contiguous().unwrap()
|
||||
} else {
|
||||
lhs.tensor
|
||||
};
|
||||
let rhs_contiguous = if !rhs.tensor.is_contiguous() {
|
||||
rhs.tensor.contiguous().unwrap()
|
||||
} else {
|
||||
rhs.tensor
|
||||
};
|
||||
CandleTensor::new(lhs_contiguous.broadcast_matmul(&rhs_contiguous).unwrap())
|
||||
}
|
||||
|
||||
fn float_cross(
|
||||
lhs: FloatTensor<Self>,
|
||||
rhs: FloatTensor<Self>,
|
||||
dim: usize,
|
||||
) -> FloatTensor<Self> {
|
||||
super::base::cross(lhs, rhs, dim)
|
||||
}
|
||||
|
||||
fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
|
||||
super::base::swap_dims(tensor, dim1, dim2)
|
||||
}
|
||||
|
||||
fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
|
||||
super::base::reshape(tensor, shape)
|
||||
}
|
||||
|
||||
fn float_gather(
|
||||
dim: usize,
|
||||
tensor: FloatTensor<Self>,
|
||||
indices: IntTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
let tensor = tensor.tensor.contiguous().unwrap();
|
||||
let indices = indices.tensor.contiguous().unwrap();
|
||||
CandleTensor::new(tensor.gather(&indices, dim).unwrap())
|
||||
}
|
||||
|
||||
fn float_scatter_add(
|
||||
dim: usize,
|
||||
tensor: FloatTensor<Self>,
|
||||
indices: IntTensor<Self>,
|
||||
value: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
CandleTensor::new(
|
||||
tensor
|
||||
.tensor
|
||||
.scatter_add(&indices.tensor, &value.tensor, dim)
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn float_select(
|
||||
tensor: FloatTensor<Self>,
|
||||
dim: usize,
|
||||
indices: IntTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
CandleTensor::new(tensor.tensor.index_select(&indices.tensor, dim).unwrap())
|
||||
}
|
||||
|
||||
fn float_select_add(
|
||||
tensor: FloatTensor<Self>,
|
||||
dim: usize,
|
||||
indices: IntTensor<Self>,
|
||||
value: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
CandleTensor::new(
|
||||
tensor
|
||||
.tensor
|
||||
.index_add(&indices.tensor, &value.tensor, dim)
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn float_slice(tensor: FloatTensor<Self>, slices: &[Slice]) -> FloatTensor<Self> {
|
||||
super::base::slice_with_steps(tensor, slices)
|
||||
}
|
||||
|
||||
fn float_slice_assign(
|
||||
tensor: FloatTensor<Self>,
|
||||
slices: &[Slice],
|
||||
value: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
super::base::slice_assign(tensor, slices, value)
|
||||
}
|
||||
|
||||
fn float_mask_where(
|
||||
tensor: FloatTensor<Self>,
|
||||
mask: BoolTensor<Self>,
|
||||
value: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
super::base::mask_where_broadcasted(tensor, mask, value)
|
||||
}
|
||||
|
||||
fn float_mask_fill(
|
||||
tensor: FloatTensor<Self>,
|
||||
mask: BoolTensor<Self>,
|
||||
value: Scalar,
|
||||
) -> FloatTensor<Self> {
|
||||
let value = super::candle_utils::fill_like::<F>(value.elem(), &tensor.tensor);
|
||||
super::base::mask_where_broadcasted(tensor, mask, CandleTensor::new(value))
|
||||
}
|
||||
|
||||
fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
|
||||
let (lhs_broadcast, rhs_broadcast) =
|
||||
super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();
|
||||
CandleTensor::new(lhs_broadcast.eq(&rhs_broadcast).unwrap())
|
||||
}
|
||||
|
||||
fn float_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
|
||||
CandleTensor::new(lhs.tensor.eq(rhs.elem::<F>()).unwrap())
|
||||
}
|
||||
|
||||
fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
|
||||
let (lhs_broadcast, rhs_broadcast) =
|
||||
super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();
|
||||
CandleTensor::new(lhs_broadcast.gt(&rhs_broadcast).unwrap())
|
||||
}
|
||||
|
||||
fn float_greater_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
|
||||
CandleTensor::new(
|
||||
lhs.tensor
|
||||
.gt(&super::candle_utils::fill_like::<F>(
|
||||
rhs.elem(),
|
||||
&lhs.tensor,
|
||||
))
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
|
||||
let (lhs_broadcast, rhs_broadcast) =
|
||||
super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();
|
||||
CandleTensor::new(lhs_broadcast.ge(&rhs_broadcast).unwrap())
|
||||
}
|
||||
|
||||
fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
|
||||
CandleTensor::new(
|
||||
lhs.tensor
|
||||
.ge(&super::candle_utils::fill_like::<F>(
|
||||
rhs.elem(),
|
||||
&lhs.tensor,
|
||||
))
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
|
||||
let (lhs_broadcast, rhs_broadcast) =
|
||||
super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();
|
||||
CandleTensor::new(lhs_broadcast.lt(&rhs_broadcast).unwrap())
|
||||
}
|
||||
|
||||
fn float_lower_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
|
||||
CandleTensor::new(
|
||||
lhs.tensor
|
||||
.lt(&super::candle_utils::fill_like::<F>(
|
||||
rhs.elem(),
|
||||
&lhs.tensor,
|
||||
))
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
|
||||
let (lhs_broadcast, rhs_broadcast) =
|
||||
super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();
|
||||
CandleTensor::new(lhs_broadcast.le(&rhs_broadcast).unwrap())
|
||||
}
|
||||
|
||||
fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
|
||||
CandleTensor::new(
|
||||
lhs.tensor
|
||||
.le(&super::candle_utils::fill_like::<F>(
|
||||
rhs.elem(),
|
||||
&lhs.tensor,
|
||||
))
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
let sum = tensor.tensor.sum_all().unwrap().to_scalar::<F>().unwrap();
|
||||
CandleTensor::from_data::<F>(
|
||||
TensorData::new([sum].into(), [1]),
|
||||
Self::float_device(&tensor),
|
||||
)
|
||||
}
|
||||
|
||||
fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
|
||||
CandleTensor::new(tensor.tensor.sum_keepdim(dim).unwrap())
|
||||
}
|
||||
|
||||
fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
|
||||
CandleTensor::new(tensor.tensor.mean_keepdim(dim).unwrap())
|
||||
}
|
||||
|
||||
fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
|
||||
CandleTensor::new(tensor.tensor.cumsum(dim).unwrap())
|
||||
}
|
||||
|
||||
fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
|
||||
let result = super::utils::cumulative_with_op(&tensor.tensor, dim, |prev, curr| {
|
||||
prev.broadcast_mul(curr)
|
||||
});
|
||||
CandleTensor::new(result)
|
||||
}
|
||||
|
||||
fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
|
||||
let result = super::utils::cumulative_with_op(&tensor.tensor, dim, |prev, curr| {
|
||||
prev.broadcast_minimum(curr)
|
||||
});
|
||||
CandleTensor::new(result)
|
||||
}
|
||||
|
||||
fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
|
||||
let result = super::utils::cumulative_with_op(&tensor.tensor, dim, |prev, curr| {
|
||||
prev.broadcast_maximum(curr)
|
||||
});
|
||||
CandleTensor::new(result)
|
||||
}
|
||||
|
||||
fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
CandleTensor::new(tensor.tensor.exp().unwrap())
|
||||
}
|
||||
|
||||
fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
CandleTensor::new(tensor.tensor.log().unwrap())
|
||||
}
|
||||
|
||||
fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
CandleTensor::new((tensor.tensor + 1.).unwrap().log().unwrap())
|
||||
}
|
||||
|
||||
fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {
|
||||
CandleTensor::new(tensor.tensor.powf(value.elem::<f64>()).unwrap())
|
||||
}
|
||||
|
||||
fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
CandleTensor::new(tensor.tensor.sqrt().unwrap())
|
||||
}
|
||||
|
||||
fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
CandleTensor::new(tensor.tensor.abs().unwrap())
|
||||
}
|
||||
|
||||
fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
CandleTensor::new(tensor.tensor.cos().unwrap())
|
||||
}
|
||||
|
||||
fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
// cosh(x) = (e^x + e^(-x)) / 2
|
||||
let exp_x = tensor.tensor.exp().unwrap();
|
||||
CandleTensor::new(((exp_x.clone() + exp_x.recip().unwrap()).unwrap() / 2.0).unwrap())
|
||||
}
|
||||
|
||||
fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
CandleTensor::new(tensor.tensor.sin().unwrap())
|
||||
}
|
||||
|
||||
fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
// sinh(x) = (e^x - e^(-x)) / 2
|
||||
let exp_x = tensor.tensor.exp().unwrap();
|
||||
CandleTensor::new(((exp_x.clone() - exp_x.recip().unwrap()).unwrap() / 2.0).unwrap())
|
||||
}
|
||||
|
||||
fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
CandleTensor::new((tensor.tensor.sin().unwrap() / tensor.tensor.cos().unwrap()).unwrap())
|
||||
}
|
||||
|
||||
fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
CandleTensor::new(tensor.tensor.tanh().unwrap())
|
||||
}
|
||||
|
||||
fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
// acos(x) = PI/2 - asin(x)
|
||||
let neg_asin_x = Self::float_neg(Self::float_asin(tensor));
|
||||
Self::float_add_scalar(neg_asin_x, core::f64::consts::FRAC_PI_2.into())
|
||||
}
|
||||
|
||||
fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
// acosh(x) = ln(x + sqrt(x^2 - 1))
|
||||
let x_squared = Self::float_powi_scalar(tensor.clone(), 2.into());
|
||||
let x_sq_minus_one = Self::float_sub_scalar(x_squared, 1f64.into());
|
||||
let sqrt_term = Self::float_sqrt(x_sq_minus_one);
|
||||
Self::float_log(Self::float_add(tensor, sqrt_term))
|
||||
}
|
||||
|
||||
fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
// asin(x) = atan(x / sqrt(1 - x^2))
|
||||
let x_squared = Self::float_powi_scalar(tensor.clone(), 2.into());
|
||||
let one_minus_x_sq = Self::float_add_scalar(Self::float_neg(x_squared), 1f64.into());
|
||||
let sqrt_term = Self::float_sqrt(one_minus_x_sq);
|
||||
Self::float_atan(Self::float_div(tensor, sqrt_term))
|
||||
}
|
||||
|
||||
fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
// asinh(x) = ln(x + sqrt(x^2 + 1))
|
||||
let x_squared = Self::float_powi_scalar(tensor.clone(), 2.into());
|
||||
let x_sq_plus_one = Self::float_add_scalar(x_squared, 1f64.into());
|
||||
let sqrt_term = Self::float_sqrt(x_sq_plus_one);
|
||||
Self::float_log(Self::float_add(tensor, sqrt_term))
|
||||
}
|
||||
|
||||
fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
// atan(x) = asin(x / sqrt(1 + x^2))
|
||||
let x_squared = Self::float_powi_scalar(tensor.clone(), 2.into());
|
||||
let one_plus_x_sq = Self::float_add_scalar(x_squared, 1f64.into());
|
||||
let sqrt_term = Self::float_sqrt(one_plus_x_sq);
|
||||
Self::float_asin(Self::float_div(tensor, sqrt_term))
|
||||
}
|
||||
|
||||
fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
// atanh(x) = ln((1 + x) / (1 - x)) / 2
|
||||
let num = (1.0 + tensor.tensor.clone()).unwrap();
|
||||
let denom = (1.0 - tensor.tensor).unwrap();
|
||||
CandleTensor::new(((num / denom).unwrap().log().unwrap() / 2.0).unwrap())
|
||||
}
|
||||
|
||||
fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
// atan2(y, x) = 2 * atan(y / (sqrt(x^2 + y^2) + x))
|
||||
let x_squared = Self::float_powi_scalar(rhs.clone(), 2.into());
|
||||
let y_squared = Self::float_powi_scalar(lhs.clone(), 2.into());
|
||||
let r = Self::float_sqrt(Self::float_add(x_squared, y_squared));
|
||||
let ratio = Self::float_div(lhs, Self::float_add(r, rhs));
|
||||
Self::float_mul_scalar(Self::float_atan(ratio), 2f64.into())
|
||||
}
|
||||
|
||||
fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
let inner = |tensor: FloatTensor<Self>| -> candle_core::Result<FloatTensor<Self>> {
|
||||
// implements round_to_even for consistent behavior vs libtorch
|
||||
// https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/runtime/register_ops_utils.h#L65-L67
|
||||
|
||||
let floor_a = tensor.tensor.floor()?;
|
||||
let frac_part = tensor.tensor.sub(&floor_a)?;
|
||||
|
||||
let half = (candle_core::Tensor::ones_like(&tensor.tensor)? * 0.5)?;
|
||||
let mask_half = frac_part.eq(&half)?;
|
||||
let half_tensor = tensor.tensor.mul(&half)?;
|
||||
let rounded_half = half_tensor.round()?;
|
||||
let doubled =
|
||||
rounded_half.mul(&(candle_core::Tensor::ones_like(&tensor.tensor)? * 2.0)?)?;
|
||||
let standard_round = tensor.tensor.round()?;
|
||||
Ok(CandleTensor::new(
|
||||
mask_half.where_cond(&doubled, &standard_round)?,
|
||||
))
|
||||
};
|
||||
inner(tensor).unwrap()
|
||||
}
|
||||
|
||||
fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
CandleTensor::new(tensor.tensor.floor().unwrap())
|
||||
}
|
||||
|
||||
fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
CandleTensor::new(tensor.tensor.ceil().unwrap())
|
||||
}
|
||||
|
||||
fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
// truncate(x) = ⌊x⌋ if x ≥ 0, and ⌈x⌉ if x < 0
|
||||
// This preserves the sign of zero and handles all special cases correctly
|
||||
let is_negative = tensor.tensor.lt(0.0).unwrap();
|
||||
let floored = tensor.tensor.floor().unwrap();
|
||||
let ceiled = tensor.tensor.ceil().unwrap();
|
||||
CandleTensor::new(is_negative.where_cond(&ceiled, &floored).unwrap())
|
||||
}
|
||||
|
||||
fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
CandleTensor::new(tensor.tensor.erf().unwrap())
|
||||
}
|
||||
|
||||
fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
|
||||
super::base::cat(tensors, dim)
|
||||
}
|
||||
|
||||
fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
|
||||
CandleTensor::new(
|
||||
tensor
|
||||
.tensor
|
||||
.argmax_keepdim(dim)
|
||||
.unwrap()
|
||||
.to_dtype(I::DTYPE)
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
|
||||
CandleTensor::new(
|
||||
tensor
|
||||
.tensor
|
||||
.argmin_keepdim(dim)
|
||||
.unwrap()
|
||||
.to_dtype(I::DTYPE)
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn float_clamp_max(tensor: FloatTensor<Self>, max: Scalar) -> FloatTensor<Self> {
|
||||
CandleTensor::new(tensor.tensor.minimum(max.elem::<F>()).unwrap())
|
||||
}
|
||||
|
||||
fn float_clamp_min(tensor: FloatTensor<Self>, min: Scalar) -> FloatTensor<Self> {
|
||||
CandleTensor::new(tensor.tensor.maximum(min.elem::<F>()).unwrap())
|
||||
}
|
||||
|
||||
fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {
|
||||
CandleTensor::new(
|
||||
tensor
|
||||
.tensor
|
||||
.clamp(min.elem::<F>(), max.elem::<F>())
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
CandleTensor::new(tensor.tensor.recip().unwrap())
|
||||
}
|
||||
|
||||
fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
//broadcast_pow is in main but not yet published
|
||||
//note: probably replace once pow once 0.3.3 is out
|
||||
//see: https://github.com/huggingface/candle/pull/1583/files#diff-6319fa1e16dadc4c7b4e25698139703d93b70f30a1f8e2ac0999978e39efaa81R2594
|
||||
|
||||
CandleTensor::new(
|
||||
rhs.tensor
|
||||
.broadcast_mul(&lhs.tensor.log().unwrap())
|
||||
.unwrap()
|
||||
.exp()
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
|
||||
super::base::permute(tensor, axes)
|
||||
}
|
||||
|
||||
fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
|
||||
super::base::flip(tensor, axes)
|
||||
}
|
||||
|
||||
fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
|
||||
expand(tensor, shape)
|
||||
}
|
||||
|
||||
fn float_unfold(
|
||||
tensor: FloatTensor<Self>,
|
||||
dim: usize,
|
||||
size: usize,
|
||||
step: usize,
|
||||
) -> FloatTensor<Self> {
|
||||
unfold(tensor, dim, size, step)
|
||||
}
|
||||
|
||||
fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||
sign(tensor)
|
||||
}
|
||||
|
||||
fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
|
||||
let dtype = dtype.into_dtype();
|
||||
|
||||
if tensor.tensor.dtype() == dtype {
|
||||
tensor
|
||||
} else {
|
||||
CandleTensor::new(tensor.tensor.to_dtype(dtype).unwrap())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
use burn_backend::{
|
||||
Backend,
|
||||
ops::{TransactionOps, TransactionPrimitive},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
Candle,
|
||||
element::{FloatCandleElement, IntCandleElement},
|
||||
};
|
||||
|
||||
impl<F: FloatCandleElement, I: IntCandleElement> TransactionOps<Self> for Candle<F, I> {}
|
||||
@@ -0,0 +1,29 @@
|
||||
/// Helper function for cumulative operations in Candle backend
|
||||
///
|
||||
/// This function reduces code duplication for cumulative operations (cumprod, cummin, cummax)
|
||||
/// which all follow the same pattern of slicing, applying an operation, and concatenating.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The input tensor
|
||||
/// * `dim` - The dimension along which to apply the cumulative operation
|
||||
/// * `op` - A closure that takes two tensor references and produces a result tensor
|
||||
pub fn cumulative_with_op<F>(tensor: &candle_core::Tensor, dim: usize, op: F) -> candle_core::Tensor
|
||||
where
|
||||
F: Fn(&candle_core::Tensor, &candle_core::Tensor) -> candle_core::Result<candle_core::Tensor>,
|
||||
{
|
||||
let dim_size = tensor.dims()[dim];
|
||||
let mut slices = Vec::with_capacity(dim_size);
|
||||
|
||||
// First slice is the initial value
|
||||
slices.push(tensor.narrow(dim, 0, 1).unwrap());
|
||||
|
||||
// Apply cumulative operation
|
||||
for i in 1..dim_size {
|
||||
let curr = tensor.narrow(dim, i, 1).unwrap();
|
||||
let result = op(&slices[i - 1], &curr).unwrap();
|
||||
slices.push(result);
|
||||
}
|
||||
|
||||
candle_core::Tensor::cat(&slices, dim).unwrap()
|
||||
}
|
||||
@@ -0,0 +1,114 @@
|
||||
use burn_backend::{DType, FloatDType, IntDType, Shape, quantization::QuantScheme};
|
||||
use burn_backend::{Element, QTensorPrimitive, TensorData, TensorMetadata};
|
||||
|
||||
use crate::{CandleDevice, element::CandleElement};
|
||||
|
||||
/// A tensor that uses the candle backend.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CandleTensor {
|
||||
pub(crate) tensor: candle_core::Tensor,
|
||||
}
|
||||
|
||||
impl TensorMetadata for CandleTensor {
|
||||
fn dtype(&self) -> DType {
|
||||
match self.tensor.dtype() {
|
||||
candle_core::DType::U8 => DType::U8,
|
||||
candle_core::DType::U32 => DType::U32,
|
||||
candle_core::DType::I64 => DType::I64,
|
||||
candle_core::DType::BF16 => DType::BF16,
|
||||
candle_core::DType::F16 => DType::F16,
|
||||
candle_core::DType::F32 => DType::F32,
|
||||
candle_core::DType::F64 => DType::F64,
|
||||
candle_core::DType::I16 => DType::I16,
|
||||
candle_core::DType::I32 => DType::I32,
|
||||
other => todo!("{other:?} not yet supported"),
|
||||
}
|
||||
}
|
||||
|
||||
fn shape(&self) -> Shape {
|
||||
Shape::from(self.tensor.dims().to_vec())
|
||||
}
|
||||
|
||||
fn rank(&self) -> usize {
|
||||
self.tensor.dims().len()
|
||||
}
|
||||
}
|
||||
|
||||
impl QTensorPrimitive for CandleTensor {
|
||||
fn scheme(&self) -> &QuantScheme {
|
||||
unimplemented!("Quantization is not supported")
|
||||
}
|
||||
}
|
||||
|
||||
impl CandleTensor {
|
||||
/// Create a new tensor.
|
||||
pub fn new(tensor: candle_core::Tensor) -> Self {
|
||||
Self { tensor }
|
||||
}
|
||||
|
||||
/// Creates a new tensor from data and a device.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `data` - The tensor's data.
|
||||
/// * `device` - The device on which the tensor will be allocated.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new tensor.
|
||||
pub fn from_data<E: CandleElement>(data: TensorData, device: CandleDevice) -> Self {
|
||||
let candle_shape: candle_core::Shape = data.shape.clone().into();
|
||||
let tensor = candle_core::Tensor::from_slice(
|
||||
data.as_slice::<E>().unwrap(),
|
||||
candle_shape,
|
||||
&device.into(),
|
||||
);
|
||||
Self::new(tensor.unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) trait IntoDType {
|
||||
fn try_into_dtype(self) -> Result<candle_core::DType, candle_core::Error>;
|
||||
|
||||
fn into_dtype(self) -> candle_core::DType
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
self.try_into_dtype().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoDType for IntDType {
|
||||
fn try_into_dtype(self) -> Result<candle_core::DType, candle_core::Error> {
|
||||
let dtype: DType = self.into();
|
||||
dtype.try_into_dtype()
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoDType for FloatDType {
|
||||
fn try_into_dtype(self) -> Result<candle_core::DType, candle_core::Error> {
|
||||
let dtype: DType = self.into();
|
||||
dtype.try_into_dtype()
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoDType for DType {
|
||||
fn try_into_dtype(self) -> Result<candle_core::DType, candle_core::Error> {
|
||||
match self {
|
||||
DType::F64 => Ok(candle_core::DType::F64),
|
||||
DType::F32 => Ok(candle_core::DType::F32),
|
||||
DType::Flex32 => Ok(candle_core::DType::F32),
|
||||
DType::F16 => Ok(candle_core::DType::F16),
|
||||
DType::BF16 => Ok(candle_core::DType::BF16),
|
||||
DType::I64 => Ok(candle_core::DType::I64),
|
||||
DType::U32 => Ok(candle_core::DType::U32),
|
||||
DType::U8 => Ok(candle_core::DType::U8),
|
||||
DType::I16 => Ok(candle_core::DType::I16),
|
||||
DType::I32 => Ok(candle_core::DType::I32),
|
||||
// DType::Bool => Ok(candle_core::DType::U8),
|
||||
_ => Err(candle_core::Error::Msg(format!(
|
||||
"Unsupported dtype {self:?}"
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user