diff --git a/Cargo.toml b/Cargo.toml index b0ae503..f78722c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -default = ["torch-backend"] +default = ["wgpu-backend"] torch-backend = ["burn-tch"] wgpu-backend = ["burn-wgpu"] @@ -22,6 +22,7 @@ optional = true [dependencies] burn = { git = "https://github.com/burn-rs/burn.git" } +burn-ndarray = { package = "burn-ndarray", git = "https://github.com/burn-rs/burn.git" } serde = {version = "1.0.171", features = ["std", "derive"]} npy = "0.4.0" num-traits = "0.2.15" diff --git a/README.md b/README.md index 53738df..d5e3fdf 100644 --- a/README.md +++ b/README.md @@ -20,18 +20,19 @@ Start by downloading the SDv1-4.bin model provided on HuggingFace. wget https://huggingface.co/Gadersd/Stable-Diffusion-Burn/resolve/main/V1/SDv1-4.bin ``` -Next, set the appropriate CUDA version. It may be possible to run the model using wgpu without the need for torch in the future using `cargo run --features wgpu-backend...` but currently wgpu doesn't support buffer sizes large enough for Stable Diffusion. - -```bash -export TORCH_CUDA_VERSION=cu113 -``` ### Step 2: Run the Sample Binary -Invoke the sample binary provided in the rust code, as shown below: +Invoke the sample binary provided in the rust code. By default, wgpu is used which requires a gpu with at least 10 GB of VRAM (will be lower in the future), but torch can be used with the `torch-backend` feature and can run on a 6 GB gpu. ```bash +# wgpu (NEEDS >= 10 GB VRAM) # Arguments: cargo run --release --bin sample burn SDv1-4 7.5 20 "An ancient mossy stone." img + +# torch (at least 6 GB VRAM, possibly less) +export TORCH_CUDA_VERSION=cu113 +# Arguments: +cargo run --release --features torch-backend --bin sample burn SDv1-4 7.5 20 "An ancient mossy stone." img ``` This command will generate an image according to the provided prompt, which will be saved as 'img0.png'. diff --git a/src/bin/convert/main.rs b/src/bin/convert/main.rs index 8476f8d..4885dab 100644 --- a/src/bin/convert/main.rs +++ b/src/bin/convert/main.rs @@ -14,13 +14,7 @@ use burn::{ }, }; -cfg_if::cfg_if! { - if #[cfg(feature = "torch-backend")] { - use burn_tch::{TchBackend, TchDevice}; - } else if #[cfg(feature = "wgpu-backend")] { - use burn_wgpu::{WgpuBackend, WgpuDevice, AutoGraphicsApi}; - } -} +use burn_ndarray::{NdArrayBackend, NdArrayDevice}; use burn::record::{self, Recorder, BinFileRecorder, FullPrecisionSettings}; @@ -43,15 +37,8 @@ fn save_model_file(model: StableDiffusion, name: &str) -> Result< } fn main() { - cfg_if::cfg_if! { - if #[cfg(feature = "torch-backend")] { - type Backend = TchBackend; - let device = TchDevice::Cpu; - } else if #[cfg(feature = "wgpu-backend")] { - type Backend = WgpuBackend; - let device = WgpuDevice::CPU; - } - } + type Backend = NdArrayBackend; + let device = NdArrayDevice::Cpu; let args: Vec = env::args().collect(); if args.len() != 3 { diff --git a/src/bin/sample/main.rs b/src/bin/sample/main.rs index 9a40d8a..c92cb3c 100644 --- a/src/bin/sample/main.rs +++ b/src/bin/sample/main.rs @@ -74,11 +74,11 @@ fn main() { process::exit(1); }) }; - + let sd = sd.to_device(&device); let unconditional_context = sd.unconditional_context(&tokenizer); - let context = sd.context(&tokenizer, prompt).unsqueeze().repeat(0, 2); // generate 2 samples + let context = sd.context(&tokenizer, prompt).unsqueeze::<3>();//.repeat(0, 2); // generate 2 samples println!("Sampling image..."); let images = sd.sample_image(context, unconditional_context, unconditional_guidance_scale, n_steps); diff --git a/src/model/stablediffusion/mod.rs b/src/model/stablediffusion/mod.rs index 72b5731..fb3dc0a 100644 --- a/src/model/stablediffusion/mod.rs +++ b/src/model/stablediffusion/mod.rs @@ -59,6 +59,11 @@ impl StableDiffusion { let [n_batch, _, _] = context.dims(); let latent = self.sample_latent(context, unconditional_context, unconditional_guidance_scale, n_steps); + self.latent_to_image(latent) + } + + pub fn latent_to_image(&self, latent: Tensor) -> Vec> { + let [n_batch, _, _, _] = latent.dims(); let image = self.autoencoder.decode_latent(latent * (1.0 / 0.18215)); let n_channel = 3; @@ -157,7 +162,7 @@ impl StableDiffusion { } pub fn context(&self, tokenizer: &SimpleTokenizer, text: &str) -> Tensor { - let device = &self.devices()[0]; + let device = &self.clip.devices()[0]; let text = format!("<|startoftext|>{}<|endoftext|>", text); let tokenized: Vec<_> = tokenizer.encode(&text).into_iter().map(|v| v as i32).collect();