From 6e065035cd25a02a84cb33e8696d0582346084dd Mon Sep 17 00:00:00 2001 From: Gadersd Date: Sat, 5 Aug 2023 19:41:11 -0400 Subject: [PATCH] Add batching to generate two images at a time --- src/bin/sample/main.rs | 2 +- src/model/stablediffusion/mod.rs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/bin/sample/main.rs b/src/bin/sample/main.rs index 278c91d..8efeea0 100644 --- a/src/bin/sample/main.rs +++ b/src/bin/sample/main.rs @@ -66,7 +66,7 @@ fn main() { let sd = sd.to_device(&device); let unconditional_context = sd.unconditional_context(&tokenizer); - let context = sd.context(&tokenizer, prompt).unsqueeze(); + let context = sd.context(&tokenizer, prompt).unsqueeze().repeat(0, 2); // generate 2 samples println!("Sampling image..."); let images = sd.sample_image(context, unconditional_context, unconditional_guidance_scale, n_steps); diff --git a/src/model/stablediffusion/mod.rs b/src/model/stablediffusion/mod.rs index 5390590..72b5731 100644 --- a/src/model/stablediffusion/mod.rs +++ b/src/model/stablediffusion/mod.rs @@ -125,13 +125,13 @@ impl StableDiffusion { } fn forward_diffuser(&self, latent: Tensor, timestep: Tensor, context: Tensor, unconditional_context: Tensor, unconditional_guidance_scale: f64) -> Tensor { - ///let [n_batch, n_channel, height, width] = latent.dims(); + let [n_batch, _, _, _] = latent.dims(); //let latent = latent.repeat(0, 2); let unconditional_latent = self.diffusion.forward( latent.clone(), timestep.clone(), - unconditional_context.unsqueeze() + unconditional_context.unsqueeze().repeat(0, n_batch) ); let conditional_latent = self.diffusion.forward(