From 75f0cedd9fb60e40e412f43eed8cbd30a6e028bc Mon Sep 17 00:00:00 2001 From: Hermes Date: Sat, 5 Oct 2024 14:19:49 -0400 Subject: [PATCH] Update to burn v0.14.0 and switch to .mpk model file --- Cargo.toml | 10 +- README.md | 21 ++-- python/dump.py | 7 +- python/requirements.txt | 1 + src/backend.rs | 43 ++++---- src/bin/convert/main.rs | 8 +- src/bin/sample/main.rs | 58 +++++++---- src/model/attention.rs | 6 +- src/model/autoencoder/load.rs | 15 +-- src/model/autoencoder/mod.rs | 137 ++++++++++++++----------- src/model/clip/load.rs | 2 +- src/model/clip/mod.rs | 60 ++++++----- src/model/groupnorm/load.rs | 12 +-- src/model/groupnorm/mod.rs | 6 +- src/model/load.rs | 76 ++++++-------- src/model/stablediffusion/load.rs | 2 +- src/model/stablediffusion/mod.rs | 46 ++++----- src/model/unet/load.rs | 2 +- src/model/unet/mod.rs | 165 +++++++++++++++--------------- 19 files changed, 366 insertions(+), 311 deletions(-) create mode 100644 python/requirements.txt diff --git a/Cargo.toml b/Cargo.toml index 3737c0e..00c23aa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,11 +14,11 @@ git = "https://github.com/burn-rs/burn.git" optional = true [dependencies] -burn = { git = "https://github.com/burn-rs/burn.git" } -burn-ndarray = { package = "burn-ndarray", git = "https://github.com/burn-rs/burn.git" } -burn-tch = { package = "burn-tch", git = "https://github.com/burn-rs/burn.git" } -burn-autodiff = { package = "burn-autodiff", git = "https://github.com/burn-rs/burn.git" } -tch = "0.13.0" +burn = "0.14.0" +burn-ndarray = "0.14.0" +burn-tch = "0.14.0" +burn-autodiff = "0.14.0" +tch = "0.15.0" serde = {version = "1.0.171", features = ["std", "derive"]} npy = "0.4.0" num-traits = "0.2.15" diff --git a/README.md b/README.md index fa564ee..e92c52c 100644 --- a/README.md +++ b/README.md @@ -4,12 +4,14 @@ Stable-Diffusion-Burn is a Rust-based project which ports the V1 stable diffusio ## How To Use +### Step 0: Install libtorch v2.4.1 + ### Step 1: Download the Model and Set Environment Variables -Start by downloading the SDv1-4.bin model provided on HuggingFace. +Start by downloading the SDv1-4 model provided on HuggingFace. ```bash -wget https://huggingface.co/Gadersd/Stable-Diffusion-Burn/resolve/main/V1/SDv1-4.bin +wget https://huggingface.co/Gadersd/Stable-Diffusion-Burn/resolve/main/SDv1-4.mpk ``` ### Step 2: Run the Sample Binary @@ -18,9 +20,13 @@ Invoke the sample binary provided in the rust code. By default, torch is used. T ```bash # torch (at least 6 GB VRAM, possibly less) -export TORCH_CUDA_VERSION=cu113 -# Arguments: -cargo run --release --bin sample burn SDv1-4 7.5 20 "An ancient mossy stone." img +# Arguments: [cuda, mps, cpu] + +# Cuda +cargo run --release --bin sample burn SDv1-4 7.5 20 "An ancient mossy stone." img cuda + +# Mps(Mac) +cargo run --release --bin sample burn SDv1-4 7.5 20 "An ancient mossy stone." img mps # wgpu (UNSTABLE) # Arguments: @@ -33,7 +39,7 @@ This command will generate an image according to the provided prompt, which will ### Optional: Extract and Convert a Fine-Tuned Model -If users are interested in using a fine-tuned version of stable diffusion, the Python scripts provided in this project can be used to transform a weight dump into a Burn model file. Note: the tinygrad dependency should be installed from source rather than with pip. +If users are interested in using a fine-tuned version of stable diffusion, the Python scripts provided in this project can be used to transform a weight dump into a Burn model file. This does not work on Windows. ```bash # Step into the Python directory @@ -42,6 +48,9 @@ cd python # Download the model, this is just the base v1.4 model as an example wget https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt +# Install tinygrad +pip install -r requirements.txt + # Extract the weights CPU=1 python3 dump.py sd-v1-4.ckpt diff --git a/python/dump.py b/python/dump.py index f604b35..66a994a 100644 --- a/python/dump.py +++ b/python/dump.py @@ -13,10 +13,11 @@ from collections import namedtuple from tqdm import tqdm from tinygrad.tensor import Tensor -from tinygrad.helpers import dtypes, GlobalCounters +from tinygrad.helpers import GlobalCounters +from tinygrad import dtypes from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding -from extra.utils import download_file -from tinygrad.state import torch_load, load_state_dict +#from extra.utils import download_file +from tinygrad.nn.state import torch_load, load_state_dict # TODO: refactor AttnBlock, CrossAttention, CLIPAttention to share code diff --git a/python/requirements.txt b/python/requirements.txt new file mode 100644 index 0000000..d1c7ebd --- /dev/null +++ b/python/requirements.txt @@ -0,0 +1 @@ +tinygrad==0.9.2 \ No newline at end of file diff --git a/src/backend.rs b/src/backend.rs index c33710d..ff41e40 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -1,13 +1,16 @@ use burn::tensor::{activation::softmax, Tensor}; +use burn::prelude::Backend; + +/*pub type FloatTensor = ::TensorPrimitive; pub trait Backend: burn::tensor::backend::Backend { fn qkv_attention( - q: Self::TensorPrimitive<3>, - k: Self::TensorPrimitive<3>, - v: Self::TensorPrimitive<3>, - mask: Option>, + q: FloatTensor, + k: FloatTensor, + v: FloatTensor, + mask: Option>, n_head: usize, - ) -> Self::TensorPrimitive<3> { + ) -> FloatTensor { qkv_attention( Tensor::::from_primitive(q), Tensor::from_primitive(k), @@ -18,24 +21,23 @@ pub trait Backend: burn::tensor::backend::Backend { .into_primitive() } - fn attn_decoder_mask(seq_length: usize, device: &Self::Device) -> Self::TensorPrimitive<2> { + fn attn_decoder_mask(seq_length: usize, device: &Self::Device) -> FloatTensor { attn_decoder_mask::(seq_length, device).into_primitive() } } -use burn::tensor::ops::TensorOps; use burn::tensor::Float; use burn_tch::{self, TchElement, TchTensor}; use tch; -impl Backend for burn_tch::TchBackend { +impl Backend for burn_tch::LibTorch { fn qkv_attention( - q: Self::TensorPrimitive<3>, - k: Self::TensorPrimitive<3>, - v: Self::TensorPrimitive<3>, - mask: Option>, + q: FloatTensor, + k: FloatTensor, + v: FloatTensor, + mask: Option>, n_head: usize, - ) -> Self::TensorPrimitive<3> { + ) -> FloatTensor { let q = Tensor::from_primitive(q); let k = Tensor::from_primitive(k); let v = Tensor::from_primitive(v); @@ -56,7 +58,7 @@ impl Backend for burn_tch::TchBackend { // for some reason torch crashes when mask is None let mask = mask.unwrap_or_else(|| { - Tensor::::zeros_device([q_ctx, k_ctx], &Self::device(&v)) + Tensor::::zeros([q_ctx, k_ctx], &Self::device(&v)) .into_primitive() }); @@ -68,6 +70,7 @@ impl Backend for burn_tch::TchBackend { Some(mask.tensor), 0.0, false, + None, ), )) .swap_dims(1, 2) @@ -78,11 +81,11 @@ impl Backend for burn_tch::TchBackend { use burn_autodiff; -impl Backend for burn_autodiff::ADBackendDecorator {} +impl Backend for burn_autodiff::Autodiff {}*/ use std::f32::NEG_INFINITY; -fn qkv_attention( +pub fn qkv_attention( q: Tensor, k: Tensor, v: Tensor, @@ -124,13 +127,13 @@ fn qkv_attention( return o; } -fn attn_decoder_mask(seq_length: usize, device: &B::Device) -> Tensor { - let mut mask = Tensor::::zeros([seq_length, seq_length]); +pub fn attn_decoder_mask(seq_length: usize, device: &B::Device) -> Tensor { + let mut mask = Tensor::::zeros([seq_length, seq_length], device); for i in 0..(seq_length - 1) { - let values = Tensor::::zeros([1, seq_length - (i + 1)]).add_scalar(NEG_INFINITY); + let values = Tensor::::zeros([1, seq_length - (i + 1)], device).add_scalar(NEG_INFINITY); mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values); } - return mask.to_device(device); + return mask; } diff --git a/src/bin/convert/main.rs b/src/bin/convert/main.rs index f6b9697..b133e33 100644 --- a/src/bin/convert/main.rs +++ b/src/bin/convert/main.rs @@ -11,9 +11,9 @@ use burn::{ tensor::{backend::Backend, Tensor}, }; -use burn_ndarray::{NdArrayBackend, NdArrayDevice}; +use burn_ndarray::{NdArray, NdArrayDevice}; -use burn::record::{self, BinFileRecorder, FullPrecisionSettings, Recorder}; +use burn::record::{self, NamedMpkFileRecorder, FullPrecisionSettings, Recorder}; fn convert_dump_to_model( dump_path: &str, @@ -33,11 +33,11 @@ fn save_model_file( model: StableDiffusion, name: &str, ) -> Result<(), record::RecorderError> { - BinFileRecorder::::new().record(model.into_record(), name.into()) + NamedMpkFileRecorder::::new().record(model.into_record(), name.into()) } fn main() { - type Backend = NdArrayBackend; + type Backend = NdArray; let device = NdArrayDevice::Cpu; let args: Vec = env::args().collect(); diff --git a/src/bin/sample/main.rs b/src/bin/sample/main.rs index 214dc5c..97cc93f 100644 --- a/src/bin/sample/main.rs +++ b/src/bin/sample/main.rs @@ -14,7 +14,7 @@ cfg_if::cfg_if! { if #[cfg(feature = "wgpu-backend")] { use burn_wgpu::{WgpuBackend, WgpuDevice, AutoGraphicsApi}; } else { - use burn_tch::{TchBackend, TchDevice}; + use burn_tch::{LibTorch, LibTorchDevice}; } } @@ -22,30 +22,21 @@ use std::env; use std::io; use std::process; -use burn::record::{self, BinFileRecorder, FullPrecisionSettings, Recorder}; +use burn::record::{self, NamedMpkFileRecorder, FullPrecisionSettings, Recorder}; fn load_stable_diffusion_model_file( filename: &str, + device: &B::Device, ) -> Result, record::RecorderError> { - BinFileRecorder::::new() - .load(filename.into()) - .map(|record| StableDiffusionConfig::new().init().load_record(record)) + NamedMpkFileRecorder::::new() + .load(filename.into(), device) + .map(|record| StableDiffusionConfig::new().init(device).load_record(record)) } fn main() { - cfg_if::cfg_if! { - if #[cfg(feature = "wgpu-backend")] { - type Backend = WgpuBackend; - let device = WgpuDevice::BestAvailable; - } else { - type Backend = TchBackend; - let device = TchDevice::Cuda(0); - } - } - let args: Vec = std::env::args().collect(); - if args.len() != 7 { - eprintln!("Usage: {} ", args[0]); + if args.len() != 7 && args.len() != 8 { + eprintln!("Usage: {} [device(cuda, mps, cpu)]", args[0]); process::exit(1); } @@ -62,11 +53,40 @@ fn main() { let prompt = &args[5]; let output_image_name = &args[6]; + // Optional device parameter + let device_arg = if args.len() == 8 { Some(&args[7]) } else { None }; + + cfg_if::cfg_if! { + if #[cfg(feature = "wgpu-backend")] { + type Backend = WgpuBackend; + let device = WgpuDevice::BestAvailable; + } else { + type Backend = LibTorch; + + let device = if let Some(dev_str) = device_arg { + match dev_str.to_lowercase().as_str() { + "cpu" => LibTorchDevice::Cpu, + "mps" => LibTorchDevice::Mps, + s if s.starts_with("cuda") => { + let idx = s[4..].parse().unwrap_or(0); + LibTorchDevice::Cuda(idx) + } + _ => { + eprintln!("Unknown device: {}", dev_str); + process::exit(1); + } + } + } else { + LibTorchDevice::Cuda(0) + }; + } + } + println!("Loading tokenizer..."); let tokenizer = SimpleTokenizer::new().unwrap(); println!("Loading model..."); let sd: StableDiffusion = if model_type == "burn" { - load_stable_diffusion_model_file(model_name).unwrap_or_else(|err| { + load_stable_diffusion_model_file(model_name, &device).unwrap_or_else(|err| { eprintln!("Error loading model: {}", err); process::exit(1); }) @@ -77,8 +97,6 @@ fn main() { }) }; - let sd = sd.to_device(&device); - let unconditional_context = sd.unconditional_context(&tokenizer); let context = sd.context(&tokenizer, prompt).unsqueeze::<3>(); //.repeat(0, 2); // generate 2 samples diff --git a/src/model/attention.rs b/src/model/attention.rs index 158afee..f9d73c5 100644 --- a/src/model/attention.rs +++ b/src/model/attention.rs @@ -45,12 +45,12 @@ pub fn qkv_attention( } pub fn attn_decoder_mask(seq_length: usize, device: &B::Device) -> Tensor { - let mut mask = Tensor::::zeros([seq_length, seq_length]); + let mut mask = Tensor::::zeros([seq_length, seq_length], device); for i in 0..(seq_length - 1) { - let values = Tensor::::zeros([1, seq_length - (i + 1)]).add_scalar(NEG_INFINITY); + let values = Tensor::::zeros([1, seq_length - (i + 1)], device).add_scalar(NEG_INFINITY); mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values); } - return mask.to_device(device); + return mask; } diff --git a/src/model/autoencoder/load.rs b/src/model/autoencoder/load.rs index d06aa1e..2ae3eaa 100644 --- a/src/model/autoencoder/load.rs +++ b/src/model/autoencoder/load.rs @@ -71,7 +71,7 @@ fn load_padded_conv2d( path: &str, device: &B::Device, ) -> Result, Box> { - let conv = load_conv2d(&format!("{}/{}", path, "conv"), device)?; + let mut conv = load_conv2d(&format!("{}/{}", path, "conv"), device)?; let channels = load_tensor::("channels", path, device)?; let channels = tensor_to_array_2(channels); @@ -81,18 +81,21 @@ fn load_padded_conv2d( let padding = load_tensor::("padding", path, device)?; let padding: [usize; 4] = tensor_to_array(padding); - let padding = Padding::new(padding[0], padding[1], padding[2], padding[3]); + let padding = PaddingCfg::new(padding[0], padding[1], padding[2], padding[3]); - let mut record = conv.into_record(); + //let mut record = conv.into_record(); let mut padded_conv: PaddedConv2d = PaddedConv2dConfig::new(channels, kernel_size, padding) .with_stride(stride) - .init(); + .init(device); let padding_actual = PaddingConfig2d::Explicit(padded_conv.padding_actual[0], padded_conv.padding_actual[1]); - record.padding = >::into_record(padding_actual); - padded_conv.conv = padded_conv.conv.load_record(record); + conv.padding = burn::module::Ignored(padding_actual); + padded_conv.conv = conv; + + //record.padding = >::into_record(padding_actual); + //padded_conv.conv = padded_conv.conv.load_record(record); Ok(padded_conv) } diff --git a/src/model/autoencoder/mod.rs b/src/model/autoencoder/mod.rs index 8a7a58b..9c40774 100644 --- a/src/model/autoencoder/mod.rs +++ b/src/model/autoencoder/mod.rs @@ -18,7 +18,8 @@ use burn::{ use super::groupnorm::*; use super::silu::*; -use crate::backend::Backend as MyBackend; +//use crate::backend::Backend as MyBackend; +use crate::backend::{qkv_attention, attn_decoder_mask}; use std::iter; @@ -26,13 +27,13 @@ use std::iter; pub struct AutoencoderConfig {} impl AutoencoderConfig { - pub fn init(&self) -> Autoencoder { + pub fn init(&self, device: &B::Device) -> Autoencoder { let encoder = - EncoderConfig::new(vec![(128, 128), (128, 256), (256, 512), (512, 512)], 32, 8).init(); + EncoderConfig::new(vec![(128, 128), (128, 256), (256, 512), (512, 512)], 32, 8).init(device); let decoder = - DecoderConfig::new(vec![(512, 512), (512, 512), (512, 256), (256, 128)], 32).init(); - let quant_conv = Conv2dConfig::new([8, 8], [1, 1]).init(); - let post_quant_conv = Conv2dConfig::new([4, 4], [1, 1]).init(); + DecoderConfig::new(vec![(512, 512), (512, 512), (512, 256), (256, 128)], 32).init(device); + let quant_conv = Conv2dConfig::new([8, 8], [1, 1]).init(device); + let post_quant_conv = Conv2dConfig::new([4, 4], [1, 1]).init(device); Autoencoder { encoder, @@ -51,7 +52,7 @@ pub struct Autoencoder { post_quant_conv: Conv2d, } -impl Autoencoder { +impl Autoencoder { pub fn forward(&self, x: Tensor) -> Tensor { self.decode_latent(self.encode_image(x)) } @@ -78,7 +79,7 @@ pub struct EncoderConfig { } impl EncoderConfig { - fn init(&self) -> Encoder { + fn init(&self, device: &B::Device) -> Encoder { let n_expanded_channels_initial = self .channels .first() @@ -88,7 +89,7 @@ impl EncoderConfig { let conv_in = Conv2dConfig::new([3, n_expanded_channels_initial], [3, 3]) .with_padding(PaddingConfig2d::Explicit(1, 1)) - .init(); + .init(device); let blocks = self .channels @@ -96,16 +97,16 @@ impl EncoderConfig { .enumerate() .map(|(i, &(n_channel_in, n_channel_out))| { let downsample = i != self.channels.len() - 1; - EncoderBlockConfig::new(n_channel_in, n_channel_out, downsample).init() + EncoderBlockConfig::new(n_channel_in, n_channel_out, downsample).init(device) }) .collect(); - let mid = MidConfig::new(n_expanded_channels_final).init(); - let norm_out = GroupNormConfig::new(self.n_group, n_expanded_channels_final).init(); + let mid = MidConfig::new(n_expanded_channels_final).init(device); + let norm_out = GroupNormConfig::new(self.n_group, n_expanded_channels_final).init(device); let silu = SILU::new(); let conv_out = Conv2dConfig::new([n_expanded_channels_final, self.n_channels_out], [3, 3]) .with_padding(PaddingConfig2d::Explicit(1, 1)) - .init(); + .init(device); Encoder { conv_in, @@ -128,7 +129,7 @@ pub struct Encoder { conv_out: Conv2d, } -impl Encoder { +impl Encoder { fn forward(&self, x: Tensor) -> Tensor { let x = self.conv_in.forward(x); @@ -150,7 +151,7 @@ pub struct DecoderConfig { } impl DecoderConfig { - fn init(&self) -> Decoder { + fn init(&self, device: &B::Device) -> Decoder { let n_expanded_channels = self .channels .first() @@ -160,8 +161,8 @@ impl DecoderConfig { let conv_in = Conv2dConfig::new([4, n_expanded_channels], [3, 3]) .with_padding(PaddingConfig2d::Explicit(1, 1)) - .init(); - let mid = MidConfig::new(n_expanded_channels).init(); + .init(device); + let mid = MidConfig::new(n_expanded_channels).init(device); let blocks = self .channels @@ -169,15 +170,15 @@ impl DecoderConfig { .enumerate() .map(|(i, &(n_channel_in, n_channel_out))| { let upsample = i != self.channels.len() - 1; - DecoderBlockConfig::new(n_channel_in, n_channel_out, upsample).init() + DecoderBlockConfig::new(n_channel_in, n_channel_out, upsample).init(device) }) .collect(); - let norm_out = GroupNormConfig::new(self.n_group, n_condensed_channels).init(); + let norm_out = GroupNormConfig::new(self.n_group, n_condensed_channels).init(device); let silu = SILU::new(); let conv_out = Conv2dConfig::new([n_condensed_channels, 3], [3, 3]) .with_padding(PaddingConfig2d::Explicit(1, 1)) - .init(); + .init(device); Decoder { conv_in, @@ -200,7 +201,7 @@ pub struct Decoder { conv_out: Conv2d, } -impl Decoder { +impl Decoder { fn forward(&self, x: Tensor) -> Tensor { let x = self.conv_in.forward(x); let x = self.mid.forward(x); @@ -223,15 +224,15 @@ pub struct EncoderBlockConfig { } impl EncoderBlockConfig { - fn init(&self) -> EncoderBlock { - let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init(); - let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(); + fn init(&self, device: &B::Device) -> EncoderBlock { + let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init(device); + let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(device); let downsampler = if self.downsample { - let padding = Padding::new(0, 1, 0, 1); + let padding = PaddingCfg::new(0, 1, 0, 1); Some( PaddedConv2dConfig::new([self.n_channels_out, self.n_channels_out], 3, padding) .with_stride(2) - .init(), + .init(device), ) } else { None @@ -272,15 +273,15 @@ pub struct DecoderBlockConfig { } impl DecoderBlockConfig { - fn init(&self) -> DecoderBlock { - let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init(); - let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(); - let res3 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(); + fn init(&self, device: &B::Device) -> DecoderBlock { + let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init(device); + let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(device); + let res3 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(device); let upsampler = if self.upsample { Some( Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3]) .with_padding(PaddingConfig2d::Explicit(1, 1)) - .init(), + .init(device), ) } else { None @@ -313,8 +314,7 @@ impl DecoderBlock { let [n_batch, n_channel, height, width] = x.dims(); let x = x .reshape([n_batch, n_channel, height, 1, width, 1]) - .repeat(3, 2) - .repeat(5, 2) + .repeat(&[1, 1, 1, 2, 1, 2]) .reshape([n_batch, n_channel, 2 * height, 2 * width]); d.forward(x) } else { @@ -329,11 +329,11 @@ pub struct PaddedConv2dConfig { kernel_size: usize, #[config(default = 1)] stride: usize, - padding: Padding, + padding: PaddingCfg, } impl PaddedConv2dConfig { - fn init(&self) -> PaddedConv2d { + fn init(&self, device: &B::Device) -> PaddedConv2d { let calc_padding = |p_left, p_right| { let n = if p_left >= p_right { 0 @@ -351,12 +351,17 @@ impl PaddedConv2dConfig { let conv = Conv2dConfig::new(self.channels, [self.kernel_size, self.kernel_size]) .with_stride([self.stride, self.stride]) .with_padding(PaddingConfig2d::Explicit(pad_vertical, pad_horizontal)) - .init(); + .init(device); let kernel_size = self.kernel_size; let stride = self.stride; - let padding = self.padding; + let padding = Padding { + pad_left: self.padding.pad_left, + pad_right: self.padding.pad_right, + pad_top: self.padding.pad_top, + pad_bottom: self.padding.pad_bottom, + }; PaddedConv2d { conv, @@ -406,7 +411,15 @@ impl PaddedConv2d { } } -#[derive(Config, Module, Copy, Debug)] +#[derive(Config, Debug)] +pub struct PaddingCfg { + pad_left: usize, + pad_right: usize, + pad_top: usize, + pad_bottom: usize, +} + +#[derive(Module, Clone, Debug)] pub struct Padding { pad_left: usize, pad_right: usize, @@ -420,10 +433,10 @@ pub struct MidConfig { } impl MidConfig { - fn init(&self) -> Mid { - let block_1 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init(); - let attn = ConvSelfAttentionBlockConfig::new(self.n_channel).init(); - let block_2 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init(); + fn init(&self, device: &B::Device) -> Mid { + let block_1 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init(device); + let attn = ConvSelfAttentionBlockConfig::new(self.n_channel).init(device); + let block_2 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init(device); Mid { block_1, @@ -440,7 +453,7 @@ pub struct Mid { block_2: ResnetBlock, } -impl Mid { +impl Mid { fn forward(&self, x: Tensor) -> Tensor { let x = self.block_1.forward(x); let x = self.attn.forward(x); @@ -456,17 +469,17 @@ pub struct ResnetBlockConfig { } impl ResnetBlockConfig { - fn init(&self) -> ResnetBlock { - let norm1 = GroupNormConfig::new(32, self.in_channels).init(); + fn init(&self, device: &B::Device) -> ResnetBlock { + let norm1 = GroupNormConfig::new(32, self.in_channels).init(device); let conv1 = Conv2dConfig::new([self.in_channels, self.out_channels], [3, 3]) .with_padding(PaddingConfig2d::Explicit(1, 1)) - .init(); - let norm2 = GroupNormConfig::new(32, self.out_channels).init(); + .init(device); + let norm2 = GroupNormConfig::new(32, self.out_channels).init(device); let conv2 = Conv2dConfig::new([self.out_channels, self.out_channels], [3, 3]) .with_padding(PaddingConfig2d::Explicit(1, 1)) - .init(); + .init(device); let nin_shortcut = if self.in_channels != self.out_channels { - Some(Conv2dConfig::new([self.in_channels, self.out_channels], [1, 1]).init()) + Some(Conv2dConfig::new([self.in_channels, self.out_channels], [1, 1]).init(device)) } else { None }; @@ -520,12 +533,12 @@ pub struct ConvSelfAttentionBlockConfig { } impl ConvSelfAttentionBlockConfig { - fn init(&self) -> ConvSelfAttentionBlock { - let norm = GroupNormConfig::new(32, self.n_channel).init(); - let q = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(); - let k = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(); - let v = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(); - let proj_out = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(); + fn init(&self, device: &B::Device) -> ConvSelfAttentionBlock { + let norm = GroupNormConfig::new(32, self.n_channel).init(device); + let q = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device); + let k = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device); + let v = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device); + let proj_out = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device); ConvSelfAttentionBlock { norm, @@ -546,7 +559,7 @@ pub struct ConvSelfAttentionBlock { proj_out: Conv2d, } -impl ConvSelfAttentionBlock { +impl ConvSelfAttentionBlock { fn forward(&self, x: Tensor) -> Tensor { let [n_batch, n_channel, height, width] = x.dims(); @@ -568,7 +581,7 @@ impl ConvSelfAttentionBlock { .reshape([n_batch, n_channel, height * width]) .swap_dims(1, 2); - let wv = Tensor::from_primitive(B::qkv_attention( + /*let wv = Tensor::from_primitive(B::qkv_attention( q.into_primitive(), k.into_primitive(), v.into_primitive(), @@ -576,6 +589,16 @@ impl ConvSelfAttentionBlock { 1, )) .swap_dims(1, 2) + .reshape([n_batch, n_channel, height, width]);*/ + + let wv = qkv_attention( + q, + k, + v, + None, + 1, + ) + .swap_dims(1, 2) .reshape([n_batch, n_channel, height, width]); let projected = self.proj_out.forward(wv); diff --git a/src/model/clip/load.rs b/src/model/clip/load.rs index d8f20d7..d917f93 100644 --- a/src/model/clip/load.rs +++ b/src/model/clip/load.rs @@ -68,7 +68,7 @@ pub fn load_residual_decoder_attention_block( pub fn load_clip(path: &str, device: &B::Device) -> Result, Box> { let token_embedding = load_embedding(&format!("{}/{}", path, "token_embedding"), device)?; let position_embedding = - load_tensor("weight", &format!("{}/position_embedding", path), device)?.into(); + Param::from_tensor(load_tensor("weight", &format!("{}/position_embedding", path), device)?); let n_layer = load_usize::("n_layer", path, device)?; let mut blocks = (0..n_layer) diff --git a/src/model/clip/mod.rs b/src/model/clip/mod.rs index 583e3ef..a8256f4 100644 --- a/src/model/clip/mod.rs +++ b/src/model/clip/mod.rs @@ -12,7 +12,8 @@ use burn::{ }, }; -use crate::backend::Backend as MyBackend; +//use crate::backend::Backend as MyBackend; +use crate::backend::{qkv_attention, attn_decoder_mask}; #[derive(Config)] pub struct CLIPConfig { @@ -24,15 +25,15 @@ pub struct CLIPConfig { } impl CLIPConfig { - pub fn init(&self) -> CLIP { - let token_embedding = nn::EmbeddingConfig::new(self.n_vocab, self.n_state).init(); + pub fn init(&self, device: &B::Device) -> CLIP { + let token_embedding = nn::EmbeddingConfig::new(self.n_vocab, self.n_state).init(device); let position_embedding = - Tensor::random([self.n_ctx, self.n_state], Distribution::Normal(0.0, 1.0)).into(); + Param::from_tensor(Tensor::random([self.n_ctx, self.n_state], Distribution::Normal(0.0, 1.0), device)); let blocks = (0..self.n_layer) .into_iter() - .map(|_| ResidualDecoderAttentionBlockConfig::new(self.n_state, self.n_head).init()) + .map(|_| ResidualDecoderAttentionBlockConfig::new(self.n_state, self.n_head).init(device)) .collect(); - let layer_norm = nn::LayerNormConfig::new(self.n_state).init(); + let layer_norm = nn::LayerNormConfig::new(self.n_state).init(device); CLIP { token_embedding, @@ -51,11 +52,12 @@ pub struct CLIP { layer_norm: nn::LayerNorm, } -impl CLIP { +impl CLIP { pub fn forward(&self, x: Tensor) -> Tensor { let [n_batch, seq_len] = x.dims(); - let mask = Tensor::from_primitive(B::attn_decoder_mask(seq_len, &x.device())); + //let mask = Tensor::from_primitive(B::attn_decoder_mask(seq_len, &x.device())); + let mask = attn_decoder_mask(seq_len, &x.device()); let embedded = self.token_embedding.forward(x) + self @@ -80,12 +82,12 @@ pub struct ResidualDecoderAttentionBlockConfig { } impl ResidualDecoderAttentionBlockConfig { - pub fn init(&self) -> ResidualDecoderAttentionBlock { - let attn = MultiHeadSelfAttentionConfig::new(self.n_state, self.n_head).init(); - let attn_ln = nn::LayerNormConfig::new(self.n_state).init(); + pub fn init(&self, device: &B::Device) -> ResidualDecoderAttentionBlock { + let attn = MultiHeadSelfAttentionConfig::new(self.n_state, self.n_head).init(device); + let attn_ln = nn::LayerNormConfig::new(self.n_state).init(device); - let mlp = MLPConfig::new(self.n_state, 4 * self.n_state).init(); - let mlp_ln = nn::LayerNormConfig::new(self.n_state).init(); + let mlp = MLPConfig::new(self.n_state, 4 * self.n_state).init(device); + let mlp_ln = nn::LayerNormConfig::new(self.n_state).init(device); ResidualDecoderAttentionBlock { attn, @@ -104,7 +106,7 @@ pub struct ResidualDecoderAttentionBlock { mlp_ln: nn::LayerNorm, } -impl ResidualDecoderAttentionBlock { +impl ResidualDecoderAttentionBlock { fn forward(&self, x: Tensor, mask: Tensor) -> Tensor { let x = x.clone() + self.attn.forward(self.attn_ln.forward(x), Some(mask)); let x = x.clone() + self.mlp.forward(self.mlp_ln.forward(x)); @@ -119,7 +121,7 @@ pub struct MultiHeadSelfAttentionConfig { } impl MultiHeadSelfAttentionConfig { - fn init(&self) -> MultiHeadSelfAttention { + fn init(&self, device: &B::Device) -> MultiHeadSelfAttention { assert!( self.n_state % self.n_head == 0, "State size {} must be a multiple of head size {}", @@ -128,10 +130,10 @@ impl MultiHeadSelfAttentionConfig { ); let n_head = self.n_head; - let query = nn::LinearConfig::new(self.n_state, self.n_state).init(); - let key = nn::LinearConfig::new(self.n_state, self.n_state).init(); - let value = nn::LinearConfig::new(self.n_state, self.n_state).init(); - let out = nn::LinearConfig::new(self.n_state, self.n_state).init(); + let query = nn::LinearConfig::new(self.n_state, self.n_state).init(device); + let key = nn::LinearConfig::new(self.n_state, self.n_state).init(device); + let value = nn::LinearConfig::new(self.n_state, self.n_state).init(device); + let out = nn::LinearConfig::new(self.n_state, self.n_state).init(device); MultiHeadSelfAttention { n_head, @@ -152,19 +154,27 @@ pub struct MultiHeadSelfAttention { out: nn::Linear, } -impl MultiHeadSelfAttention { +impl MultiHeadSelfAttention { pub fn forward(&self, x: Tensor, mask: Option>) -> Tensor { let q = self.query.forward(x.clone()); let k = self.key.forward(x.clone()); let v = self.value.forward(x); - let wv = Tensor::from_primitive(B::qkv_attention( + /*let wv = Tensor::from_primitive(B::qkv_attention( q.into_primitive(), k.into_primitive(), v.into_primitive(), mask.map(|m| m.into_primitive()), self.n_head, - )); + ));*/ + + let wv = qkv_attention( + q, + k, + v, + mask, + self.n_head, + ); return self.out.forward(wv); } @@ -177,10 +187,10 @@ pub struct MLPConfig { } impl MLPConfig { - fn init(&self) -> MLP { - let fc1 = nn::LinearConfig::new(self.input_size, self.hidden_size).init(); + fn init(&self, device: &B::Device) -> MLP { + let fc1 = nn::LinearConfig::new(self.input_size, self.hidden_size).init(device); let gelu = QuickGELU::new(); - let fc2 = nn::LinearConfig::new(self.hidden_size, self.input_size).init(); + let fc2 = nn::LinearConfig::new(self.hidden_size, self.input_size).init(device); MLP { fc1, gelu, fc2 } } diff --git a/src/model/groupnorm/load.rs b/src/model/groupnorm/load.rs index a57d9e8..39166e1 100644 --- a/src/model/groupnorm/load.rs +++ b/src/model/groupnorm/load.rs @@ -18,14 +18,14 @@ pub fn load_group_norm( let n_channel = load_usize::("n_channel", path, device)?.into(); let eps = load_f32::("eps", path, device)?.into(); - let gamma = load_tensor::("weight", path, device) + let gamma = Param::from_tensor(load_tensor::("weight", path, device) .ok() - .unwrap_or_else(|| Tensor::ones_device([n_channel], device)) - .into(); - let beta = load_tensor::("bias", path, device) + .unwrap_or_else(|| Tensor::ones([n_channel], device)) + ); + let beta = Param::from_tensor(load_tensor::("bias", path, device) .ok() - .unwrap_or_else(|| Tensor::zeros_device([n_channel], device)) - .into(); + .unwrap_or_else(|| Tensor::zeros([n_channel], device)) + ); Ok(GroupNorm { n_group, diff --git a/src/model/groupnorm/mod.rs b/src/model/groupnorm/mod.rs index 8d2947d..303e5b8 100644 --- a/src/model/groupnorm/mod.rs +++ b/src/model/groupnorm/mod.rs @@ -15,7 +15,7 @@ pub struct GroupNormConfig { } impl GroupNormConfig { - pub fn init(&self) -> GroupNorm { + pub fn init(&self, device: &B::Device) -> GroupNorm { assert!( self.n_channel % self.n_group == 0, "The number of channels {} must be divisible by the number of groups {}", @@ -25,8 +25,8 @@ impl GroupNormConfig { let n_per_group = self.n_channel / self.n_group; - let gamma = Tensor::ones([self.n_channel]).into(); - let beta = Tensor::zeros([self.n_channel]).into(); + let gamma = Param::from_tensor(Tensor::ones([self.n_channel], device)); + let beta = Param::from_tensor(Tensor::zeros([self.n_channel], device)); let eps = self.eps; diff --git a/src/model/load.rs b/src/model/load.rs index 41ad96e..98d775c 100644 --- a/src/model/load.rs +++ b/src/model/load.rs @@ -1,5 +1,7 @@ use npy::{self, NpyData}; use num_traits::cast::ToPrimitive; +use burn::tensor::cast::ToElement; +use burn::prelude::TensorData; use std::error::Error; use std::io::Read; @@ -21,7 +23,8 @@ pub fn numpy_to_tensor( let shape: Vec<_> = v[0..D].into_iter().map(|&v| v as usize).collect(); let data: Vec = v[D..].into_iter().map(|e| e.elem()).collect(); - Tensor::from_data_device(Data::new(data, shape.into()), device) + //Tensor::from_data_device(Data::new(data, shape.into()), device) + Tensor::from_data(TensorData::new(data, shape), device) } pub fn load_tensor( @@ -48,7 +51,7 @@ pub fn load_f32( path: &str, device: &B::Device, ) -> Result> { - load_tensor::(name, path, device).map(|t| t.into_scalar().to_f32().unwrap()) + load_tensor::(name, path, device).map(|t| t.into_scalar().to_f32()) } pub fn load_usize( @@ -56,7 +59,7 @@ pub fn load_usize( path: &str, device: &B::Device, ) -> Result> { - load_tensor::(name, path, device).map(|t| t.into_scalar().to_usize().unwrap()) + load_tensor::(name, path, device).map(|t| t.into_scalar().to_usize()) } pub fn load_linear( @@ -66,13 +69,10 @@ pub fn load_linear( let weight = load_tensor::("weight", path, device)?; let bias = load_tensor::("bias", path, device).ok(); - let record = nn::LinearRecord { - weight: weight.into(), - bias: bias.map(|t| t.into()), - }; - - let linear: nn::Linear = nn::LinearConfig::new(3, 3).init_with(record); - Ok(linear) + Ok(nn::Linear { + weight: Param::from_tensor(weight), + bias: bias.map(|t| Param::from_tensor(t)), + }) } pub fn load_embedding( @@ -80,14 +80,10 @@ pub fn load_embedding( device: &B::Device, ) -> Result, Box> { let weight = load_tensor::("weight", path, device)?; - let [n_vocab, n_state] = weight.dims(); - let record = nn::EmbeddingRecord { - weight: weight.into(), - }; - - let embedding = nn::EmbeddingConfig::new(n_vocab, n_state).init_with(record); - Ok(embedding) + Ok(nn::Embedding { + weight: Param::from_tensor(weight), + }) } pub fn load_layer_norm( @@ -100,13 +96,9 @@ pub fn load_layer_norm( let [n_state] = weight.dims(); - let record = nn::LayerNormRecord { - gamma: weight.into(), - beta: bias.into(), - epsilon: >::into_record(eps), - }; - - let layer_norm: nn::LayerNorm = nn::LayerNormConfig::new(n_state).init_with(record); + let mut layer_norm = nn::LayerNormConfig::new(n_state).with_epsilon(eps).init(device); + layer_norm.gamma = Param::from_tensor(weight); + layer_norm.beta = Param::from_tensor(bias); Ok(layer_norm) } @@ -116,7 +108,7 @@ pub fn load_layer_norm( let eps = load_f32::("eps", path, device)?.into(); let rmsnorm = RMSNorm { - weight: weight.into(), + weight: Param::from_tensor(weight), eps: eps }; @@ -148,40 +140,38 @@ pub fn load_conv2d( let padding = tensor_to_array_2(padding); let padding = nn::PaddingConfig2d::Explicit(padding[0], padding[1]); - let record = conv::Conv2dRecord { - weight: weight.into(), - bias: bias.map(|t| t.into()), - stride: <[usize; 2] as Module>::into_record(stride), - kernel_size: <[usize; 2] as Module>::into_record(kernel_size), - dilation: <[usize; 2] as Module>::into_record(dilation), - groups: >::into_record(n_group), - padding: >::into_record(padding.clone()), - }; - - let conv2d: conv::Conv2d = - conv::Conv2dConfig::new([n_channels_in, n_channels_out], kernel_size) + let mut conv2d = conv::Conv2dConfig::new([n_channels_in, n_channels_out], kernel_size) .with_stride(stride) .with_dilation(dilation) .with_groups(n_group) - .with_padding(padding) + .with_padding(padding.clone()) .with_bias(has_bias) - .init_with(record); + .init(device); + + conv2d.weight = Param::from_tensor(weight); + conv2d.bias = bias.map(|t| Param::from_tensor(t)); + conv2d.stride = stride; + conv2d.kernel_size = kernel_size; + conv2d.dilation = dilation; + conv2d.groups = n_group; + conv2d.padding = burn::module::Ignored(padding); + Ok(conv2d) } pub fn tensor_to_array_2(x: Tensor) -> [usize; 2] { - let vec = x.into_data().value; + let vec: Vec<::FloatElem> = x.into_data().to_vec().unwrap(); assert!(vec.len() == 2, "Tensor length must be 2."); - [vec[0].to_usize().unwrap(), vec[1].to_usize().unwrap()] + [vec[0].to_usize(), vec[1].to_usize()] } pub fn tensor_to_array(x: Tensor) -> [usize; N] { - let vec = x.into_data().value; + let vec: Vec<::FloatElem> = x.into_data().to_vec().unwrap(); assert!(vec.len() == N, "Tensor length must be {}.", N); let mut arr = [0; N]; for (a, t) in arr.iter_mut().zip(vec) { - *a = t.to_usize().unwrap(); + *a = t.to_usize(); } arr diff --git a/src/model/stablediffusion/load.rs b/src/model/stablediffusion/load.rs index de31fab..6f3d719 100644 --- a/src/model/stablediffusion/load.rs +++ b/src/model/stablediffusion/load.rs @@ -18,7 +18,7 @@ pub fn load_stable_diffusion( device: &B::Device, ) -> Result, Box> { let n_steps = load_usize::("n_steps", path, device)?; - let alpha_cumulative_products = load_tensor::("alphas_cumprod", path, device)?.into(); + let alpha_cumulative_products = Param::from_tensor(load_tensor::("alphas_cumprod", path, device)?); let autoencoder = load_autoencoder(&format!("{}/{}", path, "autoencoder"), device)?; let diffusion = load_unet(&format!("{}/{}", path, "unet"), device)?; let clip = load_clip(&format!("{}/{}", path, "clip"), device)?; diff --git a/src/model/stablediffusion/mod.rs b/src/model/stablediffusion/mod.rs index 32c708d..5a7ce67 100644 --- a/src/model/stablediffusion/mod.rs +++ b/src/model/stablediffusion/mod.rs @@ -4,11 +4,12 @@ use burn::{ config::Config, module::{Module, Param}, tensor::{backend::Backend, BasicOps, Data, Distribution, Float, Int, Tensor}, + tensor::cast::ToElement, }; use num_traits::ToPrimitive; -use crate::backend::Backend as MyBackend; +//use crate::backend::Backend as MyBackend; use super::autoencoder::{Autoencoder, AutoencoderConfig}; use super::clip::{CLIPConfig, CLIP}; @@ -19,13 +20,13 @@ use crate::tokenizer::SimpleTokenizer; pub struct StableDiffusionConfig {} impl StableDiffusionConfig { - pub fn init(&self) -> StableDiffusion { + pub fn init(&self, device: &B::Device) -> StableDiffusion { let n_steps = 1000; - let alpha_cumulative_products = offset_cosine_schedule_cumprod::(n_steps).into(); + let alpha_cumulative_products = Param::from_tensor(offset_cosine_schedule_cumprod::(n_steps as i64, device)); - let autoencoder = AutoencoderConfig::new().init(); - let diffusion = UNetConfig::new().init(); - let clip = CLIPConfig::new(49408, 768, 12, 77, 12).init(); + let autoencoder = AutoencoderConfig::new().init(device); + let diffusion = UNetConfig::new().init(device); + let clip = CLIPConfig::new(49408, 768, 12, 77, 12).init(device); StableDiffusion { n_steps, @@ -46,7 +47,7 @@ pub struct StableDiffusion { clip: CLIP, } -impl StableDiffusion { +impl StableDiffusion { pub fn sample_image( &self, context: Tensor, @@ -82,7 +83,7 @@ impl StableDiffusion { .swap_dims(2, 3) .mul_scalar(255.0); - let flattened: Vec<_> = image.into_data().value; + let flattened: Vec = image.into_data().to_vec().unwrap(); (0..n_batch) .into_iter() @@ -92,7 +93,7 @@ impl StableDiffusion { flattened[start..end] .into_iter() - .map(|v| v.to_f64().unwrap().min(255.0).max(0.0).to_u8().unwrap()) + .map(|v| v.to_f64().min(255.0).max(0.0) as u8) .collect() }) .collect() @@ -112,8 +113,7 @@ impl StableDiffusion { let [n_batches, _, _] = context.dims(); let gen_noise = || { - Tensor::random([n_batches, 4, 64, 64], Distribution::Normal(0.0, 1.0)) - .to_device(&device) + Tensor::random([n_batches, 4, 64, 64], Distribution::Normal(0.0, 1.0), &device) }; let sigma = 0.0; // Use deterministic diffusion @@ -126,8 +126,8 @@ impl StableDiffusion { .val() .slice([t..t + 1]) .into_scalar() - .to_f64() - .unwrap(); + .to_f64(); + let prev_alpha: f64 = if t >= step_size { let i = t - step_size; self.alpha_cumulative_products @@ -135,14 +135,13 @@ impl StableDiffusion { .slice([i..i + 1]) .into_scalar() .to_f64() - .unwrap() } else { 1.0 }; let sqrt_noise = (1.0 - current_alpha).sqrt(); - let timestep = Tensor::from_ints([t as i32]).to_device(&device); + let timestep = Tensor::from_ints([t as i32], &device); let pred_noise = self.forward_diffuser( latent.clone(), timestep, @@ -174,7 +173,7 @@ impl StableDiffusion { let unconditional_latent = self.diffusion.forward( latent.clone(), timestep.clone(), - unconditional_context.unsqueeze().repeat(0, n_batch), + unconditional_context.unsqueeze().repeat(&[0, n_batch]), ); let conditional_latent = self.diffusion.forward(latent, timestep, context); @@ -206,8 +205,7 @@ impl StableDiffusion { .collect(); self.clip.forward( - Tensor::from_ints(&tokenized[..]) - .to_device(device) + Tensor::::from_ints(&tokenized[..], device) .unsqueeze(), ) } @@ -215,25 +213,25 @@ impl StableDiffusion { use std::f64::consts::PI; -fn cosine_schedule(n_steps: usize) -> Tensor { - Tensor::arange(1..n_steps + 1) +fn cosine_schedule(n_steps: i64, device: &B::Device) -> Tensor { + Tensor::arange(1..n_steps + 1, device) .float() .mul_scalar(PI * 0.5 / n_steps as f64) .cos() } -fn offset_cosine_schedule(n_steps: usize) -> Tensor { +fn offset_cosine_schedule(n_steps: i64, device: &B::Device) -> Tensor { let min_signal_rate: f64 = 0.02; let max_signal_rate: f64 = 0.95; let start_angle = max_signal_rate.acos(); let end_angle = min_signal_rate.acos(); - let times = Tensor::arange(1..n_steps + 1).float(); + let times = Tensor::arange(1..n_steps + 1, device).float(); let diffusion_angles = times * ((end_angle - start_angle) / n_steps as f64) + start_angle; diffusion_angles.cos() } -fn offset_cosine_schedule_cumprod(n_steps: usize) -> Tensor { - offset_cosine_schedule::(n_steps).powf(2.0) +fn offset_cosine_schedule_cumprod(n_steps: i64, device: &B::Device) -> Tensor { + offset_cosine_schedule::(n_steps, device).powf_scalar(2.0) } diff --git a/src/model/unet/load.rs b/src/model/unet/load.rs index b2821dd..f59e966 100644 --- a/src/model/unet/load.rs +++ b/src/model/unet/load.rs @@ -65,7 +65,7 @@ pub fn load_geglu(path: &str, device: &B::Device) -> Result let geglue = GEGLU { proj: proj, - gelu: GELU::new(), // Assuming GELU::new() initializes a new GELU struct + gelu: Gelu::new(), // Assuming Gelu::new() initializes a new Gelu struct }; Ok(geglue) diff --git a/src/model/unet/mod.rs b/src/model/unet/mod.rs index 5cec18a..07879bb 100644 --- a/src/model/unet/mod.rs +++ b/src/model/unet/mod.rs @@ -6,7 +6,7 @@ use burn::{ nn::{ self, conv::{Conv2d, Conv2dConfig}, - PaddingConfig2d, GELU, + PaddingConfig2d, Gelu, }, tensor::{activation::softmax, backend::Backend, module::embedding, Distribution, Int, Tensor}, }; @@ -22,7 +22,7 @@ fn timestep_embedding( max_period: usize, ) -> Tensor { let half = dim / 2; - let freqs = (Tensor::arange_device(0..half, ×teps.device()).float() + let freqs = (Tensor::arange(0..half as i64, ×teps.device()).float() * (-(max_period as f64).ln() / half as f64)) .exp(); let args = timesteps.float() * freqs; @@ -33,50 +33,50 @@ fn timestep_embedding( pub struct UNetConfig {} impl UNetConfig { - pub fn init(&self) -> UNet { - let lin1_time_embed = nn::LinearConfig::new(320, 1280).init(); + pub fn init(&self, device: &B::Device) -> UNet { + let lin1_time_embed = nn::LinearConfig::new(320, 1280).init(device); let silu_time_embed = SILU::new(); - let lin2_time_embed = nn::LinearConfig::new(1280, 1280).init(); + let lin2_time_embed = nn::LinearConfig::new(1280, 1280).init(device); let input_blocks = UNetInputBlocks { conv: Conv2dConfig::new([4, 320], [3, 3]) .with_padding(PaddingConfig2d::Explicit(1, 1)) - .init(), - rt1: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(), - rt2: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(), - d1: DownsampleConfig::new(320).init(), - rt3: ResTransformerConfig::new(320, 1280, 640, 768, 8).init(), - rt4: ResTransformerConfig::new(640, 1280, 640, 768, 8).init(), - d2: DownsampleConfig::new(640).init(), - rt5: ResTransformerConfig::new(640, 1280, 1280, 768, 8).init(), - rt6: ResTransformerConfig::new(1280, 1280, 1280, 768, 8).init(), - d3: DownsampleConfig::new(1280).init(), - r1: ResBlockConfig::new(1280, 1280, 1280).init(), - r2: ResBlockConfig::new(1280, 1280, 1280).init(), + .init(device), + rt1: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(device), + rt2: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(device), + d1: DownsampleConfig::new(320).init(device), + rt3: ResTransformerConfig::new(320, 1280, 640, 768, 8).init(device), + rt4: ResTransformerConfig::new(640, 1280, 640, 768, 8).init(device), + d2: DownsampleConfig::new(640).init(device), + rt5: ResTransformerConfig::new(640, 1280, 1280, 768, 8).init(device), + rt6: ResTransformerConfig::new(1280, 1280, 1280, 768, 8).init(device), + d3: DownsampleConfig::new(1280).init(device), + r1: ResBlockConfig::new(1280, 1280, 1280).init(device), + r2: ResBlockConfig::new(1280, 1280, 1280).init(device), }; - let middle_block = ResTransformerResConfig::new(1280, 1280, 1280, 768, 8).init(); + let middle_block = ResTransformerResConfig::new(1280, 1280, 1280, 768, 8).init(device); let output_blocks = UNetOutputBlocks { - r1: ResBlockConfig::new(2560, 1280, 1280).init(), - r2: ResBlockConfig::new(2560, 1280, 1280).init(), - ru: ResUpSampleConfig::new(2560, 1280, 1280).init(), - rt1: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(), - rt2: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(), - rtu1: ResTransformerUpsampleConfig::new(1920, 1280, 1280, 768, 8).init(), - rt3: ResTransformerConfig::new(1920, 1280, 640, 768, 8).init(), - rt4: ResTransformerConfig::new(1280, 1280, 640, 768, 8).init(), - rtu2: ResTransformerUpsampleConfig::new(960, 1280, 640, 768, 8).init(), - rt5: ResTransformerConfig::new(960, 1280, 320, 768, 8).init(), - rt6: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(), - rt7: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(), + r1: ResBlockConfig::new(2560, 1280, 1280).init(device), + r2: ResBlockConfig::new(2560, 1280, 1280).init(device), + ru: ResUpSampleConfig::new(2560, 1280, 1280).init(device), + rt1: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(device), + rt2: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(device), + rtu1: ResTransformerUpsampleConfig::new(1920, 1280, 1280, 768, 8).init(device), + rt3: ResTransformerConfig::new(1920, 1280, 640, 768, 8).init(device), + rt4: ResTransformerConfig::new(1280, 1280, 640, 768, 8).init(device), + rtu2: ResTransformerUpsampleConfig::new(960, 1280, 640, 768, 8).init(device), + rt5: ResTransformerConfig::new(960, 1280, 320, 768, 8).init(device), + rt6: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(device), + rt7: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(device), }; - let norm_out = GroupNormConfig::new(32, 320).init(); + let norm_out = GroupNormConfig::new(32, 320).init(device); let silu_out = SILU::new(); let conv_out = Conv2dConfig::new([320, 4], [3, 3]) .with_padding(PaddingConfig2d::Explicit(1, 1)) - .init(); + .init(device); UNet { lin1_time_embed, @@ -206,16 +206,16 @@ pub struct ResTransformerConfig { } impl ResTransformerConfig { - fn init(&self) -> ResTransformer { + fn init(&self, device: &B::Device) -> ResTransformer { let res = ResBlockConfig::new( self.n_channels_in, self.n_channels_embed, self.n_channels_out, ) - .init(); + .init(device); let transformer = SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head) - .init(); + .init(device); ResTransformer { res, transformer } } @@ -243,14 +243,14 @@ pub struct ResUpSampleConfig { } impl ResUpSampleConfig { - fn init(&self) -> ResUpSample { + fn init(&self, device: &B::Device) -> ResUpSample { let res = ResBlockConfig::new( self.n_channels_in, self.n_channels_embed, self.n_channels_out, ) - .init(); - let upsample = UpsampleConfig::new(self.n_channels_out).init(); + .init(device); + let upsample = UpsampleConfig::new(self.n_channels_out).init(device); ResUpSample { res, upsample } } @@ -280,17 +280,17 @@ pub struct ResTransformerUpsampleConfig { } impl ResTransformerUpsampleConfig { - fn init(&self) -> ResTransformerUpsample { + fn init(&self, device: &B::Device) -> ResTransformerUpsample { let res = ResBlockConfig::new( self.n_channels_in, self.n_channels_embed, self.n_channels_out, ) - .init(); + .init(device); let transformer = SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head) - .init(); - let upsample = UpsampleConfig::new(self.n_channels_out).init(); + .init(device); + let upsample = UpsampleConfig::new(self.n_channels_out).init(device); ResTransformerUpsample { res, @@ -326,22 +326,22 @@ pub struct ResTransformerResConfig { } impl ResTransformerResConfig { - fn init(&self) -> ResTransformerRes { + fn init(&self, device: &B::Device) -> ResTransformerRes { let res1 = ResBlockConfig::new( self.n_channels_in, self.n_channels_embed, self.n_channels_out, ) - .init(); + .init(device); let transformer = SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head) - .init(); + .init(device); let res2 = ResBlockConfig::new( self.n_channels_in, self.n_channels_embed, self.n_channels_out, ) - .init(); + .init(device); ResTransformerRes { res1, @@ -373,10 +373,10 @@ pub struct UpsampleConfig { } impl UpsampleConfig { - fn init(&self) -> Upsample { + fn init(&self, device: &B::Device) -> Upsample { let conv = Conv2dConfig::new([self.n_channels, self.n_channels], [3, 3]) .with_padding(PaddingConfig2d::Explicit(1, 1)) - .init(); + .init(device); Upsample { conv } } @@ -392,8 +392,7 @@ impl Upsample { let [n_batch, n_channel, height, width] = x.dims(); let x = x .reshape([n_batch, n_channel, height, 1, width, 1]) - .repeat(3, 2) - .repeat(5, 2) + .repeat(&[1, 1, 1, 2, 1, 2]) .reshape([n_batch, n_channel, 2 * height, 2 * width]); self.conv.forward(x) } @@ -411,11 +410,11 @@ pub struct DownsampleConfig { } impl DownsampleConfig { - fn init(&self) -> Conv2d { + fn init(&self, device: &B::Device) -> Conv2d { Conv2dConfig::new([self.n_channels, self.n_channels], [3, 3]) .with_stride([2, 2]) .with_padding(PaddingConfig2d::Explicit(1, 1)) - .init() + .init(device) } } @@ -435,12 +434,12 @@ pub struct SpatialTransformerConfig { } impl SpatialTransformerConfig { - fn init(&self) -> SpatialTransformer { - let norm = GroupNormConfig::new(32, self.n_channels).init(); - let proj_in = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init(); + fn init(&self, device: &B::Device) -> SpatialTransformer { + let norm = GroupNormConfig::new(32, self.n_channels).init(device); + let proj_in = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init(device); let transformer = - TransformerBlockConfig::new(self.n_channels, self.n_context_state, self.n_head).init(); - let proj_out = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init(); + TransformerBlockConfig::new(self.n_channels, self.n_context_state, self.n_head).init(device); + let proj_out = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init(device); SpatialTransformer { norm, @@ -489,14 +488,14 @@ pub struct TransformerBlockConfig { } impl TransformerBlockConfig { - fn init(&self) -> TransformerBlock { - let norm1 = nn::LayerNormConfig::new(self.n_state).init(); - let attn1 = MultiHeadAttentionConfig::new(self.n_state, self.n_state, self.n_head).init(); - let norm2 = nn::LayerNormConfig::new(self.n_state).init(); + fn init(&self, device: &B::Device) -> TransformerBlock { + let norm1 = nn::LayerNormConfig::new(self.n_state).init(device); + let attn1 = MultiHeadAttentionConfig::new(self.n_state, self.n_state, self.n_head).init(device); + let norm2 = nn::LayerNormConfig::new(self.n_state).init(device); let attn2 = - MultiHeadAttentionConfig::new(self.n_state, self.n_context_state, self.n_head).init(); - let norm3 = nn::LayerNormConfig::new(self.n_state).init(); - let mlp = MLPConfig::new(self.n_state, 4).init(); + MultiHeadAttentionConfig::new(self.n_state, self.n_context_state, self.n_head).init(device); + let norm3 = nn::LayerNormConfig::new(self.n_state).init(device); + let mlp = MLPConfig::new(self.n_state, 4).init(device); TransformerBlock { norm1, @@ -534,10 +533,10 @@ pub struct MLPConfig { } impl MLPConfig { - pub fn init(&self) -> MLP { + pub fn init(&self, device: &B::Device) -> MLP { let n_state_hidden = self.n_state * self.mult; - let geglu = GEGLUConfig::new(self.n_state, n_state_hidden).init(); - let lin = nn::LinearConfig::new(n_state_hidden, self.n_state).init(); + let geglu = GEGLUConfig::new(self.n_state, n_state_hidden).init(device); + let lin = nn::LinearConfig::new(n_state_hidden, self.n_state).init(device); MLP { geglu, lin } } @@ -562,9 +561,9 @@ pub struct GEGLUConfig { } impl GEGLUConfig { - fn init(&self) -> GEGLU { - let proj = nn::LinearConfig::new(self.n_state_in, 2 * self.n_state_out).init(); - let gelu = GELU::new(); + fn init(&self, device: &B::Device) -> GEGLU { + let proj = nn::LinearConfig::new(self.n_state_in, 2 * self.n_state_out).init(device); + let gelu = Gelu::new(); GEGLU { proj, gelu } } @@ -573,7 +572,7 @@ impl GEGLUConfig { #[derive(Module, Debug)] pub struct GEGLU { proj: nn::Linear, - gelu: GELU, + gelu: Gelu, } impl GEGLU { @@ -600,7 +599,7 @@ pub struct MultiHeadAttentionConfig { } impl MultiHeadAttentionConfig { - fn init(&self) -> MultiHeadAttention { + fn init(&self, device: &B::Device) -> MultiHeadAttention { assert!( self.n_state % self.n_head == 0, "State size {} must be a multiple of head size {}", @@ -611,14 +610,14 @@ impl MultiHeadAttentionConfig { let n_head = self.n_head; let query = nn::LinearConfig::new(self.n_state, self.n_state) .with_bias(false) - .init(); + .init(device); let key = nn::LinearConfig::new(self.n_context_state, self.n_state) .with_bias(false) - .init(); + .init(device); let value = nn::LinearConfig::new(self.n_context_state, self.n_state) .with_bias(false) - .init(); - let out = nn::LinearConfig::new(self.n_state, self.n_state).init(); + .init(device); + let out = nn::LinearConfig::new(self.n_state, self.n_state).init(device); MultiHeadAttention { n_head, @@ -661,24 +660,24 @@ pub struct ResBlockConfig { } impl ResBlockConfig { - fn init(&self) -> ResBlock { - let norm_in = GroupNormConfig::new(32, self.n_channels_in).init(); + fn init(&self, device: &B::Device) -> ResBlock { + let norm_in = GroupNormConfig::new(32, self.n_channels_in).init(device); let silu_in = SILU::new(); let conv_in = Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [3, 3]) .with_padding(PaddingConfig2d::Explicit(1, 1)) - .init(); + .init(device); let silu_embed = SILU::new(); - let lin_embed = nn::LinearConfig::new(self.n_channels_embed, self.n_channels_out).init(); + let lin_embed = nn::LinearConfig::new(self.n_channels_embed, self.n_channels_out).init(device); - let norm_out = GroupNormConfig::new(32, self.n_channels_out).init(); + let norm_out = GroupNormConfig::new(32, self.n_channels_out).init(device); let silu_out = SILU::new(); let conv_out = Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3]) .with_padding(PaddingConfig2d::Explicit(1, 1)) - .init(); + .init(device); let skip_connection = if self.n_channels_in != self.n_channels_out { - Some(Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [1, 1]).init()) + Some(Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [1, 1]).init(device)) } else { None };