Fix allocation bug
This commit is contained in:
@@ -14,4 +14,4 @@ num-traits = "0.2.15"
|
|||||||
rust_tokenizers = "8.1.0"
|
rust_tokenizers = "8.1.0"
|
||||||
regex = "1.9.1"
|
regex = "1.9.1"
|
||||||
image = "0.24.6"
|
image = "0.24.6"
|
||||||
bincode = "1.3.3"
|
bincode = {version = "2.0.0-alpha.0", features = ["std"]}
|
||||||
@@ -19,9 +19,7 @@ export TORCH_CUDA_VERSION=cu113
|
|||||||
```
|
```
|
||||||
### Step 2: Run the Sample Binary
|
### Step 2: Run the Sample Binary
|
||||||
|
|
||||||
Invoke the sample binary provided in the rust code, as shown below. Loading the burn model file
|
Invoke the sample binary provided in the rust code, as shown below:
|
||||||
is currently very slow, but hopefully that will be rectified soon. You can also dump a torch model's weights and load that
|
|
||||||
which is currently much faster.
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Arguments: <model_type(burn or dump)> <model> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image>
|
# Arguments: <model_type(burn or dump)> <model> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image>
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ use burn::{
|
|||||||
use burn_tch::{TchBackend, TchDevice};
|
use burn_tch::{TchBackend, TchDevice};
|
||||||
|
|
||||||
use burn::record::{self, Recorder, FullPrecisionSettings};
|
use burn::record::{self, Recorder, FullPrecisionSettings};
|
||||||
use stablediffusion::binrecorder::{BinFileRecorderBuffered};
|
use stablediffusion::binrecorderfast::{BinFileRecorderBuffered};
|
||||||
|
|
||||||
fn convert_dump_to_model<B: Backend>(dump_path: &str, model_name: &str, device: &B::Device) -> Result<(), Box<dyn Error>> {
|
fn convert_dump_to_model<B: Backend>(dump_path: &str, model_name: &str, device: &B::Device) -> Result<(), Box<dyn Error>> {
|
||||||
println!("Loading dump...");
|
println!("Loading dump...");
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ use std::io;
|
|||||||
use std::process;
|
use std::process;
|
||||||
|
|
||||||
use burn::record::{self, Recorder, FullPrecisionSettings};
|
use burn::record::{self, Recorder, FullPrecisionSettings};
|
||||||
use stablediffusion::binrecorder::{BinFileRecorderBuffered};
|
use stablediffusion::binrecorderfast::{BinFileRecorderBuffered};
|
||||||
|
|
||||||
fn load_stable_diffusion_model_file<B: Backend>(filename: &str) -> Result<StableDiffusion<B>, record::RecorderError> {
|
fn load_stable_diffusion_model_file<B: Backend>(filename: &str) -> Result<StableDiffusion<B>, record::RecorderError> {
|
||||||
BinFileRecorderBuffered::<FullPrecisionSettings>::new()
|
BinFileRecorderBuffered::<FullPrecisionSettings>::new()
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
use bincode::{self, Options};
|
use bincode;
|
||||||
use burn::record::{PrecisionSettings, Recorder, RecorderError, FileRecorder};
|
use burn::record::{PrecisionSettings, Recorder, RecorderError, FileRecorder};
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::{BufReader, BufWriter};
|
use std::io::{BufReader, BufWriter};
|
||||||
@@ -7,9 +7,9 @@ use std::marker::PhantomData;
|
|||||||
use serde::{de::DeserializeOwned, Serialize};
|
use serde::{de::DeserializeOwned, Serialize};
|
||||||
//use super::{bin_config, PrecisionSettings, Recorder, RecorderError};
|
//use super::{bin_config, PrecisionSettings, Recorder, RecorderError};
|
||||||
|
|
||||||
/*fn bin_config() -> bincode::config::Configuration {
|
fn bin_config() -> bincode::config::Configuration {
|
||||||
bincode::config::standard()
|
bincode::config::standard()
|
||||||
}*/
|
}
|
||||||
|
|
||||||
macro_rules! str2reader {
|
macro_rules! str2reader {
|
||||||
($file:expr) => {{
|
($file:expr) => {{
|
||||||
@@ -69,8 +69,9 @@ impl<S: PrecisionSettings> Recorder for BinFileRecorderBuffered<S> {
|
|||||||
item: I,
|
item: I,
|
||||||
mut file: Self::RecordArgs,
|
mut file: Self::RecordArgs,
|
||||||
) -> Result<(), RecorderError> {
|
) -> Result<(), RecorderError> {
|
||||||
|
let config = bin_config();
|
||||||
let mut writer = str2writer!(file)?;
|
let mut writer = str2writer!(file)?;
|
||||||
bincode::serialize_into(&mut writer, &item)
|
bincode::serde::encode_into_std_write(&item, &mut writer, config)
|
||||||
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -78,7 +79,7 @@ impl<S: PrecisionSettings> Recorder for BinFileRecorderBuffered<S> {
|
|||||||
fn load_item<I: DeserializeOwned>(&self, mut file: Self::LoadArgs) -> Result<I, RecorderError> {
|
fn load_item<I: DeserializeOwned>(&self, mut file: Self::LoadArgs) -> Result<I, RecorderError> {
|
||||||
let mut reader = str2reader!(file)?;
|
let mut reader = str2reader!(file)?;
|
||||||
let state =
|
let state =
|
||||||
bincode::deserialize_from(&mut reader)
|
bincode::serde::decode_from_std_read(&mut reader, bin_config())
|
||||||
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
||||||
Ok(state)
|
Ok(state)
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
pub mod model;
|
pub mod model;
|
||||||
pub mod tokenizer;
|
pub mod tokenizer;
|
||||||
pub mod helper;
|
pub mod helper;
|
||||||
pub mod binrecorder;
|
pub mod binrecorderfast;
|
||||||
Reference in New Issue
Block a user