Switch to burn 0.9.0 to gain fast model io without the need for a custom recorder

This commit is contained in:
Gadersd
2023-08-06 19:44:13 -04:00
committed by Ben_Kosytorz
parent 1b5426c887
commit 68d4310aa8
5 changed files with 10 additions and 97 deletions

View File

@@ -11,15 +11,17 @@ torch-backend = ["burn-tch"]
wgpu-backend = ["burn-wgpu"] wgpu-backend = ["burn-wgpu"]
[dependencies.burn-tch] [dependencies.burn-tch]
version = "0.8.0" package = "burn-tch"
git = "https://github.com/burn-rs/burn.git"
optional = true optional = true
[dependencies.burn-wgpu] [dependencies.burn-wgpu]
version = "0.8.0" package = "burn-wgpu"
git = "https://github.com/burn-rs/burn.git"
optional = true optional = true
[dependencies] [dependencies]
burn = "0.8.0" burn = { git = "https://github.com/burn-rs/burn.git" }
serde = {version = "1.0.171", features = ["std", "derive"]} serde = {version = "1.0.171", features = ["std", "derive"]}
npy = "0.4.0" npy = "0.4.0"
num-traits = "0.2.15" num-traits = "0.2.15"

View File

@@ -22,8 +22,7 @@ cfg_if::cfg_if! {
} }
} }
use burn::record::{self, Recorder, FullPrecisionSettings}; use burn::record::{self, Recorder, BinFileRecorder, FullPrecisionSettings};
use stablediffusion::binrecorderfast::{BinFileRecorderBuffered};
fn convert_dump_to_model<B: Backend>(dump_path: &str, model_name: &str, device: &B::Device) -> Result<(), Box<dyn Error>> { fn convert_dump_to_model<B: Backend>(dump_path: &str, model_name: &str, device: &B::Device) -> Result<(), Box<dyn Error>> {
println!("Loading dump..."); println!("Loading dump...");
@@ -36,7 +35,7 @@ fn convert_dump_to_model<B: Backend>(dump_path: &str, model_name: &str, device:
} }
fn save_model_file<B: Backend>(model: StableDiffusion<B>, name: &str) -> Result<(), record::RecorderError> { fn save_model_file<B: Backend>(model: StableDiffusion<B>, name: &str) -> Result<(), record::RecorderError> {
BinFileRecorderBuffered::<FullPrecisionSettings>::new() BinFileRecorder::<FullPrecisionSettings>::new()
.record( .record(
model.into_record(), model.into_record(),
name.into(), name.into(),

View File

@@ -22,11 +22,10 @@ use std::env;
use std::io; use std::io;
use std::process; use std::process;
use burn::record::{self, Recorder, FullPrecisionSettings}; use burn::record::{self, Recorder, BinFileRecorder, FullPrecisionSettings};
use stablediffusion::binrecorderfast::{BinFileRecorderBuffered};
fn load_stable_diffusion_model_file<B: Backend>(filename: &str) -> Result<StableDiffusion<B>, record::RecorderError> { fn load_stable_diffusion_model_file<B: Backend>(filename: &str) -> Result<StableDiffusion<B>, record::RecorderError> {
BinFileRecorderBuffered::<FullPrecisionSettings>::new() BinFileRecorder::<FullPrecisionSettings>::new()
.load(filename.into()) .load(filename.into())
.map(|record| StableDiffusionConfig::new().init().load_record(record)) .map(|record| StableDiffusionConfig::new().init().load_record(record))
} }

View File

@@ -1,86 +0,0 @@
use bincode;
use burn::record::{PrecisionSettings, Recorder, RecorderError, FileRecorder};
use std::fs::File;
use std::io::{BufReader, BufWriter};
use std::path::PathBuf;
use std::marker::PhantomData;
use serde::{de::DeserializeOwned, Serialize};
//use super::{bin_config, PrecisionSettings, Recorder, RecorderError};
fn bin_config() -> bincode::config::Configuration {
bincode::config::standard()
}
macro_rules! str2reader {
($file:expr) => {{
$file.set_extension(<Self as FileRecorder>::file_extension());
let path = $file.as_path();
File::open(path).map_err(|err| match err.kind() {
std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()),
_ => RecorderError::Unknown(err.to_string()),
}).map(|file| BufReader::new(file)) // wrap File in BufReader
}};
}
macro_rules! str2writer {
($file:expr) => {{
$file.set_extension(<Self as FileRecorder>::file_extension());
let path = $file.as_path();
if path.exists() {
//log::info!("File exists, replacing");
std::fs::remove_file(path).map_err(|err| RecorderError::Unknown(err.to_string()))?;
}
File::create(path).map_err(|err| match err.kind() {
std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()),
_ => RecorderError::Unknown(err.to_string()),
}).map(|file| BufWriter::new(file)) // wrap File in BufWriter
}};
}
#[derive(Debug, Default, Clone)]
pub struct BinFileRecorderBuffered<S: PrecisionSettings> {
_settings: PhantomData<S>,
}
impl<S: PrecisionSettings> BinFileRecorderBuffered<S> {
pub fn new() -> Self {
BinFileRecorderBuffered {
_settings: PhantomData,
}
}
}
impl<S: PrecisionSettings> FileRecorder for BinFileRecorderBuffered<S> {
fn file_extension() -> &'static str {
"bin"
}
}
impl<S: PrecisionSettings> Recorder for BinFileRecorderBuffered<S> {
type Settings = S;
type RecordArgs = PathBuf;
type RecordOutput = ();
type LoadArgs = PathBuf;
fn save_item<I: Serialize>(
&self,
item: I,
mut file: Self::RecordArgs,
) -> Result<(), RecorderError> {
let config = bin_config();
let mut writer = str2writer!(file)?;
bincode::serde::encode_into_std_write(&item, &mut writer, config)
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
Ok(())
}
fn load_item<I: DeserializeOwned>(&self, mut file: Self::LoadArgs) -> Result<I, RecorderError> {
let mut reader = str2reader!(file)?;
let state =
bincode::serde::decode_from_std_read(&mut reader, bin_config())
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
Ok(state)
}
}

View File

@@ -1,4 +1,3 @@
pub mod model; pub mod model;
pub mod tokenizer; pub mod tokenizer;
pub mod helper; pub mod helper;
pub mod binrecorderfast;