mirror of
https://gitea.hainer-ernst.de/rasmus/burn-stablediffusion-vibecode.git
synced 2026-06-11 02:09:21 +00:00
Switch to burn 0.9.0 to gain fast model io without the need for a custom recorder
This commit is contained in:
@@ -11,15 +11,17 @@ torch-backend = ["burn-tch"]
|
|||||||
wgpu-backend = ["burn-wgpu"]
|
wgpu-backend = ["burn-wgpu"]
|
||||||
|
|
||||||
[dependencies.burn-tch]
|
[dependencies.burn-tch]
|
||||||
version = "0.8.0"
|
package = "burn-tch"
|
||||||
|
git = "https://github.com/burn-rs/burn.git"
|
||||||
optional = true
|
optional = true
|
||||||
|
|
||||||
[dependencies.burn-wgpu]
|
[dependencies.burn-wgpu]
|
||||||
version = "0.8.0"
|
package = "burn-wgpu"
|
||||||
|
git = "https://github.com/burn-rs/burn.git"
|
||||||
optional = true
|
optional = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
burn = "0.8.0"
|
burn = { git = "https://github.com/burn-rs/burn.git" }
|
||||||
serde = {version = "1.0.171", features = ["std", "derive"]}
|
serde = {version = "1.0.171", features = ["std", "derive"]}
|
||||||
npy = "0.4.0"
|
npy = "0.4.0"
|
||||||
num-traits = "0.2.15"
|
num-traits = "0.2.15"
|
||||||
|
|||||||
@@ -22,8 +22,7 @@ cfg_if::cfg_if! {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
use burn::record::{self, Recorder, FullPrecisionSettings};
|
use burn::record::{self, Recorder, BinFileRecorder, FullPrecisionSettings};
|
||||||
use stablediffusion::binrecorderfast::{BinFileRecorderBuffered};
|
|
||||||
|
|
||||||
fn convert_dump_to_model<B: Backend>(dump_path: &str, model_name: &str, device: &B::Device) -> Result<(), Box<dyn Error>> {
|
fn convert_dump_to_model<B: Backend>(dump_path: &str, model_name: &str, device: &B::Device) -> Result<(), Box<dyn Error>> {
|
||||||
println!("Loading dump...");
|
println!("Loading dump...");
|
||||||
@@ -36,7 +35,7 @@ fn convert_dump_to_model<B: Backend>(dump_path: &str, model_name: &str, device:
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn save_model_file<B: Backend>(model: StableDiffusion<B>, name: &str) -> Result<(), record::RecorderError> {
|
fn save_model_file<B: Backend>(model: StableDiffusion<B>, name: &str) -> Result<(), record::RecorderError> {
|
||||||
BinFileRecorderBuffered::<FullPrecisionSettings>::new()
|
BinFileRecorder::<FullPrecisionSettings>::new()
|
||||||
.record(
|
.record(
|
||||||
model.into_record(),
|
model.into_record(),
|
||||||
name.into(),
|
name.into(),
|
||||||
|
|||||||
@@ -22,11 +22,10 @@ use std::env;
|
|||||||
use std::io;
|
use std::io;
|
||||||
use std::process;
|
use std::process;
|
||||||
|
|
||||||
use burn::record::{self, Recorder, FullPrecisionSettings};
|
use burn::record::{self, Recorder, BinFileRecorder, FullPrecisionSettings};
|
||||||
use stablediffusion::binrecorderfast::{BinFileRecorderBuffered};
|
|
||||||
|
|
||||||
fn load_stable_diffusion_model_file<B: Backend>(filename: &str) -> Result<StableDiffusion<B>, record::RecorderError> {
|
fn load_stable_diffusion_model_file<B: Backend>(filename: &str) -> Result<StableDiffusion<B>, record::RecorderError> {
|
||||||
BinFileRecorderBuffered::<FullPrecisionSettings>::new()
|
BinFileRecorder::<FullPrecisionSettings>::new()
|
||||||
.load(filename.into())
|
.load(filename.into())
|
||||||
.map(|record| StableDiffusionConfig::new().init().load_record(record))
|
.map(|record| StableDiffusionConfig::new().init().load_record(record))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,86 +0,0 @@
|
|||||||
use bincode;
|
|
||||||
use burn::record::{PrecisionSettings, Recorder, RecorderError, FileRecorder};
|
|
||||||
use std::fs::File;
|
|
||||||
use std::io::{BufReader, BufWriter};
|
|
||||||
use std::path::PathBuf;
|
|
||||||
use std::marker::PhantomData;
|
|
||||||
use serde::{de::DeserializeOwned, Serialize};
|
|
||||||
//use super::{bin_config, PrecisionSettings, Recorder, RecorderError};
|
|
||||||
|
|
||||||
fn bin_config() -> bincode::config::Configuration {
|
|
||||||
bincode::config::standard()
|
|
||||||
}
|
|
||||||
|
|
||||||
macro_rules! str2reader {
|
|
||||||
($file:expr) => {{
|
|
||||||
$file.set_extension(<Self as FileRecorder>::file_extension());
|
|
||||||
let path = $file.as_path();
|
|
||||||
File::open(path).map_err(|err| match err.kind() {
|
|
||||||
std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()),
|
|
||||||
_ => RecorderError::Unknown(err.to_string()),
|
|
||||||
}).map(|file| BufReader::new(file)) // wrap File in BufReader
|
|
||||||
}};
|
|
||||||
}
|
|
||||||
|
|
||||||
macro_rules! str2writer {
|
|
||||||
($file:expr) => {{
|
|
||||||
$file.set_extension(<Self as FileRecorder>::file_extension());
|
|
||||||
let path = $file.as_path();
|
|
||||||
|
|
||||||
if path.exists() {
|
|
||||||
//log::info!("File exists, replacing");
|
|
||||||
std::fs::remove_file(path).map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
|
||||||
}
|
|
||||||
|
|
||||||
File::create(path).map_err(|err| match err.kind() {
|
|
||||||
std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()),
|
|
||||||
_ => RecorderError::Unknown(err.to_string()),
|
|
||||||
}).map(|file| BufWriter::new(file)) // wrap File in BufWriter
|
|
||||||
}};
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Default, Clone)]
|
|
||||||
pub struct BinFileRecorderBuffered<S: PrecisionSettings> {
|
|
||||||
_settings: PhantomData<S>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S: PrecisionSettings> BinFileRecorderBuffered<S> {
|
|
||||||
pub fn new() -> Self {
|
|
||||||
BinFileRecorderBuffered {
|
|
||||||
_settings: PhantomData,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S: PrecisionSettings> FileRecorder for BinFileRecorderBuffered<S> {
|
|
||||||
fn file_extension() -> &'static str {
|
|
||||||
"bin"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S: PrecisionSettings> Recorder for BinFileRecorderBuffered<S> {
|
|
||||||
type Settings = S;
|
|
||||||
type RecordArgs = PathBuf;
|
|
||||||
type RecordOutput = ();
|
|
||||||
type LoadArgs = PathBuf;
|
|
||||||
|
|
||||||
fn save_item<I: Serialize>(
|
|
||||||
&self,
|
|
||||||
item: I,
|
|
||||||
mut file: Self::RecordArgs,
|
|
||||||
) -> Result<(), RecorderError> {
|
|
||||||
let config = bin_config();
|
|
||||||
let mut writer = str2writer!(file)?;
|
|
||||||
bincode::serde::encode_into_std_write(&item, &mut writer, config)
|
|
||||||
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn load_item<I: DeserializeOwned>(&self, mut file: Self::LoadArgs) -> Result<I, RecorderError> {
|
|
||||||
let mut reader = str2reader!(file)?;
|
|
||||||
let state =
|
|
||||||
bincode::serde::decode_from_std_read(&mut reader, bin_config())
|
|
||||||
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
|
|
||||||
Ok(state)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,4 +1,3 @@
|
|||||||
pub mod model;
|
pub mod model;
|
||||||
pub mod tokenizer;
|
pub mod tokenizer;
|
||||||
pub mod helper;
|
pub mod helper;
|
||||||
pub mod binrecorderfast;
|
|
||||||
Reference in New Issue
Block a user