mirror of
https://gitea.hainer-ernst.de/rasmus/burn-stablediffusion-vibecode.git
synced 2026-06-10 17:59:22 +00:00
Add batching to generate two images at a time
This commit is contained in:
@@ -66,7 +66,7 @@ fn main() {
|
|||||||
let sd = sd.to_device(&device);
|
let sd = sd.to_device(&device);
|
||||||
|
|
||||||
let unconditional_context = sd.unconditional_context(&tokenizer);
|
let unconditional_context = sd.unconditional_context(&tokenizer);
|
||||||
let context = sd.context(&tokenizer, prompt).unsqueeze();
|
let context = sd.context(&tokenizer, prompt).unsqueeze().repeat(0, 2); // generate 2 samples
|
||||||
|
|
||||||
println!("Sampling image...");
|
println!("Sampling image...");
|
||||||
let images = sd.sample_image(context, unconditional_context, unconditional_guidance_scale, n_steps);
|
let images = sd.sample_image(context, unconditional_context, unconditional_guidance_scale, n_steps);
|
||||||
|
|||||||
@@ -125,13 +125,13 @@ impl<B: Backend> StableDiffusion<B> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn forward_diffuser(&self, latent: Tensor<B, 4>, timestep: Tensor<B, 1, Int>, context: Tensor<B, 3>, unconditional_context: Tensor<B, 2>, unconditional_guidance_scale: f64) -> Tensor<B, 4> {
|
fn forward_diffuser(&self, latent: Tensor<B, 4>, timestep: Tensor<B, 1, Int>, context: Tensor<B, 3>, unconditional_context: Tensor<B, 2>, unconditional_guidance_scale: f64) -> Tensor<B, 4> {
|
||||||
///let [n_batch, n_channel, height, width] = latent.dims();
|
let [n_batch, _, _, _] = latent.dims();
|
||||||
//let latent = latent.repeat(0, 2);
|
//let latent = latent.repeat(0, 2);
|
||||||
|
|
||||||
let unconditional_latent = self.diffusion.forward(
|
let unconditional_latent = self.diffusion.forward(
|
||||||
latent.clone(),
|
latent.clone(),
|
||||||
timestep.clone(),
|
timestep.clone(),
|
||||||
unconditional_context.unsqueeze()
|
unconditional_context.unsqueeze().repeat(0, n_batch)
|
||||||
);
|
);
|
||||||
|
|
||||||
let conditional_latent = self.diffusion.forward(
|
let conditional_latent = self.diffusion.forward(
|
||||||
|
|||||||
Reference in New Issue
Block a user