feat: update workspace paths and enhance gitignore
- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
This commit is contained in:
57
crates/stable-diffusion-burn/burn-crates/burn-std/Cargo.toml
Normal file
57
crates/stable-diffusion-burn/burn-crates/burn-std/Cargo.toml
Normal file
@@ -0,0 +1,57 @@
|
||||
[package]
|
||||
authors = ["Dilshod Tadjibaev (@antimora)"]
|
||||
categories = []
|
||||
description = "Core types and utilities shared across the Burn ecosystem."
|
||||
documentation = "https://docs.rs/burn-std"
|
||||
edition.workspace = true
|
||||
keywords = []
|
||||
license.workspace = true
|
||||
name = "burn-std"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-std"
|
||||
version.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[features]
|
||||
cubecl = ["dep:cubecl"]
|
||||
default = ["std", "cubecl-common/default"]
|
||||
doc = ["default"]
|
||||
std = ["cubecl-common/std", "num-traits/std"]
|
||||
tracing = ["cubecl?/tracing", "cubecl-common/tracing"]
|
||||
|
||||
network = ["dep:indicatif", "dep:reqwest", "dep:tokio"]
|
||||
|
||||
[dependencies]
|
||||
bytemuck = { workspace = true, features = ["extern_crate_alloc"] }
|
||||
half = { workspace = true, features = ["bytemuck"] }
|
||||
num-traits = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
smallvec = { workspace = true, features = ["serde"] }
|
||||
|
||||
cubecl = { workspace = true, optional = true, default-features = false }
|
||||
cubecl-common = { workspace = true, default-features = false, features = [
|
||||
"serde",
|
||||
"shared-bytes",
|
||||
] }
|
||||
cubecl-zspace = { workspace = true, default-features = false }
|
||||
# Enable extra-platforms for portable-atomic support on targets without native atomics (e.g., thumbv6m)
|
||||
# This is needed because cubecl-common's shared-bytes feature pulls in bytes
|
||||
bytes = { workspace = true }
|
||||
|
||||
# Network downloader
|
||||
indicatif = { workspace = true, optional = true }
|
||||
reqwest = { workspace = true, optional = true }
|
||||
tokio = { workspace = true, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
dashmap = { workspace = true }
|
||||
|
||||
# Enable extra-platforms for bytes on targets without native atomics (e.g., thumbv6m-none-eabi)
|
||||
[target.'cfg(not(target_has_atomic = "ptr"))'.dependencies]
|
||||
bytes = { workspace = true, features = ["extra-platforms"] }
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
features = ["doc"]
|
||||
rustdoc-args = ["--cfg", "docsrs"]
|
||||
1
crates/stable-diffusion-burn/burn-crates/burn-std/LICENSE-APACHE
Symbolic link
1
crates/stable-diffusion-burn/burn-crates/burn-std/LICENSE-APACHE
Symbolic link
@@ -0,0 +1 @@
|
||||
../../LICENSE-APACHE
|
||||
1
crates/stable-diffusion-burn/burn-crates/burn-std/LICENSE-MIT
Symbolic link
1
crates/stable-diffusion-burn/burn-crates/burn-std/LICENSE-MIT
Symbolic link
@@ -0,0 +1 @@
|
||||
../../LICENSE-MIT
|
||||
@@ -0,0 +1,7 @@
|
||||
# Burn Standard Library
|
||||
|
||||
`burn-std` provides the core types and utilities shared across the Burn ecosystem.
|
||||
It includes foundational definitions for shapes, indexing, and data types.
|
||||
|
||||
This crate supports both `std` and `no_std` environments and must compile with
|
||||
`cargo build --no-default-features` as well.
|
||||
69
crates/stable-diffusion-burn/burn-crates/burn-std/src/id.rs
Normal file
69
crates/stable-diffusion-burn/burn-crates/burn-std/src/id.rs
Normal file
@@ -0,0 +1,69 @@
|
||||
//! # Unique Identifiers
|
||||
use crate::rand::gen_random;
|
||||
|
||||
/// Simple ID generator.
|
||||
pub struct IdGenerator {}
|
||||
|
||||
impl IdGenerator {
|
||||
/// Generates a new ID.
|
||||
pub fn generate() -> u64 {
|
||||
// Generate a random u64 (18,446,744,073,709,551,615 combinations)
|
||||
let random_bytes: [u8; 8] = gen_random();
|
||||
u64::from_le_bytes(random_bytes)
|
||||
}
|
||||
}
|
||||
|
||||
pub use cubecl_common::stream_id::StreamId;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use alloc::collections::BTreeSet;
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use dashmap::DashSet; //Concurrent HashMap
|
||||
#[cfg(feature = "std")]
|
||||
use std::{sync::Arc, thread};
|
||||
|
||||
#[test]
|
||||
fn uniqueness_test() {
|
||||
const IDS_CNT: usize = 10_000;
|
||||
|
||||
let mut set: BTreeSet<u64> = BTreeSet::new();
|
||||
|
||||
for _i in 0..IDS_CNT {
|
||||
assert!(set.insert(IdGenerator::generate()));
|
||||
}
|
||||
|
||||
assert_eq!(set.len(), IDS_CNT);
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[test]
|
||||
fn thread_safety_test() {
|
||||
const NUM_THREADS: usize = 10;
|
||||
const NUM_REPEATS: usize = 1_000;
|
||||
const EXPECTED_TOTAL_IDS: usize = NUM_THREADS * NUM_REPEATS;
|
||||
|
||||
let set: Arc<DashSet<u64>> = Arc::new(DashSet::new());
|
||||
|
||||
let mut handles = vec![];
|
||||
|
||||
for _ in 0..NUM_THREADS {
|
||||
let set = set.clone();
|
||||
|
||||
let handle = thread::spawn(move || {
|
||||
for _i in 0..NUM_REPEATS {
|
||||
assert!(set.insert(IdGenerator::generate()));
|
||||
}
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
assert_eq!(set.len(), EXPECTED_TOTAL_IDS);
|
||||
}
|
||||
}
|
||||
97
crates/stable-diffusion-burn/burn-crates/burn-std/src/lib.rs
Normal file
97
crates/stable-diffusion-burn/burn-crates/burn-std/src/lib.rs
Normal file
@@ -0,0 +1,97 @@
|
||||
#![cfg_attr(not(feature = "std"), no_std)]
|
||||
#![warn(missing_docs)]
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
|
||||
//! # Burn Standard Library
|
||||
//!
|
||||
//! This library contains core types and utilities shared across Burn, including shapes, indexing,
|
||||
//! and data types.
|
||||
|
||||
extern crate alloc;
|
||||
|
||||
/// Id module contains types for unique identifiers.
|
||||
pub mod id;
|
||||
|
||||
/// Tensor utilities.
|
||||
pub mod tensor;
|
||||
pub use tensor::*;
|
||||
|
||||
/// Common Errors.
|
||||
pub use cubecl_zspace::errors::{self, *};
|
||||
|
||||
/// Network utilities.
|
||||
#[cfg(feature = "network")]
|
||||
pub mod network;
|
||||
|
||||
// Re-exported types
|
||||
pub use cubecl_common::bytes::*;
|
||||
pub use cubecl_common::*;
|
||||
pub use half::{bf16, f16};
|
||||
|
||||
#[cfg(feature = "cubecl")]
|
||||
pub use cubecl::flex32;
|
||||
|
||||
#[cfg(feature = "cubecl")]
|
||||
mod cube {
|
||||
use cubecl::ir::{ElemType, FloatKind, IntKind, StorageType, UIntKind};
|
||||
use cubecl_common::quant::scheme::QuantScheme;
|
||||
|
||||
use crate::tensor::DType;
|
||||
use crate::tensor::quantization::{QuantStore, QuantValue};
|
||||
|
||||
impl From<DType> for cubecl::ir::ElemType {
|
||||
fn from(dtype: DType) -> Self {
|
||||
match dtype {
|
||||
DType::F64 => ElemType::Float(FloatKind::F64),
|
||||
DType::F32 => ElemType::Float(FloatKind::F32),
|
||||
DType::Flex32 => ElemType::Float(FloatKind::Flex32),
|
||||
DType::F16 => ElemType::Float(FloatKind::F16),
|
||||
DType::BF16 => ElemType::Float(FloatKind::BF16),
|
||||
DType::I64 => ElemType::Int(IntKind::I64),
|
||||
DType::I32 => ElemType::Int(IntKind::I32),
|
||||
DType::I16 => ElemType::Int(IntKind::I16),
|
||||
DType::I8 => ElemType::Int(IntKind::I8),
|
||||
DType::U64 => ElemType::UInt(UIntKind::U64),
|
||||
DType::U32 => ElemType::UInt(UIntKind::U32),
|
||||
DType::U16 => ElemType::UInt(UIntKind::U16),
|
||||
DType::U8 => ElemType::UInt(UIntKind::U8),
|
||||
DType::Bool => ElemType::Bool,
|
||||
DType::QFloat(scheme) => match scheme.store {
|
||||
QuantStore::Native => match scheme.value {
|
||||
QuantValue::Q8F | QuantValue::Q8S => Self::Int(IntKind::I8),
|
||||
QuantValue::E4M3 => Self::Float(FloatKind::E4M3),
|
||||
QuantValue::E5M2 => Self::Float(FloatKind::E5M2),
|
||||
QuantValue::Q4F
|
||||
| QuantValue::Q4S
|
||||
| QuantValue::Q2F
|
||||
| QuantValue::Q2S
|
||||
| QuantValue::E2M1 => {
|
||||
panic!("Can't store native sub-byte values")
|
||||
}
|
||||
},
|
||||
QuantStore::PackedU32(_) => Self::UInt(UIntKind::U32),
|
||||
QuantStore::PackedNative(_) => match scheme.value {
|
||||
QuantValue::E2M1 => panic!("Can't store native sub-byte values"),
|
||||
other => panic!("{other:?} doesn't support native packing"),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DType> for cubecl::ir::StorageType {
|
||||
fn from(dtype: DType) -> cubecl::ir::StorageType {
|
||||
match dtype {
|
||||
DType::QFloat(QuantScheme {
|
||||
store: QuantStore::PackedNative(_),
|
||||
value: QuantValue::E2M1,
|
||||
..
|
||||
}) => StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2),
|
||||
_ => {
|
||||
let elem: ElemType = dtype.into();
|
||||
elem.into()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
//! # Common Network Utilities
|
||||
|
||||
/// Network download utilities.
|
||||
pub mod downloader {
|
||||
use indicatif::{ProgressBar, ProgressState, ProgressStyle};
|
||||
use reqwest::Client;
|
||||
use std::io::Write;
|
||||
|
||||
/// Download the file at the specified url.
|
||||
/// File download progress is reported with the help of a [progress bar](indicatif).
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `url` - The file URL to download.
|
||||
/// * `message` - The message to display on the progress bar during download.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A vector of bytes containing the downloaded file data.
|
||||
#[tokio::main(flavor = "current_thread")]
|
||||
pub async fn download_file_as_bytes(url: &str, message: &str) -> Vec<u8> {
|
||||
// Get file from web
|
||||
let mut response = Client::new().get(url).send().await.unwrap();
|
||||
let total_size = response.content_length().unwrap();
|
||||
|
||||
// Pretty progress bar
|
||||
let pb = ProgressBar::new(total_size);
|
||||
let msg = message.to_owned();
|
||||
pb.set_style(
|
||||
ProgressStyle::with_template(
|
||||
"{msg}\n {wide_bar:.cyan/blue} {bytes}/{total_bytes} ({eta})",
|
||||
)
|
||||
.unwrap()
|
||||
.with_key(
|
||||
"eta",
|
||||
|state: &ProgressState, w: &mut dyn std::fmt::Write| {
|
||||
write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()
|
||||
},
|
||||
)
|
||||
.progress_chars("▬ "),
|
||||
);
|
||||
pb.set_message(msg.clone());
|
||||
|
||||
// Read stream into bytes
|
||||
let mut downloaded: u64 = 0;
|
||||
let mut bytes: Vec<u8> = Vec::with_capacity(total_size as usize);
|
||||
while let Some(chunk) = response.chunk().await.unwrap() {
|
||||
let num_bytes = bytes.write(&chunk).unwrap();
|
||||
let new = std::cmp::min(downloaded + (num_bytes as u64), total_size);
|
||||
downloaded = new;
|
||||
pb.set_position(new);
|
||||
}
|
||||
pb.finish_with_message(msg);
|
||||
|
||||
bytes
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,224 @@
|
||||
//! Tensor data type.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::tensor::quantization::{QuantScheme, QuantStore, QuantValue};
|
||||
use crate::{bf16, f16};
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum DType {
|
||||
F64,
|
||||
F32,
|
||||
Flex32,
|
||||
F16,
|
||||
BF16,
|
||||
I64,
|
||||
I32,
|
||||
I16,
|
||||
I8,
|
||||
U64,
|
||||
U32,
|
||||
U16,
|
||||
U8,
|
||||
Bool,
|
||||
QFloat(QuantScheme),
|
||||
}
|
||||
|
||||
#[cfg(feature = "cubecl")]
|
||||
impl From<cubecl::ir::ElemType> for DType {
|
||||
fn from(value: cubecl::ir::ElemType) -> Self {
|
||||
match value {
|
||||
cubecl::ir::ElemType::Float(float_kind) => match float_kind {
|
||||
cubecl::ir::FloatKind::F16 => DType::F16,
|
||||
cubecl::ir::FloatKind::BF16 => DType::BF16,
|
||||
cubecl::ir::FloatKind::Flex32 => DType::Flex32,
|
||||
cubecl::ir::FloatKind::F32 => DType::F32,
|
||||
cubecl::ir::FloatKind::F64 => DType::F64,
|
||||
cubecl::ir::FloatKind::TF32 => panic!("Not a valid DType for tensors."),
|
||||
cubecl::ir::FloatKind::E2M1
|
||||
| cubecl::ir::FloatKind::E2M3
|
||||
| cubecl::ir::FloatKind::E3M2
|
||||
| cubecl::ir::FloatKind::E4M3
|
||||
| cubecl::ir::FloatKind::E5M2
|
||||
| cubecl::ir::FloatKind::UE8M0 => {
|
||||
unimplemented!("Not yet supported, will be used for quantization")
|
||||
}
|
||||
},
|
||||
cubecl::ir::ElemType::Int(int_kind) => match int_kind {
|
||||
cubecl::ir::IntKind::I8 => DType::I8,
|
||||
cubecl::ir::IntKind::I16 => DType::I16,
|
||||
cubecl::ir::IntKind::I32 => DType::I32,
|
||||
cubecl::ir::IntKind::I64 => DType::I64,
|
||||
},
|
||||
cubecl::ir::ElemType::UInt(uint_kind) => match uint_kind {
|
||||
cubecl::ir::UIntKind::U8 => DType::U8,
|
||||
cubecl::ir::UIntKind::U16 => DType::U16,
|
||||
cubecl::ir::UIntKind::U32 => DType::U32,
|
||||
cubecl::ir::UIntKind::U64 => DType::U64,
|
||||
},
|
||||
_ => panic!("Not a valid DType for tensors."),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DType {
|
||||
/// Returns the size of a type in bytes.
|
||||
pub const fn size(&self) -> usize {
|
||||
match self {
|
||||
DType::F64 => core::mem::size_of::<f64>(),
|
||||
DType::F32 => core::mem::size_of::<f32>(),
|
||||
DType::Flex32 => core::mem::size_of::<f32>(),
|
||||
DType::F16 => core::mem::size_of::<f16>(),
|
||||
DType::BF16 => core::mem::size_of::<bf16>(),
|
||||
DType::I64 => core::mem::size_of::<i64>(),
|
||||
DType::I32 => core::mem::size_of::<i32>(),
|
||||
DType::I16 => core::mem::size_of::<i16>(),
|
||||
DType::I8 => core::mem::size_of::<i8>(),
|
||||
DType::U64 => core::mem::size_of::<u64>(),
|
||||
DType::U32 => core::mem::size_of::<u32>(),
|
||||
DType::U16 => core::mem::size_of::<u16>(),
|
||||
DType::U8 => core::mem::size_of::<u8>(),
|
||||
DType::Bool => core::mem::size_of::<bool>(),
|
||||
DType::QFloat(scheme) => match scheme.store {
|
||||
QuantStore::Native => match scheme.value {
|
||||
QuantValue::Q8F | QuantValue::Q8S => core::mem::size_of::<i8>(),
|
||||
// e2m1 native is automatically packed by the kernels, so the actual storage is
|
||||
// 8 bits wide.
|
||||
QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => {
|
||||
core::mem::size_of::<u8>()
|
||||
}
|
||||
QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
|
||||
// Sub-byte values have fractional size
|
||||
0
|
||||
}
|
||||
},
|
||||
QuantStore::PackedU32(_) => core::mem::size_of::<u32>(),
|
||||
QuantStore::PackedNative(_) => match scheme.value {
|
||||
QuantValue::E2M1 => core::mem::size_of::<u8>(),
|
||||
_ => 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
/// Returns true if the data type is a floating point type.
|
||||
pub fn is_float(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
DType::F64 | DType::F32 | DType::Flex32 | DType::F16 | DType::BF16
|
||||
)
|
||||
}
|
||||
/// Returns true if the data type is a signed integer type.
|
||||
pub fn is_int(&self) -> bool {
|
||||
matches!(self, DType::I64 | DType::I32 | DType::I16 | DType::I8)
|
||||
}
|
||||
/// Returns true if the data type is an unsigned integer type.
|
||||
pub fn is_uint(&self) -> bool {
|
||||
matches!(self, DType::U64 | DType::U32 | DType::U16 | DType::U8)
|
||||
}
|
||||
|
||||
/// Returns true if the data type is a boolean type
|
||||
pub fn is_bool(&self) -> bool {
|
||||
matches!(self, DType::Bool)
|
||||
}
|
||||
|
||||
/// Returns the data type name.
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
DType::F64 => "f64",
|
||||
DType::F32 => "f32",
|
||||
DType::Flex32 => "flex32",
|
||||
DType::F16 => "f16",
|
||||
DType::BF16 => "bf16",
|
||||
DType::I64 => "i64",
|
||||
DType::I32 => "i32",
|
||||
DType::I16 => "i16",
|
||||
DType::I8 => "i8",
|
||||
DType::U64 => "u64",
|
||||
DType::U32 => "u32",
|
||||
DType::U16 => "u16",
|
||||
DType::U8 => "u8",
|
||||
DType::Bool => "bool",
|
||||
DType::QFloat(_) => "qfloat",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
|
||||
pub enum FloatDType {
|
||||
F64,
|
||||
F32,
|
||||
Flex32,
|
||||
F16,
|
||||
BF16,
|
||||
}
|
||||
|
||||
impl From<DType> for FloatDType {
|
||||
fn from(value: DType) -> Self {
|
||||
match value {
|
||||
DType::F64 => FloatDType::F64,
|
||||
DType::F32 => FloatDType::F32,
|
||||
DType::Flex32 => FloatDType::Flex32,
|
||||
DType::F16 => FloatDType::F16,
|
||||
DType::BF16 => FloatDType::BF16,
|
||||
_ => panic!("Expected float data type, got {value:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<FloatDType> for DType {
|
||||
fn from(value: FloatDType) -> Self {
|
||||
match value {
|
||||
FloatDType::F64 => DType::F64,
|
||||
FloatDType::F32 => DType::F32,
|
||||
FloatDType::Flex32 => DType::Flex32,
|
||||
FloatDType::F16 => DType::F16,
|
||||
FloatDType::BF16 => DType::BF16,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
|
||||
pub enum IntDType {
|
||||
I64,
|
||||
I32,
|
||||
I16,
|
||||
I8,
|
||||
U64,
|
||||
U32,
|
||||
U16,
|
||||
U8,
|
||||
}
|
||||
|
||||
impl From<DType> for IntDType {
|
||||
fn from(value: DType) -> Self {
|
||||
match value {
|
||||
DType::I64 => IntDType::I64,
|
||||
DType::I32 => IntDType::I32,
|
||||
DType::I16 => IntDType::I16,
|
||||
DType::I8 => IntDType::I8,
|
||||
DType::U64 => IntDType::U64,
|
||||
DType::U32 => IntDType::U32,
|
||||
DType::U16 => IntDType::U16,
|
||||
DType::U8 => IntDType::U8,
|
||||
_ => panic!("Expected int data type, got {value:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<IntDType> for DType {
|
||||
fn from(value: IntDType) -> Self {
|
||||
match value {
|
||||
IntDType::I64 => DType::I64,
|
||||
IntDType::I32 => DType::I32,
|
||||
IntDType::I16 => DType::I16,
|
||||
IntDType::I8 => DType::I8,
|
||||
IntDType::U64 => DType::U64,
|
||||
IntDType::U32 => DType::U32,
|
||||
IntDType::U16 => DType::U16,
|
||||
IntDType::U8 => DType::U8,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,221 @@
|
||||
pub mod dtype;
|
||||
pub mod quantization;
|
||||
pub mod shape;
|
||||
pub mod slice;
|
||||
|
||||
pub use dtype::*;
|
||||
pub use quantization::*;
|
||||
pub use shape::*;
|
||||
pub use slice::*;
|
||||
|
||||
pub use cubecl_zspace::indexing::{self, *};
|
||||
pub use cubecl_zspace::{Strides, metadata::Metadata, strides};
|
||||
|
||||
/// Check if the current tensor is contiguous.
|
||||
///
|
||||
/// A tensor is considered contiguous if its elements are stored in memory
|
||||
/// such that the stride at position `k` is equal to the product of the shapes
|
||||
/// of all dimensions greater than `k`.
|
||||
///
|
||||
/// This means that strides increase as you move from the rightmost to the leftmost dimension.
|
||||
pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
|
||||
if shape.is_empty() {
|
||||
return true;
|
||||
}
|
||||
|
||||
for (&expected, &stride) in contiguous_strides(shape).iter().zip(strides) {
|
||||
if expected != stride {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Computes the strides for a contiguous tensor with the given shape.
|
||||
///
|
||||
/// In a contiguous row-major tensor, the stride for each dimension
|
||||
/// equals the product of all dimension sizes to its right.
|
||||
pub fn contiguous_strides(shape: &[usize]) -> Strides {
|
||||
let mut strides = strides![0; shape.len()];
|
||||
let mut current = 1;
|
||||
|
||||
for (i, &dim) in shape.iter().enumerate().rev() {
|
||||
strides[i] = current;
|
||||
current *= dim;
|
||||
}
|
||||
|
||||
strides
|
||||
}
|
||||
|
||||
/// The action to take for a reshape operation.
|
||||
#[derive(Debug)]
|
||||
pub enum ReshapeAction {
|
||||
/// Updating the strides is sufficient to handle the reshape.
|
||||
UpdateStrides {
|
||||
/// The new strides.
|
||||
strides: Strides,
|
||||
},
|
||||
/// The strides are not compatible, we should recompute the buffer.
|
||||
Recompute,
|
||||
/// The strides are already correct.
|
||||
NoChange,
|
||||
}
|
||||
|
||||
/// The reshape kind.
|
||||
#[derive(Debug)]
|
||||
pub enum ReshapeAnalysis {
|
||||
/// Original tensor is contiguous, can update the strides.
|
||||
IsContiguous,
|
||||
/// Original tensor is highly permutated, can't update the strides.
|
||||
HighlyPermuted,
|
||||
/// Only batch dimensions are added, can update the strides.
|
||||
Broadcasted,
|
||||
/// Dimensions are only split, can update the strides.
|
||||
Split,
|
||||
/// Original tensor is bigger than output shape.
|
||||
SmallerRank,
|
||||
/// New shape is the same.
|
||||
NoChange,
|
||||
}
|
||||
|
||||
impl ReshapeAnalysis {
|
||||
/// Returns the proper action to take for the current analysis.
|
||||
fn action(self, shape: &[usize], strides: &[usize], shape_new: &[usize]) -> ReshapeAction {
|
||||
match self {
|
||||
ReshapeAnalysis::IsContiguous => ReshapeAction::UpdateStrides {
|
||||
strides: contiguous_strides(shape_new),
|
||||
},
|
||||
ReshapeAnalysis::NoChange => ReshapeAction::NoChange,
|
||||
ReshapeAnalysis::HighlyPermuted | ReshapeAnalysis::SmallerRank => {
|
||||
ReshapeAction::Recompute
|
||||
}
|
||||
ReshapeAnalysis::Broadcasted => {
|
||||
let shape_rank = shape.len();
|
||||
let shape_new_rank = shape_new.len();
|
||||
let n_new_batch = shape_new_rank - shape_rank;
|
||||
let num_elems = shape.iter().product::<usize>();
|
||||
let strides_new = broadcast_strides(n_new_batch, shape_rank, num_elems, strides);
|
||||
|
||||
ReshapeAction::UpdateStrides {
|
||||
strides: strides_new,
|
||||
}
|
||||
}
|
||||
ReshapeAnalysis::Split => {
|
||||
let strides_new = split_strides(shape, strides, shape_new);
|
||||
|
||||
ReshapeAction::UpdateStrides {
|
||||
strides: strides_new,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the proper action to take when reshaping a tensor.
|
||||
pub fn reshape_action(shape: &[usize], strides: &[usize], shape_new: &[usize]) -> ReshapeAction {
|
||||
reshape_analysis(shape, Some(strides), shape_new).action(shape, strides, shape_new)
|
||||
}
|
||||
|
||||
/// Calculate the new strides given added batch dimensions.
|
||||
pub fn broadcast_strides(
|
||||
n_new_batch: usize,
|
||||
rank_prev: usize,
|
||||
num_elems: usize,
|
||||
strides: &[usize],
|
||||
) -> Strides {
|
||||
let mut strides_new = strides![num_elems; rank_prev + n_new_batch];
|
||||
|
||||
for (i, s) in strides.iter().enumerate() {
|
||||
strides_new[i + n_new_batch] = *s;
|
||||
}
|
||||
|
||||
strides_new
|
||||
}
|
||||
|
||||
/// Calculate the new strides given added split dimensions.
|
||||
pub fn split_strides(shape: &[usize], strides: &[usize], shape_new: &[usize]) -> Strides {
|
||||
let mut strides_new = strides![1; shape_new.len()];
|
||||
|
||||
let mut old_idx = shape.len() - 1;
|
||||
let mut current_stride = strides[old_idx];
|
||||
let mut dim_prod = 1;
|
||||
|
||||
for (i, dim) in shape_new.iter().enumerate().rev() {
|
||||
dim_prod *= *dim;
|
||||
strides_new[i] = current_stride;
|
||||
if *dim == 1 {
|
||||
continue;
|
||||
} else if dim_prod == shape[old_idx] {
|
||||
old_idx = old_idx.saturating_sub(1);
|
||||
current_stride = strides[old_idx];
|
||||
dim_prod = 1;
|
||||
} else {
|
||||
current_stride *= *dim;
|
||||
}
|
||||
}
|
||||
|
||||
strides_new
|
||||
}
|
||||
|
||||
/// Returns the analysis of a reshape operation.
|
||||
pub fn reshape_analysis(
|
||||
shape: &[usize],
|
||||
strides: Option<&[usize]>,
|
||||
shape_new: &[usize],
|
||||
) -> ReshapeAnalysis {
|
||||
let shape_rank = shape.len();
|
||||
let shape_new_rank = shape_new.len();
|
||||
|
||||
let is_contiguous = match strides {
|
||||
Some(strides) => is_contiguous(shape, strides),
|
||||
None => false,
|
||||
};
|
||||
|
||||
if is_contiguous {
|
||||
return ReshapeAnalysis::IsContiguous;
|
||||
}
|
||||
|
||||
if shape_new_rank < shape_rank {
|
||||
return ReshapeAnalysis::SmallerRank;
|
||||
}
|
||||
|
||||
let n_new_batch = shape_new_rank - shape_rank;
|
||||
|
||||
match n_new_batch > 0 {
|
||||
true => {
|
||||
if shape == &shape_new[n_new_batch..shape_new_rank]
|
||||
&& shape_new[0..n_new_batch].iter().all(|it| *it == 1)
|
||||
{
|
||||
return ReshapeAnalysis::Broadcasted;
|
||||
} else {
|
||||
let mut dim_prod = 1;
|
||||
let mut old_idx = 0;
|
||||
for dim in shape_new {
|
||||
dim_prod *= *dim;
|
||||
|
||||
// We need to ignore unit dims because they don't affect analysis and break
|
||||
// things because they match the default `dim_prod`. If we don't do this,
|
||||
// reshapes like [2, 3] to [2, 3, 1] will panic from out of bounds access.
|
||||
if *dim == 1 {
|
||||
continue;
|
||||
} else if dim_prod == shape[old_idx] {
|
||||
dim_prod = 1;
|
||||
old_idx += 1;
|
||||
} else if dim_prod > shape[old_idx] {
|
||||
return ReshapeAnalysis::HighlyPermuted;
|
||||
}
|
||||
}
|
||||
return ReshapeAnalysis::Split;
|
||||
}
|
||||
}
|
||||
|
||||
false => {
|
||||
if shape == shape_new {
|
||||
return ReshapeAnalysis::NoChange;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
ReshapeAnalysis::HighlyPermuted
|
||||
}
|
||||
@@ -0,0 +1,393 @@
|
||||
//! Quantization data representation.
|
||||
|
||||
// Re-exported types
|
||||
pub use cubecl_common::quant::scheme::{
|
||||
BlockSize, QuantLevel, QuantMode, QuantParam, QuantScheme, QuantStore, QuantValue,
|
||||
};
|
||||
|
||||
/// Alignment (in bytes) for quantization parameters in serialized tensor data.
|
||||
///
|
||||
/// NOTE: This is currently f32-based since scales were originally always f32.
|
||||
/// With `QuantParam` now supporting different precisions (F16, BF16, etc.),
|
||||
/// this alignment may need to be revisited in the future.
|
||||
pub const QPARAM_ALIGN: usize = core::mem::align_of::<f32>();
|
||||
|
||||
use alloc::vec::Vec;
|
||||
use core::any::TypeId;
|
||||
use num_traits::PrimInt;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{DType, Metadata, Shape, bytes::Bytes};
|
||||
|
||||
#[derive(
|
||||
Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default,
|
||||
)]
|
||||
/// The precision of accumulating elements.
|
||||
pub enum QuantAcc {
|
||||
/// Full precision.
|
||||
#[default]
|
||||
F32,
|
||||
/// Half precision.
|
||||
F16,
|
||||
/// bfloat16 precision.
|
||||
BF16,
|
||||
}
|
||||
|
||||
/// Specify if the output of an operation is quantized using the scheme of the input
|
||||
/// or returned unquantized.
|
||||
#[derive(
|
||||
Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default,
|
||||
)]
|
||||
pub enum QuantPropagation {
|
||||
/// The output is quantized using the scheme of the input.
|
||||
Propagate,
|
||||
/// The output is not quantized.
|
||||
#[default]
|
||||
Inhibit,
|
||||
}
|
||||
|
||||
/// The quantization tensor data parameters.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct QParams<S> {
|
||||
/// The scaling factor.
|
||||
pub scales: S,
|
||||
}
|
||||
|
||||
/// A quantization parameter tensor descriptor.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct QParamTensor {
|
||||
/// Start of the tensor in the buffer
|
||||
pub offset_start: usize,
|
||||
/// Offset of tensor end from the end of the buffer
|
||||
pub offset_end: usize,
|
||||
/// Metadata of the tensor
|
||||
pub metadata: Metadata,
|
||||
/// Data type of the tensor
|
||||
pub dtype: DType,
|
||||
}
|
||||
|
||||
/// Calculate the shape of the quantization parameters for a given tensor and level
|
||||
pub fn params_shape(data_shape: &Shape, level: QuantLevel) -> Shape {
|
||||
match level {
|
||||
QuantLevel::Tensor => Shape::new([1]),
|
||||
QuantLevel::Block(block_size) => {
|
||||
let mut params_shape = data_shape.clone();
|
||||
let block_size = block_size.to_dim_vec(data_shape.num_dims());
|
||||
|
||||
for (shape, block_size) in params_shape.iter_mut().zip(block_size) {
|
||||
*shape = (*shape).div_ceil(block_size as usize);
|
||||
}
|
||||
|
||||
params_shape
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Quantized data bytes representation.
|
||||
///
|
||||
/// # Notes
|
||||
/// 1) The quantized values are packed into 32-bit unsigned integers. For example, int8
|
||||
/// quantized values pack 4 grouped values into a single `u32`. When unpacking these values,
|
||||
/// we make sure to retrieve only the meaningful values (and ignore the alignment padding).
|
||||
/// 2) Quantization parameters are appended to the tensor data.
|
||||
/// As such, the last bytes always correspond to the scale parameter.
|
||||
/// If the quantization scheme includes an offset (zero-point) parameter, it is next to last.
|
||||
pub struct QuantizedBytes {
|
||||
/// The quantized values and quantization parameters represented as bytes.
|
||||
pub bytes: Bytes,
|
||||
/// The quantization scheme.
|
||||
pub scheme: QuantScheme,
|
||||
/// The number of quantized elements.
|
||||
pub num_elements: usize,
|
||||
}
|
||||
|
||||
impl QuantizedBytes {
|
||||
/// Creates a new quantized bytes representation.
|
||||
pub fn new<E: bytemuck::CheckedBitPattern + bytemuck::NoUninit>(
|
||||
value: Vec<E>,
|
||||
scheme: QuantScheme,
|
||||
scales: &[f32],
|
||||
) -> Self {
|
||||
let num_elements = value.len();
|
||||
// Only used for 8-bit quantization data comparison in tests
|
||||
if TypeId::of::<E>() != TypeId::of::<i8>() {
|
||||
panic!("Invalid quantized type");
|
||||
}
|
||||
|
||||
// Re-interpret `Vec<E>` as `Vec<i8>` with `Vec::from_raw_parts`
|
||||
let i8s: Vec<i8> = bytemuck::allocation::cast_vec(value);
|
||||
let mut bytes = Bytes::from_elems(i8s);
|
||||
|
||||
match scheme.level {
|
||||
QuantLevel::Tensor => {
|
||||
let scale_bytes = bytemuck::bytes_of(&scales[0]);
|
||||
bytes.extend_from_byte_slice_aligned(scale_bytes, QPARAM_ALIGN);
|
||||
}
|
||||
QuantLevel::Block(_block_size) => {
|
||||
let mut scale_bytes = Vec::with_capacity(size_of_val(scales));
|
||||
for scale in scales {
|
||||
scale_bytes.extend_from_slice(bytemuck::bytes_of(scale));
|
||||
}
|
||||
bytes.extend_from_byte_slice_aligned(scale_bytes.as_slice(), QPARAM_ALIGN);
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
bytes,
|
||||
scheme,
|
||||
num_elements,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the int8 quantized values with the quantization parameters.
|
||||
pub fn into_vec_i8(self) -> (Vec<i8>, QParams<Vec<f32>>) {
|
||||
let (values, (qparams, num_params)) = self.split_values_off();
|
||||
|
||||
// Quantization parameters are added at the end of the tensor data.
|
||||
// As such, the last bytes always correspond to the scale parameter(s).
|
||||
// For example, per-block quantization can have multiple parameters for a single tensor:
|
||||
// [scale, scale, scale, ...]
|
||||
let scale_size = core::mem::size_of::<f32>(); // scale is stored as f32
|
||||
let qparams_bytes: &[u8] = bytemuck::cast_slice(&qparams);
|
||||
let total_bytes = qparams_bytes.len();
|
||||
|
||||
let scales_size = scale_size * num_params;
|
||||
|
||||
let scales = bytemuck::cast_slice(&qparams_bytes[total_bytes - scales_size..]).to_vec();
|
||||
|
||||
(values, QParams { scales })
|
||||
}
|
||||
|
||||
fn split_i8_values(self, num_params: usize) -> (Vec<i8>, Vec<u32>) {
|
||||
let mut values = read_bytes_to_i8(self.bytes);
|
||||
|
||||
let scale_size = num_params * size_of::<f32>();
|
||||
let values_end = values.len() - scale_size;
|
||||
|
||||
let qparams = values.split_off(values_end);
|
||||
|
||||
let qparams = if (qparams.as_ptr() as usize).is_multiple_of(4) {
|
||||
let mut qparams = core::mem::ManuallyDrop::new(qparams);
|
||||
unsafe {
|
||||
Vec::<u32>::from_raw_parts(
|
||||
qparams.as_mut_ptr() as _,
|
||||
qparams.len() / 4,
|
||||
qparams.capacity() / 4,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
#[cfg(target_endian = "little")]
|
||||
{
|
||||
// SAFETY: quantized bytes representation is created from packed u32 values in little endian
|
||||
bytemuck::cast_vec(qparams)
|
||||
}
|
||||
#[cfg(target_endian = "big")]
|
||||
{
|
||||
crate::quantization::pack_i8s_to_u32s(bytemuck::cast_vec(qparams))
|
||||
}
|
||||
};
|
||||
(values, qparams)
|
||||
}
|
||||
|
||||
/// Splits the quantized values of the tensor from the quantization parameters.
|
||||
///
|
||||
/// Returns the values in i8 and a newly allocated vector containing the quantization parameters.
|
||||
fn split_values_off(self) -> (Vec<i8>, (Vec<u32>, usize)) {
|
||||
let num_params = match self.scheme.level {
|
||||
QuantLevel::Tensor => 1,
|
||||
QuantLevel::Block(block_size) => self.num_elements / block_size.num_elements(),
|
||||
};
|
||||
|
||||
if let QuantStore::PackedU32(packed_dim) = self.scheme.store {
|
||||
assert_eq!(
|
||||
packed_dim, 0,
|
||||
"Packing must be on innermost dimension for splitting off values"
|
||||
);
|
||||
}
|
||||
|
||||
let (values, qparams) = match self.scheme.store {
|
||||
QuantStore::Native => self.split_i8_values(num_params),
|
||||
QuantStore::PackedU32(_) => match self.scheme.value {
|
||||
QuantValue::Q8F | QuantValue::Q8S => self.split_i8_values(num_params),
|
||||
QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
|
||||
let mut values = self.bytes.try_into_vec::<u32>().unwrap();
|
||||
let scale_size = num_params; // size of f32 same as u32
|
||||
let values_end = values.len() - scale_size;
|
||||
|
||||
let qparams = values.split_off(values_end);
|
||||
// Sub-byte values are unpacked as i8s for value equality tests
|
||||
let values = unpack_q_to_i8s(&values, self.num_elements, &self.scheme.value);
|
||||
(values, qparams)
|
||||
}
|
||||
QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => {
|
||||
unimplemented!("Not yet supported")
|
||||
}
|
||||
},
|
||||
QuantStore::PackedNative(_) => unimplemented!("Not yet supported"),
|
||||
};
|
||||
|
||||
(values, (qparams, num_params))
|
||||
}
|
||||
}
|
||||
|
||||
fn read_bytes_to_i8(bytes: Bytes) -> Vec<i8> {
|
||||
match bytes.try_into_vec::<i8>() {
|
||||
Ok(val) => val,
|
||||
// Safety,
|
||||
//
|
||||
// `Vec<u8>` can be Re-interpreted as `Vec<i8>` since they share the same alignment.
|
||||
Err(bytes) => unsafe { core::mem::transmute::<Vec<u8>, Vec<i8>>(bytes.to_vec()) },
|
||||
}
|
||||
}
|
||||
|
||||
/// Pack signed 8-bit integer values into a sequence of unsigned 32-bit integers.
|
||||
pub fn pack_i8s_to_u32s(values: Vec<i8>) -> Vec<u32> {
|
||||
// Shift and combine groups of four 8-bit values into a u32.
|
||||
// Same as doing this:
|
||||
// let result = (d_u8 & 0xFF) << 24 | (c_u8 & 0xFF) << 16 | (b_u8 & 0xFF) << 8 | (a_u8 & 0xFF);
|
||||
#[cfg(target_endian = "big")]
|
||||
{
|
||||
values
|
||||
.chunks(4)
|
||||
.map(|x| {
|
||||
x.iter()
|
||||
.enumerate()
|
||||
.fold(0u32, |acc, (i, x)| acc | (*x as u32 & 0xFF) << (i * 8))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
// The order of bytes in little endian matches the above description, we just need to
|
||||
// handle padding when the number of values is not a factor of 4
|
||||
#[cfg(target_endian = "little")]
|
||||
{
|
||||
let mut values = values;
|
||||
let remainder = values.len() % 4;
|
||||
if remainder != 0 {
|
||||
// Pad with zeros
|
||||
values.extend(core::iter::repeat_n(0, 4 - remainder));
|
||||
}
|
||||
|
||||
let len = values.len() / 4;
|
||||
let capacity = values.capacity() / 4;
|
||||
|
||||
// Pre-forget the old vec and re-interpret as u32
|
||||
let mut values = core::mem::ManuallyDrop::new(values);
|
||||
let ptr = values.as_mut_ptr() as *mut u32;
|
||||
|
||||
unsafe { Vec::from_raw_parts(ptr, len, capacity) }
|
||||
}
|
||||
}
|
||||
|
||||
/// Unpack integer values into a sequence of signed 8-bit integers.
|
||||
pub(crate) fn unpack_q_to_i8s<Q: PrimInt>(
|
||||
values: &[Q],
|
||||
numel: usize,
|
||||
value: &QuantValue,
|
||||
) -> Vec<i8> {
|
||||
let size_store = size_of::<Q>() * 8;
|
||||
let size_quant = value.size_bits();
|
||||
let num_quants = size_store / size_quant;
|
||||
let mask = Q::from((1 << size_quant) - 1).unwrap();
|
||||
let sign_shift = 8 - size_quant; // sign extension for sub-byte values
|
||||
values
|
||||
.iter()
|
||||
.enumerate()
|
||||
.flat_map(|(i, &packed)| {
|
||||
// A single u32 could contain less than four 8-bit values...
|
||||
let n = core::cmp::min(num_quants, numel - i * num_quants);
|
||||
// Extract each 8-bit segment from u32 and cast back to i8
|
||||
// Same as doing this (when 4 values are fully packed):
|
||||
// let a = (packed & 0xFF) as i8;
|
||||
// let b = ((packed >> 8) & 0xFF) as i8;
|
||||
// let c = ((packed >> 16) & 0xFF) as i8;
|
||||
// let d = ((packed >> 24) & 0xFF) as i8;
|
||||
(0..n).map(move |i| {
|
||||
let raw = (packed >> (i * size_quant) & mask).to_u8().unwrap();
|
||||
((raw << sign_shift) as i8) >> sign_shift
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::*;
|
||||
use alloc::vec;
|
||||
|
||||
#[test]
|
||||
fn should_pack_i8s_to_u32() {
|
||||
let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127]);
|
||||
|
||||
assert_eq!(packed, vec![2147287680]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_pack_i8s_to_u32_padded() {
|
||||
let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55]);
|
||||
let packed_padded = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55, 0, 0, 0]);
|
||||
|
||||
assert_eq!(packed, vec![2147287680, 55]);
|
||||
assert_eq!(packed, packed_padded);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_unpack_u32s_to_i8s() {
|
||||
let unpacked = unpack_q_to_i8s(&[2147287680u32], 4, &QuantValue::Q8S);
|
||||
|
||||
assert_eq!(unpacked, vec![-128, 2, -3, 127]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_unpack_u32s_to_i8s_padded() {
|
||||
let unpacked = unpack_q_to_i8s(&[55u32], 1, &QuantValue::Q8S);
|
||||
|
||||
assert_eq!(unpacked, vec![55]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_unpack_u32s_to_i8s_arange() {
|
||||
let unpacked = unpack_q_to_i8s(
|
||||
&[
|
||||
0u32, 286331136, 286331153, 572657937, 572662306, 857874978, 858993459, 858993459,
|
||||
1145324612, 1145324612, 1431655748, 1431655765, 1717982549, 1717986918, 2003199590,
|
||||
2004318071,
|
||||
],
|
||||
128,
|
||||
&QuantValue::Q4S,
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
unpacked,
|
||||
vec![
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
|
||||
3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5,
|
||||
5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
|
||||
6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_pack_unpack_quantization_parameters_per_tensor_symmetric() {
|
||||
// Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
|
||||
let scale = 0.03937008;
|
||||
let values = vec![0i8, 25, 51, 76, 102, 127];
|
||||
|
||||
let q_bytes = QuantizedBytes::new(
|
||||
values.clone(),
|
||||
QuantScheme::default()
|
||||
.with_value(QuantValue::Q8S)
|
||||
.with_store(QuantStore::Native),
|
||||
&[scale],
|
||||
);
|
||||
|
||||
let (q_values, qparams) = q_bytes.into_vec_i8();
|
||||
|
||||
assert_eq!(qparams.scales, vec![scale]);
|
||||
|
||||
assert_eq!(q_values, values);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,271 @@
|
||||
//! Tensor shape definition.
|
||||
|
||||
use super::{Slice, SliceArg};
|
||||
use alloc::vec::Vec;
|
||||
use core::ops::Range;
|
||||
|
||||
pub use crate::errors::ExpressionError;
|
||||
|
||||
pub use cubecl_zspace::{MetadataError, Shape, calculate_matmul_output, shape};
|
||||
|
||||
/// Slice-related ops on [`Shape`]
|
||||
pub trait SliceOps: Sized {
|
||||
/// Convert shape dimensions to full covering ranges (0..dim) for each dimension.
|
||||
fn into_ranges(self) -> Vec<Range<usize>>;
|
||||
/// Converts slice arguments into an array of slice specifications for the shape.
|
||||
///
|
||||
/// This method returns an array of `Slice` objects that can be used for slicing operations.
|
||||
/// The slices are clamped to the shape's dimensions. Similar to `into_ranges()`, but
|
||||
/// allows custom slice specifications instead of full ranges.
|
||||
/// For creating complex slice specifications, use the [`s!`] macro.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `slices` - An array of slice specifications, where each element can be:
|
||||
/// - A range (e.g., `2..5`)
|
||||
/// - An index
|
||||
/// - A `Slice` object
|
||||
/// - The output of the [`s!`] macro for advanced slicing
|
||||
///
|
||||
/// # Behavior
|
||||
///
|
||||
/// - Supports partial and full slicing in any number of dimensions.
|
||||
/// - Missing ranges are treated as full slices if D > D2.
|
||||
/// - Handles negative indices by wrapping around from the end of the dimension.
|
||||
/// - Clamps ranges to the shape's dimensions if they exceed the bounds.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An array of `Slice` objects corresponding to the provided slice specifications,
|
||||
/// clamped to the shape's actual dimensions.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust
|
||||
/// use burn_std::{Shape, Slice, s, SliceOps};
|
||||
///
|
||||
/// fn example() {
|
||||
/// // 1D slicing
|
||||
/// let slices = Shape::new([4]).into_slices(1..4);
|
||||
/// assert_eq!(slices[0].to_range(4), 1..3);
|
||||
///
|
||||
/// // 2D slicing
|
||||
/// let slices = Shape::new([3, 4]).into_slices(s![1..4, 0..2]);
|
||||
/// assert_eq!(slices[0].to_range(3), 1..3);
|
||||
/// assert_eq!(slices[1].to_range(4), 0..2);
|
||||
///
|
||||
/// // Using negative indices
|
||||
/// let slices = Shape::new([3]).into_slices(..-2);
|
||||
/// assert_eq!(slices[0].to_range(3), 0..1);
|
||||
///
|
||||
/// // Using the slice macro to select different ranges
|
||||
/// let slices = Shape::new([2, 3, 4]).into_slices(s![.., 1..-1]);
|
||||
/// assert_eq!(slices[0].to_range(2), 0..2);
|
||||
/// assert_eq!(slices[1].to_range(3), 1..2);
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// # See Also
|
||||
///
|
||||
/// - [`s!`] - The recommended macro for creating slice specifications
|
||||
/// - [`Shape::into_ranges`] - Convert to full covering ranges
|
||||
///
|
||||
/// [`s!`]: crate::s!
|
||||
fn into_slices<S>(self, slices: S) -> Vec<Slice>
|
||||
where
|
||||
S: SliceArg;
|
||||
/// Compute the output shape from the given slices.
|
||||
fn slice(self, slices: &[Slice]) -> Result<Self, MetadataError>;
|
||||
}
|
||||
|
||||
impl SliceOps for Shape {
|
||||
fn into_ranges(self) -> Vec<Range<usize>> {
|
||||
self.iter().map(|&d| 0..d).collect()
|
||||
}
|
||||
|
||||
fn into_slices<S>(self, slices: S) -> Vec<Slice>
|
||||
where
|
||||
S: SliceArg,
|
||||
{
|
||||
slices.into_slices(&self)
|
||||
}
|
||||
|
||||
fn slice(mut self, slices: &[Slice]) -> Result<Self, MetadataError> {
|
||||
if slices.len() > self.rank() {
|
||||
return Err(MetadataError::RankMismatch {
|
||||
left: self.rank(),
|
||||
right: slices.len(),
|
||||
});
|
||||
}
|
||||
|
||||
slices
|
||||
.iter()
|
||||
.zip(self.iter_mut())
|
||||
.for_each(|(slice, dim_size)| *dim_size = slice.output_size(*dim_size));
|
||||
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(clippy::identity_op, reason = "useful for clarity")]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::s;
|
||||
use alloc::vec;
|
||||
|
||||
#[test]
|
||||
fn test_into_ranges() {
|
||||
let dims = [2, 3, 4, 5];
|
||||
let shape = Shape::new(dims);
|
||||
assert_eq!(shape.into_ranges(), vec![0..2, 0..3, 0..4, 0..5]);
|
||||
}
|
||||
|
||||
#[allow(clippy::single_range_in_vec_init)]
|
||||
#[test]
|
||||
fn test_into_slices() {
|
||||
let slices = Shape::new([3]).into_slices(1..4);
|
||||
assert_eq!(slices[0].to_range(3), 1..3);
|
||||
|
||||
let slices = Shape::new([3, 4]).into_slices(s![1..4, 0..2]);
|
||||
assert_eq!(slices[0].to_range(3), 1..3);
|
||||
assert_eq!(slices[1].to_range(4), 0..2);
|
||||
|
||||
let slices = Shape::new([3]).into_slices(..-2);
|
||||
assert_eq!(slices[0].to_range(3), 0..1);
|
||||
|
||||
let slices = Shape::new([2, 3, 4]).into_slices(s![.., 1..-1]);
|
||||
assert_eq!(slices[0].to_range(2), 0..2);
|
||||
assert_eq!(slices[1].to_range(3), 1..2);
|
||||
|
||||
let slices = Shape::new([2, 3, 4]).into_slices(s![..20, 2]);
|
||||
assert_eq!(slices[0].to_range(2), 0..2);
|
||||
assert_eq!(slices[1].to_range(3), 2..3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shape_as_slice() {
|
||||
let dims = [2, 3, 4, 5];
|
||||
let shape = Shape::new(dims);
|
||||
|
||||
assert_eq!(shape.as_slice(), dims.as_slice());
|
||||
|
||||
// Deref coercion
|
||||
let shape_slice: &[usize] = &shape;
|
||||
assert_eq!(shape_slice, *&[2, 3, 4, 5]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shape_as_mut_slice() {
|
||||
let mut dims = [2, 3, 4, 5];
|
||||
let mut shape = Shape::new(dims);
|
||||
|
||||
let shape_mut = shape.as_mut_slice();
|
||||
assert_eq!(shape_mut, dims.as_mut_slice());
|
||||
shape_mut[1] = 6;
|
||||
|
||||
assert_eq!(shape_mut, &[2, 6, 4, 5]);
|
||||
|
||||
let mut shape = Shape::new(dims);
|
||||
let shape = &mut shape[..];
|
||||
shape[1] = 6;
|
||||
|
||||
assert_eq!(shape, shape_mut)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shape_slice_output_shape_basic() {
|
||||
// Test basic slicing with step=1
|
||||
let slices = [
|
||||
Slice::new(0, Some(5), 1), // 5 elements
|
||||
Slice::new(2, Some(8), 1), // 6 elements
|
||||
];
|
||||
let original_shape = Shape::new([10, 10, 10]);
|
||||
let result = original_shape.slice(&slices).unwrap();
|
||||
assert_eq!(result, Shape::new([5, 6, 10]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shape_slice_output_shape_with_positive_steps() {
|
||||
// Test slicing with various positive steps
|
||||
let slices = [
|
||||
Slice::new(0, Some(10), 2), // [0,2,4,6,8] -> 5 elements
|
||||
Slice::new(1, Some(9), 3), // [1,4,7] -> 3 elements
|
||||
Slice::new(0, Some(7), 4), // [0,4] -> 2 elements
|
||||
];
|
||||
let original_shape = Shape::new([20, 20, 20, 30]);
|
||||
let result = original_shape.slice(&slices).unwrap();
|
||||
assert_eq!(result, Shape::new([5, 3, 2, 30]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shape_slice_output_shape_with_negative_steps() {
|
||||
// Test slicing with negative steps (backward iteration)
|
||||
let slices = [
|
||||
Slice::new(0, Some(10), -1), // 10 elements traversed backward
|
||||
Slice::new(2, Some(8), -2), // [7,5,3] -> 3 elements
|
||||
];
|
||||
let original_shape = Shape::new([20, 20, 20]);
|
||||
let result = original_shape.slice(&slices).unwrap();
|
||||
assert_eq!(result, Shape::new([10, 3, 20]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shape_slice_output_shape_mixed_steps() {
|
||||
// Test with a mix of positive, negative, and unit steps
|
||||
let slices = [
|
||||
Slice::from_range_stepped(1..6, 1), // 5 elements
|
||||
Slice::from_range_stepped(0..10, -3), // [9,6,3,0] -> 4 elements
|
||||
Slice::from_range_stepped(2..14, 4), // [2,6,10] -> 3 elements
|
||||
];
|
||||
let original_shape = Shape::new([20, 20, 20]);
|
||||
let result = original_shape.slice(&slices).unwrap();
|
||||
assert_eq!(result, Shape::new([5, 4, 3]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shape_slice_output_shape_partial_dims() {
|
||||
// Test when slices has fewer dimensions than original shape
|
||||
let slices = [
|
||||
Slice::from_range_stepped(2..7, 2), // [2,4,6] -> 3 elements
|
||||
];
|
||||
let original_shape = Shape::new([10, 20, 30, 40]);
|
||||
let result = original_shape.slice(&slices).unwrap();
|
||||
assert_eq!(result, Shape::new([3, 20, 30, 40]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shape_slice_output_shape_edge_cases() {
|
||||
// Test edge cases with small ranges and large steps
|
||||
let slices = [
|
||||
Slice::from_range_stepped(0..1, 1), // Single element
|
||||
Slice::from_range_stepped(0..10, 100), // Step larger than range -> 1 element
|
||||
Slice::from_range_stepped(5..5, 1), // Empty range -> 0 elements
|
||||
];
|
||||
let original_shape = Shape::new([10, 20, 30]);
|
||||
let result = original_shape.slice(&slices).unwrap();
|
||||
assert_eq!(result, Shape::new([1, 1, 0]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shape_slice_output_shape_empty() {
|
||||
// Test with no slice infos (should return original shape)
|
||||
let slices = [];
|
||||
let original_shape = Shape::new([10, 20, 30]);
|
||||
let result = original_shape.slice(&slices).unwrap();
|
||||
assert_eq!(result, Shape::new([10, 20, 30]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shape_slice_output_shape_uneven_division() {
|
||||
// Test cases where range size doesn't divide evenly by step
|
||||
let slices = [
|
||||
Slice::from_range_stepped(0..7, 3), // ceil(7/3) = 3 elements: [0,3,6]
|
||||
Slice::from_range_stepped(0..11, 4), // ceil(11/4) = 3 elements: [0,4,8]
|
||||
Slice::from_range_stepped(1..10, 5), // ceil(9/5) = 2 elements: [1,6]
|
||||
];
|
||||
let original_shape = Shape::new([20, 20, 20]);
|
||||
let result = original_shape.slice(&slices).unwrap();
|
||||
assert_eq!(result, Shape::new([3, 3, 2]));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,937 @@
|
||||
//! Tensor slice utilities.
|
||||
|
||||
use crate::Shape;
|
||||
use crate::indexing::AsIndex;
|
||||
use alloc::format;
|
||||
use alloc::vec::Vec;
|
||||
use core::fmt::{Display, Formatter};
|
||||
use core::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive};
|
||||
use core::str::FromStr;
|
||||
|
||||
/// Trait for slice arguments that can be converted into an array of slices.
|
||||
/// This allows the `slice` method to accept both single slices (from `s![..]`)
|
||||
/// and arrays of slices (from `s![.., ..]` or `[0..5, 1..3]`).
|
||||
pub trait SliceArg {
|
||||
/// Convert to an vec of slices with clamping to shape dimensions.
|
||||
///
|
||||
/// Returns a [Slice] for each dimension in `shape`.
|
||||
fn into_slices(self, shape: &Shape) -> Vec<Slice>;
|
||||
}
|
||||
|
||||
impl<S: Into<Slice> + Clone> SliceArg for &[S] {
|
||||
fn into_slices(self, shape: &Shape) -> Vec<Slice> {
|
||||
assert!(
|
||||
self.len() <= shape.num_dims(),
|
||||
"Too many slices provided for shape, got {} but expected at most {}",
|
||||
self.len(),
|
||||
shape.num_dims()
|
||||
);
|
||||
|
||||
shape
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, dim_size)| {
|
||||
let slice = if i >= self.len() {
|
||||
Slice::full()
|
||||
} else {
|
||||
self[i].clone().into()
|
||||
};
|
||||
// Apply shape clamping by converting to range and back
|
||||
let clamped_range = slice.to_range(*dim_size);
|
||||
Slice::new(
|
||||
clamped_range.start as isize,
|
||||
Some(clamped_range.end as isize),
|
||||
slice.step(),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
}
|
||||
|
||||
impl SliceArg for &Vec<Slice> {
|
||||
fn into_slices(self, shape: &Shape) -> Vec<Slice> {
|
||||
self.as_slice().into_slices(shape)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize, T> SliceArg for [T; R]
|
||||
where
|
||||
T: Into<Slice> + Clone,
|
||||
{
|
||||
fn into_slices(self, shape: &Shape) -> Vec<Slice> {
|
||||
self.as_slice().into_slices(shape)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> SliceArg for T
|
||||
where
|
||||
T: Into<Slice>,
|
||||
{
|
||||
fn into_slices(self, shape: &Shape) -> Vec<Slice> {
|
||||
let slice: Slice = self.into();
|
||||
[slice].as_slice().into_slices(shape)
|
||||
}
|
||||
}
|
||||
|
||||
/// Slice argument constructor for tensor indexing.
|
||||
///
|
||||
/// The `s![]` macro is used to create multi-dimensional slice specifications for tensors.
|
||||
/// It converts various range syntax forms into a `&[Slice]` that can be used with
|
||||
/// `tensor.slice()` and `tensor.slice_assign()` operations.
|
||||
///
|
||||
/// # Syntax Overview
|
||||
///
|
||||
/// ## Basic Forms
|
||||
///
|
||||
/// * **`s![index]`** - Index a single element (produces a subview with that axis removed)
|
||||
/// * **`s![range]`** - Slice a range of elements
|
||||
/// * **`s![range;step]`** - Slice a range with a custom step
|
||||
/// * **`s![dim1, dim2, ...]`** - Multiple dimensions, each can be any of the above forms
|
||||
///
|
||||
/// ## Range Types
|
||||
///
|
||||
/// All standard Rust range types are supported:
|
||||
/// * **`a..b`** - From `a` (inclusive) to `b` (exclusive)
|
||||
/// * **`a..=b`** - From `a` to `b` (both inclusive)
|
||||
/// * **`a..`** - From `a` to the end
|
||||
/// * **`..b`** - From the beginning to `b` (exclusive)
|
||||
/// * **`..=b`** - From the beginning to `b` (inclusive)
|
||||
/// * **`..`** - The full range (all elements)
|
||||
///
|
||||
/// ## Negative Indices
|
||||
///
|
||||
/// Negative indices count from the end of the axis:
|
||||
/// * **`-1`** refers to the last element
|
||||
/// * **`-2`** refers to the second-to-last element
|
||||
/// * And so on...
|
||||
///
|
||||
/// This works in all range forms: `s![-3..-1]`, `s![-2..]`, `s![..-1]`
|
||||
///
|
||||
/// ## Step Syntax
|
||||
///
|
||||
/// Steps control the stride between selected elements:
|
||||
/// * **`;step`** after a range specifies the step
|
||||
/// * **Positive steps** select every nth element going forward
|
||||
/// * **Negative steps** select every nth element going backward
|
||||
/// * Default step is `1` when not specified
|
||||
/// * Step cannot be `0`
|
||||
///
|
||||
/// ### Negative Step Behavior
|
||||
///
|
||||
/// With negative steps, the range bounds still specify *which* elements to include,
|
||||
/// but the traversal order is reversed:
|
||||
///
|
||||
/// * `s![0..5;-1]` selects indices `[4, 3, 2, 1, 0]` (not `[0, 1, 2, 3, 4]`)
|
||||
/// * `s![2..8;-2]` selects indices `[7, 5, 3]` (starting from 7, going backward by 2)
|
||||
/// * `s![..;-1]` reverses the entire axis
|
||||
///
|
||||
/// This matches the semantics of NumPy and the ndarray crate.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ## Basic Slicing
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use burn_tensor::{Tensor, s};
|
||||
///
|
||||
/// # fn example<B: Backend>(tensor: Tensor<B, 3>) {
|
||||
/// // Select rows 0-5 (exclusive)
|
||||
/// let subset = tensor.slice(s![0..5, .., ..]);
|
||||
///
|
||||
/// // Select the last row
|
||||
/// let last_row = tensor.slice(s![-1, .., ..]);
|
||||
///
|
||||
/// // Select columns 2, 3, 4
|
||||
/// let cols = tensor.slice(s![.., 2..5, ..]);
|
||||
///
|
||||
/// // Select a single element at position [1, 2, 3]
|
||||
/// let element = tensor.slice(s![1, 2, 3]);
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
/// ## Slicing with Steps
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use burn_tensor::{Tensor, s};
|
||||
///
|
||||
/// # fn example<B: Backend>(tensor: Tensor<B, 2>) {
|
||||
/// // Select every 2nd row
|
||||
/// let even_rows = tensor.slice(s![0..10;2, ..]);
|
||||
///
|
||||
/// // Select every 3rd column
|
||||
/// let cols = tensor.slice(s![.., 0..9;3]);
|
||||
///
|
||||
/// // Select every 2nd element in reverse order
|
||||
/// let reversed_even = tensor.slice(s![10..0;-2, ..]);
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
/// ## Reversing Dimensions
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use burn_tensor::{Tensor, s};
|
||||
///
|
||||
/// # fn example<B: Backend>(tensor: Tensor<B, 2>) {
|
||||
/// // Reverse the first dimension
|
||||
/// let reversed = tensor.slice(s![..;-1, ..]);
|
||||
///
|
||||
/// // Reverse both dimensions
|
||||
/// let fully_reversed = tensor.slice(s![..;-1, ..;-1]);
|
||||
///
|
||||
/// // Reverse a specific range
|
||||
/// let range_reversed = tensor.slice(s![2..8;-1, ..]);
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
/// ## Complex Multi-dimensional Slicing
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use burn_tensor::{Tensor, s};
|
||||
///
|
||||
/// # fn example<B: Backend>(tensor: Tensor<B, 4>) {
|
||||
/// // Mix of different slice types
|
||||
/// let complex = tensor.slice(s![
|
||||
/// 0..10;2, // Every 2nd element from 0 to 10
|
||||
/// .., // All elements in dimension 1
|
||||
/// 5..15;-3, // Every 3rd element from 14 down to 5
|
||||
/// -1 // Last element in dimension 3
|
||||
/// ]);
|
||||
///
|
||||
/// // Using inclusive ranges
|
||||
/// let inclusive = tensor.slice(s![2..=5, 1..=3, .., ..]);
|
||||
///
|
||||
/// // Negative indices with steps
|
||||
/// let from_end = tensor.slice(s![-5..-1;2, .., .., ..]);
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
/// ## Slice Assignment
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use burn_tensor::{Tensor, s};
|
||||
///
|
||||
/// # fn example<B: Backend>(tensor: Tensor<B, 2>, values: Tensor<B, 2>) {
|
||||
/// // Assign to every 2nd row
|
||||
/// let tensor = tensor.slice_assign(s![0..10;2, ..], values);
|
||||
///
|
||||
/// // Assign to a reversed slice
|
||||
/// let tensor = tensor.slice_assign(s![..;-1, 0..5], values);
|
||||
/// # }
|
||||
/// ```
|
||||
#[macro_export]
|
||||
macro_rules! s {
|
||||
// Empty - should not happen
|
||||
[] => {
|
||||
compile_error!("Empty slice specification")
|
||||
};
|
||||
|
||||
// Single expression with step
|
||||
[$range:expr; $step:expr] => {
|
||||
{
|
||||
#[allow(clippy::reversed_empty_ranges)]
|
||||
{
|
||||
$crate::tensor::Slice::from_range_stepped($range, $step)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Single expression without step (no comma after)
|
||||
[$range:expr] => {
|
||||
{
|
||||
#[allow(clippy::reversed_empty_ranges)]
|
||||
{
|
||||
$crate::tensor::Slice::from($range)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Two or more expressions with first having step
|
||||
[$range:expr; $step:expr, $($rest:tt)*] => {
|
||||
{
|
||||
#[allow(clippy::reversed_empty_ranges)]
|
||||
{
|
||||
$crate::s!(@internal [$crate::tensor::Slice::from_range_stepped($range, $step)] $($rest)*)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Two or more expressions with first not having step
|
||||
[$range:expr, $($rest:tt)*] => {
|
||||
{
|
||||
#[allow(clippy::reversed_empty_ranges)]
|
||||
{
|
||||
$crate::s!(@internal [$crate::tensor::Slice::from($range)] $($rest)*)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Internal: finished parsing
|
||||
(@internal [$($acc:expr),*]) => {
|
||||
[$($acc),*]
|
||||
};
|
||||
|
||||
// Internal: parse range with step followed by comma
|
||||
(@internal [$($acc:expr),*] $range:expr; $step:expr, $($rest:tt)*) => {
|
||||
$crate::s!(@internal [$($acc,)* $crate::tensor::Slice::from_range_stepped($range, $step as isize)] $($rest)*)
|
||||
};
|
||||
|
||||
// Internal: parse range with step at end
|
||||
(@internal [$($acc:expr),*] $range:expr; $step:expr) => {
|
||||
$crate::s!(@internal [$($acc,)* $crate::tensor::Slice::from_range_stepped($range, $step as isize)])
|
||||
};
|
||||
|
||||
// Internal: parse range without step followed by comma
|
||||
(@internal [$($acc:expr),*] $range:expr, $($rest:tt)*) => {
|
||||
$crate::s!(@internal [$($acc,)* $crate::tensor::Slice::from($range)] $($rest)*)
|
||||
};
|
||||
|
||||
// Internal: parse range without step at end
|
||||
(@internal [$($acc:expr),*] $range:expr) => {
|
||||
$crate::s!(@internal [$($acc,)* $crate::tensor::Slice::from($range)])
|
||||
};
|
||||
}
|
||||
|
||||
/// A slice specification for a single tensor dimension.
|
||||
///
|
||||
/// This struct represents a range with an optional step, used for advanced indexing
|
||||
/// operations on tensors. It is typically created using the [`s!`] macro rather than
|
||||
/// constructed directly.
|
||||
///
|
||||
/// # Fields
|
||||
///
|
||||
/// * `start` - The starting index (inclusive). Negative values count from the end.
|
||||
/// * `end` - The ending index (exclusive). `None` means to the end of the dimension.
|
||||
/// * `step` - The stride between elements. Must be non-zero.
|
||||
///
|
||||
/// # Index Interpretation
|
||||
///
|
||||
/// - **Positive indices**: Count from the beginning (0-based)
|
||||
/// - **Negative indices**: Count from the end (-1 is the last element)
|
||||
/// - **Bounds checking**: Indices are clamped to valid ranges
|
||||
///
|
||||
/// # Step Behavior
|
||||
///
|
||||
/// - **Positive step**: Traverse forward through the range
|
||||
/// - **Negative step**: Traverse backward through the range
|
||||
/// - **Step size**: Determines how many elements to skip
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// While you typically use the [`s!`] macro, you can also construct slices directly:
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use burn_tensor::Slice;
|
||||
///
|
||||
/// // Equivalent to s![2..8]
|
||||
/// let slice1 = Slice::new(2, Some(8), 1);
|
||||
///
|
||||
/// // Equivalent to s![0..10;2]
|
||||
/// let slice2 = Slice::new(0, Some(10), 2);
|
||||
///
|
||||
/// // Equivalent to s![..;-1] (reverse)
|
||||
/// let slice3 = Slice::new(0, None, -1);
|
||||
/// ```
|
||||
///
|
||||
/// See also the [`s!`] macro for the preferred way to create slices.
|
||||
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
|
||||
pub struct Slice {
|
||||
/// Slice start index.
|
||||
pub start: isize,
|
||||
/// Slice end index (exclusive).
|
||||
pub end: Option<isize>,
|
||||
/// Step between elements (default: 1).
|
||||
pub step: isize,
|
||||
}
|
||||
|
||||
/// Defines an [`Iterator`] over a [`Slice`].
|
||||
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
|
||||
pub struct SliceIter {
|
||||
slice: Slice,
|
||||
current: isize,
|
||||
}
|
||||
|
||||
impl Iterator for SliceIter {
|
||||
type Item = isize;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
let next = self.current;
|
||||
self.current += self.slice.step;
|
||||
|
||||
if let Some(end) = self.slice.end {
|
||||
if self.slice.is_reversed() {
|
||||
if next <= end {
|
||||
return None;
|
||||
}
|
||||
} else if next >= end {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
/// Note: Unbounded [`Slice`]s produce infinite iterators.
|
||||
impl IntoIterator for Slice {
|
||||
type Item = isize;
|
||||
type IntoIter = SliceIter;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
SliceIter {
|
||||
slice: self,
|
||||
current: self.start,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Slice {
|
||||
fn default() -> Self {
|
||||
Self::full()
|
||||
}
|
||||
}
|
||||
|
||||
impl Slice {
|
||||
/// Creates a new slice with start, end, and step
|
||||
pub const fn new(start: isize, end: Option<isize>, step: isize) -> Self {
|
||||
assert!(step != 0, "Step cannot be zero");
|
||||
Self { start, end, step }
|
||||
}
|
||||
|
||||
/// Creates a slice that represents the full range.
|
||||
pub const fn full() -> Self {
|
||||
Self::new(0, None, 1)
|
||||
}
|
||||
|
||||
/// Creates a slice that represents a single index
|
||||
pub fn index(idx: isize) -> Self {
|
||||
Self {
|
||||
start: idx,
|
||||
end: handle_signed_inclusive_end(idx),
|
||||
step: 1,
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts the slice to a vector.
|
||||
pub fn into_vec(self) -> Vec<isize> {
|
||||
assert!(
|
||||
self.end.is_some(),
|
||||
"Slice must have an end to convert to a vector: {self:?}"
|
||||
);
|
||||
self.into_iter().collect()
|
||||
}
|
||||
|
||||
/// Clips the slice to a maximum size.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// assert_eq!(
|
||||
/// Slice::new(0, None, 1).bound_to(10),
|
||||
/// Slice::new(0, Some(10), 1));
|
||||
/// assert_eq!(
|
||||
/// Slice::new(0, Some(5), 1).bound_to(10),
|
||||
/// Slice::new(0, Some(5), 1));
|
||||
/// assert_eq!(
|
||||
/// Slice::new(0, None, -1).bound_to(10),
|
||||
/// Slice::new(0, Some(-11), -1));
|
||||
/// assert_eq!(
|
||||
/// Slice::new(0, Some(-5), -1).bound_to(10),
|
||||
/// Slice::new(0, Some(-5), -1));
|
||||
/// ```
|
||||
pub fn bound_to(self, size: usize) -> Self {
|
||||
let mut bounds = size as isize;
|
||||
|
||||
if let Some(end) = self.end {
|
||||
if end > 0 {
|
||||
bounds = end.min(bounds);
|
||||
} else {
|
||||
bounds = end.max(-(bounds + 1));
|
||||
}
|
||||
} else if self.is_reversed() {
|
||||
bounds = -(bounds + 1);
|
||||
}
|
||||
|
||||
Self {
|
||||
end: Some(bounds),
|
||||
..self
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a slice with a custom step
|
||||
pub fn with_step(start: isize, end: Option<isize>, step: isize) -> Self {
|
||||
assert!(step != 0, "Step cannot be zero");
|
||||
Self { start, end, step }
|
||||
}
|
||||
|
||||
/// Creates a slice from a range with a specified step
|
||||
pub fn from_range_stepped<R: Into<Slice>>(range: R, step: isize) -> Self {
|
||||
assert!(step != 0, "Step cannot be zero");
|
||||
let mut slice = range.into();
|
||||
slice.step = step;
|
||||
slice
|
||||
}
|
||||
|
||||
/// Returns the step of the slice
|
||||
pub fn step(&self) -> isize {
|
||||
self.step
|
||||
}
|
||||
|
||||
/// Returns the range for this slice given a dimension size
|
||||
pub fn range(&self, size: usize) -> Range<usize> {
|
||||
self.to_range(size)
|
||||
}
|
||||
|
||||
/// Convert this slice to a range for a dimension of the given size.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `size` - The size of the dimension to slice.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A `Range<usize>` representing the slice bounds.
|
||||
pub fn to_range(&self, size: usize) -> Range<usize> {
|
||||
// Always return a valid range with start <= end
|
||||
// The step information will be handled separately
|
||||
let start = convert_signed_index(self.start, size);
|
||||
let end = match self.end {
|
||||
Some(end) => convert_signed_index(end, size),
|
||||
None => size,
|
||||
};
|
||||
start..end
|
||||
}
|
||||
|
||||
/// Converts the slice into a range and step tuple
|
||||
pub fn to_range_and_step(&self, size: usize) -> (Range<usize>, isize) {
|
||||
let range = self.to_range(size);
|
||||
(range, self.step)
|
||||
}
|
||||
|
||||
/// Returns true if the step is negative
|
||||
pub fn is_reversed(&self) -> bool {
|
||||
self.step < 0
|
||||
}
|
||||
|
||||
/// Calculates the output size for this slice operation
|
||||
pub fn output_size(&self, dim_size: usize) -> usize {
|
||||
let range = self.to_range(dim_size);
|
||||
// Handle empty slices (start >= end)
|
||||
if range.start >= range.end {
|
||||
return 0;
|
||||
}
|
||||
let len = range.end - range.start;
|
||||
if self.step.unsigned_abs() == 1 {
|
||||
len
|
||||
} else {
|
||||
len.div_ceil(self.step.unsigned_abs())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_signed_index(index: isize, size: usize) -> usize {
|
||||
if index < 0 {
|
||||
(size as isize + index).max(0) as usize
|
||||
} else {
|
||||
(index as usize).min(size)
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_signed_inclusive_end(end: isize) -> Option<isize> {
|
||||
match end {
|
||||
-1 => None,
|
||||
end => Some(end + 1),
|
||||
}
|
||||
}
|
||||
|
||||
impl<I: AsIndex> From<Range<I>> for Slice {
|
||||
fn from(r: Range<I>) -> Self {
|
||||
Self {
|
||||
start: r.start.as_index(),
|
||||
end: Some(r.end.as_index()),
|
||||
step: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<I: AsIndex + Copy> From<RangeInclusive<I>> for Slice {
|
||||
fn from(r: RangeInclusive<I>) -> Self {
|
||||
Self {
|
||||
start: r.start().as_index(),
|
||||
end: handle_signed_inclusive_end(r.end().as_index()),
|
||||
step: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<I: AsIndex> From<RangeFrom<I>> for Slice {
|
||||
fn from(r: RangeFrom<I>) -> Self {
|
||||
Self {
|
||||
start: r.start.as_index(),
|
||||
end: None,
|
||||
step: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<I: AsIndex> From<RangeTo<I>> for Slice {
|
||||
fn from(r: RangeTo<I>) -> Self {
|
||||
Self {
|
||||
start: 0,
|
||||
end: Some(r.end.as_index()),
|
||||
step: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<I: AsIndex> From<RangeToInclusive<I>> for Slice {
|
||||
fn from(r: RangeToInclusive<I>) -> Self {
|
||||
Self {
|
||||
start: 0,
|
||||
end: handle_signed_inclusive_end(r.end.as_index()),
|
||||
step: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RangeFull> for Slice {
|
||||
fn from(_: RangeFull) -> Self {
|
||||
Self {
|
||||
start: 0,
|
||||
end: None,
|
||||
step: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<usize> for Slice {
|
||||
fn from(i: usize) -> Self {
|
||||
Slice::index(i as isize)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<isize> for Slice {
|
||||
fn from(i: isize) -> Self {
|
||||
Slice::index(i)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<i32> for Slice {
|
||||
fn from(i: i32) -> Self {
|
||||
Slice::index(i as isize)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<i64> for Slice {
|
||||
fn from(i: i64) -> Self {
|
||||
Slice::index(i as isize)
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Slice {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
|
||||
if self.step == 1
|
||||
&& let Some(end) = self.end
|
||||
&& self.start == end - 1
|
||||
{
|
||||
f.write_fmt(format_args!("{}", self.start))
|
||||
} else {
|
||||
if self.start != 0 {
|
||||
f.write_fmt(format_args!("{}", self.start))?;
|
||||
}
|
||||
f.write_str("..")?;
|
||||
if let Some(end) = self.end {
|
||||
f.write_fmt(format_args!("{}", end))?;
|
||||
}
|
||||
if self.step != 1 {
|
||||
f.write_fmt(format_args!(";{}", self.step))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Slice {
|
||||
type Err = crate::ExpressionError;
|
||||
|
||||
fn from_str(source: &str) -> Result<Self, Self::Err> {
|
||||
let mut s = source.trim();
|
||||
|
||||
let parse_int = |v: &str| -> Result<isize, Self::Err> {
|
||||
v.parse::<isize>().map_err(|e| {
|
||||
crate::ExpressionError::parse_error(
|
||||
format!("Invalid integer: '{v}': {}", e),
|
||||
source,
|
||||
)
|
||||
})
|
||||
};
|
||||
|
||||
let mut start: isize = 0;
|
||||
let mut end: Option<isize> = None;
|
||||
let mut step: isize = 1;
|
||||
|
||||
if let Some((head, tail)) = s.split_once(";") {
|
||||
step = parse_int(tail)?;
|
||||
s = head;
|
||||
}
|
||||
|
||||
if s.is_empty() {
|
||||
return Err(crate::ExpressionError::parse_error(
|
||||
"Empty expression",
|
||||
source,
|
||||
));
|
||||
}
|
||||
|
||||
if let Some((start_s, end_s)) = s.split_once("..") {
|
||||
if !start_s.is_empty() {
|
||||
start = parse_int(start_s)?;
|
||||
}
|
||||
if !end_s.is_empty() {
|
||||
if let Some(end_s) = end_s.strip_prefix('=') {
|
||||
end = Some(parse_int(end_s)? + 1);
|
||||
} else {
|
||||
end = Some(parse_int(end_s)?);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
start = parse_int(s)?;
|
||||
end = Some(start + 1);
|
||||
}
|
||||
|
||||
if step == 0 {
|
||||
return Err(crate::ExpressionError::invalid_expression(
|
||||
"Step cannot be zero",
|
||||
source,
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Slice::new(start, end, step))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use alloc::string::ToString;
|
||||
use alloc::vec;
|
||||
|
||||
#[test]
|
||||
fn test_slice_to_str() {
|
||||
assert_eq!(Slice::new(0, None, 1).to_string(), "..");
|
||||
|
||||
assert_eq!(Slice::new(0, Some(1), 1).to_string(), "0");
|
||||
|
||||
assert_eq!(Slice::new(0, Some(10), 1).to_string(), "..10");
|
||||
assert_eq!(Slice::new(1, Some(10), 1).to_string(), "1..10");
|
||||
|
||||
assert_eq!(Slice::new(-3, Some(10), -2).to_string(), "-3..10;-2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_slice_from_str() {
|
||||
assert_eq!("1".parse::<Slice>(), Ok(Slice::new(1, Some(2), 1)));
|
||||
assert_eq!("..".parse::<Slice>(), Ok(Slice::new(0, None, 1)));
|
||||
assert_eq!("..3".parse::<Slice>(), Ok(Slice::new(0, Some(3), 1)));
|
||||
assert_eq!("..=3".parse::<Slice>(), Ok(Slice::new(0, Some(4), 1)));
|
||||
|
||||
assert_eq!("-12..3".parse::<Slice>(), Ok(Slice::new(-12, Some(3), 1)));
|
||||
assert_eq!("..;-1".parse::<Slice>(), Ok(Slice::new(0, None, -1)));
|
||||
|
||||
assert_eq!("..=3;-2".parse::<Slice>(), Ok(Slice::new(0, Some(4), -2)));
|
||||
|
||||
assert_eq!(
|
||||
"..;0".parse::<Slice>(),
|
||||
Err(crate::ExpressionError::invalid_expression(
|
||||
"Step cannot be zero",
|
||||
"..;0"
|
||||
))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
"".parse::<Slice>(),
|
||||
Err(crate::ExpressionError::parse_error("Empty expression", ""))
|
||||
);
|
||||
assert_eq!(
|
||||
"a".parse::<Slice>(),
|
||||
Err(crate::ExpressionError::parse_error(
|
||||
"Invalid integer: 'a': invalid digit found in string",
|
||||
"a"
|
||||
))
|
||||
);
|
||||
assert_eq!(
|
||||
"..a".parse::<Slice>(),
|
||||
Err(crate::ExpressionError::parse_error(
|
||||
"Invalid integer: 'a': invalid digit found in string",
|
||||
"..a"
|
||||
))
|
||||
);
|
||||
assert_eq!(
|
||||
"a:b:c".parse::<Slice>(),
|
||||
Err(crate::ExpressionError::parse_error(
|
||||
"Invalid integer: 'a:b:c': invalid digit found in string",
|
||||
"a:b:c"
|
||||
))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_slice_output_size() {
|
||||
// Test the output_size method directly
|
||||
assert_eq!(Slice::new(0, Some(10), 1).output_size(10), 10);
|
||||
assert_eq!(Slice::new(0, Some(10), 2).output_size(10), 5);
|
||||
assert_eq!(Slice::new(0, Some(10), 3).output_size(10), 4); // ceil(10/3)
|
||||
assert_eq!(Slice::new(0, Some(10), -1).output_size(10), 10);
|
||||
assert_eq!(Slice::new(0, Some(10), -2).output_size(10), 5);
|
||||
assert_eq!(Slice::new(2, Some(8), -3).output_size(10), 2); // ceil(6/3)
|
||||
assert_eq!(Slice::new(5, Some(5), 1).output_size(10), 0); // empty range
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bound_to() {
|
||||
assert_eq!(
|
||||
Slice::new(0, None, 1).bound_to(10),
|
||||
Slice::new(0, Some(10), 1)
|
||||
);
|
||||
assert_eq!(
|
||||
Slice::new(0, Some(5), 1).bound_to(10),
|
||||
Slice::new(0, Some(5), 1)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
Slice::new(0, None, -1).bound_to(10),
|
||||
Slice::new(0, Some(-11), -1)
|
||||
);
|
||||
assert_eq!(
|
||||
Slice::new(0, Some(-5), -1).bound_to(10),
|
||||
Slice::new(0, Some(-5), -1)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_slice_iter() {
|
||||
assert_eq!(
|
||||
Slice::new(2, Some(3), 1).into_iter().collect::<Vec<_>>(),
|
||||
vec![2]
|
||||
);
|
||||
assert_eq!(
|
||||
Slice::new(3, Some(-1), -1).into_iter().collect::<Vec<_>>(),
|
||||
vec![3, 2, 1, 0]
|
||||
);
|
||||
|
||||
assert_eq!(Slice::new(3, Some(-1), -1).into_vec(), vec![3, 2, 1, 0]);
|
||||
|
||||
assert_eq!(
|
||||
Slice::new(3, None, 2)
|
||||
.into_iter()
|
||||
.take(3)
|
||||
.collect::<Vec<_>>(),
|
||||
vec![3, 5, 7]
|
||||
);
|
||||
assert_eq!(
|
||||
Slice::new(3, None, 2)
|
||||
.bound_to(8)
|
||||
.into_iter()
|
||||
.collect::<Vec<_>>(),
|
||||
vec![3, 5, 7]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(
|
||||
expected = "Slice must have an end to convert to a vector: Slice { start: 0, end: None, step: 1 }"
|
||||
)]
|
||||
fn test_unbound_slice_into_vec() {
|
||||
Slice::new(0, None, 1).into_vec();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn into_slices_should_return_for_all_shape_dims() {
|
||||
let slice = s![1];
|
||||
let shape = Shape::new([2, 3, 1]);
|
||||
|
||||
let slices = slice.into_slices(&shape);
|
||||
|
||||
assert_eq!(slices.len(), shape.len());
|
||||
|
||||
assert_eq!(slices[0], Slice::new(1, Some(2), 1));
|
||||
assert_eq!(slices[1], Slice::new(0, Some(3), 1));
|
||||
assert_eq!(slices[2], Slice::new(0, Some(1), 1));
|
||||
|
||||
let slice = s![1, 0..2];
|
||||
let slices = slice.into_slices(&shape);
|
||||
|
||||
assert_eq!(slices.len(), shape.len());
|
||||
|
||||
assert_eq!(slices[0], Slice::new(1, Some(2), 1));
|
||||
assert_eq!(slices[1], Slice::new(0, Some(2), 1));
|
||||
assert_eq!(slices[2], Slice::new(0, Some(1), 1));
|
||||
|
||||
let slice = s![..];
|
||||
let slices = slice.into_slices(&shape);
|
||||
|
||||
assert_eq!(slices.len(), shape.len());
|
||||
|
||||
assert_eq!(slices[0], Slice::new(0, Some(2), 1));
|
||||
assert_eq!(slices[1], Slice::new(0, Some(3), 1));
|
||||
assert_eq!(slices[2], Slice::new(0, Some(1), 1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn into_slices_all_dimensions() {
|
||||
let slice = s![1, ..2, ..];
|
||||
let shape = Shape::new([2, 3, 1]);
|
||||
|
||||
let slices = slice.into_slices(&shape);
|
||||
|
||||
assert_eq!(slices.len(), shape.len());
|
||||
|
||||
assert_eq!(slices[0], Slice::new(1, Some(2), 1));
|
||||
assert_eq!(slices[1], Slice::new(0, Some(2), 1));
|
||||
assert_eq!(slices[2], Slice::new(0, Some(1), 1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn into_slices_supports_empty_dimensions() {
|
||||
let slice = s![.., 1, ..];
|
||||
let shape = Shape::new([0, 3, 1]);
|
||||
|
||||
let slices = slice.into_slices(&shape);
|
||||
|
||||
assert_eq!(slices.len(), shape.len());
|
||||
|
||||
assert_eq!(slices[0], Slice::new(0, Some(0), 1));
|
||||
assert_eq!(slices[1], Slice::new(1, Some(2), 1));
|
||||
assert_eq!(slices[2], Slice::new(0, Some(1), 1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic = "Too many slices provided for shape"]
|
||||
fn into_slices_should_match_shape_rank() {
|
||||
let slice = s![.., 1, ..];
|
||||
let shape = Shape::new([3, 1]);
|
||||
|
||||
let _ = slice.into_slices(&shape);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_const_and_full() {
|
||||
static SLICES: [Slice; 2] = [Slice::full(), Slice::new(2, None, 1)];
|
||||
assert_eq!(SLICES[0], Slice::new(0, None, 1));
|
||||
assert_eq!(SLICES[1], Slice::new(2, None, 1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_default() {
|
||||
assert_eq!(Slice::default(), Slice::new(0, None, 1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_copy() {
|
||||
let mut slice = Slice::new(1, Some(3), 2);
|
||||
let slice_copy = slice;
|
||||
|
||||
slice.end = Some(4);
|
||||
|
||||
assert_eq!(slice, Slice::new(1, Some(4), 2));
|
||||
assert_eq!(slice_copy, Slice::new(1, Some(3), 2));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user