mirror of
https://gitea.hainer-ernst.de/rasmus/burn-stablediffusion-vibecode.git
synced 2026-06-11 02:09:21 +00:00
Add batching to generate two images at a time
This commit is contained in:
@@ -66,7 +66,7 @@ fn main() {
|
||||
let sd = sd.to_device(&device);
|
||||
|
||||
let unconditional_context = sd.unconditional_context(&tokenizer);
|
||||
let context = sd.context(&tokenizer, prompt).unsqueeze();
|
||||
let context = sd.context(&tokenizer, prompt).unsqueeze().repeat(0, 2); // generate 2 samples
|
||||
|
||||
println!("Sampling image...");
|
||||
let images = sd.sample_image(context, unconditional_context, unconditional_guidance_scale, n_steps);
|
||||
|
||||
@@ -125,13 +125,13 @@ impl<B: Backend> StableDiffusion<B> {
|
||||
}
|
||||
|
||||
fn forward_diffuser(&self, latent: Tensor<B, 4>, timestep: Tensor<B, 1, Int>, context: Tensor<B, 3>, unconditional_context: Tensor<B, 2>, unconditional_guidance_scale: f64) -> Tensor<B, 4> {
|
||||
///let [n_batch, n_channel, height, width] = latent.dims();
|
||||
let [n_batch, _, _, _] = latent.dims();
|
||||
//let latent = latent.repeat(0, 2);
|
||||
|
||||
let unconditional_latent = self.diffusion.forward(
|
||||
latent.clone(),
|
||||
timestep.clone(),
|
||||
unconditional_context.unsqueeze()
|
||||
unconditional_context.unsqueeze().repeat(0, n_batch)
|
||||
);
|
||||
|
||||
let conditional_latent = self.diffusion.forward(
|
||||
|
||||
Reference in New Issue
Block a user