From 058dd1d7ccb6c405bb6a9a8399203e0f67bc72b5 Mon Sep 17 00:00:00 2001 From: Gadersd Date: Sat, 5 Aug 2023 17:05:20 -0400 Subject: [PATCH] Fix allocation bug --- Cargo.toml | 2 +- README.md | 4 +--- src/bin/convert/main.rs | 2 +- src/bin/sample/main.rs | 2 +- src/{binrecorder.rs => binrecorderfast.rs} | 11 ++++++----- src/lib.rs | 2 +- 6 files changed, 11 insertions(+), 12 deletions(-) rename src/{binrecorder.rs => binrecorderfast.rs} (91%) diff --git a/Cargo.toml b/Cargo.toml index 2cacc75..6a2f234 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,4 +14,4 @@ num-traits = "0.2.15" rust_tokenizers = "8.1.0" regex = "1.9.1" image = "0.24.6" -bincode = "1.3.3" \ No newline at end of file +bincode = {version = "2.0.0-alpha.0", features = ["std"]} \ No newline at end of file diff --git a/README.md b/README.md index fb9e1f5..523abdf 100644 --- a/README.md +++ b/README.md @@ -19,9 +19,7 @@ export TORCH_CUDA_VERSION=cu113 ``` ### Step 2: Run the Sample Binary -Invoke the sample binary provided in the rust code, as shown below. Loading the burn model file -is currently very slow, but hopefully that will be rectified soon. You can also dump a torch model's weights and load that -which is currently much faster. +Invoke the sample binary provided in the rust code, as shown below: ```bash # Arguments: diff --git a/src/bin/convert/main.rs b/src/bin/convert/main.rs index 52d512b..223aa89 100644 --- a/src/bin/convert/main.rs +++ b/src/bin/convert/main.rs @@ -17,7 +17,7 @@ use burn::{ use burn_tch::{TchBackend, TchDevice}; use burn::record::{self, Recorder, FullPrecisionSettings}; -use stablediffusion::binrecorder::{BinFileRecorderBuffered}; +use stablediffusion::binrecorderfast::{BinFileRecorderBuffered}; fn convert_dump_to_model(dump_path: &str, model_name: &str, device: &B::Device) -> Result<(), Box> { println!("Loading dump..."); diff --git a/src/bin/sample/main.rs b/src/bin/sample/main.rs index 7495585..278c91d 100644 --- a/src/bin/sample/main.rs +++ b/src/bin/sample/main.rs @@ -16,7 +16,7 @@ use std::io; use std::process; use burn::record::{self, Recorder, FullPrecisionSettings}; -use stablediffusion::binrecorder::{BinFileRecorderBuffered}; +use stablediffusion::binrecorderfast::{BinFileRecorderBuffered}; fn load_stable_diffusion_model_file(filename: &str) -> Result, record::RecorderError> { BinFileRecorderBuffered::::new() diff --git a/src/binrecorder.rs b/src/binrecorderfast.rs similarity index 91% rename from src/binrecorder.rs rename to src/binrecorderfast.rs index 42814f1..b1c6ca2 100644 --- a/src/binrecorder.rs +++ b/src/binrecorderfast.rs @@ -1,4 +1,4 @@ -use bincode::{self, Options}; +use bincode; use burn::record::{PrecisionSettings, Recorder, RecorderError, FileRecorder}; use std::fs::File; use std::io::{BufReader, BufWriter}; @@ -7,9 +7,9 @@ use std::marker::PhantomData; use serde::{de::DeserializeOwned, Serialize}; //use super::{bin_config, PrecisionSettings, Recorder, RecorderError}; -/*fn bin_config() -> bincode::config::Configuration { +fn bin_config() -> bincode::config::Configuration { bincode::config::standard() -}*/ +} macro_rules! str2reader { ($file:expr) => {{ @@ -69,8 +69,9 @@ impl Recorder for BinFileRecorderBuffered { item: I, mut file: Self::RecordArgs, ) -> Result<(), RecorderError> { + let config = bin_config(); let mut writer = str2writer!(file)?; - bincode::serialize_into(&mut writer, &item) + bincode::serde::encode_into_std_write(&item, &mut writer, config) .map_err(|err| RecorderError::Unknown(err.to_string()))?; Ok(()) } @@ -78,7 +79,7 @@ impl Recorder for BinFileRecorderBuffered { fn load_item(&self, mut file: Self::LoadArgs) -> Result { let mut reader = str2reader!(file)?; let state = - bincode::deserialize_from(&mut reader) + bincode::serde::decode_from_std_read(&mut reader, bin_config()) .map_err(|err| RecorderError::Unknown(err.to_string()))?; Ok(state) } diff --git a/src/lib.rs b/src/lib.rs index 5ec3067..46cd752 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,4 @@ pub mod model; pub mod tokenizer; pub mod helper; -pub mod binrecorder; \ No newline at end of file +pub mod binrecorderfast; \ No newline at end of file