diff --git a/Cargo.toml b/Cargo.toml index 1dbeb84..6960636 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,15 +11,17 @@ torch-backend = ["burn-tch"] wgpu-backend = ["burn-wgpu"] [dependencies.burn-tch] -version = "0.8.0" +package = "burn-tch" +git = "https://github.com/burn-rs/burn.git" optional = true [dependencies.burn-wgpu] -version = "0.8.0" +package = "burn-wgpu" +git = "https://github.com/burn-rs/burn.git" optional = true [dependencies] -burn = "0.8.0" +burn = { git = "https://github.com/burn-rs/burn.git" } serde = {version = "1.0.171", features = ["std", "derive"]} npy = "0.4.0" num-traits = "0.2.15" diff --git a/src/bin/convert/main.rs b/src/bin/convert/main.rs index d2f6b53..8476f8d 100644 --- a/src/bin/convert/main.rs +++ b/src/bin/convert/main.rs @@ -22,8 +22,7 @@ cfg_if::cfg_if! { } } -use burn::record::{self, Recorder, FullPrecisionSettings}; -use stablediffusion::binrecorderfast::{BinFileRecorderBuffered}; +use burn::record::{self, Recorder, BinFileRecorder, FullPrecisionSettings}; fn convert_dump_to_model(dump_path: &str, model_name: &str, device: &B::Device) -> Result<(), Box> { println!("Loading dump..."); @@ -36,7 +35,7 @@ fn convert_dump_to_model(dump_path: &str, model_name: &str, device: } fn save_model_file(model: StableDiffusion, name: &str) -> Result<(), record::RecorderError> { - BinFileRecorderBuffered::::new() + BinFileRecorder::::new() .record( model.into_record(), name.into(), diff --git a/src/bin/sample/main.rs b/src/bin/sample/main.rs index 245e92f..9a40d8a 100644 --- a/src/bin/sample/main.rs +++ b/src/bin/sample/main.rs @@ -22,11 +22,10 @@ use std::env; use std::io; use std::process; -use burn::record::{self, Recorder, FullPrecisionSettings}; -use stablediffusion::binrecorderfast::{BinFileRecorderBuffered}; +use burn::record::{self, Recorder, BinFileRecorder, FullPrecisionSettings}; fn load_stable_diffusion_model_file(filename: &str) -> Result, record::RecorderError> { - BinFileRecorderBuffered::::new() + BinFileRecorder::::new() .load(filename.into()) .map(|record| StableDiffusionConfig::new().init().load_record(record)) } diff --git a/src/binrecorderfast.rs b/src/binrecorderfast.rs deleted file mode 100644 index b1c6ca2..0000000 --- a/src/binrecorderfast.rs +++ /dev/null @@ -1,86 +0,0 @@ -use bincode; -use burn::record::{PrecisionSettings, Recorder, RecorderError, FileRecorder}; -use std::fs::File; -use std::io::{BufReader, BufWriter}; -use std::path::PathBuf; -use std::marker::PhantomData; -use serde::{de::DeserializeOwned, Serialize}; -//use super::{bin_config, PrecisionSettings, Recorder, RecorderError}; - -fn bin_config() -> bincode::config::Configuration { - bincode::config::standard() -} - -macro_rules! str2reader { - ($file:expr) => {{ - $file.set_extension(::file_extension()); - let path = $file.as_path(); - File::open(path).map_err(|err| match err.kind() { - std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()), - _ => RecorderError::Unknown(err.to_string()), - }).map(|file| BufReader::new(file)) // wrap File in BufReader - }}; -} - -macro_rules! str2writer { - ($file:expr) => {{ - $file.set_extension(::file_extension()); - let path = $file.as_path(); - - if path.exists() { - //log::info!("File exists, replacing"); - std::fs::remove_file(path).map_err(|err| RecorderError::Unknown(err.to_string()))?; - } - - File::create(path).map_err(|err| match err.kind() { - std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()), - _ => RecorderError::Unknown(err.to_string()), - }).map(|file| BufWriter::new(file)) // wrap File in BufWriter - }}; -} - -#[derive(Debug, Default, Clone)] -pub struct BinFileRecorderBuffered { - _settings: PhantomData, -} - -impl BinFileRecorderBuffered { - pub fn new() -> Self { - BinFileRecorderBuffered { - _settings: PhantomData, - } - } -} - -impl FileRecorder for BinFileRecorderBuffered { - fn file_extension() -> &'static str { - "bin" - } -} - -impl Recorder for BinFileRecorderBuffered { - type Settings = S; - type RecordArgs = PathBuf; - type RecordOutput = (); - type LoadArgs = PathBuf; - - fn save_item( - &self, - item: I, - mut file: Self::RecordArgs, - ) -> Result<(), RecorderError> { - let config = bin_config(); - let mut writer = str2writer!(file)?; - bincode::serde::encode_into_std_write(&item, &mut writer, config) - .map_err(|err| RecorderError::Unknown(err.to_string()))?; - Ok(()) - } - - fn load_item(&self, mut file: Self::LoadArgs) -> Result { - let mut reader = str2reader!(file)?; - let state = - bincode::serde::decode_from_std_read(&mut reader, bin_config()) - .map_err(|err| RecorderError::Unknown(err.to_string()))?; - Ok(state) - } -} diff --git a/src/lib.rs b/src/lib.rs index 46cd752..993bb59 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,3 @@ pub mod model; pub mod tokenizer; -pub mod helper; -pub mod binrecorderfast; \ No newline at end of file +pub mod helper; \ No newline at end of file