feat: update workspace paths and enhance gitignore

- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
This commit is contained in:
2026-03-05 19:39:14 +01:00
parent 4bb7ca9074
commit 3a67c0979c
1605 changed files with 537032 additions and 2 deletions

View File

@@ -0,0 +1,84 @@
[package]
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
categories = ["science"]
description = "Library with simple dataset APIs for creating ML data pipelines"
documentation = "https://docs.rs/burn-dataset"
edition.workspace = true
keywords = ["deep-learning", "machine-learning", "data"]
license.workspace = true
name = "burn-dataset"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-dataset"
version.workspace = true
[lints]
workspace = true
[features]
default = ["sqlite-bundled"]
doc = ["default"]
tracing = [
"burn-std/tracing",
]
audio = ["hound"]
builtin-sources = ["vision", "dep:tar", "nlp"]
fake = ["dep:fake"]
network = ["dep:burn-std"]
sqlite = ["__sqlite-shared", "dep:rusqlite"]
sqlite-bundled = ["__sqlite-shared", "rusqlite/bundled"]
vision = ["dep:flate2", "dep:globwalk", "dep:image", "network"]
nlp = ["dep:zip", "dep:encoding_rs"]
# internal
__sqlite-shared = [
"dep:r2d2",
"dep:r2d2_sqlite",
"dep:serde_rusqlite",
"dep:image",
"dep:gix-tempfile",
]
dataframe = ["dep:polars", "dep:planus"]
[dependencies]
burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", optional = true, features = [
"network",
] }
csv = { workspace = true }
derive-new = { workspace = true }
dirs = { workspace = true }
fake = { workspace = true, optional = true }
flate2 = { workspace = true, optional = true }
gix-tempfile = { workspace = true, optional = true }
globwalk = { workspace = true, optional = true }
hound = { workspace = true, optional = true }
image = { workspace = true, optional = true }
planus = { workspace = true, optional = true }
encoding_rs = { workspace = true, optional = true }
polars = { workspace = true, optional = true }
r2d2 = { workspace = true, optional = true }
r2d2_sqlite = { workspace = true, optional = true }
rand = { workspace = true, features = ["std", "sys_rng"] }
zip = { workspace = true, optional = true }
rmp-serde = { workspace = true }
rusqlite = { workspace = true, optional = true }
sanitize-filename = { workspace = true }
serde = { workspace = true, features = ["std", "derive"] }
serde_json = { workspace = true, features = ["std"] }
serde_rusqlite = { workspace = true, optional = true }
strum = { workspace = true }
tar = { workspace = true, optional = true }
tempfile = { workspace = true }
thiserror = { workspace = true }
[dev-dependencies]
fake = { workspace = true }
rayon = { workspace = true }
rstest = { workspace = true }
[package.metadata.cargo-udeps.ignore]
normal = ["strum", "strum_macros"]
[package.metadata.docs.rs]
features = ["doc"]
rustdoc-args = ["--cfg", "docsrs"]

View File

@@ -0,0 +1 @@
../../LICENSE-APACHE

View File

@@ -0,0 +1 @@
../../LICENSE-MIT

View File

@@ -0,0 +1,17 @@
# Burn Dataset
> [Burn](https://github.com/tracel-ai/burn) dataset library
[![Current Crates.io Version](https://img.shields.io/crates/v/burn-dataset.svg)](https://crates.io/crates/burn-dataset)
[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-dataset/blob/master/README.md)
The Burn Dataset library is designed to streamline your machine learning (ML) data pipeline creation
process. It offers a variety of dataset implementations, transformation functions, and data sources.
## Feature Flags
- `audio` - enables audio dataset (SpeechCommandsDataset). Run the following example to try it out:
```shell
cargo run --example speech_commands --features audio
```

View File

@@ -0,0 +1,22 @@
use burn_dataset::HuggingfaceDatasetLoader;
use burn_dataset::SqliteDataset;
use serde::Deserialize;
#[derive(Deserialize, Debug, Clone)]
struct MnistItemRaw {
pub _image_bytes: Vec<u8>,
pub _label: usize,
}
fn main() {
// There are some datasets, such as https://huggingface.co/datasets/ylecun/mnist/tree/main that contains a script,
// In this cases you must enable trusting remote code execution if you want to use it.
let _train_ds: SqliteDataset<MnistItemRaw> = HuggingfaceDatasetLoader::new("mnist")
.with_trust_remote_code(true)
.dataset("train")
.unwrap();
// However not all dataset requires it https://huggingface.co/datasets/Anthropic/hh-rlhf/tree/main
let _train_ds: SqliteDataset<MnistItemRaw> = HuggingfaceDatasetLoader::new("Anthropic/hh-rlhf")
.dataset("train")
.unwrap();
}

View File

@@ -0,0 +1,23 @@
#[cfg(feature = "audio")]
use burn_dataset::{Dataset, audio::SpeechCommandsDataset};
#[cfg(feature = "audio")]
fn speech_command() {
let index: usize = 4835;
let test = SpeechCommandsDataset::test();
let item = test.get(index).unwrap();
println!("Item: {:?}", item);
println!("Item Length: {:?}", item.audio_samples.len());
println!("Label: {}", item.label);
assert_eq!(test.len(), 4890);
assert_eq!(item.label.to_string(), "Yes");
assert_eq!(item.sample_rate, 16000);
assert_eq!(item.audio_samples.len(), 16000);
}
fn main() {
#[cfg(feature = "audio")]
speech_command()
}

View File

@@ -0,0 +1,3 @@
mod speech_commands;
pub use speech_commands::*;

View File

@@ -0,0 +1,208 @@
use crate::{
Dataset, HuggingfaceDatasetLoader, SqliteDataset,
transform::{Mapper, MapperDataset},
};
use hound::WavReader;
use serde::{Deserialize, Serialize};
use strum::{Display, EnumCount, FromRepr};
type MappedDataset = MapperDataset<SqliteDataset<SpeechItemRaw>, ConvertSamples, SpeechItemRaw>;
/// Enum representing speech command classes in the Speech Commands dataset.
/// Class names are based on the Speech Commands dataset from Huggingface.
/// See [speech_commands](https://huggingface.co/datasets/speech_commands)
/// for more information.
#[allow(missing_docs)]
#[derive(Debug, Display, Clone, Copy, FromRepr, Serialize, Deserialize, EnumCount)]
pub enum SpeechCommandClass {
// Target command words
Yes = 0,
No = 1,
Up = 2,
Down = 3,
Left = 4,
Right = 5,
On = 6,
Off = 7,
Stop = 8,
Go = 9,
Zero = 10,
One = 11,
Two = 12,
Three = 13,
Four = 14,
Five = 15,
Six = 16,
Seven = 17,
Eight = 18,
Nine = 19,
// Non-target words that can be grouped into "Other"
Bed = 20,
Bird = 21,
Cat = 22,
Dog = 23,
Happy = 24,
House = 25,
Marvin = 26,
Sheila = 27,
Tree = 28,
Wow = 29,
// Commands from v2 dataset, that can be grouped into "Other"
Backward = 30,
Forward = 31,
Follow = 32,
Learn = 33,
Visual = 34,
// Background noise
Silence = 35,
// Other miscellaneous words
Other = 36,
}
/// Struct containing raw speech data returned from a database.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SpeechItemRaw {
/// Audio file bytes.
pub audio_bytes: Vec<u8>,
/// Label index.
pub label: usize,
/// Indicates if the label is unknown.
pub is_unknown: bool,
}
/// Speech item with audio samples and label.
///
/// The audio samples are floats in the range [-1.0, 1.0].
/// The sample rate is in Hz.
/// The label is the class index (see [SpeechCommandClass]).
/// To convert to usize simply use `as usize`. To convert label to string use `.to_string()`.
///
/// The original label is also stored in the `label_original` field for debugging and remapping if needed.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SpeechItem {
/// Audio samples in the range [-1.0, 1.0].
pub audio_samples: Vec<f32>,
/// The sample rate of the audio.
pub sample_rate: usize,
/// The label of the audio.
pub label: SpeechCommandClass,
}
/// Speech Commands dataset from Huggingface v0.02.
/// See [Speech Commands dataset](https://huggingface.co/datasets/speech_commands).
///
/// The data is downloaded from Huggingface and stored in a SQLite database (3.0 GB).
/// The dataset contains 99,720 audio samples of 2,607 people saying 35 different words.
///
/// NOTE: The most samples are under 1 second long but there are some with pure background noise that
/// need splitting into shorter segmants.
///
/// The labels are 20 target words, silence and other words.
///
/// The dataset is split into 3 parts:
/// - train: 84,848 audio files
/// - test: 4,890 audio files
/// - validation: 9,982 audio files
pub struct SpeechCommandsDataset {
dataset: MappedDataset,
}
impl SpeechCommandsDataset {
/// Create a new dataset with the given split.
pub fn new(split: &str) -> Self {
let dataset: SqliteDataset<SpeechItemRaw> =
HuggingfaceDatasetLoader::new("speech_commands")
.with_subset("v0.02")
.dataset(split)
.unwrap();
let dataset = MapperDataset::new(dataset, ConvertSamples);
Self { dataset }
}
/// Create a new dataset with the train split.
pub fn train() -> Self {
Self::new("train")
}
/// Create a new dataset with the test split.
pub fn test() -> Self {
Self::new("test")
}
/// Create a new dataset with the validation split.
pub fn validation() -> Self {
Self::new("validation")
}
/// Returns the number of classes in the dataset
pub fn num_classes() -> usize {
SpeechCommandClass::COUNT
}
}
impl Dataset<SpeechItem> for SpeechCommandsDataset {
fn get(&self, index: usize) -> Option<SpeechItem> {
self.dataset.get(index)
}
fn len(&self) -> usize {
self.dataset.len()
}
}
/// Mapper converting audio bytes into audio samples and the label to enum class.
struct ConvertSamples;
impl ConvertSamples {
/// Convert label to enum class.
fn to_speechcommandclass(label: usize) -> SpeechCommandClass {
SpeechCommandClass::from_repr(label).unwrap()
}
/// Convert audio bytes into samples of floats [-1.0, 1.0].
fn to_audiosamples(bytes: &Vec<u8>) -> (Vec<f32>, usize) {
let reader = WavReader::new(bytes.as_slice()).unwrap();
let spec = reader.spec();
// Maximum value of the audio samples (using bit shift to raise 2 to the power of bits per sample).
let max_value = (1 << (spec.bits_per_sample - 1)) as f32;
// The sample rate of the audio.
let sample_rate = spec.sample_rate as usize;
// Convert the audio samples to floats [-1.0, 1.0].
let audio_samples: Vec<f32> = reader
.into_samples::<i32>()
.filter_map(Result::ok)
.map(|sample| sample as f32 / max_value)
.collect();
(audio_samples, sample_rate)
}
}
impl Mapper<SpeechItemRaw, SpeechItem> for ConvertSamples {
/// Convert audio bytes into samples of floats [-1.0, 1.0]
/// and the label to enum class with the target word, other and silence classes.
fn map(&self, item: &SpeechItemRaw) -> SpeechItem {
let (audio_samples, sample_rate) = Self::to_audiosamples(&item.audio_bytes);
// Convert the label to enum class, with the target words, other and silence classes.
let label = Self::to_speechcommandclass(item.label);
SpeechItem {
audio_samples,
sample_rate,
label,
}
}
}

View File

@@ -0,0 +1,71 @@
use std::sync::Arc;
use crate::DatasetIterator;
/// The dataset trait defines a basic collection of items with a predefined size.
pub trait Dataset<I>: Send + Sync {
/// Gets the item at the given index.
fn get(&self, index: usize) -> Option<I>;
/// Gets the number of items in the dataset.
fn len(&self) -> usize;
/// Checks if the dataset is empty.
fn is_empty(&self) -> bool {
self.len() == 0
}
/// Returns an iterator over the dataset.
fn iter(&self) -> DatasetIterator<'_, I>
where
Self: Sized,
{
DatasetIterator::new(self)
}
}
impl<D, I> Dataset<I> for Arc<D>
where
D: Dataset<I>,
{
fn get(&self, index: usize) -> Option<I> {
self.as_ref().get(index)
}
fn len(&self) -> usize {
self.as_ref().len()
}
}
impl<I> Dataset<I> for Arc<dyn Dataset<I>> {
fn get(&self, index: usize) -> Option<I> {
self.as_ref().get(index)
}
fn len(&self) -> usize {
self.as_ref().len()
}
}
impl<D, I> Dataset<I> for Box<D>
where
D: Dataset<I>,
{
fn get(&self, index: usize) -> Option<I> {
self.as_ref().get(index)
}
fn len(&self) -> usize {
self.as_ref().len()
}
}
impl<I> Dataset<I> for Box<dyn Dataset<I>> {
fn get(&self, index: usize) -> Option<I> {
self.as_ref().get(index)
}
fn len(&self) -> usize {
self.as_ref().len()
}
}

View File

@@ -0,0 +1,465 @@
use std::marker::PhantomData;
use crate::Dataset;
use polars::frame::row::Row;
use polars::prelude::*;
use serde::de::DeserializeSeed;
use serde::{
Deserialize,
de::{self, DeserializeOwned, Deserializer, SeqAccess, Visitor},
forward_to_deserialize_any,
};
/// Error type for DataframeDataset
#[derive(thiserror::Error, Debug)]
pub enum DataframeDatasetError {
/// Error occurred during deserialization or other operations
#[error("{0}")]
Other(String),
}
impl de::Error for DataframeDatasetError {
fn custom<T: std::fmt::Display>(msg: T) -> Self {
DataframeDatasetError::Other(msg.to_string())
}
}
/// Dataset implementation for Polars DataFrame
///
/// This struct provides a way to access data from a Polars DataFrame
/// as if it were a Dataset of type I.
pub struct DataframeDataset<I> {
df: DataFrame,
len: usize,
column_name_mapping: Vec<usize>,
phantom: PhantomData<I>,
}
impl<I> DataframeDataset<I>
where
I: Clone + Send + Sync + DeserializeOwned,
{
/// Create a new DataframeDataset from a Polars DataFrame
///
/// # Arguments
///
/// * `df` - A Polars DataFrame
///
/// # Returns
///
/// A Result containing the new DataframeDataset or a DataframeDatasetError
pub fn new(df: DataFrame) -> Result<Self, DataframeDatasetError> {
let len = df.height();
let field_names = extract_field_names::<I>();
let column_name_mapping = field_names
.iter()
.map(|name| {
df.schema()
.try_get_full(name)
.expect("Corresponding column should exist in the DataFrame")
.0
})
.collect::<Vec<_>>();
Ok(DataframeDataset {
df,
len,
column_name_mapping,
phantom: PhantomData,
})
}
}
impl<I> Dataset<I> for DataframeDataset<I>
where
I: Clone + Send + Sync + DeserializeOwned,
{
/// Get an item from the dataset at the specified index
///
/// # Arguments
///
/// * `index` - The index of the item to retrieve
///
/// # Returns
///
/// An Option containing the item if it exists, or None if it doesn't
fn get(&self, index: usize) -> Option<I> {
let row = self.df.get_row(index).ok()?;
let mut deserializer = RowDeserializer::new(&row, &self.column_name_mapping);
I::deserialize(&mut deserializer).ok()
}
/// Get the length of the dataset
fn len(&self) -> usize {
self.len
}
/// Check if the dataset is empty
fn is_empty(&self) -> bool {
self.len == 0
}
}
/// A deserializer for Polars DataFrame rows
struct RowDeserializer<'a> {
row: &'a Row<'a>,
column_name_mapping: &'a Vec<usize>,
index: usize,
}
impl<'a> RowDeserializer<'a> {
/// Create a new RowDeserializer
///
/// # Arguments
///
/// * `row` - A reference to a Polars DataFrame row
/// * `column_name_mapping` - A reference to a vector mapping field names to column indices
fn new(row: &'a Row, column_name_mapping: &'a Vec<usize>) -> RowDeserializer<'a> {
RowDeserializer {
row,
column_name_mapping,
index: 0,
}
}
}
impl<'de, 'a> Deserializer<'de> for &'a mut RowDeserializer<'a> {
type Error = DataframeDatasetError;
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, DataframeDatasetError>
where
V: Visitor<'de>,
{
let i = self.column_name_mapping[self.index];
let value = &self.row.0[i];
match value {
AnyValue::Null => visitor.visit_none(),
AnyValue::Boolean(b) => visitor.visit_bool(*b),
AnyValue::Int8(i) => visitor.visit_i8(*i),
AnyValue::Int16(i) => visitor.visit_i16(*i),
AnyValue::Int32(i) => visitor.visit_i32(*i),
AnyValue::Int64(i) => visitor.visit_i64(*i),
AnyValue::UInt8(i) => visitor.visit_u8(*i),
AnyValue::UInt16(i) => visitor.visit_u16(*i),
AnyValue::UInt32(i) => visitor.visit_u32(*i),
AnyValue::UInt64(i) => visitor.visit_u64(*i),
AnyValue::Float32(f) => visitor.visit_f32(*f),
AnyValue::Float64(f) => visitor.visit_f64(*f),
AnyValue::Date(i) => visitor.visit_i32(*i),
AnyValue::String(s) => visitor.visit_string(s.to_string()),
AnyValue::Binary(b) => {
visitor.visit_seq(de::value::SeqDeserializer::new(b.iter().copied()))
}
AnyValue::Time(t) => visitor.visit_i64(*t),
ty => Err(DataframeDatasetError::Other(
format!("Unsupported type: {ty:?}").to_string(),
)),
}
}
fn deserialize_struct<V>(
self,
_name: &'static str,
_fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value, DataframeDatasetError>
where
V: Visitor<'de>,
{
visitor.visit_seq(self)
}
forward_to_deserialize_any! {
bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string
bytes byte_buf option unit unit_struct newtype_struct seq tuple
tuple_struct map enum identifier ignored_any
}
}
impl<'de, 'a> SeqAccess<'de> for RowDeserializer<'a> {
type Error = DataframeDatasetError;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, DataframeDatasetError>
where
T: DeserializeSeed<'de>,
{
if self.index >= self.row.0.len() {
return Ok(None);
}
let mut deserializer = RowDeserializer {
row: self.row,
column_name_mapping: self.column_name_mapping,
index: self.index,
};
self.index += 1;
seed.deserialize(&mut deserializer).map(Some)
}
}
struct FieldExtractor {
fields: Vec<&'static str>,
}
impl<'de> Deserializer<'de> for &mut FieldExtractor {
type Error = de::value::Error;
fn deserialize_any<V>(self, _visitor: V) -> core::result::Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
Err(de::Error::custom("Field extractor"))
}
fn deserialize_struct<V>(
self,
_name: &'static str,
fields: &'static [&'static str],
_visitor: V,
) -> core::result::Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
self.fields.extend_from_slice(fields);
Err(de::Error::custom("Field extractor"))
}
forward_to_deserialize_any! {
bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes
byte_buf option unit unit_struct newtype_struct seq tuple
tuple_struct map enum identifier ignored_any
}
}
/// Extract field names from a type T that implements Deserialize
///
/// # Returns
///
/// A vector of field names as static string slices
fn extract_field_names<'de, T>() -> Vec<&'static str>
where
T: Deserialize<'de>,
{
let mut extractor = FieldExtractor { fields: Vec::new() };
let _ = T::deserialize(&mut extractor);
extractor.fields
}
#[cfg(test)]
mod tests {
use polars::prelude::*;
use serde::Deserialize;
use super::*;
#[derive(Clone, Debug, Deserialize, PartialEq)]
struct TestData {
int32: i32,
bool: bool,
float64: f64,
string: String,
int16: i16,
uint32: u32,
uint64: u64,
float32: f32,
int64: i64,
int8: i8,
binary: Vec<u8>,
}
fn create_test_dataframe() -> DataFrame {
let s0 = Column::new("int32".into(), &[1i32, 2i32, 3i32]);
let s1 = Column::new("bool".into(), &[true, false, true]);
let s2 = Column::new("float64".into(), &[1.1f64, 2.2f64, 3.3f64]);
let s3 = Column::new("string".into(), &["Boo", "Boo2", "Boo3"]);
let s6 = Column::new("int16".into(), &[1i16, 2i16, 3i16]);
let s8 = Column::new("uint32".into(), &[1u32, 2u32, 3u32]);
let s9 = Column::new("uint64".into(), &[1u64, 2u64, 3u64]);
let s10 = Column::new("float32".into(), &[1.1f32, 2.2f32, 3.3f32]);
let s11 = Column::new("int64".into(), &[1i64, 2i64, 3i64]);
let s12 = Column::new("int8".into(), &[1i8, 2i8, 3i8]);
let binary_data: Vec<&[u8]> = vec![&[1, 2, 3], &[4, 5, 6], &[7, 8, 9]];
let s13 = Column::new("binary".into(), binary_data);
DataFrame::new_infer_height(vec![s0, s1, s2, s3, s6, s8, s9, s10, s11, s12, s13]).unwrap()
}
#[test]
fn test_dataframe_dataset_creation() {
let df = create_test_dataframe();
let dataset = DataframeDataset::<TestData>::new(df);
assert!(dataset.is_ok());
}
#[test]
fn test_dataframe_dataset_length() {
let df = create_test_dataframe();
let dataset = DataframeDataset::<TestData>::new(df).unwrap();
assert_eq!(dataset.len(), 3);
assert!(!dataset.is_empty());
}
#[test]
fn test_dataframe_dataset_get() {
let df = create_test_dataframe();
let dataset = DataframeDataset::<TestData>::new(df).unwrap();
let expected_items = vec![
TestData {
int32: 1,
bool: true,
float64: 1.1,
string: "Boo".to_string(),
int16: 1,
uint32: 1,
uint64: 1,
float32: 1.1,
int64: 1,
int8: 1,
binary: vec![1, 2, 3],
},
TestData {
int32: 2,
bool: false,
float64: 2.2,
string: "Boo2".to_string(),
int16: 2,
uint32: 2,
uint64: 2,
float32: 2.2,
int64: 2,
int8: 2,
binary: vec![4, 5, 6],
},
TestData {
int32: 3,
bool: true,
float64: 3.3,
string: "Boo3".to_string(),
int16: 3,
uint32: 3,
uint64: 3,
float32: 3.3,
int64: 3,
int8: 3,
binary: vec![7, 8, 9],
},
];
for (index, expected_item) in expected_items.iter().enumerate() {
let item = dataset.get(index).unwrap();
assert_eq!(&item, expected_item);
}
}
#[test]
fn test_dataframe_dataset_out_of_bounds() {
let df = create_test_dataframe();
let dataset = DataframeDataset::<TestData>::new(df).unwrap();
assert!(dataset.get(3).is_none());
}
#[test]
fn test_dataframe_dataset() {
let df = create_test_dataframe();
let dataset: DataframeDataset<TestData> = DataframeDataset::new(df).unwrap();
assert_eq!(dataset.len(), 3);
assert!(!dataset.is_empty());
let item = dataset.get(1).unwrap();
assert_eq!(
item,
TestData {
int32: 2,
bool: false,
float64: 2.2,
string: "Boo2".to_string(),
int16: 2,
uint32: 2,
uint64: 2,
float32: 2.2,
int64: 2,
int8: 2,
binary: vec![4, 5, 6],
}
);
let item = dataset.get(2).unwrap();
assert_eq!(
item,
TestData {
int32: 3,
bool: true,
float64: 3.3,
string: "Boo3".to_string(),
int16: 3,
uint32: 3,
uint64: 3,
float32: 3.3,
int64: 3,
int8: 3,
binary: vec![7, 8, 9],
}
);
}
#[test]
#[should_panic = "Corresponding column should exist in the DataFrame: SchemaFieldNotFound(ErrString(\"non_existent\"))"]
fn test_non_existing_struct_fields() {
#[derive(Clone, Debug, Deserialize, PartialEq)]
struct PartialTestData {
int32: i32,
bool: bool,
non_existent: String,
}
let df = create_test_dataframe();
let dataset = DataframeDataset::<PartialTestData>::new(df);
assert!(dataset.is_err());
if let Err(e) = dataset {
assert!(matches!(e, DataframeDatasetError::Other(_)));
}
}
#[test]
fn test_partial_table() {
#[derive(Clone, Debug, Deserialize, PartialEq)]
struct PartialTestData {
int32: i32,
bool: bool,
string: String,
}
let df = create_test_dataframe();
let dataset = DataframeDataset::<PartialTestData>::new(df).unwrap();
assert_eq!(dataset.len(), 3);
assert!(!dataset.is_empty());
let item = dataset.get(1).unwrap();
assert_eq!(
item,
PartialTestData {
int32: 2,
bool: false,
string: "Boo2".to_string(),
}
);
let item = dataset.get(2).unwrap();
assert_eq!(
item,
PartialTestData {
int32: 3,
bool: true,
string: "Boo3".to_string(),
}
);
}
}

View File

@@ -0,0 +1,38 @@
use crate::{Dataset, DatasetIterator, InMemDataset};
use fake::{Dummy, Fake, Faker};
/// Dataset filled with fake items generated from the [fake](fake) crate.
pub struct FakeDataset<I> {
dataset: InMemDataset<I>,
}
impl<I: Dummy<Faker>> FakeDataset<I> {
/// Create a new fake dataset with the given size.
pub fn new(size: usize) -> Self {
let mut items = Vec::with_capacity(size);
for _ in 0..size {
items.push(Faker.fake());
}
let dataset = InMemDataset::new(items);
Self { dataset }
}
}
impl<I: Send + Sync + Clone> Dataset<I> for FakeDataset<I> {
fn iter(&self) -> DatasetIterator<'_, I> {
DatasetIterator::new(self)
}
fn get(&self, index: usize) -> Option<I> {
self.dataset.get(index)
}
fn len(&self) -> usize {
self.dataset.len()
}
fn is_empty(&self) -> bool {
self.dataset.is_empty()
}
}

View File

@@ -0,0 +1,192 @@
use std::{
fs::File,
io::{BufRead, BufReader},
path::Path,
};
use serde::de::DeserializeOwned;
use crate::Dataset;
/// Dataset where all items are stored in ram.
pub struct InMemDataset<I> {
items: Vec<I>,
}
impl<I> InMemDataset<I> {
/// Creates a new in memory dataset from the given items.
pub fn new(items: Vec<I>) -> Self {
InMemDataset { items }
}
}
impl<I> Dataset<I> for InMemDataset<I>
where
I: Clone + Send + Sync,
{
fn get(&self, index: usize) -> Option<I> {
self.items.get(index).cloned()
}
fn len(&self) -> usize {
self.items.len()
}
}
impl<I> InMemDataset<I>
where
I: Clone + DeserializeOwned,
{
/// Create from a dataset. All items are loaded in memory.
pub fn from_dataset(dataset: &impl Dataset<I>) -> Self {
let items: Vec<I> = dataset.iter().collect();
Self::new(items)
}
/// Create from a json rows file (one json per line).
///
/// [Supported field types](https://docs.rs/serde_json/latest/serde_json/value/enum.Value.html)
pub fn from_json_rows<P: AsRef<Path>>(path: P) -> Result<Self, std::io::Error> {
let file = File::open(path)?;
let reader = BufReader::new(file);
let mut items = Vec::new();
for line in reader.lines() {
let item = serde_json::from_str(line.unwrap().as_str()).unwrap();
items.push(item);
}
let dataset = Self::new(items);
Ok(dataset)
}
/// Create from a csv file.
///
/// The provided `csv::ReaderBuilder` can be configured to fit your csv format.
///
/// The supported field types are: String, integer, float, and bool.
///
/// See:
/// - [Reading with Serde](https://docs.rs/csv/latest/csv/tutorial/index.html#reading-with-serde)
/// - [Delimiters, quotes and variable length records](https://docs.rs/csv/latest/csv/tutorial/index.html#delimiters-quotes-and-variable-length-records)
pub fn from_csv<P: AsRef<Path>>(
path: P,
builder: &csv::ReaderBuilder,
) -> Result<Self, std::io::Error> {
let mut rdr = builder.from_path(path)?;
let mut items = Vec::new();
for result in rdr.deserialize() {
let item: I = result?;
items.push(item);
}
let dataset = Self::new(items);
Ok(dataset)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{SqliteDataset, test_data};
use rstest::{fixture, rstest};
use serde::{Deserialize, Serialize};
const DB_FILE: &str = "tests/data/sqlite-dataset.db";
const JSON_FILE: &str = "tests/data/dataset.json";
const CSV_FILE: &str = "tests/data/dataset.csv";
const CSV_FMT_FILE: &str = "tests/data/dataset-fmt.csv";
type SqlDs = SqliteDataset<Sample>;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Sample {
column_str: String,
column_bytes: Vec<u8>,
column_int: i64,
column_bool: bool,
column_float: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct SampleCsv {
column_str: String,
column_int: i64,
column_bool: bool,
column_float: f64,
}
#[fixture]
fn train_dataset() -> SqlDs {
SqliteDataset::from_db_file(DB_FILE, "train").unwrap()
}
#[rstest]
pub fn from_dataset(train_dataset: SqlDs) {
let dataset = InMemDataset::from_dataset(&train_dataset);
let non_existing_record_index: usize = 10;
let record_index: usize = 0;
assert_eq!(train_dataset.get(non_existing_record_index), None);
assert_eq!(dataset.get(record_index).unwrap().column_str, "HI1");
}
#[test]
pub fn from_json_rows() {
let dataset = InMemDataset::<Sample>::from_json_rows(JSON_FILE).unwrap();
let non_existing_record_index: usize = 10;
let record_index: usize = 1;
assert_eq!(dataset.get(non_existing_record_index), None);
assert_eq!(dataset.get(record_index).unwrap().column_str, "HI2");
assert!(!dataset.get(record_index).unwrap().column_bool);
}
#[test]
pub fn from_csv_rows() {
let rdr = csv::ReaderBuilder::new();
let dataset = InMemDataset::<SampleCsv>::from_csv(CSV_FILE, &rdr).unwrap();
let non_existing_record_index: usize = 10;
let record_index: usize = 1;
assert_eq!(dataset.get(non_existing_record_index), None);
assert_eq!(dataset.get(record_index).unwrap().column_str, "HI2");
assert_eq!(dataset.get(record_index).unwrap().column_int, 1);
assert!(!dataset.get(record_index).unwrap().column_bool);
assert_eq!(dataset.get(record_index).unwrap().column_float, 1.0);
}
#[test]
pub fn from_csv_rows_fmt() {
let mut rdr = csv::ReaderBuilder::new();
let rdr = rdr.delimiter(b' ').has_headers(false);
let dataset = InMemDataset::<SampleCsv>::from_csv(CSV_FMT_FILE, rdr).unwrap();
let non_existing_record_index: usize = 10;
let record_index: usize = 1;
assert_eq!(dataset.get(non_existing_record_index), None);
assert_eq!(dataset.get(record_index).unwrap().column_str, "HI2");
assert_eq!(dataset.get(record_index).unwrap().column_int, 1);
assert!(!dataset.get(record_index).unwrap().column_bool);
assert_eq!(dataset.get(record_index).unwrap().column_float, 1.0);
}
#[test]
pub fn given_in_memory_dataset_when_iterate_should_iterate_though_all_items() {
let items_original = test_data::string_items();
let dataset = InMemDataset::new(items_original.clone());
let items: Vec<String> = dataset.iter().collect();
assert_eq!(items_original, items);
}
}

View File

@@ -0,0 +1,31 @@
use crate::dataset::Dataset;
use std::iter::Iterator;
/// Dataset iterator.
pub struct DatasetIterator<'a, I> {
current: usize,
dataset: &'a dyn Dataset<I>,
}
impl<'a, I> DatasetIterator<'a, I> {
/// Creates a new dataset iterator.
pub fn new<D>(dataset: &'a D) -> Self
where
D: Dataset<I>,
{
DatasetIterator {
current: 0,
dataset,
}
}
}
impl<I> Iterator for DatasetIterator<'_, I> {
type Item = I;
fn next(&mut self) -> Option<I> {
let item = self.dataset.get(self.current);
self.current += 1;
item
}
}

View File

@@ -0,0 +1,25 @@
mod base;
mod in_memory;
mod iterator;
pub use base::*;
pub use in_memory::*;
pub use iterator::*;
#[cfg(any(test, feature = "fake"))]
mod fake;
#[cfg(any(test, feature = "fake"))]
pub use self::fake::*;
#[cfg(feature = "dataframe")]
mod dataframe;
#[cfg(feature = "dataframe")]
pub use dataframe::*;
#[cfg(any(feature = "sqlite", feature = "sqlite-bundled"))]
pub use sqlite::*;
#[cfg(any(feature = "sqlite", feature = "sqlite-bundled"))]
mod sqlite;

View File

@@ -0,0 +1,851 @@
use std::{
collections::HashSet,
fs, io,
marker::PhantomData,
path::{Path, PathBuf},
sync::{Arc, RwLock},
};
use crate::Dataset;
use gix_tempfile::{
AutoRemove, ContainingDirectory, Handle,
handle::{Writable, persist},
};
use r2d2::{Pool, PooledConnection};
use r2d2_sqlite::{
SqliteConnectionManager,
rusqlite::{OpenFlags, OptionalExtension},
};
use sanitize_filename::sanitize;
use serde::{Serialize, de::DeserializeOwned};
use serde_rusqlite::{columns_from_statement, from_row_with_columns};
/// Result type for the sqlite dataset.
pub type Result<T> = core::result::Result<T, SqliteDatasetError>;
/// Sqlite dataset error.
#[derive(thiserror::Error, Debug)]
pub enum SqliteDatasetError {
/// IO related error.
#[error("IO error: {0}")]
Io(#[from] io::Error),
/// Sql related error.
#[error("Sql error: {0}")]
Sql(#[from] serde_rusqlite::rusqlite::Error),
/// Serde related error.
#[error("Serde error: {0}")]
Serde(#[from] rmp_serde::encode::Error),
/// The database file already exists error.
#[error("Overwrite flag is set to false and the database file already exists: {0}")]
FileExists(PathBuf),
/// Error when creating the connection pool.
#[error("Failed to create connection pool: {0}")]
ConnectionPool(#[from] r2d2::Error),
/// Error when persisting the temporary database file.
#[error("Could not persist the temporary database file: {0}")]
PersistDbFile(#[from] persist::Error<Writable>),
/// Any other error.
#[error("{0}")]
Other(&'static str),
}
impl From<&'static str> for SqliteDatasetError {
fn from(s: &'static str) -> Self {
SqliteDatasetError::Other(s)
}
}
/// This struct represents a dataset where all items are stored in an SQLite database.
/// Each instance of this struct corresponds to a specific table within the SQLite database,
/// and allows for interaction with the data stored in the table in a structured and typed manner.
///
/// The SQLite database must contain a table with the same name as the `split` field. This table should
/// have a primary key column named `row_id`, which is used to index the rows in the table. The `row_id`
/// should start at 1, while the corresponding dataset `index` should start at 0, i.e., `row_id` = `index` + 1.
///
/// Table columns can be represented in two ways:
///
/// 1. The table can have a column for each field in the `I` struct. In this case, the column names in the table
/// should match the field names of the `I` struct. The field names can be a subset of column names and
/// can be in any order.
///
/// For the supported field types, refer to:
/// - [Serialization field types](https://docs.rs/serde_rusqlite/latest/serde_rusqlite)
/// - [SQLite data types](https://www.sqlite.org/datatype3.html)
///
/// 2. The fields in the `I` struct can be serialized into a single column `item` in the table. In this case, the table
/// should have a single column named `item` of type `BLOB`. This is useful when the `I` struct contains complex fields
/// that cannot be mapped to a SQLite type, such as nested structs, vectors, etc. The serialization is done using
/// [MessagePack](https://msgpack.org/).
///
/// Note: The code automatically figures out which of the above two cases is applicable, and uses the appropriate
/// method to read the data from the table.
#[derive(Debug)]
pub struct SqliteDataset<I> {
db_file: PathBuf,
split: String,
conn_pool: Pool<SqliteConnectionManager>,
columns: Vec<String>,
len: usize,
select_statement: String,
row_serialized: bool,
phantom: PhantomData<I>,
}
impl<I> SqliteDataset<I> {
/// Initializes a `SqliteDataset` from a SQLite database file and a split name.
pub fn from_db_file<P: AsRef<Path>>(db_file: P, split: &str) -> Result<Self> {
// Create a connection pool
let conn_pool = create_conn_pool(&db_file, false)?;
// Determine how the table is stored
let row_serialized = Self::check_if_row_serialized(&conn_pool, split)?;
// Create a select statement and save it
let select_statement = if row_serialized {
format!("select item from {split} where row_id = ?")
} else {
format!("select * from {split} where row_id = ?")
};
// Save the column names and the number of rows
let (columns, len) = fetch_columns_and_len(&conn_pool, &select_statement, split)?;
Ok(SqliteDataset {
db_file: db_file.as_ref().to_path_buf(),
split: split.to_string(),
conn_pool,
columns,
len,
select_statement,
row_serialized,
phantom: PhantomData,
})
}
/// Returns true if table has two columns: row_id (integer) and item (blob).
///
/// This is used to determine if the table is row serialized or not.
fn check_if_row_serialized(
conn_pool: &Pool<SqliteConnectionManager>,
split: &str,
) -> Result<bool> {
// This struct is used to store the column name and type
struct Column {
name: String,
ty: String,
}
const COLUMN_NAME: usize = 1;
const COLUMN_TYPE: usize = 2;
let sql_statement = format!("PRAGMA table_info({split})");
let conn = conn_pool.get()?;
let mut stmt = conn.prepare(sql_statement.as_str())?;
let column_iter = stmt.query_map([], |row| {
Ok(Column {
name: row
.get::<usize, String>(COLUMN_NAME)
.unwrap()
.to_lowercase(),
ty: row
.get::<usize, String>(COLUMN_TYPE)
.unwrap()
.to_lowercase(),
})
})?;
let mut columns: Vec<Column> = vec![];
for column in column_iter {
columns.push(column?);
}
if columns.len() != 2 {
Ok(false)
} else {
// Check if the column names and types match the expected values
Ok(columns[0].name == "row_id"
&& columns[0].ty == "integer"
&& columns[1].name == "item"
&& columns[1].ty == "blob")
}
}
/// Get the database file name.
pub fn db_file(&self) -> PathBuf {
self.db_file.clone()
}
/// Get the split name.
pub fn split(&self) -> &str {
self.split.as_str()
}
}
impl<I> Dataset<I> for SqliteDataset<I>
where
I: Clone + Send + Sync + DeserializeOwned,
{
/// Get an item from the dataset.
fn get(&self, index: usize) -> Option<I> {
// Row ids start with 1 (one) and index starts with 0 (zero)
let row_id = index + 1;
// Get a connection from the pool
let connection = self.conn_pool.get().unwrap();
let mut statement = connection.prepare(self.select_statement.as_str()).unwrap();
if self.row_serialized {
// Fetch with a single column `item` and deserialize it with MessagePack
statement
.query_row([row_id], |row| {
// Deserialize item (blob) with MessagePack (rmp-serde)
Ok(
rmp_serde::from_slice::<I>(row.get_ref(0).unwrap().as_blob().unwrap())
.unwrap(),
)
})
.optional() //Converts Error (not found) to None
.unwrap()
} else {
// Fetch a row with multiple columns and deserialize it serde_rusqlite
statement
.query_row([row_id], |row| {
// Deserialize the row with serde_rusqlite
Ok(from_row_with_columns::<I>(row, &self.columns).unwrap())
})
.optional() //Converts Error (not found) to None
.unwrap()
}
}
/// Return the number of rows in the dataset.
fn len(&self) -> usize {
self.len
}
}
/// Fetch the column names and the number of rows from the database.
fn fetch_columns_and_len(
conn_pool: &Pool<SqliteConnectionManager>,
select_statement: &str,
split: &str,
) -> Result<(Vec<String>, usize)> {
// Save the column names
let connection = conn_pool.get()?;
let statement = connection.prepare(select_statement)?;
let columns = columns_from_statement(&statement);
// Count the number of rows and save it as len
//
// NOTE: Using coalesce(max(row_id), 0) instead of count(*) because count(*) is super slow for large tables.
// The coalesce(max(row_id), 0) returns 0 if the table is empty, otherwise it returns the max row_id,
// which corresponds to the number of rows in the table.
// The main assumption, which always holds true, is that the row_id is always increasing and there are no gaps.
// This is true for all the datasets that we are using, otherwise row_id will not correspond to the index.
let mut statement =
connection.prepare(format!("select coalesce(max(row_id), 0) from {split}").as_str())?;
let len = statement.query_row([], |row| {
let len: usize = row.get(0)?;
Ok(len)
})?;
Ok((columns, len))
}
/// Helper function to create a connection pool
fn create_conn_pool<P: AsRef<Path>>(
db_file: P,
write: bool,
) -> Result<Pool<SqliteConnectionManager>> {
let sqlite_flags = if write {
OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE
} else {
OpenFlags::SQLITE_OPEN_READ_ONLY
};
let manager = SqliteConnectionManager::file(db_file).with_flags(sqlite_flags);
Pool::new(manager).map_err(SqliteDatasetError::ConnectionPool)
}
/// The `SqliteDatasetStorage` struct represents a SQLite database for storing datasets.
/// It consists of an optional name, a database file path, and a base directory for storage.
#[derive(Clone, Debug)]
pub struct SqliteDatasetStorage {
name: Option<String>,
db_file: Option<PathBuf>,
base_dir: Option<PathBuf>,
}
impl SqliteDatasetStorage {
/// Creates a new instance of `SqliteDatasetStorage` using a dataset name.
///
/// # Arguments
///
/// * `name` - A string slice that holds the name of the dataset.
pub fn from_name(name: &str) -> Self {
SqliteDatasetStorage {
name: Some(name.to_string()),
db_file: None,
base_dir: None,
}
}
/// Creates a new instance of `SqliteDatasetStorage` using a database file path.
///
/// # Arguments
///
/// * `db_file` - A reference to the Path that represents the database file path.
pub fn from_file<P: AsRef<Path>>(db_file: P) -> Self {
SqliteDatasetStorage {
name: None,
db_file: Some(db_file.as_ref().to_path_buf()),
base_dir: None,
}
}
/// Sets the base directory for storing the dataset.
///
/// # Arguments
///
/// * `base_dir` - A string slice that represents the base directory.
pub fn with_base_dir<P: AsRef<Path>>(mut self, base_dir: P) -> Self {
self.base_dir = Some(base_dir.as_ref().to_path_buf());
self
}
/// Checks if the database file exists in the given path.
///
/// # Returns
///
/// * A boolean value indicating whether the file exists or not.
pub fn exists(&self) -> bool {
self.db_file().exists()
}
/// Fetches the database file path.
///
/// # Returns
///
/// * A `PathBuf` instance representing the file path.
pub fn db_file(&self) -> PathBuf {
match &self.db_file {
Some(db_file) => db_file.clone(),
None => {
let name = sanitize(self.name.as_ref().expect("Name is not set"));
Self::base_dir(self.base_dir.to_owned()).join(format!("{name}.db"))
}
}
}
/// Determines the base directory for storing the dataset.
///
/// # Arguments
///
/// * `base_dir` - An `Option` that may contain a `PathBuf` instance representing the base directory.
///
/// # Returns
///
/// * A `PathBuf` instance representing the base directory.
pub fn base_dir(base_dir: Option<PathBuf>) -> PathBuf {
match base_dir {
Some(base_dir) => base_dir,
None => dirs::cache_dir()
.expect("Could not get cache directory")
.join("burn-dataset"),
}
}
/// Provides a writer instance for the SQLite dataset.
///
/// # Arguments
///
/// * `overwrite` - A boolean indicating if the existing database file should be overwritten.
///
/// # Returns
///
/// * A `Result` which is `Ok` if the writer could be created, `Err` otherwise.
pub fn writer<I>(&self, overwrite: bool) -> Result<SqliteDatasetWriter<I>>
where
I: Clone + Send + Sync + Serialize + DeserializeOwned,
{
SqliteDatasetWriter::new(self.db_file(), overwrite)
}
/// Provides a reader instance for the SQLite dataset.
///
/// # Arguments
///
/// * `split` - A string slice that defines the data split for reading (e.g., "train", "test").
///
/// # Returns
///
/// * A `Result` which is `Ok` if the reader could be created, `Err` otherwise.
pub fn reader<I>(&self, split: &str) -> Result<SqliteDataset<I>>
where
I: Clone + Send + Sync + Serialize + DeserializeOwned,
{
if !self.exists() {
panic!("The database file does not exist");
}
SqliteDataset::from_db_file(self.db_file(), split)
}
}
/// This `SqliteDatasetWriter` struct is a SQLite database writer dedicated to storing datasets.
/// It retains the current writer's state and its database connection.
///
/// Being thread-safe, this writer can be concurrently used across multiple threads.
///
/// Typical applications include:
///
/// - Generation of a new dataset
/// - Storage of preprocessed data or metadata
/// - Enlargement of a dataset's item count post preprocessing
#[derive(Debug)]
pub struct SqliteDatasetWriter<I> {
db_file: PathBuf,
db_file_tmp: Option<Handle<Writable>>,
splits: Arc<RwLock<HashSet<String>>>,
overwrite: bool,
conn_pool: Option<Pool<SqliteConnectionManager>>,
is_completed: Arc<RwLock<bool>>,
phantom: PhantomData<I>,
}
impl<I> SqliteDatasetWriter<I>
where
I: Clone + Send + Sync + Serialize + DeserializeOwned,
{
/// Creates a new instance of `SqliteDatasetWriter`.
///
/// # Arguments
///
/// * `db_file` - A reference to the Path that represents the database file path.
/// * `overwrite` - A boolean indicating if the existing database file should be overwritten.
///
/// # Returns
///
/// * A `Result` which is `Ok` if the writer could be created, `Err` otherwise.
pub fn new<P: AsRef<Path>>(db_file: P, overwrite: bool) -> Result<Self> {
let writer = Self {
db_file: db_file.as_ref().to_path_buf(),
db_file_tmp: None,
splits: Arc::new(RwLock::new(HashSet::new())),
overwrite,
conn_pool: None,
is_completed: Arc::new(RwLock::new(false)),
phantom: PhantomData,
};
writer.init()
}
/// Initializes the dataset writer by creating the database file, tables, and connection pool.
///
/// # Returns
///
/// * A `Result` which is `Ok` if the writer could be initialized, `Err` otherwise.
fn init(mut self) -> Result<Self> {
// Remove the db file if it already exists
if self.db_file.exists() {
if self.overwrite {
fs::remove_file(&self.db_file)?;
} else {
return Err(SqliteDatasetError::FileExists(self.db_file));
}
}
// Create the database file directory if it does not exist
let db_file_dir = self
.db_file
.parent()
.ok_or("Unable to get parent directory")?;
if !db_file_dir.exists() {
fs::create_dir_all(db_file_dir)?;
}
// Create a temp database file name as {base_dir}/{name}.db.tmp
let mut db_file_tmp = self.db_file.clone();
db_file_tmp.set_extension("db.tmp");
if db_file_tmp.exists() {
fs::remove_file(&db_file_tmp)?;
}
// Create the temp database file and wrap it with a gix_tempfile::Handle
// This will ensure that the temp file is deleted when the writer is dropped
// or when process exits with SIGINT or SIGTERM (tempfile crate does not do this)
gix_tempfile::signal::setup(Default::default());
self.db_file_tmp = Some(gix_tempfile::writable_at(
&db_file_tmp,
ContainingDirectory::Exists,
AutoRemove::Tempfile,
)?);
let conn_pool = create_conn_pool(db_file_tmp, true)?;
self.conn_pool = Some(conn_pool);
Ok(self)
}
/// Serializes and writes an item to the database. The item is written to the table for the
/// specified split. If the table does not exist, it is created. If the table exists, the item
/// is appended to the table. The serialization is done using the [MessagePack](https://msgpack.org/)
///
/// # Arguments
///
/// * `split` - A string slice that defines the data split for writing (e.g., "train", "test").
/// * `item` - A reference to the item to be written to the database.
///
/// # Returns
///
/// * A `Result` containing the index of the inserted row if successful, an error otherwise.
pub fn write(&self, split: &str, item: &I) -> Result<usize> {
// Acquire the read lock (wont't block other reads)
let is_completed = self.is_completed.read().unwrap();
// If the writer is completed, return an error
if *is_completed {
return Err(SqliteDatasetError::Other(
"Cannot save to a completed dataset writer",
));
}
// create the table for the split if it does not exist
if !self.splits.read().unwrap().contains(split) {
self.create_table(split)?;
}
// Get a connection from the pool
let conn_pool = self.conn_pool.as_ref().unwrap();
let conn = conn_pool.get()?;
// Serialize the item using MessagePack
let serialized_item = rmp_serde::to_vec(item)?;
// Turn off the synchronous and journal mode for speed up
// We are sacrificing durability for speed but it's okay because
// we always recreate the dataset if it is not completed.
pragma_update_with_error_handling(&conn, "synchronous", "OFF")?;
pragma_update_with_error_handling(&conn, "journal_mode", "OFF")?;
// Insert the serialized item into the database
let insert_statement = format!("insert into {split} (item) values (?)");
conn.execute(insert_statement.as_str(), [serialized_item])?;
// Get the primary key of the last inserted row and convert to index (row_id-1)
let index = (conn.last_insert_rowid() - 1) as usize;
Ok(index)
}
/// Marks the dataset as completed and persists the temporary database file.
pub fn set_completed(&mut self) -> Result<()> {
let mut is_completed = self.is_completed.write().unwrap();
// Force close the connection pool
// This is required on Windows platform where the connection pool prevents
// from persisting the db by renaming the temp file.
if let Some(pool) = self.conn_pool.take() {
std::mem::drop(pool);
}
// Rename the database file from tmp to db
let _file_result = self
.db_file_tmp
.take() // take ownership of the temporary file and set to None
.unwrap() // unwrap the temporary file
.persist(&self.db_file)?
.ok_or("Unable to persist the database file")?;
*is_completed = true;
Ok(())
}
/// Creates table for the data split.
///
/// Note: call is idempotent and thread-safe.
///
/// # Arguments
///
/// * `split` - A string slice that defines the data split for the table (e.g., "train", "test").
///
/// # Returns
///
/// * A `Result` which is `Ok` if the table could be created, `Err` otherwise.
///
/// TODO (@antimora): add support creating a table with columns corresponding to the item fields
fn create_table(&self, split: &str) -> Result<()> {
// Check if the split already exists
if self.splits.read().unwrap().contains(split) {
return Ok(());
}
let conn_pool = self.conn_pool.as_ref().unwrap();
let connection = conn_pool.get()?;
let create_table_statement = format!(
"create table if not exists {split} (row_id integer primary key autoincrement not \
null, item blob not null)"
);
connection.execute(create_table_statement.as_str(), [])?;
// Add the split to the splits
self.splits.write().unwrap().insert(split.to_string());
Ok(())
}
}
/// Runs a pragma update and ignores the `ExecuteReturnedResults` error.
///
/// Sometimes ExecuteReturnedResults is returned when running a pragma update. This is not an error
/// and can be ignored. This function runs the pragma update and ignores the error if it is
/// `ExecuteReturnedResults`.
fn pragma_update_with_error_handling(
conn: &PooledConnection<SqliteConnectionManager>,
setting: &str,
value: &str,
) -> Result<()> {
let result = conn.pragma_update(None, setting, value);
if let Err(error) = result
&& error != rusqlite::Error::ExecuteReturnedResults
{
return Err(SqliteDatasetError::Sql(error));
}
Ok(())
}
#[cfg(test)]
mod tests {
use rayon::prelude::*;
use rstest::{fixture, rstest};
use serde::{Deserialize, Serialize};
use tempfile::{NamedTempFile, TempDir, tempdir};
use super::*;
type SqlDs = SqliteDataset<Sample>;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Sample {
column_str: String,
column_bytes: Vec<u8>,
column_int: i64,
column_bool: bool,
column_float: f64,
}
#[fixture]
fn train_dataset() -> SqlDs {
SqliteDataset::<Sample>::from_db_file("tests/data/sqlite-dataset.db", "train").unwrap()
}
#[rstest]
pub fn len(train_dataset: SqlDs) {
assert_eq!(train_dataset.len(), 2);
}
#[rstest]
pub fn get_some(train_dataset: SqlDs) {
let item = train_dataset.get(0).unwrap();
assert_eq!(item.column_str, "HI1");
assert_eq!(item.column_bytes, vec![55, 231, 159]);
assert_eq!(item.column_int, 1);
assert!(item.column_bool);
assert_eq!(item.column_float, 1.0);
}
#[rstest]
pub fn get_none(train_dataset: SqlDs) {
assert_eq!(train_dataset.get(10), None);
}
#[rstest]
pub fn multi_thread(train_dataset: SqlDs) {
let indices: Vec<usize> = vec![0, 1, 1, 3, 4, 5, 6, 0, 8, 1];
let results: Vec<Option<Sample>> =
indices.par_iter().map(|&i| train_dataset.get(i)).collect();
let mut match_count = 0;
for (_index, result) in indices.iter().zip(results.iter()) {
if let Some(_val) = result {
match_count += 1
}
}
assert_eq!(match_count, 5);
}
#[test]
fn sqlite_dataset_storage() {
// Test with non-existing file
let storage = SqliteDatasetStorage::from_file("non-existing.db");
assert!(!storage.exists());
// Test with non-existing name
let storage = SqliteDatasetStorage::from_name("non-existing.db");
assert!(!storage.exists());
// Test with existing file
let storage = SqliteDatasetStorage::from_file("tests/data/sqlite-dataset.db");
assert!(storage.exists());
let result = storage.reader::<Sample>("train");
assert!(result.is_ok());
let train = result.unwrap();
assert_eq!(train.len(), 2);
// Test get writer
let temp_file = NamedTempFile::new().unwrap();
let storage = SqliteDatasetStorage::from_file(temp_file.path());
assert!(storage.exists());
let result = storage.writer::<Sample>(true);
assert!(result.is_ok());
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Complex {
column_str: String,
column_bytes: Vec<u8>,
column_int: i64,
column_bool: bool,
column_float: f64,
column_complex: Vec<Vec<Vec<[u8; 3]>>>,
}
/// Create a temporary directory.
#[fixture]
fn tmp_dir() -> TempDir {
// Create a TempDir. This object will be automatically
// deleted when it goes out of scope.
tempdir().unwrap()
}
type Writer = SqliteDatasetWriter<Complex>;
/// Create a SqliteDatasetWriter with a temporary directory.
/// Make sure to return the temporary directory so that it is not deleted.
#[fixture]
fn writer_fixture(tmp_dir: TempDir) -> (Writer, TempDir) {
let temp_dir_str = tmp_dir.path();
let storage = SqliteDatasetStorage::from_name("preprocessed").with_base_dir(temp_dir_str);
let overwrite = true;
let result = storage.writer::<Complex>(overwrite);
assert!(result.is_ok());
let writer = result.unwrap();
(writer, tmp_dir)
}
#[test]
fn test_new() {
// Test that the constructor works with overwrite = true
let test_path = NamedTempFile::new().unwrap();
let _writer = SqliteDatasetWriter::<Complex>::new(&test_path, true).unwrap();
assert!(!test_path.path().exists());
// Test that the constructor works with overwrite = false
let test_path = NamedTempFile::new().unwrap();
let result = SqliteDatasetWriter::<Complex>::new(&test_path, false);
assert!(result.is_err());
// Test that the constructor works with no existing file
let temp = NamedTempFile::new().unwrap();
let test_path = temp.path().to_path_buf();
assert!(temp.close().is_ok());
assert!(!test_path.exists());
let _writer = SqliteDatasetWriter::<Complex>::new(&test_path, true).unwrap();
assert!(!test_path.exists());
}
#[rstest]
pub fn sqlite_writer_write(writer_fixture: (Writer, TempDir)) {
// Get the dataset_saver from the fixture and tmp_dir (will be deleted after scope)
let (writer, _tmp_dir) = writer_fixture;
assert!(writer.overwrite);
assert!(!writer.db_file.exists());
let new_item = Complex {
column_str: "HI1".to_string(),
column_bytes: vec![1_u8, 2, 3],
column_int: 0,
column_bool: true,
column_float: 1.0,
column_complex: vec![vec![vec![[1, 23_u8, 3]]]],
};
let index = writer.write("train", &new_item).unwrap();
assert_eq!(index, 0);
let mut writer = writer;
writer.set_completed().expect("Failed to set completed");
assert!(writer.db_file.exists());
assert!(writer.db_file_tmp.is_none());
let result = writer.write("train", &new_item);
// Should fail because the writer is completed
assert!(result.is_err());
let dataset = SqliteDataset::<Complex>::from_db_file(writer.db_file, "train").unwrap();
let fetched_item = dataset.get(0).unwrap();
assert_eq!(fetched_item, new_item);
assert_eq!(dataset.len(), 1);
}
#[rstest]
pub fn sqlite_writer_write_multi_thread(writer_fixture: (Writer, TempDir)) {
// Get the dataset_saver from the fixture and tmp_dir (will be deleted after scope)
let (writer, _tmp_dir) = writer_fixture;
let writer = Arc::new(writer);
let record_count = 20;
let splits = ["train", "test"];
(0..record_count).into_par_iter().for_each(|index: i64| {
let thread_id: std::thread::ThreadId = std::thread::current().id();
let sample = Complex {
column_str: format!("test_{thread_id:?}_{index}"),
column_bytes: vec![index as u8, 2, 3],
column_int: index,
column_bool: true,
column_float: 1.0,
column_complex: vec![vec![vec![[1, index as u8, 3]]]],
};
// half for train and half for test
let split = splits[index as usize % 2];
let _index = writer.write(split, &sample).unwrap();
});
let mut writer = Arc::try_unwrap(writer).unwrap();
writer
.set_completed()
.expect("Should set completed successfully");
let train =
SqliteDataset::<Complex>::from_db_file(writer.db_file.clone(), "train").unwrap();
let test = SqliteDataset::<Complex>::from_db_file(writer.db_file, "test").unwrap();
assert_eq!(train.len(), record_count as usize / 2);
assert_eq!(test.len(), record_count as usize / 2);
}
}

View File

@@ -0,0 +1,52 @@
#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_cfg))]
//! # Burn Dataset
//!
//! Burn Dataset is a library for creating and loading datasets.
#[macro_use]
extern crate derive_new;
extern crate alloc;
extern crate dirs;
/// Sources for datasets.
pub mod source;
pub mod transform;
/// Audio datasets.
#[cfg(feature = "audio")]
pub mod audio;
/// Vision datasets.
#[cfg(feature = "vision")]
pub mod vision;
/// Natural language processing datasets.
#[cfg(feature = "nlp")]
pub mod nlp;
/// Network dataset utilities.
#[cfg(feature = "network")]
pub mod network {
pub use burn_std::network::*;
}
mod dataset;
pub use dataset::*;
#[cfg(any(feature = "sqlite", feature = "sqlite-bundled"))]
pub use source::huggingface::downloader::*;
#[cfg(test)]
mod test_data {
pub fn string_items() -> Vec<String> {
vec![
"1 Item".to_string(),
"2 Items".to_string(),
"3 Items".to_string(),
"4 Items".to_string(),
]
}
}

View File

@@ -0,0 +1,211 @@
//! AG NEWS Dataset Module
//!
//! This module provides functionality for loading the AG NEWS text classification dataset.
//! AG NEWS is a collection of news articles categorized into different topics.
//! The dataset is split into training (120,000 articles) and test (7,600 articles) sets.
//!
//! ## Dataset Details
//! - **Classes**: 4 categories (World, Sports, Business, Sci/Tech)
//! - **AG NEWS mirror**: [fastai](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L83)
//! - **License**: [Apache License](https://github.com/fastai/fastai/blob/master/LICENSE)
//!
//! ## Usage Example
//! ```rust
//! use burn_dataset::nlp::AgNewsDataset;
//!
//! // Create an AG NEWS dataset accessor
//! let dataset = AgNewsDataset::new();
//!
//! // Access training and test sets
//! let train_dataset = dataset.train();
//! let test_dataset = dataset.test();
//! ```
use std::{path::PathBuf, sync::Mutex};
use flate2::read::GzDecoder;
use serde::{Deserialize, Serialize};
use tar::Archive;
use crate::InMemDataset;
use crate::network::downloader;
/// AG NEWS mirror from [fastai](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L83).
/// Licensed under the [Apache License](https://github.com/fastai/fastai/blob/master/LICENSE).
const AG_NEWS_URL: &str = "https://s3.amazonaws.com/fast-ai-nlp/ag_news_csv.tgz";
/// Represents an item in the AG NEWS dataset.
///
/// Each item contains a label, title, and content of a news article.
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct AgNewsItem {
/// The category label of the news article.
pub label: String,
/// The title of the news article.
pub title: String,
/// The content/body of the news article.
pub content: String,
}
/// AG NEWS dataset accessor.
///
/// This struct provides convenient access to the AG NEWS text classification dataset.
/// It automatically downloads (if not already downloaded), extracts, and loads the datasets.
///
/// The dataset is split into training (120,000 articles) and test (7,600 articles) sets.
pub struct AgNewsDataset {
agnews_dir: PathBuf,
}
/// AG NEWS dataset download lock.
///
/// This lock ensures that only one thread downloads the AG NEWS dataset at a time.
static DOWNLOAD_LOCK: Mutex<()> = Mutex::new(());
impl AgNewsDataset {
/// Creates a new AG NEWS dataset accessor.
///
/// This will download and extract the dataset if it's not already present.
pub fn new() -> Self {
Self {
agnews_dir: Self::download(),
}
}
/// Downloads and extracts the AG NEWS dataset.
///
/// # Returns
/// Path to the directory containing the extracted dataset.
fn download() -> PathBuf {
// Acquire the lock. This will block if another thread already holds the lock.
let _lock = DOWNLOAD_LOCK.lock().unwrap();
// Dataset files are stored in the burn-dataset cache directory
let cache_dir = dirs::cache_dir()
.expect("Could not get cache directory")
.join("burn-dataset");
// AG NEWS dataset directory
let agnews_dir = cache_dir.join("ag_news_csv");
// AG NEWS dataset url
let url = AG_NEWS_URL;
// AG NEWS dataset archive filename
let filename = "ag_news_csv.tgz";
// Check for already downloaded content
if !agnews_dir.exists() {
// Download gzip file
let bytes = downloader::download_file_as_bytes(url, filename);
// Decode gzip file content and unpack archive
let gz_buffer = GzDecoder::new(&bytes[..]);
let mut archive = Archive::new(gz_buffer);
archive.unpack(cache_dir).unwrap();
}
agnews_dir
}
/// Parses a CSV file into an in-memory dataset.
///
/// # Arguments
/// * `file_path` - Path to the CSV file to parse.
///
/// # Returns
/// An `InMemDataset` containing the parsed data.
fn parse_csv(file_path: &str) -> InMemDataset<AgNewsItem> {
let mut rdr = csv::ReaderBuilder::new();
let rdr = rdr.has_headers(false);
InMemDataset::from_csv(file_path, &rdr).expect("Failed to parse CSV file")
}
/// Gets the training dataset.
///
/// # Returns
/// An `InMemDataset` instance containing 120,000 training articles.
pub fn train(&self) -> InMemDataset<AgNewsItem> {
let file_path = self.agnews_dir.join("train.csv");
Self::parse_csv(file_path.to_str().unwrap())
}
/// Gets the test dataset.
///
/// # Returns
/// An `InMemDataset` instance containing 7,600 test articles.
pub fn test(&self) -> InMemDataset<AgNewsItem> {
let file_path = self.agnews_dir.join("test.csv");
Self::parse_csv(file_path.to_str().unwrap())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Dataset;
// AG NEWS dataset train and test dataset lengths
const TRAIN_DATASET_LEN: usize = 120000;
const TEST_DATASET_LEN: usize = 7600;
#[test]
fn test_agnews_download() {
let agnews_dir = AgNewsDataset::download();
assert!(agnews_dir.exists());
}
#[test]
fn test_agnews_len() {
let agnews = AgNewsDataset::new();
let train_dataset = agnews.train();
let test_dataset = agnews.test();
assert_eq!(train_dataset.len(), TRAIN_DATASET_LEN);
assert_eq!(test_dataset.len(), TEST_DATASET_LEN);
}
#[test]
fn test_agnews_first_and_last_item() {
let agnews = AgNewsDataset::new();
// Test the first and the last item in training dataset
let train_dataset = agnews.train();
let first_item = train_dataset.get(0).unwrap();
let last_item = train_dataset.get(train_dataset.len() - 1).unwrap();
assert!(compare_item(&first_item, &("3".to_string(), "Wall St. Bears Claw Back Into the Black (Reuters)".to_string(), "Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.".to_string())));
assert!(compare_item(
&last_item,
&(
"2".to_string(),
"Nets get Carter from Raptors".to_string(),
"INDIANAPOLIS -- All-Star Vince Carter was traded by the Toronto Raptors to the New Jersey Nets for Alonzo Mourning, Eric Williams, Aaron Williams, and a pair of first-round draft picks yesterday.".to_string()
)
));
// Test the first and the last item in test dataset
let test_dataset = agnews.test();
let first_item = test_dataset.get(0).unwrap();
let last_item = test_dataset.get(test_dataset.len() - 1).unwrap();
assert!(compare_item(
&first_item,
&(
"3".to_string(),
"Fears for T N pension after talks".to_string(),
"Unions representing workers at Turner Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.".to_string()
)
));
assert!(compare_item(
&last_item,
&(
"3".to_string(),
"EBay gets into rentals".to_string(),
"EBay plans to buy the apartment and home rental service Rent.com for \\$415 million, adding to its already exhaustive breadth of offerings.".to_string()
)
));
}
fn compare_item(item: &AgNewsItem, target: &(String, String, String)) -> bool {
item.label == target.0 && item.title == target.1 && item.content == target.2
}
}

View File

@@ -0,0 +1,7 @@
#[cfg(feature = "builtin-sources")]
mod ag_news;
mod text_folder;
#[cfg(feature = "builtin-sources")]
pub use ag_news::*;
pub use text_folder::*;

View File

@@ -0,0 +1,421 @@
use crate::transform::{Mapper, MapperDataset};
use crate::{Dataset, InMemDataset};
use encoding_rs::{GB18030, GBK, UTF_8, UTF_16BE, UTF_16LE};
use globwalk::{self, DirEntry};
use std::collections::{HashMap, HashSet};
use std::fs;
use std::io::Read;
use std::path::{Path, PathBuf};
use thiserror::Error;
const SUPPORTED_FILES: [&str; 1] = ["txt"];
/// Text data type.
#[derive(Debug, Clone, PartialEq)]
pub struct TextData {
/// The text content.
pub text: String,
/// Original text source.
pub text_path: String,
}
/// Text dataset item.
#[derive(Debug, Clone, PartialEq)]
pub struct TextDatasetItem {
/// Text content.
pub text: TextData,
/// Label for the text.
pub label: usize,
}
/// Raw text dataset item.
#[derive(Debug, Clone)]
struct TextDatasetItemRaw {
/// Text path.
text_path: PathBuf,
/// Text label.
label: String,
}
impl TextDatasetItemRaw {
fn new<P: AsRef<Path>>(text_path: P, label: String) -> TextDatasetItemRaw {
TextDatasetItemRaw {
text_path: text_path.as_ref().to_path_buf(),
label,
}
}
}
struct PathToTextDatasetItem {
classes: HashMap<String, usize>,
}
/// Parse the text content from file with auto-detection of encoding.
fn parse_text_content(text_path: &PathBuf) -> String {
// Read raw bytes from disk
let mut file = fs::File::open(text_path).unwrap();
let mut bytes = Vec::new();
file.read_to_end(&mut bytes).unwrap();
// Try to detect encoding and decode text
// First try UTF-8 with BOM
if bytes.starts_with(&[0xEF, 0xBB, 0xBF]) && bytes.len() >= 3 {
let (result, _, had_errors) = UTF_8.decode(&bytes[3..]);
if !had_errors {
return result.into_owned();
}
}
// Try UTF-8 without BOM
let (result, _, had_errors) = UTF_8.decode(&bytes);
if !had_errors {
return result.into_owned();
}
// Try UTF-16LE with BOM
if bytes.starts_with(&[0xFF, 0xFE]) && bytes.len() >= 2 {
let (result, had_errors) = UTF_16LE.decode_with_bom_removal(&bytes[2..]);
if !had_errors {
return result.into_owned();
}
}
// Try UTF-16BE with BOM
if bytes.starts_with(&[0xFE, 0xFF]) && bytes.len() >= 2 {
let (result, had_errors) = UTF_16BE.decode_with_bom_removal(&bytes[2..]);
if !had_errors {
return result.into_owned();
}
}
// Try GB18030 encoding
let (result, _, had_errors) = GB18030.decode(&bytes);
if !had_errors {
return result.into_owned();
}
// Try GBK encoding
let (result, _, had_errors) = GBK.decode(&bytes);
if !had_errors {
return result.into_owned();
}
// Default fallback - use from_utf8_lossy for any remaining cases
String::from_utf8_lossy(&bytes).to_string()
}
impl Mapper<TextDatasetItemRaw, TextDatasetItem> for PathToTextDatasetItem {
/// Convert a raw text dataset item (path-like) to text content with a target label.
fn map(&self, item: &TextDatasetItemRaw) -> TextDatasetItem {
let label = *self.classes.get(&item.label).unwrap();
// Load text from disk
let text_content = parse_text_content(&item.text_path);
let text_data = TextData {
text: text_content,
text_path: item.text_path.display().to_string(),
};
TextDatasetItem {
text: text_data,
label,
}
}
}
/// Error type for [TextFolderDataset](TextFolderDataset).
#[derive(Error, Debug)]
pub enum TextLoaderError {
/// Unknown error.
#[error("unknown: `{0}`")]
Unknown(String),
/// I/O operation error.
#[error("I/O error: `{0}`")]
IOError(String),
/// Invalid file error.
#[error("Invalid file extension: `{0}`")]
InvalidFileExtensionError(String),
/// Encoding error.
#[error("Encoding error: `{0}`")]
EncodingError(String),
}
type TextDatasetMapper =
MapperDataset<InMemDataset<TextDatasetItemRaw>, PathToTextDatasetItem, TextDatasetItemRaw>;
/// A generic dataset to load texts from disk.
pub struct TextFolderDataset {
dataset: TextDatasetMapper,
}
impl Dataset<TextDatasetItem> for TextFolderDataset {
fn get(&self, index: usize) -> Option<TextDatasetItem> {
self.dataset.get(index)
}
fn len(&self) -> usize {
self.dataset.len()
}
}
impl TextFolderDataset {
/// Create a text classification dataset from the root folder.
///
/// # Arguments
///
/// * `root` - Dataset root folder.
///
/// # Returns
/// A new dataset instance.
pub fn new_classification<P: AsRef<Path>>(root: P) -> Result<Self, TextLoaderError> {
// New dataset containing any of the supported file types
TextFolderDataset::new_classification_with(root, &SUPPORTED_FILES)
}
/// Create a text classification dataset from the root folder.
/// The included texts are filtered based on the provided extensions.
///
/// # Arguments
///
/// * `root` - Dataset root folder.
/// * `extensions` - List of allowed extensions.
///
/// # Returns
/// A new dataset instance.
pub fn new_classification_with<P, S>(root: P, extensions: &[S]) -> Result<Self, TextLoaderError>
where
P: AsRef<Path>,
S: AsRef<str>,
{
// Glob all texts with extensions
let walker = globwalk::GlobWalkerBuilder::from_patterns(
root.as_ref(),
&[format!(
"*.{{{}}}", // "*.{ext1,ext2,ext3}
extensions
.iter()
.map(Self::check_extension)
.collect::<Result<Vec<_>, _>>()?
.join(",")
)],
)
.follow_links(true)
.sort_by(|p1: &DirEntry, p2: &DirEntry| p1.path().cmp(p2.path())) // order by path
.build()
.map_err(|err| TextLoaderError::Unknown(format!("{err:?}")))?
.filter_map(Result::ok);
// Get all dataset items
let mut items = Vec::new();
let mut classes = HashSet::new();
for text in walker {
let text_path = text.path();
// Label name is represented by the parent folder name
let label = text_path
.parent()
.ok_or_else(|| {
TextLoaderError::IOError("Could not resolve text parent folder".to_string())
})?
.file_name()
.ok_or_else(|| {
TextLoaderError::IOError(
"Could not resolve text parent folder name".to_string(),
)
})?
.to_string_lossy()
.into_owned();
classes.insert(label.clone());
items.push(TextDatasetItemRaw::new(text_path, label))
}
// Sort class names
let mut classes = classes.into_iter().collect::<Vec<_>>();
classes.sort();
Self::with_items(items, &classes)
}
/// Create a text classification dataset with the specified items.
///
/// # Arguments
///
/// * `items` - List of dataset items, each item represented by a tuple `(text path, label)`.
/// * `classes` - Dataset class names.
///
/// # Returns
/// A new dataset instance.
pub fn new_classification_with_items<P: AsRef<Path>, S: AsRef<str>>(
items: Vec<(P, String)>,
classes: &[S],
) -> Result<Self, TextLoaderError> {
// Parse items and check valid text extension types
let items = items
.into_iter()
.map(|(path, label)| {
// Map text path and label
let path = path.as_ref();
let label = label;
Self::check_extension(&path.extension().unwrap().to_str().unwrap())?;
Ok(TextDatasetItemRaw::new(path, label))
})
.collect::<Result<Vec<_>, _>>()?;
Self::with_items(items, classes)
}
/// Create a text dataset with the specified items.
///
/// # Arguments
///
/// * `items` - Raw dataset items.
/// * `classes` - Dataset class names.
///
/// # Returns
/// A new dataset instance.
fn with_items<S: AsRef<str>>(
items: Vec<TextDatasetItemRaw>,
classes: &[S],
) -> Result<Self, TextLoaderError> {
// NOTE: right now we don't need to validate the supported text files since
// the method is private. We assume it's already validated.
let dataset = InMemDataset::new(items);
// Class names to index map
let classes = classes.iter().map(|c| c.as_ref()).collect::<Vec<_>>();
let classes_map: HashMap<_, _> = classes
.into_iter()
.enumerate()
.map(|(idx, cls)| (cls.to_string(), idx))
.collect();
let mapper = PathToTextDatasetItem {
classes: classes_map,
};
let dataset = MapperDataset::new(dataset, mapper);
Ok(Self { dataset })
}
/// Check if extension is supported.
fn check_extension<S: AsRef<str>>(extension: &S) -> Result<String, TextLoaderError> {
let extension = extension.as_ref();
if !SUPPORTED_FILES.contains(&extension) {
Err(TextLoaderError::InvalidFileExtensionError(
extension.to_string(),
))
} else {
Ok(extension.to_string())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::Path;
const TEXT_ROOT: &str = "tests/data/text_folder";
#[test]
fn test_text_folder_dataset() {
let dataset = TextFolderDataset::new_classification(TEXT_ROOT).unwrap();
// Dataset should have 4 elements (2 positive + 2 negative)
assert_eq!(dataset.len(), 4);
assert_eq!(dataset.get(4), None);
// Check that we have items from both classes
let mut found_positive = false;
let mut found_negative = false;
for i in 0..dataset.len() {
let item = dataset.get(i).unwrap();
if item.label == 0 {
found_negative = true;
// Check that the text content is loaded correctly
assert!(!item.text.text.is_empty());
assert!(item.text.text_path.contains("negative"));
} else if item.label == 1 {
found_positive = true;
// Check that the text content is loaded correctly
assert!(!item.text.text.is_empty());
assert!(item.text.text_path.contains("positive"));
}
}
// Verify we found items from both classes
assert!(found_positive);
assert!(found_negative);
}
#[test]
fn test_text_folder_dataset_with_invalid_extension() {
// Try to create a dataset with an unsupported extension
let result = TextFolderDataset::new_classification_with(TEXT_ROOT, &["invalid"]);
assert!(result.is_err());
}
#[test]
fn test_text_folder_dataset_with_items() {
// Create the dataset
let root = Path::new(TEXT_ROOT);
let items = vec![
(
root.join("positive").join("sample1.txt"),
"positive".to_string(),
),
(
root.join("negative").join("sample2.txt"),
"negative".to_string(),
),
];
let classes = vec!["positive", "negative"];
let dataset = TextFolderDataset::new_classification_with_items(items, &classes).unwrap();
// Dataset should have 2 elements
assert_eq!(dataset.len(), 2);
assert_eq!(dataset.get(2), None);
// Get items
let item0 = dataset.get(0).unwrap();
let item1 = dataset.get(1).unwrap();
// Check item0
assert!(compare_item(
&item0,
&(
"This is a positive text sample for testing the text folder dataset functionality."
.to_string(),
0
)
));
// Check item1
assert_eq!(item1.label, 1);
assert!(item1.text.text_path.contains("negative"));
assert!(compare_item(
&item1,
&(
"另一个负面文本样本,用以确保数据集能够处理同一类别中的多个文件。".to_string(),
1
)
));
}
fn compare_item(item: &TextDatasetItem, target: &(String, usize)) -> bool {
item.text.text == target.0 && item.label == target.1
}
}

View File

@@ -0,0 +1,367 @@
use std::fs::{self, create_dir_all};
use std::path::{Path, PathBuf};
use std::process::Command;
use crate::{SqliteDataset, SqliteDatasetError, SqliteDatasetStorage};
use sanitize_filename::sanitize;
use serde::de::DeserializeOwned;
use thiserror::Error;
const PYTHON_SOURCE: &str = include_str!("importer.py");
#[cfg(not(target_os = "windows"))]
const VENV_BIN_PYTHON: &str = "bin/python3";
#[cfg(target_os = "windows")]
const VENV_BIN_PYTHON: &str = "Scripts\\python";
/// Error type for [HuggingfaceDatasetLoader](HuggingfaceDatasetLoader).
#[derive(Error, Debug)]
pub enum ImporterError {
/// Unknown error.
#[error("unknown: `{0}`")]
Unknown(String),
/// Fail to download python dependencies.
#[error("fail to download python dependencies: `{0}`")]
FailToDownloadPythonDependencies(String),
/// Fail to create sqlite dataset.
#[error("sqlite dataset: `{0}`")]
SqliteDataset(#[from] SqliteDatasetError),
/// python3 is not installed.
#[error("python3 is not installed")]
PythonNotInstalled,
/// venv environment is not initialized.
#[error("venv environment is not initialized")]
VenvNotInitialized,
}
/// Load a dataset from [huggingface datasets](https://huggingface.co/datasets).
///
/// The dataset with all splits is stored in a single sqlite database (see [SqliteDataset](SqliteDataset)).
///
/// # Example
/// ```no_run
/// use burn_dataset::HuggingfaceDatasetLoader;
/// use burn_dataset::SqliteDataset;
/// use serde::{Deserialize, Serialize};
///
/// #[derive(Deserialize, Debug, Clone)]
/// struct MnistItemRaw {
/// pub image_bytes: Vec<u8>,
/// pub label: usize,
/// }
///
/// let train_ds:SqliteDataset<MnistItemRaw> = HuggingfaceDatasetLoader::new("mnist")
/// .dataset("train")
/// .unwrap();
/// ```
///
/// # Note
/// This loader relies on the [`datasets` library by HuggingFace](https://huggingface.co/docs/datasets/index)
/// to download datasets. This is a Python library, so you must have an existing Python installation.
pub struct HuggingfaceDatasetLoader {
name: String,
subset: Option<String>,
base_dir: Option<PathBuf>,
huggingface_token: Option<String>,
huggingface_cache_dir: Option<String>,
huggingface_data_dir: Option<String>,
trust_remote_code: bool,
use_python_venv: bool,
}
impl HuggingfaceDatasetLoader {
/// Create a huggingface dataset loader.
pub fn new(name: &str) -> Self {
Self {
name: name.to_string(),
subset: None,
base_dir: None,
huggingface_token: None,
huggingface_cache_dir: None,
huggingface_data_dir: None,
trust_remote_code: false,
use_python_venv: true,
}
}
/// Create a huggingface dataset loader for a subset of the dataset.
///
/// The subset name must be one of the subsets listed in the dataset page.
///
/// If no subset names are listed, then do not use this method.
pub fn with_subset(mut self, subset: &str) -> Self {
self.subset = Some(subset.to_string());
self
}
/// Specify a base directory to store the dataset.
///
/// If not specified, the dataset will be stored in the system cache directory under `burn-dataset`.
pub fn with_base_dir(mut self, base_dir: &str) -> Self {
self.base_dir = Some(base_dir.into());
self
}
/// Specify a huggingface token to download datasets behind authentication.
///
/// You can get a token from [tokens settings](https://huggingface.co/settings/tokens)
pub fn with_huggingface_token(mut self, huggingface_token: &str) -> Self {
self.huggingface_token = Some(huggingface_token.to_string());
self
}
/// Specify a huggingface cache directory to store the downloaded datasets.
///
/// If not specified, the dataset will be stored in the system cache directory under `huggingface/datasets`.
pub fn with_huggingface_cache_dir(mut self, huggingface_cache_dir: &str) -> Self {
self.huggingface_cache_dir = Some(huggingface_cache_dir.to_string());
self
}
/// Specify a relative path to a subset of a dataset. This is used in some datasets for the
/// manual steps of dataset download process.
///
/// Unless you've encountered a ManualDownloadError
/// when loading your dataset you probably don't have to worry about this setting.
pub fn with_huggingface_data_dir(mut self, huggingface_data_dir: &str) -> Self {
self.huggingface_data_dir = Some(huggingface_data_dir.to_string());
self
}
/// Specify whether or not to trust remote code.
///
/// If not specified, trust remote code is set to true.
pub fn with_trust_remote_code(mut self, trust_remote_code: bool) -> Self {
self.trust_remote_code = trust_remote_code;
self
}
/// Specify whether or not to use the burn-dataset Python
/// virtualenv for running the importer script. If false, local
/// `python3`'s environment is used.
///
/// If not specified, the virtualenv is used.
pub fn with_use_python_venv(mut self, use_python_venv: bool) -> Self {
self.use_python_venv = use_python_venv;
self
}
/// Load the dataset.
pub fn dataset<I: DeserializeOwned + Clone>(
self,
split: &str,
) -> Result<SqliteDataset<I>, ImporterError> {
let db_file = self.db_file()?;
let dataset = SqliteDataset::from_db_file(db_file, split)?;
Ok(dataset)
}
/// Get the path to the sqlite database file.
///
/// If the database file does not exist, it will be downloaded and imported.
pub fn db_file(self) -> Result<PathBuf, ImporterError> {
// determine (and create if needed) the base directory
let base_dir = SqliteDatasetStorage::base_dir(self.base_dir);
if !base_dir.exists() {
create_dir_all(&base_dir).expect("Failed to create base directory");
}
//sanitize the name and subset
let name = sanitize(self.name.as_str());
// create the db file path
let db_file_name = if let Some(subset) = self.subset.clone() {
format!("{name}-{}.db", sanitize(subset.as_str()))
} else {
format!("{name}.db")
};
let db_file = base_dir.join(db_file_name);
// import the dataset if needed
if !Path::new(&db_file).exists() {
import(
self.name,
self.subset,
db_file.clone(),
base_dir,
self.huggingface_token,
self.huggingface_cache_dir,
self.huggingface_data_dir,
self.trust_remote_code,
self.use_python_venv,
)?;
}
Ok(db_file)
}
}
/// Import a dataset from huggingface. The transformed dataset is stored as sqlite database.
#[allow(clippy::too_many_arguments)]
fn import(
name: String,
subset: Option<String>,
base_file: PathBuf,
base_dir: PathBuf,
huggingface_token: Option<String>,
huggingface_cache_dir: Option<String>,
huggingface_data_dir: Option<String>,
trust_remote_code: bool,
use_python_venv: bool,
) -> Result<(), ImporterError> {
let python_path = if use_python_venv {
install_python_deps(&base_dir)?
} else {
get_python_name()?.into()
};
let mut command = Command::new(python_path);
command.arg(importer_script_path(&base_dir));
command.arg("--name");
command.arg(name);
command.arg("--file");
command.arg(base_file);
if let Some(subset) = subset {
command.arg("--subset");
command.arg(subset);
}
if let Some(huggingface_token) = huggingface_token {
command.arg("--token");
command.arg(huggingface_token);
}
if let Some(huggingface_cache_dir) = huggingface_cache_dir {
command.arg("--cache_dir");
command.arg(huggingface_cache_dir);
}
if let Some(huggingface_data_dir) = huggingface_data_dir {
command.arg("--data_dir");
command.arg(huggingface_data_dir);
}
if trust_remote_code {
command.arg("--trust_remote_code");
command.arg("True");
}
let mut handle = command.spawn().unwrap();
let exit_status = handle
.wait()
.map_err(|err| ImporterError::Unknown(format!("{err:?}")))?;
if !exit_status.success() {
return Err(ImporterError::Unknown(format!("{exit_status}")));
}
Ok(())
}
/// check python --version output is `Python 3.x.x`
fn check_python_version_is_3(python: &str) -> bool {
let output = Command::new(python).arg("--version").output();
match output {
Ok(output) => {
if output.status.success() {
let version_string = String::from_utf8_lossy(&output.stdout);
if let Some(index) = version_string.find(' ') {
let version = &version_string[index + 1..];
version.starts_with("3.")
} else {
false
}
} else {
false
}
}
Err(_error) => false,
}
}
/// get python3 name `python` `python3` or `py`
fn get_python_name() -> Result<&'static str, ImporterError> {
let python_name_list = ["python3", "python", "py"];
for python_name in python_name_list.iter() {
if check_python_version_is_3(python_name) {
return Ok(python_name);
}
}
Err(ImporterError::PythonNotInstalled)
}
fn importer_script_path(base_dir: &Path) -> PathBuf {
let path_file = base_dir.join("importer.py");
fs::write(&path_file, PYTHON_SOURCE).expect("Write python dataset downloader");
path_file
}
fn install_python_deps(base_dir: &Path) -> Result<PathBuf, ImporterError> {
let venv_dir = base_dir.join("venv");
let venv_python_path = venv_dir.join(VENV_BIN_PYTHON);
// If the venv environment is already initialized, skip the initialization.
if !check_python_version_is_3(venv_python_path.to_str().unwrap()) {
let python_name = get_python_name()?;
let mut command = Command::new(python_name);
command.args([
"-m",
"venv",
venv_dir
.as_os_str()
.to_str()
.expect("Path utf8 conversion should not fail"),
]);
// Spawn the venv creation process and wait for it to complete.
let mut handle = command.spawn().unwrap();
handle.wait().map_err(|err| {
ImporterError::FailToDownloadPythonDependencies(format!(" error: {err}"))
})?;
// Check if the venv environment can be used successfully."
if !check_python_version_is_3(venv_python_path.to_str().unwrap()) {
return Err(ImporterError::VenvNotInitialized);
}
}
let mut ensurepip_cmd = Command::new(&venv_python_path);
ensurepip_cmd.args(["-m", "ensurepip", "--upgrade"]);
let status = ensurepip_cmd.status().map_err(|err| {
ImporterError::FailToDownloadPythonDependencies(format!("failed to run ensurepip: {err}"))
})?;
if !status.success() {
return Err(ImporterError::FailToDownloadPythonDependencies(
"ensurepip failed to initialize pip".to_string(),
));
}
let mut command = Command::new(&venv_python_path);
command.args([
"-m",
"pip",
"--quiet",
"install",
"pyarrow",
"sqlalchemy",
"Pillow",
"soundfile",
"datasets",
]);
// Spawn the pip install process and wait for it to complete.
let mut handle = command.spawn().unwrap();
handle
.wait()
.map_err(|err| ImporterError::FailToDownloadPythonDependencies(format!(" error: {err}")))?;
Ok(venv_python_path)
}

View File

@@ -0,0 +1,207 @@
import argparse
import pyarrow as pa
from datasets import Audio, Image, load_dataset
from sqlalchemy import Column, Integer, Table, create_engine, event, inspect
from sqlalchemy.types import LargeBinary
def download_and_export(
name: str,
subset: str,
db_file: str,
token: str,
cache_dir: str,
data_dir: str | None,
trust_remote_code: bool,
):
"""
Download a dataset from using HuggingFace dataset and export it to a sqlite database.
"""
# TODO For media columns (Image and Audio) sometimes when decode=False,
# bytes can be none {'bytes': None, 'path': 'healthy_train.265.jpg'}
# We should handle this case, but unfortunately we did not come across this case yet to test it.
print("*" * 80)
print("Starting huggingface dataset download and export")
print(f"Dataset Name: {name}")
print(f"Subset Name: {subset}")
print(f"Sqlite database file: {db_file}")
print(f"Trust remote code: {trust_remote_code}")
if cache_dir is None:
print(f"Custom cache dir: {cache_dir}")
print("*" * 80)
# Load the dataset
dataset_all = load_dataset(
name,
subset,
cache_dir=cache_dir,
data_dir=data_dir,
use_auth_token=token,
trust_remote_code=trust_remote_code,
)
print(f"Dataset: {dataset_all}")
# Create the database connection descriptor (sqlite)
engine = create_engine(f"sqlite:///{db_file}")
# Set some sqlite pragmas to speed up the database
event.listen(engine, "connect", set_sqlite_pragma)
# Add an row_id column to each table as primary key (datasets does not have API for this)
event.listen(Table, "before_create", add_pk_column)
# Export each split in the dataset
for key in dataset_all.keys():
dataset = dataset_all[key]
# Disable decoding for audio and image fields
dataset = disable_decoding(dataset)
# Flatten the dataset
dataset = dataset.flatten()
# Rename columns to remove dots from the names
dataset = rename_columns(dataset)
print(f"Saving dataset: {name} - {key}")
print(f"Dataset features: {dataset.features}")
# Save the dataset to a sqlite database
dataset.to_sql(
key, # table name
engine,
# don't save the index, use row_id instead (index is not unique)
index=False,
dtype=blob_columns(dataset), # save binary columns as blob
)
# Print the schema of the database so we can reference the columns in the rust code
print_table_info(engine)
def disable_decoding(dataset):
"""
Disable decoding for audio and image fields. The fields will be saved as raw file bytes.
"""
for k, v in dataset.features.items():
if isinstance(v, Audio):
dataset = dataset.cast_column(k, Audio(decode=False))
elif isinstance(v, Image):
dataset = dataset.cast_column(k, Image(decode=False))
return dataset
def rename_columns(dataset):
"""
Rename columns to remove dots from the names. Dots appear in the column names because of the flattening.
Dots are not allowed in column names in rust and sql (unless quoted). So we replace them with underscores.
This way there is an easy name mapping between the rust and sql columns.
"""
for name in dataset.features.keys():
if "." in name:
dataset = dataset.rename_column(name, name.replace(".", "_"))
return dataset
def blob_columns(dataset):
"""
Make sure all binary columns are blob columns in the database because
`to_sql` exports binary values as TEXT instead of BLOB.
"""
type_mapping = {}
for name, value in dataset.features.items():
if value.pa_type is not None and pa.types.is_binary(value.pa_type):
type_mapping[name] = LargeBinary
return type_mapping
def set_sqlite_pragma(dbapi_connection, connection_record):
"""
Set some sqlite pragmas to speed up the database
"""
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA synchronous = OFF")
cursor.execute("PRAGMA journal_mode = OFF")
cursor.close()
def add_pk_column(target, connection, **kw):
"""
Add an id column to each table.
"""
target.append_column(Column("row_id", Integer, primary_key=True))
def print_table_info(engine):
"""
Print the schema of the database so we can reference the columns in the rust code
"""
print(f"Printing table schema for sqlite3 db ({engine})")
inspector = inspect(engine)
for table_name in inspector.get_table_names():
print(f"Table: {table_name}")
for column in inspector.get_columns(table_name):
print(f"Column: {column['name']} - {column['type']}")
print("")
def parse_args():
parser = argparse.ArgumentParser(
description="Huggingface datasets downloader to use with burn-dataset"
)
parser.add_argument(
"--name", type=str, help="Name of the dataset to download", required=True
)
parser.add_argument(
"--file", type=str, help="Base file name where the data is saved", required=True
)
parser.add_argument(
"--subset", type=str, help="Subset name", required=False, default=None
)
parser.add_argument(
"--token",
type=str,
help="HuggingFace authentication token",
required=False,
default=None,
)
parser.add_argument(
"--cache_dir", type=str, help="Cache directory", required=False, default=None
)
parser.add_argument(
"--data_dir", type=str, help="Relative path to a specific subset of your dataset", required=False, default=None
)
parser.add_argument(
"--trust_remote_code",
type=bool,
help="Trust remote code",
required=False,
default=None,
)
return parser.parse_args()
def run():
args = parse_args()
download_and_export(
args.name,
args.subset,
args.file,
args.token,
args.data_dir,
args.cache_dir,
args.trust_remote_code,
)
if __name__ == "__main__":
run()

View File

@@ -0,0 +1,3 @@
pub(crate) mod downloader;
pub use downloader::*;

View File

@@ -0,0 +1,3 @@
/// Huggingface source
#[cfg(any(feature = "sqlite", feature = "sqlite-bundled"))]
pub mod huggingface;

View File

@@ -0,0 +1,56 @@
use crate::Dataset;
/// Compose multiple datasets together to create a bigger one.
#[derive(new)]
pub struct ComposedDataset<D> {
datasets: Vec<D>,
}
impl<D, I> Dataset<I> for ComposedDataset<D>
where
D: Dataset<I>,
I: Clone,
{
fn get(&self, index: usize) -> Option<I> {
let mut current_index = 0;
for dataset in self.datasets.iter() {
if index < dataset.len() + current_index {
return dataset.get(index - current_index);
}
current_index += dataset.len();
}
None
}
fn len(&self) -> usize {
let mut total = 0;
for dataset in self.datasets.iter() {
total += dataset.len();
}
total
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::FakeDataset;
#[test]
fn test_composed_dataset() {
let dataset1 = FakeDataset::<String>::new(10);
let dataset2 = FakeDataset::<String>::new(5);
let items1 = dataset1.iter().collect::<Vec<_>>();
let items2 = dataset2.iter().collect::<Vec<_>>();
let composed = ComposedDataset::new(vec![dataset1, dataset2]);
assert_eq!(composed.len(), 15);
let expected_items: Vec<String> = items1.iter().chain(items2.iter()).cloned().collect();
let items = composed.iter().collect::<Vec<_>>();
assert_eq!(items, expected_items);
}
}

View File

@@ -0,0 +1,60 @@
use crate::Dataset;
use std::marker::PhantomData;
/// Basic mapper trait to be used with the [mapper dataset](MapperDataset).
pub trait Mapper<I, O>: Send + Sync {
/// Maps an item of type I to an item of type O.
fn map(&self, item: &I) -> O;
}
/// Dataset mapping each element in an inner dataset to another element type lazily.
#[derive(new)]
pub struct MapperDataset<D, M, I> {
dataset: D,
mapper: M,
input: PhantomData<I>,
}
impl<D, M, I, O> Dataset<O> for MapperDataset<D, M, I>
where
D: Dataset<I>,
M: Mapper<I, O> + Send + Sync,
I: Send + Sync,
O: Send + Sync,
{
fn get(&self, index: usize) -> Option<O> {
let item = self.dataset.get(index);
item.map(|item| self.mapper.map(&item))
}
fn len(&self) -> usize {
self.dataset.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{InMemDataset, test_data};
#[test]
pub fn given_mapper_dataset_when_iterate_should_iterate_though_all_map_items() {
struct StringToFirstChar;
impl Mapper<String, String> for StringToFirstChar {
fn map(&self, item: &String) -> String {
let mut item = item.clone();
item.truncate(1);
item
}
}
let items_original = test_data::string_items();
let dataset = InMemDataset::new(items_original);
let dataset = MapperDataset::new(dataset, StringToFirstChar);
let items: Vec<String> = dataset.iter().collect();
assert_eq!(vec!["1", "2", "3", "4"], items);
}
}

View File

@@ -0,0 +1,30 @@
//! # Dataset Transformations
//!
//! This module provides a collection of [`crate::Dataset`] composition wrappers;
//! providing composition, subset selection, sampling, random shuffling, and windowing.
//!
//! * [`ComposedDataset`] - composes a list of datasets.
//! * [`PartialDataset`] - selects a contiguous index range subset of a dataset.
//! * [`ShuffledDataset`] - a randomly shuffled / mutably shuffle-able dataset;
//! a thin wrapper around [`SelectionDataset`].
//! * [`SamplerDataset`] - samples a dataset; support for with/without replacement,
//! and under/oversampling.
//! * [`SelectionDataset`] - selects a subset of a dataset via indices; support for shuffling.
//! * [`WindowsDataset`] - creates a sliding window over a dataset.
mod composed;
mod mapper;
mod options;
mod partial;
mod sampler;
mod selection;
mod shuffle;
mod window;
pub use composed::*;
pub use mapper::*;
pub use options::*;
pub use partial::*;
pub use sampler::*;
pub use selection::*;
pub use shuffle::*;
pub use window::*;

View File

@@ -0,0 +1,199 @@
use rand::SeedableRng;
use rand::prelude::StdRng;
use rand::rngs::SysRng;
/// Defines a source for a `StdRng`.
///
/// # Examples
///
/// ```rust,no_run
/// use rand::rngs::StdRng;
/// use rand::SeedableRng;
/// use burn_dataset::transform::RngSource;
///
/// // Default via `StdRng::from_os_rng()` (`RngSource::Default`)
/// let system: RngSource = RngSource::default();
///
/// // From a fixed seed (`RngSource::Seed`)
/// let seeded: RngSource = 42.into();
///
/// // From an existing rng (`RngSource::Rng`)
/// let rng = StdRng::seed_from_u64(123);
/// let with_rng: RngSource = rng.into();
///
/// // Forks the parent RNG to derive an independent, deterministic child RNG.
/// // The original `rng` is modified, and the resulting `RngSource` contains
/// // a new RNG starting from a unique state.
/// let mut rng = StdRng::seed_from_u64(123);
/// let forked: RngSource = (&mut rng).into();
/// ```
#[derive(Debug, Default, PartialEq, Eq)]
#[allow(clippy::large_enum_variant)]
pub enum RngSource {
/// Build a new rng from the system.
#[default]
Default,
/// The rng is passed as a seed.
Seed(u64),
/// The rng is passed as an option.
Rng(StdRng),
}
impl From<RngSource> for StdRng {
fn from(source: RngSource) -> Self {
match source {
RngSource::Default => StdRng::try_from_rng(&mut SysRng).unwrap(),
RngSource::Rng(rng) => rng,
RngSource::Seed(seed) => StdRng::seed_from_u64(seed),
}
}
}
impl From<u64> for RngSource {
fn from(seed: u64) -> Self {
Self::Seed(seed)
}
}
impl From<StdRng> for RngSource {
fn from(rng: StdRng) -> Self {
Self::Rng(rng)
}
}
/// Derive an independent RNG from a mutable parent RNG.
///
/// This advances the parent RNG and creates a new RNG seeded from its output.
/// The derived RNG is *not* a clone of the parent's state, but an independent
/// stream (equivalent to `SeedableRng::fork`).
impl From<&mut StdRng> for RngSource {
fn from(rng: &mut StdRng) -> Self {
Self::Rng(rng.fork())
}
}
/// Helper option to describe the size of a wrapper, relative to a wrapped object.
#[derive(Debug, Clone, Copy, Default, PartialEq)]
pub enum SizeConfig {
/// Use the size of the source dataset.
#[default]
Default,
/// Use the size as a ratio of the source dataset size.
///
/// Must be >= 0.
Ratio(f64),
/// Use a fixed size.
Fixed(usize),
}
impl SizeConfig {
/// Construct a source which will have the same size as the source dataset.
pub fn source() -> Self {
Self::Default
}
/// Resolve the effective size.
///
/// ## Arguments
///
/// - `source_size`: the size of the source dataset.
///
/// ## Returns
///
/// The resolved size of the wrapper dataset.
pub fn resolve(self, source_size: usize) -> usize {
match self {
SizeConfig::Default => source_size,
SizeConfig::Ratio(ratio) => {
assert!(ratio >= 0.0, "Ratio must be positive: {ratio}");
((source_size as f64) * ratio) as usize
}
SizeConfig::Fixed(size) => size,
}
}
}
impl From<usize> for SizeConfig {
fn from(size: usize) -> Self {
Self::Fixed(size)
}
}
impl From<f64> for SizeConfig {
fn from(ratio: f64) -> Self {
Self::Ratio(ratio)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::SeedableRng;
#[test]
fn test_rng_source_default() {
let rng_source: RngSource = Default::default();
assert_eq!(&rng_source, &RngSource::Default);
assert_eq!(&rng_source, &RngSource::default());
// Exercise the from_os_rng() call; but we don't know its seed;
let _rng: StdRng = rng_source.into();
}
#[test]
fn test_rng_source_seed() {
let rng_source = RngSource::from(42);
assert_eq!(&rng_source, &RngSource::Seed(42));
let rng: StdRng = rng_source.into();
let expected = StdRng::seed_from_u64(42);
assert_eq!(rng, expected);
}
#[test]
fn test_rng_source_rng() {
// From StdRng (owned).
{
let original = StdRng::seed_from_u64(42);
let rng_source = RngSource::from(original);
let rng: StdRng = rng_source.into();
// No longer clone, but from <> into should not have advanced the state
let original = StdRng::seed_from_u64(42);
assert_eq!(rng, original);
}
// From &mut StdRng (forks parent)
{
let mut original = StdRng::seed_from_u64(42);
let mut rng = StdRng::seed_from_u64(42);
let rng_forked = rng.fork();
let rng_source = RngSource::from(&mut original);
// Ensure the original was advanced
assert_eq!(original, rng);
// Ensure the sourced RNG matches the fork
let rng: StdRng = rng_source.into();
assert_eq!(rng, rng_forked);
}
}
#[test]
fn test_size_config() {
assert_eq!(SizeConfig::default(), SizeConfig::Default);
assert_eq!(SizeConfig::from(42), SizeConfig::Fixed(42));
assert_eq!(SizeConfig::from(1.5), SizeConfig::Ratio(1.5));
assert_eq!(SizeConfig::source(), SizeConfig::Default);
assert_eq!(SizeConfig::source().resolve(50), 50);
}
}

View File

@@ -0,0 +1,206 @@
use crate::Dataset;
use std::{marker::PhantomData, sync::Arc};
/// Only use a fraction of an existing dataset lazily.
#[derive(new, Clone)]
pub struct PartialDataset<D, I> {
dataset: D,
start_index: usize,
end_index: usize,
input: PhantomData<I>,
}
impl<D, I> PartialDataset<D, I>
where
D: Dataset<I>,
{
/// Splits a dataset into multiple partial datasets.
pub fn split(dataset: D, num: usize) -> Vec<PartialDataset<Arc<D>, I>> {
let dataset = Arc::new(dataset); // cheap cloning.
let mut current = 0;
let mut datasets = Vec::with_capacity(num);
let batch_size = dataset.len() / num;
for i in 0..num {
let start = current;
let mut end = current + batch_size;
if i == (num - 1) {
end = dataset.len();
}
let dataset = PartialDataset::new(dataset.clone(), start, end);
current += batch_size;
datasets.push(dataset);
}
datasets
}
/// Splits a dataset by distributing complete chunks/batches across multiple partial datasets.
pub fn split_chunks(
dataset: D,
num: usize,
batch_size: usize,
) -> Vec<PartialDataset<Arc<D>, I>> {
let dataset = Arc::new(dataset); // cheap cloning.
let total_items = dataset.len();
// Total number of complete batches
let total_batches = total_items.div_ceil(batch_size);
let batches_per_split = total_batches / num;
let extra_batches = total_batches % num;
let mut datasets = Vec::with_capacity(num);
let mut current_batch = 0;
for i in 0..num {
// Extra batches distributed across first splits
let split_batches = if i < extra_batches {
batches_per_split + 1
} else {
batches_per_split
};
let start_batch = current_batch;
let end_batch = start_batch + split_batches;
let start_index = start_batch * batch_size;
let end_index = core::cmp::min(end_batch * batch_size, total_items);
if start_index < total_items {
datasets.push(PartialDataset::new(dataset.clone(), start_index, end_index));
}
current_batch = end_batch;
}
datasets
}
}
impl<D, I> Dataset<I> for PartialDataset<D, I>
where
D: Dataset<I>,
I: Clone + Send + Sync,
{
fn get(&self, index: usize) -> Option<I> {
let index = index + self.start_index;
if index < self.start_index || index >= self.end_index {
return None;
}
self.dataset.get(index)
}
fn len(&self) -> usize {
usize::min(self.end_index - self.start_index, self.dataset.len())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::FakeDataset;
use std::collections::HashSet;
#[test]
fn test_start_from_beginning() {
let dataset_original = FakeDataset::<String>::new(27);
let mut items_original_1 = HashSet::new();
let mut items_original_2 = HashSet::new();
let mut items_partial = HashSet::new();
dataset_original.iter().enumerate().for_each(|(i, item)| {
match i >= 10 {
true => items_original_2.insert(item),
false => items_original_1.insert(item),
};
});
let dataset_partial = PartialDataset::new(dataset_original, 0, 10);
for item in dataset_partial.iter() {
items_partial.insert(item);
}
assert_eq!(dataset_partial.len(), 10);
assert_eq!(items_original_1, items_partial);
for item in items_original_2 {
assert!(!items_partial.contains(&item));
}
}
#[test]
fn test_start_inside() {
let dataset_original = FakeDataset::<String>::new(27);
let mut items_original_1 = HashSet::new();
let mut items_original_2 = HashSet::new();
let mut items_partial = HashSet::new();
dataset_original.iter().enumerate().for_each(|(i, item)| {
match !(10..20).contains(&i) {
true => items_original_2.insert(item),
false => items_original_1.insert(item),
};
});
let dataset_partial = PartialDataset::new(dataset_original, 10, 20);
for item in dataset_partial.iter() {
items_partial.insert(item);
}
assert_eq!(dataset_partial.len(), 10);
assert_eq!(items_original_1, items_partial);
for item in items_original_2 {
assert!(!items_partial.contains(&item));
}
}
#[test]
fn test_split_contains_all_items_without_duplicates() {
let dataset_original = FakeDataset::<String>::new(27);
let mut items_original = Vec::new();
let mut items_partial = Vec::new();
for item in dataset_original.iter() {
items_original.push(item);
}
let dataset_partials = PartialDataset::split(dataset_original, 4);
let expected_len = [6, 6, 6, 9];
for (i, dataset) in dataset_partials.iter().enumerate() {
assert_eq!(dataset.len(), expected_len[i]);
for item in dataset.iter() {
items_partial.push(item);
}
}
assert_eq!(items_original, items_partial);
}
#[test]
fn test_split_chunks_contains_all_items_without_duplicates() {
let dataset_original = FakeDataset::<String>::new(27);
let mut items_original = Vec::new();
let mut items_partial = Vec::new();
for item in dataset_original.iter() {
items_original.push(item);
}
let dataset_partials = PartialDataset::split_chunks(dataset_original, 4, 5);
// [(2 * 5), (2 * 5), 5, 2] -> 5 complete chunks + 1 incomplete with 2 remaining items
// OTOH, `split(dataset, 4)` would yield [6, 6, 6, 9] -> 4 incomplete chunks + 4 incomplete with [1, 1, 1, 4]
let expected_len = [10, 10, 5, 2];
for (i, dataset) in dataset_partials.iter().enumerate() {
assert_eq!(dataset.len(), expected_len[i]);
for item in dataset.iter() {
items_partial.push(item);
}
}
assert_eq!(items_original, items_partial);
}
}

View File

@@ -0,0 +1,438 @@
use crate::Dataset;
use crate::transform::{RngSource, SizeConfig};
use rand::prelude::SliceRandom;
use rand::{RngExt, distr::Uniform, rngs::StdRng, seq::IteratorRandom};
use std::{marker::PhantomData, ops::DerefMut, sync::Mutex};
/// Options to configure a [SamplerDataset].
#[derive(Debug, PartialEq)]
pub struct SamplerDatasetOptions {
/// The sampling mode.
pub replace_samples: bool,
/// The size source of the wrapper relative to the dataset.
pub size_config: SizeConfig,
/// The source of the random number generator.
pub rng_source: RngSource,
}
impl Default for SamplerDatasetOptions {
fn default() -> Self {
Self {
replace_samples: true,
size_config: SizeConfig::Default,
rng_source: RngSource::Default,
}
}
}
impl<T> From<Option<T>> for SamplerDatasetOptions
where
T: Into<SamplerDatasetOptions>,
{
fn from(option: Option<T>) -> Self {
match option {
Some(option) => option.into(),
None => Self::default(),
}
}
}
impl From<usize> for SamplerDatasetOptions {
fn from(size: usize) -> Self {
Self::default().with_replacement().with_fixed_size(size)
}
}
impl SamplerDatasetOptions {
/// Set the replacement mode.
pub fn with_replace_samples(self, replace_samples: bool) -> Self {
Self {
replace_samples,
..self
}
}
/// Set the replacement mode to WithReplacement.
pub fn with_replacement(self) -> Self {
self.with_replace_samples(true)
}
/// Set the replacement mode to WithoutReplacement.
pub fn without_replacement(self) -> Self {
self.with_replace_samples(false)
}
/// Set the size source.
pub fn with_size<S>(self, source: S) -> Self
where
S: Into<SizeConfig>,
{
Self {
size_config: source.into(),
..self
}
}
/// Set the size to the size of the source.
pub fn with_source_size(self) -> Self {
self.with_size(SizeConfig::Default)
}
/// Set the size to a fixed size.
pub fn with_fixed_size(self, size: usize) -> Self {
self.with_size(size)
}
/// Set the size to be a multiple of the ration and the source size.
pub fn with_size_ratio(self, size_ratio: f64) -> Self {
self.with_size(size_ratio)
}
/// Set the `RngSource`.
pub fn with_rng<R>(self, rng: R) -> Self
where
R: Into<RngSource>,
{
Self {
rng_source: rng.into(),
..self
}
}
/// Use the system rng.
pub fn with_system_rng(self) -> Self {
self.with_rng(RngSource::Default)
}
/// Use a rng, built from a seed.
pub fn with_seed(self, seed: u64) -> Self {
self.with_rng(seed)
}
}
/// Sample items from a dataset.
///
/// This is a convenient way of modeling a dataset as a probability distribution of a fixed size.
/// You have multiple options to instantiate the dataset sampler.
///
/// * With replacement (Default): This is the most efficient way of using the sampler because no state is
/// required to keep indices that have been selected.
///
/// * Without replacement: This has a similar effect to using a
/// [shuffled dataset](crate::transform::ShuffledDataset), but with more flexibility since you can
/// set the dataset to an arbitrary size. Once every item has been used, a new cycle is
/// created with a new random suffle.
pub struct SamplerDataset<D, I> {
dataset: D,
size: usize,
state: Mutex<SamplerState>,
input: PhantomData<I>,
}
enum SamplerState {
WithReplacement(StdRng),
WithoutReplacement(StdRng, Vec<usize>),
}
impl<D, I> SamplerDataset<D, I>
where
D: Dataset<I>,
I: Send + Sync,
{
/// Creates a new sampler dataset with replacement.
///
/// When the sample size is less than or equal to the source dataset size,
/// data will be sampled without replacement from the source dataset in
/// a uniformly shuffled order.
///
/// When the sample size is greater than the source dataset size,
/// the entire source dataset will be sampled once for every multiple
/// of the size ratios; with the remaining samples taken without replacement
/// uniformly from the source. All samples will be returned uniformly shuffled.
///
/// ## Arguments
///
/// * `dataset`: the dataset to wrap.
/// * `options`: the options to configure the sampler dataset.
///
/// ## Examples
/// ```rust,ignore
/// use burn_dataset::transform::{
/// SamplerDataset,
/// SamplerDatasetOptions,
/// };
///
/// // Examples below assuming `dataset.len()` = `10`.
///
/// // sample size: 5
/// // WithReplacement
/// // rng: StdRng::from_os_rng()
/// SamplerDataset::new(dataset, 5);
///
/// // sample size: 10 (source)
/// // WithReplacement
/// // rng: StdRng::from_os_rng()
/// SamplerDataset::new(dataset, SamplerDatasetOptions::default());
///
/// // sample size: 15
/// // WithoutReplacement
/// // rng: StdRng::seed_from_u64(42)
/// SamplerDataset::new(
/// dataset,
/// SamplerDatasetOptions::default()
/// .with_size(1.5)
/// .without_replacement()
/// .with_rng(42),
/// );
/// ```
pub fn new<O>(dataset: D, options: O) -> Self
where
O: Into<SamplerDatasetOptions>,
{
let options = options.into();
let size = options.size_config.resolve(dataset.len());
let rng = options.rng_source.into();
Self {
dataset,
size,
state: Mutex::new(match options.replace_samples {
true => SamplerState::WithReplacement(rng),
false => SamplerState::WithoutReplacement(rng, Vec::with_capacity(size)),
}),
input: PhantomData,
}
}
/// Creates a new sampler dataset with replacement.
///
/// # Arguments
///
/// - `dataset`: the dataset to wrap.
/// - `size`: the effective size of the sampled dataset.
pub fn with_replacement(dataset: D, size: usize) -> Self {
Self::new(
dataset,
SamplerDatasetOptions::default()
.with_replacement()
.with_fixed_size(size),
)
}
/// Creates a new sampler dataset without replacement.
///
/// When the sample size is less than or equal to the source dataset size,
/// data will be sampled without replacement from the source dataset in
/// a uniformly shuffled order.
///
/// When the sample size is greater than the source dataset size,
/// the entire source dataset will be sampled once for every multiple
/// of the size ratios; with the remaining samples taken without replacement
/// uniformly from the source. All samples will be returned uniformly shuffled.
///
/// # Arguments
/// - `dataset`: the dataset to wrap.
/// - `size`: the effective size of the sampled dataset.
pub fn without_replacement(dataset: D, size: usize) -> Self {
Self::new(
dataset,
SamplerDatasetOptions::default()
.without_replacement()
.with_fixed_size(size),
)
}
/// Determines if the sampler is using the "with replacement" strategy.
///
/// # Returns
/// - `true`: If the sampler is configured to sample with replacement.
/// - `false`: If the sampler is configured to sample without replacement.
pub fn is_with_replacement(&self) -> bool {
match self.state.lock().unwrap().deref_mut() {
SamplerState::WithReplacement(_) => true,
SamplerState::WithoutReplacement(_, _) => false,
}
}
fn index(&self) -> usize {
match self.state.lock().unwrap().deref_mut() {
SamplerState::WithReplacement(rng) => {
rng.sample(Uniform::new(0, self.dataset.len()).unwrap())
}
SamplerState::WithoutReplacement(rng, indices) => {
if indices.is_empty() {
// Refill the state.
let idx_range = 0..self.dataset.len();
for _ in 0..(self.size / self.dataset.len()) {
// No need to `.choose_multiple` here because we're using
// the entire source range; and `.choose_multiple` will
// not return a random sample anyway.
indices.extend(idx_range.clone())
}
// From `choose_multiple` documentation:
// > Although the elements are selected randomly, the order of elements in
// > the buffer is neither stable nor fully random. If random ordering is
// > desired, shuffle the result.
indices.extend(idx_range.sample(rng, self.size - indices.len()));
// The real shuffling is done here.
indices.shuffle(rng);
}
indices.pop().expect("Indices are refilled when empty.")
}
}
}
}
impl<D, I> Dataset<I> for SamplerDataset<D, I>
where
D: Dataset<I>,
I: Send + Sync,
{
fn get(&self, index: usize) -> Option<I> {
if index >= self.size {
return None;
}
self.dataset.get(self.index())
}
fn len(&self) -> usize {
self.size
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::bool_assert_comparison)]
use super::*;
use crate::FakeDataset;
use rand::SeedableRng;
use std::collections::HashMap;
#[test]
fn test_samplerdataset_options() {
let options = SamplerDatasetOptions::default();
assert_eq!(options.replace_samples, true);
assert_eq!(options.size_config, SizeConfig::Default);
assert_eq!(options.rng_source, RngSource::Default);
// ReplacementMode
let options = options.with_replace_samples(false);
assert_eq!(options.replace_samples, false);
let options = options.with_replacement();
assert_eq!(options.replace_samples, true);
let options = options.without_replacement();
assert_eq!(options.replace_samples, false);
// SourceSize
let options = options.with_size(SizeConfig::Default);
assert_eq!(options.size_config, SizeConfig::Default);
let options = options.with_source_size();
assert_eq!(options.size_config, SizeConfig::Default);
let options = options.with_fixed_size(10);
assert_eq!(options.size_config, SizeConfig::Fixed(10));
let options = options.with_size_ratio(1.5);
assert_eq!(options.size_config, SizeConfig::Ratio(1.5));
// RngSource
let options = options.with_system_rng();
assert_eq!(options.rng_source, RngSource::Default);
let options = options.with_seed(42);
assert_eq!(options.rng_source, RngSource::Seed(42));
let rng = StdRng::seed_from_u64(9);
let options = options.with_rng(rng);
assert!(matches!(options.rng_source, RngSource::Rng(_)));
}
#[test]
fn sampler_dataset_constructors_test() {
let ds = SamplerDataset::new(FakeDataset::<u32>::new(10), 15);
assert_eq!(ds.len(), 15);
assert_eq!(ds.dataset.len(), 10);
assert!(ds.is_with_replacement());
let ds = SamplerDataset::with_replacement(FakeDataset::<u32>::new(10), 15);
assert_eq!(ds.len(), 15);
assert_eq!(ds.dataset.len(), 10);
assert!(ds.is_with_replacement());
let ds = SamplerDataset::without_replacement(FakeDataset::<u32>::new(10), 15);
assert_eq!(ds.len(), 15);
assert_eq!(ds.dataset.len(), 10);
assert!(!ds.is_with_replacement());
}
#[test]
fn sampler_dataset_with_replacement_iter() {
let factor = 3;
let len_original = 10;
let dataset_sampler = SamplerDataset::with_replacement(
FakeDataset::<String>::new(len_original),
len_original * factor,
);
let mut total = 0;
for _item in dataset_sampler.iter() {
total += 1;
}
assert_eq!(total, factor * len_original);
}
#[test]
fn sampler_dataset_without_replacement_bucket_test() {
let factor = 3;
let len_original = 10;
let dataset_sampler = SamplerDataset::new(
FakeDataset::<String>::new(len_original),
SamplerDatasetOptions::default()
.without_replacement()
.with_size_ratio(factor as f64),
);
let mut buckets = HashMap::new();
for item in dataset_sampler.iter() {
let count = match buckets.get(&item) {
Some(count) => count + 1,
None => 1,
};
buckets.insert(item, count);
}
let mut total = 0;
for count in buckets.into_values() {
assert_eq!(count, factor);
total += count;
}
assert_eq!(total, factor * len_original);
}
#[test]
fn sampler_dataset_without_replacement_uniform_order_test() {
// This is a reversion test on the indices.shuffle(rng) call in SamplerDataset::index().
let size = 1000;
let dataset_sampler =
SamplerDataset::without_replacement(FakeDataset::<i32>::new(size), size);
let indices: Vec<_> = (0..size).map(|_| dataset_sampler.index()).collect();
let mean_delta = indices
.windows(2)
.map(|pair| pair[1].abs_diff(pair[0]))
.sum::<usize>() as f64
/ (size - 1) as f64;
let expected = (size + 2) as f64 / 3.0;
assert!(
(mean_delta - expected).abs() <= 0.25 * expected,
"Sampled indices are not uniformly distributed: mean_delta: {mean_delta}, expected: {expected}"
);
}
}

View File

@@ -0,0 +1,374 @@
use crate::Dataset;
use crate::transform::RngSource;
use rand::prelude::SliceRandom;
use rand::rngs::StdRng;
use std::marker::PhantomData;
use std::sync::Arc;
/// Generates a vector of indices from 0 to size - 1.
///
/// # Arguments
///
/// * `size` - The size of the dataset.
///
/// # Returns
///
/// A vector containing indices from 0 to size - 1.
#[inline(always)]
pub fn iota(size: usize) -> Vec<usize> {
(0..size).collect()
}
/// Generates a shuffled vector of indices up to a size.
///
/// # Arguments
///
/// * `size` - The size of the dataset to shuffle.
///
/// # Returns
///
/// A vector of shuffled indices.
#[inline(always)]
pub fn shuffled_indices(size: usize, rng: &mut StdRng) -> Vec<usize> {
let mut indices = iota(size);
indices.shuffle(rng);
indices
}
/// A dataset that selects a subset of indices from an existing dataset.
///
/// Indices may appear multiple times, but they must be within the bounds of the original dataset.
#[derive(Clone)]
pub struct SelectionDataset<D, I>
where
D: Dataset<I>,
I: Clone + Send + Sync,
{
/// The wrapped dataset from which to select indices.
pub wrapped: Arc<D>,
/// The indices to select from the wrapped dataset.
pub indices: Vec<usize>,
input: PhantomData<I>,
}
impl<D, I> SelectionDataset<D, I>
where
D: Dataset<I>,
I: Clone + Send + Sync,
{
/// Creates a new selection dataset with the given dataset and indices.
///
/// Checks that all indices are within the bounds of the dataset.
///
/// # Arguments
///
/// * `dataset` - The original dataset to select from.
/// * `indices` - A slice of indices to select from the dataset.
/// These indices must be within the bounds of the dataset.
///
/// # Panics
///
/// Panics if any index is out of bounds for the dataset.
pub fn from_indices_checked<S>(dataset: S, indices: Vec<usize>) -> Self
where
S: Into<Arc<D>>,
{
let dataset = dataset.into();
let size = dataset.len();
if let Some(idx) = indices.iter().find(|&i| *i >= size) {
panic!("Index out of bounds for wrapped dataset size: {idx} >= {size}");
}
Self::from_indices_unchecked(dataset, indices)
}
/// Creates a new selection dataset with the given dataset and indices without checking bounds.
///
/// # Arguments
///
/// * `dataset` - The original dataset to select from.
/// * `indices` - A vector of indices to select from the dataset.
///
/// # Safety
///
/// This function does not check if the indices are within the bounds of the dataset.
pub fn from_indices_unchecked<S>(dataset: S, indices: Vec<usize>) -> Self
where
S: Into<Arc<D>>,
{
Self {
wrapped: dataset.into(),
indices,
input: PhantomData,
}
}
/// Creates a new selection dataset that selects all indices from the dataset.
///
/// This allocates a 1-to-1 mapping of indices to the dataset size,
/// essentially functioning as a no-op selection. This is only useful
/// when the dataset will later be shuffled or transformed in place.
///
/// # Arguments
///
/// * `dataset` - The original dataset to select from.
///
/// # Returns
///
/// A new `SelectionDataset` that selects all indices from the dataset.
pub fn new_select_all<S>(dataset: S) -> Self
where
S: Into<Arc<D>>,
{
let dataset = dataset.into();
let size = dataset.len();
Self::from_indices_unchecked(dataset, iota(size))
}
/// Creates a new selection dataset with shuffled indices.
///
/// Selects every index of the dataset and shuffles them
/// with randomness from the provided random number generator.
///
/// # Arguments
///
/// * `dataset` - The original dataset to select from.
/// * `rng` - A mutable reference to a random number generator.
///
/// # Returns
///
/// A new `SelectionDataset` with shuffled indices.
pub fn new_shuffled<S, R>(dataset: S, rng_source: R) -> Self
where
S: Into<Arc<D>>,
R: Into<RngSource>,
{
let mut this = Self::new_select_all(dataset);
this.shuffle(rng_source);
this
}
/// Shuffles the indices of the dataset using a mutable random number generator.
///
/// This method modifies the dataset in place, shuffling the indices.
///
/// # Arguments
///
/// * `rng` - A mutable reference to a random number generator.
pub fn shuffle<R>(&mut self, rng_source: R)
where
R: Into<RngSource>,
{
let mut rng: StdRng = rng_source.into().into();
self.indices.shuffle(&mut rng)
}
/// Creates a new dataset that is a slice of the current selection dataset.
///
/// Slices the *selection indices* from ``[start..end]``.
///
/// Independent of future shuffles on the parent, but shares the same wrapped dataset.
///
///
/// # Arguments
///
/// * `start` - The start of the range.
/// * `end` - The end of the range (exclusive).
// TODO: SliceArg in burn-tensor should be lifted to burn-std; this should use SliceArg.
pub fn slice(&self, start: usize, end: usize) -> Self {
Self::from_indices_unchecked(self.wrapped.clone(), self.indices[start..end].to_vec())
}
/// Split into `num` datasets by slicing the selection indices evenly.
///
/// Split is done via `slice`, so the datasets share the same wrapped dataset.
///
/// Independent of future shuffles on the parent, but shares the same wrapped dataset.
///
/// # Arguments
///
/// * `num` - The number of datasets to split into.
///
/// # Returns
///
/// A vector of `SelectionDataset` instances, each containing a subset of the indices.
pub fn split(&self, num: usize) -> Vec<Self> {
let n = self.indices.len();
let mut current = 0;
let mut datasets = Vec::with_capacity(num);
let batch_size = n / num;
for i in 0..num {
let start = current;
let mut end = current + batch_size;
if i == (num - 1) {
end = n;
}
let dataset = self.slice(start, end);
current += batch_size;
datasets.push(dataset);
}
datasets
}
}
impl<D, I> Dataset<I> for SelectionDataset<D, I>
where
D: Dataset<I>,
I: Clone + Send + Sync,
{
fn get(&self, index: usize) -> Option<I> {
let index = self.indices.get(index)?;
self.wrapped.get(*index)
}
fn len(&self) -> usize {
self.indices.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::FakeDataset;
use rand::SeedableRng;
#[test]
fn test_iota() {
let size = 10;
let indices = iota(size);
assert_eq!(indices.len(), size);
assert_eq!(indices, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
}
#[test]
fn test_shuffled_indices_same_seed_is_deterministic() {
let size = 10;
let mut rng1 = StdRng::seed_from_u64(10);
// `StdRng` is no longer `Clone`, so its internal state cannot be duplicated.
// To test determinism, we must explicitly create a second RNG from the same seed.
let mut rng2 = StdRng::seed_from_u64(10);
let mut expected = iota(size);
expected.shuffle(&mut rng1);
let indices = shuffled_indices(size, &mut rng2);
assert_eq!(indices, expected);
}
#[test]
fn test_shuffled_indices_forked_rngs_differ() {
let size = 10;
let mut rng1 = StdRng::seed_from_u64(10);
let mut rng2 = rng1.fork();
let mut a = iota(size);
let mut b = iota(size);
a.shuffle(&mut rng1);
b.shuffle(&mut rng2);
assert_ne!(a, b);
}
#[should_panic(expected = "Index out of bounds for wrapped dataset size: 300 >= 27")]
#[test]
fn test_from_indices_checked_panics() {
let source_dataset = FakeDataset::<String>::new(27);
let indices: Vec<usize> = vec![15, 1, 12, 300];
SelectionDataset::from_indices_checked(source_dataset, indices);
}
#[test]
fn test_checked_selection_dataset() {
let source_dataset = FakeDataset::<String>::new(27);
let indices: Vec<usize> = vec![15, 1, 12, 12];
let expected: Vec<String> = indices
.iter()
.map(|i| source_dataset.get(*i).unwrap())
.collect();
let selection = SelectionDataset::from_indices_checked(source_dataset, indices.clone());
assert_eq!(&selection.indices, &indices);
let items = selection.iter().collect::<Vec<_>>();
assert_eq!(items, expected);
}
#[test]
fn test_shuffled_dataset() {
let dataset = FakeDataset::<String>::new(27);
let source_items = dataset.iter().collect::<Vec<_>>();
let selection = SelectionDataset::new_shuffled(dataset, 42);
let indices = shuffled_indices(source_items.len(), &mut StdRng::seed_from_u64(42));
assert_eq!(&selection.indices, &indices);
assert_eq!(selection.len(), source_items.len());
let expected_items: Vec<_> = indices
.iter()
.map(|&i| source_items[i].to_string())
.collect();
assert_eq!(&selection.iter().collect::<Vec<_>>(), &expected_items);
}
#[test]
fn test_slice() {
let dataset = FakeDataset::<String>::new(27);
let source_items = dataset.iter().collect::<Vec<_>>();
let selection = SelectionDataset::new_select_all(dataset);
let start = 5;
let end = 15;
let sliced_selection = selection.slice(start, end);
assert_eq!(sliced_selection.len(), end - start);
#[allow(clippy::needless_range_loop)]
for i in start..end {
assert_eq!(
sliced_selection.get(i - start),
Some(source_items[i].to_string())
);
}
}
#[test]
fn test_split() {
let dataset = FakeDataset::<String>::new(28);
let source_items = dataset.iter().collect::<Vec<_>>();
let selection = SelectionDataset::new_select_all(dataset);
let split_contents: Vec<Vec<_>> = selection
.split(3)
.iter()
.map(|d| d.iter().collect::<Vec<_>>())
.collect();
assert_eq!(
split_contents,
vec![
source_items[0..9].to_vec(),
source_items[9..18].to_vec(),
source_items[18..28].to_vec(),
]
);
}
}

View File

@@ -0,0 +1,109 @@
use crate::Dataset;
use crate::transform::{RngSource, SelectionDataset};
/// A Shuffled a dataset.
///
/// This is a thin wrapper around a [SelectionDataset] which selects and shuffles
/// the full indices of the original dataset.
///
/// Consider using [SelectionDataset] if you are only interested in
/// shuffling mechanisms.
///
/// Consider using [sampler dataset](crate::transform::SamplerDataset) if you
/// want a probability distribution which is computed lazily.
pub struct ShuffledDataset<D, I>
where
D: Dataset<I>,
I: Clone + Send + Sync,
{
wrapped: SelectionDataset<D, I>,
}
impl<D, I> ShuffledDataset<D, I>
where
D: Dataset<I>,
I: Clone + Send + Sync,
{
/// Creates a new selection dataset with shuffled indices.
///
/// This is a thin wrapper around `SelectionDataset::new_shuffled`.
///
/// # Arguments
///
/// * `dataset` - The original dataset to select from.
/// * `rng_source` - The source of the random number generator.
///
/// # Returns
///
/// A new `ShuffledDataset`.
pub fn new<R>(dataset: D, rng_source: R) -> Self
where
R: Into<RngSource>,
{
Self {
wrapped: SelectionDataset::new_shuffled(dataset, rng_source),
}
}
/// Creates a new selection dataset with shuffled indices using a fixed seed.
///
/// This is a thin wrapper around `SelectionDataset::new_shuffled_with_seed`.
///
/// # Arguments
///
/// * `dataset` - The original dataset to select from.
/// * `seed` - A fixed seed for the random number generator.
///
/// # Returns
///
/// A new `ShuffledDataset`.
#[deprecated(since = "0.19.0", note = "Use `new(dataset, seed)` instead`")]
pub fn with_seed(dataset: D, seed: u64) -> Self {
Self::new(dataset, seed)
}
}
impl<D, I> Dataset<I> for ShuffledDataset<D, I>
where
D: Dataset<I>,
I: Clone + Send + Sync,
{
fn get(&self, index: usize) -> Option<I> {
self.wrapped.get(index)
}
fn len(&self) -> usize {
self.wrapped.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::FakeDataset;
use crate::transform::selection::shuffled_indices;
use rand::SeedableRng;
use rand::prelude::StdRng;
#[test]
fn test_shuffled_dataset() {
let dataset = FakeDataset::<String>::new(27);
let source_items = dataset.iter().collect::<Vec<_>>();
let seed = 42;
#[allow(deprecated)]
let shuffled = ShuffledDataset::with_seed(dataset, seed);
let mut rng = StdRng::seed_from_u64(seed);
let indices = shuffled_indices(source_items.len(), &mut rng);
assert_eq!(shuffled.len(), source_items.len());
let expected_items: Vec<_> = indices
.iter()
.map(|&i| source_items[i].to_string())
.collect();
assert_eq!(&shuffled.iter().collect::<Vec<_>>(), &expected_items);
}
}

View File

@@ -0,0 +1,290 @@
use std::{cmp::max, marker::PhantomData, num::NonZeroUsize};
use crate::Dataset;
/// Functionality to create a window.
pub trait Window<I> {
/// Creates a window of a collection.
///
/// # Returns
///
/// A `Vec<I>` representing the window.
fn window(&self, current: usize, size: NonZeroUsize) -> Option<Vec<I>>;
}
impl<I, T: Dataset<I> + ?Sized> Window<I> for T {
fn window(&self, current: usize, size: NonZeroUsize) -> Option<Vec<I>> {
(current..current + size.get())
.map(|x| self.get(x))
.collect()
}
}
/// Functionality to create a `WindowsIterator`.
pub trait Windows<I> {
/// Creates and returns an iterator over all the windows of length `size`.
fn windows(&self, size: usize) -> WindowsIterator<'_, I>;
}
impl<I, T: Dataset<I>> Windows<I> for T {
/// Is empty if the `Dataset` is shorter than `size`.
///
/// # Panics
///
/// Panics if `size` is 0.
///
/// # Examples
///
/// ```
/// use crate::burn_dataset::{
/// transform::{Windows, WindowsDataset},
/// Dataset, InMemDataset,
/// };
///
/// let items = [1, 2, 3, 4].to_vec();
/// let dataset = InMemDataset::new(items.clone());
///
/// for window in dataset.windows(2) {
/// // do sth with window
/// }
/// ```
fn windows(&self, size: usize) -> WindowsIterator<'_, I> {
let size = NonZeroUsize::new(size).expect("window size must be non-zero");
WindowsIterator::new(self, size)
}
}
/// Overlapping windows iterator.
pub struct WindowsIterator<'a, I> {
/// The size of the windows.
pub size: NonZeroUsize,
current: usize,
dataset: &'a dyn Dataset<I>,
}
impl<'a, I> WindowsIterator<'a, I> {
/// Creates a new `WindowsIterator` instance. The windows overlap.
/// Is empty if the input `Dataset` is shorter than `size`.
///
/// # Parameters
///
/// - `dataset`: The dataset over which windows will be created.
/// - `size`: The size of the windows.
pub fn new(dataset: &'a dyn Dataset<I>, size: NonZeroUsize) -> Self {
WindowsIterator {
current: 0,
dataset,
size,
}
}
}
impl<I> Iterator for WindowsIterator<'_, I> {
type Item = Vec<I>;
fn next(&mut self) -> Option<Vec<I>> {
self.current += 1;
self.dataset.window(self.current - 1, self.size)
}
}
impl<I> Clone for WindowsIterator<'_, I> {
fn clone(&self) -> Self {
WindowsIterator {
size: self.size,
dataset: self.dataset,
current: self.current,
}
}
}
/// Dataset designed to work with overlapping windows of data.
pub struct WindowsDataset<D, I> {
/// The size of the windows.
pub size: NonZeroUsize,
dataset: D,
input: PhantomData<I>,
}
impl<D, I> WindowsDataset<D, I>
where
D: Dataset<I>,
{
/// Creates a new `WindowsDataset` instance. The windows overlap.
/// Is empty if the input `Dataset` is shorter than `size`.
///
/// # Parameters
///
/// - `dataset`: The dataset over which windows will be created.
/// - `size`: The size of the windows.
pub fn new(dataset: D, size: usize) -> Self
where
D:,
{
let size = NonZeroUsize::new(size).expect("window size must be non-zero");
WindowsDataset::<D, I> {
size,
dataset,
input: PhantomData,
}
}
}
impl<D, I> Dataset<Vec<I>> for WindowsDataset<D, I>
where
D: Dataset<I>,
I: Send + Sync,
{
/// Retrieves a window of items from the dataset.
///
/// # Parameters
///
/// - `index`: The index of the window.
///
/// # Returns
///
/// A vector representing the window.
fn get(&self, index: usize) -> Option<Vec<I>> {
self.dataset.window(index, self.size)
}
/// Retrieves the number of windows in the dataset.
///
/// # Returns
///
/// A size representing the number of windows.
fn len(&self) -> usize {
let len = self.dataset.len() as isize - self.size.get() as isize + 1;
max(len, 0) as usize
}
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use crate::{
Dataset, InMemDataset,
transform::{Windows, WindowsDataset},
};
#[rstest]
pub fn windows_should_be_equal_to_vec_windows() {
let items = [1, 2, 3, 4, 5].to_vec();
let dataset = InMemDataset::new(items.clone());
let expected = items
.windows(3)
.map(|x| x.to_vec())
.collect::<Vec<Vec<i32>>>();
let result = dataset.windows(3).collect::<Vec<Vec<i32>>>();
assert_eq!(result, expected);
}
#[rstest]
pub fn windows_dataset_should_be_equal_to_vec_windows() {
let items = [1, 2, 3, 4, 5].to_vec();
let dataset = InMemDataset::new(items.clone());
let expected = items
.windows(3)
.map(|x| x.to_vec())
.collect::<Vec<Vec<i32>>>();
let result = WindowsDataset::new(dataset, 3)
.iter()
.collect::<Vec<Vec<i32>>>();
assert_eq!(result, expected);
}
#[rstest]
pub fn cloned_iterator_should_be_equal() {
let items = [1, 2, 3, 4, 5].to_vec();
let dataset = InMemDataset::new(items.clone());
let original = dataset.windows(4);
let cloned = original.clone();
assert!(std::ptr::eq(cloned.dataset, original.dataset));
assert_eq!(cloned.size, original.size);
assert_eq!(cloned.current, original.current);
}
#[rstest]
pub fn cloned_iterator_should_be_unaffected() {
let items = [1, 2, 3, 4, 5].to_vec();
let dataset = InMemDataset::new(items.clone());
let mut original = dataset.windows(4);
let cloned = original.clone();
original.current = 2;
assert_ne!(cloned.current, original.current);
}
#[rstest]
#[should_panic(expected = "window size must be non-zero")]
pub fn windows_should_panic() {
let items = [1, 2].to_vec();
let dataset = InMemDataset::new(items.clone());
dataset.windows(0);
}
#[rstest]
#[should_panic(expected = "window size must be non-zero")]
pub fn new_window_dataset_should_panic() {
let items = [1, 2].to_vec();
let dataset = InMemDataset::new(items.clone());
WindowsDataset::new(dataset, 0);
}
#[rstest]
pub fn window_dataset_len_should_be_equal() {
let dataset = InMemDataset::new([1, 2, 3, 4].to_vec());
let result = WindowsDataset::new(dataset, 2).len();
assert_eq!(result, 3);
}
#[rstest]
pub fn window_iterator_should_be_empty() {
let dataset = InMemDataset::new([1, 2].to_vec());
let mut peekable = dataset.windows(4).peekable();
let result = peekable.peek();
assert_eq!(result, None);
}
#[rstest]
pub fn window_dataset_len_should_be_zero() {
let dataset = InMemDataset::new([1, 2].to_vec());
let result = WindowsDataset::new(dataset, 4).len();
assert_eq!(result, 0);
}
#[rstest]
pub fn window_dataset_get_should_be_equal() {
let dataset = InMemDataset::new([1, 2, 3, 4].to_vec());
let expected = Some([1, 2, 3].to_vec());
let result = WindowsDataset::new(dataset, 3).get(0);
assert_eq!(result, expected);
}
#[rstest]
pub fn window_dataset_get_should_be_none() {
let dataset = InMemDataset::new([1, 2].to_vec());
let result = WindowsDataset::new(dataset, 4).get(0);
assert_eq!(result, None);
}
}

View File

@@ -0,0 +1,241 @@
//! CIFAR Dataset Module
//!
//! This module provides functionality for loading the CIFAR-10 and CIFAR-100 image classification datasets.
//! CIFAR (Canadian Institute For Advanced Research) datasets are widely used benchmarks in computer vision,
//! consisting of 32×32 pixel color images split into training (50,000 images) and test (10,000 images) sets.
//!
//! ## Dataset Variants
//! - **CIFAR-10**: Contains 10 distinct classes (e.g., airplane, automobile, bird, cat)
//! - CIFAR-10 mirror from [fastai](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L44).
//! - Licensed under the [Apache License](https://github.com/fastai/fastai/blob/master/LICENSE).
//! - **CIFAR-100**: Contains 100 fine-grained classes (e.g., beaver, dolphin, oak tree)
//! - CIFAR-100 mirror from [fastai](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L75).
//! - Licensed under the [Apache License](https://github.com/fastai/fastai/blob/master/LICENSE).
//!
//! ## Usage Example
//! ```rust
//! use burn_dataset::vision::CifarDataset;
//! use burn_dataset::vision::CifarType;
//!
//! // Create a CIFAR-10 dataset accessor
//! let dataset = CifarDataset::new(CifarType::Cifar10);
//!
//! // Access training and test sets
//! let train_dataset = dataset.train();
//! let test_dataset = dataset.test();
//! ```
//! ```rust
//! use burn_dataset::vision::CifarDataset;
//! use burn_dataset::vision::CifarType;
//!
//! // Create a CIFAR-100 dataset accessor
//! let dataset = CifarDataset::new(CifarType::Cifar100);
//!
//! // Access training and test sets
//! let train_dataset = dataset.train();
//! let test_dataset = dataset.test();
//! ```
use std::{path::PathBuf, sync::Mutex};
use flate2::read::GzDecoder;
use tar::Archive;
use crate::network::downloader;
use crate::vision::ImageFolderDataset;
/// CIFAR-10 mirror from [fastai](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L44).
/// Licensed under the [Apache License](https://github.com/fastai/fastai/blob/master/LICENSE).
const CIFAR10_URL: &str = "https://s3.amazonaws.com/fast-ai-sample/cifar10.tgz";
/// CIFAR-100 mirror from [fastai](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L75).
/// Licensed under the [Apache License](https://github.com/fastai/fastai/blob/master/LICENSE).
const CIFAR100_URL: &str = "https://s3.amazonaws.com/fast-ai-imageclas/cifar100.tgz";
/// Enum representing the types of CIFAR datasets available.
///
/// CIFAR (Canadian Institute For Advanced Research) datasets are widely used benchmarks for image classification.
/// This enum provides support for the two main CIFAR datasets.
#[derive(Debug, Clone, Copy)]
#[allow(dead_code)]
pub enum CifarType {
/// CIFAR-10 dataset containing 10 classes with 60,000 images in total.
Cifar10,
/// CIFAR-100 dataset containing 100 classes with 60,000 images in total.
Cifar100,
}
/// CIFAR dataset accessor.
///
/// This struct provides convenient access to the CIFAR-10 and CIFAR-100 image classification datasets.
/// It automatically downloads (if not already downloaded), extracts, and loads the datasets.
///
/// All images in CIFAR datasets are 32×32 pixel color images, with 50,000 images in the training set
/// and 10,000 images in the test set.
///
/// ## Differences between datasets
/// - **CIFAR-10**: Contains 10 mutually exclusive classes such as airplane, automobile, bird, cat, etc.
/// - **CIFAR-100**: Contains 100 fine-grained classes such as beaver, dolphin, etc.
pub struct CifarDataset {
cifar_dir: PathBuf,
}
impl CifarDataset {
/// Creates a new CIFAR dataset accessor.
///
/// # Arguments
/// * `cifar_type` - Specifies whether to use CIFAR-10 or CIFAR-100 dataset
pub fn new(cifar_type: CifarType) -> Self {
Self {
cifar_dir: download(&cifar_type),
}
}
/// Gets the training dataset.
///
/// # Returns
/// An `ImageFolderDataset` instance containing 50,000 training images
pub fn train(&self) -> ImageFolderDataset {
ImageFolderDataset::new_classification(self.cifar_dir.join("train")).unwrap()
}
/// Gets the test dataset.
///
/// # Returns
/// An `ImageFolderDataset` instance containing 10,000 test images
pub fn test(&self) -> ImageFolderDataset {
ImageFolderDataset::new_classification(self.cifar_dir.join("test")).unwrap()
}
}
/// CIFAR dataset download lock.
///
/// This lock ensures that only one thread downloads the CIFAR dataset at a time.
static DOWNLOAD_LOCK: Mutex<()> = Mutex::new(());
fn download(cifar_type: &CifarType) -> PathBuf {
// Acquire the lock. This will block if another thread already holds the lock.
let _lock = DOWNLOAD_LOCK.lock().unwrap();
// Dataset files are stored in the burn-dataset cache directory
let cache_dir = dirs::cache_dir()
.expect("Could not get cache directory")
.join("burn-dataset");
// Cifar store directory
let cifar_dir = match cifar_type {
CifarType::Cifar10 => cache_dir.join("cifar10"),
CifarType::Cifar100 => cache_dir.join("cifar100"),
};
// Cifar dataset url
let url = match cifar_type {
CifarType::Cifar10 => CIFAR10_URL,
CifarType::Cifar100 => CIFAR100_URL,
};
// Cifar dataset archive filename
let filename = match cifar_type {
CifarType::Cifar10 => "cifar10.tgz",
CifarType::Cifar100 => "cifar100.tgz",
};
// Check for already downloaded content
if !cifar_dir.exists() {
// Download gzip file
let bytes = downloader::download_file_as_bytes(url, filename);
// Decode gzip file content and unpack archive
let gz_buffer = GzDecoder::new(&bytes[..]);
let mut archive = Archive::new(gz_buffer);
archive.unpack(cache_dir).unwrap();
}
cifar_dir
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Dataset, vision::Annotation};
/// CIFAR dataset length
const TRAINDATASET_LEN: usize = 50000;
const TESTDATASET_LEN: usize = 10000;
/// CIFAR-10 label range
const CIFAR10_LABEL_MIN: usize = 0;
const CIFAR10_LABEL_MAX: usize = 9;
/// CIFAR-100 label range
const CIFAR100_LABEL_MIN: usize = 0;
const CIFAR100_LABEL_MAX: usize = 99;
#[test]
fn test_cifar10_download() {
let cifar_dir = download(&CifarType::Cifar10);
assert!(cifar_dir.exists());
}
#[test]
fn test_cifar100_download() {
let cifar_dir = download(&CifarType::Cifar100);
assert!(cifar_dir.exists());
}
#[test]
fn test_cifar10_len() {
let dataset = CifarDataset::new(CifarType::Cifar10);
let train_dataset = dataset.train();
let test_dataset = dataset.test();
assert_eq!(train_dataset.len(), TRAINDATASET_LEN);
assert_eq!(test_dataset.len(), TESTDATASET_LEN);
}
#[test]
fn test_cifar100_len() {
let dataset = CifarDataset::new(CifarType::Cifar100);
let train_dataset = dataset.train();
let test_dataset = dataset.test();
assert_eq!(train_dataset.len(), TRAINDATASET_LEN);
assert_eq!(test_dataset.len(), TESTDATASET_LEN);
}
#[test]
fn test_cifar10_label_range() {
let dataset = CifarDataset::new(CifarType::Cifar10);
let test_dataset = dataset.test();
let (min, max) = get_label_range(&test_dataset);
assert_eq!(min, CIFAR10_LABEL_MIN);
assert_eq!(max, CIFAR10_LABEL_MAX);
}
#[test]
fn test_cifar100_label_range() {
let dataset = CifarDataset::new(CifarType::Cifar100);
let test_dataset = dataset.test();
let (min, max) = get_label_range(&test_dataset);
assert_eq!(min, CIFAR100_LABEL_MIN);
assert_eq!(max, CIFAR100_LABEL_MAX);
}
fn get_label_range(dataset: &ImageFolderDataset) -> (usize, usize) {
let labels: Vec<_> = dataset.iter().map(|item| item.annotation).collect();
let mut min = 128;
let mut max = 0;
for label in labels {
let index = match label {
Annotation::Label(index) => index,
_ => 0,
};
if index < min {
min = index;
}
if index > max {
max = index;
}
}
(min, max)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,221 @@
use std::fs::{File, create_dir_all};
use std::io::{Read, Seek, SeekFrom};
use std::path::{Path, PathBuf};
use flate2::read::GzDecoder;
use serde::{Deserialize, Serialize};
use crate::{
Dataset, InMemDataset,
transform::{Mapper, MapperDataset},
};
use crate::network::downloader::download_file_as_bytes;
// CVDF mirror of http://yann.lecun.com/exdb/mnist/
const URL: &str = "https://storage.googleapis.com/cvdf-datasets/mnist/";
const TRAIN_IMAGES: &str = "train-images-idx3-ubyte";
const TRAIN_LABELS: &str = "train-labels-idx1-ubyte";
const TEST_IMAGES: &str = "t10k-images-idx3-ubyte";
const TEST_LABELS: &str = "t10k-labels-idx1-ubyte";
const WIDTH: usize = 28;
const HEIGHT: usize = 28;
/// MNIST item.
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct MnistItem {
/// Image as a 2D array of floats.
pub image: [[f32; WIDTH]; HEIGHT],
/// Label of the image.
pub label: u8,
}
#[derive(Deserialize, Debug, Clone)]
struct MnistItemRaw {
pub image_bytes: Vec<u8>,
pub label: u8,
}
struct BytesToImage;
impl Mapper<MnistItemRaw, MnistItem> for BytesToImage {
/// Convert a raw MNIST item (image bytes) to a MNIST item (2D array image).
fn map(&self, item: &MnistItemRaw) -> MnistItem {
// Ensure the image dimensions are correct.
debug_assert_eq!(item.image_bytes.len(), WIDTH * HEIGHT);
// Convert the image to a 2D array of floats.
let mut image_array = [[0f32; WIDTH]; HEIGHT];
for (i, pixel) in item.image_bytes.iter().enumerate() {
let x = i % WIDTH;
let y = i / HEIGHT;
image_array[y][x] = *pixel as f32;
}
MnistItem {
image: image_array,
label: item.label,
}
}
}
type MappedDataset = MapperDataset<InMemDataset<MnistItemRaw>, BytesToImage, MnistItemRaw>;
/// The MNIST dataset consists of 70,000 28x28 black-and-white images in 10 classes (one for each digits), with 7,000
/// images per class. There are 60,000 training images and 10,000 test images.
///
/// The data is downloaded from the web from the [CVDF mirror](https://github.com/cvdfoundation/mnist).
pub struct MnistDataset {
dataset: MappedDataset,
}
impl Dataset<MnistItem> for MnistDataset {
fn get(&self, index: usize) -> Option<MnistItem> {
self.dataset.get(index)
}
fn len(&self) -> usize {
self.dataset.len()
}
}
impl MnistDataset {
/// Creates a new train dataset.
pub fn train() -> Self {
Self::new("train")
}
/// Creates a new test dataset.
pub fn test() -> Self {
Self::new("test")
}
fn new(split: &str) -> Self {
// Download dataset
let root = MnistDataset::download(split);
// MNIST is tiny so we can load it in-memory
// Train images (u8): 28 * 28 * 60000 = 47.04Mb
// Test images (u8): 28 * 28 * 10000 = 7.84Mb
let images = MnistDataset::read_images(&root, split);
let labels = MnistDataset::read_labels(&root, split);
// Collect as vector of MnistItemRaw
let items: Vec<_> = images
.into_iter()
.zip(labels)
.map(|(image_bytes, label)| MnistItemRaw { image_bytes, label })
.collect();
let dataset = InMemDataset::new(items);
let dataset = MapperDataset::new(dataset, BytesToImage);
Self { dataset }
}
/// Download the MNIST dataset files from the web.
/// Panics if the download cannot be completed or the content of the file cannot be written to disk.
fn download(split: &str) -> PathBuf {
// Dataset files are stored in the burn-dataset cache directory
let cache_dir = dirs::cache_dir()
.expect("Could not get cache directory")
.join("burn-dataset");
let split_dir = cache_dir.join("mnist").join(split);
if !split_dir.exists() {
create_dir_all(&split_dir).expect("Failed to create base directory");
}
// Download split files
match split {
"train" => {
MnistDataset::download_file(TRAIN_IMAGES, &split_dir);
MnistDataset::download_file(TRAIN_LABELS, &split_dir);
}
"test" => {
MnistDataset::download_file(TEST_IMAGES, &split_dir);
MnistDataset::download_file(TEST_LABELS, &split_dir);
}
_ => panic!("Invalid split specified {split}"),
};
split_dir
}
/// Download a file from the MNIST dataset URL to the destination directory.
/// File download progress is reported with the help of a [progress bar](indicatif).
fn download_file<P: AsRef<Path>>(name: &str, dest_dir: &P) -> PathBuf {
// Output file name
let file_name = dest_dir.as_ref().join(name);
if !file_name.exists() {
// Download gzip file
let bytes = download_file_as_bytes(&format!("{URL}{name}.gz"), name);
// Create file to write the downloaded content to
let mut output_file = File::create(&file_name).unwrap();
// Decode gzip file content and write to disk
let mut gz_buffer = GzDecoder::new(&bytes[..]);
std::io::copy(&mut gz_buffer, &mut output_file).unwrap();
}
file_name
}
/// Read images at the provided path for the specified split.
/// Each image is a vector of bytes.
fn read_images<P: AsRef<Path>>(root: &P, split: &str) -> Vec<Vec<u8>> {
let file_name = if split == "train" {
TRAIN_IMAGES
} else {
TEST_IMAGES
};
let file_name = root.as_ref().join(file_name);
// Read number of images from 16-byte header metadata
let mut f = File::open(file_name).unwrap();
let mut buf = [0u8; 4];
let _ = f.seek(SeekFrom::Start(4)).unwrap();
f.read_exact(&mut buf)
.expect("Should be able to read image file header");
let size = u32::from_be_bytes(buf);
let mut buf_images: Vec<u8> = vec![0u8; WIDTH * HEIGHT * (size as usize)];
let _ = f.seek(SeekFrom::Start(16)).unwrap();
f.read_exact(&mut buf_images)
.expect("Should be able to read image file header");
buf_images
.chunks(WIDTH * HEIGHT)
.map(|chunk| chunk.to_vec())
.collect()
}
/// Read labels at the provided path for the specified split.
fn read_labels<P: AsRef<Path>>(root: &P, split: &str) -> Vec<u8> {
let file_name = if split == "train" {
TRAIN_LABELS
} else {
TEST_LABELS
};
let file_name = root.as_ref().join(file_name);
// Read number of labels from 8-byte header metadata
let mut f = File::open(file_name).unwrap();
let mut buf = [0u8; 4];
let _ = f.seek(SeekFrom::Start(4)).unwrap();
f.read_exact(&mut buf)
.expect("Should be able to read label file header");
let size = u32::from_be_bytes(buf);
let mut buf_labels: Vec<u8> = vec![0u8; size as usize];
let _ = f.seek(SeekFrom::Start(8)).unwrap();
f.read_exact(&mut buf_labels)
.expect("Should be able to read labels from file");
buf_labels
}
}

View File

@@ -0,0 +1,9 @@
#[cfg(feature = "builtin-sources")]
mod cifar;
mod image_folder;
mod mnist;
#[cfg(feature = "builtin-sources")]
pub use cifar::*;
pub use image_folder::*;
pub use mnist::*;

View File

@@ -0,0 +1,2 @@
HI1 1 true 1.0
HI2 1 false 1.0
1 HI1 1 true 1.0
2 HI2 1 false 1.0

View File

@@ -0,0 +1,3 @@
column_str,column_int,column_bool,column_float
HI1,1,true,1.0
HI2,1,false,1.0
1 column_str column_int column_bool column_float
2 HI1 1 true 1.0
3 HI2 1 false 1.0

View File

@@ -0,0 +1,2 @@
{"column_str":"HI1","column_bytes":[1,2,3,3],"column_int":1,"column_bool":true,"column_float":1.0}
{"column_str":"HI2","column_bytes":[1,2,3,3],"column_int":1,"column_bool":false,"column_float":1.0}

View File

@@ -0,0 +1,132 @@
{
"images": [
{
"width": 32,
"height": 32,
"id": 0,
"file_name": "two_dots_and_triangle.jpg"
},
{
"width": 32,
"height": 32,
"id": 1,
"file_name": "dot_triangle.jpg"
},
{
"width": 32,
"height": 32,
"id": 2,
"file_name": "one_dot.jpg"
}
],
"categories": [
{
"id": 0,
"name": "dot"
},
{
"id": 1,
"name": "triangle"
}
],
"annotations": [
{
"id": 0,
"image_id": 0,
"category_id": 0,
"segmentation": [],
"bbox": [
3.1251719394773056,
18.0907840440165,
10.96011004126548,
10.740027510316379
],
"ignore": 0,
"iscrowd": 0,
"area": 117.71188335928603
},
{
"id": 1,
"image_id": 0,
"category_id": 0,
"segmentation": [],
"bbox": [
3.2572214580467658,
3.0371389270976605,
10.563961485557085,
10.828060522696012
],
"ignore": 0,
"iscrowd": 0,
"area": 114.38721432504178
},
{
"id": 2,
"image_id": 0,
"category_id": 1,
"segmentation": [],
"bbox": [
15.097661623108666,
3.3892709766162312,
12.632737276478679,
11.18019257221458
],
"ignore": 0,
"iscrowd": 0,
"area": 141.23643546522516
},
{
"id": 3,
"image_id": 1,
"category_id": 0,
"segmentation": [],
"bbox": [
3.125171939477304,
17.914718019257222,
10.82806052269601,
11.004126547455297
],
"ignore": 0,
"iscrowd": 0,
"area": 119.15334825525184
},
{
"id": 4,
"image_id": 1,
"category_id": 1,
"segmentation": [],
"bbox": [
15.27372764786794,
3.301237964236589,
12.192572214580478,
11.708390646492433
],
"ignore": 0,
"iscrowd": 0,
"area": 142.7553984738776
},
{
"id": 5,
"image_id": 2,
"category_id": 0,
"segmentation": [],
"bbox": [
10.07977991746905,
9.59559834938102,
10.960110041265464,
11.356258596973863
],
"ignore": 0,
"iscrowd": 0,
"area": 124.46584387990049
}
],
"info": {
"year": 2024,
"version": "1.0",
"description": "",
"contributor": "",
"url": "",
"date_created": "2024-12-11 22:16:31.823494"
}
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 727 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 634 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 120 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 KiB

View File

@@ -0,0 +1,8 @@
1 2 1 2 1 2 1 2
2 1 2 1 2 1 2 1
1 2 1 2 1 2 1 2
2 1 2 1 2 1 2 1
1 2 1 2 1 2 1 2
2 1 2 1 2 1 2 1
1 2 1 2 1 2 1 2
2 1 2 1 2 1 2 1

View File

@@ -0,0 +1,8 @@
1 2 1 1 1 2 1 1
1 2 1 1 1 1 2 1
2 2 2 1 2 1 2 2
2 2 2 2 2 2 1 1
2 2 2 1 2 1 1 1
1 1 2 2 2 2 2 1
2 2 1 2 1 2 1 2
2 1 1 1 1 1 1 1

View File

@@ -0,0 +1,8 @@
3 1 3 3 1 1 3 2
3 3 3 3 1 3 2 1
2 2 2 2 1 1 2 2
1 1 1 3 3 3 2 3
2 2 3 2 3 3 1 3
1 3 3 1 1 3 2 1
2 2 2 1 2 1 2 3
3 1 3 3 2 1 2 2

Binary file not shown.

After

Width:  |  Height:  |  Size: 165 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 133 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 204 B

View File

@@ -0,0 +1 @@
This is a negative text sample for testing the text folder dataset functionality.

View File

@@ -0,0 +1 @@
另一个负面文本样本,用以确保数据集能够处理同一类别中的多个文件。

View File

@@ -0,0 +1 @@
This is a positive text sample for testing the text folder dataset functionality.

View File

@@ -0,0 +1 @@
另一个正面文本样本,以确保数据集能够处理同一类别中的多个文件。