From a84825838d1543fa807f899f4ac4dfd0c4666b98 Mon Sep 17 00:00:00 2001 From: Gadersd Date: Sat, 5 Aug 2023 16:48:26 -0400 Subject: [PATCH] Use buffer to increase model loading/saving speed tremendously --- Cargo.toml | 3 +- src/bin/convert/main.rs | 5 ++- src/bin/sample/main.rs | 5 ++- src/binrecorder.rs | 85 +++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 3 +- 5 files changed, 95 insertions(+), 6 deletions(-) create mode 100644 src/binrecorder.rs diff --git a/Cargo.toml b/Cargo.toml index 1d8fb1b..2cacc75 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,4 +13,5 @@ npy = "0.4.0" num-traits = "0.2.15" rust_tokenizers = "8.1.0" regex = "1.9.1" -image = "0.24.6" \ No newline at end of file +image = "0.24.6" +bincode = "1.3.3" \ No newline at end of file diff --git a/src/bin/convert/main.rs b/src/bin/convert/main.rs index 924c9d4..52d512b 100644 --- a/src/bin/convert/main.rs +++ b/src/bin/convert/main.rs @@ -16,7 +16,8 @@ use burn::{ use burn_tch::{TchBackend, TchDevice}; -use burn::record::{self, Recorder, BinFileRecorder, FullPrecisionSettings}; +use burn::record::{self, Recorder, FullPrecisionSettings}; +use stablediffusion::binrecorder::{BinFileRecorderBuffered}; fn convert_dump_to_model(dump_path: &str, model_name: &str, device: &B::Device) -> Result<(), Box> { println!("Loading dump..."); @@ -29,7 +30,7 @@ fn convert_dump_to_model(dump_path: &str, model_name: &str, device: } fn save_model_file(model: StableDiffusion, name: &str) -> Result<(), record::RecorderError> { - BinFileRecorder::::new() + BinFileRecorderBuffered::::new() .record( model.into_record(), name.into(), diff --git a/src/bin/sample/main.rs b/src/bin/sample/main.rs index 6754a3e..7495585 100644 --- a/src/bin/sample/main.rs +++ b/src/bin/sample/main.rs @@ -15,10 +15,11 @@ use std::env; use std::io; use std::process; -use burn::record::{self, Recorder, BinFileRecorder, FullPrecisionSettings}; +use burn::record::{self, Recorder, FullPrecisionSettings}; +use stablediffusion::binrecorder::{BinFileRecorderBuffered}; fn load_stable_diffusion_model_file(filename: &str) -> Result, record::RecorderError> { - BinFileRecorder::::new() + BinFileRecorderBuffered::::new() .load(filename.into()) .map(|record| StableDiffusionConfig::new().init().load_record(record)) } diff --git a/src/binrecorder.rs b/src/binrecorder.rs new file mode 100644 index 0000000..42814f1 --- /dev/null +++ b/src/binrecorder.rs @@ -0,0 +1,85 @@ +use bincode::{self, Options}; +use burn::record::{PrecisionSettings, Recorder, RecorderError, FileRecorder}; +use std::fs::File; +use std::io::{BufReader, BufWriter}; +use std::path::PathBuf; +use std::marker::PhantomData; +use serde::{de::DeserializeOwned, Serialize}; +//use super::{bin_config, PrecisionSettings, Recorder, RecorderError}; + +/*fn bin_config() -> bincode::config::Configuration { + bincode::config::standard() +}*/ + +macro_rules! str2reader { + ($file:expr) => {{ + $file.set_extension(::file_extension()); + let path = $file.as_path(); + File::open(path).map_err(|err| match err.kind() { + std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()), + _ => RecorderError::Unknown(err.to_string()), + }).map(|file| BufReader::new(file)) // wrap File in BufReader + }}; +} + +macro_rules! str2writer { + ($file:expr) => {{ + $file.set_extension(::file_extension()); + let path = $file.as_path(); + + if path.exists() { + //log::info!("File exists, replacing"); + std::fs::remove_file(path).map_err(|err| RecorderError::Unknown(err.to_string()))?; + } + + File::create(path).map_err(|err| match err.kind() { + std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()), + _ => RecorderError::Unknown(err.to_string()), + }).map(|file| BufWriter::new(file)) // wrap File in BufWriter + }}; +} + +#[derive(Debug, Default, Clone)] +pub struct BinFileRecorderBuffered { + _settings: PhantomData, +} + +impl BinFileRecorderBuffered { + pub fn new() -> Self { + BinFileRecorderBuffered { + _settings: PhantomData, + } + } +} + +impl FileRecorder for BinFileRecorderBuffered { + fn file_extension() -> &'static str { + "bin" + } +} + +impl Recorder for BinFileRecorderBuffered { + type Settings = S; + type RecordArgs = PathBuf; + type RecordOutput = (); + type LoadArgs = PathBuf; + + fn save_item( + &self, + item: I, + mut file: Self::RecordArgs, + ) -> Result<(), RecorderError> { + let mut writer = str2writer!(file)?; + bincode::serialize_into(&mut writer, &item) + .map_err(|err| RecorderError::Unknown(err.to_string()))?; + Ok(()) + } + + fn load_item(&self, mut file: Self::LoadArgs) -> Result { + let mut reader = str2reader!(file)?; + let state = + bincode::deserialize_from(&mut reader) + .map_err(|err| RecorderError::Unknown(err.to_string()))?; + Ok(state) + } +} diff --git a/src/lib.rs b/src/lib.rs index 993bb59..5ec3067 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ pub mod model; pub mod tokenizer; -pub mod helper; \ No newline at end of file +pub mod helper; +pub mod binrecorder; \ No newline at end of file