From 6a73e7b27eabd60d11e2751f237d3e815fa1f333 Mon Sep 17 00:00:00 2001 From: rasmus Date: Tue, 3 Mar 2026 15:12:59 +0100 Subject: [PATCH] updated to the newest version of burn --- Cargo.toml | 16 ++++++---------- src/bin/sample/main.rs | 4 ++-- src/model/autoencoder/mod.rs | 18 +++++++++--------- src/model/clip/mod.rs | 6 +++--- src/model/groupnorm/mod.rs | 2 +- src/model/load.rs | 4 ++-- src/model/stablediffusion/mod.rs | 6 +++--- src/model/unet/mod.rs | 26 +++++++++++++------------- 8 files changed, 39 insertions(+), 43 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 00c23aa..ad988ea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,17 +8,13 @@ edition = "2021" [features] wgpu-backend = ["burn-wgpu"] -[dependencies.burn-wgpu] -package = "burn-wgpu" -git = "https://github.com/burn-rs/burn.git" -optional = true - [dependencies] -burn = "0.14.0" -burn-ndarray = "0.14.0" -burn-tch = "0.14.0" -burn-autodiff = "0.14.0" -tch = "0.15.0" +burn = "0.20.1" +burn-ndarray = "0.20.1" +burn-tch = "0.20.1" +burn-autodiff = "0.20.1" +burn-wgpu = { version = "0.20.1", optional = true } +tch = "0.22.0" serde = {version = "1.0.171", features = ["std", "derive"]} npy = "0.4.0" num-traits = "0.2.15" diff --git a/src/bin/sample/main.rs b/src/bin/sample/main.rs index 97cc93f..e5a146c 100644 --- a/src/bin/sample/main.rs +++ b/src/bin/sample/main.rs @@ -12,7 +12,7 @@ use burn::{ cfg_if::cfg_if! { if #[cfg(feature = "wgpu-backend")] { - use burn_wgpu::{WgpuBackend, WgpuDevice, AutoGraphicsApi}; + use burn_wgpu::{Wgpu, WgpuDevice}; } else { use burn_tch::{LibTorch, LibTorchDevice}; } @@ -58,7 +58,7 @@ fn main() { cfg_if::cfg_if! { if #[cfg(feature = "wgpu-backend")] { - type Backend = WgpuBackend; + type Backend = Wgpu; let device = WgpuDevice::BestAvailable; } else { type Backend = LibTorch; diff --git a/src/model/autoencoder/mod.rs b/src/model/autoencoder/mod.rs index 9c40774..e7cd3da 100644 --- a/src/model/autoencoder/mod.rs +++ b/src/model/autoencoder/mod.rs @@ -23,7 +23,7 @@ use crate::backend::{qkv_attention, attn_decoder_mask}; use std::iter; -#[derive(Config)] +#[derive(Config, Debug)] pub struct AutoencoderConfig {} impl AutoencoderConfig { @@ -71,7 +71,7 @@ impl Autoencoder { } } -#[derive(Config)] +#[derive(Config, Debug)] pub struct EncoderConfig { channels: Vec<(usize, usize)>, n_group: usize, @@ -144,7 +144,7 @@ impl Encoder { } } -#[derive(Config)] +#[derive(Config, Debug)] pub struct DecoderConfig { channels: Vec<(usize, usize)>, n_group: usize, @@ -216,7 +216,7 @@ impl Decoder { } } -#[derive(Config)] +#[derive(Config, Debug)] pub struct EncoderBlockConfig { n_channels_in: usize, n_channels_out: usize, @@ -265,7 +265,7 @@ impl EncoderBlock { } } -#[derive(Config)] +#[derive(Config, Debug)] pub struct DecoderBlockConfig { n_channels_in: usize, n_channels_out: usize, @@ -323,7 +323,7 @@ impl DecoderBlock { } } -#[derive(Config)] +#[derive(Config, Debug)] pub struct PaddedConv2dConfig { channels: [usize; 2], kernel_size: usize, @@ -427,7 +427,7 @@ pub struct Padding { pad_bottom: usize, } -#[derive(Config)] +#[derive(Config, Debug)] pub struct MidConfig { n_channel: usize, } @@ -462,7 +462,7 @@ impl Mid { } } -#[derive(Config)] +#[derive(Config, Debug)] pub struct ResnetBlockConfig { in_channels: usize, out_channels: usize, @@ -527,7 +527,7 @@ impl ResnetBlock { } } -#[derive(Config)] +#[derive(Config, Debug)] pub struct ConvSelfAttentionBlockConfig { n_channel: usize, } diff --git a/src/model/clip/mod.rs b/src/model/clip/mod.rs index a8256f4..5f9d44f 100644 --- a/src/model/clip/mod.rs +++ b/src/model/clip/mod.rs @@ -15,7 +15,7 @@ use burn::{ //use crate::backend::Backend as MyBackend; use crate::backend::{qkv_attention, attn_decoder_mask}; -#[derive(Config)] +#[derive(Config, Debug)] pub struct CLIPConfig { n_vocab: usize, n_state: usize, @@ -75,7 +75,7 @@ impl CLIP { } } -#[derive(Config)] +#[derive(Config, Debug)] pub struct ResidualDecoderAttentionBlockConfig { n_state: usize, n_head: usize, @@ -114,7 +114,7 @@ impl ResidualDecoderAttentionBlock { } } -#[derive(Config)] +#[derive(Config, Debug)] pub struct MultiHeadSelfAttentionConfig { n_state: usize, n_head: usize, diff --git a/src/model/groupnorm/mod.rs b/src/model/groupnorm/mod.rs index 303e5b8..92698c0 100644 --- a/src/model/groupnorm/mod.rs +++ b/src/model/groupnorm/mod.rs @@ -6,7 +6,7 @@ use burn::{ tensor::{backend::Backend, Tensor}, }; -#[derive(Config)] +#[derive(Config, Debug)] pub struct GroupNormConfig { n_group: usize, n_channel: usize, diff --git a/src/model/load.rs b/src/model/load.rs index 98d775c..3b51355 100644 --- a/src/model/load.rs +++ b/src/model/load.rs @@ -9,7 +9,7 @@ use burn::{ config::Config, module::{Module, Param}, nn::{self, conv}, - tensor::{backend::Backend, Data, Tensor}, + tensor::{backend::Backend, Tensor}, }; use burn::tensor::ElementConversion; @@ -98,7 +98,7 @@ pub fn load_layer_norm( let mut layer_norm = nn::LayerNormConfig::new(n_state).with_epsilon(eps).init(device); layer_norm.gamma = Param::from_tensor(weight); - layer_norm.beta = Param::from_tensor(bias); + layer_norm.beta = Some(Param::from_tensor(bias)); Ok(layer_norm) } diff --git a/src/model/stablediffusion/mod.rs b/src/model/stablediffusion/mod.rs index 5a7ce67..0f876f3 100644 --- a/src/model/stablediffusion/mod.rs +++ b/src/model/stablediffusion/mod.rs @@ -3,7 +3,7 @@ pub mod load; use burn::{ config::Config, module::{Module, Param}, - tensor::{backend::Backend, BasicOps, Data, Distribution, Float, Int, Tensor}, + tensor::{backend::Backend, BasicOps, Distribution, Float, Int, Tensor}, tensor::cast::ToElement, }; @@ -16,7 +16,7 @@ use super::clip::{CLIPConfig, CLIP}; use super::unet::{UNet, UNetConfig}; use crate::tokenizer::SimpleTokenizer; -#[derive(Config)] +#[derive(Config, Debug)] pub struct StableDiffusionConfig {} impl StableDiffusionConfig { @@ -192,7 +192,7 @@ impl StableDiffusion { } pub fn unconditional_context(&self, tokenizer: &SimpleTokenizer) -> Tensor { - self.context(tokenizer, "").squeeze(0) + self.context(tokenizer, "").squeeze::<2>() } pub fn context(&self, tokenizer: &SimpleTokenizer, text: &str) -> Tensor { diff --git a/src/model/unet/mod.rs b/src/model/unet/mod.rs index 07879bb..7a691dc 100644 --- a/src/model/unet/mod.rs +++ b/src/model/unet/mod.rs @@ -29,7 +29,7 @@ fn timestep_embedding( Tensor::cat(vec![args.clone().cos(), args.sin()], 0).unsqueeze() } -#[derive(Config)] +#[derive(Config, Debug)] pub struct UNetConfig {} impl UNetConfig { @@ -196,7 +196,7 @@ trait UNetBlock { fn forward(&self, x: Tensor, emb: Tensor, context: Tensor) -> Tensor; } -#[derive(Config)] +#[derive(Config, Debug)] pub struct ResTransformerConfig { n_channels_in: usize, n_channels_embed: usize, @@ -235,7 +235,7 @@ impl UNetBlock for ResTransformer { } } -#[derive(Config)] +#[derive(Config, Debug)] pub struct ResUpSampleConfig { n_channels_in: usize, n_channels_embed: usize, @@ -270,7 +270,7 @@ impl UNetBlock for ResUpSample { } } -#[derive(Config)] +#[derive(Config, Debug)] pub struct ResTransformerUpsampleConfig { n_channels_in: usize, n_channels_embed: usize, @@ -316,7 +316,7 @@ impl UNetBlock for ResTransformerUpsample { } } -#[derive(Config)] +#[derive(Config, Debug)] pub struct ResTransformerResConfig { n_channels_in: usize, n_channels_embed: usize, @@ -367,7 +367,7 @@ impl UNetBlock for ResTransformerRes { } } -#[derive(Config)] +#[derive(Config, Debug)] pub struct UpsampleConfig { n_channels: usize, } @@ -404,7 +404,7 @@ impl UNetBlock for Upsample { } } -#[derive(Config)] +#[derive(Config, Debug)] pub struct DownsampleConfig { n_channels: usize, } @@ -426,7 +426,7 @@ impl UNetBlock for Conv2d { } } -#[derive(Config)] +#[derive(Config, Debug)] pub struct SpatialTransformerConfig { n_channels: usize, n_context_state: usize, @@ -480,7 +480,7 @@ impl SpatialTransformer { } } -#[derive(Config)] +#[derive(Config, Debug)] pub struct TransformerBlockConfig { n_state: usize, n_context_state: usize, @@ -526,7 +526,7 @@ impl TransformerBlock { } } -#[derive(Config)] +#[derive(Config, Debug)] pub struct MLPConfig { n_state: usize, mult: usize, @@ -554,7 +554,7 @@ impl MLP { } } -#[derive(Config)] +#[derive(Config, Debug)] pub struct GEGLUConfig { n_state_in: usize, n_state_out: usize, @@ -591,7 +591,7 @@ impl GEGLU { } } -#[derive(Config)] +#[derive(Config, Debug)] pub struct MultiHeadAttentionConfig { n_state: usize, n_context_state: usize, @@ -652,7 +652,7 @@ impl MultiHeadAttention { } } -#[derive(Config)] +#[derive(Config, Debug)] pub struct ResBlockConfig { n_channels_in: usize, n_channels_embed: usize,