Fix allocation bug

This commit is contained in:
Gadersd
2023-08-05 17:05:20 -04:00
parent 2f3d94f9ad
commit 058dd1d7cc
6 changed files with 11 additions and 12 deletions

View File

@@ -14,4 +14,4 @@ num-traits = "0.2.15"
rust_tokenizers = "8.1.0" rust_tokenizers = "8.1.0"
regex = "1.9.1" regex = "1.9.1"
image = "0.24.6" image = "0.24.6"
bincode = "1.3.3" bincode = {version = "2.0.0-alpha.0", features = ["std"]}

View File

@@ -19,9 +19,7 @@ export TORCH_CUDA_VERSION=cu113
``` ```
### Step 2: Run the Sample Binary ### Step 2: Run the Sample Binary
Invoke the sample binary provided in the rust code, as shown below. Loading the burn model file Invoke the sample binary provided in the rust code, as shown below:
is currently very slow, but hopefully that will be rectified soon. You can also dump a torch model's weights and load that
which is currently much faster.
```bash ```bash
# Arguments: <model_type(burn or dump)> <model> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image> # Arguments: <model_type(burn or dump)> <model> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image>

View File

@@ -17,7 +17,7 @@ use burn::{
use burn_tch::{TchBackend, TchDevice}; use burn_tch::{TchBackend, TchDevice};
use burn::record::{self, Recorder, FullPrecisionSettings}; use burn::record::{self, Recorder, FullPrecisionSettings};
use stablediffusion::binrecorder::{BinFileRecorderBuffered}; use stablediffusion::binrecorderfast::{BinFileRecorderBuffered};
fn convert_dump_to_model<B: Backend>(dump_path: &str, model_name: &str, device: &B::Device) -> Result<(), Box<dyn Error>> { fn convert_dump_to_model<B: Backend>(dump_path: &str, model_name: &str, device: &B::Device) -> Result<(), Box<dyn Error>> {
println!("Loading dump..."); println!("Loading dump...");

View File

@@ -16,7 +16,7 @@ use std::io;
use std::process; use std::process;
use burn::record::{self, Recorder, FullPrecisionSettings}; use burn::record::{self, Recorder, FullPrecisionSettings};
use stablediffusion::binrecorder::{BinFileRecorderBuffered}; use stablediffusion::binrecorderfast::{BinFileRecorderBuffered};
fn load_stable_diffusion_model_file<B: Backend>(filename: &str) -> Result<StableDiffusion<B>, record::RecorderError> { fn load_stable_diffusion_model_file<B: Backend>(filename: &str) -> Result<StableDiffusion<B>, record::RecorderError> {
BinFileRecorderBuffered::<FullPrecisionSettings>::new() BinFileRecorderBuffered::<FullPrecisionSettings>::new()

View File

@@ -1,4 +1,4 @@
use bincode::{self, Options}; use bincode;
use burn::record::{PrecisionSettings, Recorder, RecorderError, FileRecorder}; use burn::record::{PrecisionSettings, Recorder, RecorderError, FileRecorder};
use std::fs::File; use std::fs::File;
use std::io::{BufReader, BufWriter}; use std::io::{BufReader, BufWriter};
@@ -7,9 +7,9 @@ use std::marker::PhantomData;
use serde::{de::DeserializeOwned, Serialize}; use serde::{de::DeserializeOwned, Serialize};
//use super::{bin_config, PrecisionSettings, Recorder, RecorderError}; //use super::{bin_config, PrecisionSettings, Recorder, RecorderError};
/*fn bin_config() -> bincode::config::Configuration { fn bin_config() -> bincode::config::Configuration {
bincode::config::standard() bincode::config::standard()
}*/ }
macro_rules! str2reader { macro_rules! str2reader {
($file:expr) => {{ ($file:expr) => {{
@@ -69,8 +69,9 @@ impl<S: PrecisionSettings> Recorder for BinFileRecorderBuffered<S> {
item: I, item: I,
mut file: Self::RecordArgs, mut file: Self::RecordArgs,
) -> Result<(), RecorderError> { ) -> Result<(), RecorderError> {
let config = bin_config();
let mut writer = str2writer!(file)?; let mut writer = str2writer!(file)?;
bincode::serialize_into(&mut writer, &item) bincode::serde::encode_into_std_write(&item, &mut writer, config)
.map_err(|err| RecorderError::Unknown(err.to_string()))?; .map_err(|err| RecorderError::Unknown(err.to_string()))?;
Ok(()) Ok(())
} }
@@ -78,7 +79,7 @@ impl<S: PrecisionSettings> Recorder for BinFileRecorderBuffered<S> {
fn load_item<I: DeserializeOwned>(&self, mut file: Self::LoadArgs) -> Result<I, RecorderError> { fn load_item<I: DeserializeOwned>(&self, mut file: Self::LoadArgs) -> Result<I, RecorderError> {
let mut reader = str2reader!(file)?; let mut reader = str2reader!(file)?;
let state = let state =
bincode::deserialize_from(&mut reader) bincode::serde::decode_from_std_read(&mut reader, bin_config())
.map_err(|err| RecorderError::Unknown(err.to_string()))?; .map_err(|err| RecorderError::Unknown(err.to_string()))?;
Ok(state) Ok(state)
} }

View File

@@ -1,4 +1,4 @@
pub mod model; pub mod model;
pub mod tokenizer; pub mod tokenizer;
pub mod helper; pub mod helper;
pub mod binrecorder; pub mod binrecorderfast;