Update to burn v0.14.0 and switch to .mpk model file

This commit is contained in:
Hermes
2024-10-05 14:19:49 -04:00
parent 9e4d7bd310
commit 893fb0950d
19 changed files with 366 additions and 311 deletions

View File

@@ -14,11 +14,11 @@ git = "https://github.com/burn-rs/burn.git"
optional = true
[dependencies]
burn = { git = "https://github.com/burn-rs/burn.git" }
burn-ndarray = { package = "burn-ndarray", git = "https://github.com/burn-rs/burn.git" }
burn-tch = { package = "burn-tch", git = "https://github.com/burn-rs/burn.git" }
burn-autodiff = { package = "burn-autodiff", git = "https://github.com/burn-rs/burn.git" }
tch = "0.13.0"
burn = "0.14.0"
burn-ndarray = "0.14.0"
burn-tch = "0.14.0"
burn-autodiff = "0.14.0"
tch = "0.15.0"
serde = {version = "1.0.171", features = ["std", "derive"]}
npy = "0.4.0"
num-traits = "0.2.15"

View File

@@ -4,12 +4,14 @@ Stable-Diffusion-Burn is a Rust-based project which ports the V1 stable diffusio
## How To Use
### Step 0: Install libtorch v2.4.1
### Step 1: Download the Model and Set Environment Variables
Start by downloading the SDv1-4.bin model provided on HuggingFace.
Start by downloading the SDv1-4 model provided on HuggingFace.
```bash
wget https://huggingface.co/Gadersd/Stable-Diffusion-Burn/resolve/main/V1/SDv1-4.bin
wget https://huggingface.co/Gadersd/Stable-Diffusion-Burn/resolve/main/SDv1-4.mpk
```
### Step 2: Run the Sample Binary
@@ -18,9 +20,13 @@ Invoke the sample binary provided in the rust code. By default, torch is used. T
```bash
# torch (at least 6 GB VRAM, possibly less)
export TORCH_CUDA_VERSION=cu113
# Arguments: <model_type(burn or dump)> <model> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image>
cargo run --release --bin sample burn SDv1-4 7.5 20 "An ancient mossy stone." img
# Arguments: <model_type(burn or dump)> <model_name> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image_name> [cuda, mps, cpu]
# Cuda
cargo run --release --bin sample burn SDv1-4 7.5 20 "An ancient mossy stone." img cuda
# Mps(Mac)
cargo run --release --bin sample burn SDv1-4 7.5 20 "An ancient mossy stone." img mps
# wgpu (UNSTABLE)
# Arguments: <model_type(burn or dump)> <model> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image>
@@ -33,7 +39,7 @@ This command will generate an image according to the provided prompt, which will
### Optional: Extract and Convert a Fine-Tuned Model
If users are interested in using a fine-tuned version of stable diffusion, the Python scripts provided in this project can be used to transform a weight dump into a Burn model file. Note: the tinygrad dependency should be installed from source rather than with pip.
If users are interested in using a fine-tuned version of stable diffusion, the Python scripts provided in this project can be used to transform a weight dump into a Burn model file. This does not work on Windows.
```bash
# Step into the Python directory
@@ -42,6 +48,9 @@ cd python
# Download the model, this is just the base v1.4 model as an example
wget https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt
# Install tinygrad
pip install -r requirements.txt
# Extract the weights
CPU=1 python3 dump.py sd-v1-4.ckpt

View File

@@ -13,10 +13,11 @@ from collections import namedtuple
from tqdm import tqdm
from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes, GlobalCounters
from tinygrad.helpers import GlobalCounters
from tinygrad import dtypes
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
from extra.utils import download_file
from tinygrad.state import torch_load, load_state_dict
#from extra.utils import download_file
from tinygrad.nn.state import torch_load, load_state_dict
# TODO: refactor AttnBlock, CrossAttention, CLIPAttention to share code

1
python/requirements.txt Normal file
View File

@@ -0,0 +1 @@
tinygrad==0.9.2

View File

@@ -1,13 +1,16 @@
use burn::tensor::{activation::softmax, Tensor};
use burn::prelude::Backend;
/*pub type FloatTensor<B, const D: usize> = <B as burn::tensor::backend::Backend>::TensorPrimitive<D>;
pub trait Backend: burn::tensor::backend::Backend {
fn qkv_attention(
q: Self::TensorPrimitive<3>,
k: Self::TensorPrimitive<3>,
v: Self::TensorPrimitive<3>,
mask: Option<Self::TensorPrimitive<2>>,
q: FloatTensor<Self, 3>,
k: FloatTensor<Self, 3>,
v: FloatTensor<Self, 3>,
mask: Option<FloatTensor<Self, 2>>,
n_head: usize,
) -> Self::TensorPrimitive<3> {
) -> FloatTensor<Self, 3> {
qkv_attention(
Tensor::<Self, 3>::from_primitive(q),
Tensor::from_primitive(k),
@@ -18,24 +21,23 @@ pub trait Backend: burn::tensor::backend::Backend {
.into_primitive()
}
fn attn_decoder_mask(seq_length: usize, device: &Self::Device) -> Self::TensorPrimitive<2> {
fn attn_decoder_mask(seq_length: usize, device: &Self::Device) -> FloatTensor<Self, 2> {
attn_decoder_mask::<Self>(seq_length, device).into_primitive()
}
}
use burn::tensor::ops::TensorOps;
use burn::tensor::Float;
use burn_tch::{self, TchElement, TchTensor};
use tch;
impl<E: TchElement> Backend for burn_tch::TchBackend<E> {
impl<E: TchElement> Backend for burn_tch::LibTorch<E> {
fn qkv_attention(
q: Self::TensorPrimitive<3>,
k: Self::TensorPrimitive<3>,
v: Self::TensorPrimitive<3>,
mask: Option<Self::TensorPrimitive<2>>,
q: FloatTensor<Self, 3>,
k: FloatTensor<Self, 3>,
v: FloatTensor<Self, 3>,
mask: Option<FloatTensor<Self, 2>>,
n_head: usize,
) -> Self::TensorPrimitive<3> {
) -> FloatTensor<Self, 2> {
let q = Tensor::from_primitive(q);
let k = Tensor::from_primitive(k);
let v = Tensor::from_primitive(v);
@@ -56,7 +58,7 @@ impl<E: TchElement> Backend for burn_tch::TchBackend<E> {
// for some reason torch crashes when mask is None
let mask = mask.unwrap_or_else(|| {
Tensor::<Self, 2, Float>::zeros_device([q_ctx, k_ctx], &Self::device(&v))
Tensor::<Self, 2, Float>::zeros([q_ctx, k_ctx], &Self::device(&v))
.into_primitive()
});
@@ -68,6 +70,7 @@ impl<E: TchElement> Backend for burn_tch::TchBackend<E> {
Some(mask.tensor),
0.0,
false,
None,
),
))
.swap_dims(1, 2)
@@ -78,11 +81,11 @@ impl<E: TchElement> Backend for burn_tch::TchBackend<E> {
use burn_autodiff;
impl<B: Backend> Backend for burn_autodiff::ADBackendDecorator<B> {}
impl<B: Backend> Backend for burn_autodiff::Autodiff<B> {}*/
use std::f32::NEG_INFINITY;
fn qkv_attention<B: Backend>(
pub fn qkv_attention<B: Backend>(
q: Tensor<B, 3>,
k: Tensor<B, 3>,
v: Tensor<B, 3>,
@@ -124,13 +127,13 @@ fn qkv_attention<B: Backend>(
return o;
}
fn attn_decoder_mask<B: Backend>(seq_length: usize, device: &B::Device) -> Tensor<B, 2> {
let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length]);
pub fn attn_decoder_mask<B: Backend>(seq_length: usize, device: &B::Device) -> Tensor<B, 2> {
let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length], device);
for i in 0..(seq_length - 1) {
let values = Tensor::<B, 2>::zeros([1, seq_length - (i + 1)]).add_scalar(NEG_INFINITY);
let values = Tensor::<B, 2>::zeros([1, seq_length - (i + 1)], device).add_scalar(NEG_INFINITY);
mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values);
}
return mask.to_device(device);
return mask;
}

View File

@@ -11,9 +11,9 @@ use burn::{
tensor::{backend::Backend, Tensor},
};
use burn_ndarray::{NdArrayBackend, NdArrayDevice};
use burn_ndarray::{NdArray, NdArrayDevice};
use burn::record::{self, BinFileRecorder, FullPrecisionSettings, Recorder};
use burn::record::{self, NamedMpkFileRecorder, FullPrecisionSettings, Recorder};
fn convert_dump_to_model<B: Backend>(
dump_path: &str,
@@ -33,11 +33,11 @@ fn save_model_file<B: Backend>(
model: StableDiffusion<B>,
name: &str,
) -> Result<(), record::RecorderError> {
BinFileRecorder::<FullPrecisionSettings>::new().record(model.into_record(), name.into())
NamedMpkFileRecorder::<FullPrecisionSettings>::new().record(model.into_record(), name.into())
}
fn main() {
type Backend = NdArrayBackend<f32>;
type Backend = NdArray<f32>;
let device = NdArrayDevice::Cpu;
let args: Vec<String> = env::args().collect();

View File

@@ -14,7 +14,7 @@ cfg_if::cfg_if! {
if #[cfg(feature = "wgpu-backend")] {
use burn_wgpu::{WgpuBackend, WgpuDevice, AutoGraphicsApi};
} else {
use burn_tch::{TchBackend, TchDevice};
use burn_tch::{LibTorch, LibTorchDevice};
}
}
@@ -22,30 +22,21 @@ use std::env;
use std::io;
use std::process;
use burn::record::{self, BinFileRecorder, FullPrecisionSettings, Recorder};
use burn::record::{self, NamedMpkFileRecorder, FullPrecisionSettings, Recorder};
fn load_stable_diffusion_model_file<B: Backend>(
filename: &str,
device: &B::Device,
) -> Result<StableDiffusion<B>, record::RecorderError> {
BinFileRecorder::<FullPrecisionSettings>::new()
.load(filename.into())
.map(|record| StableDiffusionConfig::new().init().load_record(record))
NamedMpkFileRecorder::<FullPrecisionSettings>::new()
.load(filename.into(), device)
.map(|record| StableDiffusionConfig::new().init(device).load_record(record))
}
fn main() {
cfg_if::cfg_if! {
if #[cfg(feature = "wgpu-backend")] {
type Backend = WgpuBackend<AutoGraphicsApi, f32, i32>;
let device = WgpuDevice::BestAvailable;
} else {
type Backend = TchBackend<f32>;
let device = TchDevice::Cuda(0);
}
}
let args: Vec<String> = std::env::args().collect();
if args.len() != 7 {
eprintln!("Usage: {} <model_type(burn or dump)> <model_name> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image_name>", args[0]);
if args.len() != 7 && args.len() != 8 {
eprintln!("Usage: {} <model_type(burn or dump)> <model_name> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image_name> [device(cuda, mps, cpu)]", args[0]);
process::exit(1);
}
@@ -62,11 +53,40 @@ fn main() {
let prompt = &args[5];
let output_image_name = &args[6];
// Optional device parameter
let device_arg = if args.len() == 8 { Some(&args[7]) } else { None };
cfg_if::cfg_if! {
if #[cfg(feature = "wgpu-backend")] {
type Backend = WgpuBackend<AutoGraphicsApi, f32, i32>;
let device = WgpuDevice::BestAvailable;
} else {
type Backend = LibTorch<f32>;
let device = if let Some(dev_str) = device_arg {
match dev_str.to_lowercase().as_str() {
"cpu" => LibTorchDevice::Cpu,
"mps" => LibTorchDevice::Mps,
s if s.starts_with("cuda") => {
let idx = s[4..].parse().unwrap_or(0);
LibTorchDevice::Cuda(idx)
}
_ => {
eprintln!("Unknown device: {}", dev_str);
process::exit(1);
}
}
} else {
LibTorchDevice::Cuda(0)
};
}
}
println!("Loading tokenizer...");
let tokenizer = SimpleTokenizer::new().unwrap();
println!("Loading model...");
let sd: StableDiffusion<Backend> = if model_type == "burn" {
load_stable_diffusion_model_file(model_name).unwrap_or_else(|err| {
load_stable_diffusion_model_file(model_name, &device).unwrap_or_else(|err| {
eprintln!("Error loading model: {}", err);
process::exit(1);
})
@@ -77,8 +97,6 @@ fn main() {
})
};
let sd = sd.to_device(&device);
let unconditional_context = sd.unconditional_context(&tokenizer);
let context = sd.context(&tokenizer, prompt).unsqueeze::<3>(); //.repeat(0, 2); // generate 2 samples

View File

@@ -45,12 +45,12 @@ pub fn qkv_attention<B: Backend>(
}
pub fn attn_decoder_mask<B: Backend>(seq_length: usize, device: &B::Device) -> Tensor<B, 2> {
let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length]);
let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length], device);
for i in 0..(seq_length - 1) {
let values = Tensor::<B, 2>::zeros([1, seq_length - (i + 1)]).add_scalar(NEG_INFINITY);
let values = Tensor::<B, 2>::zeros([1, seq_length - (i + 1)], device).add_scalar(NEG_INFINITY);
mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values);
}
return mask.to_device(device);
return mask;
}

View File

@@ -71,7 +71,7 @@ fn load_padded_conv2d<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<PaddedConv2d<B>, Box<dyn Error>> {
let conv = load_conv2d(&format!("{}/{}", path, "conv"), device)?;
let mut conv = load_conv2d(&format!("{}/{}", path, "conv"), device)?;
let channels = load_tensor::<B, 1>("channels", path, device)?;
let channels = tensor_to_array_2(channels);
@@ -81,18 +81,21 @@ fn load_padded_conv2d<B: Backend>(
let padding = load_tensor::<B, 1>("padding", path, device)?;
let padding: [usize; 4] = tensor_to_array(padding);
let padding = Padding::new(padding[0], padding[1], padding[2], padding[3]);
let padding = PaddingCfg::new(padding[0], padding[1], padding[2], padding[3]);
let mut record = conv.into_record();
//let mut record = conv.into_record();
let mut padded_conv: PaddedConv2d<B> = PaddedConv2dConfig::new(channels, kernel_size, padding)
.with_stride(stride)
.init();
.init(device);
let padding_actual =
PaddingConfig2d::Explicit(padded_conv.padding_actual[0], padded_conv.padding_actual[1]);
record.padding = <PaddingConfig2d as Module<B>>::into_record(padding_actual);
padded_conv.conv = padded_conv.conv.load_record(record);
conv.padding = burn::module::Ignored(padding_actual);
padded_conv.conv = conv;
//record.padding = <PaddingConfig2d as Module<B>>::into_record(padding_actual);
//padded_conv.conv = padded_conv.conv.load_record(record);
Ok(padded_conv)
}

View File

@@ -18,7 +18,8 @@ use burn::{
use super::groupnorm::*;
use super::silu::*;
use crate::backend::Backend as MyBackend;
//use crate::backend::Backend as MyBackend;
use crate::backend::{qkv_attention, attn_decoder_mask};
use std::iter;
@@ -26,13 +27,13 @@ use std::iter;
pub struct AutoencoderConfig {}
impl AutoencoderConfig {
pub fn init<B: Backend>(&self) -> Autoencoder<B> {
pub fn init<B: Backend>(&self, device: &B::Device) -> Autoencoder<B> {
let encoder =
EncoderConfig::new(vec![(128, 128), (128, 256), (256, 512), (512, 512)], 32, 8).init();
EncoderConfig::new(vec![(128, 128), (128, 256), (256, 512), (512, 512)], 32, 8).init(device);
let decoder =
DecoderConfig::new(vec![(512, 512), (512, 512), (512, 256), (256, 128)], 32).init();
let quant_conv = Conv2dConfig::new([8, 8], [1, 1]).init();
let post_quant_conv = Conv2dConfig::new([4, 4], [1, 1]).init();
DecoderConfig::new(vec![(512, 512), (512, 512), (512, 256), (256, 128)], 32).init(device);
let quant_conv = Conv2dConfig::new([8, 8], [1, 1]).init(device);
let post_quant_conv = Conv2dConfig::new([4, 4], [1, 1]).init(device);
Autoencoder {
encoder,
@@ -51,7 +52,7 @@ pub struct Autoencoder<B: Backend> {
post_quant_conv: Conv2d<B>,
}
impl<B: MyBackend> Autoencoder<B> {
impl<B: Backend> Autoencoder<B> {
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
self.decode_latent(self.encode_image(x))
}
@@ -78,7 +79,7 @@ pub struct EncoderConfig {
}
impl EncoderConfig {
fn init<B: Backend>(&self) -> Encoder<B> {
fn init<B: Backend>(&self, device: &B::Device) -> Encoder<B> {
let n_expanded_channels_initial = self
.channels
.first()
@@ -88,7 +89,7 @@ impl EncoderConfig {
let conv_in = Conv2dConfig::new([3, n_expanded_channels_initial], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
.init(device);
let blocks = self
.channels
@@ -96,16 +97,16 @@ impl EncoderConfig {
.enumerate()
.map(|(i, &(n_channel_in, n_channel_out))| {
let downsample = i != self.channels.len() - 1;
EncoderBlockConfig::new(n_channel_in, n_channel_out, downsample).init()
EncoderBlockConfig::new(n_channel_in, n_channel_out, downsample).init(device)
})
.collect();
let mid = MidConfig::new(n_expanded_channels_final).init();
let norm_out = GroupNormConfig::new(self.n_group, n_expanded_channels_final).init();
let mid = MidConfig::new(n_expanded_channels_final).init(device);
let norm_out = GroupNormConfig::new(self.n_group, n_expanded_channels_final).init(device);
let silu = SILU::new();
let conv_out = Conv2dConfig::new([n_expanded_channels_final, self.n_channels_out], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
.init(device);
Encoder {
conv_in,
@@ -128,7 +129,7 @@ pub struct Encoder<B: Backend> {
conv_out: Conv2d<B>,
}
impl<B: MyBackend> Encoder<B> {
impl<B: Backend> Encoder<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.conv_in.forward(x);
@@ -150,7 +151,7 @@ pub struct DecoderConfig {
}
impl DecoderConfig {
fn init<B: Backend>(&self) -> Decoder<B> {
fn init<B: Backend>(&self, device: &B::Device) -> Decoder<B> {
let n_expanded_channels = self
.channels
.first()
@@ -160,8 +161,8 @@ impl DecoderConfig {
let conv_in = Conv2dConfig::new([4, n_expanded_channels], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
let mid = MidConfig::new(n_expanded_channels).init();
.init(device);
let mid = MidConfig::new(n_expanded_channels).init(device);
let blocks = self
.channels
@@ -169,15 +170,15 @@ impl DecoderConfig {
.enumerate()
.map(|(i, &(n_channel_in, n_channel_out))| {
let upsample = i != self.channels.len() - 1;
DecoderBlockConfig::new(n_channel_in, n_channel_out, upsample).init()
DecoderBlockConfig::new(n_channel_in, n_channel_out, upsample).init(device)
})
.collect();
let norm_out = GroupNormConfig::new(self.n_group, n_condensed_channels).init();
let norm_out = GroupNormConfig::new(self.n_group, n_condensed_channels).init(device);
let silu = SILU::new();
let conv_out = Conv2dConfig::new([n_condensed_channels, 3], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
.init(device);
Decoder {
conv_in,
@@ -200,7 +201,7 @@ pub struct Decoder<B: Backend> {
conv_out: Conv2d<B>,
}
impl<B: MyBackend> Decoder<B> {
impl<B: Backend> Decoder<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.conv_in.forward(x);
let x = self.mid.forward(x);
@@ -223,15 +224,15 @@ pub struct EncoderBlockConfig {
}
impl EncoderBlockConfig {
fn init<B: Backend>(&self) -> EncoderBlock<B> {
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init();
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
fn init<B: Backend>(&self, device: &B::Device) -> EncoderBlock<B> {
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init(device);
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(device);
let downsampler = if self.downsample {
let padding = Padding::new(0, 1, 0, 1);
let padding = PaddingCfg::new(0, 1, 0, 1);
Some(
PaddedConv2dConfig::new([self.n_channels_out, self.n_channels_out], 3, padding)
.with_stride(2)
.init(),
.init(device),
)
} else {
None
@@ -272,15 +273,15 @@ pub struct DecoderBlockConfig {
}
impl DecoderBlockConfig {
fn init<B: Backend>(&self) -> DecoderBlock<B> {
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init();
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
let res3 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
fn init<B: Backend>(&self, device: &B::Device) -> DecoderBlock<B> {
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init(device);
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(device);
let res3 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(device);
let upsampler = if self.upsample {
Some(
Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init(),
.init(device),
)
} else {
None
@@ -313,8 +314,7 @@ impl<B: Backend> DecoderBlock<B> {
let [n_batch, n_channel, height, width] = x.dims();
let x = x
.reshape([n_batch, n_channel, height, 1, width, 1])
.repeat(3, 2)
.repeat(5, 2)
.repeat(&[1, 1, 1, 2, 1, 2])
.reshape([n_batch, n_channel, 2 * height, 2 * width]);
d.forward(x)
} else {
@@ -329,11 +329,11 @@ pub struct PaddedConv2dConfig {
kernel_size: usize,
#[config(default = 1)]
stride: usize,
padding: Padding,
padding: PaddingCfg,
}
impl PaddedConv2dConfig {
fn init<B: Backend>(&self) -> PaddedConv2d<B> {
fn init<B: Backend>(&self, device: &B::Device) -> PaddedConv2d<B> {
let calc_padding = |p_left, p_right| {
let n = if p_left >= p_right {
0
@@ -351,12 +351,17 @@ impl PaddedConv2dConfig {
let conv = Conv2dConfig::new(self.channels, [self.kernel_size, self.kernel_size])
.with_stride([self.stride, self.stride])
.with_padding(PaddingConfig2d::Explicit(pad_vertical, pad_horizontal))
.init();
.init(device);
let kernel_size = self.kernel_size;
let stride = self.stride;
let padding = self.padding;
let padding = Padding {
pad_left: self.padding.pad_left,
pad_right: self.padding.pad_right,
pad_top: self.padding.pad_top,
pad_bottom: self.padding.pad_bottom,
};
PaddedConv2d {
conv,
@@ -406,7 +411,15 @@ impl<B: Backend> PaddedConv2d<B> {
}
}
#[derive(Config, Module, Copy, Debug)]
#[derive(Config, Debug)]
pub struct PaddingCfg {
pad_left: usize,
pad_right: usize,
pad_top: usize,
pad_bottom: usize,
}
#[derive(Module, Clone, Debug)]
pub struct Padding {
pad_left: usize,
pad_right: usize,
@@ -420,10 +433,10 @@ pub struct MidConfig {
}
impl MidConfig {
fn init<B: Backend>(&self) -> Mid<B> {
let block_1 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init();
let attn = ConvSelfAttentionBlockConfig::new(self.n_channel).init();
let block_2 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init();
fn init<B: Backend>(&self, device: &B::Device) -> Mid<B> {
let block_1 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init(device);
let attn = ConvSelfAttentionBlockConfig::new(self.n_channel).init(device);
let block_2 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init(device);
Mid {
block_1,
@@ -440,7 +453,7 @@ pub struct Mid<B: Backend> {
block_2: ResnetBlock<B>,
}
impl<B: MyBackend> Mid<B> {
impl<B: Backend> Mid<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.block_1.forward(x);
let x = self.attn.forward(x);
@@ -456,17 +469,17 @@ pub struct ResnetBlockConfig {
}
impl ResnetBlockConfig {
fn init<B: Backend>(&self) -> ResnetBlock<B> {
let norm1 = GroupNormConfig::new(32, self.in_channels).init();
fn init<B: Backend>(&self, device: &B::Device) -> ResnetBlock<B> {
let norm1 = GroupNormConfig::new(32, self.in_channels).init(device);
let conv1 = Conv2dConfig::new([self.in_channels, self.out_channels], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
let norm2 = GroupNormConfig::new(32, self.out_channels).init();
.init(device);
let norm2 = GroupNormConfig::new(32, self.out_channels).init(device);
let conv2 = Conv2dConfig::new([self.out_channels, self.out_channels], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
.init(device);
let nin_shortcut = if self.in_channels != self.out_channels {
Some(Conv2dConfig::new([self.in_channels, self.out_channels], [1, 1]).init())
Some(Conv2dConfig::new([self.in_channels, self.out_channels], [1, 1]).init(device))
} else {
None
};
@@ -520,12 +533,12 @@ pub struct ConvSelfAttentionBlockConfig {
}
impl ConvSelfAttentionBlockConfig {
fn init<B: Backend>(&self) -> ConvSelfAttentionBlock<B> {
let norm = GroupNormConfig::new(32, self.n_channel).init();
let q = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
let k = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
let v = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
let proj_out = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
fn init<B: Backend>(&self, device: &B::Device) -> ConvSelfAttentionBlock<B> {
let norm = GroupNormConfig::new(32, self.n_channel).init(device);
let q = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
let k = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
let v = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
let proj_out = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
ConvSelfAttentionBlock {
norm,
@@ -546,7 +559,7 @@ pub struct ConvSelfAttentionBlock<B: Backend> {
proj_out: Conv2d<B>,
}
impl<B: MyBackend> ConvSelfAttentionBlock<B> {
impl<B: Backend> ConvSelfAttentionBlock<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let [n_batch, n_channel, height, width] = x.dims();
@@ -568,7 +581,7 @@ impl<B: MyBackend> ConvSelfAttentionBlock<B> {
.reshape([n_batch, n_channel, height * width])
.swap_dims(1, 2);
let wv = Tensor::from_primitive(B::qkv_attention(
/*let wv = Tensor::from_primitive(B::qkv_attention(
q.into_primitive(),
k.into_primitive(),
v.into_primitive(),
@@ -576,6 +589,16 @@ impl<B: MyBackend> ConvSelfAttentionBlock<B> {
1,
))
.swap_dims(1, 2)
.reshape([n_batch, n_channel, height, width]);*/
let wv = qkv_attention(
q,
k,
v,
None,
1,
)
.swap_dims(1, 2)
.reshape([n_batch, n_channel, height, width]);
let projected = self.proj_out.forward(wv);

View File

@@ -68,7 +68,7 @@ pub fn load_residual_decoder_attention_block<B: Backend>(
pub fn load_clip<B: Backend>(path: &str, device: &B::Device) -> Result<CLIP<B>, Box<dyn Error>> {
let token_embedding = load_embedding(&format!("{}/{}", path, "token_embedding"), device)?;
let position_embedding =
load_tensor("weight", &format!("{}/position_embedding", path), device)?.into();
Param::from_tensor(load_tensor("weight", &format!("{}/position_embedding", path), device)?);
let n_layer = load_usize::<B>("n_layer", path, device)?;
let mut blocks = (0..n_layer)

View File

@@ -12,7 +12,8 @@ use burn::{
},
};
use crate::backend::Backend as MyBackend;
//use crate::backend::Backend as MyBackend;
use crate::backend::{qkv_attention, attn_decoder_mask};
#[derive(Config)]
pub struct CLIPConfig {
@@ -24,15 +25,15 @@ pub struct CLIPConfig {
}
impl CLIPConfig {
pub fn init<B: Backend>(&self) -> CLIP<B> {
let token_embedding = nn::EmbeddingConfig::new(self.n_vocab, self.n_state).init();
pub fn init<B: Backend>(&self, device: &B::Device) -> CLIP<B> {
let token_embedding = nn::EmbeddingConfig::new(self.n_vocab, self.n_state).init(device);
let position_embedding =
Tensor::random([self.n_ctx, self.n_state], Distribution::Normal(0.0, 1.0)).into();
Param::from_tensor(Tensor::random([self.n_ctx, self.n_state], Distribution::Normal(0.0, 1.0), device));
let blocks = (0..self.n_layer)
.into_iter()
.map(|_| ResidualDecoderAttentionBlockConfig::new(self.n_state, self.n_head).init())
.map(|_| ResidualDecoderAttentionBlockConfig::new(self.n_state, self.n_head).init(device))
.collect();
let layer_norm = nn::LayerNormConfig::new(self.n_state).init();
let layer_norm = nn::LayerNormConfig::new(self.n_state).init(device);
CLIP {
token_embedding,
@@ -51,11 +52,12 @@ pub struct CLIP<B: Backend> {
layer_norm: nn::LayerNorm<B>,
}
impl<B: MyBackend> CLIP<B> {
impl<B: Backend> CLIP<B> {
pub fn forward(&self, x: Tensor<B, 2, Int>) -> Tensor<B, 3> {
let [n_batch, seq_len] = x.dims();
let mask = Tensor::from_primitive(B::attn_decoder_mask(seq_len, &x.device()));
//let mask = Tensor::from_primitive(B::attn_decoder_mask(seq_len, &x.device()));
let mask = attn_decoder_mask(seq_len, &x.device());
let embedded = self.token_embedding.forward(x)
+ self
@@ -80,12 +82,12 @@ pub struct ResidualDecoderAttentionBlockConfig {
}
impl ResidualDecoderAttentionBlockConfig {
pub fn init<B: Backend>(&self) -> ResidualDecoderAttentionBlock<B> {
let attn = MultiHeadSelfAttentionConfig::new(self.n_state, self.n_head).init();
let attn_ln = nn::LayerNormConfig::new(self.n_state).init();
pub fn init<B: Backend>(&self, device: &B::Device) -> ResidualDecoderAttentionBlock<B> {
let attn = MultiHeadSelfAttentionConfig::new(self.n_state, self.n_head).init(device);
let attn_ln = nn::LayerNormConfig::new(self.n_state).init(device);
let mlp = MLPConfig::new(self.n_state, 4 * self.n_state).init();
let mlp_ln = nn::LayerNormConfig::new(self.n_state).init();
let mlp = MLPConfig::new(self.n_state, 4 * self.n_state).init(device);
let mlp_ln = nn::LayerNormConfig::new(self.n_state).init(device);
ResidualDecoderAttentionBlock {
attn,
@@ -104,7 +106,7 @@ pub struct ResidualDecoderAttentionBlock<B: Backend> {
mlp_ln: nn::LayerNorm<B>,
}
impl<B: MyBackend> ResidualDecoderAttentionBlock<B> {
impl<B: Backend> ResidualDecoderAttentionBlock<B> {
fn forward(&self, x: Tensor<B, 3>, mask: Tensor<B, 2>) -> Tensor<B, 3> {
let x = x.clone() + self.attn.forward(self.attn_ln.forward(x), Some(mask));
let x = x.clone() + self.mlp.forward(self.mlp_ln.forward(x));
@@ -119,7 +121,7 @@ pub struct MultiHeadSelfAttentionConfig {
}
impl MultiHeadSelfAttentionConfig {
fn init<B: Backend>(&self) -> MultiHeadSelfAttention<B> {
fn init<B: Backend>(&self, device: &B::Device) -> MultiHeadSelfAttention<B> {
assert!(
self.n_state % self.n_head == 0,
"State size {} must be a multiple of head size {}",
@@ -128,10 +130,10 @@ impl MultiHeadSelfAttentionConfig {
);
let n_head = self.n_head;
let query = nn::LinearConfig::new(self.n_state, self.n_state).init();
let key = nn::LinearConfig::new(self.n_state, self.n_state).init();
let value = nn::LinearConfig::new(self.n_state, self.n_state).init();
let out = nn::LinearConfig::new(self.n_state, self.n_state).init();
let query = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
let key = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
let value = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
let out = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
MultiHeadSelfAttention {
n_head,
@@ -152,19 +154,27 @@ pub struct MultiHeadSelfAttention<B: Backend> {
out: nn::Linear<B>,
}
impl<B: MyBackend> MultiHeadSelfAttention<B> {
impl<B: Backend> MultiHeadSelfAttention<B> {
pub fn forward(&self, x: Tensor<B, 3>, mask: Option<Tensor<B, 2>>) -> Tensor<B, 3> {
let q = self.query.forward(x.clone());
let k = self.key.forward(x.clone());
let v = self.value.forward(x);
let wv = Tensor::from_primitive(B::qkv_attention(
/*let wv = Tensor::from_primitive(B::qkv_attention(
q.into_primitive(),
k.into_primitive(),
v.into_primitive(),
mask.map(|m| m.into_primitive()),
self.n_head,
));
));*/
let wv = qkv_attention(
q,
k,
v,
mask,
self.n_head,
);
return self.out.forward(wv);
}
@@ -177,10 +187,10 @@ pub struct MLPConfig {
}
impl MLPConfig {
fn init<B: Backend>(&self) -> MLP<B> {
let fc1 = nn::LinearConfig::new(self.input_size, self.hidden_size).init();
fn init<B: Backend>(&self, device: &B::Device) -> MLP<B> {
let fc1 = nn::LinearConfig::new(self.input_size, self.hidden_size).init(device);
let gelu = QuickGELU::new();
let fc2 = nn::LinearConfig::new(self.hidden_size, self.input_size).init();
let fc2 = nn::LinearConfig::new(self.hidden_size, self.input_size).init(device);
MLP { fc1, gelu, fc2 }
}

View File

@@ -18,14 +18,14 @@ pub fn load_group_norm<B: Backend>(
let n_channel = load_usize::<B>("n_channel", path, device)?.into();
let eps = load_f32::<B>("eps", path, device)?.into();
let gamma = load_tensor::<B, 1>("weight", path, device)
let gamma = Param::from_tensor(load_tensor::<B, 1>("weight", path, device)
.ok()
.unwrap_or_else(|| Tensor::ones_device([n_channel], device))
.into();
let beta = load_tensor::<B, 1>("bias", path, device)
.unwrap_or_else(|| Tensor::ones([n_channel], device))
);
let beta = Param::from_tensor(load_tensor::<B, 1>("bias", path, device)
.ok()
.unwrap_or_else(|| Tensor::zeros_device([n_channel], device))
.into();
.unwrap_or_else(|| Tensor::zeros([n_channel], device))
);
Ok(GroupNorm {
n_group,

View File

@@ -15,7 +15,7 @@ pub struct GroupNormConfig {
}
impl GroupNormConfig {
pub fn init<B: Backend>(&self) -> GroupNorm<B> {
pub fn init<B: Backend>(&self, device: &B::Device) -> GroupNorm<B> {
assert!(
self.n_channel % self.n_group == 0,
"The number of channels {} must be divisible by the number of groups {}",
@@ -25,8 +25,8 @@ impl GroupNormConfig {
let n_per_group = self.n_channel / self.n_group;
let gamma = Tensor::ones([self.n_channel]).into();
let beta = Tensor::zeros([self.n_channel]).into();
let gamma = Param::from_tensor(Tensor::ones([self.n_channel], device));
let beta = Param::from_tensor(Tensor::zeros([self.n_channel], device));
let eps = self.eps;

View File

@@ -1,5 +1,7 @@
use npy::{self, NpyData};
use num_traits::cast::ToPrimitive;
use burn::tensor::cast::ToElement;
use burn::prelude::TensorData;
use std::error::Error;
use std::io::Read;
@@ -21,7 +23,8 @@ pub fn numpy_to_tensor<B: Backend, const D: usize>(
let shape: Vec<_> = v[0..D].into_iter().map(|&v| v as usize).collect();
let data: Vec<B::FloatElem> = v[D..].into_iter().map(|e| e.elem()).collect();
Tensor::from_data_device(Data::new(data, shape.into()), device)
//Tensor::from_data_device(Data::new(data, shape.into()), device)
Tensor::from_data(TensorData::new(data, shape), device)
}
pub fn load_tensor<B: Backend, const D: usize>(
@@ -48,7 +51,7 @@ pub fn load_f32<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<f32, Box<dyn Error>> {
load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_f32().unwrap())
load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_f32())
}
pub fn load_usize<B: Backend>(
@@ -56,7 +59,7 @@ pub fn load_usize<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<usize, Box<dyn Error>> {
load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_usize().unwrap())
load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_usize())
}
pub fn load_linear<B: Backend>(
@@ -66,13 +69,10 @@ pub fn load_linear<B: Backend>(
let weight = load_tensor::<B, 2>("weight", path, device)?;
let bias = load_tensor::<B, 1>("bias", path, device).ok();
let record = nn::LinearRecord {
weight: weight.into(),
bias: bias.map(|t| t.into()),
};
let linear: nn::Linear<B> = nn::LinearConfig::new(3, 3).init_with(record);
Ok(linear)
Ok(nn::Linear {
weight: Param::from_tensor(weight),
bias: bias.map(|t| Param::from_tensor(t)),
})
}
pub fn load_embedding<B: Backend>(
@@ -80,14 +80,10 @@ pub fn load_embedding<B: Backend>(
device: &B::Device,
) -> Result<nn::Embedding<B>, Box<dyn Error>> {
let weight = load_tensor::<B, 2>("weight", path, device)?;
let [n_vocab, n_state] = weight.dims();
let record = nn::EmbeddingRecord {
weight: weight.into(),
};
let embedding = nn::EmbeddingConfig::new(n_vocab, n_state).init_with(record);
Ok(embedding)
Ok(nn::Embedding {
weight: Param::from_tensor(weight),
})
}
pub fn load_layer_norm<B: Backend>(
@@ -100,13 +96,9 @@ pub fn load_layer_norm<B: Backend>(
let [n_state] = weight.dims();
let record = nn::LayerNormRecord {
gamma: weight.into(),
beta: bias.into(),
epsilon: <f64 as Module<B>>::into_record(eps),
};
let layer_norm: nn::LayerNorm<B> = nn::LayerNormConfig::new(n_state).init_with(record);
let mut layer_norm = nn::LayerNormConfig::new(n_state).with_epsilon(eps).init(device);
layer_norm.gamma = Param::from_tensor(weight);
layer_norm.beta = Param::from_tensor(bias);
Ok(layer_norm)
}
@@ -116,7 +108,7 @@ pub fn load_layer_norm<B: Backend>(
let eps = load_f32::<B>("eps", path, device)?.into();
let rmsnorm = RMSNorm {
weight: weight.into(),
weight: Param::from_tensor(weight),
eps: eps
};
@@ -148,40 +140,38 @@ pub fn load_conv2d<B: Backend>(
let padding = tensor_to_array_2(padding);
let padding = nn::PaddingConfig2d::Explicit(padding[0], padding[1]);
let record = conv::Conv2dRecord {
weight: weight.into(),
bias: bias.map(|t| t.into()),
stride: <[usize; 2] as Module<B>>::into_record(stride),
kernel_size: <[usize; 2] as Module<B>>::into_record(kernel_size),
dilation: <[usize; 2] as Module<B>>::into_record(dilation),
groups: <usize as Module<B>>::into_record(n_group),
padding: <nn::PaddingConfig2d as Module<B>>::into_record(padding.clone()),
};
let conv2d: conv::Conv2d<B> =
conv::Conv2dConfig::new([n_channels_in, n_channels_out], kernel_size)
let mut conv2d = conv::Conv2dConfig::new([n_channels_in, n_channels_out], kernel_size)
.with_stride(stride)
.with_dilation(dilation)
.with_groups(n_group)
.with_padding(padding)
.with_padding(padding.clone())
.with_bias(has_bias)
.init_with(record);
.init(device);
conv2d.weight = Param::from_tensor(weight);
conv2d.bias = bias.map(|t| Param::from_tensor(t));
conv2d.stride = stride;
conv2d.kernel_size = kernel_size;
conv2d.dilation = dilation;
conv2d.groups = n_group;
conv2d.padding = burn::module::Ignored(padding);
Ok(conv2d)
}
pub fn tensor_to_array_2<B: Backend>(x: Tensor<B, 1>) -> [usize; 2] {
let vec = x.into_data().value;
let vec: Vec<<B as Backend>::FloatElem> = x.into_data().to_vec().unwrap();
assert!(vec.len() == 2, "Tensor length must be 2.");
[vec[0].to_usize().unwrap(), vec[1].to_usize().unwrap()]
[vec[0].to_usize(), vec[1].to_usize()]
}
pub fn tensor_to_array<const N: usize, B: Backend>(x: Tensor<B, 1>) -> [usize; N] {
let vec = x.into_data().value;
let vec: Vec<<B as Backend>::FloatElem> = x.into_data().to_vec().unwrap();
assert!(vec.len() == N, "Tensor length must be {}.", N);
let mut arr = [0; N];
for (a, t) in arr.iter_mut().zip(vec) {
*a = t.to_usize().unwrap();
*a = t.to_usize();
}
arr

View File

@@ -18,7 +18,7 @@ pub fn load_stable_diffusion<B: Backend>(
device: &B::Device,
) -> Result<StableDiffusion<B>, Box<dyn Error>> {
let n_steps = load_usize::<B>("n_steps", path, device)?;
let alpha_cumulative_products = load_tensor::<B, 1>("alphas_cumprod", path, device)?.into();
let alpha_cumulative_products = Param::from_tensor(load_tensor::<B, 1>("alphas_cumprod", path, device)?);
let autoencoder = load_autoencoder(&format!("{}/{}", path, "autoencoder"), device)?;
let diffusion = load_unet(&format!("{}/{}", path, "unet"), device)?;
let clip = load_clip(&format!("{}/{}", path, "clip"), device)?;

View File

@@ -4,11 +4,12 @@ use burn::{
config::Config,
module::{Module, Param},
tensor::{backend::Backend, BasicOps, Data, Distribution, Float, Int, Tensor},
tensor::cast::ToElement,
};
use num_traits::ToPrimitive;
use crate::backend::Backend as MyBackend;
//use crate::backend::Backend as MyBackend;
use super::autoencoder::{Autoencoder, AutoencoderConfig};
use super::clip::{CLIPConfig, CLIP};
@@ -19,13 +20,13 @@ use crate::tokenizer::SimpleTokenizer;
pub struct StableDiffusionConfig {}
impl StableDiffusionConfig {
pub fn init<B: Backend>(&self) -> StableDiffusion<B> {
pub fn init<B: Backend>(&self, device: &B::Device) -> StableDiffusion<B> {
let n_steps = 1000;
let alpha_cumulative_products = offset_cosine_schedule_cumprod::<B>(n_steps).into();
let alpha_cumulative_products = Param::from_tensor(offset_cosine_schedule_cumprod::<B>(n_steps as i64, device));
let autoencoder = AutoencoderConfig::new().init();
let diffusion = UNetConfig::new().init();
let clip = CLIPConfig::new(49408, 768, 12, 77, 12).init();
let autoencoder = AutoencoderConfig::new().init(device);
let diffusion = UNetConfig::new().init(device);
let clip = CLIPConfig::new(49408, 768, 12, 77, 12).init(device);
StableDiffusion {
n_steps,
@@ -46,7 +47,7 @@ pub struct StableDiffusion<B: Backend> {
clip: CLIP<B>,
}
impl<B: MyBackend> StableDiffusion<B> {
impl<B: Backend> StableDiffusion<B> {
pub fn sample_image(
&self,
context: Tensor<B, 3>,
@@ -82,7 +83,7 @@ impl<B: MyBackend> StableDiffusion<B> {
.swap_dims(2, 3)
.mul_scalar(255.0);
let flattened: Vec<_> = image.into_data().value;
let flattened: Vec<B::FloatElem> = image.into_data().to_vec().unwrap();
(0..n_batch)
.into_iter()
@@ -92,7 +93,7 @@ impl<B: MyBackend> StableDiffusion<B> {
flattened[start..end]
.into_iter()
.map(|v| v.to_f64().unwrap().min(255.0).max(0.0).to_u8().unwrap())
.map(|v| v.to_f64().min(255.0).max(0.0) as u8)
.collect()
})
.collect()
@@ -112,8 +113,7 @@ impl<B: MyBackend> StableDiffusion<B> {
let [n_batches, _, _] = context.dims();
let gen_noise = || {
Tensor::random([n_batches, 4, 64, 64], Distribution::Normal(0.0, 1.0))
.to_device(&device)
Tensor::random([n_batches, 4, 64, 64], Distribution::Normal(0.0, 1.0), &device)
};
let sigma = 0.0; // Use deterministic diffusion
@@ -126,8 +126,8 @@ impl<B: MyBackend> StableDiffusion<B> {
.val()
.slice([t..t + 1])
.into_scalar()
.to_f64()
.unwrap();
.to_f64();
let prev_alpha: f64 = if t >= step_size {
let i = t - step_size;
self.alpha_cumulative_products
@@ -135,14 +135,13 @@ impl<B: MyBackend> StableDiffusion<B> {
.slice([i..i + 1])
.into_scalar()
.to_f64()
.unwrap()
} else {
1.0
};
let sqrt_noise = (1.0 - current_alpha).sqrt();
let timestep = Tensor::from_ints([t as i32]).to_device(&device);
let timestep = Tensor::from_ints([t as i32], &device);
let pred_noise = self.forward_diffuser(
latent.clone(),
timestep,
@@ -174,7 +173,7 @@ impl<B: MyBackend> StableDiffusion<B> {
let unconditional_latent = self.diffusion.forward(
latent.clone(),
timestep.clone(),
unconditional_context.unsqueeze().repeat(0, n_batch),
unconditional_context.unsqueeze().repeat(&[0, n_batch]),
);
let conditional_latent = self.diffusion.forward(latent, timestep, context);
@@ -206,8 +205,7 @@ impl<B: MyBackend> StableDiffusion<B> {
.collect();
self.clip.forward(
Tensor::from_ints(&tokenized[..])
.to_device(device)
Tensor::<B, 1, Int>::from_ints(&tokenized[..], device)
.unsqueeze(),
)
}
@@ -215,25 +213,25 @@ impl<B: MyBackend> StableDiffusion<B> {
use std::f64::consts::PI;
fn cosine_schedule<B: Backend>(n_steps: usize) -> Tensor<B, 1> {
Tensor::arange(1..n_steps + 1)
fn cosine_schedule<B: Backend>(n_steps: i64, device: &B::Device) -> Tensor<B, 1> {
Tensor::arange(1..n_steps + 1, device)
.float()
.mul_scalar(PI * 0.5 / n_steps as f64)
.cos()
}
fn offset_cosine_schedule<B: Backend>(n_steps: usize) -> Tensor<B, 1> {
fn offset_cosine_schedule<B: Backend>(n_steps: i64, device: &B::Device) -> Tensor<B, 1> {
let min_signal_rate: f64 = 0.02;
let max_signal_rate: f64 = 0.95;
let start_angle = max_signal_rate.acos();
let end_angle = min_signal_rate.acos();
let times = Tensor::arange(1..n_steps + 1).float();
let times = Tensor::arange(1..n_steps + 1, device).float();
let diffusion_angles = times * ((end_angle - start_angle) / n_steps as f64) + start_angle;
diffusion_angles.cos()
}
fn offset_cosine_schedule_cumprod<B: Backend>(n_steps: usize) -> Tensor<B, 1> {
offset_cosine_schedule::<B>(n_steps).powf(2.0)
fn offset_cosine_schedule_cumprod<B: Backend>(n_steps: i64, device: &B::Device) -> Tensor<B, 1> {
offset_cosine_schedule::<B>(n_steps, device).powf_scalar(2.0)
}

View File

@@ -65,7 +65,7 @@ pub fn load_geglu<B: Backend>(path: &str, device: &B::Device) -> Result<GEGLU<B>
let geglue = GEGLU {
proj: proj,
gelu: GELU::new(), // Assuming GELU::new() initializes a new GELU struct
gelu: Gelu::new(), // Assuming Gelu::new() initializes a new Gelu struct
};
Ok(geglue)

View File

@@ -6,7 +6,7 @@ use burn::{
nn::{
self,
conv::{Conv2d, Conv2dConfig},
PaddingConfig2d, GELU,
PaddingConfig2d, Gelu,
},
tensor::{activation::softmax, backend::Backend, module::embedding, Distribution, Int, Tensor},
};
@@ -22,7 +22,7 @@ fn timestep_embedding<B: Backend>(
max_period: usize,
) -> Tensor<B, 2> {
let half = dim / 2;
let freqs = (Tensor::arange_device(0..half, &timesteps.device()).float()
let freqs = (Tensor::arange(0..half as i64, &timesteps.device()).float()
* (-(max_period as f64).ln() / half as f64))
.exp();
let args = timesteps.float() * freqs;
@@ -33,50 +33,50 @@ fn timestep_embedding<B: Backend>(
pub struct UNetConfig {}
impl UNetConfig {
pub fn init<B: Backend>(&self) -> UNet<B> {
let lin1_time_embed = nn::LinearConfig::new(320, 1280).init();
pub fn init<B: Backend>(&self, device: &B::Device) -> UNet<B> {
let lin1_time_embed = nn::LinearConfig::new(320, 1280).init(device);
let silu_time_embed = SILU::new();
let lin2_time_embed = nn::LinearConfig::new(1280, 1280).init();
let lin2_time_embed = nn::LinearConfig::new(1280, 1280).init(device);
let input_blocks = UNetInputBlocks {
conv: Conv2dConfig::new([4, 320], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init(),
rt1: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(),
rt2: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(),
d1: DownsampleConfig::new(320).init(),
rt3: ResTransformerConfig::new(320, 1280, 640, 768, 8).init(),
rt4: ResTransformerConfig::new(640, 1280, 640, 768, 8).init(),
d2: DownsampleConfig::new(640).init(),
rt5: ResTransformerConfig::new(640, 1280, 1280, 768, 8).init(),
rt6: ResTransformerConfig::new(1280, 1280, 1280, 768, 8).init(),
d3: DownsampleConfig::new(1280).init(),
r1: ResBlockConfig::new(1280, 1280, 1280).init(),
r2: ResBlockConfig::new(1280, 1280, 1280).init(),
.init(device),
rt1: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(device),
rt2: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(device),
d1: DownsampleConfig::new(320).init(device),
rt3: ResTransformerConfig::new(320, 1280, 640, 768, 8).init(device),
rt4: ResTransformerConfig::new(640, 1280, 640, 768, 8).init(device),
d2: DownsampleConfig::new(640).init(device),
rt5: ResTransformerConfig::new(640, 1280, 1280, 768, 8).init(device),
rt6: ResTransformerConfig::new(1280, 1280, 1280, 768, 8).init(device),
d3: DownsampleConfig::new(1280).init(device),
r1: ResBlockConfig::new(1280, 1280, 1280).init(device),
r2: ResBlockConfig::new(1280, 1280, 1280).init(device),
};
let middle_block = ResTransformerResConfig::new(1280, 1280, 1280, 768, 8).init();
let middle_block = ResTransformerResConfig::new(1280, 1280, 1280, 768, 8).init(device);
let output_blocks = UNetOutputBlocks {
r1: ResBlockConfig::new(2560, 1280, 1280).init(),
r2: ResBlockConfig::new(2560, 1280, 1280).init(),
ru: ResUpSampleConfig::new(2560, 1280, 1280).init(),
rt1: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(),
rt2: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(),
rtu1: ResTransformerUpsampleConfig::new(1920, 1280, 1280, 768, 8).init(),
rt3: ResTransformerConfig::new(1920, 1280, 640, 768, 8).init(),
rt4: ResTransformerConfig::new(1280, 1280, 640, 768, 8).init(),
rtu2: ResTransformerUpsampleConfig::new(960, 1280, 640, 768, 8).init(),
rt5: ResTransformerConfig::new(960, 1280, 320, 768, 8).init(),
rt6: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(),
rt7: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(),
r1: ResBlockConfig::new(2560, 1280, 1280).init(device),
r2: ResBlockConfig::new(2560, 1280, 1280).init(device),
ru: ResUpSampleConfig::new(2560, 1280, 1280).init(device),
rt1: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(device),
rt2: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(device),
rtu1: ResTransformerUpsampleConfig::new(1920, 1280, 1280, 768, 8).init(device),
rt3: ResTransformerConfig::new(1920, 1280, 640, 768, 8).init(device),
rt4: ResTransformerConfig::new(1280, 1280, 640, 768, 8).init(device),
rtu2: ResTransformerUpsampleConfig::new(960, 1280, 640, 768, 8).init(device),
rt5: ResTransformerConfig::new(960, 1280, 320, 768, 8).init(device),
rt6: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(device),
rt7: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(device),
};
let norm_out = GroupNormConfig::new(32, 320).init();
let norm_out = GroupNormConfig::new(32, 320).init(device);
let silu_out = SILU::new();
let conv_out = Conv2dConfig::new([320, 4], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
.init(device);
UNet {
lin1_time_embed,
@@ -206,16 +206,16 @@ pub struct ResTransformerConfig {
}
impl ResTransformerConfig {
fn init<B: Backend>(&self) -> ResTransformer<B> {
fn init<B: Backend>(&self, device: &B::Device) -> ResTransformer<B> {
let res = ResBlockConfig::new(
self.n_channels_in,
self.n_channels_embed,
self.n_channels_out,
)
.init();
.init(device);
let transformer =
SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head)
.init();
.init(device);
ResTransformer { res, transformer }
}
@@ -243,14 +243,14 @@ pub struct ResUpSampleConfig {
}
impl ResUpSampleConfig {
fn init<B: Backend>(&self) -> ResUpSample<B> {
fn init<B: Backend>(&self, device: &B::Device) -> ResUpSample<B> {
let res = ResBlockConfig::new(
self.n_channels_in,
self.n_channels_embed,
self.n_channels_out,
)
.init();
let upsample = UpsampleConfig::new(self.n_channels_out).init();
.init(device);
let upsample = UpsampleConfig::new(self.n_channels_out).init(device);
ResUpSample { res, upsample }
}
@@ -280,17 +280,17 @@ pub struct ResTransformerUpsampleConfig {
}
impl ResTransformerUpsampleConfig {
fn init<B: Backend>(&self) -> ResTransformerUpsample<B> {
fn init<B: Backend>(&self, device: &B::Device) -> ResTransformerUpsample<B> {
let res = ResBlockConfig::new(
self.n_channels_in,
self.n_channels_embed,
self.n_channels_out,
)
.init();
.init(device);
let transformer =
SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head)
.init();
let upsample = UpsampleConfig::new(self.n_channels_out).init();
.init(device);
let upsample = UpsampleConfig::new(self.n_channels_out).init(device);
ResTransformerUpsample {
res,
@@ -326,22 +326,22 @@ pub struct ResTransformerResConfig {
}
impl ResTransformerResConfig {
fn init<B: Backend>(&self) -> ResTransformerRes<B> {
fn init<B: Backend>(&self, device: &B::Device) -> ResTransformerRes<B> {
let res1 = ResBlockConfig::new(
self.n_channels_in,
self.n_channels_embed,
self.n_channels_out,
)
.init();
.init(device);
let transformer =
SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head)
.init();
.init(device);
let res2 = ResBlockConfig::new(
self.n_channels_in,
self.n_channels_embed,
self.n_channels_out,
)
.init();
.init(device);
ResTransformerRes {
res1,
@@ -373,10 +373,10 @@ pub struct UpsampleConfig {
}
impl UpsampleConfig {
fn init<B: Backend>(&self) -> Upsample<B> {
fn init<B: Backend>(&self, device: &B::Device) -> Upsample<B> {
let conv = Conv2dConfig::new([self.n_channels, self.n_channels], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
.init(device);
Upsample { conv }
}
@@ -392,8 +392,7 @@ impl<B: Backend> Upsample<B> {
let [n_batch, n_channel, height, width] = x.dims();
let x = x
.reshape([n_batch, n_channel, height, 1, width, 1])
.repeat(3, 2)
.repeat(5, 2)
.repeat(&[1, 1, 1, 2, 1, 2])
.reshape([n_batch, n_channel, 2 * height, 2 * width]);
self.conv.forward(x)
}
@@ -411,11 +410,11 @@ pub struct DownsampleConfig {
}
impl DownsampleConfig {
fn init<B: Backend>(&self) -> Conv2d<B> {
fn init<B: Backend>(&self, device: &B::Device) -> Conv2d<B> {
Conv2dConfig::new([self.n_channels, self.n_channels], [3, 3])
.with_stride([2, 2])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init()
.init(device)
}
}
@@ -435,12 +434,12 @@ pub struct SpatialTransformerConfig {
}
impl SpatialTransformerConfig {
fn init<B: Backend>(&self) -> SpatialTransformer<B> {
let norm = GroupNormConfig::new(32, self.n_channels).init();
let proj_in = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init();
fn init<B: Backend>(&self, device: &B::Device) -> SpatialTransformer<B> {
let norm = GroupNormConfig::new(32, self.n_channels).init(device);
let proj_in = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init(device);
let transformer =
TransformerBlockConfig::new(self.n_channels, self.n_context_state, self.n_head).init();
let proj_out = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init();
TransformerBlockConfig::new(self.n_channels, self.n_context_state, self.n_head).init(device);
let proj_out = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init(device);
SpatialTransformer {
norm,
@@ -489,14 +488,14 @@ pub struct TransformerBlockConfig {
}
impl TransformerBlockConfig {
fn init<B: Backend>(&self) -> TransformerBlock<B> {
let norm1 = nn::LayerNormConfig::new(self.n_state).init();
let attn1 = MultiHeadAttentionConfig::new(self.n_state, self.n_state, self.n_head).init();
let norm2 = nn::LayerNormConfig::new(self.n_state).init();
fn init<B: Backend>(&self, device: &B::Device) -> TransformerBlock<B> {
let norm1 = nn::LayerNormConfig::new(self.n_state).init(device);
let attn1 = MultiHeadAttentionConfig::new(self.n_state, self.n_state, self.n_head).init(device);
let norm2 = nn::LayerNormConfig::new(self.n_state).init(device);
let attn2 =
MultiHeadAttentionConfig::new(self.n_state, self.n_context_state, self.n_head).init();
let norm3 = nn::LayerNormConfig::new(self.n_state).init();
let mlp = MLPConfig::new(self.n_state, 4).init();
MultiHeadAttentionConfig::new(self.n_state, self.n_context_state, self.n_head).init(device);
let norm3 = nn::LayerNormConfig::new(self.n_state).init(device);
let mlp = MLPConfig::new(self.n_state, 4).init(device);
TransformerBlock {
norm1,
@@ -534,10 +533,10 @@ pub struct MLPConfig {
}
impl MLPConfig {
pub fn init<B: Backend>(&self) -> MLP<B> {
pub fn init<B: Backend>(&self, device: &B::Device) -> MLP<B> {
let n_state_hidden = self.n_state * self.mult;
let geglu = GEGLUConfig::new(self.n_state, n_state_hidden).init();
let lin = nn::LinearConfig::new(n_state_hidden, self.n_state).init();
let geglu = GEGLUConfig::new(self.n_state, n_state_hidden).init(device);
let lin = nn::LinearConfig::new(n_state_hidden, self.n_state).init(device);
MLP { geglu, lin }
}
@@ -562,9 +561,9 @@ pub struct GEGLUConfig {
}
impl GEGLUConfig {
fn init<B: Backend>(&self) -> GEGLU<B> {
let proj = nn::LinearConfig::new(self.n_state_in, 2 * self.n_state_out).init();
let gelu = GELU::new();
fn init<B: Backend>(&self, device: &B::Device) -> GEGLU<B> {
let proj = nn::LinearConfig::new(self.n_state_in, 2 * self.n_state_out).init(device);
let gelu = Gelu::new();
GEGLU { proj, gelu }
}
@@ -573,7 +572,7 @@ impl GEGLUConfig {
#[derive(Module, Debug)]
pub struct GEGLU<B: Backend> {
proj: nn::Linear<B>,
gelu: GELU,
gelu: Gelu,
}
impl<B: Backend> GEGLU<B> {
@@ -600,7 +599,7 @@ pub struct MultiHeadAttentionConfig {
}
impl MultiHeadAttentionConfig {
fn init<B: Backend>(&self) -> MultiHeadAttention<B> {
fn init<B: Backend>(&self, device: &B::Device) -> MultiHeadAttention<B> {
assert!(
self.n_state % self.n_head == 0,
"State size {} must be a multiple of head size {}",
@@ -611,14 +610,14 @@ impl MultiHeadAttentionConfig {
let n_head = self.n_head;
let query = nn::LinearConfig::new(self.n_state, self.n_state)
.with_bias(false)
.init();
.init(device);
let key = nn::LinearConfig::new(self.n_context_state, self.n_state)
.with_bias(false)
.init();
.init(device);
let value = nn::LinearConfig::new(self.n_context_state, self.n_state)
.with_bias(false)
.init();
let out = nn::LinearConfig::new(self.n_state, self.n_state).init();
.init(device);
let out = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
MultiHeadAttention {
n_head,
@@ -661,24 +660,24 @@ pub struct ResBlockConfig {
}
impl ResBlockConfig {
fn init<B: Backend>(&self) -> ResBlock<B> {
let norm_in = GroupNormConfig::new(32, self.n_channels_in).init();
fn init<B: Backend>(&self, device: &B::Device) -> ResBlock<B> {
let norm_in = GroupNormConfig::new(32, self.n_channels_in).init(device);
let silu_in = SILU::new();
let conv_in = Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
.init(device);
let silu_embed = SILU::new();
let lin_embed = nn::LinearConfig::new(self.n_channels_embed, self.n_channels_out).init();
let lin_embed = nn::LinearConfig::new(self.n_channels_embed, self.n_channels_out).init(device);
let norm_out = GroupNormConfig::new(32, self.n_channels_out).init();
let norm_out = GroupNormConfig::new(32, self.n_channels_out).init(device);
let silu_out = SILU::new();
let conv_out = Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
.init(device);
let skip_connection = if self.n_channels_in != self.n_channels_out {
Some(Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [1, 1]).init())
Some(Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [1, 1]).init(device))
} else {
None
};