feat: update workspace paths and enhance gitignore

- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
This commit is contained in:
2026-03-05 19:39:14 +01:00
parent 4bb7ca9074
commit 3a67c0979c
1605 changed files with 537032 additions and 2 deletions

View File

@@ -0,0 +1,57 @@
[package]
authors = ["Dilshod Tadjibaev (@antimora)"]
categories = []
description = "Core types and utilities shared across the Burn ecosystem."
documentation = "https://docs.rs/burn-std"
edition.workspace = true
keywords = []
license.workspace = true
name = "burn-std"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-std"
version.workspace = true
[lints]
workspace = true
[features]
cubecl = ["dep:cubecl"]
default = ["std", "cubecl-common/default"]
doc = ["default"]
std = ["cubecl-common/std", "num-traits/std"]
tracing = ["cubecl?/tracing", "cubecl-common/tracing"]
network = ["dep:indicatif", "dep:reqwest", "dep:tokio"]
[dependencies]
bytemuck = { workspace = true, features = ["extern_crate_alloc"] }
half = { workspace = true, features = ["bytemuck"] }
num-traits = { workspace = true }
serde = { workspace = true }
smallvec = { workspace = true, features = ["serde"] }
cubecl = { workspace = true, optional = true, default-features = false }
cubecl-common = { workspace = true, default-features = false, features = [
"serde",
"shared-bytes",
] }
cubecl-zspace = { workspace = true, default-features = false }
# Enable extra-platforms for portable-atomic support on targets without native atomics (e.g., thumbv6m)
# This is needed because cubecl-common's shared-bytes feature pulls in bytes
bytes = { workspace = true }
# Network downloader
indicatif = { workspace = true, optional = true }
reqwest = { workspace = true, optional = true }
tokio = { workspace = true, optional = true }
[dev-dependencies]
dashmap = { workspace = true }
# Enable extra-platforms for bytes on targets without native atomics (e.g., thumbv6m-none-eabi)
[target.'cfg(not(target_has_atomic = "ptr"))'.dependencies]
bytes = { workspace = true, features = ["extra-platforms"] }
[package.metadata.docs.rs]
features = ["doc"]
rustdoc-args = ["--cfg", "docsrs"]

View File

@@ -0,0 +1 @@
../../LICENSE-APACHE

View File

@@ -0,0 +1 @@
../../LICENSE-MIT

View File

@@ -0,0 +1,7 @@
# Burn Standard Library
`burn-std` provides the core types and utilities shared across the Burn ecosystem.
It includes foundational definitions for shapes, indexing, and data types.
This crate supports both `std` and `no_std` environments and must compile with
`cargo build --no-default-features` as well.

View File

@@ -0,0 +1,69 @@
//! # Unique Identifiers
use crate::rand::gen_random;
/// Simple ID generator.
pub struct IdGenerator {}
impl IdGenerator {
/// Generates a new ID.
pub fn generate() -> u64 {
// Generate a random u64 (18,446,744,073,709,551,615 combinations)
let random_bytes: [u8; 8] = gen_random();
u64::from_le_bytes(random_bytes)
}
}
pub use cubecl_common::stream_id::StreamId;
#[cfg(test)]
mod tests {
use super::*;
use alloc::collections::BTreeSet;
#[cfg(feature = "std")]
use dashmap::DashSet; //Concurrent HashMap
#[cfg(feature = "std")]
use std::{sync::Arc, thread};
#[test]
fn uniqueness_test() {
const IDS_CNT: usize = 10_000;
let mut set: BTreeSet<u64> = BTreeSet::new();
for _i in 0..IDS_CNT {
assert!(set.insert(IdGenerator::generate()));
}
assert_eq!(set.len(), IDS_CNT);
}
#[cfg(feature = "std")]
#[test]
fn thread_safety_test() {
const NUM_THREADS: usize = 10;
const NUM_REPEATS: usize = 1_000;
const EXPECTED_TOTAL_IDS: usize = NUM_THREADS * NUM_REPEATS;
let set: Arc<DashSet<u64>> = Arc::new(DashSet::new());
let mut handles = vec![];
for _ in 0..NUM_THREADS {
let set = set.clone();
let handle = thread::spawn(move || {
for _i in 0..NUM_REPEATS {
assert!(set.insert(IdGenerator::generate()));
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(set.len(), EXPECTED_TOTAL_IDS);
}
}

View File

@@ -0,0 +1,97 @@
#![cfg_attr(not(feature = "std"), no_std)]
#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_cfg))]
//! # Burn Standard Library
//!
//! This library contains core types and utilities shared across Burn, including shapes, indexing,
//! and data types.
extern crate alloc;
/// Id module contains types for unique identifiers.
pub mod id;
/// Tensor utilities.
pub mod tensor;
pub use tensor::*;
/// Common Errors.
pub use cubecl_zspace::errors::{self, *};
/// Network utilities.
#[cfg(feature = "network")]
pub mod network;
// Re-exported types
pub use cubecl_common::bytes::*;
pub use cubecl_common::*;
pub use half::{bf16, f16};
#[cfg(feature = "cubecl")]
pub use cubecl::flex32;
#[cfg(feature = "cubecl")]
mod cube {
use cubecl::ir::{ElemType, FloatKind, IntKind, StorageType, UIntKind};
use cubecl_common::quant::scheme::QuantScheme;
use crate::tensor::DType;
use crate::tensor::quantization::{QuantStore, QuantValue};
impl From<DType> for cubecl::ir::ElemType {
fn from(dtype: DType) -> Self {
match dtype {
DType::F64 => ElemType::Float(FloatKind::F64),
DType::F32 => ElemType::Float(FloatKind::F32),
DType::Flex32 => ElemType::Float(FloatKind::Flex32),
DType::F16 => ElemType::Float(FloatKind::F16),
DType::BF16 => ElemType::Float(FloatKind::BF16),
DType::I64 => ElemType::Int(IntKind::I64),
DType::I32 => ElemType::Int(IntKind::I32),
DType::I16 => ElemType::Int(IntKind::I16),
DType::I8 => ElemType::Int(IntKind::I8),
DType::U64 => ElemType::UInt(UIntKind::U64),
DType::U32 => ElemType::UInt(UIntKind::U32),
DType::U16 => ElemType::UInt(UIntKind::U16),
DType::U8 => ElemType::UInt(UIntKind::U8),
DType::Bool => ElemType::Bool,
DType::QFloat(scheme) => match scheme.store {
QuantStore::Native => match scheme.value {
QuantValue::Q8F | QuantValue::Q8S => Self::Int(IntKind::I8),
QuantValue::E4M3 => Self::Float(FloatKind::E4M3),
QuantValue::E5M2 => Self::Float(FloatKind::E5M2),
QuantValue::Q4F
| QuantValue::Q4S
| QuantValue::Q2F
| QuantValue::Q2S
| QuantValue::E2M1 => {
panic!("Can't store native sub-byte values")
}
},
QuantStore::PackedU32(_) => Self::UInt(UIntKind::U32),
QuantStore::PackedNative(_) => match scheme.value {
QuantValue::E2M1 => panic!("Can't store native sub-byte values"),
other => panic!("{other:?} doesn't support native packing"),
},
},
}
}
}
impl From<DType> for cubecl::ir::StorageType {
fn from(dtype: DType) -> cubecl::ir::StorageType {
match dtype {
DType::QFloat(QuantScheme {
store: QuantStore::PackedNative(_),
value: QuantValue::E2M1,
..
}) => StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2),
_ => {
let elem: ElemType = dtype.into();
elem.into()
}
}
}
}
}

View File

@@ -0,0 +1,57 @@
//! # Common Network Utilities
/// Network download utilities.
pub mod downloader {
use indicatif::{ProgressBar, ProgressState, ProgressStyle};
use reqwest::Client;
use std::io::Write;
/// Download the file at the specified url.
/// File download progress is reported with the help of a [progress bar](indicatif).
///
/// # Arguments
///
/// * `url` - The file URL to download.
/// * `message` - The message to display on the progress bar during download.
///
/// # Returns
///
/// A vector of bytes containing the downloaded file data.
#[tokio::main(flavor = "current_thread")]
pub async fn download_file_as_bytes(url: &str, message: &str) -> Vec<u8> {
// Get file from web
let mut response = Client::new().get(url).send().await.unwrap();
let total_size = response.content_length().unwrap();
// Pretty progress bar
let pb = ProgressBar::new(total_size);
let msg = message.to_owned();
pb.set_style(
ProgressStyle::with_template(
"{msg}\n {wide_bar:.cyan/blue} {bytes}/{total_bytes} ({eta})",
)
.unwrap()
.with_key(
"eta",
|state: &ProgressState, w: &mut dyn std::fmt::Write| {
write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()
},
)
.progress_chars(""),
);
pb.set_message(msg.clone());
// Read stream into bytes
let mut downloaded: u64 = 0;
let mut bytes: Vec<u8> = Vec::with_capacity(total_size as usize);
while let Some(chunk) = response.chunk().await.unwrap() {
let num_bytes = bytes.write(&chunk).unwrap();
let new = std::cmp::min(downloaded + (num_bytes as u64), total_size);
downloaded = new;
pb.set_position(new);
}
pb.finish_with_message(msg);
bytes
}
}

View File

@@ -0,0 +1,224 @@
//! Tensor data type.
use serde::{Deserialize, Serialize};
use crate::tensor::quantization::{QuantScheme, QuantStore, QuantValue};
use crate::{bf16, f16};
#[allow(missing_docs)]
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub enum DType {
F64,
F32,
Flex32,
F16,
BF16,
I64,
I32,
I16,
I8,
U64,
U32,
U16,
U8,
Bool,
QFloat(QuantScheme),
}
#[cfg(feature = "cubecl")]
impl From<cubecl::ir::ElemType> for DType {
fn from(value: cubecl::ir::ElemType) -> Self {
match value {
cubecl::ir::ElemType::Float(float_kind) => match float_kind {
cubecl::ir::FloatKind::F16 => DType::F16,
cubecl::ir::FloatKind::BF16 => DType::BF16,
cubecl::ir::FloatKind::Flex32 => DType::Flex32,
cubecl::ir::FloatKind::F32 => DType::F32,
cubecl::ir::FloatKind::F64 => DType::F64,
cubecl::ir::FloatKind::TF32 => panic!("Not a valid DType for tensors."),
cubecl::ir::FloatKind::E2M1
| cubecl::ir::FloatKind::E2M3
| cubecl::ir::FloatKind::E3M2
| cubecl::ir::FloatKind::E4M3
| cubecl::ir::FloatKind::E5M2
| cubecl::ir::FloatKind::UE8M0 => {
unimplemented!("Not yet supported, will be used for quantization")
}
},
cubecl::ir::ElemType::Int(int_kind) => match int_kind {
cubecl::ir::IntKind::I8 => DType::I8,
cubecl::ir::IntKind::I16 => DType::I16,
cubecl::ir::IntKind::I32 => DType::I32,
cubecl::ir::IntKind::I64 => DType::I64,
},
cubecl::ir::ElemType::UInt(uint_kind) => match uint_kind {
cubecl::ir::UIntKind::U8 => DType::U8,
cubecl::ir::UIntKind::U16 => DType::U16,
cubecl::ir::UIntKind::U32 => DType::U32,
cubecl::ir::UIntKind::U64 => DType::U64,
},
_ => panic!("Not a valid DType for tensors."),
}
}
}
impl DType {
/// Returns the size of a type in bytes.
pub const fn size(&self) -> usize {
match self {
DType::F64 => core::mem::size_of::<f64>(),
DType::F32 => core::mem::size_of::<f32>(),
DType::Flex32 => core::mem::size_of::<f32>(),
DType::F16 => core::mem::size_of::<f16>(),
DType::BF16 => core::mem::size_of::<bf16>(),
DType::I64 => core::mem::size_of::<i64>(),
DType::I32 => core::mem::size_of::<i32>(),
DType::I16 => core::mem::size_of::<i16>(),
DType::I8 => core::mem::size_of::<i8>(),
DType::U64 => core::mem::size_of::<u64>(),
DType::U32 => core::mem::size_of::<u32>(),
DType::U16 => core::mem::size_of::<u16>(),
DType::U8 => core::mem::size_of::<u8>(),
DType::Bool => core::mem::size_of::<bool>(),
DType::QFloat(scheme) => match scheme.store {
QuantStore::Native => match scheme.value {
QuantValue::Q8F | QuantValue::Q8S => core::mem::size_of::<i8>(),
// e2m1 native is automatically packed by the kernels, so the actual storage is
// 8 bits wide.
QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => {
core::mem::size_of::<u8>()
}
QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
// Sub-byte values have fractional size
0
}
},
QuantStore::PackedU32(_) => core::mem::size_of::<u32>(),
QuantStore::PackedNative(_) => match scheme.value {
QuantValue::E2M1 => core::mem::size_of::<u8>(),
_ => 0,
},
},
}
}
/// Returns true if the data type is a floating point type.
pub fn is_float(&self) -> bool {
matches!(
self,
DType::F64 | DType::F32 | DType::Flex32 | DType::F16 | DType::BF16
)
}
/// Returns true if the data type is a signed integer type.
pub fn is_int(&self) -> bool {
matches!(self, DType::I64 | DType::I32 | DType::I16 | DType::I8)
}
/// Returns true if the data type is an unsigned integer type.
pub fn is_uint(&self) -> bool {
matches!(self, DType::U64 | DType::U32 | DType::U16 | DType::U8)
}
/// Returns true if the data type is a boolean type
pub fn is_bool(&self) -> bool {
matches!(self, DType::Bool)
}
/// Returns the data type name.
pub fn name(&self) -> &'static str {
match self {
DType::F64 => "f64",
DType::F32 => "f32",
DType::Flex32 => "flex32",
DType::F16 => "f16",
DType::BF16 => "bf16",
DType::I64 => "i64",
DType::I32 => "i32",
DType::I16 => "i16",
DType::I8 => "i8",
DType::U64 => "u64",
DType::U32 => "u32",
DType::U16 => "u16",
DType::U8 => "u8",
DType::Bool => "bool",
DType::QFloat(_) => "qfloat",
}
}
}
#[allow(missing_docs)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum FloatDType {
F64,
F32,
Flex32,
F16,
BF16,
}
impl From<DType> for FloatDType {
fn from(value: DType) -> Self {
match value {
DType::F64 => FloatDType::F64,
DType::F32 => FloatDType::F32,
DType::Flex32 => FloatDType::Flex32,
DType::F16 => FloatDType::F16,
DType::BF16 => FloatDType::BF16,
_ => panic!("Expected float data type, got {value:?}"),
}
}
}
impl From<FloatDType> for DType {
fn from(value: FloatDType) -> Self {
match value {
FloatDType::F64 => DType::F64,
FloatDType::F32 => DType::F32,
FloatDType::Flex32 => DType::Flex32,
FloatDType::F16 => DType::F16,
FloatDType::BF16 => DType::BF16,
}
}
}
#[allow(missing_docs)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum IntDType {
I64,
I32,
I16,
I8,
U64,
U32,
U16,
U8,
}
impl From<DType> for IntDType {
fn from(value: DType) -> Self {
match value {
DType::I64 => IntDType::I64,
DType::I32 => IntDType::I32,
DType::I16 => IntDType::I16,
DType::I8 => IntDType::I8,
DType::U64 => IntDType::U64,
DType::U32 => IntDType::U32,
DType::U16 => IntDType::U16,
DType::U8 => IntDType::U8,
_ => panic!("Expected int data type, got {value:?}"),
}
}
}
impl From<IntDType> for DType {
fn from(value: IntDType) -> Self {
match value {
IntDType::I64 => DType::I64,
IntDType::I32 => DType::I32,
IntDType::I16 => DType::I16,
IntDType::I8 => DType::I8,
IntDType::U64 => DType::U64,
IntDType::U32 => DType::U32,
IntDType::U16 => DType::U16,
IntDType::U8 => DType::U8,
}
}
}

View File

@@ -0,0 +1,221 @@
pub mod dtype;
pub mod quantization;
pub mod shape;
pub mod slice;
pub use dtype::*;
pub use quantization::*;
pub use shape::*;
pub use slice::*;
pub use cubecl_zspace::indexing::{self, *};
pub use cubecl_zspace::{Strides, metadata::Metadata, strides};
/// Check if the current tensor is contiguous.
///
/// A tensor is considered contiguous if its elements are stored in memory
/// such that the stride at position `k` is equal to the product of the shapes
/// of all dimensions greater than `k`.
///
/// This means that strides increase as you move from the rightmost to the leftmost dimension.
pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
if shape.is_empty() {
return true;
}
for (&expected, &stride) in contiguous_strides(shape).iter().zip(strides) {
if expected != stride {
return false;
}
}
true
}
/// Computes the strides for a contiguous tensor with the given shape.
///
/// In a contiguous row-major tensor, the stride for each dimension
/// equals the product of all dimension sizes to its right.
pub fn contiguous_strides(shape: &[usize]) -> Strides {
let mut strides = strides![0; shape.len()];
let mut current = 1;
for (i, &dim) in shape.iter().enumerate().rev() {
strides[i] = current;
current *= dim;
}
strides
}
/// The action to take for a reshape operation.
#[derive(Debug)]
pub enum ReshapeAction {
/// Updating the strides is sufficient to handle the reshape.
UpdateStrides {
/// The new strides.
strides: Strides,
},
/// The strides are not compatible, we should recompute the buffer.
Recompute,
/// The strides are already correct.
NoChange,
}
/// The reshape kind.
#[derive(Debug)]
pub enum ReshapeAnalysis {
/// Original tensor is contiguous, can update the strides.
IsContiguous,
/// Original tensor is highly permutated, can't update the strides.
HighlyPermuted,
/// Only batch dimensions are added, can update the strides.
Broadcasted,
/// Dimensions are only split, can update the strides.
Split,
/// Original tensor is bigger than output shape.
SmallerRank,
/// New shape is the same.
NoChange,
}
impl ReshapeAnalysis {
/// Returns the proper action to take for the current analysis.
fn action(self, shape: &[usize], strides: &[usize], shape_new: &[usize]) -> ReshapeAction {
match self {
ReshapeAnalysis::IsContiguous => ReshapeAction::UpdateStrides {
strides: contiguous_strides(shape_new),
},
ReshapeAnalysis::NoChange => ReshapeAction::NoChange,
ReshapeAnalysis::HighlyPermuted | ReshapeAnalysis::SmallerRank => {
ReshapeAction::Recompute
}
ReshapeAnalysis::Broadcasted => {
let shape_rank = shape.len();
let shape_new_rank = shape_new.len();
let n_new_batch = shape_new_rank - shape_rank;
let num_elems = shape.iter().product::<usize>();
let strides_new = broadcast_strides(n_new_batch, shape_rank, num_elems, strides);
ReshapeAction::UpdateStrides {
strides: strides_new,
}
}
ReshapeAnalysis::Split => {
let strides_new = split_strides(shape, strides, shape_new);
ReshapeAction::UpdateStrides {
strides: strides_new,
}
}
}
}
}
/// Returns the proper action to take when reshaping a tensor.
pub fn reshape_action(shape: &[usize], strides: &[usize], shape_new: &[usize]) -> ReshapeAction {
reshape_analysis(shape, Some(strides), shape_new).action(shape, strides, shape_new)
}
/// Calculate the new strides given added batch dimensions.
pub fn broadcast_strides(
n_new_batch: usize,
rank_prev: usize,
num_elems: usize,
strides: &[usize],
) -> Strides {
let mut strides_new = strides![num_elems; rank_prev + n_new_batch];
for (i, s) in strides.iter().enumerate() {
strides_new[i + n_new_batch] = *s;
}
strides_new
}
/// Calculate the new strides given added split dimensions.
pub fn split_strides(shape: &[usize], strides: &[usize], shape_new: &[usize]) -> Strides {
let mut strides_new = strides![1; shape_new.len()];
let mut old_idx = shape.len() - 1;
let mut current_stride = strides[old_idx];
let mut dim_prod = 1;
for (i, dim) in shape_new.iter().enumerate().rev() {
dim_prod *= *dim;
strides_new[i] = current_stride;
if *dim == 1 {
continue;
} else if dim_prod == shape[old_idx] {
old_idx = old_idx.saturating_sub(1);
current_stride = strides[old_idx];
dim_prod = 1;
} else {
current_stride *= *dim;
}
}
strides_new
}
/// Returns the analysis of a reshape operation.
pub fn reshape_analysis(
shape: &[usize],
strides: Option<&[usize]>,
shape_new: &[usize],
) -> ReshapeAnalysis {
let shape_rank = shape.len();
let shape_new_rank = shape_new.len();
let is_contiguous = match strides {
Some(strides) => is_contiguous(shape, strides),
None => false,
};
if is_contiguous {
return ReshapeAnalysis::IsContiguous;
}
if shape_new_rank < shape_rank {
return ReshapeAnalysis::SmallerRank;
}
let n_new_batch = shape_new_rank - shape_rank;
match n_new_batch > 0 {
true => {
if shape == &shape_new[n_new_batch..shape_new_rank]
&& shape_new[0..n_new_batch].iter().all(|it| *it == 1)
{
return ReshapeAnalysis::Broadcasted;
} else {
let mut dim_prod = 1;
let mut old_idx = 0;
for dim in shape_new {
dim_prod *= *dim;
// We need to ignore unit dims because they don't affect analysis and break
// things because they match the default `dim_prod`. If we don't do this,
// reshapes like [2, 3] to [2, 3, 1] will panic from out of bounds access.
if *dim == 1 {
continue;
} else if dim_prod == shape[old_idx] {
dim_prod = 1;
old_idx += 1;
} else if dim_prod > shape[old_idx] {
return ReshapeAnalysis::HighlyPermuted;
}
}
return ReshapeAnalysis::Split;
}
}
false => {
if shape == shape_new {
return ReshapeAnalysis::NoChange;
}
}
};
ReshapeAnalysis::HighlyPermuted
}

View File

@@ -0,0 +1,393 @@
//! Quantization data representation.
// Re-exported types
pub use cubecl_common::quant::scheme::{
BlockSize, QuantLevel, QuantMode, QuantParam, QuantScheme, QuantStore, QuantValue,
};
/// Alignment (in bytes) for quantization parameters in serialized tensor data.
///
/// NOTE: This is currently f32-based since scales were originally always f32.
/// With `QuantParam` now supporting different precisions (F16, BF16, etc.),
/// this alignment may need to be revisited in the future.
pub const QPARAM_ALIGN: usize = core::mem::align_of::<f32>();
use alloc::vec::Vec;
use core::any::TypeId;
use num_traits::PrimInt;
use serde::{Deserialize, Serialize};
use crate::{DType, Metadata, Shape, bytes::Bytes};
#[derive(
Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default,
)]
/// The precision of accumulating elements.
pub enum QuantAcc {
/// Full precision.
#[default]
F32,
/// Half precision.
F16,
/// bfloat16 precision.
BF16,
}
/// Specify if the output of an operation is quantized using the scheme of the input
/// or returned unquantized.
#[derive(
Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default,
)]
pub enum QuantPropagation {
/// The output is quantized using the scheme of the input.
Propagate,
/// The output is not quantized.
#[default]
Inhibit,
}
/// The quantization tensor data parameters.
#[derive(Clone, Debug)]
pub struct QParams<S> {
/// The scaling factor.
pub scales: S,
}
/// A quantization parameter tensor descriptor.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct QParamTensor {
/// Start of the tensor in the buffer
pub offset_start: usize,
/// Offset of tensor end from the end of the buffer
pub offset_end: usize,
/// Metadata of the tensor
pub metadata: Metadata,
/// Data type of the tensor
pub dtype: DType,
}
/// Calculate the shape of the quantization parameters for a given tensor and level
pub fn params_shape(data_shape: &Shape, level: QuantLevel) -> Shape {
match level {
QuantLevel::Tensor => Shape::new([1]),
QuantLevel::Block(block_size) => {
let mut params_shape = data_shape.clone();
let block_size = block_size.to_dim_vec(data_shape.num_dims());
for (shape, block_size) in params_shape.iter_mut().zip(block_size) {
*shape = (*shape).div_ceil(block_size as usize);
}
params_shape
}
}
}
/// Quantized data bytes representation.
///
/// # Notes
/// 1) The quantized values are packed into 32-bit unsigned integers. For example, int8
/// quantized values pack 4 grouped values into a single `u32`. When unpacking these values,
/// we make sure to retrieve only the meaningful values (and ignore the alignment padding).
/// 2) Quantization parameters are appended to the tensor data.
/// As such, the last bytes always correspond to the scale parameter.
/// If the quantization scheme includes an offset (zero-point) parameter, it is next to last.
pub struct QuantizedBytes {
/// The quantized values and quantization parameters represented as bytes.
pub bytes: Bytes,
/// The quantization scheme.
pub scheme: QuantScheme,
/// The number of quantized elements.
pub num_elements: usize,
}
impl QuantizedBytes {
/// Creates a new quantized bytes representation.
pub fn new<E: bytemuck::CheckedBitPattern + bytemuck::NoUninit>(
value: Vec<E>,
scheme: QuantScheme,
scales: &[f32],
) -> Self {
let num_elements = value.len();
// Only used for 8-bit quantization data comparison in tests
if TypeId::of::<E>() != TypeId::of::<i8>() {
panic!("Invalid quantized type");
}
// Re-interpret `Vec<E>` as `Vec<i8>` with `Vec::from_raw_parts`
let i8s: Vec<i8> = bytemuck::allocation::cast_vec(value);
let mut bytes = Bytes::from_elems(i8s);
match scheme.level {
QuantLevel::Tensor => {
let scale_bytes = bytemuck::bytes_of(&scales[0]);
bytes.extend_from_byte_slice_aligned(scale_bytes, QPARAM_ALIGN);
}
QuantLevel::Block(_block_size) => {
let mut scale_bytes = Vec::with_capacity(size_of_val(scales));
for scale in scales {
scale_bytes.extend_from_slice(bytemuck::bytes_of(scale));
}
bytes.extend_from_byte_slice_aligned(scale_bytes.as_slice(), QPARAM_ALIGN);
}
}
Self {
bytes,
scheme,
num_elements,
}
}
/// Returns the int8 quantized values with the quantization parameters.
pub fn into_vec_i8(self) -> (Vec<i8>, QParams<Vec<f32>>) {
let (values, (qparams, num_params)) = self.split_values_off();
// Quantization parameters are added at the end of the tensor data.
// As such, the last bytes always correspond to the scale parameter(s).
// For example, per-block quantization can have multiple parameters for a single tensor:
// [scale, scale, scale, ...]
let scale_size = core::mem::size_of::<f32>(); // scale is stored as f32
let qparams_bytes: &[u8] = bytemuck::cast_slice(&qparams);
let total_bytes = qparams_bytes.len();
let scales_size = scale_size * num_params;
let scales = bytemuck::cast_slice(&qparams_bytes[total_bytes - scales_size..]).to_vec();
(values, QParams { scales })
}
fn split_i8_values(self, num_params: usize) -> (Vec<i8>, Vec<u32>) {
let mut values = read_bytes_to_i8(self.bytes);
let scale_size = num_params * size_of::<f32>();
let values_end = values.len() - scale_size;
let qparams = values.split_off(values_end);
let qparams = if (qparams.as_ptr() as usize).is_multiple_of(4) {
let mut qparams = core::mem::ManuallyDrop::new(qparams);
unsafe {
Vec::<u32>::from_raw_parts(
qparams.as_mut_ptr() as _,
qparams.len() / 4,
qparams.capacity() / 4,
)
}
} else {
#[cfg(target_endian = "little")]
{
// SAFETY: quantized bytes representation is created from packed u32 values in little endian
bytemuck::cast_vec(qparams)
}
#[cfg(target_endian = "big")]
{
crate::quantization::pack_i8s_to_u32s(bytemuck::cast_vec(qparams))
}
};
(values, qparams)
}
/// Splits the quantized values of the tensor from the quantization parameters.
///
/// Returns the values in i8 and a newly allocated vector containing the quantization parameters.
fn split_values_off(self) -> (Vec<i8>, (Vec<u32>, usize)) {
let num_params = match self.scheme.level {
QuantLevel::Tensor => 1,
QuantLevel::Block(block_size) => self.num_elements / block_size.num_elements(),
};
if let QuantStore::PackedU32(packed_dim) = self.scheme.store {
assert_eq!(
packed_dim, 0,
"Packing must be on innermost dimension for splitting off values"
);
}
let (values, qparams) = match self.scheme.store {
QuantStore::Native => self.split_i8_values(num_params),
QuantStore::PackedU32(_) => match self.scheme.value {
QuantValue::Q8F | QuantValue::Q8S => self.split_i8_values(num_params),
QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
let mut values = self.bytes.try_into_vec::<u32>().unwrap();
let scale_size = num_params; // size of f32 same as u32
let values_end = values.len() - scale_size;
let qparams = values.split_off(values_end);
// Sub-byte values are unpacked as i8s for value equality tests
let values = unpack_q_to_i8s(&values, self.num_elements, &self.scheme.value);
(values, qparams)
}
QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => {
unimplemented!("Not yet supported")
}
},
QuantStore::PackedNative(_) => unimplemented!("Not yet supported"),
};
(values, (qparams, num_params))
}
}
fn read_bytes_to_i8(bytes: Bytes) -> Vec<i8> {
match bytes.try_into_vec::<i8>() {
Ok(val) => val,
// Safety,
//
// `Vec<u8>` can be Re-interpreted as `Vec<i8>` since they share the same alignment.
Err(bytes) => unsafe { core::mem::transmute::<Vec<u8>, Vec<i8>>(bytes.to_vec()) },
}
}
/// Pack signed 8-bit integer values into a sequence of unsigned 32-bit integers.
pub fn pack_i8s_to_u32s(values: Vec<i8>) -> Vec<u32> {
// Shift and combine groups of four 8-bit values into a u32.
// Same as doing this:
// let result = (d_u8 & 0xFF) << 24 | (c_u8 & 0xFF) << 16 | (b_u8 & 0xFF) << 8 | (a_u8 & 0xFF);
#[cfg(target_endian = "big")]
{
values
.chunks(4)
.map(|x| {
x.iter()
.enumerate()
.fold(0u32, |acc, (i, x)| acc | (*x as u32 & 0xFF) << (i * 8))
})
.collect()
}
// The order of bytes in little endian matches the above description, we just need to
// handle padding when the number of values is not a factor of 4
#[cfg(target_endian = "little")]
{
let mut values = values;
let remainder = values.len() % 4;
if remainder != 0 {
// Pad with zeros
values.extend(core::iter::repeat_n(0, 4 - remainder));
}
let len = values.len() / 4;
let capacity = values.capacity() / 4;
// Pre-forget the old vec and re-interpret as u32
let mut values = core::mem::ManuallyDrop::new(values);
let ptr = values.as_mut_ptr() as *mut u32;
unsafe { Vec::from_raw_parts(ptr, len, capacity) }
}
}
/// Unpack integer values into a sequence of signed 8-bit integers.
pub(crate) fn unpack_q_to_i8s<Q: PrimInt>(
values: &[Q],
numel: usize,
value: &QuantValue,
) -> Vec<i8> {
let size_store = size_of::<Q>() * 8;
let size_quant = value.size_bits();
let num_quants = size_store / size_quant;
let mask = Q::from((1 << size_quant) - 1).unwrap();
let sign_shift = 8 - size_quant; // sign extension for sub-byte values
values
.iter()
.enumerate()
.flat_map(|(i, &packed)| {
// A single u32 could contain less than four 8-bit values...
let n = core::cmp::min(num_quants, numel - i * num_quants);
// Extract each 8-bit segment from u32 and cast back to i8
// Same as doing this (when 4 values are fully packed):
// let a = (packed & 0xFF) as i8;
// let b = ((packed >> 8) & 0xFF) as i8;
// let c = ((packed >> 16) & 0xFF) as i8;
// let d = ((packed >> 24) & 0xFF) as i8;
(0..n).map(move |i| {
let raw = (packed >> (i * size_quant) & mask).to_u8().unwrap();
((raw << sign_shift) as i8) >> sign_shift
})
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
#[test]
fn should_pack_i8s_to_u32() {
let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127]);
assert_eq!(packed, vec![2147287680]);
}
#[test]
fn should_pack_i8s_to_u32_padded() {
let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55]);
let packed_padded = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55, 0, 0, 0]);
assert_eq!(packed, vec![2147287680, 55]);
assert_eq!(packed, packed_padded);
}
#[test]
fn should_unpack_u32s_to_i8s() {
let unpacked = unpack_q_to_i8s(&[2147287680u32], 4, &QuantValue::Q8S);
assert_eq!(unpacked, vec![-128, 2, -3, 127]);
}
#[test]
fn should_unpack_u32s_to_i8s_padded() {
let unpacked = unpack_q_to_i8s(&[55u32], 1, &QuantValue::Q8S);
assert_eq!(unpacked, vec![55]);
}
#[test]
fn should_unpack_u32s_to_i8s_arange() {
let unpacked = unpack_q_to_i8s(
&[
0u32, 286331136, 286331153, 572657937, 572662306, 857874978, 858993459, 858993459,
1145324612, 1145324612, 1431655748, 1431655765, 1717982549, 1717986918, 2003199590,
2004318071,
],
128,
&QuantValue::Q4S,
);
assert_eq!(
unpacked,
vec![
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5,
5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7
]
);
}
#[test]
fn should_pack_unpack_quantization_parameters_per_tensor_symmetric() {
// Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
let scale = 0.03937008;
let values = vec![0i8, 25, 51, 76, 102, 127];
let q_bytes = QuantizedBytes::new(
values.clone(),
QuantScheme::default()
.with_value(QuantValue::Q8S)
.with_store(QuantStore::Native),
&[scale],
);
let (q_values, qparams) = q_bytes.into_vec_i8();
assert_eq!(qparams.scales, vec![scale]);
assert_eq!(q_values, values);
}
}

View File

@@ -0,0 +1,271 @@
//! Tensor shape definition.
use super::{Slice, SliceArg};
use alloc::vec::Vec;
use core::ops::Range;
pub use crate::errors::ExpressionError;
pub use cubecl_zspace::{MetadataError, Shape, calculate_matmul_output, shape};
/// Slice-related ops on [`Shape`]
pub trait SliceOps: Sized {
/// Convert shape dimensions to full covering ranges (0..dim) for each dimension.
fn into_ranges(self) -> Vec<Range<usize>>;
/// Converts slice arguments into an array of slice specifications for the shape.
///
/// This method returns an array of `Slice` objects that can be used for slicing operations.
/// The slices are clamped to the shape's dimensions. Similar to `into_ranges()`, but
/// allows custom slice specifications instead of full ranges.
/// For creating complex slice specifications, use the [`s!`] macro.
///
/// # Arguments
///
/// * `slices` - An array of slice specifications, where each element can be:
/// - A range (e.g., `2..5`)
/// - An index
/// - A `Slice` object
/// - The output of the [`s!`] macro for advanced slicing
///
/// # Behavior
///
/// - Supports partial and full slicing in any number of dimensions.
/// - Missing ranges are treated as full slices if D > D2.
/// - Handles negative indices by wrapping around from the end of the dimension.
/// - Clamps ranges to the shape's dimensions if they exceed the bounds.
///
/// # Returns
///
/// An array of `Slice` objects corresponding to the provided slice specifications,
/// clamped to the shape's actual dimensions.
///
/// # Examples
///
/// ```rust
/// use burn_std::{Shape, Slice, s, SliceOps};
///
/// fn example() {
/// // 1D slicing
/// let slices = Shape::new([4]).into_slices(1..4);
/// assert_eq!(slices[0].to_range(4), 1..3);
///
/// // 2D slicing
/// let slices = Shape::new([3, 4]).into_slices(s![1..4, 0..2]);
/// assert_eq!(slices[0].to_range(3), 1..3);
/// assert_eq!(slices[1].to_range(4), 0..2);
///
/// // Using negative indices
/// let slices = Shape::new([3]).into_slices(..-2);
/// assert_eq!(slices[0].to_range(3), 0..1);
///
/// // Using the slice macro to select different ranges
/// let slices = Shape::new([2, 3, 4]).into_slices(s![.., 1..-1]);
/// assert_eq!(slices[0].to_range(2), 0..2);
/// assert_eq!(slices[1].to_range(3), 1..2);
/// }
/// ```
///
/// # See Also
///
/// - [`s!`] - The recommended macro for creating slice specifications
/// - [`Shape::into_ranges`] - Convert to full covering ranges
///
/// [`s!`]: crate::s!
fn into_slices<S>(self, slices: S) -> Vec<Slice>
where
S: SliceArg;
/// Compute the output shape from the given slices.
fn slice(self, slices: &[Slice]) -> Result<Self, MetadataError>;
}
impl SliceOps for Shape {
fn into_ranges(self) -> Vec<Range<usize>> {
self.iter().map(|&d| 0..d).collect()
}
fn into_slices<S>(self, slices: S) -> Vec<Slice>
where
S: SliceArg,
{
slices.into_slices(&self)
}
fn slice(mut self, slices: &[Slice]) -> Result<Self, MetadataError> {
if slices.len() > self.rank() {
return Err(MetadataError::RankMismatch {
left: self.rank(),
right: slices.len(),
});
}
slices
.iter()
.zip(self.iter_mut())
.for_each(|(slice, dim_size)| *dim_size = slice.output_size(*dim_size));
Ok(self)
}
}
#[cfg(test)]
#[allow(clippy::identity_op, reason = "useful for clarity")]
mod tests {
use super::*;
use crate::s;
use alloc::vec;
#[test]
fn test_into_ranges() {
let dims = [2, 3, 4, 5];
let shape = Shape::new(dims);
assert_eq!(shape.into_ranges(), vec![0..2, 0..3, 0..4, 0..5]);
}
#[allow(clippy::single_range_in_vec_init)]
#[test]
fn test_into_slices() {
let slices = Shape::new([3]).into_slices(1..4);
assert_eq!(slices[0].to_range(3), 1..3);
let slices = Shape::new([3, 4]).into_slices(s![1..4, 0..2]);
assert_eq!(slices[0].to_range(3), 1..3);
assert_eq!(slices[1].to_range(4), 0..2);
let slices = Shape::new([3]).into_slices(..-2);
assert_eq!(slices[0].to_range(3), 0..1);
let slices = Shape::new([2, 3, 4]).into_slices(s![.., 1..-1]);
assert_eq!(slices[0].to_range(2), 0..2);
assert_eq!(slices[1].to_range(3), 1..2);
let slices = Shape::new([2, 3, 4]).into_slices(s![..20, 2]);
assert_eq!(slices[0].to_range(2), 0..2);
assert_eq!(slices[1].to_range(3), 2..3);
}
#[test]
fn test_shape_as_slice() {
let dims = [2, 3, 4, 5];
let shape = Shape::new(dims);
assert_eq!(shape.as_slice(), dims.as_slice());
// Deref coercion
let shape_slice: &[usize] = &shape;
assert_eq!(shape_slice, *&[2, 3, 4, 5]);
}
#[test]
fn test_shape_as_mut_slice() {
let mut dims = [2, 3, 4, 5];
let mut shape = Shape::new(dims);
let shape_mut = shape.as_mut_slice();
assert_eq!(shape_mut, dims.as_mut_slice());
shape_mut[1] = 6;
assert_eq!(shape_mut, &[2, 6, 4, 5]);
let mut shape = Shape::new(dims);
let shape = &mut shape[..];
shape[1] = 6;
assert_eq!(shape, shape_mut)
}
#[test]
fn test_shape_slice_output_shape_basic() {
// Test basic slicing with step=1
let slices = [
Slice::new(0, Some(5), 1), // 5 elements
Slice::new(2, Some(8), 1), // 6 elements
];
let original_shape = Shape::new([10, 10, 10]);
let result = original_shape.slice(&slices).unwrap();
assert_eq!(result, Shape::new([5, 6, 10]));
}
#[test]
fn test_shape_slice_output_shape_with_positive_steps() {
// Test slicing with various positive steps
let slices = [
Slice::new(0, Some(10), 2), // [0,2,4,6,8] -> 5 elements
Slice::new(1, Some(9), 3), // [1,4,7] -> 3 elements
Slice::new(0, Some(7), 4), // [0,4] -> 2 elements
];
let original_shape = Shape::new([20, 20, 20, 30]);
let result = original_shape.slice(&slices).unwrap();
assert_eq!(result, Shape::new([5, 3, 2, 30]));
}
#[test]
fn test_shape_slice_output_shape_with_negative_steps() {
// Test slicing with negative steps (backward iteration)
let slices = [
Slice::new(0, Some(10), -1), // 10 elements traversed backward
Slice::new(2, Some(8), -2), // [7,5,3] -> 3 elements
];
let original_shape = Shape::new([20, 20, 20]);
let result = original_shape.slice(&slices).unwrap();
assert_eq!(result, Shape::new([10, 3, 20]));
}
#[test]
fn test_shape_slice_output_shape_mixed_steps() {
// Test with a mix of positive, negative, and unit steps
let slices = [
Slice::from_range_stepped(1..6, 1), // 5 elements
Slice::from_range_stepped(0..10, -3), // [9,6,3,0] -> 4 elements
Slice::from_range_stepped(2..14, 4), // [2,6,10] -> 3 elements
];
let original_shape = Shape::new([20, 20, 20]);
let result = original_shape.slice(&slices).unwrap();
assert_eq!(result, Shape::new([5, 4, 3]));
}
#[test]
fn test_shape_slice_output_shape_partial_dims() {
// Test when slices has fewer dimensions than original shape
let slices = [
Slice::from_range_stepped(2..7, 2), // [2,4,6] -> 3 elements
];
let original_shape = Shape::new([10, 20, 30, 40]);
let result = original_shape.slice(&slices).unwrap();
assert_eq!(result, Shape::new([3, 20, 30, 40]));
}
#[test]
fn test_shape_slice_output_shape_edge_cases() {
// Test edge cases with small ranges and large steps
let slices = [
Slice::from_range_stepped(0..1, 1), // Single element
Slice::from_range_stepped(0..10, 100), // Step larger than range -> 1 element
Slice::from_range_stepped(5..5, 1), // Empty range -> 0 elements
];
let original_shape = Shape::new([10, 20, 30]);
let result = original_shape.slice(&slices).unwrap();
assert_eq!(result, Shape::new([1, 1, 0]));
}
#[test]
fn test_shape_slice_output_shape_empty() {
// Test with no slice infos (should return original shape)
let slices = [];
let original_shape = Shape::new([10, 20, 30]);
let result = original_shape.slice(&slices).unwrap();
assert_eq!(result, Shape::new([10, 20, 30]));
}
#[test]
fn test_shape_slice_output_shape_uneven_division() {
// Test cases where range size doesn't divide evenly by step
let slices = [
Slice::from_range_stepped(0..7, 3), // ceil(7/3) = 3 elements: [0,3,6]
Slice::from_range_stepped(0..11, 4), // ceil(11/4) = 3 elements: [0,4,8]
Slice::from_range_stepped(1..10, 5), // ceil(9/5) = 2 elements: [1,6]
];
let original_shape = Shape::new([20, 20, 20]);
let result = original_shape.slice(&slices).unwrap();
assert_eq!(result, Shape::new([3, 3, 2]));
}
}

View File

@@ -0,0 +1,937 @@
//! Tensor slice utilities.
use crate::Shape;
use crate::indexing::AsIndex;
use alloc::format;
use alloc::vec::Vec;
use core::fmt::{Display, Formatter};
use core::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive};
use core::str::FromStr;
/// Trait for slice arguments that can be converted into an array of slices.
/// This allows the `slice` method to accept both single slices (from `s![..]`)
/// and arrays of slices (from `s![.., ..]` or `[0..5, 1..3]`).
pub trait SliceArg {
/// Convert to an vec of slices with clamping to shape dimensions.
///
/// Returns a [Slice] for each dimension in `shape`.
fn into_slices(self, shape: &Shape) -> Vec<Slice>;
}
impl<S: Into<Slice> + Clone> SliceArg for &[S] {
fn into_slices(self, shape: &Shape) -> Vec<Slice> {
assert!(
self.len() <= shape.num_dims(),
"Too many slices provided for shape, got {} but expected at most {}",
self.len(),
shape.num_dims()
);
shape
.iter()
.enumerate()
.map(|(i, dim_size)| {
let slice = if i >= self.len() {
Slice::full()
} else {
self[i].clone().into()
};
// Apply shape clamping by converting to range and back
let clamped_range = slice.to_range(*dim_size);
Slice::new(
clamped_range.start as isize,
Some(clamped_range.end as isize),
slice.step(),
)
})
.collect::<Vec<_>>()
}
}
impl SliceArg for &Vec<Slice> {
fn into_slices(self, shape: &Shape) -> Vec<Slice> {
self.as_slice().into_slices(shape)
}
}
impl<const R: usize, T> SliceArg for [T; R]
where
T: Into<Slice> + Clone,
{
fn into_slices(self, shape: &Shape) -> Vec<Slice> {
self.as_slice().into_slices(shape)
}
}
impl<T> SliceArg for T
where
T: Into<Slice>,
{
fn into_slices(self, shape: &Shape) -> Vec<Slice> {
let slice: Slice = self.into();
[slice].as_slice().into_slices(shape)
}
}
/// Slice argument constructor for tensor indexing.
///
/// The `s![]` macro is used to create multi-dimensional slice specifications for tensors.
/// It converts various range syntax forms into a `&[Slice]` that can be used with
/// `tensor.slice()` and `tensor.slice_assign()` operations.
///
/// # Syntax Overview
///
/// ## Basic Forms
///
/// * **`s![index]`** - Index a single element (produces a subview with that axis removed)
/// * **`s![range]`** - Slice a range of elements
/// * **`s![range;step]`** - Slice a range with a custom step
/// * **`s![dim1, dim2, ...]`** - Multiple dimensions, each can be any of the above forms
///
/// ## Range Types
///
/// All standard Rust range types are supported:
/// * **`a..b`** - From `a` (inclusive) to `b` (exclusive)
/// * **`a..=b`** - From `a` to `b` (both inclusive)
/// * **`a..`** - From `a` to the end
/// * **`..b`** - From the beginning to `b` (exclusive)
/// * **`..=b`** - From the beginning to `b` (inclusive)
/// * **`..`** - The full range (all elements)
///
/// ## Negative Indices
///
/// Negative indices count from the end of the axis:
/// * **`-1`** refers to the last element
/// * **`-2`** refers to the second-to-last element
/// * And so on...
///
/// This works in all range forms: `s![-3..-1]`, `s![-2..]`, `s![..-1]`
///
/// ## Step Syntax
///
/// Steps control the stride between selected elements:
/// * **`;step`** after a range specifies the step
/// * **Positive steps** select every nth element going forward
/// * **Negative steps** select every nth element going backward
/// * Default step is `1` when not specified
/// * Step cannot be `0`
///
/// ### Negative Step Behavior
///
/// With negative steps, the range bounds still specify *which* elements to include,
/// but the traversal order is reversed:
///
/// * `s![0..5;-1]` selects indices `[4, 3, 2, 1, 0]` (not `[0, 1, 2, 3, 4]`)
/// * `s![2..8;-2]` selects indices `[7, 5, 3]` (starting from 7, going backward by 2)
/// * `s![..;-1]` reverses the entire axis
///
/// This matches the semantics of NumPy and the ndarray crate.
///
/// # Examples
///
/// ## Basic Slicing
///
/// ```rust,ignore
/// use burn_tensor::{Tensor, s};
///
/// # fn example<B: Backend>(tensor: Tensor<B, 3>) {
/// // Select rows 0-5 (exclusive)
/// let subset = tensor.slice(s![0..5, .., ..]);
///
/// // Select the last row
/// let last_row = tensor.slice(s![-1, .., ..]);
///
/// // Select columns 2, 3, 4
/// let cols = tensor.slice(s![.., 2..5, ..]);
///
/// // Select a single element at position [1, 2, 3]
/// let element = tensor.slice(s![1, 2, 3]);
/// # }
/// ```
///
/// ## Slicing with Steps
///
/// ```rust,ignore
/// use burn_tensor::{Tensor, s};
///
/// # fn example<B: Backend>(tensor: Tensor<B, 2>) {
/// // Select every 2nd row
/// let even_rows = tensor.slice(s![0..10;2, ..]);
///
/// // Select every 3rd column
/// let cols = tensor.slice(s![.., 0..9;3]);
///
/// // Select every 2nd element in reverse order
/// let reversed_even = tensor.slice(s![10..0;-2, ..]);
/// # }
/// ```
///
/// ## Reversing Dimensions
///
/// ```rust,ignore
/// use burn_tensor::{Tensor, s};
///
/// # fn example<B: Backend>(tensor: Tensor<B, 2>) {
/// // Reverse the first dimension
/// let reversed = tensor.slice(s![..;-1, ..]);
///
/// // Reverse both dimensions
/// let fully_reversed = tensor.slice(s![..;-1, ..;-1]);
///
/// // Reverse a specific range
/// let range_reversed = tensor.slice(s![2..8;-1, ..]);
/// # }
/// ```
///
/// ## Complex Multi-dimensional Slicing
///
/// ```rust,ignore
/// use burn_tensor::{Tensor, s};
///
/// # fn example<B: Backend>(tensor: Tensor<B, 4>) {
/// // Mix of different slice types
/// let complex = tensor.slice(s![
/// 0..10;2, // Every 2nd element from 0 to 10
/// .., // All elements in dimension 1
/// 5..15;-3, // Every 3rd element from 14 down to 5
/// -1 // Last element in dimension 3
/// ]);
///
/// // Using inclusive ranges
/// let inclusive = tensor.slice(s![2..=5, 1..=3, .., ..]);
///
/// // Negative indices with steps
/// let from_end = tensor.slice(s![-5..-1;2, .., .., ..]);
/// # }
/// ```
///
/// ## Slice Assignment
///
/// ```rust,ignore
/// use burn_tensor::{Tensor, s};
///
/// # fn example<B: Backend>(tensor: Tensor<B, 2>, values: Tensor<B, 2>) {
/// // Assign to every 2nd row
/// let tensor = tensor.slice_assign(s![0..10;2, ..], values);
///
/// // Assign to a reversed slice
/// let tensor = tensor.slice_assign(s![..;-1, 0..5], values);
/// # }
/// ```
#[macro_export]
macro_rules! s {
// Empty - should not happen
[] => {
compile_error!("Empty slice specification")
};
// Single expression with step
[$range:expr; $step:expr] => {
{
#[allow(clippy::reversed_empty_ranges)]
{
$crate::tensor::Slice::from_range_stepped($range, $step)
}
}
};
// Single expression without step (no comma after)
[$range:expr] => {
{
#[allow(clippy::reversed_empty_ranges)]
{
$crate::tensor::Slice::from($range)
}
}
};
// Two or more expressions with first having step
[$range:expr; $step:expr, $($rest:tt)*] => {
{
#[allow(clippy::reversed_empty_ranges)]
{
$crate::s!(@internal [$crate::tensor::Slice::from_range_stepped($range, $step)] $($rest)*)
}
}
};
// Two or more expressions with first not having step
[$range:expr, $($rest:tt)*] => {
{
#[allow(clippy::reversed_empty_ranges)]
{
$crate::s!(@internal [$crate::tensor::Slice::from($range)] $($rest)*)
}
}
};
// Internal: finished parsing
(@internal [$($acc:expr),*]) => {
[$($acc),*]
};
// Internal: parse range with step followed by comma
(@internal [$($acc:expr),*] $range:expr; $step:expr, $($rest:tt)*) => {
$crate::s!(@internal [$($acc,)* $crate::tensor::Slice::from_range_stepped($range, $step as isize)] $($rest)*)
};
// Internal: parse range with step at end
(@internal [$($acc:expr),*] $range:expr; $step:expr) => {
$crate::s!(@internal [$($acc,)* $crate::tensor::Slice::from_range_stepped($range, $step as isize)])
};
// Internal: parse range without step followed by comma
(@internal [$($acc:expr),*] $range:expr, $($rest:tt)*) => {
$crate::s!(@internal [$($acc,)* $crate::tensor::Slice::from($range)] $($rest)*)
};
// Internal: parse range without step at end
(@internal [$($acc:expr),*] $range:expr) => {
$crate::s!(@internal [$($acc,)* $crate::tensor::Slice::from($range)])
};
}
/// A slice specification for a single tensor dimension.
///
/// This struct represents a range with an optional step, used for advanced indexing
/// operations on tensors. It is typically created using the [`s!`] macro rather than
/// constructed directly.
///
/// # Fields
///
/// * `start` - The starting index (inclusive). Negative values count from the end.
/// * `end` - The ending index (exclusive). `None` means to the end of the dimension.
/// * `step` - The stride between elements. Must be non-zero.
///
/// # Index Interpretation
///
/// - **Positive indices**: Count from the beginning (0-based)
/// - **Negative indices**: Count from the end (-1 is the last element)
/// - **Bounds checking**: Indices are clamped to valid ranges
///
/// # Step Behavior
///
/// - **Positive step**: Traverse forward through the range
/// - **Negative step**: Traverse backward through the range
/// - **Step size**: Determines how many elements to skip
///
/// # Examples
///
/// While you typically use the [`s!`] macro, you can also construct slices directly:
///
/// ```rust,ignore
/// use burn_tensor::Slice;
///
/// // Equivalent to s![2..8]
/// let slice1 = Slice::new(2, Some(8), 1);
///
/// // Equivalent to s![0..10;2]
/// let slice2 = Slice::new(0, Some(10), 2);
///
/// // Equivalent to s![..;-1] (reverse)
/// let slice3 = Slice::new(0, None, -1);
/// ```
///
/// See also the [`s!`] macro for the preferred way to create slices.
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct Slice {
/// Slice start index.
pub start: isize,
/// Slice end index (exclusive).
pub end: Option<isize>,
/// Step between elements (default: 1).
pub step: isize,
}
/// Defines an [`Iterator`] over a [`Slice`].
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct SliceIter {
slice: Slice,
current: isize,
}
impl Iterator for SliceIter {
type Item = isize;
fn next(&mut self) -> Option<Self::Item> {
let next = self.current;
self.current += self.slice.step;
if let Some(end) = self.slice.end {
if self.slice.is_reversed() {
if next <= end {
return None;
}
} else if next >= end {
return None;
}
}
Some(next)
}
}
/// Note: Unbounded [`Slice`]s produce infinite iterators.
impl IntoIterator for Slice {
type Item = isize;
type IntoIter = SliceIter;
fn into_iter(self) -> Self::IntoIter {
SliceIter {
slice: self,
current: self.start,
}
}
}
impl Default for Slice {
fn default() -> Self {
Self::full()
}
}
impl Slice {
/// Creates a new slice with start, end, and step
pub const fn new(start: isize, end: Option<isize>, step: isize) -> Self {
assert!(step != 0, "Step cannot be zero");
Self { start, end, step }
}
/// Creates a slice that represents the full range.
pub const fn full() -> Self {
Self::new(0, None, 1)
}
/// Creates a slice that represents a single index
pub fn index(idx: isize) -> Self {
Self {
start: idx,
end: handle_signed_inclusive_end(idx),
step: 1,
}
}
/// Converts the slice to a vector.
pub fn into_vec(self) -> Vec<isize> {
assert!(
self.end.is_some(),
"Slice must have an end to convert to a vector: {self:?}"
);
self.into_iter().collect()
}
/// Clips the slice to a maximum size.
///
/// # Example
///
/// ```rust,ignore
/// assert_eq!(
/// Slice::new(0, None, 1).bound_to(10),
/// Slice::new(0, Some(10), 1));
/// assert_eq!(
/// Slice::new(0, Some(5), 1).bound_to(10),
/// Slice::new(0, Some(5), 1));
/// assert_eq!(
/// Slice::new(0, None, -1).bound_to(10),
/// Slice::new(0, Some(-11), -1));
/// assert_eq!(
/// Slice::new(0, Some(-5), -1).bound_to(10),
/// Slice::new(0, Some(-5), -1));
/// ```
pub fn bound_to(self, size: usize) -> Self {
let mut bounds = size as isize;
if let Some(end) = self.end {
if end > 0 {
bounds = end.min(bounds);
} else {
bounds = end.max(-(bounds + 1));
}
} else if self.is_reversed() {
bounds = -(bounds + 1);
}
Self {
end: Some(bounds),
..self
}
}
/// Creates a slice with a custom step
pub fn with_step(start: isize, end: Option<isize>, step: isize) -> Self {
assert!(step != 0, "Step cannot be zero");
Self { start, end, step }
}
/// Creates a slice from a range with a specified step
pub fn from_range_stepped<R: Into<Slice>>(range: R, step: isize) -> Self {
assert!(step != 0, "Step cannot be zero");
let mut slice = range.into();
slice.step = step;
slice
}
/// Returns the step of the slice
pub fn step(&self) -> isize {
self.step
}
/// Returns the range for this slice given a dimension size
pub fn range(&self, size: usize) -> Range<usize> {
self.to_range(size)
}
/// Convert this slice to a range for a dimension of the given size.
///
/// # Arguments
///
/// * `size` - The size of the dimension to slice.
///
/// # Returns
///
/// A `Range<usize>` representing the slice bounds.
pub fn to_range(&self, size: usize) -> Range<usize> {
// Always return a valid range with start <= end
// The step information will be handled separately
let start = convert_signed_index(self.start, size);
let end = match self.end {
Some(end) => convert_signed_index(end, size),
None => size,
};
start..end
}
/// Converts the slice into a range and step tuple
pub fn to_range_and_step(&self, size: usize) -> (Range<usize>, isize) {
let range = self.to_range(size);
(range, self.step)
}
/// Returns true if the step is negative
pub fn is_reversed(&self) -> bool {
self.step < 0
}
/// Calculates the output size for this slice operation
pub fn output_size(&self, dim_size: usize) -> usize {
let range = self.to_range(dim_size);
// Handle empty slices (start >= end)
if range.start >= range.end {
return 0;
}
let len = range.end - range.start;
if self.step.unsigned_abs() == 1 {
len
} else {
len.div_ceil(self.step.unsigned_abs())
}
}
}
fn convert_signed_index(index: isize, size: usize) -> usize {
if index < 0 {
(size as isize + index).max(0) as usize
} else {
(index as usize).min(size)
}
}
fn handle_signed_inclusive_end(end: isize) -> Option<isize> {
match end {
-1 => None,
end => Some(end + 1),
}
}
impl<I: AsIndex> From<Range<I>> for Slice {
fn from(r: Range<I>) -> Self {
Self {
start: r.start.as_index(),
end: Some(r.end.as_index()),
step: 1,
}
}
}
impl<I: AsIndex + Copy> From<RangeInclusive<I>> for Slice {
fn from(r: RangeInclusive<I>) -> Self {
Self {
start: r.start().as_index(),
end: handle_signed_inclusive_end(r.end().as_index()),
step: 1,
}
}
}
impl<I: AsIndex> From<RangeFrom<I>> for Slice {
fn from(r: RangeFrom<I>) -> Self {
Self {
start: r.start.as_index(),
end: None,
step: 1,
}
}
}
impl<I: AsIndex> From<RangeTo<I>> for Slice {
fn from(r: RangeTo<I>) -> Self {
Self {
start: 0,
end: Some(r.end.as_index()),
step: 1,
}
}
}
impl<I: AsIndex> From<RangeToInclusive<I>> for Slice {
fn from(r: RangeToInclusive<I>) -> Self {
Self {
start: 0,
end: handle_signed_inclusive_end(r.end.as_index()),
step: 1,
}
}
}
impl From<RangeFull> for Slice {
fn from(_: RangeFull) -> Self {
Self {
start: 0,
end: None,
step: 1,
}
}
}
impl From<usize> for Slice {
fn from(i: usize) -> Self {
Slice::index(i as isize)
}
}
impl From<isize> for Slice {
fn from(i: isize) -> Self {
Slice::index(i)
}
}
impl From<i32> for Slice {
fn from(i: i32) -> Self {
Slice::index(i as isize)
}
}
impl From<i64> for Slice {
fn from(i: i64) -> Self {
Slice::index(i as isize)
}
}
impl Display for Slice {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
if self.step == 1
&& let Some(end) = self.end
&& self.start == end - 1
{
f.write_fmt(format_args!("{}", self.start))
} else {
if self.start != 0 {
f.write_fmt(format_args!("{}", self.start))?;
}
f.write_str("..")?;
if let Some(end) = self.end {
f.write_fmt(format_args!("{}", end))?;
}
if self.step != 1 {
f.write_fmt(format_args!(";{}", self.step))?;
}
Ok(())
}
}
}
impl FromStr for Slice {
type Err = crate::ExpressionError;
fn from_str(source: &str) -> Result<Self, Self::Err> {
let mut s = source.trim();
let parse_int = |v: &str| -> Result<isize, Self::Err> {
v.parse::<isize>().map_err(|e| {
crate::ExpressionError::parse_error(
format!("Invalid integer: '{v}': {}", e),
source,
)
})
};
let mut start: isize = 0;
let mut end: Option<isize> = None;
let mut step: isize = 1;
if let Some((head, tail)) = s.split_once(";") {
step = parse_int(tail)?;
s = head;
}
if s.is_empty() {
return Err(crate::ExpressionError::parse_error(
"Empty expression",
source,
));
}
if let Some((start_s, end_s)) = s.split_once("..") {
if !start_s.is_empty() {
start = parse_int(start_s)?;
}
if !end_s.is_empty() {
if let Some(end_s) = end_s.strip_prefix('=') {
end = Some(parse_int(end_s)? + 1);
} else {
end = Some(parse_int(end_s)?);
}
}
} else {
start = parse_int(s)?;
end = Some(start + 1);
}
if step == 0 {
return Err(crate::ExpressionError::invalid_expression(
"Step cannot be zero",
source,
));
}
Ok(Slice::new(start, end, step))
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::string::ToString;
use alloc::vec;
#[test]
fn test_slice_to_str() {
assert_eq!(Slice::new(0, None, 1).to_string(), "..");
assert_eq!(Slice::new(0, Some(1), 1).to_string(), "0");
assert_eq!(Slice::new(0, Some(10), 1).to_string(), "..10");
assert_eq!(Slice::new(1, Some(10), 1).to_string(), "1..10");
assert_eq!(Slice::new(-3, Some(10), -2).to_string(), "-3..10;-2");
}
#[test]
fn test_slice_from_str() {
assert_eq!("1".parse::<Slice>(), Ok(Slice::new(1, Some(2), 1)));
assert_eq!("..".parse::<Slice>(), Ok(Slice::new(0, None, 1)));
assert_eq!("..3".parse::<Slice>(), Ok(Slice::new(0, Some(3), 1)));
assert_eq!("..=3".parse::<Slice>(), Ok(Slice::new(0, Some(4), 1)));
assert_eq!("-12..3".parse::<Slice>(), Ok(Slice::new(-12, Some(3), 1)));
assert_eq!("..;-1".parse::<Slice>(), Ok(Slice::new(0, None, -1)));
assert_eq!("..=3;-2".parse::<Slice>(), Ok(Slice::new(0, Some(4), -2)));
assert_eq!(
"..;0".parse::<Slice>(),
Err(crate::ExpressionError::invalid_expression(
"Step cannot be zero",
"..;0"
))
);
assert_eq!(
"".parse::<Slice>(),
Err(crate::ExpressionError::parse_error("Empty expression", ""))
);
assert_eq!(
"a".parse::<Slice>(),
Err(crate::ExpressionError::parse_error(
"Invalid integer: 'a': invalid digit found in string",
"a"
))
);
assert_eq!(
"..a".parse::<Slice>(),
Err(crate::ExpressionError::parse_error(
"Invalid integer: 'a': invalid digit found in string",
"..a"
))
);
assert_eq!(
"a:b:c".parse::<Slice>(),
Err(crate::ExpressionError::parse_error(
"Invalid integer: 'a:b:c': invalid digit found in string",
"a:b:c"
))
);
}
#[test]
fn test_slice_output_size() {
// Test the output_size method directly
assert_eq!(Slice::new(0, Some(10), 1).output_size(10), 10);
assert_eq!(Slice::new(0, Some(10), 2).output_size(10), 5);
assert_eq!(Slice::new(0, Some(10), 3).output_size(10), 4); // ceil(10/3)
assert_eq!(Slice::new(0, Some(10), -1).output_size(10), 10);
assert_eq!(Slice::new(0, Some(10), -2).output_size(10), 5);
assert_eq!(Slice::new(2, Some(8), -3).output_size(10), 2); // ceil(6/3)
assert_eq!(Slice::new(5, Some(5), 1).output_size(10), 0); // empty range
}
#[test]
fn test_bound_to() {
assert_eq!(
Slice::new(0, None, 1).bound_to(10),
Slice::new(0, Some(10), 1)
);
assert_eq!(
Slice::new(0, Some(5), 1).bound_to(10),
Slice::new(0, Some(5), 1)
);
assert_eq!(
Slice::new(0, None, -1).bound_to(10),
Slice::new(0, Some(-11), -1)
);
assert_eq!(
Slice::new(0, Some(-5), -1).bound_to(10),
Slice::new(0, Some(-5), -1)
);
}
#[test]
fn test_slice_iter() {
assert_eq!(
Slice::new(2, Some(3), 1).into_iter().collect::<Vec<_>>(),
vec![2]
);
assert_eq!(
Slice::new(3, Some(-1), -1).into_iter().collect::<Vec<_>>(),
vec![3, 2, 1, 0]
);
assert_eq!(Slice::new(3, Some(-1), -1).into_vec(), vec![3, 2, 1, 0]);
assert_eq!(
Slice::new(3, None, 2)
.into_iter()
.take(3)
.collect::<Vec<_>>(),
vec![3, 5, 7]
);
assert_eq!(
Slice::new(3, None, 2)
.bound_to(8)
.into_iter()
.collect::<Vec<_>>(),
vec![3, 5, 7]
);
}
#[test]
#[should_panic(
expected = "Slice must have an end to convert to a vector: Slice { start: 0, end: None, step: 1 }"
)]
fn test_unbound_slice_into_vec() {
Slice::new(0, None, 1).into_vec();
}
#[test]
fn into_slices_should_return_for_all_shape_dims() {
let slice = s![1];
let shape = Shape::new([2, 3, 1]);
let slices = slice.into_slices(&shape);
assert_eq!(slices.len(), shape.len());
assert_eq!(slices[0], Slice::new(1, Some(2), 1));
assert_eq!(slices[1], Slice::new(0, Some(3), 1));
assert_eq!(slices[2], Slice::new(0, Some(1), 1));
let slice = s![1, 0..2];
let slices = slice.into_slices(&shape);
assert_eq!(slices.len(), shape.len());
assert_eq!(slices[0], Slice::new(1, Some(2), 1));
assert_eq!(slices[1], Slice::new(0, Some(2), 1));
assert_eq!(slices[2], Slice::new(0, Some(1), 1));
let slice = s![..];
let slices = slice.into_slices(&shape);
assert_eq!(slices.len(), shape.len());
assert_eq!(slices[0], Slice::new(0, Some(2), 1));
assert_eq!(slices[1], Slice::new(0, Some(3), 1));
assert_eq!(slices[2], Slice::new(0, Some(1), 1));
}
#[test]
fn into_slices_all_dimensions() {
let slice = s![1, ..2, ..];
let shape = Shape::new([2, 3, 1]);
let slices = slice.into_slices(&shape);
assert_eq!(slices.len(), shape.len());
assert_eq!(slices[0], Slice::new(1, Some(2), 1));
assert_eq!(slices[1], Slice::new(0, Some(2), 1));
assert_eq!(slices[2], Slice::new(0, Some(1), 1));
}
#[test]
fn into_slices_supports_empty_dimensions() {
let slice = s![.., 1, ..];
let shape = Shape::new([0, 3, 1]);
let slices = slice.into_slices(&shape);
assert_eq!(slices.len(), shape.len());
assert_eq!(slices[0], Slice::new(0, Some(0), 1));
assert_eq!(slices[1], Slice::new(1, Some(2), 1));
assert_eq!(slices[2], Slice::new(0, Some(1), 1));
}
#[test]
#[should_panic = "Too many slices provided for shape"]
fn into_slices_should_match_shape_rank() {
let slice = s![.., 1, ..];
let shape = Shape::new([3, 1]);
let _ = slice.into_slices(&shape);
}
#[test]
fn should_support_const_and_full() {
static SLICES: [Slice; 2] = [Slice::full(), Slice::new(2, None, 1)];
assert_eq!(SLICES[0], Slice::new(0, None, 1));
assert_eq!(SLICES[1], Slice::new(2, None, 1));
}
#[test]
fn should_support_default() {
assert_eq!(Slice::default(), Slice::new(0, None, 1));
}
#[test]
fn should_support_copy() {
let mut slice = Slice::new(1, Some(3), 2);
let slice_copy = slice;
slice.end = Some(4);
assert_eq!(slice, Slice::new(1, Some(4), 2));
assert_eq!(slice_copy, Slice::new(1, Some(3), 2));
}
}