Fix allocation bug

This commit is contained in:
Gadersd
2023-08-05 17:05:20 -04:00
committed by Ben_Kosytorz
parent a84825838d
commit ebf7c32e88
6 changed files with 11 additions and 12 deletions

View File

@@ -14,4 +14,4 @@ num-traits = "0.2.15"
rust_tokenizers = "8.1.0"
regex = "1.9.1"
image = "0.24.6"
bincode = "1.3.3"
bincode = {version = "2.0.0-alpha.0", features = ["std"]}

View File

@@ -19,9 +19,7 @@ export TORCH_CUDA_VERSION=cu113
```
### Step 2: Run the Sample Binary
Invoke the sample binary provided in the rust code, as shown below. Loading the burn model file
is currently very slow, but hopefully that will be rectified soon. You can also dump a torch model's weights and load that
which is currently much faster.
Invoke the sample binary provided in the rust code, as shown below:
```bash
# Arguments: <model_type(burn or dump)> <model> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image>

View File

@@ -17,7 +17,7 @@ use burn::{
use burn_tch::{TchBackend, TchDevice};
use burn::record::{self, Recorder, FullPrecisionSettings};
use stablediffusion::binrecorder::{BinFileRecorderBuffered};
use stablediffusion::binrecorderfast::{BinFileRecorderBuffered};
fn convert_dump_to_model<B: Backend>(dump_path: &str, model_name: &str, device: &B::Device) -> Result<(), Box<dyn Error>> {
println!("Loading dump...");

View File

@@ -16,7 +16,7 @@ use std::io;
use std::process;
use burn::record::{self, Recorder, FullPrecisionSettings};
use stablediffusion::binrecorder::{BinFileRecorderBuffered};
use stablediffusion::binrecorderfast::{BinFileRecorderBuffered};
fn load_stable_diffusion_model_file<B: Backend>(filename: &str) -> Result<StableDiffusion<B>, record::RecorderError> {
BinFileRecorderBuffered::<FullPrecisionSettings>::new()

View File

@@ -1,4 +1,4 @@
use bincode::{self, Options};
use bincode;
use burn::record::{PrecisionSettings, Recorder, RecorderError, FileRecorder};
use std::fs::File;
use std::io::{BufReader, BufWriter};
@@ -7,9 +7,9 @@ use std::marker::PhantomData;
use serde::{de::DeserializeOwned, Serialize};
//use super::{bin_config, PrecisionSettings, Recorder, RecorderError};
/*fn bin_config() -> bincode::config::Configuration {
fn bin_config() -> bincode::config::Configuration {
bincode::config::standard()
}*/
}
macro_rules! str2reader {
($file:expr) => {{
@@ -69,8 +69,9 @@ impl<S: PrecisionSettings> Recorder for BinFileRecorderBuffered<S> {
item: I,
mut file: Self::RecordArgs,
) -> Result<(), RecorderError> {
let config = bin_config();
let mut writer = str2writer!(file)?;
bincode::serialize_into(&mut writer, &item)
bincode::serde::encode_into_std_write(&item, &mut writer, config)
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
Ok(())
}
@@ -78,7 +79,7 @@ impl<S: PrecisionSettings> Recorder for BinFileRecorderBuffered<S> {
fn load_item<I: DeserializeOwned>(&self, mut file: Self::LoadArgs) -> Result<I, RecorderError> {
let mut reader = str2reader!(file)?;
let state =
bincode::deserialize_from(&mut reader)
bincode::serde::decode_from_std_read(&mut reader, bin_config())
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
Ok(state)
}

View File

@@ -1,4 +1,4 @@
pub mod model;
pub mod tokenizer;
pub mod helper;
pub mod binrecorder;
pub mod binrecorderfast;