Add batching to generate two images at a time

This commit is contained in:
Gadersd
2023-08-05 19:41:11 -04:00
parent 57b446c08d
commit 6e065035cd
2 changed files with 3 additions and 3 deletions

View File

@@ -66,7 +66,7 @@ fn main() {
let sd = sd.to_device(&device);
let unconditional_context = sd.unconditional_context(&tokenizer);
let context = sd.context(&tokenizer, prompt).unsqueeze();
let context = sd.context(&tokenizer, prompt).unsqueeze().repeat(0, 2); // generate 2 samples
println!("Sampling image...");
let images = sd.sample_image(context, unconditional_context, unconditional_guidance_scale, n_steps);

View File

@@ -125,13 +125,13 @@ impl<B: Backend> StableDiffusion<B> {
}
fn forward_diffuser(&self, latent: Tensor<B, 4>, timestep: Tensor<B, 1, Int>, context: Tensor<B, 3>, unconditional_context: Tensor<B, 2>, unconditional_guidance_scale: f64) -> Tensor<B, 4> {
///let [n_batch, n_channel, height, width] = latent.dims();
let [n_batch, _, _, _] = latent.dims();
//let latent = latent.repeat(0, 2);
let unconditional_latent = self.diffusion.forward(
latent.clone(),
timestep.clone(),
unconditional_context.unsqueeze()
unconditional_context.unsqueeze().repeat(0, n_batch)
);
let conditional_latent = self.diffusion.forward(