feat: update workspace paths and enhance gitignore
- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
@@ -0,0 +1,84 @@
|
||||
[package]
|
||||
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||
categories = ["science"]
|
||||
description = "Library with simple dataset APIs for creating ML data pipelines"
|
||||
documentation = "https://docs.rs/burn-dataset"
|
||||
edition.workspace = true
|
||||
keywords = ["deep-learning", "machine-learning", "data"]
|
||||
license.workspace = true
|
||||
name = "burn-dataset"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-dataset"
|
||||
version.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[features]
|
||||
default = ["sqlite-bundled"]
|
||||
doc = ["default"]
|
||||
tracing = [
|
||||
"burn-std/tracing",
|
||||
]
|
||||
|
||||
audio = ["hound"]
|
||||
builtin-sources = ["vision", "dep:tar", "nlp"]
|
||||
fake = ["dep:fake"]
|
||||
network = ["dep:burn-std"]
|
||||
sqlite = ["__sqlite-shared", "dep:rusqlite"]
|
||||
sqlite-bundled = ["__sqlite-shared", "rusqlite/bundled"]
|
||||
vision = ["dep:flate2", "dep:globwalk", "dep:image", "network"]
|
||||
nlp = ["dep:zip", "dep:encoding_rs"]
|
||||
# internal
|
||||
__sqlite-shared = [
|
||||
"dep:r2d2",
|
||||
"dep:r2d2_sqlite",
|
||||
"dep:serde_rusqlite",
|
||||
"dep:image",
|
||||
"dep:gix-tempfile",
|
||||
]
|
||||
dataframe = ["dep:polars", "dep:planus"]
|
||||
|
||||
[dependencies]
|
||||
burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", optional = true, features = [
|
||||
"network",
|
||||
] }
|
||||
csv = { workspace = true }
|
||||
derive-new = { workspace = true }
|
||||
dirs = { workspace = true }
|
||||
fake = { workspace = true, optional = true }
|
||||
flate2 = { workspace = true, optional = true }
|
||||
gix-tempfile = { workspace = true, optional = true }
|
||||
globwalk = { workspace = true, optional = true }
|
||||
hound = { workspace = true, optional = true }
|
||||
image = { workspace = true, optional = true }
|
||||
planus = { workspace = true, optional = true }
|
||||
encoding_rs = { workspace = true, optional = true }
|
||||
polars = { workspace = true, optional = true }
|
||||
r2d2 = { workspace = true, optional = true }
|
||||
r2d2_sqlite = { workspace = true, optional = true }
|
||||
rand = { workspace = true, features = ["std", "sys_rng"] }
|
||||
zip = { workspace = true, optional = true }
|
||||
rmp-serde = { workspace = true }
|
||||
rusqlite = { workspace = true, optional = true }
|
||||
sanitize-filename = { workspace = true }
|
||||
serde = { workspace = true, features = ["std", "derive"] }
|
||||
serde_json = { workspace = true, features = ["std"] }
|
||||
serde_rusqlite = { workspace = true, optional = true }
|
||||
strum = { workspace = true }
|
||||
tar = { workspace = true, optional = true }
|
||||
tempfile = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
fake = { workspace = true }
|
||||
rayon = { workspace = true }
|
||||
rstest = { workspace = true }
|
||||
|
||||
[package.metadata.cargo-udeps.ignore]
|
||||
normal = ["strum", "strum_macros"]
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
features = ["doc"]
|
||||
rustdoc-args = ["--cfg", "docsrs"]
|
||||
@@ -0,0 +1 @@
|
||||
../../LICENSE-APACHE
|
||||
@@ -0,0 +1 @@
|
||||
../../LICENSE-MIT
|
||||
@@ -0,0 +1,17 @@
|
||||
# Burn Dataset
|
||||
|
||||
> [Burn](https://github.com/tracel-ai/burn) dataset library
|
||||
|
||||
[](https://crates.io/crates/burn-dataset)
|
||||
[](https://github.com/tracel-ai/burn-dataset/blob/master/README.md)
|
||||
|
||||
The Burn Dataset library is designed to streamline your machine learning (ML) data pipeline creation
|
||||
process. It offers a variety of dataset implementations, transformation functions, and data sources.
|
||||
|
||||
## Feature Flags
|
||||
|
||||
- `audio` - enables audio dataset (SpeechCommandsDataset). Run the following example to try it out:
|
||||
|
||||
```shell
|
||||
cargo run --example speech_commands --features audio
|
||||
```
|
||||
@@ -0,0 +1,22 @@
|
||||
use burn_dataset::HuggingfaceDatasetLoader;
|
||||
use burn_dataset::SqliteDataset;
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
struct MnistItemRaw {
|
||||
pub _image_bytes: Vec<u8>,
|
||||
pub _label: usize,
|
||||
}
|
||||
fn main() {
|
||||
// There are some datasets, such as https://huggingface.co/datasets/ylecun/mnist/tree/main that contains a script,
|
||||
// In this cases you must enable trusting remote code execution if you want to use it.
|
||||
let _train_ds: SqliteDataset<MnistItemRaw> = HuggingfaceDatasetLoader::new("mnist")
|
||||
.with_trust_remote_code(true)
|
||||
.dataset("train")
|
||||
.unwrap();
|
||||
|
||||
// However not all dataset requires it https://huggingface.co/datasets/Anthropic/hh-rlhf/tree/main
|
||||
let _train_ds: SqliteDataset<MnistItemRaw> = HuggingfaceDatasetLoader::new("Anthropic/hh-rlhf")
|
||||
.dataset("train")
|
||||
.unwrap();
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
#[cfg(feature = "audio")]
|
||||
use burn_dataset::{Dataset, audio::SpeechCommandsDataset};
|
||||
|
||||
#[cfg(feature = "audio")]
|
||||
fn speech_command() {
|
||||
let index: usize = 4835;
|
||||
let test = SpeechCommandsDataset::test();
|
||||
let item = test.get(index).unwrap();
|
||||
|
||||
println!("Item: {:?}", item);
|
||||
println!("Item Length: {:?}", item.audio_samples.len());
|
||||
println!("Label: {}", item.label);
|
||||
|
||||
assert_eq!(test.len(), 4890);
|
||||
assert_eq!(item.label.to_string(), "Yes");
|
||||
assert_eq!(item.sample_rate, 16000);
|
||||
assert_eq!(item.audio_samples.len(), 16000);
|
||||
}
|
||||
|
||||
fn main() {
|
||||
#[cfg(feature = "audio")]
|
||||
speech_command()
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
mod speech_commands;
|
||||
|
||||
pub use speech_commands::*;
|
||||
@@ -0,0 +1,208 @@
|
||||
use crate::{
|
||||
Dataset, HuggingfaceDatasetLoader, SqliteDataset,
|
||||
transform::{Mapper, MapperDataset},
|
||||
};
|
||||
|
||||
use hound::WavReader;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use strum::{Display, EnumCount, FromRepr};
|
||||
|
||||
type MappedDataset = MapperDataset<SqliteDataset<SpeechItemRaw>, ConvertSamples, SpeechItemRaw>;
|
||||
|
||||
/// Enum representing speech command classes in the Speech Commands dataset.
|
||||
/// Class names are based on the Speech Commands dataset from Huggingface.
|
||||
/// See [speech_commands](https://huggingface.co/datasets/speech_commands)
|
||||
/// for more information.
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Debug, Display, Clone, Copy, FromRepr, Serialize, Deserialize, EnumCount)]
|
||||
pub enum SpeechCommandClass {
|
||||
// Target command words
|
||||
Yes = 0,
|
||||
No = 1,
|
||||
Up = 2,
|
||||
Down = 3,
|
||||
Left = 4,
|
||||
Right = 5,
|
||||
On = 6,
|
||||
Off = 7,
|
||||
Stop = 8,
|
||||
Go = 9,
|
||||
Zero = 10,
|
||||
One = 11,
|
||||
Two = 12,
|
||||
Three = 13,
|
||||
Four = 14,
|
||||
Five = 15,
|
||||
Six = 16,
|
||||
Seven = 17,
|
||||
Eight = 18,
|
||||
Nine = 19,
|
||||
|
||||
// Non-target words that can be grouped into "Other"
|
||||
Bed = 20,
|
||||
Bird = 21,
|
||||
Cat = 22,
|
||||
Dog = 23,
|
||||
Happy = 24,
|
||||
House = 25,
|
||||
Marvin = 26,
|
||||
Sheila = 27,
|
||||
Tree = 28,
|
||||
Wow = 29,
|
||||
|
||||
// Commands from v2 dataset, that can be grouped into "Other"
|
||||
Backward = 30,
|
||||
Forward = 31,
|
||||
Follow = 32,
|
||||
Learn = 33,
|
||||
Visual = 34,
|
||||
|
||||
// Background noise
|
||||
Silence = 35,
|
||||
|
||||
// Other miscellaneous words
|
||||
Other = 36,
|
||||
}
|
||||
|
||||
/// Struct containing raw speech data returned from a database.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct SpeechItemRaw {
|
||||
/// Audio file bytes.
|
||||
pub audio_bytes: Vec<u8>,
|
||||
|
||||
/// Label index.
|
||||
pub label: usize,
|
||||
|
||||
/// Indicates if the label is unknown.
|
||||
pub is_unknown: bool,
|
||||
}
|
||||
|
||||
/// Speech item with audio samples and label.
|
||||
///
|
||||
/// The audio samples are floats in the range [-1.0, 1.0].
|
||||
/// The sample rate is in Hz.
|
||||
/// The label is the class index (see [SpeechCommandClass]).
|
||||
/// To convert to usize simply use `as usize`. To convert label to string use `.to_string()`.
|
||||
///
|
||||
/// The original label is also stored in the `label_original` field for debugging and remapping if needed.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct SpeechItem {
|
||||
/// Audio samples in the range [-1.0, 1.0].
|
||||
pub audio_samples: Vec<f32>,
|
||||
|
||||
/// The sample rate of the audio.
|
||||
pub sample_rate: usize,
|
||||
|
||||
/// The label of the audio.
|
||||
pub label: SpeechCommandClass,
|
||||
}
|
||||
|
||||
/// Speech Commands dataset from Huggingface v0.02.
|
||||
/// See [Speech Commands dataset](https://huggingface.co/datasets/speech_commands).
|
||||
///
|
||||
/// The data is downloaded from Huggingface and stored in a SQLite database (3.0 GB).
|
||||
/// The dataset contains 99,720 audio samples of 2,607 people saying 35 different words.
|
||||
///
|
||||
/// NOTE: The most samples are under 1 second long but there are some with pure background noise that
|
||||
/// need splitting into shorter segmants.
|
||||
///
|
||||
/// The labels are 20 target words, silence and other words.
|
||||
///
|
||||
/// The dataset is split into 3 parts:
|
||||
/// - train: 84,848 audio files
|
||||
/// - test: 4,890 audio files
|
||||
/// - validation: 9,982 audio files
|
||||
pub struct SpeechCommandsDataset {
|
||||
dataset: MappedDataset,
|
||||
}
|
||||
|
||||
impl SpeechCommandsDataset {
|
||||
/// Create a new dataset with the given split.
|
||||
pub fn new(split: &str) -> Self {
|
||||
let dataset: SqliteDataset<SpeechItemRaw> =
|
||||
HuggingfaceDatasetLoader::new("speech_commands")
|
||||
.with_subset("v0.02")
|
||||
.dataset(split)
|
||||
.unwrap();
|
||||
let dataset = MapperDataset::new(dataset, ConvertSamples);
|
||||
Self { dataset }
|
||||
}
|
||||
|
||||
/// Create a new dataset with the train split.
|
||||
pub fn train() -> Self {
|
||||
Self::new("train")
|
||||
}
|
||||
|
||||
/// Create a new dataset with the test split.
|
||||
pub fn test() -> Self {
|
||||
Self::new("test")
|
||||
}
|
||||
|
||||
/// Create a new dataset with the validation split.
|
||||
pub fn validation() -> Self {
|
||||
Self::new("validation")
|
||||
}
|
||||
|
||||
/// Returns the number of classes in the dataset
|
||||
pub fn num_classes() -> usize {
|
||||
SpeechCommandClass::COUNT
|
||||
}
|
||||
}
|
||||
|
||||
impl Dataset<SpeechItem> for SpeechCommandsDataset {
|
||||
fn get(&self, index: usize) -> Option<SpeechItem> {
|
||||
self.dataset.get(index)
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.dataset.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Mapper converting audio bytes into audio samples and the label to enum class.
|
||||
struct ConvertSamples;
|
||||
|
||||
impl ConvertSamples {
|
||||
/// Convert label to enum class.
|
||||
fn to_speechcommandclass(label: usize) -> SpeechCommandClass {
|
||||
SpeechCommandClass::from_repr(label).unwrap()
|
||||
}
|
||||
|
||||
/// Convert audio bytes into samples of floats [-1.0, 1.0].
|
||||
fn to_audiosamples(bytes: &Vec<u8>) -> (Vec<f32>, usize) {
|
||||
let reader = WavReader::new(bytes.as_slice()).unwrap();
|
||||
let spec = reader.spec();
|
||||
|
||||
// Maximum value of the audio samples (using bit shift to raise 2 to the power of bits per sample).
|
||||
let max_value = (1 << (spec.bits_per_sample - 1)) as f32;
|
||||
|
||||
// The sample rate of the audio.
|
||||
let sample_rate = spec.sample_rate as usize;
|
||||
|
||||
// Convert the audio samples to floats [-1.0, 1.0].
|
||||
let audio_samples: Vec<f32> = reader
|
||||
.into_samples::<i32>()
|
||||
.filter_map(Result::ok)
|
||||
.map(|sample| sample as f32 / max_value)
|
||||
.collect();
|
||||
|
||||
(audio_samples, sample_rate)
|
||||
}
|
||||
}
|
||||
|
||||
impl Mapper<SpeechItemRaw, SpeechItem> for ConvertSamples {
|
||||
/// Convert audio bytes into samples of floats [-1.0, 1.0]
|
||||
/// and the label to enum class with the target word, other and silence classes.
|
||||
fn map(&self, item: &SpeechItemRaw) -> SpeechItem {
|
||||
let (audio_samples, sample_rate) = Self::to_audiosamples(&item.audio_bytes);
|
||||
|
||||
// Convert the label to enum class, with the target words, other and silence classes.
|
||||
let label = Self::to_speechcommandclass(item.label);
|
||||
|
||||
SpeechItem {
|
||||
audio_samples,
|
||||
sample_rate,
|
||||
label,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::DatasetIterator;
|
||||
|
||||
/// The dataset trait defines a basic collection of items with a predefined size.
|
||||
pub trait Dataset<I>: Send + Sync {
|
||||
/// Gets the item at the given index.
|
||||
fn get(&self, index: usize) -> Option<I>;
|
||||
|
||||
/// Gets the number of items in the dataset.
|
||||
fn len(&self) -> usize;
|
||||
|
||||
/// Checks if the dataset is empty.
|
||||
fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
|
||||
/// Returns an iterator over the dataset.
|
||||
fn iter(&self) -> DatasetIterator<'_, I>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
DatasetIterator::new(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, I> Dataset<I> for Arc<D>
|
||||
where
|
||||
D: Dataset<I>,
|
||||
{
|
||||
fn get(&self, index: usize) -> Option<I> {
|
||||
self.as_ref().get(index)
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.as_ref().len()
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> Dataset<I> for Arc<dyn Dataset<I>> {
|
||||
fn get(&self, index: usize) -> Option<I> {
|
||||
self.as_ref().get(index)
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.as_ref().len()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, I> Dataset<I> for Box<D>
|
||||
where
|
||||
D: Dataset<I>,
|
||||
{
|
||||
fn get(&self, index: usize) -> Option<I> {
|
||||
self.as_ref().get(index)
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.as_ref().len()
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> Dataset<I> for Box<dyn Dataset<I>> {
|
||||
fn get(&self, index: usize) -> Option<I> {
|
||||
self.as_ref().get(index)
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.as_ref().len()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,465 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::Dataset;
|
||||
|
||||
use polars::frame::row::Row;
|
||||
use polars::prelude::*;
|
||||
use serde::de::DeserializeSeed;
|
||||
use serde::{
|
||||
Deserialize,
|
||||
de::{self, DeserializeOwned, Deserializer, SeqAccess, Visitor},
|
||||
forward_to_deserialize_any,
|
||||
};
|
||||
|
||||
/// Error type for DataframeDataset
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum DataframeDatasetError {
|
||||
/// Error occurred during deserialization or other operations
|
||||
#[error("{0}")]
|
||||
Other(String),
|
||||
}
|
||||
|
||||
impl de::Error for DataframeDatasetError {
|
||||
fn custom<T: std::fmt::Display>(msg: T) -> Self {
|
||||
DataframeDatasetError::Other(msg.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
/// Dataset implementation for Polars DataFrame
|
||||
///
|
||||
/// This struct provides a way to access data from a Polars DataFrame
|
||||
/// as if it were a Dataset of type I.
|
||||
pub struct DataframeDataset<I> {
|
||||
df: DataFrame,
|
||||
len: usize,
|
||||
column_name_mapping: Vec<usize>,
|
||||
phantom: PhantomData<I>,
|
||||
}
|
||||
|
||||
impl<I> DataframeDataset<I>
|
||||
where
|
||||
I: Clone + Send + Sync + DeserializeOwned,
|
||||
{
|
||||
/// Create a new DataframeDataset from a Polars DataFrame
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `df` - A Polars DataFrame
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A Result containing the new DataframeDataset or a DataframeDatasetError
|
||||
pub fn new(df: DataFrame) -> Result<Self, DataframeDatasetError> {
|
||||
let len = df.height();
|
||||
let field_names = extract_field_names::<I>();
|
||||
|
||||
let column_name_mapping = field_names
|
||||
.iter()
|
||||
.map(|name| {
|
||||
df.schema()
|
||||
.try_get_full(name)
|
||||
.expect("Corresponding column should exist in the DataFrame")
|
||||
.0
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
Ok(DataframeDataset {
|
||||
df,
|
||||
len,
|
||||
column_name_mapping,
|
||||
phantom: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> Dataset<I> for DataframeDataset<I>
|
||||
where
|
||||
I: Clone + Send + Sync + DeserializeOwned,
|
||||
{
|
||||
/// Get an item from the dataset at the specified index
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `index` - The index of the item to retrieve
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An Option containing the item if it exists, or None if it doesn't
|
||||
fn get(&self, index: usize) -> Option<I> {
|
||||
let row = self.df.get_row(index).ok()?;
|
||||
|
||||
let mut deserializer = RowDeserializer::new(&row, &self.column_name_mapping);
|
||||
I::deserialize(&mut deserializer).ok()
|
||||
}
|
||||
|
||||
/// Get the length of the dataset
|
||||
fn len(&self) -> usize {
|
||||
self.len
|
||||
}
|
||||
|
||||
/// Check if the dataset is empty
|
||||
fn is_empty(&self) -> bool {
|
||||
self.len == 0
|
||||
}
|
||||
}
|
||||
|
||||
/// A deserializer for Polars DataFrame rows
|
||||
struct RowDeserializer<'a> {
|
||||
row: &'a Row<'a>,
|
||||
column_name_mapping: &'a Vec<usize>,
|
||||
index: usize,
|
||||
}
|
||||
|
||||
impl<'a> RowDeserializer<'a> {
|
||||
/// Create a new RowDeserializer
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `row` - A reference to a Polars DataFrame row
|
||||
/// * `column_name_mapping` - A reference to a vector mapping field names to column indices
|
||||
fn new(row: &'a Row, column_name_mapping: &'a Vec<usize>) -> RowDeserializer<'a> {
|
||||
RowDeserializer {
|
||||
row,
|
||||
column_name_mapping,
|
||||
index: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de, 'a> Deserializer<'de> for &'a mut RowDeserializer<'a> {
|
||||
type Error = DataframeDatasetError;
|
||||
|
||||
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, DataframeDatasetError>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
let i = self.column_name_mapping[self.index];
|
||||
|
||||
let value = &self.row.0[i];
|
||||
match value {
|
||||
AnyValue::Null => visitor.visit_none(),
|
||||
AnyValue::Boolean(b) => visitor.visit_bool(*b),
|
||||
AnyValue::Int8(i) => visitor.visit_i8(*i),
|
||||
AnyValue::Int16(i) => visitor.visit_i16(*i),
|
||||
AnyValue::Int32(i) => visitor.visit_i32(*i),
|
||||
AnyValue::Int64(i) => visitor.visit_i64(*i),
|
||||
AnyValue::UInt8(i) => visitor.visit_u8(*i),
|
||||
AnyValue::UInt16(i) => visitor.visit_u16(*i),
|
||||
AnyValue::UInt32(i) => visitor.visit_u32(*i),
|
||||
AnyValue::UInt64(i) => visitor.visit_u64(*i),
|
||||
AnyValue::Float32(f) => visitor.visit_f32(*f),
|
||||
AnyValue::Float64(f) => visitor.visit_f64(*f),
|
||||
AnyValue::Date(i) => visitor.visit_i32(*i),
|
||||
AnyValue::String(s) => visitor.visit_string(s.to_string()),
|
||||
AnyValue::Binary(b) => {
|
||||
visitor.visit_seq(de::value::SeqDeserializer::new(b.iter().copied()))
|
||||
}
|
||||
AnyValue::Time(t) => visitor.visit_i64(*t),
|
||||
ty => Err(DataframeDatasetError::Other(
|
||||
format!("Unsupported type: {ty:?}").to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn deserialize_struct<V>(
|
||||
self,
|
||||
_name: &'static str,
|
||||
_fields: &'static [&'static str],
|
||||
visitor: V,
|
||||
) -> Result<V::Value, DataframeDatasetError>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_seq(self)
|
||||
}
|
||||
|
||||
forward_to_deserialize_any! {
|
||||
bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string
|
||||
bytes byte_buf option unit unit_struct newtype_struct seq tuple
|
||||
tuple_struct map enum identifier ignored_any
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de, 'a> SeqAccess<'de> for RowDeserializer<'a> {
|
||||
type Error = DataframeDatasetError;
|
||||
|
||||
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, DataframeDatasetError>
|
||||
where
|
||||
T: DeserializeSeed<'de>,
|
||||
{
|
||||
if self.index >= self.row.0.len() {
|
||||
return Ok(None);
|
||||
}
|
||||
let mut deserializer = RowDeserializer {
|
||||
row: self.row,
|
||||
column_name_mapping: self.column_name_mapping,
|
||||
index: self.index,
|
||||
};
|
||||
self.index += 1;
|
||||
seed.deserialize(&mut deserializer).map(Some)
|
||||
}
|
||||
}
|
||||
|
||||
struct FieldExtractor {
|
||||
fields: Vec<&'static str>,
|
||||
}
|
||||
|
||||
impl<'de> Deserializer<'de> for &mut FieldExtractor {
|
||||
type Error = de::value::Error;
|
||||
|
||||
fn deserialize_any<V>(self, _visitor: V) -> core::result::Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
Err(de::Error::custom("Field extractor"))
|
||||
}
|
||||
|
||||
fn deserialize_struct<V>(
|
||||
self,
|
||||
_name: &'static str,
|
||||
fields: &'static [&'static str],
|
||||
_visitor: V,
|
||||
) -> core::result::Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
self.fields.extend_from_slice(fields);
|
||||
Err(de::Error::custom("Field extractor"))
|
||||
}
|
||||
|
||||
forward_to_deserialize_any! {
|
||||
bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes
|
||||
byte_buf option unit unit_struct newtype_struct seq tuple
|
||||
tuple_struct map enum identifier ignored_any
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract field names from a type T that implements Deserialize
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A vector of field names as static string slices
|
||||
fn extract_field_names<'de, T>() -> Vec<&'static str>
|
||||
where
|
||||
T: Deserialize<'de>,
|
||||
{
|
||||
let mut extractor = FieldExtractor { fields: Vec::new() };
|
||||
let _ = T::deserialize(&mut extractor);
|
||||
extractor.fields
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use polars::prelude::*;
|
||||
use serde::Deserialize;
|
||||
|
||||
use super::*;
|
||||
#[derive(Clone, Debug, Deserialize, PartialEq)]
|
||||
struct TestData {
|
||||
int32: i32,
|
||||
bool: bool,
|
||||
float64: f64,
|
||||
string: String,
|
||||
int16: i16,
|
||||
uint32: u32,
|
||||
uint64: u64,
|
||||
float32: f32,
|
||||
int64: i64,
|
||||
int8: i8,
|
||||
binary: Vec<u8>,
|
||||
}
|
||||
|
||||
fn create_test_dataframe() -> DataFrame {
|
||||
let s0 = Column::new("int32".into(), &[1i32, 2i32, 3i32]);
|
||||
let s1 = Column::new("bool".into(), &[true, false, true]);
|
||||
let s2 = Column::new("float64".into(), &[1.1f64, 2.2f64, 3.3f64]);
|
||||
let s3 = Column::new("string".into(), &["Boo", "Boo2", "Boo3"]);
|
||||
let s6 = Column::new("int16".into(), &[1i16, 2i16, 3i16]);
|
||||
let s8 = Column::new("uint32".into(), &[1u32, 2u32, 3u32]);
|
||||
let s9 = Column::new("uint64".into(), &[1u64, 2u64, 3u64]);
|
||||
let s10 = Column::new("float32".into(), &[1.1f32, 2.2f32, 3.3f32]);
|
||||
let s11 = Column::new("int64".into(), &[1i64, 2i64, 3i64]);
|
||||
let s12 = Column::new("int8".into(), &[1i8, 2i8, 3i8]);
|
||||
|
||||
let binary_data: Vec<&[u8]> = vec![&[1, 2, 3], &[4, 5, 6], &[7, 8, 9]];
|
||||
|
||||
let s13 = Column::new("binary".into(), binary_data);
|
||||
DataFrame::new_infer_height(vec![s0, s1, s2, s3, s6, s8, s9, s10, s11, s12, s13]).unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dataframe_dataset_creation() {
|
||||
let df = create_test_dataframe();
|
||||
let dataset = DataframeDataset::<TestData>::new(df);
|
||||
assert!(dataset.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dataframe_dataset_length() {
|
||||
let df = create_test_dataframe();
|
||||
let dataset = DataframeDataset::<TestData>::new(df).unwrap();
|
||||
assert_eq!(dataset.len(), 3);
|
||||
assert!(!dataset.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dataframe_dataset_get() {
|
||||
let df = create_test_dataframe();
|
||||
let dataset = DataframeDataset::<TestData>::new(df).unwrap();
|
||||
|
||||
let expected_items = vec![
|
||||
TestData {
|
||||
int32: 1,
|
||||
bool: true,
|
||||
float64: 1.1,
|
||||
string: "Boo".to_string(),
|
||||
int16: 1,
|
||||
uint32: 1,
|
||||
uint64: 1,
|
||||
float32: 1.1,
|
||||
int64: 1,
|
||||
int8: 1,
|
||||
binary: vec![1, 2, 3],
|
||||
},
|
||||
TestData {
|
||||
int32: 2,
|
||||
bool: false,
|
||||
float64: 2.2,
|
||||
string: "Boo2".to_string(),
|
||||
int16: 2,
|
||||
uint32: 2,
|
||||
uint64: 2,
|
||||
float32: 2.2,
|
||||
int64: 2,
|
||||
int8: 2,
|
||||
binary: vec![4, 5, 6],
|
||||
},
|
||||
TestData {
|
||||
int32: 3,
|
||||
bool: true,
|
||||
float64: 3.3,
|
||||
string: "Boo3".to_string(),
|
||||
int16: 3,
|
||||
uint32: 3,
|
||||
uint64: 3,
|
||||
float32: 3.3,
|
||||
int64: 3,
|
||||
int8: 3,
|
||||
binary: vec![7, 8, 9],
|
||||
},
|
||||
];
|
||||
|
||||
for (index, expected_item) in expected_items.iter().enumerate() {
|
||||
let item = dataset.get(index).unwrap();
|
||||
assert_eq!(&item, expected_item);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dataframe_dataset_out_of_bounds() {
|
||||
let df = create_test_dataframe();
|
||||
let dataset = DataframeDataset::<TestData>::new(df).unwrap();
|
||||
assert!(dataset.get(3).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dataframe_dataset() {
|
||||
let df = create_test_dataframe();
|
||||
let dataset: DataframeDataset<TestData> = DataframeDataset::new(df).unwrap();
|
||||
|
||||
assert_eq!(dataset.len(), 3);
|
||||
assert!(!dataset.is_empty());
|
||||
|
||||
let item = dataset.get(1).unwrap();
|
||||
assert_eq!(
|
||||
item,
|
||||
TestData {
|
||||
int32: 2,
|
||||
bool: false,
|
||||
float64: 2.2,
|
||||
string: "Boo2".to_string(),
|
||||
int16: 2,
|
||||
uint32: 2,
|
||||
uint64: 2,
|
||||
float32: 2.2,
|
||||
int64: 2,
|
||||
int8: 2,
|
||||
binary: vec![4, 5, 6],
|
||||
}
|
||||
);
|
||||
|
||||
let item = dataset.get(2).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
item,
|
||||
TestData {
|
||||
int32: 3,
|
||||
bool: true,
|
||||
float64: 3.3,
|
||||
string: "Boo3".to_string(),
|
||||
int16: 3,
|
||||
uint32: 3,
|
||||
uint64: 3,
|
||||
float32: 3.3,
|
||||
int64: 3,
|
||||
int8: 3,
|
||||
binary: vec![7, 8, 9],
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic = "Corresponding column should exist in the DataFrame: SchemaFieldNotFound(ErrString(\"non_existent\"))"]
|
||||
fn test_non_existing_struct_fields() {
|
||||
#[derive(Clone, Debug, Deserialize, PartialEq)]
|
||||
struct PartialTestData {
|
||||
int32: i32,
|
||||
bool: bool,
|
||||
non_existent: String,
|
||||
}
|
||||
|
||||
let df = create_test_dataframe();
|
||||
let dataset = DataframeDataset::<PartialTestData>::new(df);
|
||||
|
||||
assert!(dataset.is_err());
|
||||
if let Err(e) = dataset {
|
||||
assert!(matches!(e, DataframeDatasetError::Other(_)));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_table() {
|
||||
#[derive(Clone, Debug, Deserialize, PartialEq)]
|
||||
struct PartialTestData {
|
||||
int32: i32,
|
||||
bool: bool,
|
||||
string: String,
|
||||
}
|
||||
|
||||
let df = create_test_dataframe();
|
||||
let dataset = DataframeDataset::<PartialTestData>::new(df).unwrap();
|
||||
|
||||
assert_eq!(dataset.len(), 3);
|
||||
assert!(!dataset.is_empty());
|
||||
|
||||
let item = dataset.get(1).unwrap();
|
||||
assert_eq!(
|
||||
item,
|
||||
PartialTestData {
|
||||
int32: 2,
|
||||
bool: false,
|
||||
string: "Boo2".to_string(),
|
||||
}
|
||||
);
|
||||
|
||||
let item = dataset.get(2).unwrap();
|
||||
assert_eq!(
|
||||
item,
|
||||
PartialTestData {
|
||||
int32: 3,
|
||||
bool: true,
|
||||
string: "Boo3".to_string(),
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
use crate::{Dataset, DatasetIterator, InMemDataset};
|
||||
use fake::{Dummy, Fake, Faker};
|
||||
|
||||
/// Dataset filled with fake items generated from the [fake](fake) crate.
|
||||
pub struct FakeDataset<I> {
|
||||
dataset: InMemDataset<I>,
|
||||
}
|
||||
|
||||
impl<I: Dummy<Faker>> FakeDataset<I> {
|
||||
/// Create a new fake dataset with the given size.
|
||||
pub fn new(size: usize) -> Self {
|
||||
let mut items = Vec::with_capacity(size);
|
||||
for _ in 0..size {
|
||||
items.push(Faker.fake());
|
||||
}
|
||||
let dataset = InMemDataset::new(items);
|
||||
|
||||
Self { dataset }
|
||||
}
|
||||
}
|
||||
|
||||
impl<I: Send + Sync + Clone> Dataset<I> for FakeDataset<I> {
|
||||
fn iter(&self) -> DatasetIterator<'_, I> {
|
||||
DatasetIterator::new(self)
|
||||
}
|
||||
|
||||
fn get(&self, index: usize) -> Option<I> {
|
||||
self.dataset.get(index)
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.dataset.len()
|
||||
}
|
||||
|
||||
fn is_empty(&self) -> bool {
|
||||
self.dataset.is_empty()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,192 @@
|
||||
use std::{
|
||||
fs::File,
|
||||
io::{BufRead, BufReader},
|
||||
path::Path,
|
||||
};
|
||||
|
||||
use serde::de::DeserializeOwned;
|
||||
|
||||
use crate::Dataset;
|
||||
|
||||
/// Dataset where all items are stored in ram.
|
||||
pub struct InMemDataset<I> {
|
||||
items: Vec<I>,
|
||||
}
|
||||
|
||||
impl<I> InMemDataset<I> {
|
||||
/// Creates a new in memory dataset from the given items.
|
||||
pub fn new(items: Vec<I>) -> Self {
|
||||
InMemDataset { items }
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> Dataset<I> for InMemDataset<I>
|
||||
where
|
||||
I: Clone + Send + Sync,
|
||||
{
|
||||
fn get(&self, index: usize) -> Option<I> {
|
||||
self.items.get(index).cloned()
|
||||
}
|
||||
fn len(&self) -> usize {
|
||||
self.items.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> InMemDataset<I>
|
||||
where
|
||||
I: Clone + DeserializeOwned,
|
||||
{
|
||||
/// Create from a dataset. All items are loaded in memory.
|
||||
pub fn from_dataset(dataset: &impl Dataset<I>) -> Self {
|
||||
let items: Vec<I> = dataset.iter().collect();
|
||||
Self::new(items)
|
||||
}
|
||||
|
||||
/// Create from a json rows file (one json per line).
|
||||
///
|
||||
/// [Supported field types](https://docs.rs/serde_json/latest/serde_json/value/enum.Value.html)
|
||||
pub fn from_json_rows<P: AsRef<Path>>(path: P) -> Result<Self, std::io::Error> {
|
||||
let file = File::open(path)?;
|
||||
let reader = BufReader::new(file);
|
||||
let mut items = Vec::new();
|
||||
|
||||
for line in reader.lines() {
|
||||
let item = serde_json::from_str(line.unwrap().as_str()).unwrap();
|
||||
items.push(item);
|
||||
}
|
||||
|
||||
let dataset = Self::new(items);
|
||||
|
||||
Ok(dataset)
|
||||
}
|
||||
|
||||
/// Create from a csv file.
|
||||
///
|
||||
/// The provided `csv::ReaderBuilder` can be configured to fit your csv format.
|
||||
///
|
||||
/// The supported field types are: String, integer, float, and bool.
|
||||
///
|
||||
/// See:
|
||||
/// - [Reading with Serde](https://docs.rs/csv/latest/csv/tutorial/index.html#reading-with-serde)
|
||||
/// - [Delimiters, quotes and variable length records](https://docs.rs/csv/latest/csv/tutorial/index.html#delimiters-quotes-and-variable-length-records)
|
||||
pub fn from_csv<P: AsRef<Path>>(
|
||||
path: P,
|
||||
builder: &csv::ReaderBuilder,
|
||||
) -> Result<Self, std::io::Error> {
|
||||
let mut rdr = builder.from_path(path)?;
|
||||
|
||||
let mut items = Vec::new();
|
||||
|
||||
for result in rdr.deserialize() {
|
||||
let item: I = result?;
|
||||
items.push(item);
|
||||
}
|
||||
|
||||
let dataset = Self::new(items);
|
||||
|
||||
Ok(dataset)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::{SqliteDataset, test_data};
|
||||
|
||||
use rstest::{fixture, rstest};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
const DB_FILE: &str = "tests/data/sqlite-dataset.db";
|
||||
const JSON_FILE: &str = "tests/data/dataset.json";
|
||||
const CSV_FILE: &str = "tests/data/dataset.csv";
|
||||
const CSV_FMT_FILE: &str = "tests/data/dataset-fmt.csv";
|
||||
|
||||
type SqlDs = SqliteDataset<Sample>;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct Sample {
|
||||
column_str: String,
|
||||
column_bytes: Vec<u8>,
|
||||
column_int: i64,
|
||||
column_bool: bool,
|
||||
column_float: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct SampleCsv {
|
||||
column_str: String,
|
||||
column_int: i64,
|
||||
column_bool: bool,
|
||||
column_float: f64,
|
||||
}
|
||||
|
||||
#[fixture]
|
||||
fn train_dataset() -> SqlDs {
|
||||
SqliteDataset::from_db_file(DB_FILE, "train").unwrap()
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
pub fn from_dataset(train_dataset: SqlDs) {
|
||||
let dataset = InMemDataset::from_dataset(&train_dataset);
|
||||
|
||||
let non_existing_record_index: usize = 10;
|
||||
let record_index: usize = 0;
|
||||
|
||||
assert_eq!(train_dataset.get(non_existing_record_index), None);
|
||||
assert_eq!(dataset.get(record_index).unwrap().column_str, "HI1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn from_json_rows() {
|
||||
let dataset = InMemDataset::<Sample>::from_json_rows(JSON_FILE).unwrap();
|
||||
|
||||
let non_existing_record_index: usize = 10;
|
||||
let record_index: usize = 1;
|
||||
|
||||
assert_eq!(dataset.get(non_existing_record_index), None);
|
||||
assert_eq!(dataset.get(record_index).unwrap().column_str, "HI2");
|
||||
assert!(!dataset.get(record_index).unwrap().column_bool);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn from_csv_rows() {
|
||||
let rdr = csv::ReaderBuilder::new();
|
||||
let dataset = InMemDataset::<SampleCsv>::from_csv(CSV_FILE, &rdr).unwrap();
|
||||
|
||||
let non_existing_record_index: usize = 10;
|
||||
let record_index: usize = 1;
|
||||
|
||||
assert_eq!(dataset.get(non_existing_record_index), None);
|
||||
assert_eq!(dataset.get(record_index).unwrap().column_str, "HI2");
|
||||
assert_eq!(dataset.get(record_index).unwrap().column_int, 1);
|
||||
assert!(!dataset.get(record_index).unwrap().column_bool);
|
||||
assert_eq!(dataset.get(record_index).unwrap().column_float, 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn from_csv_rows_fmt() {
|
||||
let mut rdr = csv::ReaderBuilder::new();
|
||||
let rdr = rdr.delimiter(b' ').has_headers(false);
|
||||
let dataset = InMemDataset::<SampleCsv>::from_csv(CSV_FMT_FILE, rdr).unwrap();
|
||||
|
||||
let non_existing_record_index: usize = 10;
|
||||
let record_index: usize = 1;
|
||||
|
||||
assert_eq!(dataset.get(non_existing_record_index), None);
|
||||
assert_eq!(dataset.get(record_index).unwrap().column_str, "HI2");
|
||||
assert_eq!(dataset.get(record_index).unwrap().column_int, 1);
|
||||
assert!(!dataset.get(record_index).unwrap().column_bool);
|
||||
assert_eq!(dataset.get(record_index).unwrap().column_float, 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn given_in_memory_dataset_when_iterate_should_iterate_though_all_items() {
|
||||
let items_original = test_data::string_items();
|
||||
let dataset = InMemDataset::new(items_original.clone());
|
||||
|
||||
let items: Vec<String> = dataset.iter().collect();
|
||||
|
||||
assert_eq!(items_original, items);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
use crate::dataset::Dataset;
|
||||
use std::iter::Iterator;
|
||||
|
||||
/// Dataset iterator.
|
||||
pub struct DatasetIterator<'a, I> {
|
||||
current: usize,
|
||||
dataset: &'a dyn Dataset<I>,
|
||||
}
|
||||
|
||||
impl<'a, I> DatasetIterator<'a, I> {
|
||||
/// Creates a new dataset iterator.
|
||||
pub fn new<D>(dataset: &'a D) -> Self
|
||||
where
|
||||
D: Dataset<I>,
|
||||
{
|
||||
DatasetIterator {
|
||||
current: 0,
|
||||
dataset,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> Iterator for DatasetIterator<'_, I> {
|
||||
type Item = I;
|
||||
|
||||
fn next(&mut self) -> Option<I> {
|
||||
let item = self.dataset.get(self.current);
|
||||
self.current += 1;
|
||||
item
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
mod base;
|
||||
mod in_memory;
|
||||
mod iterator;
|
||||
|
||||
pub use base::*;
|
||||
pub use in_memory::*;
|
||||
pub use iterator::*;
|
||||
|
||||
#[cfg(any(test, feature = "fake"))]
|
||||
mod fake;
|
||||
|
||||
#[cfg(any(test, feature = "fake"))]
|
||||
pub use self::fake::*;
|
||||
|
||||
#[cfg(feature = "dataframe")]
|
||||
mod dataframe;
|
||||
|
||||
#[cfg(feature = "dataframe")]
|
||||
pub use dataframe::*;
|
||||
|
||||
#[cfg(any(feature = "sqlite", feature = "sqlite-bundled"))]
|
||||
pub use sqlite::*;
|
||||
|
||||
#[cfg(any(feature = "sqlite", feature = "sqlite-bundled"))]
|
||||
mod sqlite;
|
||||
@@ -0,0 +1,851 @@
|
||||
use std::{
|
||||
collections::HashSet,
|
||||
fs, io,
|
||||
marker::PhantomData,
|
||||
path::{Path, PathBuf},
|
||||
sync::{Arc, RwLock},
|
||||
};
|
||||
|
||||
use crate::Dataset;
|
||||
|
||||
use gix_tempfile::{
|
||||
AutoRemove, ContainingDirectory, Handle,
|
||||
handle::{Writable, persist},
|
||||
};
|
||||
use r2d2::{Pool, PooledConnection};
|
||||
use r2d2_sqlite::{
|
||||
SqliteConnectionManager,
|
||||
rusqlite::{OpenFlags, OptionalExtension},
|
||||
};
|
||||
use sanitize_filename::sanitize;
|
||||
use serde::{Serialize, de::DeserializeOwned};
|
||||
use serde_rusqlite::{columns_from_statement, from_row_with_columns};
|
||||
|
||||
/// Result type for the sqlite dataset.
|
||||
pub type Result<T> = core::result::Result<T, SqliteDatasetError>;
|
||||
|
||||
/// Sqlite dataset error.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum SqliteDatasetError {
|
||||
/// IO related error.
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] io::Error),
|
||||
|
||||
/// Sql related error.
|
||||
#[error("Sql error: {0}")]
|
||||
Sql(#[from] serde_rusqlite::rusqlite::Error),
|
||||
|
||||
/// Serde related error.
|
||||
#[error("Serde error: {0}")]
|
||||
Serde(#[from] rmp_serde::encode::Error),
|
||||
|
||||
/// The database file already exists error.
|
||||
#[error("Overwrite flag is set to false and the database file already exists: {0}")]
|
||||
FileExists(PathBuf),
|
||||
|
||||
/// Error when creating the connection pool.
|
||||
#[error("Failed to create connection pool: {0}")]
|
||||
ConnectionPool(#[from] r2d2::Error),
|
||||
|
||||
/// Error when persisting the temporary database file.
|
||||
#[error("Could not persist the temporary database file: {0}")]
|
||||
PersistDbFile(#[from] persist::Error<Writable>),
|
||||
|
||||
/// Any other error.
|
||||
#[error("{0}")]
|
||||
Other(&'static str),
|
||||
}
|
||||
|
||||
impl From<&'static str> for SqliteDatasetError {
|
||||
fn from(s: &'static str) -> Self {
|
||||
SqliteDatasetError::Other(s)
|
||||
}
|
||||
}
|
||||
|
||||
/// This struct represents a dataset where all items are stored in an SQLite database.
|
||||
/// Each instance of this struct corresponds to a specific table within the SQLite database,
|
||||
/// and allows for interaction with the data stored in the table in a structured and typed manner.
|
||||
///
|
||||
/// The SQLite database must contain a table with the same name as the `split` field. This table should
|
||||
/// have a primary key column named `row_id`, which is used to index the rows in the table. The `row_id`
|
||||
/// should start at 1, while the corresponding dataset `index` should start at 0, i.e., `row_id` = `index` + 1.
|
||||
///
|
||||
/// Table columns can be represented in two ways:
|
||||
///
|
||||
/// 1. The table can have a column for each field in the `I` struct. In this case, the column names in the table
|
||||
/// should match the field names of the `I` struct. The field names can be a subset of column names and
|
||||
/// can be in any order.
|
||||
///
|
||||
/// For the supported field types, refer to:
|
||||
/// - [Serialization field types](https://docs.rs/serde_rusqlite/latest/serde_rusqlite)
|
||||
/// - [SQLite data types](https://www.sqlite.org/datatype3.html)
|
||||
///
|
||||
/// 2. The fields in the `I` struct can be serialized into a single column `item` in the table. In this case, the table
|
||||
/// should have a single column named `item` of type `BLOB`. This is useful when the `I` struct contains complex fields
|
||||
/// that cannot be mapped to a SQLite type, such as nested structs, vectors, etc. The serialization is done using
|
||||
/// [MessagePack](https://msgpack.org/).
|
||||
///
|
||||
/// Note: The code automatically figures out which of the above two cases is applicable, and uses the appropriate
|
||||
/// method to read the data from the table.
|
||||
#[derive(Debug)]
|
||||
pub struct SqliteDataset<I> {
|
||||
db_file: PathBuf,
|
||||
split: String,
|
||||
conn_pool: Pool<SqliteConnectionManager>,
|
||||
columns: Vec<String>,
|
||||
len: usize,
|
||||
select_statement: String,
|
||||
row_serialized: bool,
|
||||
phantom: PhantomData<I>,
|
||||
}
|
||||
|
||||
impl<I> SqliteDataset<I> {
|
||||
/// Initializes a `SqliteDataset` from a SQLite database file and a split name.
|
||||
pub fn from_db_file<P: AsRef<Path>>(db_file: P, split: &str) -> Result<Self> {
|
||||
// Create a connection pool
|
||||
let conn_pool = create_conn_pool(&db_file, false)?;
|
||||
|
||||
// Determine how the table is stored
|
||||
let row_serialized = Self::check_if_row_serialized(&conn_pool, split)?;
|
||||
|
||||
// Create a select statement and save it
|
||||
let select_statement = if row_serialized {
|
||||
format!("select item from {split} where row_id = ?")
|
||||
} else {
|
||||
format!("select * from {split} where row_id = ?")
|
||||
};
|
||||
|
||||
// Save the column names and the number of rows
|
||||
let (columns, len) = fetch_columns_and_len(&conn_pool, &select_statement, split)?;
|
||||
|
||||
Ok(SqliteDataset {
|
||||
db_file: db_file.as_ref().to_path_buf(),
|
||||
split: split.to_string(),
|
||||
conn_pool,
|
||||
columns,
|
||||
len,
|
||||
select_statement,
|
||||
row_serialized,
|
||||
phantom: PhantomData,
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns true if table has two columns: row_id (integer) and item (blob).
|
||||
///
|
||||
/// This is used to determine if the table is row serialized or not.
|
||||
fn check_if_row_serialized(
|
||||
conn_pool: &Pool<SqliteConnectionManager>,
|
||||
split: &str,
|
||||
) -> Result<bool> {
|
||||
// This struct is used to store the column name and type
|
||||
struct Column {
|
||||
name: String,
|
||||
ty: String,
|
||||
}
|
||||
|
||||
const COLUMN_NAME: usize = 1;
|
||||
const COLUMN_TYPE: usize = 2;
|
||||
|
||||
let sql_statement = format!("PRAGMA table_info({split})");
|
||||
|
||||
let conn = conn_pool.get()?;
|
||||
|
||||
let mut stmt = conn.prepare(sql_statement.as_str())?;
|
||||
let column_iter = stmt.query_map([], |row| {
|
||||
Ok(Column {
|
||||
name: row
|
||||
.get::<usize, String>(COLUMN_NAME)
|
||||
.unwrap()
|
||||
.to_lowercase(),
|
||||
ty: row
|
||||
.get::<usize, String>(COLUMN_TYPE)
|
||||
.unwrap()
|
||||
.to_lowercase(),
|
||||
})
|
||||
})?;
|
||||
|
||||
let mut columns: Vec<Column> = vec![];
|
||||
|
||||
for column in column_iter {
|
||||
columns.push(column?);
|
||||
}
|
||||
|
||||
if columns.len() != 2 {
|
||||
Ok(false)
|
||||
} else {
|
||||
// Check if the column names and types match the expected values
|
||||
Ok(columns[0].name == "row_id"
|
||||
&& columns[0].ty == "integer"
|
||||
&& columns[1].name == "item"
|
||||
&& columns[1].ty == "blob")
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the database file name.
|
||||
pub fn db_file(&self) -> PathBuf {
|
||||
self.db_file.clone()
|
||||
}
|
||||
|
||||
/// Get the split name.
|
||||
pub fn split(&self) -> &str {
|
||||
self.split.as_str()
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> Dataset<I> for SqliteDataset<I>
|
||||
where
|
||||
I: Clone + Send + Sync + DeserializeOwned,
|
||||
{
|
||||
/// Get an item from the dataset.
|
||||
fn get(&self, index: usize) -> Option<I> {
|
||||
// Row ids start with 1 (one) and index starts with 0 (zero)
|
||||
let row_id = index + 1;
|
||||
|
||||
// Get a connection from the pool
|
||||
let connection = self.conn_pool.get().unwrap();
|
||||
let mut statement = connection.prepare(self.select_statement.as_str()).unwrap();
|
||||
|
||||
if self.row_serialized {
|
||||
// Fetch with a single column `item` and deserialize it with MessagePack
|
||||
statement
|
||||
.query_row([row_id], |row| {
|
||||
// Deserialize item (blob) with MessagePack (rmp-serde)
|
||||
Ok(
|
||||
rmp_serde::from_slice::<I>(row.get_ref(0).unwrap().as_blob().unwrap())
|
||||
.unwrap(),
|
||||
)
|
||||
})
|
||||
.optional() //Converts Error (not found) to None
|
||||
.unwrap()
|
||||
} else {
|
||||
// Fetch a row with multiple columns and deserialize it serde_rusqlite
|
||||
statement
|
||||
.query_row([row_id], |row| {
|
||||
// Deserialize the row with serde_rusqlite
|
||||
Ok(from_row_with_columns::<I>(row, &self.columns).unwrap())
|
||||
})
|
||||
.optional() //Converts Error (not found) to None
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the number of rows in the dataset.
|
||||
fn len(&self) -> usize {
|
||||
self.len
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetch the column names and the number of rows from the database.
|
||||
fn fetch_columns_and_len(
|
||||
conn_pool: &Pool<SqliteConnectionManager>,
|
||||
select_statement: &str,
|
||||
split: &str,
|
||||
) -> Result<(Vec<String>, usize)> {
|
||||
// Save the column names
|
||||
let connection = conn_pool.get()?;
|
||||
let statement = connection.prepare(select_statement)?;
|
||||
let columns = columns_from_statement(&statement);
|
||||
|
||||
// Count the number of rows and save it as len
|
||||
//
|
||||
// NOTE: Using coalesce(max(row_id), 0) instead of count(*) because count(*) is super slow for large tables.
|
||||
// The coalesce(max(row_id), 0) returns 0 if the table is empty, otherwise it returns the max row_id,
|
||||
// which corresponds to the number of rows in the table.
|
||||
// The main assumption, which always holds true, is that the row_id is always increasing and there are no gaps.
|
||||
// This is true for all the datasets that we are using, otherwise row_id will not correspond to the index.
|
||||
let mut statement =
|
||||
connection.prepare(format!("select coalesce(max(row_id), 0) from {split}").as_str())?;
|
||||
|
||||
let len = statement.query_row([], |row| {
|
||||
let len: usize = row.get(0)?;
|
||||
Ok(len)
|
||||
})?;
|
||||
Ok((columns, len))
|
||||
}
|
||||
|
||||
/// Helper function to create a connection pool
|
||||
fn create_conn_pool<P: AsRef<Path>>(
|
||||
db_file: P,
|
||||
write: bool,
|
||||
) -> Result<Pool<SqliteConnectionManager>> {
|
||||
let sqlite_flags = if write {
|
||||
OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE
|
||||
} else {
|
||||
OpenFlags::SQLITE_OPEN_READ_ONLY
|
||||
};
|
||||
|
||||
let manager = SqliteConnectionManager::file(db_file).with_flags(sqlite_flags);
|
||||
Pool::new(manager).map_err(SqliteDatasetError::ConnectionPool)
|
||||
}
|
||||
|
||||
/// The `SqliteDatasetStorage` struct represents a SQLite database for storing datasets.
|
||||
/// It consists of an optional name, a database file path, and a base directory for storage.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct SqliteDatasetStorage {
|
||||
name: Option<String>,
|
||||
db_file: Option<PathBuf>,
|
||||
base_dir: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl SqliteDatasetStorage {
|
||||
/// Creates a new instance of `SqliteDatasetStorage` using a dataset name.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `name` - A string slice that holds the name of the dataset.
|
||||
pub fn from_name(name: &str) -> Self {
|
||||
SqliteDatasetStorage {
|
||||
name: Some(name.to_string()),
|
||||
db_file: None,
|
||||
base_dir: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new instance of `SqliteDatasetStorage` using a database file path.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `db_file` - A reference to the Path that represents the database file path.
|
||||
pub fn from_file<P: AsRef<Path>>(db_file: P) -> Self {
|
||||
SqliteDatasetStorage {
|
||||
name: None,
|
||||
db_file: Some(db_file.as_ref().to_path_buf()),
|
||||
base_dir: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets the base directory for storing the dataset.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `base_dir` - A string slice that represents the base directory.
|
||||
pub fn with_base_dir<P: AsRef<Path>>(mut self, base_dir: P) -> Self {
|
||||
self.base_dir = Some(base_dir.as_ref().to_path_buf());
|
||||
self
|
||||
}
|
||||
|
||||
/// Checks if the database file exists in the given path.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * A boolean value indicating whether the file exists or not.
|
||||
pub fn exists(&self) -> bool {
|
||||
self.db_file().exists()
|
||||
}
|
||||
|
||||
/// Fetches the database file path.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * A `PathBuf` instance representing the file path.
|
||||
pub fn db_file(&self) -> PathBuf {
|
||||
match &self.db_file {
|
||||
Some(db_file) => db_file.clone(),
|
||||
None => {
|
||||
let name = sanitize(self.name.as_ref().expect("Name is not set"));
|
||||
Self::base_dir(self.base_dir.to_owned()).join(format!("{name}.db"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Determines the base directory for storing the dataset.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `base_dir` - An `Option` that may contain a `PathBuf` instance representing the base directory.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * A `PathBuf` instance representing the base directory.
|
||||
pub fn base_dir(base_dir: Option<PathBuf>) -> PathBuf {
|
||||
match base_dir {
|
||||
Some(base_dir) => base_dir,
|
||||
None => dirs::cache_dir()
|
||||
.expect("Could not get cache directory")
|
||||
.join("burn-dataset"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Provides a writer instance for the SQLite dataset.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `overwrite` - A boolean indicating if the existing database file should be overwritten.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * A `Result` which is `Ok` if the writer could be created, `Err` otherwise.
|
||||
pub fn writer<I>(&self, overwrite: bool) -> Result<SqliteDatasetWriter<I>>
|
||||
where
|
||||
I: Clone + Send + Sync + Serialize + DeserializeOwned,
|
||||
{
|
||||
SqliteDatasetWriter::new(self.db_file(), overwrite)
|
||||
}
|
||||
|
||||
/// Provides a reader instance for the SQLite dataset.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `split` - A string slice that defines the data split for reading (e.g., "train", "test").
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * A `Result` which is `Ok` if the reader could be created, `Err` otherwise.
|
||||
pub fn reader<I>(&self, split: &str) -> Result<SqliteDataset<I>>
|
||||
where
|
||||
I: Clone + Send + Sync + Serialize + DeserializeOwned,
|
||||
{
|
||||
if !self.exists() {
|
||||
panic!("The database file does not exist");
|
||||
}
|
||||
|
||||
SqliteDataset::from_db_file(self.db_file(), split)
|
||||
}
|
||||
}
|
||||
|
||||
/// This `SqliteDatasetWriter` struct is a SQLite database writer dedicated to storing datasets.
|
||||
/// It retains the current writer's state and its database connection.
|
||||
///
|
||||
/// Being thread-safe, this writer can be concurrently used across multiple threads.
|
||||
///
|
||||
/// Typical applications include:
|
||||
///
|
||||
/// - Generation of a new dataset
|
||||
/// - Storage of preprocessed data or metadata
|
||||
/// - Enlargement of a dataset's item count post preprocessing
|
||||
#[derive(Debug)]
|
||||
pub struct SqliteDatasetWriter<I> {
|
||||
db_file: PathBuf,
|
||||
db_file_tmp: Option<Handle<Writable>>,
|
||||
splits: Arc<RwLock<HashSet<String>>>,
|
||||
overwrite: bool,
|
||||
conn_pool: Option<Pool<SqliteConnectionManager>>,
|
||||
is_completed: Arc<RwLock<bool>>,
|
||||
phantom: PhantomData<I>,
|
||||
}
|
||||
|
||||
impl<I> SqliteDatasetWriter<I>
|
||||
where
|
||||
I: Clone + Send + Sync + Serialize + DeserializeOwned,
|
||||
{
|
||||
/// Creates a new instance of `SqliteDatasetWriter`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `db_file` - A reference to the Path that represents the database file path.
|
||||
/// * `overwrite` - A boolean indicating if the existing database file should be overwritten.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * A `Result` which is `Ok` if the writer could be created, `Err` otherwise.
|
||||
pub fn new<P: AsRef<Path>>(db_file: P, overwrite: bool) -> Result<Self> {
|
||||
let writer = Self {
|
||||
db_file: db_file.as_ref().to_path_buf(),
|
||||
db_file_tmp: None,
|
||||
splits: Arc::new(RwLock::new(HashSet::new())),
|
||||
overwrite,
|
||||
conn_pool: None,
|
||||
is_completed: Arc::new(RwLock::new(false)),
|
||||
phantom: PhantomData,
|
||||
};
|
||||
|
||||
writer.init()
|
||||
}
|
||||
|
||||
/// Initializes the dataset writer by creating the database file, tables, and connection pool.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * A `Result` which is `Ok` if the writer could be initialized, `Err` otherwise.
|
||||
fn init(mut self) -> Result<Self> {
|
||||
// Remove the db file if it already exists
|
||||
if self.db_file.exists() {
|
||||
if self.overwrite {
|
||||
fs::remove_file(&self.db_file)?;
|
||||
} else {
|
||||
return Err(SqliteDatasetError::FileExists(self.db_file));
|
||||
}
|
||||
}
|
||||
|
||||
// Create the database file directory if it does not exist
|
||||
let db_file_dir = self
|
||||
.db_file
|
||||
.parent()
|
||||
.ok_or("Unable to get parent directory")?;
|
||||
|
||||
if !db_file_dir.exists() {
|
||||
fs::create_dir_all(db_file_dir)?;
|
||||
}
|
||||
|
||||
// Create a temp database file name as {base_dir}/{name}.db.tmp
|
||||
let mut db_file_tmp = self.db_file.clone();
|
||||
db_file_tmp.set_extension("db.tmp");
|
||||
if db_file_tmp.exists() {
|
||||
fs::remove_file(&db_file_tmp)?;
|
||||
}
|
||||
|
||||
// Create the temp database file and wrap it with a gix_tempfile::Handle
|
||||
// This will ensure that the temp file is deleted when the writer is dropped
|
||||
// or when process exits with SIGINT or SIGTERM (tempfile crate does not do this)
|
||||
gix_tempfile::signal::setup(Default::default());
|
||||
self.db_file_tmp = Some(gix_tempfile::writable_at(
|
||||
&db_file_tmp,
|
||||
ContainingDirectory::Exists,
|
||||
AutoRemove::Tempfile,
|
||||
)?);
|
||||
|
||||
let conn_pool = create_conn_pool(db_file_tmp, true)?;
|
||||
self.conn_pool = Some(conn_pool);
|
||||
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Serializes and writes an item to the database. The item is written to the table for the
|
||||
/// specified split. If the table does not exist, it is created. If the table exists, the item
|
||||
/// is appended to the table. The serialization is done using the [MessagePack](https://msgpack.org/)
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `split` - A string slice that defines the data split for writing (e.g., "train", "test").
|
||||
/// * `item` - A reference to the item to be written to the database.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * A `Result` containing the index of the inserted row if successful, an error otherwise.
|
||||
pub fn write(&self, split: &str, item: &I) -> Result<usize> {
|
||||
// Acquire the read lock (wont't block other reads)
|
||||
let is_completed = self.is_completed.read().unwrap();
|
||||
|
||||
// If the writer is completed, return an error
|
||||
if *is_completed {
|
||||
return Err(SqliteDatasetError::Other(
|
||||
"Cannot save to a completed dataset writer",
|
||||
));
|
||||
}
|
||||
|
||||
// create the table for the split if it does not exist
|
||||
if !self.splits.read().unwrap().contains(split) {
|
||||
self.create_table(split)?;
|
||||
}
|
||||
|
||||
// Get a connection from the pool
|
||||
let conn_pool = self.conn_pool.as_ref().unwrap();
|
||||
let conn = conn_pool.get()?;
|
||||
|
||||
// Serialize the item using MessagePack
|
||||
let serialized_item = rmp_serde::to_vec(item)?;
|
||||
|
||||
// Turn off the synchronous and journal mode for speed up
|
||||
// We are sacrificing durability for speed but it's okay because
|
||||
// we always recreate the dataset if it is not completed.
|
||||
pragma_update_with_error_handling(&conn, "synchronous", "OFF")?;
|
||||
pragma_update_with_error_handling(&conn, "journal_mode", "OFF")?;
|
||||
|
||||
// Insert the serialized item into the database
|
||||
let insert_statement = format!("insert into {split} (item) values (?)");
|
||||
conn.execute(insert_statement.as_str(), [serialized_item])?;
|
||||
|
||||
// Get the primary key of the last inserted row and convert to index (row_id-1)
|
||||
let index = (conn.last_insert_rowid() - 1) as usize;
|
||||
|
||||
Ok(index)
|
||||
}
|
||||
|
||||
/// Marks the dataset as completed and persists the temporary database file.
|
||||
pub fn set_completed(&mut self) -> Result<()> {
|
||||
let mut is_completed = self.is_completed.write().unwrap();
|
||||
|
||||
// Force close the connection pool
|
||||
// This is required on Windows platform where the connection pool prevents
|
||||
// from persisting the db by renaming the temp file.
|
||||
if let Some(pool) = self.conn_pool.take() {
|
||||
std::mem::drop(pool);
|
||||
}
|
||||
|
||||
// Rename the database file from tmp to db
|
||||
let _file_result = self
|
||||
.db_file_tmp
|
||||
.take() // take ownership of the temporary file and set to None
|
||||
.unwrap() // unwrap the temporary file
|
||||
.persist(&self.db_file)?
|
||||
.ok_or("Unable to persist the database file")?;
|
||||
|
||||
*is_completed = true;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Creates table for the data split.
|
||||
///
|
||||
/// Note: call is idempotent and thread-safe.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `split` - A string slice that defines the data split for the table (e.g., "train", "test").
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * A `Result` which is `Ok` if the table could be created, `Err` otherwise.
|
||||
///
|
||||
/// TODO (@antimora): add support creating a table with columns corresponding to the item fields
|
||||
fn create_table(&self, split: &str) -> Result<()> {
|
||||
// Check if the split already exists
|
||||
if self.splits.read().unwrap().contains(split) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let conn_pool = self.conn_pool.as_ref().unwrap();
|
||||
let connection = conn_pool.get()?;
|
||||
let create_table_statement = format!(
|
||||
"create table if not exists {split} (row_id integer primary key autoincrement not \
|
||||
null, item blob not null)"
|
||||
);
|
||||
|
||||
connection.execute(create_table_statement.as_str(), [])?;
|
||||
|
||||
// Add the split to the splits
|
||||
self.splits.write().unwrap().insert(split.to_string());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Runs a pragma update and ignores the `ExecuteReturnedResults` error.
|
||||
///
|
||||
/// Sometimes ExecuteReturnedResults is returned when running a pragma update. This is not an error
|
||||
/// and can be ignored. This function runs the pragma update and ignores the error if it is
|
||||
/// `ExecuteReturnedResults`.
|
||||
fn pragma_update_with_error_handling(
|
||||
conn: &PooledConnection<SqliteConnectionManager>,
|
||||
setting: &str,
|
||||
value: &str,
|
||||
) -> Result<()> {
|
||||
let result = conn.pragma_update(None, setting, value);
|
||||
if let Err(error) = result
|
||||
&& error != rusqlite::Error::ExecuteReturnedResults
|
||||
{
|
||||
return Err(SqliteDatasetError::Sql(error));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rayon::prelude::*;
|
||||
use rstest::{fixture, rstest};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tempfile::{NamedTempFile, TempDir, tempdir};
|
||||
|
||||
use super::*;
|
||||
|
||||
type SqlDs = SqliteDataset<Sample>;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct Sample {
|
||||
column_str: String,
|
||||
column_bytes: Vec<u8>,
|
||||
column_int: i64,
|
||||
column_bool: bool,
|
||||
column_float: f64,
|
||||
}
|
||||
|
||||
#[fixture]
|
||||
fn train_dataset() -> SqlDs {
|
||||
SqliteDataset::<Sample>::from_db_file("tests/data/sqlite-dataset.db", "train").unwrap()
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
pub fn len(train_dataset: SqlDs) {
|
||||
assert_eq!(train_dataset.len(), 2);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
pub fn get_some(train_dataset: SqlDs) {
|
||||
let item = train_dataset.get(0).unwrap();
|
||||
assert_eq!(item.column_str, "HI1");
|
||||
assert_eq!(item.column_bytes, vec![55, 231, 159]);
|
||||
assert_eq!(item.column_int, 1);
|
||||
assert!(item.column_bool);
|
||||
assert_eq!(item.column_float, 1.0);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
pub fn get_none(train_dataset: SqlDs) {
|
||||
assert_eq!(train_dataset.get(10), None);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
pub fn multi_thread(train_dataset: SqlDs) {
|
||||
let indices: Vec<usize> = vec![0, 1, 1, 3, 4, 5, 6, 0, 8, 1];
|
||||
let results: Vec<Option<Sample>> =
|
||||
indices.par_iter().map(|&i| train_dataset.get(i)).collect();
|
||||
|
||||
let mut match_count = 0;
|
||||
for (_index, result) in indices.iter().zip(results.iter()) {
|
||||
if let Some(_val) = result {
|
||||
match_count += 1
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(match_count, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sqlite_dataset_storage() {
|
||||
// Test with non-existing file
|
||||
let storage = SqliteDatasetStorage::from_file("non-existing.db");
|
||||
assert!(!storage.exists());
|
||||
|
||||
// Test with non-existing name
|
||||
let storage = SqliteDatasetStorage::from_name("non-existing.db");
|
||||
assert!(!storage.exists());
|
||||
|
||||
// Test with existing file
|
||||
let storage = SqliteDatasetStorage::from_file("tests/data/sqlite-dataset.db");
|
||||
assert!(storage.exists());
|
||||
let result = storage.reader::<Sample>("train");
|
||||
assert!(result.is_ok());
|
||||
let train = result.unwrap();
|
||||
assert_eq!(train.len(), 2);
|
||||
|
||||
// Test get writer
|
||||
let temp_file = NamedTempFile::new().unwrap();
|
||||
let storage = SqliteDatasetStorage::from_file(temp_file.path());
|
||||
assert!(storage.exists());
|
||||
let result = storage.writer::<Sample>(true);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct Complex {
|
||||
column_str: String,
|
||||
column_bytes: Vec<u8>,
|
||||
column_int: i64,
|
||||
column_bool: bool,
|
||||
column_float: f64,
|
||||
column_complex: Vec<Vec<Vec<[u8; 3]>>>,
|
||||
}
|
||||
|
||||
/// Create a temporary directory.
|
||||
#[fixture]
|
||||
fn tmp_dir() -> TempDir {
|
||||
// Create a TempDir. This object will be automatically
|
||||
// deleted when it goes out of scope.
|
||||
tempdir().unwrap()
|
||||
}
|
||||
type Writer = SqliteDatasetWriter<Complex>;
|
||||
|
||||
/// Create a SqliteDatasetWriter with a temporary directory.
|
||||
/// Make sure to return the temporary directory so that it is not deleted.
|
||||
#[fixture]
|
||||
fn writer_fixture(tmp_dir: TempDir) -> (Writer, TempDir) {
|
||||
let temp_dir_str = tmp_dir.path();
|
||||
let storage = SqliteDatasetStorage::from_name("preprocessed").with_base_dir(temp_dir_str);
|
||||
let overwrite = true;
|
||||
let result = storage.writer::<Complex>(overwrite);
|
||||
assert!(result.is_ok());
|
||||
let writer = result.unwrap();
|
||||
(writer, tmp_dir)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_new() {
|
||||
// Test that the constructor works with overwrite = true
|
||||
let test_path = NamedTempFile::new().unwrap();
|
||||
let _writer = SqliteDatasetWriter::<Complex>::new(&test_path, true).unwrap();
|
||||
assert!(!test_path.path().exists());
|
||||
|
||||
// Test that the constructor works with overwrite = false
|
||||
let test_path = NamedTempFile::new().unwrap();
|
||||
let result = SqliteDatasetWriter::<Complex>::new(&test_path, false);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Test that the constructor works with no existing file
|
||||
let temp = NamedTempFile::new().unwrap();
|
||||
let test_path = temp.path().to_path_buf();
|
||||
assert!(temp.close().is_ok());
|
||||
assert!(!test_path.exists());
|
||||
let _writer = SqliteDatasetWriter::<Complex>::new(&test_path, true).unwrap();
|
||||
assert!(!test_path.exists());
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
pub fn sqlite_writer_write(writer_fixture: (Writer, TempDir)) {
|
||||
// Get the dataset_saver from the fixture and tmp_dir (will be deleted after scope)
|
||||
let (writer, _tmp_dir) = writer_fixture;
|
||||
|
||||
assert!(writer.overwrite);
|
||||
assert!(!writer.db_file.exists());
|
||||
|
||||
let new_item = Complex {
|
||||
column_str: "HI1".to_string(),
|
||||
column_bytes: vec![1_u8, 2, 3],
|
||||
column_int: 0,
|
||||
column_bool: true,
|
||||
column_float: 1.0,
|
||||
column_complex: vec![vec![vec![[1, 23_u8, 3]]]],
|
||||
};
|
||||
|
||||
let index = writer.write("train", &new_item).unwrap();
|
||||
assert_eq!(index, 0);
|
||||
|
||||
let mut writer = writer;
|
||||
|
||||
writer.set_completed().expect("Failed to set completed");
|
||||
|
||||
assert!(writer.db_file.exists());
|
||||
assert!(writer.db_file_tmp.is_none());
|
||||
|
||||
let result = writer.write("train", &new_item);
|
||||
|
||||
// Should fail because the writer is completed
|
||||
assert!(result.is_err());
|
||||
|
||||
let dataset = SqliteDataset::<Complex>::from_db_file(writer.db_file, "train").unwrap();
|
||||
|
||||
let fetched_item = dataset.get(0).unwrap();
|
||||
assert_eq!(fetched_item, new_item);
|
||||
assert_eq!(dataset.len(), 1);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
pub fn sqlite_writer_write_multi_thread(writer_fixture: (Writer, TempDir)) {
|
||||
// Get the dataset_saver from the fixture and tmp_dir (will be deleted after scope)
|
||||
let (writer, _tmp_dir) = writer_fixture;
|
||||
|
||||
let writer = Arc::new(writer);
|
||||
let record_count = 20;
|
||||
|
||||
let splits = ["train", "test"];
|
||||
|
||||
(0..record_count).into_par_iter().for_each(|index: i64| {
|
||||
let thread_id: std::thread::ThreadId = std::thread::current().id();
|
||||
let sample = Complex {
|
||||
column_str: format!("test_{thread_id:?}_{index}"),
|
||||
column_bytes: vec![index as u8, 2, 3],
|
||||
column_int: index,
|
||||
column_bool: true,
|
||||
column_float: 1.0,
|
||||
column_complex: vec![vec![vec![[1, index as u8, 3]]]],
|
||||
};
|
||||
|
||||
// half for train and half for test
|
||||
let split = splits[index as usize % 2];
|
||||
|
||||
let _index = writer.write(split, &sample).unwrap();
|
||||
});
|
||||
|
||||
let mut writer = Arc::try_unwrap(writer).unwrap();
|
||||
|
||||
writer
|
||||
.set_completed()
|
||||
.expect("Should set completed successfully");
|
||||
|
||||
let train =
|
||||
SqliteDataset::<Complex>::from_db_file(writer.db_file.clone(), "train").unwrap();
|
||||
let test = SqliteDataset::<Complex>::from_db_file(writer.db_file, "test").unwrap();
|
||||
|
||||
assert_eq!(train.len(), record_count as usize / 2);
|
||||
assert_eq!(test.len(), record_count as usize / 2);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
#![warn(missing_docs)]
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
|
||||
//! # Burn Dataset
|
||||
//!
|
||||
//! Burn Dataset is a library for creating and loading datasets.
|
||||
|
||||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
|
||||
extern crate alloc;
|
||||
extern crate dirs;
|
||||
|
||||
/// Sources for datasets.
|
||||
pub mod source;
|
||||
|
||||
pub mod transform;
|
||||
|
||||
/// Audio datasets.
|
||||
#[cfg(feature = "audio")]
|
||||
pub mod audio;
|
||||
|
||||
/// Vision datasets.
|
||||
#[cfg(feature = "vision")]
|
||||
pub mod vision;
|
||||
|
||||
/// Natural language processing datasets.
|
||||
#[cfg(feature = "nlp")]
|
||||
pub mod nlp;
|
||||
|
||||
/// Network dataset utilities.
|
||||
#[cfg(feature = "network")]
|
||||
pub mod network {
|
||||
pub use burn_std::network::*;
|
||||
}
|
||||
|
||||
mod dataset;
|
||||
pub use dataset::*;
|
||||
#[cfg(any(feature = "sqlite", feature = "sqlite-bundled"))]
|
||||
pub use source::huggingface::downloader::*;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test_data {
|
||||
pub fn string_items() -> Vec<String> {
|
||||
vec![
|
||||
"1 Item".to_string(),
|
||||
"2 Items".to_string(),
|
||||
"3 Items".to_string(),
|
||||
"4 Items".to_string(),
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,211 @@
|
||||
//! AG NEWS Dataset Module
|
||||
//!
|
||||
//! This module provides functionality for loading the AG NEWS text classification dataset.
|
||||
//! AG NEWS is a collection of news articles categorized into different topics.
|
||||
//! The dataset is split into training (120,000 articles) and test (7,600 articles) sets.
|
||||
//!
|
||||
//! ## Dataset Details
|
||||
//! - **Classes**: 4 categories (World, Sports, Business, Sci/Tech)
|
||||
//! - **AG NEWS mirror**: [fastai](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L83)
|
||||
//! - **License**: [Apache License](https://github.com/fastai/fastai/blob/master/LICENSE)
|
||||
//!
|
||||
//! ## Usage Example
|
||||
//! ```rust
|
||||
//! use burn_dataset::nlp::AgNewsDataset;
|
||||
//!
|
||||
//! // Create an AG NEWS dataset accessor
|
||||
//! let dataset = AgNewsDataset::new();
|
||||
//!
|
||||
//! // Access training and test sets
|
||||
//! let train_dataset = dataset.train();
|
||||
//! let test_dataset = dataset.test();
|
||||
//! ```
|
||||
|
||||
use std::{path::PathBuf, sync::Mutex};
|
||||
|
||||
use flate2::read::GzDecoder;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tar::Archive;
|
||||
|
||||
use crate::InMemDataset;
|
||||
use crate::network::downloader;
|
||||
|
||||
/// AG NEWS mirror from [fastai](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L83).
|
||||
/// Licensed under the [Apache License](https://github.com/fastai/fastai/blob/master/LICENSE).
|
||||
const AG_NEWS_URL: &str = "https://s3.amazonaws.com/fast-ai-nlp/ag_news_csv.tgz";
|
||||
|
||||
/// Represents an item in the AG NEWS dataset.
|
||||
///
|
||||
/// Each item contains a label, title, and content of a news article.
|
||||
#[derive(Deserialize, Serialize, Debug, Clone)]
|
||||
pub struct AgNewsItem {
|
||||
/// The category label of the news article.
|
||||
pub label: String,
|
||||
/// The title of the news article.
|
||||
pub title: String,
|
||||
/// The content/body of the news article.
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
/// AG NEWS dataset accessor.
|
||||
///
|
||||
/// This struct provides convenient access to the AG NEWS text classification dataset.
|
||||
/// It automatically downloads (if not already downloaded), extracts, and loads the datasets.
|
||||
///
|
||||
/// The dataset is split into training (120,000 articles) and test (7,600 articles) sets.
|
||||
pub struct AgNewsDataset {
|
||||
agnews_dir: PathBuf,
|
||||
}
|
||||
|
||||
/// AG NEWS dataset download lock.
|
||||
///
|
||||
/// This lock ensures that only one thread downloads the AG NEWS dataset at a time.
|
||||
static DOWNLOAD_LOCK: Mutex<()> = Mutex::new(());
|
||||
|
||||
impl AgNewsDataset {
|
||||
/// Creates a new AG NEWS dataset accessor.
|
||||
///
|
||||
/// This will download and extract the dataset if it's not already present.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
agnews_dir: Self::download(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Downloads and extracts the AG NEWS dataset.
|
||||
///
|
||||
/// # Returns
|
||||
/// Path to the directory containing the extracted dataset.
|
||||
fn download() -> PathBuf {
|
||||
// Acquire the lock. This will block if another thread already holds the lock.
|
||||
let _lock = DOWNLOAD_LOCK.lock().unwrap();
|
||||
|
||||
// Dataset files are stored in the burn-dataset cache directory
|
||||
let cache_dir = dirs::cache_dir()
|
||||
.expect("Could not get cache directory")
|
||||
.join("burn-dataset");
|
||||
|
||||
// AG NEWS dataset directory
|
||||
let agnews_dir = cache_dir.join("ag_news_csv");
|
||||
|
||||
// AG NEWS dataset url
|
||||
let url = AG_NEWS_URL;
|
||||
|
||||
// AG NEWS dataset archive filename
|
||||
let filename = "ag_news_csv.tgz";
|
||||
|
||||
// Check for already downloaded content
|
||||
if !agnews_dir.exists() {
|
||||
// Download gzip file
|
||||
let bytes = downloader::download_file_as_bytes(url, filename);
|
||||
|
||||
// Decode gzip file content and unpack archive
|
||||
let gz_buffer = GzDecoder::new(&bytes[..]);
|
||||
let mut archive = Archive::new(gz_buffer);
|
||||
archive.unpack(cache_dir).unwrap();
|
||||
}
|
||||
|
||||
agnews_dir
|
||||
}
|
||||
|
||||
/// Parses a CSV file into an in-memory dataset.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `file_path` - Path to the CSV file to parse.
|
||||
///
|
||||
/// # Returns
|
||||
/// An `InMemDataset` containing the parsed data.
|
||||
fn parse_csv(file_path: &str) -> InMemDataset<AgNewsItem> {
|
||||
let mut rdr = csv::ReaderBuilder::new();
|
||||
let rdr = rdr.has_headers(false);
|
||||
|
||||
InMemDataset::from_csv(file_path, &rdr).expect("Failed to parse CSV file")
|
||||
}
|
||||
|
||||
/// Gets the training dataset.
|
||||
///
|
||||
/// # Returns
|
||||
/// An `InMemDataset` instance containing 120,000 training articles.
|
||||
pub fn train(&self) -> InMemDataset<AgNewsItem> {
|
||||
let file_path = self.agnews_dir.join("train.csv");
|
||||
Self::parse_csv(file_path.to_str().unwrap())
|
||||
}
|
||||
|
||||
/// Gets the test dataset.
|
||||
///
|
||||
/// # Returns
|
||||
/// An `InMemDataset` instance containing 7,600 test articles.
|
||||
pub fn test(&self) -> InMemDataset<AgNewsItem> {
|
||||
let file_path = self.agnews_dir.join("test.csv");
|
||||
Self::parse_csv(file_path.to_str().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::Dataset;
|
||||
|
||||
// AG NEWS dataset train and test dataset lengths
|
||||
const TRAIN_DATASET_LEN: usize = 120000;
|
||||
const TEST_DATASET_LEN: usize = 7600;
|
||||
|
||||
#[test]
|
||||
fn test_agnews_download() {
|
||||
let agnews_dir = AgNewsDataset::download();
|
||||
assert!(agnews_dir.exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agnews_len() {
|
||||
let agnews = AgNewsDataset::new();
|
||||
let train_dataset = agnews.train();
|
||||
let test_dataset = agnews.test();
|
||||
assert_eq!(train_dataset.len(), TRAIN_DATASET_LEN);
|
||||
assert_eq!(test_dataset.len(), TEST_DATASET_LEN);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agnews_first_and_last_item() {
|
||||
let agnews = AgNewsDataset::new();
|
||||
|
||||
// Test the first and the last item in training dataset
|
||||
let train_dataset = agnews.train();
|
||||
let first_item = train_dataset.get(0).unwrap();
|
||||
let last_item = train_dataset.get(train_dataset.len() - 1).unwrap();
|
||||
assert!(compare_item(&first_item, &("3".to_string(), "Wall St. Bears Claw Back Into the Black (Reuters)".to_string(), "Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.".to_string())));
|
||||
assert!(compare_item(
|
||||
&last_item,
|
||||
&(
|
||||
"2".to_string(),
|
||||
"Nets get Carter from Raptors".to_string(),
|
||||
"INDIANAPOLIS -- All-Star Vince Carter was traded by the Toronto Raptors to the New Jersey Nets for Alonzo Mourning, Eric Williams, Aaron Williams, and a pair of first-round draft picks yesterday.".to_string()
|
||||
)
|
||||
));
|
||||
|
||||
// Test the first and the last item in test dataset
|
||||
let test_dataset = agnews.test();
|
||||
let first_item = test_dataset.get(0).unwrap();
|
||||
let last_item = test_dataset.get(test_dataset.len() - 1).unwrap();
|
||||
assert!(compare_item(
|
||||
&first_item,
|
||||
&(
|
||||
"3".to_string(),
|
||||
"Fears for T N pension after talks".to_string(),
|
||||
"Unions representing workers at Turner Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.".to_string()
|
||||
)
|
||||
));
|
||||
assert!(compare_item(
|
||||
&last_item,
|
||||
&(
|
||||
"3".to_string(),
|
||||
"EBay gets into rentals".to_string(),
|
||||
"EBay plans to buy the apartment and home rental service Rent.com for \\$415 million, adding to its already exhaustive breadth of offerings.".to_string()
|
||||
)
|
||||
));
|
||||
}
|
||||
|
||||
fn compare_item(item: &AgNewsItem, target: &(String, String, String)) -> bool {
|
||||
item.label == target.0 && item.title == target.1 && item.content == target.2
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
#[cfg(feature = "builtin-sources")]
|
||||
mod ag_news;
|
||||
mod text_folder;
|
||||
|
||||
#[cfg(feature = "builtin-sources")]
|
||||
pub use ag_news::*;
|
||||
pub use text_folder::*;
|
||||
@@ -0,0 +1,421 @@
|
||||
use crate::transform::{Mapper, MapperDataset};
|
||||
use crate::{Dataset, InMemDataset};
|
||||
|
||||
use encoding_rs::{GB18030, GBK, UTF_8, UTF_16BE, UTF_16LE};
|
||||
use globwalk::{self, DirEntry};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::fs;
|
||||
use std::io::Read;
|
||||
use std::path::{Path, PathBuf};
|
||||
use thiserror::Error;
|
||||
|
||||
const SUPPORTED_FILES: [&str; 1] = ["txt"];
|
||||
|
||||
/// Text data type.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct TextData {
|
||||
/// The text content.
|
||||
pub text: String,
|
||||
|
||||
/// Original text source.
|
||||
pub text_path: String,
|
||||
}
|
||||
|
||||
/// Text dataset item.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct TextDatasetItem {
|
||||
/// Text content.
|
||||
pub text: TextData,
|
||||
|
||||
/// Label for the text.
|
||||
pub label: usize,
|
||||
}
|
||||
|
||||
/// Raw text dataset item.
|
||||
#[derive(Debug, Clone)]
|
||||
struct TextDatasetItemRaw {
|
||||
/// Text path.
|
||||
text_path: PathBuf,
|
||||
|
||||
/// Text label.
|
||||
label: String,
|
||||
}
|
||||
|
||||
impl TextDatasetItemRaw {
|
||||
fn new<P: AsRef<Path>>(text_path: P, label: String) -> TextDatasetItemRaw {
|
||||
TextDatasetItemRaw {
|
||||
text_path: text_path.as_ref().to_path_buf(),
|
||||
label,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct PathToTextDatasetItem {
|
||||
classes: HashMap<String, usize>,
|
||||
}
|
||||
|
||||
/// Parse the text content from file with auto-detection of encoding.
|
||||
fn parse_text_content(text_path: &PathBuf) -> String {
|
||||
// Read raw bytes from disk
|
||||
let mut file = fs::File::open(text_path).unwrap();
|
||||
let mut bytes = Vec::new();
|
||||
file.read_to_end(&mut bytes).unwrap();
|
||||
|
||||
// Try to detect encoding and decode text
|
||||
// First try UTF-8 with BOM
|
||||
if bytes.starts_with(&[0xEF, 0xBB, 0xBF]) && bytes.len() >= 3 {
|
||||
let (result, _, had_errors) = UTF_8.decode(&bytes[3..]);
|
||||
if !had_errors {
|
||||
return result.into_owned();
|
||||
}
|
||||
}
|
||||
|
||||
// Try UTF-8 without BOM
|
||||
let (result, _, had_errors) = UTF_8.decode(&bytes);
|
||||
if !had_errors {
|
||||
return result.into_owned();
|
||||
}
|
||||
|
||||
// Try UTF-16LE with BOM
|
||||
if bytes.starts_with(&[0xFF, 0xFE]) && bytes.len() >= 2 {
|
||||
let (result, had_errors) = UTF_16LE.decode_with_bom_removal(&bytes[2..]);
|
||||
if !had_errors {
|
||||
return result.into_owned();
|
||||
}
|
||||
}
|
||||
|
||||
// Try UTF-16BE with BOM
|
||||
if bytes.starts_with(&[0xFE, 0xFF]) && bytes.len() >= 2 {
|
||||
let (result, had_errors) = UTF_16BE.decode_with_bom_removal(&bytes[2..]);
|
||||
if !had_errors {
|
||||
return result.into_owned();
|
||||
}
|
||||
}
|
||||
|
||||
// Try GB18030 encoding
|
||||
let (result, _, had_errors) = GB18030.decode(&bytes);
|
||||
if !had_errors {
|
||||
return result.into_owned();
|
||||
}
|
||||
|
||||
// Try GBK encoding
|
||||
let (result, _, had_errors) = GBK.decode(&bytes);
|
||||
if !had_errors {
|
||||
return result.into_owned();
|
||||
}
|
||||
|
||||
// Default fallback - use from_utf8_lossy for any remaining cases
|
||||
String::from_utf8_lossy(&bytes).to_string()
|
||||
}
|
||||
|
||||
impl Mapper<TextDatasetItemRaw, TextDatasetItem> for PathToTextDatasetItem {
|
||||
/// Convert a raw text dataset item (path-like) to text content with a target label.
|
||||
fn map(&self, item: &TextDatasetItemRaw) -> TextDatasetItem {
|
||||
let label = *self.classes.get(&item.label).unwrap();
|
||||
|
||||
// Load text from disk
|
||||
let text_content = parse_text_content(&item.text_path);
|
||||
|
||||
let text_data = TextData {
|
||||
text: text_content,
|
||||
text_path: item.text_path.display().to_string(),
|
||||
};
|
||||
|
||||
TextDatasetItem {
|
||||
text: text_data,
|
||||
label,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Error type for [TextFolderDataset](TextFolderDataset).
|
||||
#[derive(Error, Debug)]
|
||||
pub enum TextLoaderError {
|
||||
/// Unknown error.
|
||||
#[error("unknown: `{0}`")]
|
||||
Unknown(String),
|
||||
|
||||
/// I/O operation error.
|
||||
#[error("I/O error: `{0}`")]
|
||||
IOError(String),
|
||||
|
||||
/// Invalid file error.
|
||||
#[error("Invalid file extension: `{0}`")]
|
||||
InvalidFileExtensionError(String),
|
||||
|
||||
/// Encoding error.
|
||||
#[error("Encoding error: `{0}`")]
|
||||
EncodingError(String),
|
||||
}
|
||||
|
||||
type TextDatasetMapper =
|
||||
MapperDataset<InMemDataset<TextDatasetItemRaw>, PathToTextDatasetItem, TextDatasetItemRaw>;
|
||||
|
||||
/// A generic dataset to load texts from disk.
|
||||
pub struct TextFolderDataset {
|
||||
dataset: TextDatasetMapper,
|
||||
}
|
||||
|
||||
impl Dataset<TextDatasetItem> for TextFolderDataset {
|
||||
fn get(&self, index: usize) -> Option<TextDatasetItem> {
|
||||
self.dataset.get(index)
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.dataset.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl TextFolderDataset {
|
||||
/// Create a text classification dataset from the root folder.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `root` - Dataset root folder.
|
||||
///
|
||||
/// # Returns
|
||||
/// A new dataset instance.
|
||||
pub fn new_classification<P: AsRef<Path>>(root: P) -> Result<Self, TextLoaderError> {
|
||||
// New dataset containing any of the supported file types
|
||||
TextFolderDataset::new_classification_with(root, &SUPPORTED_FILES)
|
||||
}
|
||||
|
||||
/// Create a text classification dataset from the root folder.
|
||||
/// The included texts are filtered based on the provided extensions.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `root` - Dataset root folder.
|
||||
/// * `extensions` - List of allowed extensions.
|
||||
///
|
||||
/// # Returns
|
||||
/// A new dataset instance.
|
||||
pub fn new_classification_with<P, S>(root: P, extensions: &[S]) -> Result<Self, TextLoaderError>
|
||||
where
|
||||
P: AsRef<Path>,
|
||||
S: AsRef<str>,
|
||||
{
|
||||
// Glob all texts with extensions
|
||||
let walker = globwalk::GlobWalkerBuilder::from_patterns(
|
||||
root.as_ref(),
|
||||
&[format!(
|
||||
"*.{{{}}}", // "*.{ext1,ext2,ext3}
|
||||
extensions
|
||||
.iter()
|
||||
.map(Self::check_extension)
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.join(",")
|
||||
)],
|
||||
)
|
||||
.follow_links(true)
|
||||
.sort_by(|p1: &DirEntry, p2: &DirEntry| p1.path().cmp(p2.path())) // order by path
|
||||
.build()
|
||||
.map_err(|err| TextLoaderError::Unknown(format!("{err:?}")))?
|
||||
.filter_map(Result::ok);
|
||||
|
||||
// Get all dataset items
|
||||
let mut items = Vec::new();
|
||||
let mut classes = HashSet::new();
|
||||
for text in walker {
|
||||
let text_path = text.path();
|
||||
|
||||
// Label name is represented by the parent folder name
|
||||
let label = text_path
|
||||
.parent()
|
||||
.ok_or_else(|| {
|
||||
TextLoaderError::IOError("Could not resolve text parent folder".to_string())
|
||||
})?
|
||||
.file_name()
|
||||
.ok_or_else(|| {
|
||||
TextLoaderError::IOError(
|
||||
"Could not resolve text parent folder name".to_string(),
|
||||
)
|
||||
})?
|
||||
.to_string_lossy()
|
||||
.into_owned();
|
||||
|
||||
classes.insert(label.clone());
|
||||
|
||||
items.push(TextDatasetItemRaw::new(text_path, label))
|
||||
}
|
||||
|
||||
// Sort class names
|
||||
let mut classes = classes.into_iter().collect::<Vec<_>>();
|
||||
classes.sort();
|
||||
|
||||
Self::with_items(items, &classes)
|
||||
}
|
||||
|
||||
/// Create a text classification dataset with the specified items.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `items` - List of dataset items, each item represented by a tuple `(text path, label)`.
|
||||
/// * `classes` - Dataset class names.
|
||||
///
|
||||
/// # Returns
|
||||
/// A new dataset instance.
|
||||
pub fn new_classification_with_items<P: AsRef<Path>, S: AsRef<str>>(
|
||||
items: Vec<(P, String)>,
|
||||
classes: &[S],
|
||||
) -> Result<Self, TextLoaderError> {
|
||||
// Parse items and check valid text extension types
|
||||
let items = items
|
||||
.into_iter()
|
||||
.map(|(path, label)| {
|
||||
// Map text path and label
|
||||
let path = path.as_ref();
|
||||
let label = label;
|
||||
|
||||
Self::check_extension(&path.extension().unwrap().to_str().unwrap())?;
|
||||
|
||||
Ok(TextDatasetItemRaw::new(path, label))
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
Self::with_items(items, classes)
|
||||
}
|
||||
|
||||
/// Create a text dataset with the specified items.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `items` - Raw dataset items.
|
||||
/// * `classes` - Dataset class names.
|
||||
///
|
||||
/// # Returns
|
||||
/// A new dataset instance.
|
||||
fn with_items<S: AsRef<str>>(
|
||||
items: Vec<TextDatasetItemRaw>,
|
||||
classes: &[S],
|
||||
) -> Result<Self, TextLoaderError> {
|
||||
// NOTE: right now we don't need to validate the supported text files since
|
||||
// the method is private. We assume it's already validated.
|
||||
let dataset = InMemDataset::new(items);
|
||||
|
||||
// Class names to index map
|
||||
let classes = classes.iter().map(|c| c.as_ref()).collect::<Vec<_>>();
|
||||
let classes_map: HashMap<_, _> = classes
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(idx, cls)| (cls.to_string(), idx))
|
||||
.collect();
|
||||
|
||||
let mapper = PathToTextDatasetItem {
|
||||
classes: classes_map,
|
||||
};
|
||||
let dataset = MapperDataset::new(dataset, mapper);
|
||||
|
||||
Ok(Self { dataset })
|
||||
}
|
||||
|
||||
/// Check if extension is supported.
|
||||
fn check_extension<S: AsRef<str>>(extension: &S) -> Result<String, TextLoaderError> {
|
||||
let extension = extension.as_ref();
|
||||
if !SUPPORTED_FILES.contains(&extension) {
|
||||
Err(TextLoaderError::InvalidFileExtensionError(
|
||||
extension.to_string(),
|
||||
))
|
||||
} else {
|
||||
Ok(extension.to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::path::Path;
|
||||
|
||||
const TEXT_ROOT: &str = "tests/data/text_folder";
|
||||
|
||||
#[test]
|
||||
fn test_text_folder_dataset() {
|
||||
let dataset = TextFolderDataset::new_classification(TEXT_ROOT).unwrap();
|
||||
|
||||
// Dataset should have 4 elements (2 positive + 2 negative)
|
||||
assert_eq!(dataset.len(), 4);
|
||||
assert_eq!(dataset.get(4), None);
|
||||
|
||||
// Check that we have items from both classes
|
||||
let mut found_positive = false;
|
||||
let mut found_negative = false;
|
||||
|
||||
for i in 0..dataset.len() {
|
||||
let item = dataset.get(i).unwrap();
|
||||
if item.label == 0 {
|
||||
found_negative = true;
|
||||
// Check that the text content is loaded correctly
|
||||
assert!(!item.text.text.is_empty());
|
||||
assert!(item.text.text_path.contains("negative"));
|
||||
} else if item.label == 1 {
|
||||
found_positive = true;
|
||||
// Check that the text content is loaded correctly
|
||||
assert!(!item.text.text.is_empty());
|
||||
assert!(item.text.text_path.contains("positive"));
|
||||
}
|
||||
}
|
||||
|
||||
// Verify we found items from both classes
|
||||
assert!(found_positive);
|
||||
assert!(found_negative);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_text_folder_dataset_with_invalid_extension() {
|
||||
// Try to create a dataset with an unsupported extension
|
||||
let result = TextFolderDataset::new_classification_with(TEXT_ROOT, &["invalid"]);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_text_folder_dataset_with_items() {
|
||||
// Create the dataset
|
||||
let root = Path::new(TEXT_ROOT);
|
||||
let items = vec![
|
||||
(
|
||||
root.join("positive").join("sample1.txt"),
|
||||
"positive".to_string(),
|
||||
),
|
||||
(
|
||||
root.join("negative").join("sample2.txt"),
|
||||
"negative".to_string(),
|
||||
),
|
||||
];
|
||||
let classes = vec!["positive", "negative"];
|
||||
let dataset = TextFolderDataset::new_classification_with_items(items, &classes).unwrap();
|
||||
|
||||
// Dataset should have 2 elements
|
||||
assert_eq!(dataset.len(), 2);
|
||||
assert_eq!(dataset.get(2), None);
|
||||
|
||||
// Get items
|
||||
let item0 = dataset.get(0).unwrap();
|
||||
let item1 = dataset.get(1).unwrap();
|
||||
|
||||
// Check item0
|
||||
assert!(compare_item(
|
||||
&item0,
|
||||
&(
|
||||
"This is a positive text sample for testing the text folder dataset functionality."
|
||||
.to_string(),
|
||||
0
|
||||
)
|
||||
));
|
||||
|
||||
// Check item1
|
||||
assert_eq!(item1.label, 1);
|
||||
assert!(item1.text.text_path.contains("negative"));
|
||||
assert!(compare_item(
|
||||
&item1,
|
||||
&(
|
||||
"另一个负面文本样本,用以确保数据集能够处理同一类别中的多个文件。".to_string(),
|
||||
1
|
||||
)
|
||||
));
|
||||
}
|
||||
|
||||
fn compare_item(item: &TextDatasetItem, target: &(String, usize)) -> bool {
|
||||
item.text.text == target.0 && item.label == target.1
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,367 @@
|
||||
use std::fs::{self, create_dir_all};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::Command;
|
||||
|
||||
use crate::{SqliteDataset, SqliteDatasetError, SqliteDatasetStorage};
|
||||
|
||||
use sanitize_filename::sanitize;
|
||||
use serde::de::DeserializeOwned;
|
||||
use thiserror::Error;
|
||||
|
||||
const PYTHON_SOURCE: &str = include_str!("importer.py");
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
const VENV_BIN_PYTHON: &str = "bin/python3";
|
||||
#[cfg(target_os = "windows")]
|
||||
const VENV_BIN_PYTHON: &str = "Scripts\\python";
|
||||
|
||||
/// Error type for [HuggingfaceDatasetLoader](HuggingfaceDatasetLoader).
|
||||
#[derive(Error, Debug)]
|
||||
pub enum ImporterError {
|
||||
/// Unknown error.
|
||||
#[error("unknown: `{0}`")]
|
||||
Unknown(String),
|
||||
|
||||
/// Fail to download python dependencies.
|
||||
#[error("fail to download python dependencies: `{0}`")]
|
||||
FailToDownloadPythonDependencies(String),
|
||||
|
||||
/// Fail to create sqlite dataset.
|
||||
#[error("sqlite dataset: `{0}`")]
|
||||
SqliteDataset(#[from] SqliteDatasetError),
|
||||
|
||||
/// python3 is not installed.
|
||||
#[error("python3 is not installed")]
|
||||
PythonNotInstalled,
|
||||
|
||||
/// venv environment is not initialized.
|
||||
#[error("venv environment is not initialized")]
|
||||
VenvNotInitialized,
|
||||
}
|
||||
|
||||
/// Load a dataset from [huggingface datasets](https://huggingface.co/datasets).
|
||||
///
|
||||
/// The dataset with all splits is stored in a single sqlite database (see [SqliteDataset](SqliteDataset)).
|
||||
///
|
||||
/// # Example
|
||||
/// ```no_run
|
||||
/// use burn_dataset::HuggingfaceDatasetLoader;
|
||||
/// use burn_dataset::SqliteDataset;
|
||||
/// use serde::{Deserialize, Serialize};
|
||||
///
|
||||
/// #[derive(Deserialize, Debug, Clone)]
|
||||
/// struct MnistItemRaw {
|
||||
/// pub image_bytes: Vec<u8>,
|
||||
/// pub label: usize,
|
||||
/// }
|
||||
///
|
||||
/// let train_ds:SqliteDataset<MnistItemRaw> = HuggingfaceDatasetLoader::new("mnist")
|
||||
/// .dataset("train")
|
||||
/// .unwrap();
|
||||
/// ```
|
||||
///
|
||||
/// # Note
|
||||
/// This loader relies on the [`datasets` library by HuggingFace](https://huggingface.co/docs/datasets/index)
|
||||
/// to download datasets. This is a Python library, so you must have an existing Python installation.
|
||||
pub struct HuggingfaceDatasetLoader {
|
||||
name: String,
|
||||
subset: Option<String>,
|
||||
base_dir: Option<PathBuf>,
|
||||
huggingface_token: Option<String>,
|
||||
huggingface_cache_dir: Option<String>,
|
||||
huggingface_data_dir: Option<String>,
|
||||
trust_remote_code: bool,
|
||||
use_python_venv: bool,
|
||||
}
|
||||
|
||||
impl HuggingfaceDatasetLoader {
|
||||
/// Create a huggingface dataset loader.
|
||||
pub fn new(name: &str) -> Self {
|
||||
Self {
|
||||
name: name.to_string(),
|
||||
subset: None,
|
||||
base_dir: None,
|
||||
huggingface_token: None,
|
||||
huggingface_cache_dir: None,
|
||||
huggingface_data_dir: None,
|
||||
trust_remote_code: false,
|
||||
use_python_venv: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a huggingface dataset loader for a subset of the dataset.
|
||||
///
|
||||
/// The subset name must be one of the subsets listed in the dataset page.
|
||||
///
|
||||
/// If no subset names are listed, then do not use this method.
|
||||
pub fn with_subset(mut self, subset: &str) -> Self {
|
||||
self.subset = Some(subset.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
/// Specify a base directory to store the dataset.
|
||||
///
|
||||
/// If not specified, the dataset will be stored in the system cache directory under `burn-dataset`.
|
||||
pub fn with_base_dir(mut self, base_dir: &str) -> Self {
|
||||
self.base_dir = Some(base_dir.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Specify a huggingface token to download datasets behind authentication.
|
||||
///
|
||||
/// You can get a token from [tokens settings](https://huggingface.co/settings/tokens)
|
||||
pub fn with_huggingface_token(mut self, huggingface_token: &str) -> Self {
|
||||
self.huggingface_token = Some(huggingface_token.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
/// Specify a huggingface cache directory to store the downloaded datasets.
|
||||
///
|
||||
/// If not specified, the dataset will be stored in the system cache directory under `huggingface/datasets`.
|
||||
pub fn with_huggingface_cache_dir(mut self, huggingface_cache_dir: &str) -> Self {
|
||||
self.huggingface_cache_dir = Some(huggingface_cache_dir.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
/// Specify a relative path to a subset of a dataset. This is used in some datasets for the
|
||||
/// manual steps of dataset download process.
|
||||
///
|
||||
/// Unless you've encountered a ManualDownloadError
|
||||
/// when loading your dataset you probably don't have to worry about this setting.
|
||||
pub fn with_huggingface_data_dir(mut self, huggingface_data_dir: &str) -> Self {
|
||||
self.huggingface_data_dir = Some(huggingface_data_dir.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
/// Specify whether or not to trust remote code.
|
||||
///
|
||||
/// If not specified, trust remote code is set to true.
|
||||
pub fn with_trust_remote_code(mut self, trust_remote_code: bool) -> Self {
|
||||
self.trust_remote_code = trust_remote_code;
|
||||
self
|
||||
}
|
||||
|
||||
/// Specify whether or not to use the burn-dataset Python
|
||||
/// virtualenv for running the importer script. If false, local
|
||||
/// `python3`'s environment is used.
|
||||
///
|
||||
/// If not specified, the virtualenv is used.
|
||||
pub fn with_use_python_venv(mut self, use_python_venv: bool) -> Self {
|
||||
self.use_python_venv = use_python_venv;
|
||||
self
|
||||
}
|
||||
|
||||
/// Load the dataset.
|
||||
pub fn dataset<I: DeserializeOwned + Clone>(
|
||||
self,
|
||||
split: &str,
|
||||
) -> Result<SqliteDataset<I>, ImporterError> {
|
||||
let db_file = self.db_file()?;
|
||||
let dataset = SqliteDataset::from_db_file(db_file, split)?;
|
||||
Ok(dataset)
|
||||
}
|
||||
|
||||
/// Get the path to the sqlite database file.
|
||||
///
|
||||
/// If the database file does not exist, it will be downloaded and imported.
|
||||
pub fn db_file(self) -> Result<PathBuf, ImporterError> {
|
||||
// determine (and create if needed) the base directory
|
||||
let base_dir = SqliteDatasetStorage::base_dir(self.base_dir);
|
||||
|
||||
if !base_dir.exists() {
|
||||
create_dir_all(&base_dir).expect("Failed to create base directory");
|
||||
}
|
||||
|
||||
//sanitize the name and subset
|
||||
let name = sanitize(self.name.as_str());
|
||||
|
||||
// create the db file path
|
||||
let db_file_name = if let Some(subset) = self.subset.clone() {
|
||||
format!("{name}-{}.db", sanitize(subset.as_str()))
|
||||
} else {
|
||||
format!("{name}.db")
|
||||
};
|
||||
|
||||
let db_file = base_dir.join(db_file_name);
|
||||
|
||||
// import the dataset if needed
|
||||
if !Path::new(&db_file).exists() {
|
||||
import(
|
||||
self.name,
|
||||
self.subset,
|
||||
db_file.clone(),
|
||||
base_dir,
|
||||
self.huggingface_token,
|
||||
self.huggingface_cache_dir,
|
||||
self.huggingface_data_dir,
|
||||
self.trust_remote_code,
|
||||
self.use_python_venv,
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(db_file)
|
||||
}
|
||||
}
|
||||
|
||||
/// Import a dataset from huggingface. The transformed dataset is stored as sqlite database.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn import(
|
||||
name: String,
|
||||
subset: Option<String>,
|
||||
base_file: PathBuf,
|
||||
base_dir: PathBuf,
|
||||
huggingface_token: Option<String>,
|
||||
huggingface_cache_dir: Option<String>,
|
||||
huggingface_data_dir: Option<String>,
|
||||
trust_remote_code: bool,
|
||||
use_python_venv: bool,
|
||||
) -> Result<(), ImporterError> {
|
||||
let python_path = if use_python_venv {
|
||||
install_python_deps(&base_dir)?
|
||||
} else {
|
||||
get_python_name()?.into()
|
||||
};
|
||||
|
||||
let mut command = Command::new(python_path);
|
||||
|
||||
command.arg(importer_script_path(&base_dir));
|
||||
|
||||
command.arg("--name");
|
||||
command.arg(name);
|
||||
|
||||
command.arg("--file");
|
||||
command.arg(base_file);
|
||||
|
||||
if let Some(subset) = subset {
|
||||
command.arg("--subset");
|
||||
command.arg(subset);
|
||||
}
|
||||
|
||||
if let Some(huggingface_token) = huggingface_token {
|
||||
command.arg("--token");
|
||||
command.arg(huggingface_token);
|
||||
}
|
||||
|
||||
if let Some(huggingface_cache_dir) = huggingface_cache_dir {
|
||||
command.arg("--cache_dir");
|
||||
command.arg(huggingface_cache_dir);
|
||||
}
|
||||
if let Some(huggingface_data_dir) = huggingface_data_dir {
|
||||
command.arg("--data_dir");
|
||||
command.arg(huggingface_data_dir);
|
||||
}
|
||||
if trust_remote_code {
|
||||
command.arg("--trust_remote_code");
|
||||
command.arg("True");
|
||||
}
|
||||
let mut handle = command.spawn().unwrap();
|
||||
|
||||
let exit_status = handle
|
||||
.wait()
|
||||
.map_err(|err| ImporterError::Unknown(format!("{err:?}")))?;
|
||||
|
||||
if !exit_status.success() {
|
||||
return Err(ImporterError::Unknown(format!("{exit_status}")));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// check python --version output is `Python 3.x.x`
|
||||
fn check_python_version_is_3(python: &str) -> bool {
|
||||
let output = Command::new(python).arg("--version").output();
|
||||
match output {
|
||||
Ok(output) => {
|
||||
if output.status.success() {
|
||||
let version_string = String::from_utf8_lossy(&output.stdout);
|
||||
if let Some(index) = version_string.find(' ') {
|
||||
let version = &version_string[index + 1..];
|
||||
version.starts_with("3.")
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
Err(_error) => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// get python3 name `python` `python3` or `py`
|
||||
fn get_python_name() -> Result<&'static str, ImporterError> {
|
||||
let python_name_list = ["python3", "python", "py"];
|
||||
for python_name in python_name_list.iter() {
|
||||
if check_python_version_is_3(python_name) {
|
||||
return Ok(python_name);
|
||||
}
|
||||
}
|
||||
Err(ImporterError::PythonNotInstalled)
|
||||
}
|
||||
|
||||
fn importer_script_path(base_dir: &Path) -> PathBuf {
|
||||
let path_file = base_dir.join("importer.py");
|
||||
|
||||
fs::write(&path_file, PYTHON_SOURCE).expect("Write python dataset downloader");
|
||||
path_file
|
||||
}
|
||||
|
||||
fn install_python_deps(base_dir: &Path) -> Result<PathBuf, ImporterError> {
|
||||
let venv_dir = base_dir.join("venv");
|
||||
let venv_python_path = venv_dir.join(VENV_BIN_PYTHON);
|
||||
// If the venv environment is already initialized, skip the initialization.
|
||||
if !check_python_version_is_3(venv_python_path.to_str().unwrap()) {
|
||||
let python_name = get_python_name()?;
|
||||
let mut command = Command::new(python_name);
|
||||
command.args([
|
||||
"-m",
|
||||
"venv",
|
||||
venv_dir
|
||||
.as_os_str()
|
||||
.to_str()
|
||||
.expect("Path utf8 conversion should not fail"),
|
||||
]);
|
||||
|
||||
// Spawn the venv creation process and wait for it to complete.
|
||||
let mut handle = command.spawn().unwrap();
|
||||
|
||||
handle.wait().map_err(|err| {
|
||||
ImporterError::FailToDownloadPythonDependencies(format!(" error: {err}"))
|
||||
})?;
|
||||
// Check if the venv environment can be used successfully."
|
||||
if !check_python_version_is_3(venv_python_path.to_str().unwrap()) {
|
||||
return Err(ImporterError::VenvNotInitialized);
|
||||
}
|
||||
}
|
||||
|
||||
let mut ensurepip_cmd = Command::new(&venv_python_path);
|
||||
ensurepip_cmd.args(["-m", "ensurepip", "--upgrade"]);
|
||||
let status = ensurepip_cmd.status().map_err(|err| {
|
||||
ImporterError::FailToDownloadPythonDependencies(format!("failed to run ensurepip: {err}"))
|
||||
})?;
|
||||
if !status.success() {
|
||||
return Err(ImporterError::FailToDownloadPythonDependencies(
|
||||
"ensurepip failed to initialize pip".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut command = Command::new(&venv_python_path);
|
||||
command.args([
|
||||
"-m",
|
||||
"pip",
|
||||
"--quiet",
|
||||
"install",
|
||||
"pyarrow",
|
||||
"sqlalchemy",
|
||||
"Pillow",
|
||||
"soundfile",
|
||||
"datasets",
|
||||
]);
|
||||
|
||||
// Spawn the pip install process and wait for it to complete.
|
||||
let mut handle = command.spawn().unwrap();
|
||||
handle
|
||||
.wait()
|
||||
.map_err(|err| ImporterError::FailToDownloadPythonDependencies(format!(" error: {err}")))?;
|
||||
|
||||
Ok(venv_python_path)
|
||||
}
|
||||
@@ -0,0 +1,207 @@
|
||||
import argparse
|
||||
|
||||
import pyarrow as pa
|
||||
from datasets import Audio, Image, load_dataset
|
||||
from sqlalchemy import Column, Integer, Table, create_engine, event, inspect
|
||||
from sqlalchemy.types import LargeBinary
|
||||
|
||||
|
||||
def download_and_export(
|
||||
name: str,
|
||||
subset: str,
|
||||
db_file: str,
|
||||
token: str,
|
||||
cache_dir: str,
|
||||
data_dir: str | None,
|
||||
trust_remote_code: bool,
|
||||
):
|
||||
"""
|
||||
Download a dataset from using HuggingFace dataset and export it to a sqlite database.
|
||||
"""
|
||||
|
||||
# TODO For media columns (Image and Audio) sometimes when decode=False,
|
||||
# bytes can be none {'bytes': None, 'path': 'healthy_train.265.jpg'}
|
||||
# We should handle this case, but unfortunately we did not come across this case yet to test it.
|
||||
|
||||
print("*" * 80)
|
||||
print("Starting huggingface dataset download and export")
|
||||
print(f"Dataset Name: {name}")
|
||||
print(f"Subset Name: {subset}")
|
||||
print(f"Sqlite database file: {db_file}")
|
||||
print(f"Trust remote code: {trust_remote_code}")
|
||||
if cache_dir is None:
|
||||
print(f"Custom cache dir: {cache_dir}")
|
||||
print("*" * 80)
|
||||
|
||||
# Load the dataset
|
||||
dataset_all = load_dataset(
|
||||
name,
|
||||
subset,
|
||||
cache_dir=cache_dir,
|
||||
data_dir=data_dir,
|
||||
use_auth_token=token,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
print(f"Dataset: {dataset_all}")
|
||||
|
||||
# Create the database connection descriptor (sqlite)
|
||||
engine = create_engine(f"sqlite:///{db_file}")
|
||||
|
||||
# Set some sqlite pragmas to speed up the database
|
||||
event.listen(engine, "connect", set_sqlite_pragma)
|
||||
|
||||
# Add an row_id column to each table as primary key (datasets does not have API for this)
|
||||
event.listen(Table, "before_create", add_pk_column)
|
||||
|
||||
# Export each split in the dataset
|
||||
for key in dataset_all.keys():
|
||||
dataset = dataset_all[key]
|
||||
|
||||
# Disable decoding for audio and image fields
|
||||
dataset = disable_decoding(dataset)
|
||||
|
||||
# Flatten the dataset
|
||||
dataset = dataset.flatten()
|
||||
|
||||
# Rename columns to remove dots from the names
|
||||
dataset = rename_columns(dataset)
|
||||
|
||||
print(f"Saving dataset: {name} - {key}")
|
||||
print(f"Dataset features: {dataset.features}")
|
||||
|
||||
# Save the dataset to a sqlite database
|
||||
dataset.to_sql(
|
||||
key, # table name
|
||||
engine,
|
||||
# don't save the index, use row_id instead (index is not unique)
|
||||
index=False,
|
||||
dtype=blob_columns(dataset), # save binary columns as blob
|
||||
)
|
||||
|
||||
# Print the schema of the database so we can reference the columns in the rust code
|
||||
print_table_info(engine)
|
||||
|
||||
|
||||
def disable_decoding(dataset):
|
||||
"""
|
||||
Disable decoding for audio and image fields. The fields will be saved as raw file bytes.
|
||||
"""
|
||||
for k, v in dataset.features.items():
|
||||
if isinstance(v, Audio):
|
||||
dataset = dataset.cast_column(k, Audio(decode=False))
|
||||
elif isinstance(v, Image):
|
||||
dataset = dataset.cast_column(k, Image(decode=False))
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def rename_columns(dataset):
|
||||
"""
|
||||
Rename columns to remove dots from the names. Dots appear in the column names because of the flattening.
|
||||
Dots are not allowed in column names in rust and sql (unless quoted). So we replace them with underscores.
|
||||
This way there is an easy name mapping between the rust and sql columns.
|
||||
"""
|
||||
|
||||
for name in dataset.features.keys():
|
||||
if "." in name:
|
||||
dataset = dataset.rename_column(name, name.replace(".", "_"))
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def blob_columns(dataset):
|
||||
"""
|
||||
Make sure all binary columns are blob columns in the database because
|
||||
`to_sql` exports binary values as TEXT instead of BLOB.
|
||||
"""
|
||||
type_mapping = {}
|
||||
for name, value in dataset.features.items():
|
||||
if value.pa_type is not None and pa.types.is_binary(value.pa_type):
|
||||
type_mapping[name] = LargeBinary
|
||||
return type_mapping
|
||||
|
||||
|
||||
def set_sqlite_pragma(dbapi_connection, connection_record):
|
||||
"""
|
||||
Set some sqlite pragmas to speed up the database
|
||||
"""
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA synchronous = OFF")
|
||||
cursor.execute("PRAGMA journal_mode = OFF")
|
||||
cursor.close()
|
||||
|
||||
|
||||
def add_pk_column(target, connection, **kw):
|
||||
"""
|
||||
Add an id column to each table.
|
||||
"""
|
||||
target.append_column(Column("row_id", Integer, primary_key=True))
|
||||
|
||||
|
||||
def print_table_info(engine):
|
||||
"""
|
||||
Print the schema of the database so we can reference the columns in the rust code
|
||||
"""
|
||||
print(f"Printing table schema for sqlite3 db ({engine})")
|
||||
inspector = inspect(engine)
|
||||
for table_name in inspector.get_table_names():
|
||||
print(f"Table: {table_name}")
|
||||
for column in inspector.get_columns(table_name):
|
||||
print(f"Column: {column['name']} - {column['type']}")
|
||||
print("")
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Huggingface datasets downloader to use with burn-dataset"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--name", type=str, help="Name of the dataset to download", required=True
|
||||
)
|
||||
parser.add_argument(
|
||||
"--file", type=str, help="Base file name where the data is saved", required=True
|
||||
)
|
||||
parser.add_argument(
|
||||
"--subset", type=str, help="Subset name", required=False, default=None
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token",
|
||||
type=str,
|
||||
help="HuggingFace authentication token",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir", type=str, help="Cache directory", required=False, default=None
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_dir", type=str, help="Relative path to a specific subset of your dataset", required=False, default=None
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trust_remote_code",
|
||||
type=bool,
|
||||
help="Trust remote code",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def run():
|
||||
args = parse_args()
|
||||
|
||||
download_and_export(
|
||||
args.name,
|
||||
args.subset,
|
||||
args.file,
|
||||
args.token,
|
||||
args.data_dir,
|
||||
args.cache_dir,
|
||||
args.trust_remote_code,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
@@ -0,0 +1,3 @@
|
||||
pub(crate) mod downloader;
|
||||
|
||||
pub use downloader::*;
|
||||
@@ -0,0 +1,3 @@
|
||||
/// Huggingface source
|
||||
#[cfg(any(feature = "sqlite", feature = "sqlite-bundled"))]
|
||||
pub mod huggingface;
|
||||
@@ -0,0 +1,56 @@
|
||||
use crate::Dataset;
|
||||
|
||||
/// Compose multiple datasets together to create a bigger one.
|
||||
#[derive(new)]
|
||||
pub struct ComposedDataset<D> {
|
||||
datasets: Vec<D>,
|
||||
}
|
||||
|
||||
impl<D, I> Dataset<I> for ComposedDataset<D>
|
||||
where
|
||||
D: Dataset<I>,
|
||||
I: Clone,
|
||||
{
|
||||
fn get(&self, index: usize) -> Option<I> {
|
||||
let mut current_index = 0;
|
||||
for dataset in self.datasets.iter() {
|
||||
if index < dataset.len() + current_index {
|
||||
return dataset.get(index - current_index);
|
||||
}
|
||||
current_index += dataset.len();
|
||||
}
|
||||
None
|
||||
}
|
||||
fn len(&self) -> usize {
|
||||
let mut total = 0;
|
||||
for dataset in self.datasets.iter() {
|
||||
total += dataset.len();
|
||||
}
|
||||
total
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::FakeDataset;
|
||||
|
||||
#[test]
|
||||
fn test_composed_dataset() {
|
||||
let dataset1 = FakeDataset::<String>::new(10);
|
||||
let dataset2 = FakeDataset::<String>::new(5);
|
||||
|
||||
let items1 = dataset1.iter().collect::<Vec<_>>();
|
||||
let items2 = dataset2.iter().collect::<Vec<_>>();
|
||||
|
||||
let composed = ComposedDataset::new(vec![dataset1, dataset2]);
|
||||
|
||||
assert_eq!(composed.len(), 15);
|
||||
|
||||
let expected_items: Vec<String> = items1.iter().chain(items2.iter()).cloned().collect();
|
||||
|
||||
let items = composed.iter().collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(items, expected_items);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
use crate::Dataset;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// Basic mapper trait to be used with the [mapper dataset](MapperDataset).
|
||||
pub trait Mapper<I, O>: Send + Sync {
|
||||
/// Maps an item of type I to an item of type O.
|
||||
fn map(&self, item: &I) -> O;
|
||||
}
|
||||
|
||||
/// Dataset mapping each element in an inner dataset to another element type lazily.
|
||||
#[derive(new)]
|
||||
pub struct MapperDataset<D, M, I> {
|
||||
dataset: D,
|
||||
mapper: M,
|
||||
input: PhantomData<I>,
|
||||
}
|
||||
|
||||
impl<D, M, I, O> Dataset<O> for MapperDataset<D, M, I>
|
||||
where
|
||||
D: Dataset<I>,
|
||||
M: Mapper<I, O> + Send + Sync,
|
||||
I: Send + Sync,
|
||||
O: Send + Sync,
|
||||
{
|
||||
fn get(&self, index: usize) -> Option<O> {
|
||||
let item = self.dataset.get(index);
|
||||
item.map(|item| self.mapper.map(&item))
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.dataset.len()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{InMemDataset, test_data};
|
||||
|
||||
#[test]
|
||||
pub fn given_mapper_dataset_when_iterate_should_iterate_though_all_map_items() {
|
||||
struct StringToFirstChar;
|
||||
|
||||
impl Mapper<String, String> for StringToFirstChar {
|
||||
fn map(&self, item: &String) -> String {
|
||||
let mut item = item.clone();
|
||||
item.truncate(1);
|
||||
item
|
||||
}
|
||||
}
|
||||
|
||||
let items_original = test_data::string_items();
|
||||
let dataset = InMemDataset::new(items_original);
|
||||
let dataset = MapperDataset::new(dataset, StringToFirstChar);
|
||||
|
||||
let items: Vec<String> = dataset.iter().collect();
|
||||
|
||||
assert_eq!(vec!["1", "2", "3", "4"], items);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
//! # Dataset Transformations
|
||||
//!
|
||||
//! This module provides a collection of [`crate::Dataset`] composition wrappers;
|
||||
//! providing composition, subset selection, sampling, random shuffling, and windowing.
|
||||
//!
|
||||
//! * [`ComposedDataset`] - composes a list of datasets.
|
||||
//! * [`PartialDataset`] - selects a contiguous index range subset of a dataset.
|
||||
//! * [`ShuffledDataset`] - a randomly shuffled / mutably shuffle-able dataset;
|
||||
//! a thin wrapper around [`SelectionDataset`].
|
||||
//! * [`SamplerDataset`] - samples a dataset; support for with/without replacement,
|
||||
//! and under/oversampling.
|
||||
//! * [`SelectionDataset`] - selects a subset of a dataset via indices; support for shuffling.
|
||||
//! * [`WindowsDataset`] - creates a sliding window over a dataset.
|
||||
mod composed;
|
||||
mod mapper;
|
||||
mod options;
|
||||
mod partial;
|
||||
mod sampler;
|
||||
mod selection;
|
||||
mod shuffle;
|
||||
mod window;
|
||||
|
||||
pub use composed::*;
|
||||
pub use mapper::*;
|
||||
pub use options::*;
|
||||
pub use partial::*;
|
||||
pub use sampler::*;
|
||||
pub use selection::*;
|
||||
pub use shuffle::*;
|
||||
pub use window::*;
|
||||
@@ -0,0 +1,199 @@
|
||||
use rand::SeedableRng;
|
||||
use rand::prelude::StdRng;
|
||||
use rand::rngs::SysRng;
|
||||
|
||||
/// Defines a source for a `StdRng`.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use rand::rngs::StdRng;
|
||||
/// use rand::SeedableRng;
|
||||
/// use burn_dataset::transform::RngSource;
|
||||
///
|
||||
/// // Default via `StdRng::from_os_rng()` (`RngSource::Default`)
|
||||
/// let system: RngSource = RngSource::default();
|
||||
///
|
||||
/// // From a fixed seed (`RngSource::Seed`)
|
||||
/// let seeded: RngSource = 42.into();
|
||||
///
|
||||
/// // From an existing rng (`RngSource::Rng`)
|
||||
/// let rng = StdRng::seed_from_u64(123);
|
||||
/// let with_rng: RngSource = rng.into();
|
||||
///
|
||||
/// // Forks the parent RNG to derive an independent, deterministic child RNG.
|
||||
/// // The original `rng` is modified, and the resulting `RngSource` contains
|
||||
/// // a new RNG starting from a unique state.
|
||||
/// let mut rng = StdRng::seed_from_u64(123);
|
||||
/// let forked: RngSource = (&mut rng).into();
|
||||
/// ```
|
||||
#[derive(Debug, Default, PartialEq, Eq)]
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
pub enum RngSource {
|
||||
/// Build a new rng from the system.
|
||||
#[default]
|
||||
Default,
|
||||
|
||||
/// The rng is passed as a seed.
|
||||
Seed(u64),
|
||||
|
||||
/// The rng is passed as an option.
|
||||
Rng(StdRng),
|
||||
}
|
||||
|
||||
impl From<RngSource> for StdRng {
|
||||
fn from(source: RngSource) -> Self {
|
||||
match source {
|
||||
RngSource::Default => StdRng::try_from_rng(&mut SysRng).unwrap(),
|
||||
RngSource::Rng(rng) => rng,
|
||||
RngSource::Seed(seed) => StdRng::seed_from_u64(seed),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<u64> for RngSource {
|
||||
fn from(seed: u64) -> Self {
|
||||
Self::Seed(seed)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StdRng> for RngSource {
|
||||
fn from(rng: StdRng) -> Self {
|
||||
Self::Rng(rng)
|
||||
}
|
||||
}
|
||||
|
||||
/// Derive an independent RNG from a mutable parent RNG.
|
||||
///
|
||||
/// This advances the parent RNG and creates a new RNG seeded from its output.
|
||||
/// The derived RNG is *not* a clone of the parent's state, but an independent
|
||||
/// stream (equivalent to `SeedableRng::fork`).
|
||||
impl From<&mut StdRng> for RngSource {
|
||||
fn from(rng: &mut StdRng) -> Self {
|
||||
Self::Rng(rng.fork())
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper option to describe the size of a wrapper, relative to a wrapped object.
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq)]
|
||||
pub enum SizeConfig {
|
||||
/// Use the size of the source dataset.
|
||||
#[default]
|
||||
Default,
|
||||
|
||||
/// Use the size as a ratio of the source dataset size.
|
||||
///
|
||||
/// Must be >= 0.
|
||||
Ratio(f64),
|
||||
|
||||
/// Use a fixed size.
|
||||
Fixed(usize),
|
||||
}
|
||||
|
||||
impl SizeConfig {
|
||||
/// Construct a source which will have the same size as the source dataset.
|
||||
pub fn source() -> Self {
|
||||
Self::Default
|
||||
}
|
||||
|
||||
/// Resolve the effective size.
|
||||
///
|
||||
/// ## Arguments
|
||||
///
|
||||
/// - `source_size`: the size of the source dataset.
|
||||
///
|
||||
/// ## Returns
|
||||
///
|
||||
/// The resolved size of the wrapper dataset.
|
||||
pub fn resolve(self, source_size: usize) -> usize {
|
||||
match self {
|
||||
SizeConfig::Default => source_size,
|
||||
SizeConfig::Ratio(ratio) => {
|
||||
assert!(ratio >= 0.0, "Ratio must be positive: {ratio}");
|
||||
((source_size as f64) * ratio) as usize
|
||||
}
|
||||
SizeConfig::Fixed(size) => size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<usize> for SizeConfig {
|
||||
fn from(size: usize) -> Self {
|
||||
Self::Fixed(size)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<f64> for SizeConfig {
|
||||
fn from(ratio: f64) -> Self {
|
||||
Self::Ratio(ratio)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rand::SeedableRng;
|
||||
|
||||
#[test]
|
||||
fn test_rng_source_default() {
|
||||
let rng_source: RngSource = Default::default();
|
||||
assert_eq!(&rng_source, &RngSource::Default);
|
||||
assert_eq!(&rng_source, &RngSource::default());
|
||||
|
||||
// Exercise the from_os_rng() call; but we don't know its seed;
|
||||
let _rng: StdRng = rng_source.into();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rng_source_seed() {
|
||||
let rng_source = RngSource::from(42);
|
||||
assert_eq!(&rng_source, &RngSource::Seed(42));
|
||||
|
||||
let rng: StdRng = rng_source.into();
|
||||
let expected = StdRng::seed_from_u64(42);
|
||||
|
||||
assert_eq!(rng, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rng_source_rng() {
|
||||
// From StdRng (owned).
|
||||
{
|
||||
let original = StdRng::seed_from_u64(42);
|
||||
|
||||
let rng_source = RngSource::from(original);
|
||||
let rng: StdRng = rng_source.into();
|
||||
// No longer clone, but from <> into should not have advanced the state
|
||||
let original = StdRng::seed_from_u64(42);
|
||||
assert_eq!(rng, original);
|
||||
}
|
||||
|
||||
// From &mut StdRng (forks parent)
|
||||
{
|
||||
let mut original = StdRng::seed_from_u64(42);
|
||||
let mut rng = StdRng::seed_from_u64(42);
|
||||
let rng_forked = rng.fork();
|
||||
|
||||
let rng_source = RngSource::from(&mut original);
|
||||
|
||||
// Ensure the original was advanced
|
||||
assert_eq!(original, rng);
|
||||
|
||||
// Ensure the sourced RNG matches the fork
|
||||
let rng: StdRng = rng_source.into();
|
||||
assert_eq!(rng, rng_forked);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_size_config() {
|
||||
assert_eq!(SizeConfig::default(), SizeConfig::Default);
|
||||
|
||||
assert_eq!(SizeConfig::from(42), SizeConfig::Fixed(42));
|
||||
|
||||
assert_eq!(SizeConfig::from(1.5), SizeConfig::Ratio(1.5));
|
||||
|
||||
assert_eq!(SizeConfig::source(), SizeConfig::Default);
|
||||
assert_eq!(SizeConfig::source().resolve(50), 50);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,206 @@
|
||||
use crate::Dataset;
|
||||
use std::{marker::PhantomData, sync::Arc};
|
||||
|
||||
/// Only use a fraction of an existing dataset lazily.
|
||||
#[derive(new, Clone)]
|
||||
pub struct PartialDataset<D, I> {
|
||||
dataset: D,
|
||||
start_index: usize,
|
||||
end_index: usize,
|
||||
input: PhantomData<I>,
|
||||
}
|
||||
|
||||
impl<D, I> PartialDataset<D, I>
|
||||
where
|
||||
D: Dataset<I>,
|
||||
{
|
||||
/// Splits a dataset into multiple partial datasets.
|
||||
pub fn split(dataset: D, num: usize) -> Vec<PartialDataset<Arc<D>, I>> {
|
||||
let dataset = Arc::new(dataset); // cheap cloning.
|
||||
|
||||
let mut current = 0;
|
||||
let mut datasets = Vec::with_capacity(num);
|
||||
|
||||
let batch_size = dataset.len() / num;
|
||||
|
||||
for i in 0..num {
|
||||
let start = current;
|
||||
let mut end = current + batch_size;
|
||||
|
||||
if i == (num - 1) {
|
||||
end = dataset.len();
|
||||
}
|
||||
|
||||
let dataset = PartialDataset::new(dataset.clone(), start, end);
|
||||
|
||||
current += batch_size;
|
||||
datasets.push(dataset);
|
||||
}
|
||||
|
||||
datasets
|
||||
}
|
||||
|
||||
/// Splits a dataset by distributing complete chunks/batches across multiple partial datasets.
|
||||
pub fn split_chunks(
|
||||
dataset: D,
|
||||
num: usize,
|
||||
batch_size: usize,
|
||||
) -> Vec<PartialDataset<Arc<D>, I>> {
|
||||
let dataset = Arc::new(dataset); // cheap cloning.
|
||||
let total_items = dataset.len();
|
||||
|
||||
// Total number of complete batches
|
||||
let total_batches = total_items.div_ceil(batch_size);
|
||||
let batches_per_split = total_batches / num;
|
||||
let extra_batches = total_batches % num;
|
||||
|
||||
let mut datasets = Vec::with_capacity(num);
|
||||
let mut current_batch = 0;
|
||||
|
||||
for i in 0..num {
|
||||
// Extra batches distributed across first splits
|
||||
let split_batches = if i < extra_batches {
|
||||
batches_per_split + 1
|
||||
} else {
|
||||
batches_per_split
|
||||
};
|
||||
|
||||
let start_batch = current_batch;
|
||||
let end_batch = start_batch + split_batches;
|
||||
|
||||
let start_index = start_batch * batch_size;
|
||||
let end_index = core::cmp::min(end_batch * batch_size, total_items);
|
||||
|
||||
if start_index < total_items {
|
||||
datasets.push(PartialDataset::new(dataset.clone(), start_index, end_index));
|
||||
}
|
||||
|
||||
current_batch = end_batch;
|
||||
}
|
||||
|
||||
datasets
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, I> Dataset<I> for PartialDataset<D, I>
|
||||
where
|
||||
D: Dataset<I>,
|
||||
I: Clone + Send + Sync,
|
||||
{
|
||||
fn get(&self, index: usize) -> Option<I> {
|
||||
let index = index + self.start_index;
|
||||
if index < self.start_index || index >= self.end_index {
|
||||
return None;
|
||||
}
|
||||
self.dataset.get(index)
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
usize::min(self.end_index - self.start_index, self.dataset.len())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::FakeDataset;
|
||||
use std::collections::HashSet;
|
||||
|
||||
#[test]
|
||||
fn test_start_from_beginning() {
|
||||
let dataset_original = FakeDataset::<String>::new(27);
|
||||
let mut items_original_1 = HashSet::new();
|
||||
let mut items_original_2 = HashSet::new();
|
||||
let mut items_partial = HashSet::new();
|
||||
dataset_original.iter().enumerate().for_each(|(i, item)| {
|
||||
match i >= 10 {
|
||||
true => items_original_2.insert(item),
|
||||
false => items_original_1.insert(item),
|
||||
};
|
||||
});
|
||||
|
||||
let dataset_partial = PartialDataset::new(dataset_original, 0, 10);
|
||||
|
||||
for item in dataset_partial.iter() {
|
||||
items_partial.insert(item);
|
||||
}
|
||||
|
||||
assert_eq!(dataset_partial.len(), 10);
|
||||
assert_eq!(items_original_1, items_partial);
|
||||
for item in items_original_2 {
|
||||
assert!(!items_partial.contains(&item));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_start_inside() {
|
||||
let dataset_original = FakeDataset::<String>::new(27);
|
||||
let mut items_original_1 = HashSet::new();
|
||||
let mut items_original_2 = HashSet::new();
|
||||
let mut items_partial = HashSet::new();
|
||||
|
||||
dataset_original.iter().enumerate().for_each(|(i, item)| {
|
||||
match !(10..20).contains(&i) {
|
||||
true => items_original_2.insert(item),
|
||||
false => items_original_1.insert(item),
|
||||
};
|
||||
});
|
||||
|
||||
let dataset_partial = PartialDataset::new(dataset_original, 10, 20);
|
||||
for item in dataset_partial.iter() {
|
||||
items_partial.insert(item);
|
||||
}
|
||||
|
||||
assert_eq!(dataset_partial.len(), 10);
|
||||
assert_eq!(items_original_1, items_partial);
|
||||
for item in items_original_2 {
|
||||
assert!(!items_partial.contains(&item));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_split_contains_all_items_without_duplicates() {
|
||||
let dataset_original = FakeDataset::<String>::new(27);
|
||||
let mut items_original = Vec::new();
|
||||
let mut items_partial = Vec::new();
|
||||
for item in dataset_original.iter() {
|
||||
items_original.push(item);
|
||||
}
|
||||
|
||||
let dataset_partials = PartialDataset::split(dataset_original, 4);
|
||||
let expected_len = [6, 6, 6, 9];
|
||||
|
||||
for (i, dataset) in dataset_partials.iter().enumerate() {
|
||||
assert_eq!(dataset.len(), expected_len[i]);
|
||||
for item in dataset.iter() {
|
||||
items_partial.push(item);
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(items_original, items_partial);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_split_chunks_contains_all_items_without_duplicates() {
|
||||
let dataset_original = FakeDataset::<String>::new(27);
|
||||
let mut items_original = Vec::new();
|
||||
let mut items_partial = Vec::new();
|
||||
for item in dataset_original.iter() {
|
||||
items_original.push(item);
|
||||
}
|
||||
|
||||
let dataset_partials = PartialDataset::split_chunks(dataset_original, 4, 5);
|
||||
// [(2 * 5), (2 * 5), 5, 2] -> 5 complete chunks + 1 incomplete with 2 remaining items
|
||||
// OTOH, `split(dataset, 4)` would yield [6, 6, 6, 9] -> 4 incomplete chunks + 4 incomplete with [1, 1, 1, 4]
|
||||
let expected_len = [10, 10, 5, 2];
|
||||
|
||||
for (i, dataset) in dataset_partials.iter().enumerate() {
|
||||
assert_eq!(dataset.len(), expected_len[i]);
|
||||
for item in dataset.iter() {
|
||||
items_partial.push(item);
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(items_original, items_partial);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,438 @@
|
||||
use crate::Dataset;
|
||||
use crate::transform::{RngSource, SizeConfig};
|
||||
use rand::prelude::SliceRandom;
|
||||
use rand::{RngExt, distr::Uniform, rngs::StdRng, seq::IteratorRandom};
|
||||
use std::{marker::PhantomData, ops::DerefMut, sync::Mutex};
|
||||
|
||||
/// Options to configure a [SamplerDataset].
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct SamplerDatasetOptions {
|
||||
/// The sampling mode.
|
||||
pub replace_samples: bool,
|
||||
|
||||
/// The size source of the wrapper relative to the dataset.
|
||||
pub size_config: SizeConfig,
|
||||
|
||||
/// The source of the random number generator.
|
||||
pub rng_source: RngSource,
|
||||
}
|
||||
|
||||
impl Default for SamplerDatasetOptions {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
replace_samples: true,
|
||||
size_config: SizeConfig::Default,
|
||||
rng_source: RngSource::Default,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<Option<T>> for SamplerDatasetOptions
|
||||
where
|
||||
T: Into<SamplerDatasetOptions>,
|
||||
{
|
||||
fn from(option: Option<T>) -> Self {
|
||||
match option {
|
||||
Some(option) => option.into(),
|
||||
None => Self::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<usize> for SamplerDatasetOptions {
|
||||
fn from(size: usize) -> Self {
|
||||
Self::default().with_replacement().with_fixed_size(size)
|
||||
}
|
||||
}
|
||||
|
||||
impl SamplerDatasetOptions {
|
||||
/// Set the replacement mode.
|
||||
pub fn with_replace_samples(self, replace_samples: bool) -> Self {
|
||||
Self {
|
||||
replace_samples,
|
||||
..self
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the replacement mode to WithReplacement.
|
||||
pub fn with_replacement(self) -> Self {
|
||||
self.with_replace_samples(true)
|
||||
}
|
||||
|
||||
/// Set the replacement mode to WithoutReplacement.
|
||||
pub fn without_replacement(self) -> Self {
|
||||
self.with_replace_samples(false)
|
||||
}
|
||||
|
||||
/// Set the size source.
|
||||
pub fn with_size<S>(self, source: S) -> Self
|
||||
where
|
||||
S: Into<SizeConfig>,
|
||||
{
|
||||
Self {
|
||||
size_config: source.into(),
|
||||
..self
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the size to the size of the source.
|
||||
pub fn with_source_size(self) -> Self {
|
||||
self.with_size(SizeConfig::Default)
|
||||
}
|
||||
|
||||
/// Set the size to a fixed size.
|
||||
pub fn with_fixed_size(self, size: usize) -> Self {
|
||||
self.with_size(size)
|
||||
}
|
||||
|
||||
/// Set the size to be a multiple of the ration and the source size.
|
||||
pub fn with_size_ratio(self, size_ratio: f64) -> Self {
|
||||
self.with_size(size_ratio)
|
||||
}
|
||||
|
||||
/// Set the `RngSource`.
|
||||
pub fn with_rng<R>(self, rng: R) -> Self
|
||||
where
|
||||
R: Into<RngSource>,
|
||||
{
|
||||
Self {
|
||||
rng_source: rng.into(),
|
||||
..self
|
||||
}
|
||||
}
|
||||
|
||||
/// Use the system rng.
|
||||
pub fn with_system_rng(self) -> Self {
|
||||
self.with_rng(RngSource::Default)
|
||||
}
|
||||
|
||||
/// Use a rng, built from a seed.
|
||||
pub fn with_seed(self, seed: u64) -> Self {
|
||||
self.with_rng(seed)
|
||||
}
|
||||
}
|
||||
|
||||
/// Sample items from a dataset.
|
||||
///
|
||||
/// This is a convenient way of modeling a dataset as a probability distribution of a fixed size.
|
||||
/// You have multiple options to instantiate the dataset sampler.
|
||||
///
|
||||
/// * With replacement (Default): This is the most efficient way of using the sampler because no state is
|
||||
/// required to keep indices that have been selected.
|
||||
///
|
||||
/// * Without replacement: This has a similar effect to using a
|
||||
/// [shuffled dataset](crate::transform::ShuffledDataset), but with more flexibility since you can
|
||||
/// set the dataset to an arbitrary size. Once every item has been used, a new cycle is
|
||||
/// created with a new random suffle.
|
||||
pub struct SamplerDataset<D, I> {
|
||||
dataset: D,
|
||||
size: usize,
|
||||
state: Mutex<SamplerState>,
|
||||
input: PhantomData<I>,
|
||||
}
|
||||
enum SamplerState {
|
||||
WithReplacement(StdRng),
|
||||
WithoutReplacement(StdRng, Vec<usize>),
|
||||
}
|
||||
|
||||
impl<D, I> SamplerDataset<D, I>
|
||||
where
|
||||
D: Dataset<I>,
|
||||
I: Send + Sync,
|
||||
{
|
||||
/// Creates a new sampler dataset with replacement.
|
||||
///
|
||||
/// When the sample size is less than or equal to the source dataset size,
|
||||
/// data will be sampled without replacement from the source dataset in
|
||||
/// a uniformly shuffled order.
|
||||
///
|
||||
/// When the sample size is greater than the source dataset size,
|
||||
/// the entire source dataset will be sampled once for every multiple
|
||||
/// of the size ratios; with the remaining samples taken without replacement
|
||||
/// uniformly from the source. All samples will be returned uniformly shuffled.
|
||||
///
|
||||
/// ## Arguments
|
||||
///
|
||||
/// * `dataset`: the dataset to wrap.
|
||||
/// * `options`: the options to configure the sampler dataset.
|
||||
///
|
||||
/// ## Examples
|
||||
/// ```rust,ignore
|
||||
/// use burn_dataset::transform::{
|
||||
/// SamplerDataset,
|
||||
/// SamplerDatasetOptions,
|
||||
/// };
|
||||
///
|
||||
/// // Examples below assuming `dataset.len()` = `10`.
|
||||
///
|
||||
/// // sample size: 5
|
||||
/// // WithReplacement
|
||||
/// // rng: StdRng::from_os_rng()
|
||||
/// SamplerDataset::new(dataset, 5);
|
||||
///
|
||||
/// // sample size: 10 (source)
|
||||
/// // WithReplacement
|
||||
/// // rng: StdRng::from_os_rng()
|
||||
/// SamplerDataset::new(dataset, SamplerDatasetOptions::default());
|
||||
///
|
||||
/// // sample size: 15
|
||||
/// // WithoutReplacement
|
||||
/// // rng: StdRng::seed_from_u64(42)
|
||||
/// SamplerDataset::new(
|
||||
/// dataset,
|
||||
/// SamplerDatasetOptions::default()
|
||||
/// .with_size(1.5)
|
||||
/// .without_replacement()
|
||||
/// .with_rng(42),
|
||||
/// );
|
||||
/// ```
|
||||
pub fn new<O>(dataset: D, options: O) -> Self
|
||||
where
|
||||
O: Into<SamplerDatasetOptions>,
|
||||
{
|
||||
let options = options.into();
|
||||
let size = options.size_config.resolve(dataset.len());
|
||||
let rng = options.rng_source.into();
|
||||
Self {
|
||||
dataset,
|
||||
size,
|
||||
state: Mutex::new(match options.replace_samples {
|
||||
true => SamplerState::WithReplacement(rng),
|
||||
false => SamplerState::WithoutReplacement(rng, Vec::with_capacity(size)),
|
||||
}),
|
||||
input: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new sampler dataset with replacement.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `dataset`: the dataset to wrap.
|
||||
/// - `size`: the effective size of the sampled dataset.
|
||||
pub fn with_replacement(dataset: D, size: usize) -> Self {
|
||||
Self::new(
|
||||
dataset,
|
||||
SamplerDatasetOptions::default()
|
||||
.with_replacement()
|
||||
.with_fixed_size(size),
|
||||
)
|
||||
}
|
||||
|
||||
/// Creates a new sampler dataset without replacement.
|
||||
///
|
||||
/// When the sample size is less than or equal to the source dataset size,
|
||||
/// data will be sampled without replacement from the source dataset in
|
||||
/// a uniformly shuffled order.
|
||||
///
|
||||
/// When the sample size is greater than the source dataset size,
|
||||
/// the entire source dataset will be sampled once for every multiple
|
||||
/// of the size ratios; with the remaining samples taken without replacement
|
||||
/// uniformly from the source. All samples will be returned uniformly shuffled.
|
||||
///
|
||||
/// # Arguments
|
||||
/// - `dataset`: the dataset to wrap.
|
||||
/// - `size`: the effective size of the sampled dataset.
|
||||
pub fn without_replacement(dataset: D, size: usize) -> Self {
|
||||
Self::new(
|
||||
dataset,
|
||||
SamplerDatasetOptions::default()
|
||||
.without_replacement()
|
||||
.with_fixed_size(size),
|
||||
)
|
||||
}
|
||||
|
||||
/// Determines if the sampler is using the "with replacement" strategy.
|
||||
///
|
||||
/// # Returns
|
||||
/// - `true`: If the sampler is configured to sample with replacement.
|
||||
/// - `false`: If the sampler is configured to sample without replacement.
|
||||
pub fn is_with_replacement(&self) -> bool {
|
||||
match self.state.lock().unwrap().deref_mut() {
|
||||
SamplerState::WithReplacement(_) => true,
|
||||
SamplerState::WithoutReplacement(_, _) => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn index(&self) -> usize {
|
||||
match self.state.lock().unwrap().deref_mut() {
|
||||
SamplerState::WithReplacement(rng) => {
|
||||
rng.sample(Uniform::new(0, self.dataset.len()).unwrap())
|
||||
}
|
||||
SamplerState::WithoutReplacement(rng, indices) => {
|
||||
if indices.is_empty() {
|
||||
// Refill the state.
|
||||
let idx_range = 0..self.dataset.len();
|
||||
for _ in 0..(self.size / self.dataset.len()) {
|
||||
// No need to `.choose_multiple` here because we're using
|
||||
// the entire source range; and `.choose_multiple` will
|
||||
// not return a random sample anyway.
|
||||
indices.extend(idx_range.clone())
|
||||
}
|
||||
|
||||
// From `choose_multiple` documentation:
|
||||
// > Although the elements are selected randomly, the order of elements in
|
||||
// > the buffer is neither stable nor fully random. If random ordering is
|
||||
// > desired, shuffle the result.
|
||||
indices.extend(idx_range.sample(rng, self.size - indices.len()));
|
||||
|
||||
// The real shuffling is done here.
|
||||
indices.shuffle(rng);
|
||||
}
|
||||
|
||||
indices.pop().expect("Indices are refilled when empty.")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, I> Dataset<I> for SamplerDataset<D, I>
|
||||
where
|
||||
D: Dataset<I>,
|
||||
I: Send + Sync,
|
||||
{
|
||||
fn get(&self, index: usize) -> Option<I> {
|
||||
if index >= self.size {
|
||||
return None;
|
||||
}
|
||||
|
||||
self.dataset.get(self.index())
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![allow(clippy::bool_assert_comparison)]
|
||||
|
||||
use super::*;
|
||||
use crate::FakeDataset;
|
||||
use rand::SeedableRng;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[test]
|
||||
fn test_samplerdataset_options() {
|
||||
let options = SamplerDatasetOptions::default();
|
||||
assert_eq!(options.replace_samples, true);
|
||||
assert_eq!(options.size_config, SizeConfig::Default);
|
||||
assert_eq!(options.rng_source, RngSource::Default);
|
||||
|
||||
// ReplacementMode
|
||||
let options = options.with_replace_samples(false);
|
||||
assert_eq!(options.replace_samples, false);
|
||||
let options = options.with_replacement();
|
||||
assert_eq!(options.replace_samples, true);
|
||||
let options = options.without_replacement();
|
||||
assert_eq!(options.replace_samples, false);
|
||||
|
||||
// SourceSize
|
||||
let options = options.with_size(SizeConfig::Default);
|
||||
assert_eq!(options.size_config, SizeConfig::Default);
|
||||
let options = options.with_source_size();
|
||||
assert_eq!(options.size_config, SizeConfig::Default);
|
||||
let options = options.with_fixed_size(10);
|
||||
assert_eq!(options.size_config, SizeConfig::Fixed(10));
|
||||
let options = options.with_size_ratio(1.5);
|
||||
assert_eq!(options.size_config, SizeConfig::Ratio(1.5));
|
||||
|
||||
// RngSource
|
||||
let options = options.with_system_rng();
|
||||
assert_eq!(options.rng_source, RngSource::Default);
|
||||
let options = options.with_seed(42);
|
||||
assert_eq!(options.rng_source, RngSource::Seed(42));
|
||||
let rng = StdRng::seed_from_u64(9);
|
||||
let options = options.with_rng(rng);
|
||||
assert!(matches!(options.rng_source, RngSource::Rng(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sampler_dataset_constructors_test() {
|
||||
let ds = SamplerDataset::new(FakeDataset::<u32>::new(10), 15);
|
||||
assert_eq!(ds.len(), 15);
|
||||
assert_eq!(ds.dataset.len(), 10);
|
||||
assert!(ds.is_with_replacement());
|
||||
|
||||
let ds = SamplerDataset::with_replacement(FakeDataset::<u32>::new(10), 15);
|
||||
assert_eq!(ds.len(), 15);
|
||||
assert_eq!(ds.dataset.len(), 10);
|
||||
assert!(ds.is_with_replacement());
|
||||
|
||||
let ds = SamplerDataset::without_replacement(FakeDataset::<u32>::new(10), 15);
|
||||
assert_eq!(ds.len(), 15);
|
||||
assert_eq!(ds.dataset.len(), 10);
|
||||
assert!(!ds.is_with_replacement());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sampler_dataset_with_replacement_iter() {
|
||||
let factor = 3;
|
||||
let len_original = 10;
|
||||
let dataset_sampler = SamplerDataset::with_replacement(
|
||||
FakeDataset::<String>::new(len_original),
|
||||
len_original * factor,
|
||||
);
|
||||
let mut total = 0;
|
||||
|
||||
for _item in dataset_sampler.iter() {
|
||||
total += 1;
|
||||
}
|
||||
|
||||
assert_eq!(total, factor * len_original);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sampler_dataset_without_replacement_bucket_test() {
|
||||
let factor = 3;
|
||||
let len_original = 10;
|
||||
|
||||
let dataset_sampler = SamplerDataset::new(
|
||||
FakeDataset::<String>::new(len_original),
|
||||
SamplerDatasetOptions::default()
|
||||
.without_replacement()
|
||||
.with_size_ratio(factor as f64),
|
||||
);
|
||||
|
||||
let mut buckets = HashMap::new();
|
||||
|
||||
for item in dataset_sampler.iter() {
|
||||
let count = match buckets.get(&item) {
|
||||
Some(count) => count + 1,
|
||||
None => 1,
|
||||
};
|
||||
|
||||
buckets.insert(item, count);
|
||||
}
|
||||
|
||||
let mut total = 0;
|
||||
for count in buckets.into_values() {
|
||||
assert_eq!(count, factor);
|
||||
total += count;
|
||||
}
|
||||
assert_eq!(total, factor * len_original);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sampler_dataset_without_replacement_uniform_order_test() {
|
||||
// This is a reversion test on the indices.shuffle(rng) call in SamplerDataset::index().
|
||||
let size = 1000;
|
||||
let dataset_sampler =
|
||||
SamplerDataset::without_replacement(FakeDataset::<i32>::new(size), size);
|
||||
|
||||
let indices: Vec<_> = (0..size).map(|_| dataset_sampler.index()).collect();
|
||||
let mean_delta = indices
|
||||
.windows(2)
|
||||
.map(|pair| pair[1].abs_diff(pair[0]))
|
||||
.sum::<usize>() as f64
|
||||
/ (size - 1) as f64;
|
||||
|
||||
let expected = (size + 2) as f64 / 3.0;
|
||||
|
||||
assert!(
|
||||
(mean_delta - expected).abs() <= 0.25 * expected,
|
||||
"Sampled indices are not uniformly distributed: mean_delta: {mean_delta}, expected: {expected}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,374 @@
|
||||
use crate::Dataset;
|
||||
use crate::transform::RngSource;
|
||||
use rand::prelude::SliceRandom;
|
||||
use rand::rngs::StdRng;
|
||||
use std::marker::PhantomData;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Generates a vector of indices from 0 to size - 1.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `size` - The size of the dataset.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A vector containing indices from 0 to size - 1.
|
||||
#[inline(always)]
|
||||
pub fn iota(size: usize) -> Vec<usize> {
|
||||
(0..size).collect()
|
||||
}
|
||||
|
||||
/// Generates a shuffled vector of indices up to a size.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `size` - The size of the dataset to shuffle.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A vector of shuffled indices.
|
||||
#[inline(always)]
|
||||
pub fn shuffled_indices(size: usize, rng: &mut StdRng) -> Vec<usize> {
|
||||
let mut indices = iota(size);
|
||||
indices.shuffle(rng);
|
||||
indices
|
||||
}
|
||||
|
||||
/// A dataset that selects a subset of indices from an existing dataset.
|
||||
///
|
||||
/// Indices may appear multiple times, but they must be within the bounds of the original dataset.
|
||||
#[derive(Clone)]
|
||||
pub struct SelectionDataset<D, I>
|
||||
where
|
||||
D: Dataset<I>,
|
||||
I: Clone + Send + Sync,
|
||||
{
|
||||
/// The wrapped dataset from which to select indices.
|
||||
pub wrapped: Arc<D>,
|
||||
|
||||
/// The indices to select from the wrapped dataset.
|
||||
pub indices: Vec<usize>,
|
||||
|
||||
input: PhantomData<I>,
|
||||
}
|
||||
|
||||
impl<D, I> SelectionDataset<D, I>
|
||||
where
|
||||
D: Dataset<I>,
|
||||
I: Clone + Send + Sync,
|
||||
{
|
||||
/// Creates a new selection dataset with the given dataset and indices.
|
||||
///
|
||||
/// Checks that all indices are within the bounds of the dataset.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dataset` - The original dataset to select from.
|
||||
/// * `indices` - A slice of indices to select from the dataset.
|
||||
/// These indices must be within the bounds of the dataset.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if any index is out of bounds for the dataset.
|
||||
pub fn from_indices_checked<S>(dataset: S, indices: Vec<usize>) -> Self
|
||||
where
|
||||
S: Into<Arc<D>>,
|
||||
{
|
||||
let dataset = dataset.into();
|
||||
|
||||
let size = dataset.len();
|
||||
if let Some(idx) = indices.iter().find(|&i| *i >= size) {
|
||||
panic!("Index out of bounds for wrapped dataset size: {idx} >= {size}");
|
||||
}
|
||||
|
||||
Self::from_indices_unchecked(dataset, indices)
|
||||
}
|
||||
|
||||
/// Creates a new selection dataset with the given dataset and indices without checking bounds.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dataset` - The original dataset to select from.
|
||||
/// * `indices` - A vector of indices to select from the dataset.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// This function does not check if the indices are within the bounds of the dataset.
|
||||
pub fn from_indices_unchecked<S>(dataset: S, indices: Vec<usize>) -> Self
|
||||
where
|
||||
S: Into<Arc<D>>,
|
||||
{
|
||||
Self {
|
||||
wrapped: dataset.into(),
|
||||
indices,
|
||||
input: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new selection dataset that selects all indices from the dataset.
|
||||
///
|
||||
/// This allocates a 1-to-1 mapping of indices to the dataset size,
|
||||
/// essentially functioning as a no-op selection. This is only useful
|
||||
/// when the dataset will later be shuffled or transformed in place.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dataset` - The original dataset to select from.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new `SelectionDataset` that selects all indices from the dataset.
|
||||
pub fn new_select_all<S>(dataset: S) -> Self
|
||||
where
|
||||
S: Into<Arc<D>>,
|
||||
{
|
||||
let dataset = dataset.into();
|
||||
let size = dataset.len();
|
||||
Self::from_indices_unchecked(dataset, iota(size))
|
||||
}
|
||||
|
||||
/// Creates a new selection dataset with shuffled indices.
|
||||
///
|
||||
/// Selects every index of the dataset and shuffles them
|
||||
/// with randomness from the provided random number generator.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dataset` - The original dataset to select from.
|
||||
/// * `rng` - A mutable reference to a random number generator.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new `SelectionDataset` with shuffled indices.
|
||||
pub fn new_shuffled<S, R>(dataset: S, rng_source: R) -> Self
|
||||
where
|
||||
S: Into<Arc<D>>,
|
||||
R: Into<RngSource>,
|
||||
{
|
||||
let mut this = Self::new_select_all(dataset);
|
||||
this.shuffle(rng_source);
|
||||
this
|
||||
}
|
||||
|
||||
/// Shuffles the indices of the dataset using a mutable random number generator.
|
||||
///
|
||||
/// This method modifies the dataset in place, shuffling the indices.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `rng` - A mutable reference to a random number generator.
|
||||
pub fn shuffle<R>(&mut self, rng_source: R)
|
||||
where
|
||||
R: Into<RngSource>,
|
||||
{
|
||||
let mut rng: StdRng = rng_source.into().into();
|
||||
self.indices.shuffle(&mut rng)
|
||||
}
|
||||
|
||||
/// Creates a new dataset that is a slice of the current selection dataset.
|
||||
///
|
||||
/// Slices the *selection indices* from ``[start..end]``.
|
||||
///
|
||||
/// Independent of future shuffles on the parent, but shares the same wrapped dataset.
|
||||
///
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `start` - The start of the range.
|
||||
/// * `end` - The end of the range (exclusive).
|
||||
// TODO: SliceArg in burn-tensor should be lifted to burn-std; this should use SliceArg.
|
||||
pub fn slice(&self, start: usize, end: usize) -> Self {
|
||||
Self::from_indices_unchecked(self.wrapped.clone(), self.indices[start..end].to_vec())
|
||||
}
|
||||
|
||||
/// Split into `num` datasets by slicing the selection indices evenly.
|
||||
///
|
||||
/// Split is done via `slice`, so the datasets share the same wrapped dataset.
|
||||
///
|
||||
/// Independent of future shuffles on the parent, but shares the same wrapped dataset.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `num` - The number of datasets to split into.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A vector of `SelectionDataset` instances, each containing a subset of the indices.
|
||||
pub fn split(&self, num: usize) -> Vec<Self> {
|
||||
let n = self.indices.len();
|
||||
|
||||
let mut current = 0;
|
||||
let mut datasets = Vec::with_capacity(num);
|
||||
|
||||
let batch_size = n / num;
|
||||
for i in 0..num {
|
||||
let start = current;
|
||||
let mut end = current + batch_size;
|
||||
|
||||
if i == (num - 1) {
|
||||
end = n;
|
||||
}
|
||||
|
||||
let dataset = self.slice(start, end);
|
||||
|
||||
current += batch_size;
|
||||
datasets.push(dataset);
|
||||
}
|
||||
|
||||
datasets
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, I> Dataset<I> for SelectionDataset<D, I>
|
||||
where
|
||||
D: Dataset<I>,
|
||||
I: Clone + Send + Sync,
|
||||
{
|
||||
fn get(&self, index: usize) -> Option<I> {
|
||||
let index = self.indices.get(index)?;
|
||||
self.wrapped.get(*index)
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.indices.len()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::FakeDataset;
|
||||
use rand::SeedableRng;
|
||||
|
||||
#[test]
|
||||
fn test_iota() {
|
||||
let size = 10;
|
||||
let indices = iota(size);
|
||||
assert_eq!(indices.len(), size);
|
||||
assert_eq!(indices, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shuffled_indices_same_seed_is_deterministic() {
|
||||
let size = 10;
|
||||
|
||||
let mut rng1 = StdRng::seed_from_u64(10);
|
||||
// `StdRng` is no longer `Clone`, so its internal state cannot be duplicated.
|
||||
// To test determinism, we must explicitly create a second RNG from the same seed.
|
||||
let mut rng2 = StdRng::seed_from_u64(10);
|
||||
|
||||
let mut expected = iota(size);
|
||||
expected.shuffle(&mut rng1);
|
||||
|
||||
let indices = shuffled_indices(size, &mut rng2);
|
||||
|
||||
assert_eq!(indices, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shuffled_indices_forked_rngs_differ() {
|
||||
let size = 10;
|
||||
|
||||
let mut rng1 = StdRng::seed_from_u64(10);
|
||||
let mut rng2 = rng1.fork();
|
||||
|
||||
let mut a = iota(size);
|
||||
let mut b = iota(size);
|
||||
|
||||
a.shuffle(&mut rng1);
|
||||
b.shuffle(&mut rng2);
|
||||
|
||||
assert_ne!(a, b);
|
||||
}
|
||||
|
||||
#[should_panic(expected = "Index out of bounds for wrapped dataset size: 300 >= 27")]
|
||||
#[test]
|
||||
fn test_from_indices_checked_panics() {
|
||||
let source_dataset = FakeDataset::<String>::new(27);
|
||||
let indices: Vec<usize> = vec![15, 1, 12, 300];
|
||||
SelectionDataset::from_indices_checked(source_dataset, indices);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_checked_selection_dataset() {
|
||||
let source_dataset = FakeDataset::<String>::new(27);
|
||||
|
||||
let indices: Vec<usize> = vec![15, 1, 12, 12];
|
||||
let expected: Vec<String> = indices
|
||||
.iter()
|
||||
.map(|i| source_dataset.get(*i).unwrap())
|
||||
.collect();
|
||||
|
||||
let selection = SelectionDataset::from_indices_checked(source_dataset, indices.clone());
|
||||
|
||||
assert_eq!(&selection.indices, &indices);
|
||||
|
||||
let items = selection.iter().collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(items, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shuffled_dataset() {
|
||||
let dataset = FakeDataset::<String>::new(27);
|
||||
let source_items = dataset.iter().collect::<Vec<_>>();
|
||||
|
||||
let selection = SelectionDataset::new_shuffled(dataset, 42);
|
||||
|
||||
let indices = shuffled_indices(source_items.len(), &mut StdRng::seed_from_u64(42));
|
||||
|
||||
assert_eq!(&selection.indices, &indices);
|
||||
assert_eq!(selection.len(), source_items.len());
|
||||
|
||||
let expected_items: Vec<_> = indices
|
||||
.iter()
|
||||
.map(|&i| source_items[i].to_string())
|
||||
.collect();
|
||||
assert_eq!(&selection.iter().collect::<Vec<_>>(), &expected_items);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_slice() {
|
||||
let dataset = FakeDataset::<String>::new(27);
|
||||
let source_items = dataset.iter().collect::<Vec<_>>();
|
||||
|
||||
let selection = SelectionDataset::new_select_all(dataset);
|
||||
|
||||
let start = 5;
|
||||
let end = 15;
|
||||
let sliced_selection = selection.slice(start, end);
|
||||
|
||||
assert_eq!(sliced_selection.len(), end - start);
|
||||
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
for i in start..end {
|
||||
assert_eq!(
|
||||
sliced_selection.get(i - start),
|
||||
Some(source_items[i].to_string())
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_split() {
|
||||
let dataset = FakeDataset::<String>::new(28);
|
||||
let source_items = dataset.iter().collect::<Vec<_>>();
|
||||
|
||||
let selection = SelectionDataset::new_select_all(dataset);
|
||||
|
||||
let split_contents: Vec<Vec<_>> = selection
|
||||
.split(3)
|
||||
.iter()
|
||||
.map(|d| d.iter().collect::<Vec<_>>())
|
||||
.collect();
|
||||
assert_eq!(
|
||||
split_contents,
|
||||
vec![
|
||||
source_items[0..9].to_vec(),
|
||||
source_items[9..18].to_vec(),
|
||||
source_items[18..28].to_vec(),
|
||||
]
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,109 @@
|
||||
use crate::Dataset;
|
||||
use crate::transform::{RngSource, SelectionDataset};
|
||||
|
||||
/// A Shuffled a dataset.
|
||||
///
|
||||
/// This is a thin wrapper around a [SelectionDataset] which selects and shuffles
|
||||
/// the full indices of the original dataset.
|
||||
///
|
||||
/// Consider using [SelectionDataset] if you are only interested in
|
||||
/// shuffling mechanisms.
|
||||
///
|
||||
/// Consider using [sampler dataset](crate::transform::SamplerDataset) if you
|
||||
/// want a probability distribution which is computed lazily.
|
||||
pub struct ShuffledDataset<D, I>
|
||||
where
|
||||
D: Dataset<I>,
|
||||
I: Clone + Send + Sync,
|
||||
{
|
||||
wrapped: SelectionDataset<D, I>,
|
||||
}
|
||||
|
||||
impl<D, I> ShuffledDataset<D, I>
|
||||
where
|
||||
D: Dataset<I>,
|
||||
I: Clone + Send + Sync,
|
||||
{
|
||||
/// Creates a new selection dataset with shuffled indices.
|
||||
///
|
||||
/// This is a thin wrapper around `SelectionDataset::new_shuffled`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dataset` - The original dataset to select from.
|
||||
/// * `rng_source` - The source of the random number generator.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new `ShuffledDataset`.
|
||||
pub fn new<R>(dataset: D, rng_source: R) -> Self
|
||||
where
|
||||
R: Into<RngSource>,
|
||||
{
|
||||
Self {
|
||||
wrapped: SelectionDataset::new_shuffled(dataset, rng_source),
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new selection dataset with shuffled indices using a fixed seed.
|
||||
///
|
||||
/// This is a thin wrapper around `SelectionDataset::new_shuffled_with_seed`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dataset` - The original dataset to select from.
|
||||
/// * `seed` - A fixed seed for the random number generator.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new `ShuffledDataset`.
|
||||
#[deprecated(since = "0.19.0", note = "Use `new(dataset, seed)` instead`")]
|
||||
pub fn with_seed(dataset: D, seed: u64) -> Self {
|
||||
Self::new(dataset, seed)
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, I> Dataset<I> for ShuffledDataset<D, I>
|
||||
where
|
||||
D: Dataset<I>,
|
||||
I: Clone + Send + Sync,
|
||||
{
|
||||
fn get(&self, index: usize) -> Option<I> {
|
||||
self.wrapped.get(index)
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.wrapped.len()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::FakeDataset;
|
||||
use crate::transform::selection::shuffled_indices;
|
||||
use rand::SeedableRng;
|
||||
use rand::prelude::StdRng;
|
||||
|
||||
#[test]
|
||||
fn test_shuffled_dataset() {
|
||||
let dataset = FakeDataset::<String>::new(27);
|
||||
let source_items = dataset.iter().collect::<Vec<_>>();
|
||||
|
||||
let seed = 42;
|
||||
|
||||
#[allow(deprecated)]
|
||||
let shuffled = ShuffledDataset::with_seed(dataset, seed);
|
||||
|
||||
let mut rng = StdRng::seed_from_u64(seed);
|
||||
let indices = shuffled_indices(source_items.len(), &mut rng);
|
||||
|
||||
assert_eq!(shuffled.len(), source_items.len());
|
||||
|
||||
let expected_items: Vec<_> = indices
|
||||
.iter()
|
||||
.map(|&i| source_items[i].to_string())
|
||||
.collect();
|
||||
assert_eq!(&shuffled.iter().collect::<Vec<_>>(), &expected_items);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,290 @@
|
||||
use std::{cmp::max, marker::PhantomData, num::NonZeroUsize};
|
||||
|
||||
use crate::Dataset;
|
||||
|
||||
/// Functionality to create a window.
|
||||
pub trait Window<I> {
|
||||
/// Creates a window of a collection.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A `Vec<I>` representing the window.
|
||||
fn window(&self, current: usize, size: NonZeroUsize) -> Option<Vec<I>>;
|
||||
}
|
||||
|
||||
impl<I, T: Dataset<I> + ?Sized> Window<I> for T {
|
||||
fn window(&self, current: usize, size: NonZeroUsize) -> Option<Vec<I>> {
|
||||
(current..current + size.get())
|
||||
.map(|x| self.get(x))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Functionality to create a `WindowsIterator`.
|
||||
pub trait Windows<I> {
|
||||
/// Creates and returns an iterator over all the windows of length `size`.
|
||||
fn windows(&self, size: usize) -> WindowsIterator<'_, I>;
|
||||
}
|
||||
|
||||
impl<I, T: Dataset<I>> Windows<I> for T {
|
||||
/// Is empty if the `Dataset` is shorter than `size`.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `size` is 0.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use crate::burn_dataset::{
|
||||
/// transform::{Windows, WindowsDataset},
|
||||
/// Dataset, InMemDataset,
|
||||
/// };
|
||||
///
|
||||
/// let items = [1, 2, 3, 4].to_vec();
|
||||
/// let dataset = InMemDataset::new(items.clone());
|
||||
///
|
||||
/// for window in dataset.windows(2) {
|
||||
/// // do sth with window
|
||||
/// }
|
||||
/// ```
|
||||
fn windows(&self, size: usize) -> WindowsIterator<'_, I> {
|
||||
let size = NonZeroUsize::new(size).expect("window size must be non-zero");
|
||||
WindowsIterator::new(self, size)
|
||||
}
|
||||
}
|
||||
|
||||
/// Overlapping windows iterator.
|
||||
pub struct WindowsIterator<'a, I> {
|
||||
/// The size of the windows.
|
||||
pub size: NonZeroUsize,
|
||||
current: usize,
|
||||
dataset: &'a dyn Dataset<I>,
|
||||
}
|
||||
|
||||
impl<'a, I> WindowsIterator<'a, I> {
|
||||
/// Creates a new `WindowsIterator` instance. The windows overlap.
|
||||
/// Is empty if the input `Dataset` is shorter than `size`.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// - `dataset`: The dataset over which windows will be created.
|
||||
/// - `size`: The size of the windows.
|
||||
pub fn new(dataset: &'a dyn Dataset<I>, size: NonZeroUsize) -> Self {
|
||||
WindowsIterator {
|
||||
current: 0,
|
||||
dataset,
|
||||
size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> Iterator for WindowsIterator<'_, I> {
|
||||
type Item = Vec<I>;
|
||||
|
||||
fn next(&mut self) -> Option<Vec<I>> {
|
||||
self.current += 1;
|
||||
self.dataset.window(self.current - 1, self.size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> Clone for WindowsIterator<'_, I> {
|
||||
fn clone(&self) -> Self {
|
||||
WindowsIterator {
|
||||
size: self.size,
|
||||
dataset: self.dataset,
|
||||
current: self.current,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Dataset designed to work with overlapping windows of data.
|
||||
pub struct WindowsDataset<D, I> {
|
||||
/// The size of the windows.
|
||||
pub size: NonZeroUsize,
|
||||
dataset: D,
|
||||
input: PhantomData<I>,
|
||||
}
|
||||
|
||||
impl<D, I> WindowsDataset<D, I>
|
||||
where
|
||||
D: Dataset<I>,
|
||||
{
|
||||
/// Creates a new `WindowsDataset` instance. The windows overlap.
|
||||
/// Is empty if the input `Dataset` is shorter than `size`.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// - `dataset`: The dataset over which windows will be created.
|
||||
/// - `size`: The size of the windows.
|
||||
pub fn new(dataset: D, size: usize) -> Self
|
||||
where
|
||||
D:,
|
||||
{
|
||||
let size = NonZeroUsize::new(size).expect("window size must be non-zero");
|
||||
WindowsDataset::<D, I> {
|
||||
size,
|
||||
dataset,
|
||||
input: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, I> Dataset<Vec<I>> for WindowsDataset<D, I>
|
||||
where
|
||||
D: Dataset<I>,
|
||||
I: Send + Sync,
|
||||
{
|
||||
/// Retrieves a window of items from the dataset.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// - `index`: The index of the window.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A vector representing the window.
|
||||
fn get(&self, index: usize) -> Option<Vec<I>> {
|
||||
self.dataset.window(index, self.size)
|
||||
}
|
||||
|
||||
/// Retrieves the number of windows in the dataset.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A size representing the number of windows.
|
||||
fn len(&self) -> usize {
|
||||
let len = self.dataset.len() as isize - self.size.get() as isize + 1;
|
||||
max(len, 0) as usize
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rstest::rstest;
|
||||
|
||||
use crate::{
|
||||
Dataset, InMemDataset,
|
||||
transform::{Windows, WindowsDataset},
|
||||
};
|
||||
|
||||
#[rstest]
|
||||
pub fn windows_should_be_equal_to_vec_windows() {
|
||||
let items = [1, 2, 3, 4, 5].to_vec();
|
||||
let dataset = InMemDataset::new(items.clone());
|
||||
let expected = items
|
||||
.windows(3)
|
||||
.map(|x| x.to_vec())
|
||||
.collect::<Vec<Vec<i32>>>();
|
||||
|
||||
let result = dataset.windows(3).collect::<Vec<Vec<i32>>>();
|
||||
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
pub fn windows_dataset_should_be_equal_to_vec_windows() {
|
||||
let items = [1, 2, 3, 4, 5].to_vec();
|
||||
let dataset = InMemDataset::new(items.clone());
|
||||
let expected = items
|
||||
.windows(3)
|
||||
.map(|x| x.to_vec())
|
||||
.collect::<Vec<Vec<i32>>>();
|
||||
|
||||
let result = WindowsDataset::new(dataset, 3)
|
||||
.iter()
|
||||
.collect::<Vec<Vec<i32>>>();
|
||||
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
pub fn cloned_iterator_should_be_equal() {
|
||||
let items = [1, 2, 3, 4, 5].to_vec();
|
||||
let dataset = InMemDataset::new(items.clone());
|
||||
let original = dataset.windows(4);
|
||||
|
||||
let cloned = original.clone();
|
||||
|
||||
assert!(std::ptr::eq(cloned.dataset, original.dataset));
|
||||
assert_eq!(cloned.size, original.size);
|
||||
assert_eq!(cloned.current, original.current);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
pub fn cloned_iterator_should_be_unaffected() {
|
||||
let items = [1, 2, 3, 4, 5].to_vec();
|
||||
let dataset = InMemDataset::new(items.clone());
|
||||
let mut original = dataset.windows(4);
|
||||
|
||||
let cloned = original.clone();
|
||||
original.current = 2;
|
||||
|
||||
assert_ne!(cloned.current, original.current);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[should_panic(expected = "window size must be non-zero")]
|
||||
pub fn windows_should_panic() {
|
||||
let items = [1, 2].to_vec();
|
||||
let dataset = InMemDataset::new(items.clone());
|
||||
|
||||
dataset.windows(0);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[should_panic(expected = "window size must be non-zero")]
|
||||
pub fn new_window_dataset_should_panic() {
|
||||
let items = [1, 2].to_vec();
|
||||
let dataset = InMemDataset::new(items.clone());
|
||||
|
||||
WindowsDataset::new(dataset, 0);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
pub fn window_dataset_len_should_be_equal() {
|
||||
let dataset = InMemDataset::new([1, 2, 3, 4].to_vec());
|
||||
|
||||
let result = WindowsDataset::new(dataset, 2).len();
|
||||
|
||||
assert_eq!(result, 3);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
pub fn window_iterator_should_be_empty() {
|
||||
let dataset = InMemDataset::new([1, 2].to_vec());
|
||||
let mut peekable = dataset.windows(4).peekable();
|
||||
|
||||
let result = peekable.peek();
|
||||
|
||||
assert_eq!(result, None);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
pub fn window_dataset_len_should_be_zero() {
|
||||
let dataset = InMemDataset::new([1, 2].to_vec());
|
||||
|
||||
let result = WindowsDataset::new(dataset, 4).len();
|
||||
|
||||
assert_eq!(result, 0);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
pub fn window_dataset_get_should_be_equal() {
|
||||
let dataset = InMemDataset::new([1, 2, 3, 4].to_vec());
|
||||
let expected = Some([1, 2, 3].to_vec());
|
||||
|
||||
let result = WindowsDataset::new(dataset, 3).get(0);
|
||||
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
pub fn window_dataset_get_should_be_none() {
|
||||
let dataset = InMemDataset::new([1, 2].to_vec());
|
||||
|
||||
let result = WindowsDataset::new(dataset, 4).get(0);
|
||||
|
||||
assert_eq!(result, None);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,241 @@
|
||||
//! CIFAR Dataset Module
|
||||
//!
|
||||
//! This module provides functionality for loading the CIFAR-10 and CIFAR-100 image classification datasets.
|
||||
//! CIFAR (Canadian Institute For Advanced Research) datasets are widely used benchmarks in computer vision,
|
||||
//! consisting of 32×32 pixel color images split into training (50,000 images) and test (10,000 images) sets.
|
||||
//!
|
||||
//! ## Dataset Variants
|
||||
//! - **CIFAR-10**: Contains 10 distinct classes (e.g., airplane, automobile, bird, cat)
|
||||
//! - CIFAR-10 mirror from [fastai](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L44).
|
||||
//! - Licensed under the [Apache License](https://github.com/fastai/fastai/blob/master/LICENSE).
|
||||
//! - **CIFAR-100**: Contains 100 fine-grained classes (e.g., beaver, dolphin, oak tree)
|
||||
//! - CIFAR-100 mirror from [fastai](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L75).
|
||||
//! - Licensed under the [Apache License](https://github.com/fastai/fastai/blob/master/LICENSE).
|
||||
//!
|
||||
//! ## Usage Example
|
||||
//! ```rust
|
||||
//! use burn_dataset::vision::CifarDataset;
|
||||
//! use burn_dataset::vision::CifarType;
|
||||
//!
|
||||
//! // Create a CIFAR-10 dataset accessor
|
||||
//! let dataset = CifarDataset::new(CifarType::Cifar10);
|
||||
//!
|
||||
//! // Access training and test sets
|
||||
//! let train_dataset = dataset.train();
|
||||
//! let test_dataset = dataset.test();
|
||||
//! ```
|
||||
//! ```rust
|
||||
//! use burn_dataset::vision::CifarDataset;
|
||||
//! use burn_dataset::vision::CifarType;
|
||||
//!
|
||||
//! // Create a CIFAR-100 dataset accessor
|
||||
//! let dataset = CifarDataset::new(CifarType::Cifar100);
|
||||
//!
|
||||
//! // Access training and test sets
|
||||
//! let train_dataset = dataset.train();
|
||||
//! let test_dataset = dataset.test();
|
||||
//! ```
|
||||
|
||||
use std::{path::PathBuf, sync::Mutex};
|
||||
|
||||
use flate2::read::GzDecoder;
|
||||
use tar::Archive;
|
||||
|
||||
use crate::network::downloader;
|
||||
use crate::vision::ImageFolderDataset;
|
||||
|
||||
/// CIFAR-10 mirror from [fastai](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L44).
|
||||
/// Licensed under the [Apache License](https://github.com/fastai/fastai/blob/master/LICENSE).
|
||||
const CIFAR10_URL: &str = "https://s3.amazonaws.com/fast-ai-sample/cifar10.tgz";
|
||||
|
||||
/// CIFAR-100 mirror from [fastai](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L75).
|
||||
/// Licensed under the [Apache License](https://github.com/fastai/fastai/blob/master/LICENSE).
|
||||
const CIFAR100_URL: &str = "https://s3.amazonaws.com/fast-ai-imageclas/cifar100.tgz";
|
||||
|
||||
/// Enum representing the types of CIFAR datasets available.
|
||||
///
|
||||
/// CIFAR (Canadian Institute For Advanced Research) datasets are widely used benchmarks for image classification.
|
||||
/// This enum provides support for the two main CIFAR datasets.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
#[allow(dead_code)]
|
||||
pub enum CifarType {
|
||||
/// CIFAR-10 dataset containing 10 classes with 60,000 images in total.
|
||||
Cifar10,
|
||||
/// CIFAR-100 dataset containing 100 classes with 60,000 images in total.
|
||||
Cifar100,
|
||||
}
|
||||
|
||||
/// CIFAR dataset accessor.
|
||||
///
|
||||
/// This struct provides convenient access to the CIFAR-10 and CIFAR-100 image classification datasets.
|
||||
/// It automatically downloads (if not already downloaded), extracts, and loads the datasets.
|
||||
///
|
||||
/// All images in CIFAR datasets are 32×32 pixel color images, with 50,000 images in the training set
|
||||
/// and 10,000 images in the test set.
|
||||
///
|
||||
/// ## Differences between datasets
|
||||
/// - **CIFAR-10**: Contains 10 mutually exclusive classes such as airplane, automobile, bird, cat, etc.
|
||||
/// - **CIFAR-100**: Contains 100 fine-grained classes such as beaver, dolphin, etc.
|
||||
pub struct CifarDataset {
|
||||
cifar_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl CifarDataset {
|
||||
/// Creates a new CIFAR dataset accessor.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `cifar_type` - Specifies whether to use CIFAR-10 or CIFAR-100 dataset
|
||||
pub fn new(cifar_type: CifarType) -> Self {
|
||||
Self {
|
||||
cifar_dir: download(&cifar_type),
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the training dataset.
|
||||
///
|
||||
/// # Returns
|
||||
/// An `ImageFolderDataset` instance containing 50,000 training images
|
||||
pub fn train(&self) -> ImageFolderDataset {
|
||||
ImageFolderDataset::new_classification(self.cifar_dir.join("train")).unwrap()
|
||||
}
|
||||
|
||||
/// Gets the test dataset.
|
||||
///
|
||||
/// # Returns
|
||||
/// An `ImageFolderDataset` instance containing 10,000 test images
|
||||
pub fn test(&self) -> ImageFolderDataset {
|
||||
ImageFolderDataset::new_classification(self.cifar_dir.join("test")).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// CIFAR dataset download lock.
|
||||
///
|
||||
/// This lock ensures that only one thread downloads the CIFAR dataset at a time.
|
||||
static DOWNLOAD_LOCK: Mutex<()> = Mutex::new(());
|
||||
|
||||
fn download(cifar_type: &CifarType) -> PathBuf {
|
||||
// Acquire the lock. This will block if another thread already holds the lock.
|
||||
let _lock = DOWNLOAD_LOCK.lock().unwrap();
|
||||
|
||||
// Dataset files are stored in the burn-dataset cache directory
|
||||
let cache_dir = dirs::cache_dir()
|
||||
.expect("Could not get cache directory")
|
||||
.join("burn-dataset");
|
||||
|
||||
// Cifar store directory
|
||||
let cifar_dir = match cifar_type {
|
||||
CifarType::Cifar10 => cache_dir.join("cifar10"),
|
||||
CifarType::Cifar100 => cache_dir.join("cifar100"),
|
||||
};
|
||||
|
||||
// Cifar dataset url
|
||||
let url = match cifar_type {
|
||||
CifarType::Cifar10 => CIFAR10_URL,
|
||||
CifarType::Cifar100 => CIFAR100_URL,
|
||||
};
|
||||
|
||||
// Cifar dataset archive filename
|
||||
let filename = match cifar_type {
|
||||
CifarType::Cifar10 => "cifar10.tgz",
|
||||
CifarType::Cifar100 => "cifar100.tgz",
|
||||
};
|
||||
|
||||
// Check for already downloaded content
|
||||
if !cifar_dir.exists() {
|
||||
// Download gzip file
|
||||
let bytes = downloader::download_file_as_bytes(url, filename);
|
||||
|
||||
// Decode gzip file content and unpack archive
|
||||
let gz_buffer = GzDecoder::new(&bytes[..]);
|
||||
let mut archive = Archive::new(gz_buffer);
|
||||
archive.unpack(cache_dir).unwrap();
|
||||
}
|
||||
|
||||
cifar_dir
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{Dataset, vision::Annotation};
|
||||
|
||||
/// CIFAR dataset length
|
||||
const TRAINDATASET_LEN: usize = 50000;
|
||||
const TESTDATASET_LEN: usize = 10000;
|
||||
|
||||
/// CIFAR-10 label range
|
||||
const CIFAR10_LABEL_MIN: usize = 0;
|
||||
const CIFAR10_LABEL_MAX: usize = 9;
|
||||
|
||||
/// CIFAR-100 label range
|
||||
const CIFAR100_LABEL_MIN: usize = 0;
|
||||
const CIFAR100_LABEL_MAX: usize = 99;
|
||||
|
||||
#[test]
|
||||
fn test_cifar10_download() {
|
||||
let cifar_dir = download(&CifarType::Cifar10);
|
||||
assert!(cifar_dir.exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cifar100_download() {
|
||||
let cifar_dir = download(&CifarType::Cifar100);
|
||||
assert!(cifar_dir.exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cifar10_len() {
|
||||
let dataset = CifarDataset::new(CifarType::Cifar10);
|
||||
let train_dataset = dataset.train();
|
||||
let test_dataset = dataset.test();
|
||||
assert_eq!(train_dataset.len(), TRAINDATASET_LEN);
|
||||
assert_eq!(test_dataset.len(), TESTDATASET_LEN);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cifar100_len() {
|
||||
let dataset = CifarDataset::new(CifarType::Cifar100);
|
||||
let train_dataset = dataset.train();
|
||||
let test_dataset = dataset.test();
|
||||
assert_eq!(train_dataset.len(), TRAINDATASET_LEN);
|
||||
assert_eq!(test_dataset.len(), TESTDATASET_LEN);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cifar10_label_range() {
|
||||
let dataset = CifarDataset::new(CifarType::Cifar10);
|
||||
let test_dataset = dataset.test();
|
||||
let (min, max) = get_label_range(&test_dataset);
|
||||
assert_eq!(min, CIFAR10_LABEL_MIN);
|
||||
assert_eq!(max, CIFAR10_LABEL_MAX);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cifar100_label_range() {
|
||||
let dataset = CifarDataset::new(CifarType::Cifar100);
|
||||
let test_dataset = dataset.test();
|
||||
let (min, max) = get_label_range(&test_dataset);
|
||||
assert_eq!(min, CIFAR100_LABEL_MIN);
|
||||
assert_eq!(max, CIFAR100_LABEL_MAX);
|
||||
}
|
||||
|
||||
fn get_label_range(dataset: &ImageFolderDataset) -> (usize, usize) {
|
||||
let labels: Vec<_> = dataset.iter().map(|item| item.annotation).collect();
|
||||
let mut min = 128;
|
||||
let mut max = 0;
|
||||
for label in labels {
|
||||
let index = match label {
|
||||
Annotation::Label(index) => index,
|
||||
_ => 0,
|
||||
};
|
||||
if index < min {
|
||||
min = index;
|
||||
}
|
||||
if index > max {
|
||||
max = index;
|
||||
}
|
||||
}
|
||||
|
||||
(min, max)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,221 @@
|
||||
use std::fs::{File, create_dir_all};
|
||||
use std::io::{Read, Seek, SeekFrom};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use flate2::read::GzDecoder;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
Dataset, InMemDataset,
|
||||
transform::{Mapper, MapperDataset},
|
||||
};
|
||||
|
||||
use crate::network::downloader::download_file_as_bytes;
|
||||
|
||||
// CVDF mirror of http://yann.lecun.com/exdb/mnist/
|
||||
const URL: &str = "https://storage.googleapis.com/cvdf-datasets/mnist/";
|
||||
const TRAIN_IMAGES: &str = "train-images-idx3-ubyte";
|
||||
const TRAIN_LABELS: &str = "train-labels-idx1-ubyte";
|
||||
const TEST_IMAGES: &str = "t10k-images-idx3-ubyte";
|
||||
const TEST_LABELS: &str = "t10k-labels-idx1-ubyte";
|
||||
|
||||
const WIDTH: usize = 28;
|
||||
const HEIGHT: usize = 28;
|
||||
|
||||
/// MNIST item.
|
||||
#[derive(Deserialize, Serialize, Debug, Clone)]
|
||||
pub struct MnistItem {
|
||||
/// Image as a 2D array of floats.
|
||||
pub image: [[f32; WIDTH]; HEIGHT],
|
||||
|
||||
/// Label of the image.
|
||||
pub label: u8,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
struct MnistItemRaw {
|
||||
pub image_bytes: Vec<u8>,
|
||||
pub label: u8,
|
||||
}
|
||||
|
||||
struct BytesToImage;
|
||||
|
||||
impl Mapper<MnistItemRaw, MnistItem> for BytesToImage {
|
||||
/// Convert a raw MNIST item (image bytes) to a MNIST item (2D array image).
|
||||
fn map(&self, item: &MnistItemRaw) -> MnistItem {
|
||||
// Ensure the image dimensions are correct.
|
||||
debug_assert_eq!(item.image_bytes.len(), WIDTH * HEIGHT);
|
||||
|
||||
// Convert the image to a 2D array of floats.
|
||||
let mut image_array = [[0f32; WIDTH]; HEIGHT];
|
||||
for (i, pixel) in item.image_bytes.iter().enumerate() {
|
||||
let x = i % WIDTH;
|
||||
let y = i / HEIGHT;
|
||||
image_array[y][x] = *pixel as f32;
|
||||
}
|
||||
|
||||
MnistItem {
|
||||
image: image_array,
|
||||
label: item.label,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type MappedDataset = MapperDataset<InMemDataset<MnistItemRaw>, BytesToImage, MnistItemRaw>;
|
||||
|
||||
/// The MNIST dataset consists of 70,000 28x28 black-and-white images in 10 classes (one for each digits), with 7,000
|
||||
/// images per class. There are 60,000 training images and 10,000 test images.
|
||||
///
|
||||
/// The data is downloaded from the web from the [CVDF mirror](https://github.com/cvdfoundation/mnist).
|
||||
pub struct MnistDataset {
|
||||
dataset: MappedDataset,
|
||||
}
|
||||
|
||||
impl Dataset<MnistItem> for MnistDataset {
|
||||
fn get(&self, index: usize) -> Option<MnistItem> {
|
||||
self.dataset.get(index)
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.dataset.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl MnistDataset {
|
||||
/// Creates a new train dataset.
|
||||
pub fn train() -> Self {
|
||||
Self::new("train")
|
||||
}
|
||||
|
||||
/// Creates a new test dataset.
|
||||
pub fn test() -> Self {
|
||||
Self::new("test")
|
||||
}
|
||||
|
||||
fn new(split: &str) -> Self {
|
||||
// Download dataset
|
||||
let root = MnistDataset::download(split);
|
||||
|
||||
// MNIST is tiny so we can load it in-memory
|
||||
// Train images (u8): 28 * 28 * 60000 = 47.04Mb
|
||||
// Test images (u8): 28 * 28 * 10000 = 7.84Mb
|
||||
let images = MnistDataset::read_images(&root, split);
|
||||
let labels = MnistDataset::read_labels(&root, split);
|
||||
|
||||
// Collect as vector of MnistItemRaw
|
||||
let items: Vec<_> = images
|
||||
.into_iter()
|
||||
.zip(labels)
|
||||
.map(|(image_bytes, label)| MnistItemRaw { image_bytes, label })
|
||||
.collect();
|
||||
|
||||
let dataset = InMemDataset::new(items);
|
||||
let dataset = MapperDataset::new(dataset, BytesToImage);
|
||||
|
||||
Self { dataset }
|
||||
}
|
||||
|
||||
/// Download the MNIST dataset files from the web.
|
||||
/// Panics if the download cannot be completed or the content of the file cannot be written to disk.
|
||||
fn download(split: &str) -> PathBuf {
|
||||
// Dataset files are stored in the burn-dataset cache directory
|
||||
let cache_dir = dirs::cache_dir()
|
||||
.expect("Could not get cache directory")
|
||||
.join("burn-dataset");
|
||||
let split_dir = cache_dir.join("mnist").join(split);
|
||||
|
||||
if !split_dir.exists() {
|
||||
create_dir_all(&split_dir).expect("Failed to create base directory");
|
||||
}
|
||||
|
||||
// Download split files
|
||||
match split {
|
||||
"train" => {
|
||||
MnistDataset::download_file(TRAIN_IMAGES, &split_dir);
|
||||
MnistDataset::download_file(TRAIN_LABELS, &split_dir);
|
||||
}
|
||||
"test" => {
|
||||
MnistDataset::download_file(TEST_IMAGES, &split_dir);
|
||||
MnistDataset::download_file(TEST_LABELS, &split_dir);
|
||||
}
|
||||
_ => panic!("Invalid split specified {split}"),
|
||||
};
|
||||
|
||||
split_dir
|
||||
}
|
||||
|
||||
/// Download a file from the MNIST dataset URL to the destination directory.
|
||||
/// File download progress is reported with the help of a [progress bar](indicatif).
|
||||
fn download_file<P: AsRef<Path>>(name: &str, dest_dir: &P) -> PathBuf {
|
||||
// Output file name
|
||||
let file_name = dest_dir.as_ref().join(name);
|
||||
|
||||
if !file_name.exists() {
|
||||
// Download gzip file
|
||||
let bytes = download_file_as_bytes(&format!("{URL}{name}.gz"), name);
|
||||
|
||||
// Create file to write the downloaded content to
|
||||
let mut output_file = File::create(&file_name).unwrap();
|
||||
|
||||
// Decode gzip file content and write to disk
|
||||
let mut gz_buffer = GzDecoder::new(&bytes[..]);
|
||||
std::io::copy(&mut gz_buffer, &mut output_file).unwrap();
|
||||
}
|
||||
|
||||
file_name
|
||||
}
|
||||
|
||||
/// Read images at the provided path for the specified split.
|
||||
/// Each image is a vector of bytes.
|
||||
fn read_images<P: AsRef<Path>>(root: &P, split: &str) -> Vec<Vec<u8>> {
|
||||
let file_name = if split == "train" {
|
||||
TRAIN_IMAGES
|
||||
} else {
|
||||
TEST_IMAGES
|
||||
};
|
||||
let file_name = root.as_ref().join(file_name);
|
||||
|
||||
// Read number of images from 16-byte header metadata
|
||||
let mut f = File::open(file_name).unwrap();
|
||||
let mut buf = [0u8; 4];
|
||||
let _ = f.seek(SeekFrom::Start(4)).unwrap();
|
||||
f.read_exact(&mut buf)
|
||||
.expect("Should be able to read image file header");
|
||||
let size = u32::from_be_bytes(buf);
|
||||
|
||||
let mut buf_images: Vec<u8> = vec![0u8; WIDTH * HEIGHT * (size as usize)];
|
||||
let _ = f.seek(SeekFrom::Start(16)).unwrap();
|
||||
f.read_exact(&mut buf_images)
|
||||
.expect("Should be able to read image file header");
|
||||
|
||||
buf_images
|
||||
.chunks(WIDTH * HEIGHT)
|
||||
.map(|chunk| chunk.to_vec())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Read labels at the provided path for the specified split.
|
||||
fn read_labels<P: AsRef<Path>>(root: &P, split: &str) -> Vec<u8> {
|
||||
let file_name = if split == "train" {
|
||||
TRAIN_LABELS
|
||||
} else {
|
||||
TEST_LABELS
|
||||
};
|
||||
let file_name = root.as_ref().join(file_name);
|
||||
|
||||
// Read number of labels from 8-byte header metadata
|
||||
let mut f = File::open(file_name).unwrap();
|
||||
let mut buf = [0u8; 4];
|
||||
let _ = f.seek(SeekFrom::Start(4)).unwrap();
|
||||
f.read_exact(&mut buf)
|
||||
.expect("Should be able to read label file header");
|
||||
let size = u32::from_be_bytes(buf);
|
||||
|
||||
let mut buf_labels: Vec<u8> = vec![0u8; size as usize];
|
||||
let _ = f.seek(SeekFrom::Start(8)).unwrap();
|
||||
f.read_exact(&mut buf_labels)
|
||||
.expect("Should be able to read labels from file");
|
||||
|
||||
buf_labels
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
#[cfg(feature = "builtin-sources")]
|
||||
mod cifar;
|
||||
mod image_folder;
|
||||
mod mnist;
|
||||
|
||||
#[cfg(feature = "builtin-sources")]
|
||||
pub use cifar::*;
|
||||
pub use image_folder::*;
|
||||
pub use mnist::*;
|
||||
@@ -0,0 +1,2 @@
|
||||
HI1 1 true 1.0
|
||||
HI2 1 false 1.0
|
||||
|
@@ -0,0 +1,3 @@
|
||||
column_str,column_int,column_bool,column_float
|
||||
HI1,1,true,1.0
|
||||
HI2,1,false,1.0
|
||||
|
@@ -0,0 +1,2 @@
|
||||
{"column_str":"HI1","column_bytes":[1,2,3,3],"column_int":1,"column_bool":true,"column_float":1.0}
|
||||
{"column_str":"HI2","column_bytes":[1,2,3,3],"column_int":1,"column_bool":false,"column_float":1.0}
|
||||
@@ -0,0 +1,132 @@
|
||||
{
|
||||
"images": [
|
||||
{
|
||||
"width": 32,
|
||||
"height": 32,
|
||||
"id": 0,
|
||||
"file_name": "two_dots_and_triangle.jpg"
|
||||
},
|
||||
{
|
||||
"width": 32,
|
||||
"height": 32,
|
||||
"id": 1,
|
||||
"file_name": "dot_triangle.jpg"
|
||||
},
|
||||
{
|
||||
"width": 32,
|
||||
"height": 32,
|
||||
"id": 2,
|
||||
"file_name": "one_dot.jpg"
|
||||
}
|
||||
],
|
||||
"categories": [
|
||||
{
|
||||
"id": 0,
|
||||
"name": "dot"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"name": "triangle"
|
||||
}
|
||||
],
|
||||
"annotations": [
|
||||
{
|
||||
"id": 0,
|
||||
"image_id": 0,
|
||||
"category_id": 0,
|
||||
"segmentation": [],
|
||||
"bbox": [
|
||||
3.1251719394773056,
|
||||
18.0907840440165,
|
||||
10.96011004126548,
|
||||
10.740027510316379
|
||||
],
|
||||
"ignore": 0,
|
||||
"iscrowd": 0,
|
||||
"area": 117.71188335928603
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"image_id": 0,
|
||||
"category_id": 0,
|
||||
"segmentation": [],
|
||||
"bbox": [
|
||||
3.2572214580467658,
|
||||
3.0371389270976605,
|
||||
10.563961485557085,
|
||||
10.828060522696012
|
||||
],
|
||||
"ignore": 0,
|
||||
"iscrowd": 0,
|
||||
"area": 114.38721432504178
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"image_id": 0,
|
||||
"category_id": 1,
|
||||
"segmentation": [],
|
||||
"bbox": [
|
||||
15.097661623108666,
|
||||
3.3892709766162312,
|
||||
12.632737276478679,
|
||||
11.18019257221458
|
||||
],
|
||||
"ignore": 0,
|
||||
"iscrowd": 0,
|
||||
"area": 141.23643546522516
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"image_id": 1,
|
||||
"category_id": 0,
|
||||
"segmentation": [],
|
||||
"bbox": [
|
||||
3.125171939477304,
|
||||
17.914718019257222,
|
||||
10.82806052269601,
|
||||
11.004126547455297
|
||||
],
|
||||
"ignore": 0,
|
||||
"iscrowd": 0,
|
||||
"area": 119.15334825525184
|
||||
},
|
||||
{
|
||||
"id": 4,
|
||||
"image_id": 1,
|
||||
"category_id": 1,
|
||||
"segmentation": [],
|
||||
"bbox": [
|
||||
15.27372764786794,
|
||||
3.301237964236589,
|
||||
12.192572214580478,
|
||||
11.708390646492433
|
||||
],
|
||||
"ignore": 0,
|
||||
"iscrowd": 0,
|
||||
"area": 142.7553984738776
|
||||
},
|
||||
{
|
||||
"id": 5,
|
||||
"image_id": 2,
|
||||
"category_id": 0,
|
||||
"segmentation": [],
|
||||
"bbox": [
|
||||
10.07977991746905,
|
||||
9.59559834938102,
|
||||
10.960110041265464,
|
||||
11.356258596973863
|
||||
],
|
||||
"ignore": 0,
|
||||
"iscrowd": 0,
|
||||
"area": 124.46584387990049
|
||||
}
|
||||
],
|
||||
"info": {
|
||||
"year": 2024,
|
||||
"version": "1.0",
|
||||
"description": "",
|
||||
"contributor": "",
|
||||
"url": "",
|
||||
"date_created": "2024-12-11 22:16:31.823494"
|
||||
}
|
||||
}
|
||||
|
After Width: | Height: | Size: 727 B |
|
After Width: | Height: | Size: 634 B |
|
After Width: | Height: | Size: 120 B |
|
After Width: | Height: | Size: 1.5 KiB |
|
After Width: | Height: | Size: 1.4 KiB |
|
After Width: | Height: | Size: 1.7 KiB |
|
After Width: | Height: | Size: 117 B |
@@ -0,0 +1,8 @@
|
||||
1 2 1 2 1 2 1 2
|
||||
2 1 2 1 2 1 2 1
|
||||
1 2 1 2 1 2 1 2
|
||||
2 1 2 1 2 1 2 1
|
||||
1 2 1 2 1 2 1 2
|
||||
2 1 2 1 2 1 2 1
|
||||
1 2 1 2 1 2 1 2
|
||||
2 1 2 1 2 1 2 1
|
||||
|
After Width: | Height: | Size: 123 B |
@@ -0,0 +1,8 @@
|
||||
1 2 1 1 1 2 1 1
|
||||
1 2 1 1 1 1 2 1
|
||||
2 2 2 1 2 1 2 2
|
||||
2 2 2 2 2 2 1 1
|
||||
2 2 2 1 2 1 1 1
|
||||
1 1 2 2 2 2 2 1
|
||||
2 2 1 2 1 2 1 2
|
||||
2 1 1 1 1 1 1 1
|
||||
|
After Width: | Height: | Size: 137 B |
@@ -0,0 +1,8 @@
|
||||
3 1 3 3 1 1 3 2
|
||||
3 3 3 3 1 3 2 1
|
||||
2 2 2 2 1 1 2 2
|
||||
1 1 1 3 3 3 2 3
|
||||
2 2 3 2 3 3 1 3
|
||||
1 3 3 1 1 3 2 1
|
||||
2 2 2 1 2 1 2 3
|
||||
3 1 3 3 2 1 2 2
|
||||
|
After Width: | Height: | Size: 165 B |
|
After Width: | Height: | Size: 133 B |
|
After Width: | Height: | Size: 204 B |
@@ -0,0 +1 @@
|
||||
This is a negative text sample for testing the text folder dataset functionality.
|
||||
@@ -0,0 +1 @@
|
||||
另一个负面文本样本,用以确保数据集能够处理同一类别中的多个文件。
|
||||
@@ -0,0 +1 @@
|
||||
This is a positive text sample for testing the text folder dataset functionality.
|
||||
@@ -0,0 +1 @@
|
||||
另一个正面文本样本,以确保数据集能够处理同一类别中的多个文件。
|
||||