From 77f30aefa7650326194d4135f4392b5f33a0af91 Mon Sep 17 00:00:00 2001 From: Gadersd Date: Fri, 4 Aug 2023 17:01:44 -0400 Subject: [PATCH] Add first successful sampling implementation --- python/stable_diffusion.py | 9 ++------- python/stablediffusion.py | 8 +++++--- src/lib.rs | 2 -- src/main.rs | 11 +++++++++-- src/model/attention.rs | 4 ++-- src/model/autoencoder/mod.rs | 8 -------- src/model/clip/mod.rs | 2 +- src/model/stablediffusion/mod.rs | 31 +++++++++++++++++++++++-------- src/model/unet/mod.rs | 1 - 9 files changed, 42 insertions(+), 34 deletions(-) diff --git a/python/stable_diffusion.py b/python/stable_diffusion.py index d825fb4..10c6622 100644 --- a/python/stable_diffusion.py +++ b/python/stable_diffusion.py @@ -56,9 +56,6 @@ class ResnetBlock: def __call__(self, x): h = self.conv1(self.norm1(x).swish()) - '''v = h - print(v.shape) - print(v[0, 0:10, :, :].numpy())''' h = self.conv2(self.norm2(h).swish()) return self.nin_shortcut(x) + h @@ -145,7 +142,6 @@ class AutoencoderKL: latent = self.encoder(x) latent = self.quant_conv(latent) latent = latent[:, 0:4] # only the means - print("latent", latent.shape) latent = self.post_quant_conv(latent) return self.decoder(latent) @@ -339,15 +335,12 @@ class UNetModel: saved_inputs = [] for i,b in enumerate(self.input_blocks): - #print("input block", i) - print(x.numpy()) for bb in b: x = run(x, bb) saved_inputs.append(x) for bb in self.middle_block: x = run(x, bb) for i,b in enumerate(self.output_blocks): - #print("output block", i) x = x.cat(saved_inputs.pop(), dim=1) for bb in b: x = run(x, bb) @@ -644,7 +637,9 @@ if __name__ == "__main__": download_file('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', FILENAME) load_state_dict(model, torch_load(FILENAME)['state_dict'], strict=False) + print('Saving model...') sdsave.save_stable_diffusion(model, "params") + print('Model saved.') '''parser = argparse.ArgumentParser(description='Run Stable Diffusion', formatter_class=argparse.ArgumentDefaultsHelpFormatter) diff --git a/python/stablediffusion.py b/python/stablediffusion.py index 63634f1..fec6528 100644 --- a/python/stablediffusion.py +++ b/python/stablediffusion.py @@ -1,3 +1,4 @@ +import pathlib from autoencoder import save_autoencoder from unet import save_unet_model from clip import save_clip_text_transformer @@ -5,8 +6,9 @@ from clip import save_clip_text_transformer from save import save_scalar, save_tensor def save_stable_diffusion(stable_diffusion, path): + pathlib.Path(path).mkdir(parents=True, exist_ok=True) save_scalar(stable_diffusion.alphas_cumprod.shape[0], "n_steps", path) save_tensor(stable_diffusion.alphas_cumprod, 'alphas_cumprod', path) - save_autoencoder(stable_diffusion.autoencoder, 'autoencoder', path) - save_unet_model(stable_diffusion.diffusion, 'unet', path) - save_clip_text_transformer(stable_diffusion.clip, 'clip', path) \ No newline at end of file + save_autoencoder(stable_diffusion.first_stage_model, pathlib.Path(path, 'autoencoder')) + save_unet_model(stable_diffusion.model.diffusion_model, pathlib.Path(path, 'unet')) + save_clip_text_transformer(stable_diffusion.cond_stage_model.transformer.text_model, pathlib.Path(path, 'clip')) \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index d7dd5e9..993bb59 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,3 @@ -#![feature(generic_const_exprs)] - pub mod model; pub mod tokenizer; pub mod helper; \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 4483088..98020e6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -65,16 +65,23 @@ fn main() { let output = unet.forward(input, timesteps, context);*/ //print_tensor(output); + println!("Loading tokenizer..."); let tokenizer = SimpleTokenizer::new().unwrap(); + + println!("Loading Stable Diffusion..."); let sd: StableDiffusion = load_stable_diffusion("params", &device).unwrap(); + let sd = sd.to_device(&device); let unconditional_guidance_scale = 7.5; let unconditional_context = sd.unconditional_context(&tokenizer); - let context = sd.context(&tokenizer, "A rainbow pony is flying.").unsqueeze(); + let context = sd.context(&tokenizer, "A wine glass filled with pink flower petals.").unsqueeze(); - let n_steps = 5; + let n_steps = 100; + println!("Sampling images..."); let images = sd.sample_image(context, unconditional_context, unconditional_guidance_scale, n_steps); + + println!("Saving images..."); save_images(&images, "image_samples/", 512, 512).unwrap(); } diff --git a/src/model/attention.rs b/src/model/attention.rs index ba13109..e516c2e 100644 --- a/src/model/attention.rs +++ b/src/model/attention.rs @@ -35,7 +35,7 @@ pub fn qkv_attention(q: Tensor, k: Tensor, v: Tensor(seq_length: usize) -> Tensor { +pub fn attn_decoder_mask(seq_length: usize, device: &B::Device) -> Tensor { let mut mask = Tensor::::zeros([seq_length, seq_length]); for i in 0..(seq_length - 1) { @@ -43,5 +43,5 @@ pub fn attn_decoder_mask(seq_length: usize) -> Tensor { mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values); } - return mask; + return mask.to_device(device); } \ No newline at end of file diff --git a/src/model/autoencoder/mod.rs b/src/model/autoencoder/mod.rs index 1c47d7e..8daec5b 100644 --- a/src/model/autoencoder/mod.rs +++ b/src/model/autoencoder/mod.rs @@ -43,14 +43,6 @@ impl AutoencoderConfig { } -fn print_tensor(x: Tensor) { - let [_, channels, height, width] = x.dims(); - let channels = channels.min(10); - let data = x.slice([0..1, 0..channels, 0..height, 0..width]).into_data(); - println!("{:?}", data); -} - - #[derive(Module, Debug)] pub struct Autoencoder { encoder: Encoder, diff --git a/src/model/clip/mod.rs b/src/model/clip/mod.rs index c9a22a7..9c6e8f9 100644 --- a/src/model/clip/mod.rs +++ b/src/model/clip/mod.rs @@ -59,7 +59,7 @@ impl CLIP { pub fn forward(&self, x: Tensor) -> Tensor { let [n_batch, seq_len] = x.dims(); - let mask = attn_decoder_mask(seq_len); + let mask = attn_decoder_mask(seq_len, &x.device()); let embedded = self.token_embedding.forward(x) + self.position_embedding.val().slice([0..seq_len]).unsqueeze(); diff --git a/src/model/stablediffusion/mod.rs b/src/model/stablediffusion/mod.rs index 96df0fb..aa28d05 100644 --- a/src/model/stablediffusion/mod.rs +++ b/src/model/stablediffusion/mod.rs @@ -85,19 +85,21 @@ impl StableDiffusion { let start = b * num_elements_per_image; let end = start + num_elements_per_image; - flattened[start..end].into_iter().map(|v| v.to_u8().unwrap()).collect() + flattened[start..end].into_iter().map(|v| v.to_f64().unwrap().min(255.0).max(0.0).to_u8().unwrap()).collect() }).collect() } pub fn sample_latent(&self, context: Tensor, unconditional_context: Tensor, unconditional_guidance_scale: f64, n_steps: usize) -> Tensor { assert!(self.n_steps % n_steps == 0); + let device = context.device(); + let step_size = self.n_steps / n_steps; let [n_batches, _, _] = context.dims(); let gen_noise = || { - Tensor::random([n_batches, 4, 64, 64], Distribution::Normal(0.0, 1.0) ) + Tensor::random([n_batches, 4, 64, 64], Distribution::Normal(0.0, 1.0)).to_device(&device) }; let sigma = 0.0; // Use deterministic diffusion @@ -114,7 +116,7 @@ impl StableDiffusion { let sqrt_noise = (1.0 - current_alpha).sqrt(); - let timestep = Tensor::from_ints([t as i32]); + let timestep = Tensor::from_ints([t as i32]).to_device(&device); let pred_noise = self.forward_diffuser(latent.clone(), timestep, context.clone(), unconditional_context.clone(), unconditional_guidance_scale); let predx0 = (latent - pred_noise.clone() * sqrt_noise) / current_alpha.sqrt(); @@ -128,17 +130,29 @@ impl StableDiffusion { } fn forward_diffuser(&self, latent: Tensor, timestep: Tensor, context: Tensor, unconditional_context: Tensor, unconditional_guidance_scale: f64) -> Tensor { - let [n_batch, n_channel, height, width] = latent.dims(); - let latent = latent.repeat(0, 2); + ///let [n_batch, n_channel, height, width] = latent.dims(); + //let latent = latent.repeat(0, 2); - let latent = self.diffusion.forward( + let unconditional_latent = self.diffusion.forward( + latent.clone(), + timestep.clone(), + unconditional_context.unsqueeze() + ); + + let conditional_latent = self.diffusion.forward( + latent, + timestep, + context + ); + + /*let latent = self.diffusion.forward( latent.repeat(0, 2), timestep.repeat(0, 2), Tensor::cat(vec![unconditional_context.unsqueeze::<3>(), context], 0) ); let unconditional_latent = latent.clone().slice([0..n_batch]); - let conditional_latent = latent.slice([n_batch..2 * n_batch]); + let conditional_latent = latent.slice([n_batch..2 * n_batch]);*/ unconditional_latent.clone() + (conditional_latent - unconditional_latent) * unconditional_guidance_scale } @@ -148,10 +162,11 @@ impl StableDiffusion { } pub fn context(&self, tokenizer: &SimpleTokenizer, text: &str) -> Tensor { + let device = &self.devices()[0]; let text = format!("<|startoftext|>{}<|endoftext|>", text); let tokenized: Vec<_> = tokenizer.encode(&text).into_iter().map(|v| v as i32).collect(); - self.clip.forward(Tensor::from_ints(&tokenized[..]).unsqueeze()) + self.clip.forward(Tensor::from_ints(&tokenized[..]).to_device(device).unsqueeze()) } } diff --git a/src/model/unet/mod.rs b/src/model/unet/mod.rs index 4aff773..033cc9b 100644 --- a/src/model/unet/mod.rs +++ b/src/model/unet/mod.rs @@ -113,7 +113,6 @@ impl UNet { // input blocks for block in self.input_blocks.as_array() { - println!("{:?}", x.clone().flatten::<1>(0, 3).slice([0..100]).into_data()); x = block.forward(x, emb.clone(), context.clone()); saved_inputs.push(x.clone()) }