mirror of
https://gitea.hainer-ernst.de/rasmus/burn-stablediffusion-vibecode.git
synced 2026-06-10 17:59:22 +00:00
Update to burn v0.14.0 and switch to .mpk model file
This commit is contained in:
10
Cargo.toml
10
Cargo.toml
@@ -14,11 +14,11 @@ git = "https://github.com/burn-rs/burn.git"
|
||||
optional = true
|
||||
|
||||
[dependencies]
|
||||
burn = { git = "https://github.com/burn-rs/burn.git" }
|
||||
burn-ndarray = { package = "burn-ndarray", git = "https://github.com/burn-rs/burn.git" }
|
||||
burn-tch = { package = "burn-tch", git = "https://github.com/burn-rs/burn.git" }
|
||||
burn-autodiff = { package = "burn-autodiff", git = "https://github.com/burn-rs/burn.git" }
|
||||
tch = "0.13.0"
|
||||
burn = "0.14.0"
|
||||
burn-ndarray = "0.14.0"
|
||||
burn-tch = "0.14.0"
|
||||
burn-autodiff = "0.14.0"
|
||||
tch = "0.15.0"
|
||||
serde = {version = "1.0.171", features = ["std", "derive"]}
|
||||
npy = "0.4.0"
|
||||
num-traits = "0.2.15"
|
||||
|
||||
21
README.md
21
README.md
@@ -4,12 +4,14 @@ Stable-Diffusion-Burn is a Rust-based project which ports the V1 stable diffusio
|
||||
|
||||
## How To Use
|
||||
|
||||
### Step 0: Install libtorch v2.4.1
|
||||
|
||||
### Step 1: Download the Model and Set Environment Variables
|
||||
|
||||
Start by downloading the SDv1-4.bin model provided on HuggingFace.
|
||||
Start by downloading the SDv1-4 model provided on HuggingFace.
|
||||
|
||||
```bash
|
||||
wget https://huggingface.co/Gadersd/Stable-Diffusion-Burn/resolve/main/V1/SDv1-4.bin
|
||||
wget https://huggingface.co/Gadersd/Stable-Diffusion-Burn/resolve/main/SDv1-4.mpk
|
||||
```
|
||||
|
||||
### Step 2: Run the Sample Binary
|
||||
@@ -18,9 +20,13 @@ Invoke the sample binary provided in the rust code. By default, torch is used. T
|
||||
|
||||
```bash
|
||||
# torch (at least 6 GB VRAM, possibly less)
|
||||
export TORCH_CUDA_VERSION=cu113
|
||||
# Arguments: <model_type(burn or dump)> <model> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image>
|
||||
cargo run --release --bin sample burn SDv1-4 7.5 20 "An ancient mossy stone." img
|
||||
# Arguments: <model_type(burn or dump)> <model_name> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image_name> [cuda, mps, cpu]
|
||||
|
||||
# Cuda
|
||||
cargo run --release --bin sample burn SDv1-4 7.5 20 "An ancient mossy stone." img cuda
|
||||
|
||||
# Mps(Mac)
|
||||
cargo run --release --bin sample burn SDv1-4 7.5 20 "An ancient mossy stone." img mps
|
||||
|
||||
# wgpu (UNSTABLE)
|
||||
# Arguments: <model_type(burn or dump)> <model> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image>
|
||||
@@ -33,7 +39,7 @@ This command will generate an image according to the provided prompt, which will
|
||||
|
||||
### Optional: Extract and Convert a Fine-Tuned Model
|
||||
|
||||
If users are interested in using a fine-tuned version of stable diffusion, the Python scripts provided in this project can be used to transform a weight dump into a Burn model file. Note: the tinygrad dependency should be installed from source rather than with pip.
|
||||
If users are interested in using a fine-tuned version of stable diffusion, the Python scripts provided in this project can be used to transform a weight dump into a Burn model file. This does not work on Windows.
|
||||
|
||||
```bash
|
||||
# Step into the Python directory
|
||||
@@ -42,6 +48,9 @@ cd python
|
||||
# Download the model, this is just the base v1.4 model as an example
|
||||
wget https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt
|
||||
|
||||
# Install tinygrad
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Extract the weights
|
||||
CPU=1 python3 dump.py sd-v1-4.ckpt
|
||||
|
||||
|
||||
@@ -13,10 +13,11 @@ from collections import namedtuple
|
||||
|
||||
from tqdm import tqdm
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import dtypes, GlobalCounters
|
||||
from tinygrad.helpers import GlobalCounters
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
|
||||
from extra.utils import download_file
|
||||
from tinygrad.state import torch_load, load_state_dict
|
||||
#from extra.utils import download_file
|
||||
from tinygrad.nn.state import torch_load, load_state_dict
|
||||
|
||||
# TODO: refactor AttnBlock, CrossAttention, CLIPAttention to share code
|
||||
|
||||
|
||||
1
python/requirements.txt
Normal file
1
python/requirements.txt
Normal file
@@ -0,0 +1 @@
|
||||
tinygrad==0.9.2
|
||||
@@ -1,13 +1,16 @@
|
||||
use burn::tensor::{activation::softmax, Tensor};
|
||||
use burn::prelude::Backend;
|
||||
|
||||
/*pub type FloatTensor<B, const D: usize> = <B as burn::tensor::backend::Backend>::TensorPrimitive<D>;
|
||||
|
||||
pub trait Backend: burn::tensor::backend::Backend {
|
||||
fn qkv_attention(
|
||||
q: Self::TensorPrimitive<3>,
|
||||
k: Self::TensorPrimitive<3>,
|
||||
v: Self::TensorPrimitive<3>,
|
||||
mask: Option<Self::TensorPrimitive<2>>,
|
||||
q: FloatTensor<Self, 3>,
|
||||
k: FloatTensor<Self, 3>,
|
||||
v: FloatTensor<Self, 3>,
|
||||
mask: Option<FloatTensor<Self, 2>>,
|
||||
n_head: usize,
|
||||
) -> Self::TensorPrimitive<3> {
|
||||
) -> FloatTensor<Self, 3> {
|
||||
qkv_attention(
|
||||
Tensor::<Self, 3>::from_primitive(q),
|
||||
Tensor::from_primitive(k),
|
||||
@@ -18,24 +21,23 @@ pub trait Backend: burn::tensor::backend::Backend {
|
||||
.into_primitive()
|
||||
}
|
||||
|
||||
fn attn_decoder_mask(seq_length: usize, device: &Self::Device) -> Self::TensorPrimitive<2> {
|
||||
fn attn_decoder_mask(seq_length: usize, device: &Self::Device) -> FloatTensor<Self, 2> {
|
||||
attn_decoder_mask::<Self>(seq_length, device).into_primitive()
|
||||
}
|
||||
}
|
||||
|
||||
use burn::tensor::ops::TensorOps;
|
||||
use burn::tensor::Float;
|
||||
use burn_tch::{self, TchElement, TchTensor};
|
||||
use tch;
|
||||
|
||||
impl<E: TchElement> Backend for burn_tch::TchBackend<E> {
|
||||
impl<E: TchElement> Backend for burn_tch::LibTorch<E> {
|
||||
fn qkv_attention(
|
||||
q: Self::TensorPrimitive<3>,
|
||||
k: Self::TensorPrimitive<3>,
|
||||
v: Self::TensorPrimitive<3>,
|
||||
mask: Option<Self::TensorPrimitive<2>>,
|
||||
q: FloatTensor<Self, 3>,
|
||||
k: FloatTensor<Self, 3>,
|
||||
v: FloatTensor<Self, 3>,
|
||||
mask: Option<FloatTensor<Self, 2>>,
|
||||
n_head: usize,
|
||||
) -> Self::TensorPrimitive<3> {
|
||||
) -> FloatTensor<Self, 2> {
|
||||
let q = Tensor::from_primitive(q);
|
||||
let k = Tensor::from_primitive(k);
|
||||
let v = Tensor::from_primitive(v);
|
||||
@@ -56,7 +58,7 @@ impl<E: TchElement> Backend for burn_tch::TchBackend<E> {
|
||||
|
||||
// for some reason torch crashes when mask is None
|
||||
let mask = mask.unwrap_or_else(|| {
|
||||
Tensor::<Self, 2, Float>::zeros_device([q_ctx, k_ctx], &Self::device(&v))
|
||||
Tensor::<Self, 2, Float>::zeros([q_ctx, k_ctx], &Self::device(&v))
|
||||
.into_primitive()
|
||||
});
|
||||
|
||||
@@ -68,6 +70,7 @@ impl<E: TchElement> Backend for burn_tch::TchBackend<E> {
|
||||
Some(mask.tensor),
|
||||
0.0,
|
||||
false,
|
||||
None,
|
||||
),
|
||||
))
|
||||
.swap_dims(1, 2)
|
||||
@@ -78,11 +81,11 @@ impl<E: TchElement> Backend for burn_tch::TchBackend<E> {
|
||||
|
||||
use burn_autodiff;
|
||||
|
||||
impl<B: Backend> Backend for burn_autodiff::ADBackendDecorator<B> {}
|
||||
impl<B: Backend> Backend for burn_autodiff::Autodiff<B> {}*/
|
||||
|
||||
use std::f32::NEG_INFINITY;
|
||||
|
||||
fn qkv_attention<B: Backend>(
|
||||
pub fn qkv_attention<B: Backend>(
|
||||
q: Tensor<B, 3>,
|
||||
k: Tensor<B, 3>,
|
||||
v: Tensor<B, 3>,
|
||||
@@ -124,13 +127,13 @@ fn qkv_attention<B: Backend>(
|
||||
return o;
|
||||
}
|
||||
|
||||
fn attn_decoder_mask<B: Backend>(seq_length: usize, device: &B::Device) -> Tensor<B, 2> {
|
||||
let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length]);
|
||||
pub fn attn_decoder_mask<B: Backend>(seq_length: usize, device: &B::Device) -> Tensor<B, 2> {
|
||||
let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length], device);
|
||||
|
||||
for i in 0..(seq_length - 1) {
|
||||
let values = Tensor::<B, 2>::zeros([1, seq_length - (i + 1)]).add_scalar(NEG_INFINITY);
|
||||
let values = Tensor::<B, 2>::zeros([1, seq_length - (i + 1)], device).add_scalar(NEG_INFINITY);
|
||||
mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values);
|
||||
}
|
||||
|
||||
return mask.to_device(device);
|
||||
return mask;
|
||||
}
|
||||
|
||||
@@ -11,9 +11,9 @@ use burn::{
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
use burn_ndarray::{NdArrayBackend, NdArrayDevice};
|
||||
use burn_ndarray::{NdArray, NdArrayDevice};
|
||||
|
||||
use burn::record::{self, BinFileRecorder, FullPrecisionSettings, Recorder};
|
||||
use burn::record::{self, NamedMpkFileRecorder, FullPrecisionSettings, Recorder};
|
||||
|
||||
fn convert_dump_to_model<B: Backend>(
|
||||
dump_path: &str,
|
||||
@@ -33,11 +33,11 @@ fn save_model_file<B: Backend>(
|
||||
model: StableDiffusion<B>,
|
||||
name: &str,
|
||||
) -> Result<(), record::RecorderError> {
|
||||
BinFileRecorder::<FullPrecisionSettings>::new().record(model.into_record(), name.into())
|
||||
NamedMpkFileRecorder::<FullPrecisionSettings>::new().record(model.into_record(), name.into())
|
||||
}
|
||||
|
||||
fn main() {
|
||||
type Backend = NdArrayBackend<f32>;
|
||||
type Backend = NdArray<f32>;
|
||||
let device = NdArrayDevice::Cpu;
|
||||
|
||||
let args: Vec<String> = env::args().collect();
|
||||
|
||||
@@ -14,7 +14,7 @@ cfg_if::cfg_if! {
|
||||
if #[cfg(feature = "wgpu-backend")] {
|
||||
use burn_wgpu::{WgpuBackend, WgpuDevice, AutoGraphicsApi};
|
||||
} else {
|
||||
use burn_tch::{TchBackend, TchDevice};
|
||||
use burn_tch::{LibTorch, LibTorchDevice};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,30 +22,21 @@ use std::env;
|
||||
use std::io;
|
||||
use std::process;
|
||||
|
||||
use burn::record::{self, BinFileRecorder, FullPrecisionSettings, Recorder};
|
||||
use burn::record::{self, NamedMpkFileRecorder, FullPrecisionSettings, Recorder};
|
||||
|
||||
fn load_stable_diffusion_model_file<B: Backend>(
|
||||
filename: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<StableDiffusion<B>, record::RecorderError> {
|
||||
BinFileRecorder::<FullPrecisionSettings>::new()
|
||||
.load(filename.into())
|
||||
.map(|record| StableDiffusionConfig::new().init().load_record(record))
|
||||
NamedMpkFileRecorder::<FullPrecisionSettings>::new()
|
||||
.load(filename.into(), device)
|
||||
.map(|record| StableDiffusionConfig::new().init(device).load_record(record))
|
||||
}
|
||||
|
||||
fn main() {
|
||||
cfg_if::cfg_if! {
|
||||
if #[cfg(feature = "wgpu-backend")] {
|
||||
type Backend = WgpuBackend<AutoGraphicsApi, f32, i32>;
|
||||
let device = WgpuDevice::BestAvailable;
|
||||
} else {
|
||||
type Backend = TchBackend<f32>;
|
||||
let device = TchDevice::Cuda(0);
|
||||
}
|
||||
}
|
||||
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
if args.len() != 7 {
|
||||
eprintln!("Usage: {} <model_type(burn or dump)> <model_name> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image_name>", args[0]);
|
||||
if args.len() != 7 && args.len() != 8 {
|
||||
eprintln!("Usage: {} <model_type(burn or dump)> <model_name> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image_name> [device(cuda, mps, cpu)]", args[0]);
|
||||
process::exit(1);
|
||||
}
|
||||
|
||||
@@ -62,11 +53,40 @@ fn main() {
|
||||
let prompt = &args[5];
|
||||
let output_image_name = &args[6];
|
||||
|
||||
// Optional device parameter
|
||||
let device_arg = if args.len() == 8 { Some(&args[7]) } else { None };
|
||||
|
||||
cfg_if::cfg_if! {
|
||||
if #[cfg(feature = "wgpu-backend")] {
|
||||
type Backend = WgpuBackend<AutoGraphicsApi, f32, i32>;
|
||||
let device = WgpuDevice::BestAvailable;
|
||||
} else {
|
||||
type Backend = LibTorch<f32>;
|
||||
|
||||
let device = if let Some(dev_str) = device_arg {
|
||||
match dev_str.to_lowercase().as_str() {
|
||||
"cpu" => LibTorchDevice::Cpu,
|
||||
"mps" => LibTorchDevice::Mps,
|
||||
s if s.starts_with("cuda") => {
|
||||
let idx = s[4..].parse().unwrap_or(0);
|
||||
LibTorchDevice::Cuda(idx)
|
||||
}
|
||||
_ => {
|
||||
eprintln!("Unknown device: {}", dev_str);
|
||||
process::exit(1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
LibTorchDevice::Cuda(0)
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
println!("Loading tokenizer...");
|
||||
let tokenizer = SimpleTokenizer::new().unwrap();
|
||||
println!("Loading model...");
|
||||
let sd: StableDiffusion<Backend> = if model_type == "burn" {
|
||||
load_stable_diffusion_model_file(model_name).unwrap_or_else(|err| {
|
||||
load_stable_diffusion_model_file(model_name, &device).unwrap_or_else(|err| {
|
||||
eprintln!("Error loading model: {}", err);
|
||||
process::exit(1);
|
||||
})
|
||||
@@ -77,8 +97,6 @@ fn main() {
|
||||
})
|
||||
};
|
||||
|
||||
let sd = sd.to_device(&device);
|
||||
|
||||
let unconditional_context = sd.unconditional_context(&tokenizer);
|
||||
let context = sd.context(&tokenizer, prompt).unsqueeze::<3>(); //.repeat(0, 2); // generate 2 samples
|
||||
|
||||
|
||||
@@ -45,12 +45,12 @@ pub fn qkv_attention<B: Backend>(
|
||||
}
|
||||
|
||||
pub fn attn_decoder_mask<B: Backend>(seq_length: usize, device: &B::Device) -> Tensor<B, 2> {
|
||||
let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length]);
|
||||
let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length], device);
|
||||
|
||||
for i in 0..(seq_length - 1) {
|
||||
let values = Tensor::<B, 2>::zeros([1, seq_length - (i + 1)]).add_scalar(NEG_INFINITY);
|
||||
let values = Tensor::<B, 2>::zeros([1, seq_length - (i + 1)], device).add_scalar(NEG_INFINITY);
|
||||
mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values);
|
||||
}
|
||||
|
||||
return mask.to_device(device);
|
||||
return mask;
|
||||
}
|
||||
|
||||
@@ -71,7 +71,7 @@ fn load_padded_conv2d<B: Backend>(
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<PaddedConv2d<B>, Box<dyn Error>> {
|
||||
let conv = load_conv2d(&format!("{}/{}", path, "conv"), device)?;
|
||||
let mut conv = load_conv2d(&format!("{}/{}", path, "conv"), device)?;
|
||||
|
||||
let channels = load_tensor::<B, 1>("channels", path, device)?;
|
||||
let channels = tensor_to_array_2(channels);
|
||||
@@ -81,18 +81,21 @@ fn load_padded_conv2d<B: Backend>(
|
||||
|
||||
let padding = load_tensor::<B, 1>("padding", path, device)?;
|
||||
let padding: [usize; 4] = tensor_to_array(padding);
|
||||
let padding = Padding::new(padding[0], padding[1], padding[2], padding[3]);
|
||||
let padding = PaddingCfg::new(padding[0], padding[1], padding[2], padding[3]);
|
||||
|
||||
let mut record = conv.into_record();
|
||||
//let mut record = conv.into_record();
|
||||
|
||||
let mut padded_conv: PaddedConv2d<B> = PaddedConv2dConfig::new(channels, kernel_size, padding)
|
||||
.with_stride(stride)
|
||||
.init();
|
||||
.init(device);
|
||||
let padding_actual =
|
||||
PaddingConfig2d::Explicit(padded_conv.padding_actual[0], padded_conv.padding_actual[1]);
|
||||
|
||||
record.padding = <PaddingConfig2d as Module<B>>::into_record(padding_actual);
|
||||
padded_conv.conv = padded_conv.conv.load_record(record);
|
||||
conv.padding = burn::module::Ignored(padding_actual);
|
||||
padded_conv.conv = conv;
|
||||
|
||||
//record.padding = <PaddingConfig2d as Module<B>>::into_record(padding_actual);
|
||||
//padded_conv.conv = padded_conv.conv.load_record(record);
|
||||
|
||||
Ok(padded_conv)
|
||||
}
|
||||
|
||||
@@ -18,7 +18,8 @@ use burn::{
|
||||
|
||||
use super::groupnorm::*;
|
||||
use super::silu::*;
|
||||
use crate::backend::Backend as MyBackend;
|
||||
//use crate::backend::Backend as MyBackend;
|
||||
use crate::backend::{qkv_attention, attn_decoder_mask};
|
||||
|
||||
use std::iter;
|
||||
|
||||
@@ -26,13 +27,13 @@ use std::iter;
|
||||
pub struct AutoencoderConfig {}
|
||||
|
||||
impl AutoencoderConfig {
|
||||
pub fn init<B: Backend>(&self) -> Autoencoder<B> {
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> Autoencoder<B> {
|
||||
let encoder =
|
||||
EncoderConfig::new(vec![(128, 128), (128, 256), (256, 512), (512, 512)], 32, 8).init();
|
||||
EncoderConfig::new(vec![(128, 128), (128, 256), (256, 512), (512, 512)], 32, 8).init(device);
|
||||
let decoder =
|
||||
DecoderConfig::new(vec![(512, 512), (512, 512), (512, 256), (256, 128)], 32).init();
|
||||
let quant_conv = Conv2dConfig::new([8, 8], [1, 1]).init();
|
||||
let post_quant_conv = Conv2dConfig::new([4, 4], [1, 1]).init();
|
||||
DecoderConfig::new(vec![(512, 512), (512, 512), (512, 256), (256, 128)], 32).init(device);
|
||||
let quant_conv = Conv2dConfig::new([8, 8], [1, 1]).init(device);
|
||||
let post_quant_conv = Conv2dConfig::new([4, 4], [1, 1]).init(device);
|
||||
|
||||
Autoencoder {
|
||||
encoder,
|
||||
@@ -51,7 +52,7 @@ pub struct Autoencoder<B: Backend> {
|
||||
post_quant_conv: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: MyBackend> Autoencoder<B> {
|
||||
impl<B: Backend> Autoencoder<B> {
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
self.decode_latent(self.encode_image(x))
|
||||
}
|
||||
@@ -78,7 +79,7 @@ pub struct EncoderConfig {
|
||||
}
|
||||
|
||||
impl EncoderConfig {
|
||||
fn init<B: Backend>(&self) -> Encoder<B> {
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> Encoder<B> {
|
||||
let n_expanded_channels_initial = self
|
||||
.channels
|
||||
.first()
|
||||
@@ -88,7 +89,7 @@ impl EncoderConfig {
|
||||
|
||||
let conv_in = Conv2dConfig::new([3, n_expanded_channels_initial], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init();
|
||||
.init(device);
|
||||
|
||||
let blocks = self
|
||||
.channels
|
||||
@@ -96,16 +97,16 @@ impl EncoderConfig {
|
||||
.enumerate()
|
||||
.map(|(i, &(n_channel_in, n_channel_out))| {
|
||||
let downsample = i != self.channels.len() - 1;
|
||||
EncoderBlockConfig::new(n_channel_in, n_channel_out, downsample).init()
|
||||
EncoderBlockConfig::new(n_channel_in, n_channel_out, downsample).init(device)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mid = MidConfig::new(n_expanded_channels_final).init();
|
||||
let norm_out = GroupNormConfig::new(self.n_group, n_expanded_channels_final).init();
|
||||
let mid = MidConfig::new(n_expanded_channels_final).init(device);
|
||||
let norm_out = GroupNormConfig::new(self.n_group, n_expanded_channels_final).init(device);
|
||||
let silu = SILU::new();
|
||||
let conv_out = Conv2dConfig::new([n_expanded_channels_final, self.n_channels_out], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init();
|
||||
.init(device);
|
||||
|
||||
Encoder {
|
||||
conv_in,
|
||||
@@ -128,7 +129,7 @@ pub struct Encoder<B: Backend> {
|
||||
conv_out: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: MyBackend> Encoder<B> {
|
||||
impl<B: Backend> Encoder<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let x = self.conv_in.forward(x);
|
||||
|
||||
@@ -150,7 +151,7 @@ pub struct DecoderConfig {
|
||||
}
|
||||
|
||||
impl DecoderConfig {
|
||||
fn init<B: Backend>(&self) -> Decoder<B> {
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> Decoder<B> {
|
||||
let n_expanded_channels = self
|
||||
.channels
|
||||
.first()
|
||||
@@ -160,8 +161,8 @@ impl DecoderConfig {
|
||||
|
||||
let conv_in = Conv2dConfig::new([4, n_expanded_channels], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init();
|
||||
let mid = MidConfig::new(n_expanded_channels).init();
|
||||
.init(device);
|
||||
let mid = MidConfig::new(n_expanded_channels).init(device);
|
||||
|
||||
let blocks = self
|
||||
.channels
|
||||
@@ -169,15 +170,15 @@ impl DecoderConfig {
|
||||
.enumerate()
|
||||
.map(|(i, &(n_channel_in, n_channel_out))| {
|
||||
let upsample = i != self.channels.len() - 1;
|
||||
DecoderBlockConfig::new(n_channel_in, n_channel_out, upsample).init()
|
||||
DecoderBlockConfig::new(n_channel_in, n_channel_out, upsample).init(device)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let norm_out = GroupNormConfig::new(self.n_group, n_condensed_channels).init();
|
||||
let norm_out = GroupNormConfig::new(self.n_group, n_condensed_channels).init(device);
|
||||
let silu = SILU::new();
|
||||
let conv_out = Conv2dConfig::new([n_condensed_channels, 3], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init();
|
||||
.init(device);
|
||||
|
||||
Decoder {
|
||||
conv_in,
|
||||
@@ -200,7 +201,7 @@ pub struct Decoder<B: Backend> {
|
||||
conv_out: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: MyBackend> Decoder<B> {
|
||||
impl<B: Backend> Decoder<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let x = self.conv_in.forward(x);
|
||||
let x = self.mid.forward(x);
|
||||
@@ -223,15 +224,15 @@ pub struct EncoderBlockConfig {
|
||||
}
|
||||
|
||||
impl EncoderBlockConfig {
|
||||
fn init<B: Backend>(&self) -> EncoderBlock<B> {
|
||||
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init();
|
||||
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> EncoderBlock<B> {
|
||||
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init(device);
|
||||
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(device);
|
||||
let downsampler = if self.downsample {
|
||||
let padding = Padding::new(0, 1, 0, 1);
|
||||
let padding = PaddingCfg::new(0, 1, 0, 1);
|
||||
Some(
|
||||
PaddedConv2dConfig::new([self.n_channels_out, self.n_channels_out], 3, padding)
|
||||
.with_stride(2)
|
||||
.init(),
|
||||
.init(device),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
@@ -272,15 +273,15 @@ pub struct DecoderBlockConfig {
|
||||
}
|
||||
|
||||
impl DecoderBlockConfig {
|
||||
fn init<B: Backend>(&self) -> DecoderBlock<B> {
|
||||
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init();
|
||||
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
|
||||
let res3 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> DecoderBlock<B> {
|
||||
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init(device);
|
||||
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(device);
|
||||
let res3 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(device);
|
||||
let upsampler = if self.upsample {
|
||||
Some(
|
||||
Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init(),
|
||||
.init(device),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
@@ -313,8 +314,7 @@ impl<B: Backend> DecoderBlock<B> {
|
||||
let [n_batch, n_channel, height, width] = x.dims();
|
||||
let x = x
|
||||
.reshape([n_batch, n_channel, height, 1, width, 1])
|
||||
.repeat(3, 2)
|
||||
.repeat(5, 2)
|
||||
.repeat(&[1, 1, 1, 2, 1, 2])
|
||||
.reshape([n_batch, n_channel, 2 * height, 2 * width]);
|
||||
d.forward(x)
|
||||
} else {
|
||||
@@ -329,11 +329,11 @@ pub struct PaddedConv2dConfig {
|
||||
kernel_size: usize,
|
||||
#[config(default = 1)]
|
||||
stride: usize,
|
||||
padding: Padding,
|
||||
padding: PaddingCfg,
|
||||
}
|
||||
|
||||
impl PaddedConv2dConfig {
|
||||
fn init<B: Backend>(&self) -> PaddedConv2d<B> {
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> PaddedConv2d<B> {
|
||||
let calc_padding = |p_left, p_right| {
|
||||
let n = if p_left >= p_right {
|
||||
0
|
||||
@@ -351,12 +351,17 @@ impl PaddedConv2dConfig {
|
||||
let conv = Conv2dConfig::new(self.channels, [self.kernel_size, self.kernel_size])
|
||||
.with_stride([self.stride, self.stride])
|
||||
.with_padding(PaddingConfig2d::Explicit(pad_vertical, pad_horizontal))
|
||||
.init();
|
||||
.init(device);
|
||||
|
||||
let kernel_size = self.kernel_size;
|
||||
let stride = self.stride;
|
||||
|
||||
let padding = self.padding;
|
||||
let padding = Padding {
|
||||
pad_left: self.padding.pad_left,
|
||||
pad_right: self.padding.pad_right,
|
||||
pad_top: self.padding.pad_top,
|
||||
pad_bottom: self.padding.pad_bottom,
|
||||
};
|
||||
|
||||
PaddedConv2d {
|
||||
conv,
|
||||
@@ -406,7 +411,15 @@ impl<B: Backend> PaddedConv2d<B> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Config, Module, Copy, Debug)]
|
||||
#[derive(Config, Debug)]
|
||||
pub struct PaddingCfg {
|
||||
pad_left: usize,
|
||||
pad_right: usize,
|
||||
pad_top: usize,
|
||||
pad_bottom: usize,
|
||||
}
|
||||
|
||||
#[derive(Module, Clone, Debug)]
|
||||
pub struct Padding {
|
||||
pad_left: usize,
|
||||
pad_right: usize,
|
||||
@@ -420,10 +433,10 @@ pub struct MidConfig {
|
||||
}
|
||||
|
||||
impl MidConfig {
|
||||
fn init<B: Backend>(&self) -> Mid<B> {
|
||||
let block_1 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init();
|
||||
let attn = ConvSelfAttentionBlockConfig::new(self.n_channel).init();
|
||||
let block_2 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init();
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> Mid<B> {
|
||||
let block_1 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init(device);
|
||||
let attn = ConvSelfAttentionBlockConfig::new(self.n_channel).init(device);
|
||||
let block_2 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init(device);
|
||||
|
||||
Mid {
|
||||
block_1,
|
||||
@@ -440,7 +453,7 @@ pub struct Mid<B: Backend> {
|
||||
block_2: ResnetBlock<B>,
|
||||
}
|
||||
|
||||
impl<B: MyBackend> Mid<B> {
|
||||
impl<B: Backend> Mid<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let x = self.block_1.forward(x);
|
||||
let x = self.attn.forward(x);
|
||||
@@ -456,17 +469,17 @@ pub struct ResnetBlockConfig {
|
||||
}
|
||||
|
||||
impl ResnetBlockConfig {
|
||||
fn init<B: Backend>(&self) -> ResnetBlock<B> {
|
||||
let norm1 = GroupNormConfig::new(32, self.in_channels).init();
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> ResnetBlock<B> {
|
||||
let norm1 = GroupNormConfig::new(32, self.in_channels).init(device);
|
||||
let conv1 = Conv2dConfig::new([self.in_channels, self.out_channels], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init();
|
||||
let norm2 = GroupNormConfig::new(32, self.out_channels).init();
|
||||
.init(device);
|
||||
let norm2 = GroupNormConfig::new(32, self.out_channels).init(device);
|
||||
let conv2 = Conv2dConfig::new([self.out_channels, self.out_channels], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init();
|
||||
.init(device);
|
||||
let nin_shortcut = if self.in_channels != self.out_channels {
|
||||
Some(Conv2dConfig::new([self.in_channels, self.out_channels], [1, 1]).init())
|
||||
Some(Conv2dConfig::new([self.in_channels, self.out_channels], [1, 1]).init(device))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -520,12 +533,12 @@ pub struct ConvSelfAttentionBlockConfig {
|
||||
}
|
||||
|
||||
impl ConvSelfAttentionBlockConfig {
|
||||
fn init<B: Backend>(&self) -> ConvSelfAttentionBlock<B> {
|
||||
let norm = GroupNormConfig::new(32, self.n_channel).init();
|
||||
let q = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
|
||||
let k = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
|
||||
let v = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
|
||||
let proj_out = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> ConvSelfAttentionBlock<B> {
|
||||
let norm = GroupNormConfig::new(32, self.n_channel).init(device);
|
||||
let q = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
|
||||
let k = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
|
||||
let v = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
|
||||
let proj_out = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
|
||||
|
||||
ConvSelfAttentionBlock {
|
||||
norm,
|
||||
@@ -546,7 +559,7 @@ pub struct ConvSelfAttentionBlock<B: Backend> {
|
||||
proj_out: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: MyBackend> ConvSelfAttentionBlock<B> {
|
||||
impl<B: Backend> ConvSelfAttentionBlock<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let [n_batch, n_channel, height, width] = x.dims();
|
||||
|
||||
@@ -568,7 +581,7 @@ impl<B: MyBackend> ConvSelfAttentionBlock<B> {
|
||||
.reshape([n_batch, n_channel, height * width])
|
||||
.swap_dims(1, 2);
|
||||
|
||||
let wv = Tensor::from_primitive(B::qkv_attention(
|
||||
/*let wv = Tensor::from_primitive(B::qkv_attention(
|
||||
q.into_primitive(),
|
||||
k.into_primitive(),
|
||||
v.into_primitive(),
|
||||
@@ -576,6 +589,16 @@ impl<B: MyBackend> ConvSelfAttentionBlock<B> {
|
||||
1,
|
||||
))
|
||||
.swap_dims(1, 2)
|
||||
.reshape([n_batch, n_channel, height, width]);*/
|
||||
|
||||
let wv = qkv_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
None,
|
||||
1,
|
||||
)
|
||||
.swap_dims(1, 2)
|
||||
.reshape([n_batch, n_channel, height, width]);
|
||||
|
||||
let projected = self.proj_out.forward(wv);
|
||||
|
||||
@@ -68,7 +68,7 @@ pub fn load_residual_decoder_attention_block<B: Backend>(
|
||||
pub fn load_clip<B: Backend>(path: &str, device: &B::Device) -> Result<CLIP<B>, Box<dyn Error>> {
|
||||
let token_embedding = load_embedding(&format!("{}/{}", path, "token_embedding"), device)?;
|
||||
let position_embedding =
|
||||
load_tensor("weight", &format!("{}/position_embedding", path), device)?.into();
|
||||
Param::from_tensor(load_tensor("weight", &format!("{}/position_embedding", path), device)?);
|
||||
|
||||
let n_layer = load_usize::<B>("n_layer", path, device)?;
|
||||
let mut blocks = (0..n_layer)
|
||||
|
||||
@@ -12,7 +12,8 @@ use burn::{
|
||||
},
|
||||
};
|
||||
|
||||
use crate::backend::Backend as MyBackend;
|
||||
//use crate::backend::Backend as MyBackend;
|
||||
use crate::backend::{qkv_attention, attn_decoder_mask};
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct CLIPConfig {
|
||||
@@ -24,15 +25,15 @@ pub struct CLIPConfig {
|
||||
}
|
||||
|
||||
impl CLIPConfig {
|
||||
pub fn init<B: Backend>(&self) -> CLIP<B> {
|
||||
let token_embedding = nn::EmbeddingConfig::new(self.n_vocab, self.n_state).init();
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> CLIP<B> {
|
||||
let token_embedding = nn::EmbeddingConfig::new(self.n_vocab, self.n_state).init(device);
|
||||
let position_embedding =
|
||||
Tensor::random([self.n_ctx, self.n_state], Distribution::Normal(0.0, 1.0)).into();
|
||||
Param::from_tensor(Tensor::random([self.n_ctx, self.n_state], Distribution::Normal(0.0, 1.0), device));
|
||||
let blocks = (0..self.n_layer)
|
||||
.into_iter()
|
||||
.map(|_| ResidualDecoderAttentionBlockConfig::new(self.n_state, self.n_head).init())
|
||||
.map(|_| ResidualDecoderAttentionBlockConfig::new(self.n_state, self.n_head).init(device))
|
||||
.collect();
|
||||
let layer_norm = nn::LayerNormConfig::new(self.n_state).init();
|
||||
let layer_norm = nn::LayerNormConfig::new(self.n_state).init(device);
|
||||
|
||||
CLIP {
|
||||
token_embedding,
|
||||
@@ -51,11 +52,12 @@ pub struct CLIP<B: Backend> {
|
||||
layer_norm: nn::LayerNorm<B>,
|
||||
}
|
||||
|
||||
impl<B: MyBackend> CLIP<B> {
|
||||
impl<B: Backend> CLIP<B> {
|
||||
pub fn forward(&self, x: Tensor<B, 2, Int>) -> Tensor<B, 3> {
|
||||
let [n_batch, seq_len] = x.dims();
|
||||
|
||||
let mask = Tensor::from_primitive(B::attn_decoder_mask(seq_len, &x.device()));
|
||||
//let mask = Tensor::from_primitive(B::attn_decoder_mask(seq_len, &x.device()));
|
||||
let mask = attn_decoder_mask(seq_len, &x.device());
|
||||
|
||||
let embedded = self.token_embedding.forward(x)
|
||||
+ self
|
||||
@@ -80,12 +82,12 @@ pub struct ResidualDecoderAttentionBlockConfig {
|
||||
}
|
||||
|
||||
impl ResidualDecoderAttentionBlockConfig {
|
||||
pub fn init<B: Backend>(&self) -> ResidualDecoderAttentionBlock<B> {
|
||||
let attn = MultiHeadSelfAttentionConfig::new(self.n_state, self.n_head).init();
|
||||
let attn_ln = nn::LayerNormConfig::new(self.n_state).init();
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> ResidualDecoderAttentionBlock<B> {
|
||||
let attn = MultiHeadSelfAttentionConfig::new(self.n_state, self.n_head).init(device);
|
||||
let attn_ln = nn::LayerNormConfig::new(self.n_state).init(device);
|
||||
|
||||
let mlp = MLPConfig::new(self.n_state, 4 * self.n_state).init();
|
||||
let mlp_ln = nn::LayerNormConfig::new(self.n_state).init();
|
||||
let mlp = MLPConfig::new(self.n_state, 4 * self.n_state).init(device);
|
||||
let mlp_ln = nn::LayerNormConfig::new(self.n_state).init(device);
|
||||
|
||||
ResidualDecoderAttentionBlock {
|
||||
attn,
|
||||
@@ -104,7 +106,7 @@ pub struct ResidualDecoderAttentionBlock<B: Backend> {
|
||||
mlp_ln: nn::LayerNorm<B>,
|
||||
}
|
||||
|
||||
impl<B: MyBackend> ResidualDecoderAttentionBlock<B> {
|
||||
impl<B: Backend> ResidualDecoderAttentionBlock<B> {
|
||||
fn forward(&self, x: Tensor<B, 3>, mask: Tensor<B, 2>) -> Tensor<B, 3> {
|
||||
let x = x.clone() + self.attn.forward(self.attn_ln.forward(x), Some(mask));
|
||||
let x = x.clone() + self.mlp.forward(self.mlp_ln.forward(x));
|
||||
@@ -119,7 +121,7 @@ pub struct MultiHeadSelfAttentionConfig {
|
||||
}
|
||||
|
||||
impl MultiHeadSelfAttentionConfig {
|
||||
fn init<B: Backend>(&self) -> MultiHeadSelfAttention<B> {
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> MultiHeadSelfAttention<B> {
|
||||
assert!(
|
||||
self.n_state % self.n_head == 0,
|
||||
"State size {} must be a multiple of head size {}",
|
||||
@@ -128,10 +130,10 @@ impl MultiHeadSelfAttentionConfig {
|
||||
);
|
||||
|
||||
let n_head = self.n_head;
|
||||
let query = nn::LinearConfig::new(self.n_state, self.n_state).init();
|
||||
let key = nn::LinearConfig::new(self.n_state, self.n_state).init();
|
||||
let value = nn::LinearConfig::new(self.n_state, self.n_state).init();
|
||||
let out = nn::LinearConfig::new(self.n_state, self.n_state).init();
|
||||
let query = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
|
||||
let key = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
|
||||
let value = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
|
||||
let out = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
|
||||
|
||||
MultiHeadSelfAttention {
|
||||
n_head,
|
||||
@@ -152,19 +154,27 @@ pub struct MultiHeadSelfAttention<B: Backend> {
|
||||
out: nn::Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: MyBackend> MultiHeadSelfAttention<B> {
|
||||
impl<B: Backend> MultiHeadSelfAttention<B> {
|
||||
pub fn forward(&self, x: Tensor<B, 3>, mask: Option<Tensor<B, 2>>) -> Tensor<B, 3> {
|
||||
let q = self.query.forward(x.clone());
|
||||
let k = self.key.forward(x.clone());
|
||||
let v = self.value.forward(x);
|
||||
|
||||
let wv = Tensor::from_primitive(B::qkv_attention(
|
||||
/*let wv = Tensor::from_primitive(B::qkv_attention(
|
||||
q.into_primitive(),
|
||||
k.into_primitive(),
|
||||
v.into_primitive(),
|
||||
mask.map(|m| m.into_primitive()),
|
||||
self.n_head,
|
||||
));
|
||||
));*/
|
||||
|
||||
let wv = qkv_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
mask,
|
||||
self.n_head,
|
||||
);
|
||||
|
||||
return self.out.forward(wv);
|
||||
}
|
||||
@@ -177,10 +187,10 @@ pub struct MLPConfig {
|
||||
}
|
||||
|
||||
impl MLPConfig {
|
||||
fn init<B: Backend>(&self) -> MLP<B> {
|
||||
let fc1 = nn::LinearConfig::new(self.input_size, self.hidden_size).init();
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> MLP<B> {
|
||||
let fc1 = nn::LinearConfig::new(self.input_size, self.hidden_size).init(device);
|
||||
let gelu = QuickGELU::new();
|
||||
let fc2 = nn::LinearConfig::new(self.hidden_size, self.input_size).init();
|
||||
let fc2 = nn::LinearConfig::new(self.hidden_size, self.input_size).init(device);
|
||||
|
||||
MLP { fc1, gelu, fc2 }
|
||||
}
|
||||
|
||||
@@ -18,14 +18,14 @@ pub fn load_group_norm<B: Backend>(
|
||||
let n_channel = load_usize::<B>("n_channel", path, device)?.into();
|
||||
let eps = load_f32::<B>("eps", path, device)?.into();
|
||||
|
||||
let gamma = load_tensor::<B, 1>("weight", path, device)
|
||||
let gamma = Param::from_tensor(load_tensor::<B, 1>("weight", path, device)
|
||||
.ok()
|
||||
.unwrap_or_else(|| Tensor::ones_device([n_channel], device))
|
||||
.into();
|
||||
let beta = load_tensor::<B, 1>("bias", path, device)
|
||||
.unwrap_or_else(|| Tensor::ones([n_channel], device))
|
||||
);
|
||||
let beta = Param::from_tensor(load_tensor::<B, 1>("bias", path, device)
|
||||
.ok()
|
||||
.unwrap_or_else(|| Tensor::zeros_device([n_channel], device))
|
||||
.into();
|
||||
.unwrap_or_else(|| Tensor::zeros([n_channel], device))
|
||||
);
|
||||
|
||||
Ok(GroupNorm {
|
||||
n_group,
|
||||
|
||||
@@ -15,7 +15,7 @@ pub struct GroupNormConfig {
|
||||
}
|
||||
|
||||
impl GroupNormConfig {
|
||||
pub fn init<B: Backend>(&self) -> GroupNorm<B> {
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> GroupNorm<B> {
|
||||
assert!(
|
||||
self.n_channel % self.n_group == 0,
|
||||
"The number of channels {} must be divisible by the number of groups {}",
|
||||
@@ -25,8 +25,8 @@ impl GroupNormConfig {
|
||||
|
||||
let n_per_group = self.n_channel / self.n_group;
|
||||
|
||||
let gamma = Tensor::ones([self.n_channel]).into();
|
||||
let beta = Tensor::zeros([self.n_channel]).into();
|
||||
let gamma = Param::from_tensor(Tensor::ones([self.n_channel], device));
|
||||
let beta = Param::from_tensor(Tensor::zeros([self.n_channel], device));
|
||||
|
||||
let eps = self.eps;
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
use npy::{self, NpyData};
|
||||
use num_traits::cast::ToPrimitive;
|
||||
use burn::tensor::cast::ToElement;
|
||||
use burn::prelude::TensorData;
|
||||
use std::error::Error;
|
||||
use std::io::Read;
|
||||
|
||||
@@ -21,7 +23,8 @@ pub fn numpy_to_tensor<B: Backend, const D: usize>(
|
||||
let shape: Vec<_> = v[0..D].into_iter().map(|&v| v as usize).collect();
|
||||
let data: Vec<B::FloatElem> = v[D..].into_iter().map(|e| e.elem()).collect();
|
||||
|
||||
Tensor::from_data_device(Data::new(data, shape.into()), device)
|
||||
//Tensor::from_data_device(Data::new(data, shape.into()), device)
|
||||
Tensor::from_data(TensorData::new(data, shape), device)
|
||||
}
|
||||
|
||||
pub fn load_tensor<B: Backend, const D: usize>(
|
||||
@@ -48,7 +51,7 @@ pub fn load_f32<B: Backend>(
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<f32, Box<dyn Error>> {
|
||||
load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_f32().unwrap())
|
||||
load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_f32())
|
||||
}
|
||||
|
||||
pub fn load_usize<B: Backend>(
|
||||
@@ -56,7 +59,7 @@ pub fn load_usize<B: Backend>(
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<usize, Box<dyn Error>> {
|
||||
load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_usize().unwrap())
|
||||
load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_usize())
|
||||
}
|
||||
|
||||
pub fn load_linear<B: Backend>(
|
||||
@@ -66,13 +69,10 @@ pub fn load_linear<B: Backend>(
|
||||
let weight = load_tensor::<B, 2>("weight", path, device)?;
|
||||
let bias = load_tensor::<B, 1>("bias", path, device).ok();
|
||||
|
||||
let record = nn::LinearRecord {
|
||||
weight: weight.into(),
|
||||
bias: bias.map(|t| t.into()),
|
||||
};
|
||||
|
||||
let linear: nn::Linear<B> = nn::LinearConfig::new(3, 3).init_with(record);
|
||||
Ok(linear)
|
||||
Ok(nn::Linear {
|
||||
weight: Param::from_tensor(weight),
|
||||
bias: bias.map(|t| Param::from_tensor(t)),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn load_embedding<B: Backend>(
|
||||
@@ -80,14 +80,10 @@ pub fn load_embedding<B: Backend>(
|
||||
device: &B::Device,
|
||||
) -> Result<nn::Embedding<B>, Box<dyn Error>> {
|
||||
let weight = load_tensor::<B, 2>("weight", path, device)?;
|
||||
let [n_vocab, n_state] = weight.dims();
|
||||
|
||||
let record = nn::EmbeddingRecord {
|
||||
weight: weight.into(),
|
||||
};
|
||||
|
||||
let embedding = nn::EmbeddingConfig::new(n_vocab, n_state).init_with(record);
|
||||
Ok(embedding)
|
||||
Ok(nn::Embedding {
|
||||
weight: Param::from_tensor(weight),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn load_layer_norm<B: Backend>(
|
||||
@@ -100,13 +96,9 @@ pub fn load_layer_norm<B: Backend>(
|
||||
|
||||
let [n_state] = weight.dims();
|
||||
|
||||
let record = nn::LayerNormRecord {
|
||||
gamma: weight.into(),
|
||||
beta: bias.into(),
|
||||
epsilon: <f64 as Module<B>>::into_record(eps),
|
||||
};
|
||||
|
||||
let layer_norm: nn::LayerNorm<B> = nn::LayerNormConfig::new(n_state).init_with(record);
|
||||
let mut layer_norm = nn::LayerNormConfig::new(n_state).with_epsilon(eps).init(device);
|
||||
layer_norm.gamma = Param::from_tensor(weight);
|
||||
layer_norm.beta = Param::from_tensor(bias);
|
||||
|
||||
Ok(layer_norm)
|
||||
}
|
||||
@@ -116,7 +108,7 @@ pub fn load_layer_norm<B: Backend>(
|
||||
let eps = load_f32::<B>("eps", path, device)?.into();
|
||||
|
||||
let rmsnorm = RMSNorm {
|
||||
weight: weight.into(),
|
||||
weight: Param::from_tensor(weight),
|
||||
eps: eps
|
||||
};
|
||||
|
||||
@@ -148,40 +140,38 @@ pub fn load_conv2d<B: Backend>(
|
||||
let padding = tensor_to_array_2(padding);
|
||||
let padding = nn::PaddingConfig2d::Explicit(padding[0], padding[1]);
|
||||
|
||||
let record = conv::Conv2dRecord {
|
||||
weight: weight.into(),
|
||||
bias: bias.map(|t| t.into()),
|
||||
stride: <[usize; 2] as Module<B>>::into_record(stride),
|
||||
kernel_size: <[usize; 2] as Module<B>>::into_record(kernel_size),
|
||||
dilation: <[usize; 2] as Module<B>>::into_record(dilation),
|
||||
groups: <usize as Module<B>>::into_record(n_group),
|
||||
padding: <nn::PaddingConfig2d as Module<B>>::into_record(padding.clone()),
|
||||
};
|
||||
|
||||
let conv2d: conv::Conv2d<B> =
|
||||
conv::Conv2dConfig::new([n_channels_in, n_channels_out], kernel_size)
|
||||
let mut conv2d = conv::Conv2dConfig::new([n_channels_in, n_channels_out], kernel_size)
|
||||
.with_stride(stride)
|
||||
.with_dilation(dilation)
|
||||
.with_groups(n_group)
|
||||
.with_padding(padding)
|
||||
.with_padding(padding.clone())
|
||||
.with_bias(has_bias)
|
||||
.init_with(record);
|
||||
.init(device);
|
||||
|
||||
conv2d.weight = Param::from_tensor(weight);
|
||||
conv2d.bias = bias.map(|t| Param::from_tensor(t));
|
||||
conv2d.stride = stride;
|
||||
conv2d.kernel_size = kernel_size;
|
||||
conv2d.dilation = dilation;
|
||||
conv2d.groups = n_group;
|
||||
conv2d.padding = burn::module::Ignored(padding);
|
||||
|
||||
Ok(conv2d)
|
||||
}
|
||||
|
||||
pub fn tensor_to_array_2<B: Backend>(x: Tensor<B, 1>) -> [usize; 2] {
|
||||
let vec = x.into_data().value;
|
||||
let vec: Vec<<B as Backend>::FloatElem> = x.into_data().to_vec().unwrap();
|
||||
assert!(vec.len() == 2, "Tensor length must be 2.");
|
||||
[vec[0].to_usize().unwrap(), vec[1].to_usize().unwrap()]
|
||||
[vec[0].to_usize(), vec[1].to_usize()]
|
||||
}
|
||||
|
||||
pub fn tensor_to_array<const N: usize, B: Backend>(x: Tensor<B, 1>) -> [usize; N] {
|
||||
let vec = x.into_data().value;
|
||||
let vec: Vec<<B as Backend>::FloatElem> = x.into_data().to_vec().unwrap();
|
||||
assert!(vec.len() == N, "Tensor length must be {}.", N);
|
||||
|
||||
let mut arr = [0; N];
|
||||
for (a, t) in arr.iter_mut().zip(vec) {
|
||||
*a = t.to_usize().unwrap();
|
||||
*a = t.to_usize();
|
||||
}
|
||||
|
||||
arr
|
||||
|
||||
@@ -18,7 +18,7 @@ pub fn load_stable_diffusion<B: Backend>(
|
||||
device: &B::Device,
|
||||
) -> Result<StableDiffusion<B>, Box<dyn Error>> {
|
||||
let n_steps = load_usize::<B>("n_steps", path, device)?;
|
||||
let alpha_cumulative_products = load_tensor::<B, 1>("alphas_cumprod", path, device)?.into();
|
||||
let alpha_cumulative_products = Param::from_tensor(load_tensor::<B, 1>("alphas_cumprod", path, device)?);
|
||||
let autoencoder = load_autoencoder(&format!("{}/{}", path, "autoencoder"), device)?;
|
||||
let diffusion = load_unet(&format!("{}/{}", path, "unet"), device)?;
|
||||
let clip = load_clip(&format!("{}/{}", path, "clip"), device)?;
|
||||
|
||||
@@ -4,11 +4,12 @@ use burn::{
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
tensor::{backend::Backend, BasicOps, Data, Distribution, Float, Int, Tensor},
|
||||
tensor::cast::ToElement,
|
||||
};
|
||||
|
||||
use num_traits::ToPrimitive;
|
||||
|
||||
use crate::backend::Backend as MyBackend;
|
||||
//use crate::backend::Backend as MyBackend;
|
||||
|
||||
use super::autoencoder::{Autoencoder, AutoencoderConfig};
|
||||
use super::clip::{CLIPConfig, CLIP};
|
||||
@@ -19,13 +20,13 @@ use crate::tokenizer::SimpleTokenizer;
|
||||
pub struct StableDiffusionConfig {}
|
||||
|
||||
impl StableDiffusionConfig {
|
||||
pub fn init<B: Backend>(&self) -> StableDiffusion<B> {
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> StableDiffusion<B> {
|
||||
let n_steps = 1000;
|
||||
let alpha_cumulative_products = offset_cosine_schedule_cumprod::<B>(n_steps).into();
|
||||
let alpha_cumulative_products = Param::from_tensor(offset_cosine_schedule_cumprod::<B>(n_steps as i64, device));
|
||||
|
||||
let autoencoder = AutoencoderConfig::new().init();
|
||||
let diffusion = UNetConfig::new().init();
|
||||
let clip = CLIPConfig::new(49408, 768, 12, 77, 12).init();
|
||||
let autoencoder = AutoencoderConfig::new().init(device);
|
||||
let diffusion = UNetConfig::new().init(device);
|
||||
let clip = CLIPConfig::new(49408, 768, 12, 77, 12).init(device);
|
||||
|
||||
StableDiffusion {
|
||||
n_steps,
|
||||
@@ -46,7 +47,7 @@ pub struct StableDiffusion<B: Backend> {
|
||||
clip: CLIP<B>,
|
||||
}
|
||||
|
||||
impl<B: MyBackend> StableDiffusion<B> {
|
||||
impl<B: Backend> StableDiffusion<B> {
|
||||
pub fn sample_image(
|
||||
&self,
|
||||
context: Tensor<B, 3>,
|
||||
@@ -82,7 +83,7 @@ impl<B: MyBackend> StableDiffusion<B> {
|
||||
.swap_dims(2, 3)
|
||||
.mul_scalar(255.0);
|
||||
|
||||
let flattened: Vec<_> = image.into_data().value;
|
||||
let flattened: Vec<B::FloatElem> = image.into_data().to_vec().unwrap();
|
||||
|
||||
(0..n_batch)
|
||||
.into_iter()
|
||||
@@ -92,7 +93,7 @@ impl<B: MyBackend> StableDiffusion<B> {
|
||||
|
||||
flattened[start..end]
|
||||
.into_iter()
|
||||
.map(|v| v.to_f64().unwrap().min(255.0).max(0.0).to_u8().unwrap())
|
||||
.map(|v| v.to_f64().min(255.0).max(0.0) as u8)
|
||||
.collect()
|
||||
})
|
||||
.collect()
|
||||
@@ -112,8 +113,7 @@ impl<B: MyBackend> StableDiffusion<B> {
|
||||
let [n_batches, _, _] = context.dims();
|
||||
|
||||
let gen_noise = || {
|
||||
Tensor::random([n_batches, 4, 64, 64], Distribution::Normal(0.0, 1.0))
|
||||
.to_device(&device)
|
||||
Tensor::random([n_batches, 4, 64, 64], Distribution::Normal(0.0, 1.0), &device)
|
||||
};
|
||||
|
||||
let sigma = 0.0; // Use deterministic diffusion
|
||||
@@ -126,8 +126,8 @@ impl<B: MyBackend> StableDiffusion<B> {
|
||||
.val()
|
||||
.slice([t..t + 1])
|
||||
.into_scalar()
|
||||
.to_f64()
|
||||
.unwrap();
|
||||
.to_f64();
|
||||
|
||||
let prev_alpha: f64 = if t >= step_size {
|
||||
let i = t - step_size;
|
||||
self.alpha_cumulative_products
|
||||
@@ -135,14 +135,13 @@ impl<B: MyBackend> StableDiffusion<B> {
|
||||
.slice([i..i + 1])
|
||||
.into_scalar()
|
||||
.to_f64()
|
||||
.unwrap()
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
|
||||
let sqrt_noise = (1.0 - current_alpha).sqrt();
|
||||
|
||||
let timestep = Tensor::from_ints([t as i32]).to_device(&device);
|
||||
let timestep = Tensor::from_ints([t as i32], &device);
|
||||
let pred_noise = self.forward_diffuser(
|
||||
latent.clone(),
|
||||
timestep,
|
||||
@@ -174,7 +173,7 @@ impl<B: MyBackend> StableDiffusion<B> {
|
||||
let unconditional_latent = self.diffusion.forward(
|
||||
latent.clone(),
|
||||
timestep.clone(),
|
||||
unconditional_context.unsqueeze().repeat(0, n_batch),
|
||||
unconditional_context.unsqueeze().repeat(&[0, n_batch]),
|
||||
);
|
||||
|
||||
let conditional_latent = self.diffusion.forward(latent, timestep, context);
|
||||
@@ -206,8 +205,7 @@ impl<B: MyBackend> StableDiffusion<B> {
|
||||
.collect();
|
||||
|
||||
self.clip.forward(
|
||||
Tensor::from_ints(&tokenized[..])
|
||||
.to_device(device)
|
||||
Tensor::<B, 1, Int>::from_ints(&tokenized[..], device)
|
||||
.unsqueeze(),
|
||||
)
|
||||
}
|
||||
@@ -215,25 +213,25 @@ impl<B: MyBackend> StableDiffusion<B> {
|
||||
|
||||
use std::f64::consts::PI;
|
||||
|
||||
fn cosine_schedule<B: Backend>(n_steps: usize) -> Tensor<B, 1> {
|
||||
Tensor::arange(1..n_steps + 1)
|
||||
fn cosine_schedule<B: Backend>(n_steps: i64, device: &B::Device) -> Tensor<B, 1> {
|
||||
Tensor::arange(1..n_steps + 1, device)
|
||||
.float()
|
||||
.mul_scalar(PI * 0.5 / n_steps as f64)
|
||||
.cos()
|
||||
}
|
||||
|
||||
fn offset_cosine_schedule<B: Backend>(n_steps: usize) -> Tensor<B, 1> {
|
||||
fn offset_cosine_schedule<B: Backend>(n_steps: i64, device: &B::Device) -> Tensor<B, 1> {
|
||||
let min_signal_rate: f64 = 0.02;
|
||||
let max_signal_rate: f64 = 0.95;
|
||||
let start_angle = max_signal_rate.acos();
|
||||
let end_angle = min_signal_rate.acos();
|
||||
|
||||
let times = Tensor::arange(1..n_steps + 1).float();
|
||||
let times = Tensor::arange(1..n_steps + 1, device).float();
|
||||
|
||||
let diffusion_angles = times * ((end_angle - start_angle) / n_steps as f64) + start_angle;
|
||||
diffusion_angles.cos()
|
||||
}
|
||||
|
||||
fn offset_cosine_schedule_cumprod<B: Backend>(n_steps: usize) -> Tensor<B, 1> {
|
||||
offset_cosine_schedule::<B>(n_steps).powf(2.0)
|
||||
fn offset_cosine_schedule_cumprod<B: Backend>(n_steps: i64, device: &B::Device) -> Tensor<B, 1> {
|
||||
offset_cosine_schedule::<B>(n_steps, device).powf_scalar(2.0)
|
||||
}
|
||||
|
||||
@@ -65,7 +65,7 @@ pub fn load_geglu<B: Backend>(path: &str, device: &B::Device) -> Result<GEGLU<B>
|
||||
|
||||
let geglue = GEGLU {
|
||||
proj: proj,
|
||||
gelu: GELU::new(), // Assuming GELU::new() initializes a new GELU struct
|
||||
gelu: Gelu::new(), // Assuming Gelu::new() initializes a new Gelu struct
|
||||
};
|
||||
|
||||
Ok(geglue)
|
||||
|
||||
@@ -6,7 +6,7 @@ use burn::{
|
||||
nn::{
|
||||
self,
|
||||
conv::{Conv2d, Conv2dConfig},
|
||||
PaddingConfig2d, GELU,
|
||||
PaddingConfig2d, Gelu,
|
||||
},
|
||||
tensor::{activation::softmax, backend::Backend, module::embedding, Distribution, Int, Tensor},
|
||||
};
|
||||
@@ -22,7 +22,7 @@ fn timestep_embedding<B: Backend>(
|
||||
max_period: usize,
|
||||
) -> Tensor<B, 2> {
|
||||
let half = dim / 2;
|
||||
let freqs = (Tensor::arange_device(0..half, ×teps.device()).float()
|
||||
let freqs = (Tensor::arange(0..half as i64, ×teps.device()).float()
|
||||
* (-(max_period as f64).ln() / half as f64))
|
||||
.exp();
|
||||
let args = timesteps.float() * freqs;
|
||||
@@ -33,50 +33,50 @@ fn timestep_embedding<B: Backend>(
|
||||
pub struct UNetConfig {}
|
||||
|
||||
impl UNetConfig {
|
||||
pub fn init<B: Backend>(&self) -> UNet<B> {
|
||||
let lin1_time_embed = nn::LinearConfig::new(320, 1280).init();
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> UNet<B> {
|
||||
let lin1_time_embed = nn::LinearConfig::new(320, 1280).init(device);
|
||||
let silu_time_embed = SILU::new();
|
||||
let lin2_time_embed = nn::LinearConfig::new(1280, 1280).init();
|
||||
let lin2_time_embed = nn::LinearConfig::new(1280, 1280).init(device);
|
||||
|
||||
let input_blocks = UNetInputBlocks {
|
||||
conv: Conv2dConfig::new([4, 320], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init(),
|
||||
rt1: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(),
|
||||
rt2: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(),
|
||||
d1: DownsampleConfig::new(320).init(),
|
||||
rt3: ResTransformerConfig::new(320, 1280, 640, 768, 8).init(),
|
||||
rt4: ResTransformerConfig::new(640, 1280, 640, 768, 8).init(),
|
||||
d2: DownsampleConfig::new(640).init(),
|
||||
rt5: ResTransformerConfig::new(640, 1280, 1280, 768, 8).init(),
|
||||
rt6: ResTransformerConfig::new(1280, 1280, 1280, 768, 8).init(),
|
||||
d3: DownsampleConfig::new(1280).init(),
|
||||
r1: ResBlockConfig::new(1280, 1280, 1280).init(),
|
||||
r2: ResBlockConfig::new(1280, 1280, 1280).init(),
|
||||
.init(device),
|
||||
rt1: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(device),
|
||||
rt2: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(device),
|
||||
d1: DownsampleConfig::new(320).init(device),
|
||||
rt3: ResTransformerConfig::new(320, 1280, 640, 768, 8).init(device),
|
||||
rt4: ResTransformerConfig::new(640, 1280, 640, 768, 8).init(device),
|
||||
d2: DownsampleConfig::new(640).init(device),
|
||||
rt5: ResTransformerConfig::new(640, 1280, 1280, 768, 8).init(device),
|
||||
rt6: ResTransformerConfig::new(1280, 1280, 1280, 768, 8).init(device),
|
||||
d3: DownsampleConfig::new(1280).init(device),
|
||||
r1: ResBlockConfig::new(1280, 1280, 1280).init(device),
|
||||
r2: ResBlockConfig::new(1280, 1280, 1280).init(device),
|
||||
};
|
||||
|
||||
let middle_block = ResTransformerResConfig::new(1280, 1280, 1280, 768, 8).init();
|
||||
let middle_block = ResTransformerResConfig::new(1280, 1280, 1280, 768, 8).init(device);
|
||||
|
||||
let output_blocks = UNetOutputBlocks {
|
||||
r1: ResBlockConfig::new(2560, 1280, 1280).init(),
|
||||
r2: ResBlockConfig::new(2560, 1280, 1280).init(),
|
||||
ru: ResUpSampleConfig::new(2560, 1280, 1280).init(),
|
||||
rt1: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(),
|
||||
rt2: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(),
|
||||
rtu1: ResTransformerUpsampleConfig::new(1920, 1280, 1280, 768, 8).init(),
|
||||
rt3: ResTransformerConfig::new(1920, 1280, 640, 768, 8).init(),
|
||||
rt4: ResTransformerConfig::new(1280, 1280, 640, 768, 8).init(),
|
||||
rtu2: ResTransformerUpsampleConfig::new(960, 1280, 640, 768, 8).init(),
|
||||
rt5: ResTransformerConfig::new(960, 1280, 320, 768, 8).init(),
|
||||
rt6: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(),
|
||||
rt7: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(),
|
||||
r1: ResBlockConfig::new(2560, 1280, 1280).init(device),
|
||||
r2: ResBlockConfig::new(2560, 1280, 1280).init(device),
|
||||
ru: ResUpSampleConfig::new(2560, 1280, 1280).init(device),
|
||||
rt1: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(device),
|
||||
rt2: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(device),
|
||||
rtu1: ResTransformerUpsampleConfig::new(1920, 1280, 1280, 768, 8).init(device),
|
||||
rt3: ResTransformerConfig::new(1920, 1280, 640, 768, 8).init(device),
|
||||
rt4: ResTransformerConfig::new(1280, 1280, 640, 768, 8).init(device),
|
||||
rtu2: ResTransformerUpsampleConfig::new(960, 1280, 640, 768, 8).init(device),
|
||||
rt5: ResTransformerConfig::new(960, 1280, 320, 768, 8).init(device),
|
||||
rt6: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(device),
|
||||
rt7: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(device),
|
||||
};
|
||||
|
||||
let norm_out = GroupNormConfig::new(32, 320).init();
|
||||
let norm_out = GroupNormConfig::new(32, 320).init(device);
|
||||
let silu_out = SILU::new();
|
||||
let conv_out = Conv2dConfig::new([320, 4], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init();
|
||||
.init(device);
|
||||
|
||||
UNet {
|
||||
lin1_time_embed,
|
||||
@@ -206,16 +206,16 @@ pub struct ResTransformerConfig {
|
||||
}
|
||||
|
||||
impl ResTransformerConfig {
|
||||
fn init<B: Backend>(&self) -> ResTransformer<B> {
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> ResTransformer<B> {
|
||||
let res = ResBlockConfig::new(
|
||||
self.n_channels_in,
|
||||
self.n_channels_embed,
|
||||
self.n_channels_out,
|
||||
)
|
||||
.init();
|
||||
.init(device);
|
||||
let transformer =
|
||||
SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head)
|
||||
.init();
|
||||
.init(device);
|
||||
|
||||
ResTransformer { res, transformer }
|
||||
}
|
||||
@@ -243,14 +243,14 @@ pub struct ResUpSampleConfig {
|
||||
}
|
||||
|
||||
impl ResUpSampleConfig {
|
||||
fn init<B: Backend>(&self) -> ResUpSample<B> {
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> ResUpSample<B> {
|
||||
let res = ResBlockConfig::new(
|
||||
self.n_channels_in,
|
||||
self.n_channels_embed,
|
||||
self.n_channels_out,
|
||||
)
|
||||
.init();
|
||||
let upsample = UpsampleConfig::new(self.n_channels_out).init();
|
||||
.init(device);
|
||||
let upsample = UpsampleConfig::new(self.n_channels_out).init(device);
|
||||
|
||||
ResUpSample { res, upsample }
|
||||
}
|
||||
@@ -280,17 +280,17 @@ pub struct ResTransformerUpsampleConfig {
|
||||
}
|
||||
|
||||
impl ResTransformerUpsampleConfig {
|
||||
fn init<B: Backend>(&self) -> ResTransformerUpsample<B> {
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> ResTransformerUpsample<B> {
|
||||
let res = ResBlockConfig::new(
|
||||
self.n_channels_in,
|
||||
self.n_channels_embed,
|
||||
self.n_channels_out,
|
||||
)
|
||||
.init();
|
||||
.init(device);
|
||||
let transformer =
|
||||
SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head)
|
||||
.init();
|
||||
let upsample = UpsampleConfig::new(self.n_channels_out).init();
|
||||
.init(device);
|
||||
let upsample = UpsampleConfig::new(self.n_channels_out).init(device);
|
||||
|
||||
ResTransformerUpsample {
|
||||
res,
|
||||
@@ -326,22 +326,22 @@ pub struct ResTransformerResConfig {
|
||||
}
|
||||
|
||||
impl ResTransformerResConfig {
|
||||
fn init<B: Backend>(&self) -> ResTransformerRes<B> {
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> ResTransformerRes<B> {
|
||||
let res1 = ResBlockConfig::new(
|
||||
self.n_channels_in,
|
||||
self.n_channels_embed,
|
||||
self.n_channels_out,
|
||||
)
|
||||
.init();
|
||||
.init(device);
|
||||
let transformer =
|
||||
SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head)
|
||||
.init();
|
||||
.init(device);
|
||||
let res2 = ResBlockConfig::new(
|
||||
self.n_channels_in,
|
||||
self.n_channels_embed,
|
||||
self.n_channels_out,
|
||||
)
|
||||
.init();
|
||||
.init(device);
|
||||
|
||||
ResTransformerRes {
|
||||
res1,
|
||||
@@ -373,10 +373,10 @@ pub struct UpsampleConfig {
|
||||
}
|
||||
|
||||
impl UpsampleConfig {
|
||||
fn init<B: Backend>(&self) -> Upsample<B> {
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> Upsample<B> {
|
||||
let conv = Conv2dConfig::new([self.n_channels, self.n_channels], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init();
|
||||
.init(device);
|
||||
|
||||
Upsample { conv }
|
||||
}
|
||||
@@ -392,8 +392,7 @@ impl<B: Backend> Upsample<B> {
|
||||
let [n_batch, n_channel, height, width] = x.dims();
|
||||
let x = x
|
||||
.reshape([n_batch, n_channel, height, 1, width, 1])
|
||||
.repeat(3, 2)
|
||||
.repeat(5, 2)
|
||||
.repeat(&[1, 1, 1, 2, 1, 2])
|
||||
.reshape([n_batch, n_channel, 2 * height, 2 * width]);
|
||||
self.conv.forward(x)
|
||||
}
|
||||
@@ -411,11 +410,11 @@ pub struct DownsampleConfig {
|
||||
}
|
||||
|
||||
impl DownsampleConfig {
|
||||
fn init<B: Backend>(&self) -> Conv2d<B> {
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> Conv2d<B> {
|
||||
Conv2dConfig::new([self.n_channels, self.n_channels], [3, 3])
|
||||
.with_stride([2, 2])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init()
|
||||
.init(device)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -435,12 +434,12 @@ pub struct SpatialTransformerConfig {
|
||||
}
|
||||
|
||||
impl SpatialTransformerConfig {
|
||||
fn init<B: Backend>(&self) -> SpatialTransformer<B> {
|
||||
let norm = GroupNormConfig::new(32, self.n_channels).init();
|
||||
let proj_in = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init();
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> SpatialTransformer<B> {
|
||||
let norm = GroupNormConfig::new(32, self.n_channels).init(device);
|
||||
let proj_in = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init(device);
|
||||
let transformer =
|
||||
TransformerBlockConfig::new(self.n_channels, self.n_context_state, self.n_head).init();
|
||||
let proj_out = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init();
|
||||
TransformerBlockConfig::new(self.n_channels, self.n_context_state, self.n_head).init(device);
|
||||
let proj_out = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init(device);
|
||||
|
||||
SpatialTransformer {
|
||||
norm,
|
||||
@@ -489,14 +488,14 @@ pub struct TransformerBlockConfig {
|
||||
}
|
||||
|
||||
impl TransformerBlockConfig {
|
||||
fn init<B: Backend>(&self) -> TransformerBlock<B> {
|
||||
let norm1 = nn::LayerNormConfig::new(self.n_state).init();
|
||||
let attn1 = MultiHeadAttentionConfig::new(self.n_state, self.n_state, self.n_head).init();
|
||||
let norm2 = nn::LayerNormConfig::new(self.n_state).init();
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> TransformerBlock<B> {
|
||||
let norm1 = nn::LayerNormConfig::new(self.n_state).init(device);
|
||||
let attn1 = MultiHeadAttentionConfig::new(self.n_state, self.n_state, self.n_head).init(device);
|
||||
let norm2 = nn::LayerNormConfig::new(self.n_state).init(device);
|
||||
let attn2 =
|
||||
MultiHeadAttentionConfig::new(self.n_state, self.n_context_state, self.n_head).init();
|
||||
let norm3 = nn::LayerNormConfig::new(self.n_state).init();
|
||||
let mlp = MLPConfig::new(self.n_state, 4).init();
|
||||
MultiHeadAttentionConfig::new(self.n_state, self.n_context_state, self.n_head).init(device);
|
||||
let norm3 = nn::LayerNormConfig::new(self.n_state).init(device);
|
||||
let mlp = MLPConfig::new(self.n_state, 4).init(device);
|
||||
|
||||
TransformerBlock {
|
||||
norm1,
|
||||
@@ -534,10 +533,10 @@ pub struct MLPConfig {
|
||||
}
|
||||
|
||||
impl MLPConfig {
|
||||
pub fn init<B: Backend>(&self) -> MLP<B> {
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> MLP<B> {
|
||||
let n_state_hidden = self.n_state * self.mult;
|
||||
let geglu = GEGLUConfig::new(self.n_state, n_state_hidden).init();
|
||||
let lin = nn::LinearConfig::new(n_state_hidden, self.n_state).init();
|
||||
let geglu = GEGLUConfig::new(self.n_state, n_state_hidden).init(device);
|
||||
let lin = nn::LinearConfig::new(n_state_hidden, self.n_state).init(device);
|
||||
|
||||
MLP { geglu, lin }
|
||||
}
|
||||
@@ -562,9 +561,9 @@ pub struct GEGLUConfig {
|
||||
}
|
||||
|
||||
impl GEGLUConfig {
|
||||
fn init<B: Backend>(&self) -> GEGLU<B> {
|
||||
let proj = nn::LinearConfig::new(self.n_state_in, 2 * self.n_state_out).init();
|
||||
let gelu = GELU::new();
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> GEGLU<B> {
|
||||
let proj = nn::LinearConfig::new(self.n_state_in, 2 * self.n_state_out).init(device);
|
||||
let gelu = Gelu::new();
|
||||
|
||||
GEGLU { proj, gelu }
|
||||
}
|
||||
@@ -573,7 +572,7 @@ impl GEGLUConfig {
|
||||
#[derive(Module, Debug)]
|
||||
pub struct GEGLU<B: Backend> {
|
||||
proj: nn::Linear<B>,
|
||||
gelu: GELU,
|
||||
gelu: Gelu,
|
||||
}
|
||||
|
||||
impl<B: Backend> GEGLU<B> {
|
||||
@@ -600,7 +599,7 @@ pub struct MultiHeadAttentionConfig {
|
||||
}
|
||||
|
||||
impl MultiHeadAttentionConfig {
|
||||
fn init<B: Backend>(&self) -> MultiHeadAttention<B> {
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> MultiHeadAttention<B> {
|
||||
assert!(
|
||||
self.n_state % self.n_head == 0,
|
||||
"State size {} must be a multiple of head size {}",
|
||||
@@ -611,14 +610,14 @@ impl MultiHeadAttentionConfig {
|
||||
let n_head = self.n_head;
|
||||
let query = nn::LinearConfig::new(self.n_state, self.n_state)
|
||||
.with_bias(false)
|
||||
.init();
|
||||
.init(device);
|
||||
let key = nn::LinearConfig::new(self.n_context_state, self.n_state)
|
||||
.with_bias(false)
|
||||
.init();
|
||||
.init(device);
|
||||
let value = nn::LinearConfig::new(self.n_context_state, self.n_state)
|
||||
.with_bias(false)
|
||||
.init();
|
||||
let out = nn::LinearConfig::new(self.n_state, self.n_state).init();
|
||||
.init(device);
|
||||
let out = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
|
||||
|
||||
MultiHeadAttention {
|
||||
n_head,
|
||||
@@ -661,24 +660,24 @@ pub struct ResBlockConfig {
|
||||
}
|
||||
|
||||
impl ResBlockConfig {
|
||||
fn init<B: Backend>(&self) -> ResBlock<B> {
|
||||
let norm_in = GroupNormConfig::new(32, self.n_channels_in).init();
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> ResBlock<B> {
|
||||
let norm_in = GroupNormConfig::new(32, self.n_channels_in).init(device);
|
||||
let silu_in = SILU::new();
|
||||
let conv_in = Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init();
|
||||
.init(device);
|
||||
|
||||
let silu_embed = SILU::new();
|
||||
let lin_embed = nn::LinearConfig::new(self.n_channels_embed, self.n_channels_out).init();
|
||||
let lin_embed = nn::LinearConfig::new(self.n_channels_embed, self.n_channels_out).init(device);
|
||||
|
||||
let norm_out = GroupNormConfig::new(32, self.n_channels_out).init();
|
||||
let norm_out = GroupNormConfig::new(32, self.n_channels_out).init(device);
|
||||
let silu_out = SILU::new();
|
||||
let conv_out = Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init();
|
||||
.init(device);
|
||||
|
||||
let skip_connection = if self.n_channels_in != self.n_channels_out {
|
||||
Some(Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [1, 1]).init())
|
||||
Some(Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [1, 1]).init(device))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user