Update to burn v0.14.0 and switch to .mpk model file

This commit is contained in:
Hermes
2024-10-05 14:19:49 -04:00
committed by Ben_Kosytorz
parent 3c49b0a151
commit 75f0cedd9f
19 changed files with 366 additions and 311 deletions

View File

@@ -14,11 +14,11 @@ git = "https://github.com/burn-rs/burn.git"
optional = true optional = true
[dependencies] [dependencies]
burn = { git = "https://github.com/burn-rs/burn.git" } burn = "0.14.0"
burn-ndarray = { package = "burn-ndarray", git = "https://github.com/burn-rs/burn.git" } burn-ndarray = "0.14.0"
burn-tch = { package = "burn-tch", git = "https://github.com/burn-rs/burn.git" } burn-tch = "0.14.0"
burn-autodiff = { package = "burn-autodiff", git = "https://github.com/burn-rs/burn.git" } burn-autodiff = "0.14.0"
tch = "0.13.0" tch = "0.15.0"
serde = {version = "1.0.171", features = ["std", "derive"]} serde = {version = "1.0.171", features = ["std", "derive"]}
npy = "0.4.0" npy = "0.4.0"
num-traits = "0.2.15" num-traits = "0.2.15"

View File

@@ -4,12 +4,14 @@ Stable-Diffusion-Burn is a Rust-based project which ports the V1 stable diffusio
## How To Use ## How To Use
### Step 0: Install libtorch v2.4.1
### Step 1: Download the Model and Set Environment Variables ### Step 1: Download the Model and Set Environment Variables
Start by downloading the SDv1-4.bin model provided on HuggingFace. Start by downloading the SDv1-4 model provided on HuggingFace.
```bash ```bash
wget https://huggingface.co/Gadersd/Stable-Diffusion-Burn/resolve/main/V1/SDv1-4.bin wget https://huggingface.co/Gadersd/Stable-Diffusion-Burn/resolve/main/SDv1-4.mpk
``` ```
### Step 2: Run the Sample Binary ### Step 2: Run the Sample Binary
@@ -18,9 +20,13 @@ Invoke the sample binary provided in the rust code. By default, torch is used. T
```bash ```bash
# torch (at least 6 GB VRAM, possibly less) # torch (at least 6 GB VRAM, possibly less)
export TORCH_CUDA_VERSION=cu113 # Arguments: <model_type(burn or dump)> <model_name> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image_name> [cuda, mps, cpu]
# Arguments: <model_type(burn or dump)> <model> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image>
cargo run --release --bin sample burn SDv1-4 7.5 20 "An ancient mossy stone." img # Cuda
cargo run --release --bin sample burn SDv1-4 7.5 20 "An ancient mossy stone." img cuda
# Mps(Mac)
cargo run --release --bin sample burn SDv1-4 7.5 20 "An ancient mossy stone." img mps
# wgpu (UNSTABLE) # wgpu (UNSTABLE)
# Arguments: <model_type(burn or dump)> <model> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image> # Arguments: <model_type(burn or dump)> <model> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image>
@@ -33,7 +39,7 @@ This command will generate an image according to the provided prompt, which will
### Optional: Extract and Convert a Fine-Tuned Model ### Optional: Extract and Convert a Fine-Tuned Model
If users are interested in using a fine-tuned version of stable diffusion, the Python scripts provided in this project can be used to transform a weight dump into a Burn model file. Note: the tinygrad dependency should be installed from source rather than with pip. If users are interested in using a fine-tuned version of stable diffusion, the Python scripts provided in this project can be used to transform a weight dump into a Burn model file. This does not work on Windows.
```bash ```bash
# Step into the Python directory # Step into the Python directory
@@ -42,6 +48,9 @@ cd python
# Download the model, this is just the base v1.4 model as an example # Download the model, this is just the base v1.4 model as an example
wget https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt wget https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt
# Install tinygrad
pip install -r requirements.txt
# Extract the weights # Extract the weights
CPU=1 python3 dump.py sd-v1-4.ckpt CPU=1 python3 dump.py sd-v1-4.ckpt

View File

@@ -13,10 +13,11 @@ from collections import namedtuple
from tqdm import tqdm from tqdm import tqdm
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes, GlobalCounters from tinygrad.helpers import GlobalCounters
from tinygrad import dtypes
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
from extra.utils import download_file #from extra.utils import download_file
from tinygrad.state import torch_load, load_state_dict from tinygrad.nn.state import torch_load, load_state_dict
# TODO: refactor AttnBlock, CrossAttention, CLIPAttention to share code # TODO: refactor AttnBlock, CrossAttention, CLIPAttention to share code

1
python/requirements.txt Normal file
View File

@@ -0,0 +1 @@
tinygrad==0.9.2

View File

@@ -1,13 +1,16 @@
use burn::tensor::{activation::softmax, Tensor}; use burn::tensor::{activation::softmax, Tensor};
use burn::prelude::Backend;
/*pub type FloatTensor<B, const D: usize> = <B as burn::tensor::backend::Backend>::TensorPrimitive<D>;
pub trait Backend: burn::tensor::backend::Backend { pub trait Backend: burn::tensor::backend::Backend {
fn qkv_attention( fn qkv_attention(
q: Self::TensorPrimitive<3>, q: FloatTensor<Self, 3>,
k: Self::TensorPrimitive<3>, k: FloatTensor<Self, 3>,
v: Self::TensorPrimitive<3>, v: FloatTensor<Self, 3>,
mask: Option<Self::TensorPrimitive<2>>, mask: Option<FloatTensor<Self, 2>>,
n_head: usize, n_head: usize,
) -> Self::TensorPrimitive<3> { ) -> FloatTensor<Self, 3> {
qkv_attention( qkv_attention(
Tensor::<Self, 3>::from_primitive(q), Tensor::<Self, 3>::from_primitive(q),
Tensor::from_primitive(k), Tensor::from_primitive(k),
@@ -18,24 +21,23 @@ pub trait Backend: burn::tensor::backend::Backend {
.into_primitive() .into_primitive()
} }
fn attn_decoder_mask(seq_length: usize, device: &Self::Device) -> Self::TensorPrimitive<2> { fn attn_decoder_mask(seq_length: usize, device: &Self::Device) -> FloatTensor<Self, 2> {
attn_decoder_mask::<Self>(seq_length, device).into_primitive() attn_decoder_mask::<Self>(seq_length, device).into_primitive()
} }
} }
use burn::tensor::ops::TensorOps;
use burn::tensor::Float; use burn::tensor::Float;
use burn_tch::{self, TchElement, TchTensor}; use burn_tch::{self, TchElement, TchTensor};
use tch; use tch;
impl<E: TchElement> Backend for burn_tch::TchBackend<E> { impl<E: TchElement> Backend for burn_tch::LibTorch<E> {
fn qkv_attention( fn qkv_attention(
q: Self::TensorPrimitive<3>, q: FloatTensor<Self, 3>,
k: Self::TensorPrimitive<3>, k: FloatTensor<Self, 3>,
v: Self::TensorPrimitive<3>, v: FloatTensor<Self, 3>,
mask: Option<Self::TensorPrimitive<2>>, mask: Option<FloatTensor<Self, 2>>,
n_head: usize, n_head: usize,
) -> Self::TensorPrimitive<3> { ) -> FloatTensor<Self, 2> {
let q = Tensor::from_primitive(q); let q = Tensor::from_primitive(q);
let k = Tensor::from_primitive(k); let k = Tensor::from_primitive(k);
let v = Tensor::from_primitive(v); let v = Tensor::from_primitive(v);
@@ -56,7 +58,7 @@ impl<E: TchElement> Backend for burn_tch::TchBackend<E> {
// for some reason torch crashes when mask is None // for some reason torch crashes when mask is None
let mask = mask.unwrap_or_else(|| { let mask = mask.unwrap_or_else(|| {
Tensor::<Self, 2, Float>::zeros_device([q_ctx, k_ctx], &Self::device(&v)) Tensor::<Self, 2, Float>::zeros([q_ctx, k_ctx], &Self::device(&v))
.into_primitive() .into_primitive()
}); });
@@ -68,6 +70,7 @@ impl<E: TchElement> Backend for burn_tch::TchBackend<E> {
Some(mask.tensor), Some(mask.tensor),
0.0, 0.0,
false, false,
None,
), ),
)) ))
.swap_dims(1, 2) .swap_dims(1, 2)
@@ -78,11 +81,11 @@ impl<E: TchElement> Backend for burn_tch::TchBackend<E> {
use burn_autodiff; use burn_autodiff;
impl<B: Backend> Backend for burn_autodiff::ADBackendDecorator<B> {} impl<B: Backend> Backend for burn_autodiff::Autodiff<B> {}*/
use std::f32::NEG_INFINITY; use std::f32::NEG_INFINITY;
fn qkv_attention<B: Backend>( pub fn qkv_attention<B: Backend>(
q: Tensor<B, 3>, q: Tensor<B, 3>,
k: Tensor<B, 3>, k: Tensor<B, 3>,
v: Tensor<B, 3>, v: Tensor<B, 3>,
@@ -124,13 +127,13 @@ fn qkv_attention<B: Backend>(
return o; return o;
} }
fn attn_decoder_mask<B: Backend>(seq_length: usize, device: &B::Device) -> Tensor<B, 2> { pub fn attn_decoder_mask<B: Backend>(seq_length: usize, device: &B::Device) -> Tensor<B, 2> {
let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length]); let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length], device);
for i in 0..(seq_length - 1) { for i in 0..(seq_length - 1) {
let values = Tensor::<B, 2>::zeros([1, seq_length - (i + 1)]).add_scalar(NEG_INFINITY); let values = Tensor::<B, 2>::zeros([1, seq_length - (i + 1)], device).add_scalar(NEG_INFINITY);
mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values); mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values);
} }
return mask.to_device(device); return mask;
} }

View File

@@ -11,9 +11,9 @@ use burn::{
tensor::{backend::Backend, Tensor}, tensor::{backend::Backend, Tensor},
}; };
use burn_ndarray::{NdArrayBackend, NdArrayDevice}; use burn_ndarray::{NdArray, NdArrayDevice};
use burn::record::{self, BinFileRecorder, FullPrecisionSettings, Recorder}; use burn::record::{self, NamedMpkFileRecorder, FullPrecisionSettings, Recorder};
fn convert_dump_to_model<B: Backend>( fn convert_dump_to_model<B: Backend>(
dump_path: &str, dump_path: &str,
@@ -33,11 +33,11 @@ fn save_model_file<B: Backend>(
model: StableDiffusion<B>, model: StableDiffusion<B>,
name: &str, name: &str,
) -> Result<(), record::RecorderError> { ) -> Result<(), record::RecorderError> {
BinFileRecorder::<FullPrecisionSettings>::new().record(model.into_record(), name.into()) NamedMpkFileRecorder::<FullPrecisionSettings>::new().record(model.into_record(), name.into())
} }
fn main() { fn main() {
type Backend = NdArrayBackend<f32>; type Backend = NdArray<f32>;
let device = NdArrayDevice::Cpu; let device = NdArrayDevice::Cpu;
let args: Vec<String> = env::args().collect(); let args: Vec<String> = env::args().collect();

View File

@@ -14,7 +14,7 @@ cfg_if::cfg_if! {
if #[cfg(feature = "wgpu-backend")] { if #[cfg(feature = "wgpu-backend")] {
use burn_wgpu::{WgpuBackend, WgpuDevice, AutoGraphicsApi}; use burn_wgpu::{WgpuBackend, WgpuDevice, AutoGraphicsApi};
} else { } else {
use burn_tch::{TchBackend, TchDevice}; use burn_tch::{LibTorch, LibTorchDevice};
} }
} }
@@ -22,30 +22,21 @@ use std::env;
use std::io; use std::io;
use std::process; use std::process;
use burn::record::{self, BinFileRecorder, FullPrecisionSettings, Recorder}; use burn::record::{self, NamedMpkFileRecorder, FullPrecisionSettings, Recorder};
fn load_stable_diffusion_model_file<B: Backend>( fn load_stable_diffusion_model_file<B: Backend>(
filename: &str, filename: &str,
device: &B::Device,
) -> Result<StableDiffusion<B>, record::RecorderError> { ) -> Result<StableDiffusion<B>, record::RecorderError> {
BinFileRecorder::<FullPrecisionSettings>::new() NamedMpkFileRecorder::<FullPrecisionSettings>::new()
.load(filename.into()) .load(filename.into(), device)
.map(|record| StableDiffusionConfig::new().init().load_record(record)) .map(|record| StableDiffusionConfig::new().init(device).load_record(record))
} }
fn main() { fn main() {
cfg_if::cfg_if! {
if #[cfg(feature = "wgpu-backend")] {
type Backend = WgpuBackend<AutoGraphicsApi, f32, i32>;
let device = WgpuDevice::BestAvailable;
} else {
type Backend = TchBackend<f32>;
let device = TchDevice::Cuda(0);
}
}
let args: Vec<String> = std::env::args().collect(); let args: Vec<String> = std::env::args().collect();
if args.len() != 7 { if args.len() != 7 && args.len() != 8 {
eprintln!("Usage: {} <model_type(burn or dump)> <model_name> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image_name>", args[0]); eprintln!("Usage: {} <model_type(burn or dump)> <model_name> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image_name> [device(cuda, mps, cpu)]", args[0]);
process::exit(1); process::exit(1);
} }
@@ -62,11 +53,40 @@ fn main() {
let prompt = &args[5]; let prompt = &args[5];
let output_image_name = &args[6]; let output_image_name = &args[6];
// Optional device parameter
let device_arg = if args.len() == 8 { Some(&args[7]) } else { None };
cfg_if::cfg_if! {
if #[cfg(feature = "wgpu-backend")] {
type Backend = WgpuBackend<AutoGraphicsApi, f32, i32>;
let device = WgpuDevice::BestAvailable;
} else {
type Backend = LibTorch<f32>;
let device = if let Some(dev_str) = device_arg {
match dev_str.to_lowercase().as_str() {
"cpu" => LibTorchDevice::Cpu,
"mps" => LibTorchDevice::Mps,
s if s.starts_with("cuda") => {
let idx = s[4..].parse().unwrap_or(0);
LibTorchDevice::Cuda(idx)
}
_ => {
eprintln!("Unknown device: {}", dev_str);
process::exit(1);
}
}
} else {
LibTorchDevice::Cuda(0)
};
}
}
println!("Loading tokenizer..."); println!("Loading tokenizer...");
let tokenizer = SimpleTokenizer::new().unwrap(); let tokenizer = SimpleTokenizer::new().unwrap();
println!("Loading model..."); println!("Loading model...");
let sd: StableDiffusion<Backend> = if model_type == "burn" { let sd: StableDiffusion<Backend> = if model_type == "burn" {
load_stable_diffusion_model_file(model_name).unwrap_or_else(|err| { load_stable_diffusion_model_file(model_name, &device).unwrap_or_else(|err| {
eprintln!("Error loading model: {}", err); eprintln!("Error loading model: {}", err);
process::exit(1); process::exit(1);
}) })
@@ -77,8 +97,6 @@ fn main() {
}) })
}; };
let sd = sd.to_device(&device);
let unconditional_context = sd.unconditional_context(&tokenizer); let unconditional_context = sd.unconditional_context(&tokenizer);
let context = sd.context(&tokenizer, prompt).unsqueeze::<3>(); //.repeat(0, 2); // generate 2 samples let context = sd.context(&tokenizer, prompt).unsqueeze::<3>(); //.repeat(0, 2); // generate 2 samples

View File

@@ -45,12 +45,12 @@ pub fn qkv_attention<B: Backend>(
} }
pub fn attn_decoder_mask<B: Backend>(seq_length: usize, device: &B::Device) -> Tensor<B, 2> { pub fn attn_decoder_mask<B: Backend>(seq_length: usize, device: &B::Device) -> Tensor<B, 2> {
let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length]); let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length], device);
for i in 0..(seq_length - 1) { for i in 0..(seq_length - 1) {
let values = Tensor::<B, 2>::zeros([1, seq_length - (i + 1)]).add_scalar(NEG_INFINITY); let values = Tensor::<B, 2>::zeros([1, seq_length - (i + 1)], device).add_scalar(NEG_INFINITY);
mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values); mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values);
} }
return mask.to_device(device); return mask;
} }

View File

@@ -71,7 +71,7 @@ fn load_padded_conv2d<B: Backend>(
path: &str, path: &str,
device: &B::Device, device: &B::Device,
) -> Result<PaddedConv2d<B>, Box<dyn Error>> { ) -> Result<PaddedConv2d<B>, Box<dyn Error>> {
let conv = load_conv2d(&format!("{}/{}", path, "conv"), device)?; let mut conv = load_conv2d(&format!("{}/{}", path, "conv"), device)?;
let channels = load_tensor::<B, 1>("channels", path, device)?; let channels = load_tensor::<B, 1>("channels", path, device)?;
let channels = tensor_to_array_2(channels); let channels = tensor_to_array_2(channels);
@@ -81,18 +81,21 @@ fn load_padded_conv2d<B: Backend>(
let padding = load_tensor::<B, 1>("padding", path, device)?; let padding = load_tensor::<B, 1>("padding", path, device)?;
let padding: [usize; 4] = tensor_to_array(padding); let padding: [usize; 4] = tensor_to_array(padding);
let padding = Padding::new(padding[0], padding[1], padding[2], padding[3]); let padding = PaddingCfg::new(padding[0], padding[1], padding[2], padding[3]);
let mut record = conv.into_record(); //let mut record = conv.into_record();
let mut padded_conv: PaddedConv2d<B> = PaddedConv2dConfig::new(channels, kernel_size, padding) let mut padded_conv: PaddedConv2d<B> = PaddedConv2dConfig::new(channels, kernel_size, padding)
.with_stride(stride) .with_stride(stride)
.init(); .init(device);
let padding_actual = let padding_actual =
PaddingConfig2d::Explicit(padded_conv.padding_actual[0], padded_conv.padding_actual[1]); PaddingConfig2d::Explicit(padded_conv.padding_actual[0], padded_conv.padding_actual[1]);
record.padding = <PaddingConfig2d as Module<B>>::into_record(padding_actual); conv.padding = burn::module::Ignored(padding_actual);
padded_conv.conv = padded_conv.conv.load_record(record); padded_conv.conv = conv;
//record.padding = <PaddingConfig2d as Module<B>>::into_record(padding_actual);
//padded_conv.conv = padded_conv.conv.load_record(record);
Ok(padded_conv) Ok(padded_conv)
} }

View File

@@ -18,7 +18,8 @@ use burn::{
use super::groupnorm::*; use super::groupnorm::*;
use super::silu::*; use super::silu::*;
use crate::backend::Backend as MyBackend; //use crate::backend::Backend as MyBackend;
use crate::backend::{qkv_attention, attn_decoder_mask};
use std::iter; use std::iter;
@@ -26,13 +27,13 @@ use std::iter;
pub struct AutoencoderConfig {} pub struct AutoencoderConfig {}
impl AutoencoderConfig { impl AutoencoderConfig {
pub fn init<B: Backend>(&self) -> Autoencoder<B> { pub fn init<B: Backend>(&self, device: &B::Device) -> Autoencoder<B> {
let encoder = let encoder =
EncoderConfig::new(vec![(128, 128), (128, 256), (256, 512), (512, 512)], 32, 8).init(); EncoderConfig::new(vec![(128, 128), (128, 256), (256, 512), (512, 512)], 32, 8).init(device);
let decoder = let decoder =
DecoderConfig::new(vec![(512, 512), (512, 512), (512, 256), (256, 128)], 32).init(); DecoderConfig::new(vec![(512, 512), (512, 512), (512, 256), (256, 128)], 32).init(device);
let quant_conv = Conv2dConfig::new([8, 8], [1, 1]).init(); let quant_conv = Conv2dConfig::new([8, 8], [1, 1]).init(device);
let post_quant_conv = Conv2dConfig::new([4, 4], [1, 1]).init(); let post_quant_conv = Conv2dConfig::new([4, 4], [1, 1]).init(device);
Autoencoder { Autoencoder {
encoder, encoder,
@@ -51,7 +52,7 @@ pub struct Autoencoder<B: Backend> {
post_quant_conv: Conv2d<B>, post_quant_conv: Conv2d<B>,
} }
impl<B: MyBackend> Autoencoder<B> { impl<B: Backend> Autoencoder<B> {
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> { pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
self.decode_latent(self.encode_image(x)) self.decode_latent(self.encode_image(x))
} }
@@ -78,7 +79,7 @@ pub struct EncoderConfig {
} }
impl EncoderConfig { impl EncoderConfig {
fn init<B: Backend>(&self) -> Encoder<B> { fn init<B: Backend>(&self, device: &B::Device) -> Encoder<B> {
let n_expanded_channels_initial = self let n_expanded_channels_initial = self
.channels .channels
.first() .first()
@@ -88,7 +89,7 @@ impl EncoderConfig {
let conv_in = Conv2dConfig::new([3, n_expanded_channels_initial], [3, 3]) let conv_in = Conv2dConfig::new([3, n_expanded_channels_initial], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1)) .with_padding(PaddingConfig2d::Explicit(1, 1))
.init(); .init(device);
let blocks = self let blocks = self
.channels .channels
@@ -96,16 +97,16 @@ impl EncoderConfig {
.enumerate() .enumerate()
.map(|(i, &(n_channel_in, n_channel_out))| { .map(|(i, &(n_channel_in, n_channel_out))| {
let downsample = i != self.channels.len() - 1; let downsample = i != self.channels.len() - 1;
EncoderBlockConfig::new(n_channel_in, n_channel_out, downsample).init() EncoderBlockConfig::new(n_channel_in, n_channel_out, downsample).init(device)
}) })
.collect(); .collect();
let mid = MidConfig::new(n_expanded_channels_final).init(); let mid = MidConfig::new(n_expanded_channels_final).init(device);
let norm_out = GroupNormConfig::new(self.n_group, n_expanded_channels_final).init(); let norm_out = GroupNormConfig::new(self.n_group, n_expanded_channels_final).init(device);
let silu = SILU::new(); let silu = SILU::new();
let conv_out = Conv2dConfig::new([n_expanded_channels_final, self.n_channels_out], [3, 3]) let conv_out = Conv2dConfig::new([n_expanded_channels_final, self.n_channels_out], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1)) .with_padding(PaddingConfig2d::Explicit(1, 1))
.init(); .init(device);
Encoder { Encoder {
conv_in, conv_in,
@@ -128,7 +129,7 @@ pub struct Encoder<B: Backend> {
conv_out: Conv2d<B>, conv_out: Conv2d<B>,
} }
impl<B: MyBackend> Encoder<B> { impl<B: Backend> Encoder<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> { fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.conv_in.forward(x); let x = self.conv_in.forward(x);
@@ -150,7 +151,7 @@ pub struct DecoderConfig {
} }
impl DecoderConfig { impl DecoderConfig {
fn init<B: Backend>(&self) -> Decoder<B> { fn init<B: Backend>(&self, device: &B::Device) -> Decoder<B> {
let n_expanded_channels = self let n_expanded_channels = self
.channels .channels
.first() .first()
@@ -160,8 +161,8 @@ impl DecoderConfig {
let conv_in = Conv2dConfig::new([4, n_expanded_channels], [3, 3]) let conv_in = Conv2dConfig::new([4, n_expanded_channels], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1)) .with_padding(PaddingConfig2d::Explicit(1, 1))
.init(); .init(device);
let mid = MidConfig::new(n_expanded_channels).init(); let mid = MidConfig::new(n_expanded_channels).init(device);
let blocks = self let blocks = self
.channels .channels
@@ -169,15 +170,15 @@ impl DecoderConfig {
.enumerate() .enumerate()
.map(|(i, &(n_channel_in, n_channel_out))| { .map(|(i, &(n_channel_in, n_channel_out))| {
let upsample = i != self.channels.len() - 1; let upsample = i != self.channels.len() - 1;
DecoderBlockConfig::new(n_channel_in, n_channel_out, upsample).init() DecoderBlockConfig::new(n_channel_in, n_channel_out, upsample).init(device)
}) })
.collect(); .collect();
let norm_out = GroupNormConfig::new(self.n_group, n_condensed_channels).init(); let norm_out = GroupNormConfig::new(self.n_group, n_condensed_channels).init(device);
let silu = SILU::new(); let silu = SILU::new();
let conv_out = Conv2dConfig::new([n_condensed_channels, 3], [3, 3]) let conv_out = Conv2dConfig::new([n_condensed_channels, 3], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1)) .with_padding(PaddingConfig2d::Explicit(1, 1))
.init(); .init(device);
Decoder { Decoder {
conv_in, conv_in,
@@ -200,7 +201,7 @@ pub struct Decoder<B: Backend> {
conv_out: Conv2d<B>, conv_out: Conv2d<B>,
} }
impl<B: MyBackend> Decoder<B> { impl<B: Backend> Decoder<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> { fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.conv_in.forward(x); let x = self.conv_in.forward(x);
let x = self.mid.forward(x); let x = self.mid.forward(x);
@@ -223,15 +224,15 @@ pub struct EncoderBlockConfig {
} }
impl EncoderBlockConfig { impl EncoderBlockConfig {
fn init<B: Backend>(&self) -> EncoderBlock<B> { fn init<B: Backend>(&self, device: &B::Device) -> EncoderBlock<B> {
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init(); let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init(device);
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(); let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(device);
let downsampler = if self.downsample { let downsampler = if self.downsample {
let padding = Padding::new(0, 1, 0, 1); let padding = PaddingCfg::new(0, 1, 0, 1);
Some( Some(
PaddedConv2dConfig::new([self.n_channels_out, self.n_channels_out], 3, padding) PaddedConv2dConfig::new([self.n_channels_out, self.n_channels_out], 3, padding)
.with_stride(2) .with_stride(2)
.init(), .init(device),
) )
} else { } else {
None None
@@ -272,15 +273,15 @@ pub struct DecoderBlockConfig {
} }
impl DecoderBlockConfig { impl DecoderBlockConfig {
fn init<B: Backend>(&self) -> DecoderBlock<B> { fn init<B: Backend>(&self, device: &B::Device) -> DecoderBlock<B> {
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init(); let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init(device);
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(); let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(device);
let res3 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(); let res3 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(device);
let upsampler = if self.upsample { let upsampler = if self.upsample {
Some( Some(
Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3]) Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1)) .with_padding(PaddingConfig2d::Explicit(1, 1))
.init(), .init(device),
) )
} else { } else {
None None
@@ -313,8 +314,7 @@ impl<B: Backend> DecoderBlock<B> {
let [n_batch, n_channel, height, width] = x.dims(); let [n_batch, n_channel, height, width] = x.dims();
let x = x let x = x
.reshape([n_batch, n_channel, height, 1, width, 1]) .reshape([n_batch, n_channel, height, 1, width, 1])
.repeat(3, 2) .repeat(&[1, 1, 1, 2, 1, 2])
.repeat(5, 2)
.reshape([n_batch, n_channel, 2 * height, 2 * width]); .reshape([n_batch, n_channel, 2 * height, 2 * width]);
d.forward(x) d.forward(x)
} else { } else {
@@ -329,11 +329,11 @@ pub struct PaddedConv2dConfig {
kernel_size: usize, kernel_size: usize,
#[config(default = 1)] #[config(default = 1)]
stride: usize, stride: usize,
padding: Padding, padding: PaddingCfg,
} }
impl PaddedConv2dConfig { impl PaddedConv2dConfig {
fn init<B: Backend>(&self) -> PaddedConv2d<B> { fn init<B: Backend>(&self, device: &B::Device) -> PaddedConv2d<B> {
let calc_padding = |p_left, p_right| { let calc_padding = |p_left, p_right| {
let n = if p_left >= p_right { let n = if p_left >= p_right {
0 0
@@ -351,12 +351,17 @@ impl PaddedConv2dConfig {
let conv = Conv2dConfig::new(self.channels, [self.kernel_size, self.kernel_size]) let conv = Conv2dConfig::new(self.channels, [self.kernel_size, self.kernel_size])
.with_stride([self.stride, self.stride]) .with_stride([self.stride, self.stride])
.with_padding(PaddingConfig2d::Explicit(pad_vertical, pad_horizontal)) .with_padding(PaddingConfig2d::Explicit(pad_vertical, pad_horizontal))
.init(); .init(device);
let kernel_size = self.kernel_size; let kernel_size = self.kernel_size;
let stride = self.stride; let stride = self.stride;
let padding = self.padding; let padding = Padding {
pad_left: self.padding.pad_left,
pad_right: self.padding.pad_right,
pad_top: self.padding.pad_top,
pad_bottom: self.padding.pad_bottom,
};
PaddedConv2d { PaddedConv2d {
conv, conv,
@@ -406,7 +411,15 @@ impl<B: Backend> PaddedConv2d<B> {
} }
} }
#[derive(Config, Module, Copy, Debug)] #[derive(Config, Debug)]
pub struct PaddingCfg {
pad_left: usize,
pad_right: usize,
pad_top: usize,
pad_bottom: usize,
}
#[derive(Module, Clone, Debug)]
pub struct Padding { pub struct Padding {
pad_left: usize, pad_left: usize,
pad_right: usize, pad_right: usize,
@@ -420,10 +433,10 @@ pub struct MidConfig {
} }
impl MidConfig { impl MidConfig {
fn init<B: Backend>(&self) -> Mid<B> { fn init<B: Backend>(&self, device: &B::Device) -> Mid<B> {
let block_1 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init(); let block_1 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init(device);
let attn = ConvSelfAttentionBlockConfig::new(self.n_channel).init(); let attn = ConvSelfAttentionBlockConfig::new(self.n_channel).init(device);
let block_2 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init(); let block_2 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init(device);
Mid { Mid {
block_1, block_1,
@@ -440,7 +453,7 @@ pub struct Mid<B: Backend> {
block_2: ResnetBlock<B>, block_2: ResnetBlock<B>,
} }
impl<B: MyBackend> Mid<B> { impl<B: Backend> Mid<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> { fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.block_1.forward(x); let x = self.block_1.forward(x);
let x = self.attn.forward(x); let x = self.attn.forward(x);
@@ -456,17 +469,17 @@ pub struct ResnetBlockConfig {
} }
impl ResnetBlockConfig { impl ResnetBlockConfig {
fn init<B: Backend>(&self) -> ResnetBlock<B> { fn init<B: Backend>(&self, device: &B::Device) -> ResnetBlock<B> {
let norm1 = GroupNormConfig::new(32, self.in_channels).init(); let norm1 = GroupNormConfig::new(32, self.in_channels).init(device);
let conv1 = Conv2dConfig::new([self.in_channels, self.out_channels], [3, 3]) let conv1 = Conv2dConfig::new([self.in_channels, self.out_channels], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1)) .with_padding(PaddingConfig2d::Explicit(1, 1))
.init(); .init(device);
let norm2 = GroupNormConfig::new(32, self.out_channels).init(); let norm2 = GroupNormConfig::new(32, self.out_channels).init(device);
let conv2 = Conv2dConfig::new([self.out_channels, self.out_channels], [3, 3]) let conv2 = Conv2dConfig::new([self.out_channels, self.out_channels], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1)) .with_padding(PaddingConfig2d::Explicit(1, 1))
.init(); .init(device);
let nin_shortcut = if self.in_channels != self.out_channels { let nin_shortcut = if self.in_channels != self.out_channels {
Some(Conv2dConfig::new([self.in_channels, self.out_channels], [1, 1]).init()) Some(Conv2dConfig::new([self.in_channels, self.out_channels], [1, 1]).init(device))
} else { } else {
None None
}; };
@@ -520,12 +533,12 @@ pub struct ConvSelfAttentionBlockConfig {
} }
impl ConvSelfAttentionBlockConfig { impl ConvSelfAttentionBlockConfig {
fn init<B: Backend>(&self) -> ConvSelfAttentionBlock<B> { fn init<B: Backend>(&self, device: &B::Device) -> ConvSelfAttentionBlock<B> {
let norm = GroupNormConfig::new(32, self.n_channel).init(); let norm = GroupNormConfig::new(32, self.n_channel).init(device);
let q = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(); let q = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
let k = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(); let k = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
let v = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(); let v = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
let proj_out = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(); let proj_out = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
ConvSelfAttentionBlock { ConvSelfAttentionBlock {
norm, norm,
@@ -546,7 +559,7 @@ pub struct ConvSelfAttentionBlock<B: Backend> {
proj_out: Conv2d<B>, proj_out: Conv2d<B>,
} }
impl<B: MyBackend> ConvSelfAttentionBlock<B> { impl<B: Backend> ConvSelfAttentionBlock<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> { fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let [n_batch, n_channel, height, width] = x.dims(); let [n_batch, n_channel, height, width] = x.dims();
@@ -568,7 +581,7 @@ impl<B: MyBackend> ConvSelfAttentionBlock<B> {
.reshape([n_batch, n_channel, height * width]) .reshape([n_batch, n_channel, height * width])
.swap_dims(1, 2); .swap_dims(1, 2);
let wv = Tensor::from_primitive(B::qkv_attention( /*let wv = Tensor::from_primitive(B::qkv_attention(
q.into_primitive(), q.into_primitive(),
k.into_primitive(), k.into_primitive(),
v.into_primitive(), v.into_primitive(),
@@ -576,6 +589,16 @@ impl<B: MyBackend> ConvSelfAttentionBlock<B> {
1, 1,
)) ))
.swap_dims(1, 2) .swap_dims(1, 2)
.reshape([n_batch, n_channel, height, width]);*/
let wv = qkv_attention(
q,
k,
v,
None,
1,
)
.swap_dims(1, 2)
.reshape([n_batch, n_channel, height, width]); .reshape([n_batch, n_channel, height, width]);
let projected = self.proj_out.forward(wv); let projected = self.proj_out.forward(wv);

View File

@@ -68,7 +68,7 @@ pub fn load_residual_decoder_attention_block<B: Backend>(
pub fn load_clip<B: Backend>(path: &str, device: &B::Device) -> Result<CLIP<B>, Box<dyn Error>> { pub fn load_clip<B: Backend>(path: &str, device: &B::Device) -> Result<CLIP<B>, Box<dyn Error>> {
let token_embedding = load_embedding(&format!("{}/{}", path, "token_embedding"), device)?; let token_embedding = load_embedding(&format!("{}/{}", path, "token_embedding"), device)?;
let position_embedding = let position_embedding =
load_tensor("weight", &format!("{}/position_embedding", path), device)?.into(); Param::from_tensor(load_tensor("weight", &format!("{}/position_embedding", path), device)?);
let n_layer = load_usize::<B>("n_layer", path, device)?; let n_layer = load_usize::<B>("n_layer", path, device)?;
let mut blocks = (0..n_layer) let mut blocks = (0..n_layer)

View File

@@ -12,7 +12,8 @@ use burn::{
}, },
}; };
use crate::backend::Backend as MyBackend; //use crate::backend::Backend as MyBackend;
use crate::backend::{qkv_attention, attn_decoder_mask};
#[derive(Config)] #[derive(Config)]
pub struct CLIPConfig { pub struct CLIPConfig {
@@ -24,15 +25,15 @@ pub struct CLIPConfig {
} }
impl CLIPConfig { impl CLIPConfig {
pub fn init<B: Backend>(&self) -> CLIP<B> { pub fn init<B: Backend>(&self, device: &B::Device) -> CLIP<B> {
let token_embedding = nn::EmbeddingConfig::new(self.n_vocab, self.n_state).init(); let token_embedding = nn::EmbeddingConfig::new(self.n_vocab, self.n_state).init(device);
let position_embedding = let position_embedding =
Tensor::random([self.n_ctx, self.n_state], Distribution::Normal(0.0, 1.0)).into(); Param::from_tensor(Tensor::random([self.n_ctx, self.n_state], Distribution::Normal(0.0, 1.0), device));
let blocks = (0..self.n_layer) let blocks = (0..self.n_layer)
.into_iter() .into_iter()
.map(|_| ResidualDecoderAttentionBlockConfig::new(self.n_state, self.n_head).init()) .map(|_| ResidualDecoderAttentionBlockConfig::new(self.n_state, self.n_head).init(device))
.collect(); .collect();
let layer_norm = nn::LayerNormConfig::new(self.n_state).init(); let layer_norm = nn::LayerNormConfig::new(self.n_state).init(device);
CLIP { CLIP {
token_embedding, token_embedding,
@@ -51,11 +52,12 @@ pub struct CLIP<B: Backend> {
layer_norm: nn::LayerNorm<B>, layer_norm: nn::LayerNorm<B>,
} }
impl<B: MyBackend> CLIP<B> { impl<B: Backend> CLIP<B> {
pub fn forward(&self, x: Tensor<B, 2, Int>) -> Tensor<B, 3> { pub fn forward(&self, x: Tensor<B, 2, Int>) -> Tensor<B, 3> {
let [n_batch, seq_len] = x.dims(); let [n_batch, seq_len] = x.dims();
let mask = Tensor::from_primitive(B::attn_decoder_mask(seq_len, &x.device())); //let mask = Tensor::from_primitive(B::attn_decoder_mask(seq_len, &x.device()));
let mask = attn_decoder_mask(seq_len, &x.device());
let embedded = self.token_embedding.forward(x) let embedded = self.token_embedding.forward(x)
+ self + self
@@ -80,12 +82,12 @@ pub struct ResidualDecoderAttentionBlockConfig {
} }
impl ResidualDecoderAttentionBlockConfig { impl ResidualDecoderAttentionBlockConfig {
pub fn init<B: Backend>(&self) -> ResidualDecoderAttentionBlock<B> { pub fn init<B: Backend>(&self, device: &B::Device) -> ResidualDecoderAttentionBlock<B> {
let attn = MultiHeadSelfAttentionConfig::new(self.n_state, self.n_head).init(); let attn = MultiHeadSelfAttentionConfig::new(self.n_state, self.n_head).init(device);
let attn_ln = nn::LayerNormConfig::new(self.n_state).init(); let attn_ln = nn::LayerNormConfig::new(self.n_state).init(device);
let mlp = MLPConfig::new(self.n_state, 4 * self.n_state).init(); let mlp = MLPConfig::new(self.n_state, 4 * self.n_state).init(device);
let mlp_ln = nn::LayerNormConfig::new(self.n_state).init(); let mlp_ln = nn::LayerNormConfig::new(self.n_state).init(device);
ResidualDecoderAttentionBlock { ResidualDecoderAttentionBlock {
attn, attn,
@@ -104,7 +106,7 @@ pub struct ResidualDecoderAttentionBlock<B: Backend> {
mlp_ln: nn::LayerNorm<B>, mlp_ln: nn::LayerNorm<B>,
} }
impl<B: MyBackend> ResidualDecoderAttentionBlock<B> { impl<B: Backend> ResidualDecoderAttentionBlock<B> {
fn forward(&self, x: Tensor<B, 3>, mask: Tensor<B, 2>) -> Tensor<B, 3> { fn forward(&self, x: Tensor<B, 3>, mask: Tensor<B, 2>) -> Tensor<B, 3> {
let x = x.clone() + self.attn.forward(self.attn_ln.forward(x), Some(mask)); let x = x.clone() + self.attn.forward(self.attn_ln.forward(x), Some(mask));
let x = x.clone() + self.mlp.forward(self.mlp_ln.forward(x)); let x = x.clone() + self.mlp.forward(self.mlp_ln.forward(x));
@@ -119,7 +121,7 @@ pub struct MultiHeadSelfAttentionConfig {
} }
impl MultiHeadSelfAttentionConfig { impl MultiHeadSelfAttentionConfig {
fn init<B: Backend>(&self) -> MultiHeadSelfAttention<B> { fn init<B: Backend>(&self, device: &B::Device) -> MultiHeadSelfAttention<B> {
assert!( assert!(
self.n_state % self.n_head == 0, self.n_state % self.n_head == 0,
"State size {} must be a multiple of head size {}", "State size {} must be a multiple of head size {}",
@@ -128,10 +130,10 @@ impl MultiHeadSelfAttentionConfig {
); );
let n_head = self.n_head; let n_head = self.n_head;
let query = nn::LinearConfig::new(self.n_state, self.n_state).init(); let query = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
let key = nn::LinearConfig::new(self.n_state, self.n_state).init(); let key = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
let value = nn::LinearConfig::new(self.n_state, self.n_state).init(); let value = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
let out = nn::LinearConfig::new(self.n_state, self.n_state).init(); let out = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
MultiHeadSelfAttention { MultiHeadSelfAttention {
n_head, n_head,
@@ -152,19 +154,27 @@ pub struct MultiHeadSelfAttention<B: Backend> {
out: nn::Linear<B>, out: nn::Linear<B>,
} }
impl<B: MyBackend> MultiHeadSelfAttention<B> { impl<B: Backend> MultiHeadSelfAttention<B> {
pub fn forward(&self, x: Tensor<B, 3>, mask: Option<Tensor<B, 2>>) -> Tensor<B, 3> { pub fn forward(&self, x: Tensor<B, 3>, mask: Option<Tensor<B, 2>>) -> Tensor<B, 3> {
let q = self.query.forward(x.clone()); let q = self.query.forward(x.clone());
let k = self.key.forward(x.clone()); let k = self.key.forward(x.clone());
let v = self.value.forward(x); let v = self.value.forward(x);
let wv = Tensor::from_primitive(B::qkv_attention( /*let wv = Tensor::from_primitive(B::qkv_attention(
q.into_primitive(), q.into_primitive(),
k.into_primitive(), k.into_primitive(),
v.into_primitive(), v.into_primitive(),
mask.map(|m| m.into_primitive()), mask.map(|m| m.into_primitive()),
self.n_head, self.n_head,
)); ));*/
let wv = qkv_attention(
q,
k,
v,
mask,
self.n_head,
);
return self.out.forward(wv); return self.out.forward(wv);
} }
@@ -177,10 +187,10 @@ pub struct MLPConfig {
} }
impl MLPConfig { impl MLPConfig {
fn init<B: Backend>(&self) -> MLP<B> { fn init<B: Backend>(&self, device: &B::Device) -> MLP<B> {
let fc1 = nn::LinearConfig::new(self.input_size, self.hidden_size).init(); let fc1 = nn::LinearConfig::new(self.input_size, self.hidden_size).init(device);
let gelu = QuickGELU::new(); let gelu = QuickGELU::new();
let fc2 = nn::LinearConfig::new(self.hidden_size, self.input_size).init(); let fc2 = nn::LinearConfig::new(self.hidden_size, self.input_size).init(device);
MLP { fc1, gelu, fc2 } MLP { fc1, gelu, fc2 }
} }

View File

@@ -18,14 +18,14 @@ pub fn load_group_norm<B: Backend>(
let n_channel = load_usize::<B>("n_channel", path, device)?.into(); let n_channel = load_usize::<B>("n_channel", path, device)?.into();
let eps = load_f32::<B>("eps", path, device)?.into(); let eps = load_f32::<B>("eps", path, device)?.into();
let gamma = load_tensor::<B, 1>("weight", path, device) let gamma = Param::from_tensor(load_tensor::<B, 1>("weight", path, device)
.ok() .ok()
.unwrap_or_else(|| Tensor::ones_device([n_channel], device)) .unwrap_or_else(|| Tensor::ones([n_channel], device))
.into(); );
let beta = load_tensor::<B, 1>("bias", path, device) let beta = Param::from_tensor(load_tensor::<B, 1>("bias", path, device)
.ok() .ok()
.unwrap_or_else(|| Tensor::zeros_device([n_channel], device)) .unwrap_or_else(|| Tensor::zeros([n_channel], device))
.into(); );
Ok(GroupNorm { Ok(GroupNorm {
n_group, n_group,

View File

@@ -15,7 +15,7 @@ pub struct GroupNormConfig {
} }
impl GroupNormConfig { impl GroupNormConfig {
pub fn init<B: Backend>(&self) -> GroupNorm<B> { pub fn init<B: Backend>(&self, device: &B::Device) -> GroupNorm<B> {
assert!( assert!(
self.n_channel % self.n_group == 0, self.n_channel % self.n_group == 0,
"The number of channels {} must be divisible by the number of groups {}", "The number of channels {} must be divisible by the number of groups {}",
@@ -25,8 +25,8 @@ impl GroupNormConfig {
let n_per_group = self.n_channel / self.n_group; let n_per_group = self.n_channel / self.n_group;
let gamma = Tensor::ones([self.n_channel]).into(); let gamma = Param::from_tensor(Tensor::ones([self.n_channel], device));
let beta = Tensor::zeros([self.n_channel]).into(); let beta = Param::from_tensor(Tensor::zeros([self.n_channel], device));
let eps = self.eps; let eps = self.eps;

View File

@@ -1,5 +1,7 @@
use npy::{self, NpyData}; use npy::{self, NpyData};
use num_traits::cast::ToPrimitive; use num_traits::cast::ToPrimitive;
use burn::tensor::cast::ToElement;
use burn::prelude::TensorData;
use std::error::Error; use std::error::Error;
use std::io::Read; use std::io::Read;
@@ -21,7 +23,8 @@ pub fn numpy_to_tensor<B: Backend, const D: usize>(
let shape: Vec<_> = v[0..D].into_iter().map(|&v| v as usize).collect(); let shape: Vec<_> = v[0..D].into_iter().map(|&v| v as usize).collect();
let data: Vec<B::FloatElem> = v[D..].into_iter().map(|e| e.elem()).collect(); let data: Vec<B::FloatElem> = v[D..].into_iter().map(|e| e.elem()).collect();
Tensor::from_data_device(Data::new(data, shape.into()), device) //Tensor::from_data_device(Data::new(data, shape.into()), device)
Tensor::from_data(TensorData::new(data, shape), device)
} }
pub fn load_tensor<B: Backend, const D: usize>( pub fn load_tensor<B: Backend, const D: usize>(
@@ -48,7 +51,7 @@ pub fn load_f32<B: Backend>(
path: &str, path: &str,
device: &B::Device, device: &B::Device,
) -> Result<f32, Box<dyn Error>> { ) -> Result<f32, Box<dyn Error>> {
load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_f32().unwrap()) load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_f32())
} }
pub fn load_usize<B: Backend>( pub fn load_usize<B: Backend>(
@@ -56,7 +59,7 @@ pub fn load_usize<B: Backend>(
path: &str, path: &str,
device: &B::Device, device: &B::Device,
) -> Result<usize, Box<dyn Error>> { ) -> Result<usize, Box<dyn Error>> {
load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_usize().unwrap()) load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_usize())
} }
pub fn load_linear<B: Backend>( pub fn load_linear<B: Backend>(
@@ -66,13 +69,10 @@ pub fn load_linear<B: Backend>(
let weight = load_tensor::<B, 2>("weight", path, device)?; let weight = load_tensor::<B, 2>("weight", path, device)?;
let bias = load_tensor::<B, 1>("bias", path, device).ok(); let bias = load_tensor::<B, 1>("bias", path, device).ok();
let record = nn::LinearRecord { Ok(nn::Linear {
weight: weight.into(), weight: Param::from_tensor(weight),
bias: bias.map(|t| t.into()), bias: bias.map(|t| Param::from_tensor(t)),
}; })
let linear: nn::Linear<B> = nn::LinearConfig::new(3, 3).init_with(record);
Ok(linear)
} }
pub fn load_embedding<B: Backend>( pub fn load_embedding<B: Backend>(
@@ -80,14 +80,10 @@ pub fn load_embedding<B: Backend>(
device: &B::Device, device: &B::Device,
) -> Result<nn::Embedding<B>, Box<dyn Error>> { ) -> Result<nn::Embedding<B>, Box<dyn Error>> {
let weight = load_tensor::<B, 2>("weight", path, device)?; let weight = load_tensor::<B, 2>("weight", path, device)?;
let [n_vocab, n_state] = weight.dims();
let record = nn::EmbeddingRecord { Ok(nn::Embedding {
weight: weight.into(), weight: Param::from_tensor(weight),
}; })
let embedding = nn::EmbeddingConfig::new(n_vocab, n_state).init_with(record);
Ok(embedding)
} }
pub fn load_layer_norm<B: Backend>( pub fn load_layer_norm<B: Backend>(
@@ -100,13 +96,9 @@ pub fn load_layer_norm<B: Backend>(
let [n_state] = weight.dims(); let [n_state] = weight.dims();
let record = nn::LayerNormRecord { let mut layer_norm = nn::LayerNormConfig::new(n_state).with_epsilon(eps).init(device);
gamma: weight.into(), layer_norm.gamma = Param::from_tensor(weight);
beta: bias.into(), layer_norm.beta = Param::from_tensor(bias);
epsilon: <f64 as Module<B>>::into_record(eps),
};
let layer_norm: nn::LayerNorm<B> = nn::LayerNormConfig::new(n_state).init_with(record);
Ok(layer_norm) Ok(layer_norm)
} }
@@ -116,7 +108,7 @@ pub fn load_layer_norm<B: Backend>(
let eps = load_f32::<B>("eps", path, device)?.into(); let eps = load_f32::<B>("eps", path, device)?.into();
let rmsnorm = RMSNorm { let rmsnorm = RMSNorm {
weight: weight.into(), weight: Param::from_tensor(weight),
eps: eps eps: eps
}; };
@@ -148,40 +140,38 @@ pub fn load_conv2d<B: Backend>(
let padding = tensor_to_array_2(padding); let padding = tensor_to_array_2(padding);
let padding = nn::PaddingConfig2d::Explicit(padding[0], padding[1]); let padding = nn::PaddingConfig2d::Explicit(padding[0], padding[1]);
let record = conv::Conv2dRecord { let mut conv2d = conv::Conv2dConfig::new([n_channels_in, n_channels_out], kernel_size)
weight: weight.into(),
bias: bias.map(|t| t.into()),
stride: <[usize; 2] as Module<B>>::into_record(stride),
kernel_size: <[usize; 2] as Module<B>>::into_record(kernel_size),
dilation: <[usize; 2] as Module<B>>::into_record(dilation),
groups: <usize as Module<B>>::into_record(n_group),
padding: <nn::PaddingConfig2d as Module<B>>::into_record(padding.clone()),
};
let conv2d: conv::Conv2d<B> =
conv::Conv2dConfig::new([n_channels_in, n_channels_out], kernel_size)
.with_stride(stride) .with_stride(stride)
.with_dilation(dilation) .with_dilation(dilation)
.with_groups(n_group) .with_groups(n_group)
.with_padding(padding) .with_padding(padding.clone())
.with_bias(has_bias) .with_bias(has_bias)
.init_with(record); .init(device);
conv2d.weight = Param::from_tensor(weight);
conv2d.bias = bias.map(|t| Param::from_tensor(t));
conv2d.stride = stride;
conv2d.kernel_size = kernel_size;
conv2d.dilation = dilation;
conv2d.groups = n_group;
conv2d.padding = burn::module::Ignored(padding);
Ok(conv2d) Ok(conv2d)
} }
pub fn tensor_to_array_2<B: Backend>(x: Tensor<B, 1>) -> [usize; 2] { pub fn tensor_to_array_2<B: Backend>(x: Tensor<B, 1>) -> [usize; 2] {
let vec = x.into_data().value; let vec: Vec<<B as Backend>::FloatElem> = x.into_data().to_vec().unwrap();
assert!(vec.len() == 2, "Tensor length must be 2."); assert!(vec.len() == 2, "Tensor length must be 2.");
[vec[0].to_usize().unwrap(), vec[1].to_usize().unwrap()] [vec[0].to_usize(), vec[1].to_usize()]
} }
pub fn tensor_to_array<const N: usize, B: Backend>(x: Tensor<B, 1>) -> [usize; N] { pub fn tensor_to_array<const N: usize, B: Backend>(x: Tensor<B, 1>) -> [usize; N] {
let vec = x.into_data().value; let vec: Vec<<B as Backend>::FloatElem> = x.into_data().to_vec().unwrap();
assert!(vec.len() == N, "Tensor length must be {}.", N); assert!(vec.len() == N, "Tensor length must be {}.", N);
let mut arr = [0; N]; let mut arr = [0; N];
for (a, t) in arr.iter_mut().zip(vec) { for (a, t) in arr.iter_mut().zip(vec) {
*a = t.to_usize().unwrap(); *a = t.to_usize();
} }
arr arr

View File

@@ -18,7 +18,7 @@ pub fn load_stable_diffusion<B: Backend>(
device: &B::Device, device: &B::Device,
) -> Result<StableDiffusion<B>, Box<dyn Error>> { ) -> Result<StableDiffusion<B>, Box<dyn Error>> {
let n_steps = load_usize::<B>("n_steps", path, device)?; let n_steps = load_usize::<B>("n_steps", path, device)?;
let alpha_cumulative_products = load_tensor::<B, 1>("alphas_cumprod", path, device)?.into(); let alpha_cumulative_products = Param::from_tensor(load_tensor::<B, 1>("alphas_cumprod", path, device)?);
let autoencoder = load_autoencoder(&format!("{}/{}", path, "autoencoder"), device)?; let autoencoder = load_autoencoder(&format!("{}/{}", path, "autoencoder"), device)?;
let diffusion = load_unet(&format!("{}/{}", path, "unet"), device)?; let diffusion = load_unet(&format!("{}/{}", path, "unet"), device)?;
let clip = load_clip(&format!("{}/{}", path, "clip"), device)?; let clip = load_clip(&format!("{}/{}", path, "clip"), device)?;

View File

@@ -4,11 +4,12 @@ use burn::{
config::Config, config::Config,
module::{Module, Param}, module::{Module, Param},
tensor::{backend::Backend, BasicOps, Data, Distribution, Float, Int, Tensor}, tensor::{backend::Backend, BasicOps, Data, Distribution, Float, Int, Tensor},
tensor::cast::ToElement,
}; };
use num_traits::ToPrimitive; use num_traits::ToPrimitive;
use crate::backend::Backend as MyBackend; //use crate::backend::Backend as MyBackend;
use super::autoencoder::{Autoencoder, AutoencoderConfig}; use super::autoencoder::{Autoencoder, AutoencoderConfig};
use super::clip::{CLIPConfig, CLIP}; use super::clip::{CLIPConfig, CLIP};
@@ -19,13 +20,13 @@ use crate::tokenizer::SimpleTokenizer;
pub struct StableDiffusionConfig {} pub struct StableDiffusionConfig {}
impl StableDiffusionConfig { impl StableDiffusionConfig {
pub fn init<B: Backend>(&self) -> StableDiffusion<B> { pub fn init<B: Backend>(&self, device: &B::Device) -> StableDiffusion<B> {
let n_steps = 1000; let n_steps = 1000;
let alpha_cumulative_products = offset_cosine_schedule_cumprod::<B>(n_steps).into(); let alpha_cumulative_products = Param::from_tensor(offset_cosine_schedule_cumprod::<B>(n_steps as i64, device));
let autoencoder = AutoencoderConfig::new().init(); let autoencoder = AutoencoderConfig::new().init(device);
let diffusion = UNetConfig::new().init(); let diffusion = UNetConfig::new().init(device);
let clip = CLIPConfig::new(49408, 768, 12, 77, 12).init(); let clip = CLIPConfig::new(49408, 768, 12, 77, 12).init(device);
StableDiffusion { StableDiffusion {
n_steps, n_steps,
@@ -46,7 +47,7 @@ pub struct StableDiffusion<B: Backend> {
clip: CLIP<B>, clip: CLIP<B>,
} }
impl<B: MyBackend> StableDiffusion<B> { impl<B: Backend> StableDiffusion<B> {
pub fn sample_image( pub fn sample_image(
&self, &self,
context: Tensor<B, 3>, context: Tensor<B, 3>,
@@ -82,7 +83,7 @@ impl<B: MyBackend> StableDiffusion<B> {
.swap_dims(2, 3) .swap_dims(2, 3)
.mul_scalar(255.0); .mul_scalar(255.0);
let flattened: Vec<_> = image.into_data().value; let flattened: Vec<B::FloatElem> = image.into_data().to_vec().unwrap();
(0..n_batch) (0..n_batch)
.into_iter() .into_iter()
@@ -92,7 +93,7 @@ impl<B: MyBackend> StableDiffusion<B> {
flattened[start..end] flattened[start..end]
.into_iter() .into_iter()
.map(|v| v.to_f64().unwrap().min(255.0).max(0.0).to_u8().unwrap()) .map(|v| v.to_f64().min(255.0).max(0.0) as u8)
.collect() .collect()
}) })
.collect() .collect()
@@ -112,8 +113,7 @@ impl<B: MyBackend> StableDiffusion<B> {
let [n_batches, _, _] = context.dims(); let [n_batches, _, _] = context.dims();
let gen_noise = || { let gen_noise = || {
Tensor::random([n_batches, 4, 64, 64], Distribution::Normal(0.0, 1.0)) Tensor::random([n_batches, 4, 64, 64], Distribution::Normal(0.0, 1.0), &device)
.to_device(&device)
}; };
let sigma = 0.0; // Use deterministic diffusion let sigma = 0.0; // Use deterministic diffusion
@@ -126,8 +126,8 @@ impl<B: MyBackend> StableDiffusion<B> {
.val() .val()
.slice([t..t + 1]) .slice([t..t + 1])
.into_scalar() .into_scalar()
.to_f64() .to_f64();
.unwrap();
let prev_alpha: f64 = if t >= step_size { let prev_alpha: f64 = if t >= step_size {
let i = t - step_size; let i = t - step_size;
self.alpha_cumulative_products self.alpha_cumulative_products
@@ -135,14 +135,13 @@ impl<B: MyBackend> StableDiffusion<B> {
.slice([i..i + 1]) .slice([i..i + 1])
.into_scalar() .into_scalar()
.to_f64() .to_f64()
.unwrap()
} else { } else {
1.0 1.0
}; };
let sqrt_noise = (1.0 - current_alpha).sqrt(); let sqrt_noise = (1.0 - current_alpha).sqrt();
let timestep = Tensor::from_ints([t as i32]).to_device(&device); let timestep = Tensor::from_ints([t as i32], &device);
let pred_noise = self.forward_diffuser( let pred_noise = self.forward_diffuser(
latent.clone(), latent.clone(),
timestep, timestep,
@@ -174,7 +173,7 @@ impl<B: MyBackend> StableDiffusion<B> {
let unconditional_latent = self.diffusion.forward( let unconditional_latent = self.diffusion.forward(
latent.clone(), latent.clone(),
timestep.clone(), timestep.clone(),
unconditional_context.unsqueeze().repeat(0, n_batch), unconditional_context.unsqueeze().repeat(&[0, n_batch]),
); );
let conditional_latent = self.diffusion.forward(latent, timestep, context); let conditional_latent = self.diffusion.forward(latent, timestep, context);
@@ -206,8 +205,7 @@ impl<B: MyBackend> StableDiffusion<B> {
.collect(); .collect();
self.clip.forward( self.clip.forward(
Tensor::from_ints(&tokenized[..]) Tensor::<B, 1, Int>::from_ints(&tokenized[..], device)
.to_device(device)
.unsqueeze(), .unsqueeze(),
) )
} }
@@ -215,25 +213,25 @@ impl<B: MyBackend> StableDiffusion<B> {
use std::f64::consts::PI; use std::f64::consts::PI;
fn cosine_schedule<B: Backend>(n_steps: usize) -> Tensor<B, 1> { fn cosine_schedule<B: Backend>(n_steps: i64, device: &B::Device) -> Tensor<B, 1> {
Tensor::arange(1..n_steps + 1) Tensor::arange(1..n_steps + 1, device)
.float() .float()
.mul_scalar(PI * 0.5 / n_steps as f64) .mul_scalar(PI * 0.5 / n_steps as f64)
.cos() .cos()
} }
fn offset_cosine_schedule<B: Backend>(n_steps: usize) -> Tensor<B, 1> { fn offset_cosine_schedule<B: Backend>(n_steps: i64, device: &B::Device) -> Tensor<B, 1> {
let min_signal_rate: f64 = 0.02; let min_signal_rate: f64 = 0.02;
let max_signal_rate: f64 = 0.95; let max_signal_rate: f64 = 0.95;
let start_angle = max_signal_rate.acos(); let start_angle = max_signal_rate.acos();
let end_angle = min_signal_rate.acos(); let end_angle = min_signal_rate.acos();
let times = Tensor::arange(1..n_steps + 1).float(); let times = Tensor::arange(1..n_steps + 1, device).float();
let diffusion_angles = times * ((end_angle - start_angle) / n_steps as f64) + start_angle; let diffusion_angles = times * ((end_angle - start_angle) / n_steps as f64) + start_angle;
diffusion_angles.cos() diffusion_angles.cos()
} }
fn offset_cosine_schedule_cumprod<B: Backend>(n_steps: usize) -> Tensor<B, 1> { fn offset_cosine_schedule_cumprod<B: Backend>(n_steps: i64, device: &B::Device) -> Tensor<B, 1> {
offset_cosine_schedule::<B>(n_steps).powf(2.0) offset_cosine_schedule::<B>(n_steps, device).powf_scalar(2.0)
} }

View File

@@ -65,7 +65,7 @@ pub fn load_geglu<B: Backend>(path: &str, device: &B::Device) -> Result<GEGLU<B>
let geglue = GEGLU { let geglue = GEGLU {
proj: proj, proj: proj,
gelu: GELU::new(), // Assuming GELU::new() initializes a new GELU struct gelu: Gelu::new(), // Assuming Gelu::new() initializes a new Gelu struct
}; };
Ok(geglue) Ok(geglue)

View File

@@ -6,7 +6,7 @@ use burn::{
nn::{ nn::{
self, self,
conv::{Conv2d, Conv2dConfig}, conv::{Conv2d, Conv2dConfig},
PaddingConfig2d, GELU, PaddingConfig2d, Gelu,
}, },
tensor::{activation::softmax, backend::Backend, module::embedding, Distribution, Int, Tensor}, tensor::{activation::softmax, backend::Backend, module::embedding, Distribution, Int, Tensor},
}; };
@@ -22,7 +22,7 @@ fn timestep_embedding<B: Backend>(
max_period: usize, max_period: usize,
) -> Tensor<B, 2> { ) -> Tensor<B, 2> {
let half = dim / 2; let half = dim / 2;
let freqs = (Tensor::arange_device(0..half, &timesteps.device()).float() let freqs = (Tensor::arange(0..half as i64, &timesteps.device()).float()
* (-(max_period as f64).ln() / half as f64)) * (-(max_period as f64).ln() / half as f64))
.exp(); .exp();
let args = timesteps.float() * freqs; let args = timesteps.float() * freqs;
@@ -33,50 +33,50 @@ fn timestep_embedding<B: Backend>(
pub struct UNetConfig {} pub struct UNetConfig {}
impl UNetConfig { impl UNetConfig {
pub fn init<B: Backend>(&self) -> UNet<B> { pub fn init<B: Backend>(&self, device: &B::Device) -> UNet<B> {
let lin1_time_embed = nn::LinearConfig::new(320, 1280).init(); let lin1_time_embed = nn::LinearConfig::new(320, 1280).init(device);
let silu_time_embed = SILU::new(); let silu_time_embed = SILU::new();
let lin2_time_embed = nn::LinearConfig::new(1280, 1280).init(); let lin2_time_embed = nn::LinearConfig::new(1280, 1280).init(device);
let input_blocks = UNetInputBlocks { let input_blocks = UNetInputBlocks {
conv: Conv2dConfig::new([4, 320], [3, 3]) conv: Conv2dConfig::new([4, 320], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1)) .with_padding(PaddingConfig2d::Explicit(1, 1))
.init(), .init(device),
rt1: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(), rt1: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(device),
rt2: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(), rt2: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(device),
d1: DownsampleConfig::new(320).init(), d1: DownsampleConfig::new(320).init(device),
rt3: ResTransformerConfig::new(320, 1280, 640, 768, 8).init(), rt3: ResTransformerConfig::new(320, 1280, 640, 768, 8).init(device),
rt4: ResTransformerConfig::new(640, 1280, 640, 768, 8).init(), rt4: ResTransformerConfig::new(640, 1280, 640, 768, 8).init(device),
d2: DownsampleConfig::new(640).init(), d2: DownsampleConfig::new(640).init(device),
rt5: ResTransformerConfig::new(640, 1280, 1280, 768, 8).init(), rt5: ResTransformerConfig::new(640, 1280, 1280, 768, 8).init(device),
rt6: ResTransformerConfig::new(1280, 1280, 1280, 768, 8).init(), rt6: ResTransformerConfig::new(1280, 1280, 1280, 768, 8).init(device),
d3: DownsampleConfig::new(1280).init(), d3: DownsampleConfig::new(1280).init(device),
r1: ResBlockConfig::new(1280, 1280, 1280).init(), r1: ResBlockConfig::new(1280, 1280, 1280).init(device),
r2: ResBlockConfig::new(1280, 1280, 1280).init(), r2: ResBlockConfig::new(1280, 1280, 1280).init(device),
}; };
let middle_block = ResTransformerResConfig::new(1280, 1280, 1280, 768, 8).init(); let middle_block = ResTransformerResConfig::new(1280, 1280, 1280, 768, 8).init(device);
let output_blocks = UNetOutputBlocks { let output_blocks = UNetOutputBlocks {
r1: ResBlockConfig::new(2560, 1280, 1280).init(), r1: ResBlockConfig::new(2560, 1280, 1280).init(device),
r2: ResBlockConfig::new(2560, 1280, 1280).init(), r2: ResBlockConfig::new(2560, 1280, 1280).init(device),
ru: ResUpSampleConfig::new(2560, 1280, 1280).init(), ru: ResUpSampleConfig::new(2560, 1280, 1280).init(device),
rt1: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(), rt1: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(device),
rt2: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(), rt2: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(device),
rtu1: ResTransformerUpsampleConfig::new(1920, 1280, 1280, 768, 8).init(), rtu1: ResTransformerUpsampleConfig::new(1920, 1280, 1280, 768, 8).init(device),
rt3: ResTransformerConfig::new(1920, 1280, 640, 768, 8).init(), rt3: ResTransformerConfig::new(1920, 1280, 640, 768, 8).init(device),
rt4: ResTransformerConfig::new(1280, 1280, 640, 768, 8).init(), rt4: ResTransformerConfig::new(1280, 1280, 640, 768, 8).init(device),
rtu2: ResTransformerUpsampleConfig::new(960, 1280, 640, 768, 8).init(), rtu2: ResTransformerUpsampleConfig::new(960, 1280, 640, 768, 8).init(device),
rt5: ResTransformerConfig::new(960, 1280, 320, 768, 8).init(), rt5: ResTransformerConfig::new(960, 1280, 320, 768, 8).init(device),
rt6: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(), rt6: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(device),
rt7: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(), rt7: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(device),
}; };
let norm_out = GroupNormConfig::new(32, 320).init(); let norm_out = GroupNormConfig::new(32, 320).init(device);
let silu_out = SILU::new(); let silu_out = SILU::new();
let conv_out = Conv2dConfig::new([320, 4], [3, 3]) let conv_out = Conv2dConfig::new([320, 4], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1)) .with_padding(PaddingConfig2d::Explicit(1, 1))
.init(); .init(device);
UNet { UNet {
lin1_time_embed, lin1_time_embed,
@@ -206,16 +206,16 @@ pub struct ResTransformerConfig {
} }
impl ResTransformerConfig { impl ResTransformerConfig {
fn init<B: Backend>(&self) -> ResTransformer<B> { fn init<B: Backend>(&self, device: &B::Device) -> ResTransformer<B> {
let res = ResBlockConfig::new( let res = ResBlockConfig::new(
self.n_channels_in, self.n_channels_in,
self.n_channels_embed, self.n_channels_embed,
self.n_channels_out, self.n_channels_out,
) )
.init(); .init(device);
let transformer = let transformer =
SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head) SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head)
.init(); .init(device);
ResTransformer { res, transformer } ResTransformer { res, transformer }
} }
@@ -243,14 +243,14 @@ pub struct ResUpSampleConfig {
} }
impl ResUpSampleConfig { impl ResUpSampleConfig {
fn init<B: Backend>(&self) -> ResUpSample<B> { fn init<B: Backend>(&self, device: &B::Device) -> ResUpSample<B> {
let res = ResBlockConfig::new( let res = ResBlockConfig::new(
self.n_channels_in, self.n_channels_in,
self.n_channels_embed, self.n_channels_embed,
self.n_channels_out, self.n_channels_out,
) )
.init(); .init(device);
let upsample = UpsampleConfig::new(self.n_channels_out).init(); let upsample = UpsampleConfig::new(self.n_channels_out).init(device);
ResUpSample { res, upsample } ResUpSample { res, upsample }
} }
@@ -280,17 +280,17 @@ pub struct ResTransformerUpsampleConfig {
} }
impl ResTransformerUpsampleConfig { impl ResTransformerUpsampleConfig {
fn init<B: Backend>(&self) -> ResTransformerUpsample<B> { fn init<B: Backend>(&self, device: &B::Device) -> ResTransformerUpsample<B> {
let res = ResBlockConfig::new( let res = ResBlockConfig::new(
self.n_channels_in, self.n_channels_in,
self.n_channels_embed, self.n_channels_embed,
self.n_channels_out, self.n_channels_out,
) )
.init(); .init(device);
let transformer = let transformer =
SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head) SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head)
.init(); .init(device);
let upsample = UpsampleConfig::new(self.n_channels_out).init(); let upsample = UpsampleConfig::new(self.n_channels_out).init(device);
ResTransformerUpsample { ResTransformerUpsample {
res, res,
@@ -326,22 +326,22 @@ pub struct ResTransformerResConfig {
} }
impl ResTransformerResConfig { impl ResTransformerResConfig {
fn init<B: Backend>(&self) -> ResTransformerRes<B> { fn init<B: Backend>(&self, device: &B::Device) -> ResTransformerRes<B> {
let res1 = ResBlockConfig::new( let res1 = ResBlockConfig::new(
self.n_channels_in, self.n_channels_in,
self.n_channels_embed, self.n_channels_embed,
self.n_channels_out, self.n_channels_out,
) )
.init(); .init(device);
let transformer = let transformer =
SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head) SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head)
.init(); .init(device);
let res2 = ResBlockConfig::new( let res2 = ResBlockConfig::new(
self.n_channels_in, self.n_channels_in,
self.n_channels_embed, self.n_channels_embed,
self.n_channels_out, self.n_channels_out,
) )
.init(); .init(device);
ResTransformerRes { ResTransformerRes {
res1, res1,
@@ -373,10 +373,10 @@ pub struct UpsampleConfig {
} }
impl UpsampleConfig { impl UpsampleConfig {
fn init<B: Backend>(&self) -> Upsample<B> { fn init<B: Backend>(&self, device: &B::Device) -> Upsample<B> {
let conv = Conv2dConfig::new([self.n_channels, self.n_channels], [3, 3]) let conv = Conv2dConfig::new([self.n_channels, self.n_channels], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1)) .with_padding(PaddingConfig2d::Explicit(1, 1))
.init(); .init(device);
Upsample { conv } Upsample { conv }
} }
@@ -392,8 +392,7 @@ impl<B: Backend> Upsample<B> {
let [n_batch, n_channel, height, width] = x.dims(); let [n_batch, n_channel, height, width] = x.dims();
let x = x let x = x
.reshape([n_batch, n_channel, height, 1, width, 1]) .reshape([n_batch, n_channel, height, 1, width, 1])
.repeat(3, 2) .repeat(&[1, 1, 1, 2, 1, 2])
.repeat(5, 2)
.reshape([n_batch, n_channel, 2 * height, 2 * width]); .reshape([n_batch, n_channel, 2 * height, 2 * width]);
self.conv.forward(x) self.conv.forward(x)
} }
@@ -411,11 +410,11 @@ pub struct DownsampleConfig {
} }
impl DownsampleConfig { impl DownsampleConfig {
fn init<B: Backend>(&self) -> Conv2d<B> { fn init<B: Backend>(&self, device: &B::Device) -> Conv2d<B> {
Conv2dConfig::new([self.n_channels, self.n_channels], [3, 3]) Conv2dConfig::new([self.n_channels, self.n_channels], [3, 3])
.with_stride([2, 2]) .with_stride([2, 2])
.with_padding(PaddingConfig2d::Explicit(1, 1)) .with_padding(PaddingConfig2d::Explicit(1, 1))
.init() .init(device)
} }
} }
@@ -435,12 +434,12 @@ pub struct SpatialTransformerConfig {
} }
impl SpatialTransformerConfig { impl SpatialTransformerConfig {
fn init<B: Backend>(&self) -> SpatialTransformer<B> { fn init<B: Backend>(&self, device: &B::Device) -> SpatialTransformer<B> {
let norm = GroupNormConfig::new(32, self.n_channels).init(); let norm = GroupNormConfig::new(32, self.n_channels).init(device);
let proj_in = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init(); let proj_in = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init(device);
let transformer = let transformer =
TransformerBlockConfig::new(self.n_channels, self.n_context_state, self.n_head).init(); TransformerBlockConfig::new(self.n_channels, self.n_context_state, self.n_head).init(device);
let proj_out = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init(); let proj_out = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init(device);
SpatialTransformer { SpatialTransformer {
norm, norm,
@@ -489,14 +488,14 @@ pub struct TransformerBlockConfig {
} }
impl TransformerBlockConfig { impl TransformerBlockConfig {
fn init<B: Backend>(&self) -> TransformerBlock<B> { fn init<B: Backend>(&self, device: &B::Device) -> TransformerBlock<B> {
let norm1 = nn::LayerNormConfig::new(self.n_state).init(); let norm1 = nn::LayerNormConfig::new(self.n_state).init(device);
let attn1 = MultiHeadAttentionConfig::new(self.n_state, self.n_state, self.n_head).init(); let attn1 = MultiHeadAttentionConfig::new(self.n_state, self.n_state, self.n_head).init(device);
let norm2 = nn::LayerNormConfig::new(self.n_state).init(); let norm2 = nn::LayerNormConfig::new(self.n_state).init(device);
let attn2 = let attn2 =
MultiHeadAttentionConfig::new(self.n_state, self.n_context_state, self.n_head).init(); MultiHeadAttentionConfig::new(self.n_state, self.n_context_state, self.n_head).init(device);
let norm3 = nn::LayerNormConfig::new(self.n_state).init(); let norm3 = nn::LayerNormConfig::new(self.n_state).init(device);
let mlp = MLPConfig::new(self.n_state, 4).init(); let mlp = MLPConfig::new(self.n_state, 4).init(device);
TransformerBlock { TransformerBlock {
norm1, norm1,
@@ -534,10 +533,10 @@ pub struct MLPConfig {
} }
impl MLPConfig { impl MLPConfig {
pub fn init<B: Backend>(&self) -> MLP<B> { pub fn init<B: Backend>(&self, device: &B::Device) -> MLP<B> {
let n_state_hidden = self.n_state * self.mult; let n_state_hidden = self.n_state * self.mult;
let geglu = GEGLUConfig::new(self.n_state, n_state_hidden).init(); let geglu = GEGLUConfig::new(self.n_state, n_state_hidden).init(device);
let lin = nn::LinearConfig::new(n_state_hidden, self.n_state).init(); let lin = nn::LinearConfig::new(n_state_hidden, self.n_state).init(device);
MLP { geglu, lin } MLP { geglu, lin }
} }
@@ -562,9 +561,9 @@ pub struct GEGLUConfig {
} }
impl GEGLUConfig { impl GEGLUConfig {
fn init<B: Backend>(&self) -> GEGLU<B> { fn init<B: Backend>(&self, device: &B::Device) -> GEGLU<B> {
let proj = nn::LinearConfig::new(self.n_state_in, 2 * self.n_state_out).init(); let proj = nn::LinearConfig::new(self.n_state_in, 2 * self.n_state_out).init(device);
let gelu = GELU::new(); let gelu = Gelu::new();
GEGLU { proj, gelu } GEGLU { proj, gelu }
} }
@@ -573,7 +572,7 @@ impl GEGLUConfig {
#[derive(Module, Debug)] #[derive(Module, Debug)]
pub struct GEGLU<B: Backend> { pub struct GEGLU<B: Backend> {
proj: nn::Linear<B>, proj: nn::Linear<B>,
gelu: GELU, gelu: Gelu,
} }
impl<B: Backend> GEGLU<B> { impl<B: Backend> GEGLU<B> {
@@ -600,7 +599,7 @@ pub struct MultiHeadAttentionConfig {
} }
impl MultiHeadAttentionConfig { impl MultiHeadAttentionConfig {
fn init<B: Backend>(&self) -> MultiHeadAttention<B> { fn init<B: Backend>(&self, device: &B::Device) -> MultiHeadAttention<B> {
assert!( assert!(
self.n_state % self.n_head == 0, self.n_state % self.n_head == 0,
"State size {} must be a multiple of head size {}", "State size {} must be a multiple of head size {}",
@@ -611,14 +610,14 @@ impl MultiHeadAttentionConfig {
let n_head = self.n_head; let n_head = self.n_head;
let query = nn::LinearConfig::new(self.n_state, self.n_state) let query = nn::LinearConfig::new(self.n_state, self.n_state)
.with_bias(false) .with_bias(false)
.init(); .init(device);
let key = nn::LinearConfig::new(self.n_context_state, self.n_state) let key = nn::LinearConfig::new(self.n_context_state, self.n_state)
.with_bias(false) .with_bias(false)
.init(); .init(device);
let value = nn::LinearConfig::new(self.n_context_state, self.n_state) let value = nn::LinearConfig::new(self.n_context_state, self.n_state)
.with_bias(false) .with_bias(false)
.init(); .init(device);
let out = nn::LinearConfig::new(self.n_state, self.n_state).init(); let out = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
MultiHeadAttention { MultiHeadAttention {
n_head, n_head,
@@ -661,24 +660,24 @@ pub struct ResBlockConfig {
} }
impl ResBlockConfig { impl ResBlockConfig {
fn init<B: Backend>(&self) -> ResBlock<B> { fn init<B: Backend>(&self, device: &B::Device) -> ResBlock<B> {
let norm_in = GroupNormConfig::new(32, self.n_channels_in).init(); let norm_in = GroupNormConfig::new(32, self.n_channels_in).init(device);
let silu_in = SILU::new(); let silu_in = SILU::new();
let conv_in = Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [3, 3]) let conv_in = Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1)) .with_padding(PaddingConfig2d::Explicit(1, 1))
.init(); .init(device);
let silu_embed = SILU::new(); let silu_embed = SILU::new();
let lin_embed = nn::LinearConfig::new(self.n_channels_embed, self.n_channels_out).init(); let lin_embed = nn::LinearConfig::new(self.n_channels_embed, self.n_channels_out).init(device);
let norm_out = GroupNormConfig::new(32, self.n_channels_out).init(); let norm_out = GroupNormConfig::new(32, self.n_channels_out).init(device);
let silu_out = SILU::new(); let silu_out = SILU::new();
let conv_out = Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3]) let conv_out = Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1)) .with_padding(PaddingConfig2d::Explicit(1, 1))
.init(); .init(device);
let skip_connection = if self.n_channels_in != self.n_channels_out { let skip_connection = if self.n_channels_in != self.n_channels_out {
Some(Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [1, 1]).init()) Some(Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [1, 1]).init(device))
} else { } else {
None None
}; };