updated to the newest version of burn

This commit is contained in:
rasmus
2026-03-03 15:12:59 +01:00
parent 893fb0950d
commit 6cfd6db5a5
8 changed files with 39 additions and 43 deletions

View File

@@ -8,17 +8,13 @@ edition = "2021"
[features]
wgpu-backend = ["burn-wgpu"]
[dependencies.burn-wgpu]
package = "burn-wgpu"
git = "https://github.com/burn-rs/burn.git"
optional = true
[dependencies]
burn = "0.14.0"
burn-ndarray = "0.14.0"
burn-tch = "0.14.0"
burn-autodiff = "0.14.0"
tch = "0.15.0"
burn = "0.20.1"
burn-ndarray = "0.20.1"
burn-tch = "0.20.1"
burn-autodiff = "0.20.1"
burn-wgpu = { version = "0.20.1", optional = true }
tch = "0.22.0"
serde = {version = "1.0.171", features = ["std", "derive"]}
npy = "0.4.0"
num-traits = "0.2.15"

View File

@@ -12,7 +12,7 @@ use burn::{
cfg_if::cfg_if! {
if #[cfg(feature = "wgpu-backend")] {
use burn_wgpu::{WgpuBackend, WgpuDevice, AutoGraphicsApi};
use burn_wgpu::{Wgpu, WgpuDevice};
} else {
use burn_tch::{LibTorch, LibTorchDevice};
}
@@ -58,7 +58,7 @@ fn main() {
cfg_if::cfg_if! {
if #[cfg(feature = "wgpu-backend")] {
type Backend = WgpuBackend<AutoGraphicsApi, f32, i32>;
type Backend = Wgpu;
let device = WgpuDevice::BestAvailable;
} else {
type Backend = LibTorch<f32>;

View File

@@ -23,7 +23,7 @@ use crate::backend::{qkv_attention, attn_decoder_mask};
use std::iter;
#[derive(Config)]
#[derive(Config, Debug)]
pub struct AutoencoderConfig {}
impl AutoencoderConfig {
@@ -71,7 +71,7 @@ impl<B: Backend> Autoencoder<B> {
}
}
#[derive(Config)]
#[derive(Config, Debug)]
pub struct EncoderConfig {
channels: Vec<(usize, usize)>,
n_group: usize,
@@ -144,7 +144,7 @@ impl<B: Backend> Encoder<B> {
}
}
#[derive(Config)]
#[derive(Config, Debug)]
pub struct DecoderConfig {
channels: Vec<(usize, usize)>,
n_group: usize,
@@ -216,7 +216,7 @@ impl<B: Backend> Decoder<B> {
}
}
#[derive(Config)]
#[derive(Config, Debug)]
pub struct EncoderBlockConfig {
n_channels_in: usize,
n_channels_out: usize,
@@ -265,7 +265,7 @@ impl<B: Backend> EncoderBlock<B> {
}
}
#[derive(Config)]
#[derive(Config, Debug)]
pub struct DecoderBlockConfig {
n_channels_in: usize,
n_channels_out: usize,
@@ -323,7 +323,7 @@ impl<B: Backend> DecoderBlock<B> {
}
}
#[derive(Config)]
#[derive(Config, Debug)]
pub struct PaddedConv2dConfig {
channels: [usize; 2],
kernel_size: usize,
@@ -427,7 +427,7 @@ pub struct Padding {
pad_bottom: usize,
}
#[derive(Config)]
#[derive(Config, Debug)]
pub struct MidConfig {
n_channel: usize,
}
@@ -462,7 +462,7 @@ impl<B: Backend> Mid<B> {
}
}
#[derive(Config)]
#[derive(Config, Debug)]
pub struct ResnetBlockConfig {
in_channels: usize,
out_channels: usize,
@@ -527,7 +527,7 @@ impl<B: Backend> ResnetBlock<B> {
}
}
#[derive(Config)]
#[derive(Config, Debug)]
pub struct ConvSelfAttentionBlockConfig {
n_channel: usize,
}

View File

@@ -15,7 +15,7 @@ use burn::{
//use crate::backend::Backend as MyBackend;
use crate::backend::{qkv_attention, attn_decoder_mask};
#[derive(Config)]
#[derive(Config, Debug)]
pub struct CLIPConfig {
n_vocab: usize,
n_state: usize,
@@ -75,7 +75,7 @@ impl<B: Backend> CLIP<B> {
}
}
#[derive(Config)]
#[derive(Config, Debug)]
pub struct ResidualDecoderAttentionBlockConfig {
n_state: usize,
n_head: usize,
@@ -114,7 +114,7 @@ impl<B: Backend> ResidualDecoderAttentionBlock<B> {
}
}
#[derive(Config)]
#[derive(Config, Debug)]
pub struct MultiHeadSelfAttentionConfig {
n_state: usize,
n_head: usize,

View File

@@ -6,7 +6,7 @@ use burn::{
tensor::{backend::Backend, Tensor},
};
#[derive(Config)]
#[derive(Config, Debug)]
pub struct GroupNormConfig {
n_group: usize,
n_channel: usize,

View File

@@ -9,7 +9,7 @@ use burn::{
config::Config,
module::{Module, Param},
nn::{self, conv},
tensor::{backend::Backend, Data, Tensor},
tensor::{backend::Backend, Tensor},
};
use burn::tensor::ElementConversion;
@@ -98,7 +98,7 @@ pub fn load_layer_norm<B: Backend>(
let mut layer_norm = nn::LayerNormConfig::new(n_state).with_epsilon(eps).init(device);
layer_norm.gamma = Param::from_tensor(weight);
layer_norm.beta = Param::from_tensor(bias);
layer_norm.beta = Some(Param::from_tensor(bias));
Ok(layer_norm)
}

View File

@@ -3,7 +3,7 @@ pub mod load;
use burn::{
config::Config,
module::{Module, Param},
tensor::{backend::Backend, BasicOps, Data, Distribution, Float, Int, Tensor},
tensor::{backend::Backend, BasicOps, Distribution, Float, Int, Tensor},
tensor::cast::ToElement,
};
@@ -16,7 +16,7 @@ use super::clip::{CLIPConfig, CLIP};
use super::unet::{UNet, UNetConfig};
use crate::tokenizer::SimpleTokenizer;
#[derive(Config)]
#[derive(Config, Debug)]
pub struct StableDiffusionConfig {}
impl StableDiffusionConfig {
@@ -192,7 +192,7 @@ impl<B: Backend> StableDiffusion<B> {
}
pub fn unconditional_context(&self, tokenizer: &SimpleTokenizer) -> Tensor<B, 2> {
self.context(tokenizer, "").squeeze(0)
self.context(tokenizer, "").squeeze::<2>()
}
pub fn context(&self, tokenizer: &SimpleTokenizer, text: &str) -> Tensor<B, 3> {

View File

@@ -29,7 +29,7 @@ fn timestep_embedding<B: Backend>(
Tensor::cat(vec![args.clone().cos(), args.sin()], 0).unsqueeze()
}
#[derive(Config)]
#[derive(Config, Debug)]
pub struct UNetConfig {}
impl UNetConfig {
@@ -196,7 +196,7 @@ trait UNetBlock<B: Backend> {
fn forward(&self, x: Tensor<B, 4>, emb: Tensor<B, 2>, context: Tensor<B, 3>) -> Tensor<B, 4>;
}
#[derive(Config)]
#[derive(Config, Debug)]
pub struct ResTransformerConfig {
n_channels_in: usize,
n_channels_embed: usize,
@@ -235,7 +235,7 @@ impl<B: Backend> UNetBlock<B> for ResTransformer<B> {
}
}
#[derive(Config)]
#[derive(Config, Debug)]
pub struct ResUpSampleConfig {
n_channels_in: usize,
n_channels_embed: usize,
@@ -270,7 +270,7 @@ impl<B: Backend> UNetBlock<B> for ResUpSample<B> {
}
}
#[derive(Config)]
#[derive(Config, Debug)]
pub struct ResTransformerUpsampleConfig {
n_channels_in: usize,
n_channels_embed: usize,
@@ -316,7 +316,7 @@ impl<B: Backend> UNetBlock<B> for ResTransformerUpsample<B> {
}
}
#[derive(Config)]
#[derive(Config, Debug)]
pub struct ResTransformerResConfig {
n_channels_in: usize,
n_channels_embed: usize,
@@ -367,7 +367,7 @@ impl<B: Backend> UNetBlock<B> for ResTransformerRes<B> {
}
}
#[derive(Config)]
#[derive(Config, Debug)]
pub struct UpsampleConfig {
n_channels: usize,
}
@@ -404,7 +404,7 @@ impl<B: Backend> UNetBlock<B> for Upsample<B> {
}
}
#[derive(Config)]
#[derive(Config, Debug)]
pub struct DownsampleConfig {
n_channels: usize,
}
@@ -426,7 +426,7 @@ impl<B: Backend> UNetBlock<B> for Conv2d<B> {
}
}
#[derive(Config)]
#[derive(Config, Debug)]
pub struct SpatialTransformerConfig {
n_channels: usize,
n_context_state: usize,
@@ -480,7 +480,7 @@ impl<B: Backend> SpatialTransformer<B> {
}
}
#[derive(Config)]
#[derive(Config, Debug)]
pub struct TransformerBlockConfig {
n_state: usize,
n_context_state: usize,
@@ -526,7 +526,7 @@ impl<B: Backend> TransformerBlock<B> {
}
}
#[derive(Config)]
#[derive(Config, Debug)]
pub struct MLPConfig {
n_state: usize,
mult: usize,
@@ -554,7 +554,7 @@ impl<B: Backend> MLP<B> {
}
}
#[derive(Config)]
#[derive(Config, Debug)]
pub struct GEGLUConfig {
n_state_in: usize,
n_state_out: usize,
@@ -591,7 +591,7 @@ impl<B: Backend> GEGLU<B> {
}
}
#[derive(Config)]
#[derive(Config, Debug)]
pub struct MultiHeadAttentionConfig {
n_state: usize,
n_context_state: usize,
@@ -652,7 +652,7 @@ impl<B: Backend> MultiHeadAttention<B> {
}
}
#[derive(Config)]
#[derive(Config, Debug)]
pub struct ResBlockConfig {
n_channels_in: usize,
n_channels_embed: usize,