Add model loading, saving, and conversion functionality

This commit is contained in:
Gadersd
2023-08-05 10:14:03 -04:00
parent 41fce2a47e
commit 9e247777fa
5 changed files with 160 additions and 126 deletions

View File

@@ -2,7 +2,7 @@ pub mod load;
use burn::{
config::Config,
module::Module,
module::{Module, Param},
tensor::{
backend::Backend,
Tensor,
@@ -27,12 +27,9 @@ pub struct StableDiffusionConfig {
}
impl StableDiffusionConfig {
fn init<B: Backend>(&self) -> StableDiffusion<B> {
pub fn init<B: Backend>(&self) -> StableDiffusion<B> {
let n_steps = 1000;
let alpha_cumulative_products = offset_cosine_schedule_cumprod::<B>(n_steps)
.into_data().value
.into_iter()
.map(|v: <Float as BasicOps<B>>::Elem| v.to_f64().unwrap()).collect();
let alpha_cumulative_products = offset_cosine_schedule_cumprod::<B>(n_steps).into();
let autoencoder = AutoencoderConfig::new().init();
let diffusion = UNetConfig::new().init();
@@ -51,7 +48,7 @@ impl StableDiffusionConfig {
#[derive(Module, Debug)]
pub struct StableDiffusion<B: Backend> {
n_steps: usize,
alpha_cumulative_products: Vec<f64>,
alpha_cumulative_products: Param<Tensor<B, 1>>,
autoencoder: Autoencoder<B>,
diffusion: UNet<B>,
clip: CLIP<B>,
@@ -90,8 +87,6 @@ impl<B: Backend> StableDiffusion<B> {
}
pub fn sample_latent(&self, context: Tensor<B, 3>, unconditional_context: Tensor<B, 2>, unconditional_guidance_scale: f64, n_steps: usize) -> Tensor<B, 4> {
assert!(self.n_steps % n_steps == 0);
let device = context.device();
let step_size = self.n_steps / n_steps;
@@ -107,9 +102,10 @@ impl<B: Backend> StableDiffusion<B> {
let mut latent = gen_noise();
for t in (0..self.n_steps).rev().step_by(step_size) {
let current_alpha = self.alpha_cumulative_products[t];
let prev_alpha = if t >= step_size {
self.alpha_cumulative_products[t - step_size]
let current_alpha: f64 = self.alpha_cumulative_products.val().slice([t..t + 1]).into_scalar().to_f64().unwrap();
let prev_alpha: f64 = if t >= step_size {
let i = t - step_size;
self.alpha_cumulative_products.val().slice([i..i + 1]).into_scalar().to_f64().unwrap()
} else {
1.0
};