mirror of
https://gitea.hainer-ernst.de/rasmus/burn-stablediffusion-vibecode.git
synced 2026-06-10 17:59:22 +00:00
Use buffer to increase model loading/saving speed tremendously
This commit is contained in:
@@ -13,4 +13,5 @@ npy = "0.4.0"
|
||||
num-traits = "0.2.15"
|
||||
rust_tokenizers = "8.1.0"
|
||||
regex = "1.9.1"
|
||||
image = "0.24.6"
|
||||
image = "0.24.6"
|
||||
bincode = "1.3.3"
|
||||
@@ -16,7 +16,8 @@ use burn::{
|
||||
|
||||
use burn_tch::{TchBackend, TchDevice};
|
||||
|
||||
use burn::record::{self, Recorder, BinFileRecorder, FullPrecisionSettings};
|
||||
use burn::record::{self, Recorder, FullPrecisionSettings};
|
||||
use stablediffusion::binrecorder::{BinFileRecorderBuffered};
|
||||
|
||||
fn convert_dump_to_model<B: Backend>(dump_path: &str, model_name: &str, device: &B::Device) -> Result<(), Box<dyn Error>> {
|
||||
println!("Loading dump...");
|
||||
@@ -29,7 +30,7 @@ fn convert_dump_to_model<B: Backend>(dump_path: &str, model_name: &str, device:
|
||||
}
|
||||
|
||||
fn save_model_file<B: Backend>(model: StableDiffusion<B>, name: &str) -> Result<(), record::RecorderError> {
|
||||
BinFileRecorder::<FullPrecisionSettings>::new()
|
||||
BinFileRecorderBuffered::<FullPrecisionSettings>::new()
|
||||
.record(
|
||||
model.into_record(),
|
||||
name.into(),
|
||||
|
||||
@@ -15,10 +15,11 @@ use std::env;
|
||||
use std::io;
|
||||
use std::process;
|
||||
|
||||
use burn::record::{self, Recorder, BinFileRecorder, FullPrecisionSettings};
|
||||
use burn::record::{self, Recorder, FullPrecisionSettings};
|
||||
use stablediffusion::binrecorder::{BinFileRecorderBuffered};
|
||||
|
||||
fn load_stable_diffusion_model_file<B: Backend>(filename: &str) -> Result<StableDiffusion<B>, record::RecorderError> {
|
||||
BinFileRecorder::<FullPrecisionSettings>::new()
|
||||
BinFileRecorderBuffered::<FullPrecisionSettings>::new()
|
||||
.load(filename.into())
|
||||
.map(|record| StableDiffusionConfig::new().init().load_record(record))
|
||||
}
|
||||
|
||||
85
src/binrecorder.rs
Normal file
85
src/binrecorder.rs
Normal file
@@ -0,0 +1,85 @@
|
||||
use bincode::{self, Options};
|
||||
use burn::record::{PrecisionSettings, Recorder, RecorderError, FileRecorder};
|
||||
use std::fs::File;
|
||||
use std::io::{BufReader, BufWriter};
|
||||
use std::path::PathBuf;
|
||||
use std::marker::PhantomData;
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
//use super::{bin_config, PrecisionSettings, Recorder, RecorderError};
|
||||
|
||||
/*fn bin_config() -> bincode::config::Configuration {
|
||||
bincode::config::standard()
|
||||
}*/
|
||||
|
||||
macro_rules! str2reader {
|
||||
($file:expr) => {{
|
||||
$file.set_extension(<Self as FileRecorder>::file_extension());
|
||||
let path = $file.as_path();
|
||||
File::open(path).map_err(|err| match err.kind() {
|
||||
std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()),
|
||||
_ => RecorderError::Unknown(err.to_string()),
|
||||
}).map(|file| BufReader::new(file)) // wrap File in BufReader
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! str2writer {
|
||||
($file:expr) => {{
|
||||
$file.set_extension(<Self as FileRecorder>::file_extension());
|
||||
let path = $file.as_path();
|
||||
|
||||
if path.exists() {
|
||||
//log::info!("File exists, replacing");
|
||||
std::fs::remove_file(path).map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
||||
}
|
||||
|
||||
File::create(path).map_err(|err| match err.kind() {
|
||||
std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()),
|
||||
_ => RecorderError::Unknown(err.to_string()),
|
||||
}).map(|file| BufWriter::new(file)) // wrap File in BufWriter
|
||||
}};
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct BinFileRecorderBuffered<S: PrecisionSettings> {
|
||||
_settings: PhantomData<S>,
|
||||
}
|
||||
|
||||
impl<S: PrecisionSettings> BinFileRecorderBuffered<S> {
|
||||
pub fn new() -> Self {
|
||||
BinFileRecorderBuffered {
|
||||
_settings: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: PrecisionSettings> FileRecorder for BinFileRecorderBuffered<S> {
|
||||
fn file_extension() -> &'static str {
|
||||
"bin"
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: PrecisionSettings> Recorder for BinFileRecorderBuffered<S> {
|
||||
type Settings = S;
|
||||
type RecordArgs = PathBuf;
|
||||
type RecordOutput = ();
|
||||
type LoadArgs = PathBuf;
|
||||
|
||||
fn save_item<I: Serialize>(
|
||||
&self,
|
||||
item: I,
|
||||
mut file: Self::RecordArgs,
|
||||
) -> Result<(), RecorderError> {
|
||||
let mut writer = str2writer!(file)?;
|
||||
bincode::serialize_into(&mut writer, &item)
|
||||
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_item<I: DeserializeOwned>(&self, mut file: Self::LoadArgs) -> Result<I, RecorderError> {
|
||||
let mut reader = str2reader!(file)?;
|
||||
let state =
|
||||
bincode::deserialize_from(&mut reader)
|
||||
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
||||
Ok(state)
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod model;
|
||||
pub mod tokenizer;
|
||||
pub mod helper;
|
||||
pub mod helper;
|
||||
pub mod binrecorder;
|
||||
Reference in New Issue
Block a user