Compare commits
10 Commits
c24d37df00
...
test
| Author | SHA1 | Date | |
|---|---|---|---|
| 754810ca88 | |||
|
|
6cfd6db5a5 | ||
|
|
893fb0950d | ||
|
|
9e4d7bd310 | ||
|
|
01b1aea897 | ||
|
|
f4c58c1790 | ||
|
|
a62795347f | ||
|
|
1830756917 | ||
|
|
b87273c2be | ||
|
|
31c24a82ef |
18
Cargo.toml
18
Cargo.toml
@@ -6,27 +6,17 @@ edition = "2021"
|
|||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["torch-backend"]
|
|
||||||
torch-backend = ["burn-tch"]
|
|
||||||
wgpu-backend = ["burn-wgpu"]
|
wgpu-backend = ["burn-wgpu"]
|
||||||
|
default = ["wgpu-backend"]
|
||||||
[dependencies.burn-tch]
|
|
||||||
package = "burn-tch"
|
|
||||||
git = "https://github.com/burn-rs/burn.git"
|
|
||||||
optional = true
|
|
||||||
|
|
||||||
[dependencies.burn-wgpu]
|
|
||||||
package = "burn-wgpu"
|
|
||||||
git = "https://github.com/burn-rs/burn.git"
|
|
||||||
optional = true
|
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
burn = { git = "https://github.com/burn-rs/burn.git" }
|
burn = "0.20.1"
|
||||||
|
burn-autodiff = "0.20.1"
|
||||||
|
burn-wgpu = { version = "0.20.1", optional = true }
|
||||||
serde = {version = "1.0.171", features = ["std", "derive"]}
|
serde = {version = "1.0.171", features = ["std", "derive"]}
|
||||||
npy = "0.4.0"
|
npy = "0.4.0"
|
||||||
num-traits = "0.2.15"
|
num-traits = "0.2.15"
|
||||||
rust_tokenizers = "8.1.0"
|
rust_tokenizers = "8.1.0"
|
||||||
regex = "1.9.1"
|
regex = "1.9.1"
|
||||||
image = "0.24.6"
|
image = "0.24.6"
|
||||||
bincode = {version = "2.0.0-alpha.0", features = ["std"]}
|
|
||||||
cfg-if = "0.1"
|
cfg-if = "0.1"
|
||||||
33
README.md
33
README.md
@@ -2,37 +2,29 @@
|
|||||||
|
|
||||||
Stable-Diffusion-Burn is a Rust-based project which ports the V1 stable diffusion model into the deep learning framework, Burn. This repository is licensed under the MIT Licence.
|
Stable-Diffusion-Burn is a Rust-based project which ports the V1 stable diffusion model into the deep learning framework, Burn. This repository is licensed under the MIT Licence.
|
||||||
|
|
||||||
## Support The Project
|
|
||||||
|
|
||||||
Stable-Diffusion-Burn is a passion project that is open and free to all. I want to empower everyone with reliable AI that can be run by ourselves on our own hardware to ensure that great AI is not limited to the hands of the few. If you support this vision consider supporting me so that I can continue on this journey and produce more projects such as Stable Diffusion XL in Rust.
|
|
||||||
|
|
||||||
You can show your support by buying a shirt at https://www.bonfire.com/machine-learning/. The shirt image was, of course, generated by my Rust powered Stable Diffusion! I'd love to release more projects and any support will help make that happen!
|
|
||||||
|
|
||||||
Any contribution would be greatly appreciated. Thanks!
|
|
||||||
|
|
||||||
## How To Use
|
## How To Use
|
||||||
|
|
||||||
### Step 1: Download the Model and Set Environment Variables
|
### Step 1: Download the Model and Set Environment Variables
|
||||||
|
|
||||||
Start by downloading the SDv1-4.bin model provided on HuggingFace.
|
Start by downloading the SDv1-4 model provided on HuggingFace.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
wget https://huggingface.co/Gadersd/Stable-Diffusion-Burn/resolve/main/V1/SDv1-4.bin
|
wget https://huggingface.co/Gadersd/Stable-Diffusion-Burn/resolve/main/SDv1-4.mpk
|
||||||
```
|
```
|
||||||
|
|
||||||
Next, set the appropriate CUDA version. It may be possible to run the model using wgpu without the need for torch in the future using `cargo run --features wgpu-backend...` but currently wgpu doesn't support buffer sizes large enough for Stable Diffusion.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
export TORCH_CUDA_VERSION=cu113
|
|
||||||
```
|
|
||||||
### Step 2: Run the Sample Binary
|
### Step 2: Run the Sample Binary
|
||||||
|
|
||||||
Invoke the sample binary provided in the rust code, as shown below:
|
Invoke the sample binary provided in the rust code. The application now uses a pure Rust backend (WGPU/Vulkan) instead of libtorch. The WGPU backend is unstable for SD but may work well in the future as burn-wpu is optimized.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Arguments: <model_type(burn or dump)> <model> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image>
|
# WGPU/Vulkan backend (GPU accelerated, requires Vulkan-compatible GPU)
|
||||||
|
# Arguments: <model_type(burn or dump)> <model_name> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image_name>
|
||||||
|
|
||||||
|
# GPU (Vulkan)
|
||||||
|
cargo run --release --features wgpu-backend --bin sample burn SDv1-4 7.5 20 "An ancient mossy stone." img
|
||||||
|
|
||||||
|
# CPU (UNSTABLE - fallback if GPU not available)
|
||||||
cargo run --release --bin sample burn SDv1-4 7.5 20 "An ancient mossy stone." img
|
cargo run --release --bin sample burn SDv1-4 7.5 20 "An ancient mossy stone." img
|
||||||
```
|
|
||||||
|
|
||||||
This command will generate an image according to the provided prompt, which will be saved as 'img0.png'.
|
This command will generate an image according to the provided prompt, which will be saved as 'img0.png'.
|
||||||
|
|
||||||
@@ -40,7 +32,7 @@ This command will generate an image according to the provided prompt, which will
|
|||||||
|
|
||||||
### Optional: Extract and Convert a Fine-Tuned Model
|
### Optional: Extract and Convert a Fine-Tuned Model
|
||||||
|
|
||||||
If users are interested in using a fine-tuned version of stable diffusion, the Python scripts provided in this project can be used to transform a weight dump into a Burn model file.
|
If users are interested in using a fine-tuned version of stable diffusion, the Python scripts provided in this project can be used to transform a weight dump into a Burn model file. This does not work on Windows.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Step into the Python directory
|
# Step into the Python directory
|
||||||
@@ -49,6 +41,9 @@ cd python
|
|||||||
# Download the model, this is just the base v1.4 model as an example
|
# Download the model, this is just the base v1.4 model as an example
|
||||||
wget https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt
|
wget https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt
|
||||||
|
|
||||||
|
# Install tinygrad
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
# Extract the weights
|
# Extract the weights
|
||||||
CPU=1 python3 dump.py sd-v1-4.ckpt
|
CPU=1 python3 dump.py sd-v1-4.ckpt
|
||||||
|
|
||||||
|
|||||||
BIN
img0.png
BIN
img0.png
Binary file not shown.
|
Before Width: | Height: | Size: 671 KiB After Width: | Height: | Size: 677 KiB |
@@ -13,10 +13,11 @@ from collections import namedtuple
|
|||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.helpers import dtypes, GlobalCounters
|
from tinygrad.helpers import GlobalCounters
|
||||||
|
from tinygrad import dtypes
|
||||||
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
|
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
|
||||||
from extra.utils import download_file
|
#from extra.utils import download_file
|
||||||
from tinygrad.state import torch_load, load_state_dict
|
from tinygrad.nn.state import torch_load, load_state_dict
|
||||||
|
|
||||||
# TODO: refactor AttnBlock, CrossAttention, CLIPAttention to share code
|
# TODO: refactor AttnBlock, CrossAttention, CLIPAttention to share code
|
||||||
|
|
||||||
|
|||||||
1
python/requirements.txt
Normal file
1
python/requirements.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
tinygrad==0.9.2
|
||||||
139
src/backend.rs
Normal file
139
src/backend.rs
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
use burn::tensor::{activation::softmax, Tensor};
|
||||||
|
use burn::prelude::Backend;
|
||||||
|
|
||||||
|
/*pub type FloatTensor<B, const D: usize> = <B as burn::tensor::backend::Backend>::TensorPrimitive<D>;
|
||||||
|
|
||||||
|
pub trait Backend: burn::tensor::backend::Backend {
|
||||||
|
fn qkv_attention(
|
||||||
|
q: FloatTensor<Self, 3>,
|
||||||
|
k: FloatTensor<Self, 3>,
|
||||||
|
v: FloatTensor<Self, 3>,
|
||||||
|
mask: Option<FloatTensor<Self, 2>>,
|
||||||
|
n_head: usize,
|
||||||
|
) -> FloatTensor<Self, 3> {
|
||||||
|
qkv_attention(
|
||||||
|
Tensor::<Self, 3>::from_primitive(q),
|
||||||
|
Tensor::from_primitive(k),
|
||||||
|
Tensor::from_primitive(v),
|
||||||
|
mask.map(|m| Tensor::from_primitive(m)),
|
||||||
|
n_head,
|
||||||
|
)
|
||||||
|
.into_primitive()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn attn_decoder_mask(seq_length: usize, device: &Self::Device) -> FloatTensor<Self, 2> {
|
||||||
|
attn_decoder_mask::<Self>(seq_length, device).into_primitive()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
use burn::tensor::Float;
|
||||||
|
use burn_tch::{self, TchElement, TchTensor};
|
||||||
|
use tch;
|
||||||
|
|
||||||
|
impl<E: TchElement> Backend for burn_tch::LibTorch<E> {
|
||||||
|
fn qkv_attention(
|
||||||
|
q: FloatTensor<Self, 3>,
|
||||||
|
k: FloatTensor<Self, 3>,
|
||||||
|
v: FloatTensor<Self, 3>,
|
||||||
|
mask: Option<FloatTensor<Self, 2>>,
|
||||||
|
n_head: usize,
|
||||||
|
) -> FloatTensor<Self, 2> {
|
||||||
|
let q = Tensor::from_primitive(q);
|
||||||
|
let k = Tensor::from_primitive(k);
|
||||||
|
let v = Tensor::from_primitive(v);
|
||||||
|
|
||||||
|
let [n_batch, q_ctx, n_state] = q.dims();
|
||||||
|
let [_, k_ctx, _] = k.dims();
|
||||||
|
let n_hstate = n_state / n_head;
|
||||||
|
|
||||||
|
let rearrange = |t: Tensor<Self, 3>| {
|
||||||
|
let [_, n_ctx, _] = t.dims();
|
||||||
|
t.reshape([n_batch, n_ctx, n_head, n_hstate])
|
||||||
|
.swap_dims(1, 2)
|
||||||
|
};
|
||||||
|
|
||||||
|
let q = rearrange(q).into_primitive();
|
||||||
|
let k = rearrange(k).into_primitive();
|
||||||
|
let v = rearrange(v).into_primitive();
|
||||||
|
|
||||||
|
// for some reason torch crashes when mask is None
|
||||||
|
let mask = mask.unwrap_or_else(|| {
|
||||||
|
Tensor::<Self, 2, Float>::zeros([q_ctx, k_ctx], &Self::device(&v))
|
||||||
|
.into_primitive()
|
||||||
|
});
|
||||||
|
|
||||||
|
Tensor::<Self, 4>::from_primitive(TchTensor::new(
|
||||||
|
tch::Tensor::scaled_dot_product_attention(
|
||||||
|
&q.tensor,
|
||||||
|
&k.tensor,
|
||||||
|
&v.tensor,
|
||||||
|
Some(mask.tensor),
|
||||||
|
0.0,
|
||||||
|
false,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
))
|
||||||
|
.swap_dims(1, 2)
|
||||||
|
.flatten(2, 3)
|
||||||
|
.into_primitive()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
use burn_autodiff;
|
||||||
|
|
||||||
|
impl<B: Backend> Backend for burn_autodiff::Autodiff<B> {}*/
|
||||||
|
|
||||||
|
use std::f32::NEG_INFINITY;
|
||||||
|
|
||||||
|
pub fn qkv_attention<B: Backend>(
|
||||||
|
q: Tensor<B, 3>,
|
||||||
|
k: Tensor<B, 3>,
|
||||||
|
v: Tensor<B, 3>,
|
||||||
|
mask: Option<Tensor<B, 2>>,
|
||||||
|
n_head: usize,
|
||||||
|
) -> Tensor<B, 3> {
|
||||||
|
let [n_batch, n_qctx, n_state] = q.dims();
|
||||||
|
let [_, n_ctx, _] = k.dims();
|
||||||
|
|
||||||
|
let scale = (n_state as f64 / n_head as f64).powf(-0.25);
|
||||||
|
let n_hstate = n_state / n_head;
|
||||||
|
|
||||||
|
let q = q
|
||||||
|
.reshape([n_batch, n_qctx, n_head, n_hstate])
|
||||||
|
.swap_dims(1, 2)
|
||||||
|
* scale;
|
||||||
|
let k = k
|
||||||
|
.reshape([n_batch, n_ctx, n_head, n_hstate])
|
||||||
|
.swap_dims(1, 2)
|
||||||
|
.transpose()
|
||||||
|
* scale;
|
||||||
|
let v = v
|
||||||
|
.reshape([n_batch, n_ctx, n_head, n_hstate])
|
||||||
|
.swap_dims(1, 2);
|
||||||
|
|
||||||
|
let qk = q.matmul(k);
|
||||||
|
|
||||||
|
// apply mask
|
||||||
|
let qk = if let Some(mask) = mask {
|
||||||
|
qk + mask.slice([0..n_qctx, 0..n_ctx]).unsqueeze::<4>()
|
||||||
|
} else {
|
||||||
|
qk
|
||||||
|
};
|
||||||
|
|
||||||
|
// normalize value weightings
|
||||||
|
let w = softmax(qk, 3);
|
||||||
|
let o = w.matmul(v).swap_dims(1, 2).flatten(2, 3);
|
||||||
|
|
||||||
|
return o;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn attn_decoder_mask<B: Backend>(seq_length: usize, device: &B::Device) -> Tensor<B, 2> {
|
||||||
|
let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length], device);
|
||||||
|
|
||||||
|
for i in 0..(seq_length - 1) {
|
||||||
|
let values = Tensor::<B, 2>::zeros([1, seq_length - (i + 1)], device).add_scalar(NEG_INFINITY);
|
||||||
|
mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values);
|
||||||
|
}
|
||||||
|
|
||||||
|
return mask;
|
||||||
|
}
|
||||||
@@ -1,32 +1,27 @@
|
|||||||
use std::env;
|
use std::env;
|
||||||
use std::process;
|
|
||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
|
use std::process;
|
||||||
|
|
||||||
use stablediffusion::model::stablediffusion::{StableDiffusion, load::load_stable_diffusion};
|
use stablediffusion::model::stablediffusion::{load::load_stable_diffusion, StableDiffusion};
|
||||||
|
|
||||||
use burn::{
|
use burn::{
|
||||||
config::Config,
|
config::Config,
|
||||||
module::{Module, Param},
|
module::{Module, Param},
|
||||||
nn,
|
nn,
|
||||||
tensor::{
|
tensor::{backend::Backend, Tensor},
|
||||||
backend::Backend,
|
|
||||||
Tensor,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
cfg_if::cfg_if! {
|
use burn_ndarray::{NdArray, NdArrayDevice};
|
||||||
if #[cfg(feature = "torch-backend")] {
|
|
||||||
use burn_tch::{TchBackend, TchDevice};
|
|
||||||
} else if #[cfg(feature = "wgpu-backend")] {
|
|
||||||
use burn_wgpu::{WgpuBackend, WgpuDevice, AutoGraphicsApi};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
use burn::record::{self, Recorder, BinFileRecorder, FullPrecisionSettings};
|
use burn::record::{self, NamedMpkFileRecorder, FullPrecisionSettings, Recorder};
|
||||||
|
|
||||||
fn convert_dump_to_model<B: Backend>(dump_path: &str, model_name: &str, device: &B::Device) -> Result<(), Box<dyn Error>> {
|
fn convert_dump_to_model<B: Backend>(
|
||||||
|
dump_path: &str,
|
||||||
|
model_name: &str,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Result<(), Box<dyn Error>> {
|
||||||
println!("Loading dump...");
|
println!("Loading dump...");
|
||||||
let model: StableDiffusion::<B> = load_stable_diffusion(dump_path, device)?;
|
let model: StableDiffusion<B> = load_stable_diffusion(dump_path, device)?;
|
||||||
|
|
||||||
println!("Saving model...");
|
println!("Saving model...");
|
||||||
save_model_file(model, model_name)?;
|
save_model_file(model, model_name)?;
|
||||||
@@ -34,24 +29,16 @@ fn convert_dump_to_model<B: Backend>(dump_path: &str, model_name: &str, device:
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn save_model_file<B: Backend>(model: StableDiffusion<B>, name: &str) -> Result<(), record::RecorderError> {
|
fn save_model_file<B: Backend>(
|
||||||
BinFileRecorder::<FullPrecisionSettings>::new()
|
model: StableDiffusion<B>,
|
||||||
.record(
|
name: &str,
|
||||||
model.into_record(),
|
) -> Result<(), record::RecorderError> {
|
||||||
name.into(),
|
NamedMpkFileRecorder::<FullPrecisionSettings>::new().record(model.into_record(), name.into())
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
cfg_if::cfg_if! {
|
type Backend = NdArray<f32>;
|
||||||
if #[cfg(feature = "torch-backend")] {
|
let device = NdArrayDevice::Cpu;
|
||||||
type Backend = TchBackend<f32>;
|
|
||||||
let device = TchDevice::Cpu;
|
|
||||||
} else if #[cfg(feature = "wgpu-backend")] {
|
|
||||||
type Backend = WgpuBackend<AutoGraphicsApi, f32, i32>;
|
|
||||||
let device = WgpuDevice::CPU;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let args: Vec<String> = env::args().collect();
|
let args: Vec<String> = env::args().collect();
|
||||||
if args.len() != 3 {
|
if args.len() != 3 {
|
||||||
|
|||||||
@@ -1,20 +1,20 @@
|
|||||||
use stablediffusion::{tokenizer::SimpleTokenizer, model::stablediffusion::{*, load::load_stable_diffusion}};
|
use stablediffusion::{
|
||||||
|
model::stablediffusion::{load::load_stable_diffusion, *},
|
||||||
|
tokenizer::SimpleTokenizer,
|
||||||
|
};
|
||||||
|
|
||||||
use burn::{
|
use burn::{
|
||||||
config::Config,
|
config::Config,
|
||||||
module::{Module, Param},
|
module::{Module, Param},
|
||||||
nn,
|
nn,
|
||||||
tensor::{
|
tensor::{backend::Backend, Tensor},
|
||||||
backend::Backend,
|
|
||||||
Tensor,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
cfg_if::cfg_if! {
|
cfg_if::cfg_if! {
|
||||||
if #[cfg(feature = "torch-backend")] {
|
if #[cfg(feature = "wgpu-backend")] {
|
||||||
use burn_tch::{TchBackend, TchDevice};
|
use burn_wgpu::{Wgpu, WgpuDevice};
|
||||||
} else if #[cfg(feature = "wgpu-backend")] {
|
} else {
|
||||||
use burn_wgpu::{WgpuBackend, WgpuDevice, AutoGraphicsApi};
|
use burn_ndarray::NdArrayDevice;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -22,28 +22,21 @@ use std::env;
|
|||||||
use std::io;
|
use std::io;
|
||||||
use std::process;
|
use std::process;
|
||||||
|
|
||||||
use burn::record::{self, Recorder, BinFileRecorder, FullPrecisionSettings};
|
use burn::record::{self, NamedMpkFileRecorder, FullPrecisionSettings, Recorder};
|
||||||
|
|
||||||
fn load_stable_diffusion_model_file<B: Backend>(filename: &str) -> Result<StableDiffusion<B>, record::RecorderError> {
|
fn load_stable_diffusion_model_file<B: Backend>(
|
||||||
BinFileRecorder::<FullPrecisionSettings>::new()
|
filename: &str,
|
||||||
.load(filename.into())
|
device: &B::Device,
|
||||||
.map(|record| StableDiffusionConfig::new().init().load_record(record))
|
) -> Result<StableDiffusion<B>, record::RecorderError> {
|
||||||
|
NamedMpkFileRecorder::<FullPrecisionSettings>::new()
|
||||||
|
.load(filename.into(), device)
|
||||||
|
.map(|record| StableDiffusionConfig::new().init(device).load_record(record))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
cfg_if::cfg_if! {
|
|
||||||
if #[cfg(feature = "torch-backend")] {
|
|
||||||
type Backend = TchBackend<f32>;
|
|
||||||
let device = TchDevice::Cuda(0);
|
|
||||||
} else if #[cfg(feature = "wgpu-backend")] {
|
|
||||||
type Backend = WgpuBackend<AutoGraphicsApi, f32, i32>;
|
|
||||||
let device = WgpuDevice::BestAvailable;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let args: Vec<String> = std::env::args().collect();
|
let args: Vec<String> = std::env::args().collect();
|
||||||
if args.len() != 7 {
|
if args.len() != 7 && args.len() != 8 {
|
||||||
eprintln!("Usage: {} <model_type(burn or dump)> <model_name> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image_name>", args[0]);
|
eprintln!("Usage: {} <model_type(burn or dump)> <model_name> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image_name> [device(cuda, mps, cpu)]", args[0]);
|
||||||
process::exit(1);
|
process::exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -60,11 +53,24 @@ fn main() {
|
|||||||
let prompt = &args[5];
|
let prompt = &args[5];
|
||||||
let output_image_name = &args[6];
|
let output_image_name = &args[6];
|
||||||
|
|
||||||
|
// Optional device parameter
|
||||||
|
let device_arg = if args.len() == 8 { Some(&args[7]) } else { None };
|
||||||
|
|
||||||
|
cfg_if::cfg_if! {
|
||||||
|
if #[cfg(feature = "wgpu-backend")] {
|
||||||
|
type Backend = Wgpu;
|
||||||
|
let device = WgpuDevice::BestAvailable;
|
||||||
|
} else {
|
||||||
|
type Backend = burn::backend::ndarray::NdArray<f32>;
|
||||||
|
let device = NdArrayDevice::Cpu;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
println!("Loading tokenizer...");
|
println!("Loading tokenizer...");
|
||||||
let tokenizer = SimpleTokenizer::new().unwrap();
|
let tokenizer = SimpleTokenizer::new().unwrap();
|
||||||
println!("Loading model...");
|
println!("Loading model...");
|
||||||
let sd: StableDiffusion<Backend> = if model_type == "burn" {
|
let sd: StableDiffusion<Backend> = if model_type == "burn" {
|
||||||
load_stable_diffusion_model_file(model_name).unwrap_or_else(|err| {
|
load_stable_diffusion_model_file(model_name, &device).unwrap_or_else(|err| {
|
||||||
eprintln!("Error loading model: {}", err);
|
eprintln!("Error loading model: {}", err);
|
||||||
process::exit(1);
|
process::exit(1);
|
||||||
})
|
})
|
||||||
@@ -75,20 +81,23 @@ fn main() {
|
|||||||
})
|
})
|
||||||
};
|
};
|
||||||
|
|
||||||
let sd = sd.to_device(&device);
|
|
||||||
|
|
||||||
let unconditional_context = sd.unconditional_context(&tokenizer);
|
let unconditional_context = sd.unconditional_context(&tokenizer);
|
||||||
let context = sd.context(&tokenizer, prompt).unsqueeze().repeat(0, 2); // generate 2 samples
|
let context = sd.context(&tokenizer, prompt).unsqueeze::<3>(); //.repeat(0, 2); // generate 2 samples
|
||||||
|
|
||||||
println!("Sampling image...");
|
println!("Sampling image...");
|
||||||
let images = sd.sample_image(context, unconditional_context, unconditional_guidance_scale, n_steps);
|
let images = sd.sample_image(
|
||||||
|
context,
|
||||||
|
unconditional_context,
|
||||||
|
unconditional_guidance_scale,
|
||||||
|
n_steps,
|
||||||
|
);
|
||||||
save_images(&images, output_image_name, 512, 512).unwrap_or_else(|err| {
|
save_images(&images, output_image_name, 512, 512).unwrap_or_else(|err| {
|
||||||
eprintln!("Error saving image: {}", err);
|
eprintln!("Error saving image: {}", err);
|
||||||
process::exit(1);
|
process::exit(1);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
use image::{self, ImageResult, ColorType::Rgb8};
|
use image::{self, ColorType::Rgb8, ImageResult};
|
||||||
|
|
||||||
fn save_images(images: &Vec<Vec<u8>>, basepath: &str, width: u32, height: u32) -> ImageResult<()> {
|
fn save_images(images: &Vec<Vec<u8>>, basepath: &str, width: u32, height: u32) -> ImageResult<()> {
|
||||||
for (index, img_data) in images.iter().enumerate() {
|
for (index, img_data) in images.iter().enumerate() {
|
||||||
@@ -103,12 +112,15 @@ fn save_images(images: &Vec<Vec<u8>>, basepath: &str, width: u32, height: u32) -
|
|||||||
fn save_test_image() -> ImageResult<()> {
|
fn save_test_image() -> ImageResult<()> {
|
||||||
let width = 256;
|
let width = 256;
|
||||||
let height = 256;
|
let height = 256;
|
||||||
let raw: Vec<_> = (0..width * height).into_iter().flat_map(|i| {
|
let raw: Vec<_> = (0..width * height)
|
||||||
|
.into_iter()
|
||||||
|
.flat_map(|i| {
|
||||||
let row = i / width;
|
let row = i / width;
|
||||||
let red = (255.0 * row as f64 / height as f64) as u8;
|
let red = (255.0 * row as f64 / height as f64) as u8;
|
||||||
|
|
||||||
[red, 0, 0]
|
[red, 0, 0]
|
||||||
}).collect();
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
image::save_buffer("red.png", &raw[..], width, height, Rgb8)
|
image::save_buffer("red.png", &raw[..], width, height, Rgb8)
|
||||||
}
|
}
|
||||||
@@ -1,87 +0,0 @@
|
|||||||
use burn::{
|
|
||||||
tensor::{
|
|
||||||
backend::Backend,
|
|
||||||
activation::relu,
|
|
||||||
Tensor,
|
|
||||||
Int,
|
|
||||||
Bool,
|
|
||||||
Float,
|
|
||||||
TensorKind,
|
|
||||||
BasicOps,
|
|
||||||
Numeric,
|
|
||||||
Element,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
use num_traits::ToPrimitive;
|
|
||||||
|
|
||||||
|
|
||||||
pub fn tensor_max_scalar<B: Backend, const D: usize>(x: Tensor<B, D>, max: f64) -> Tensor<B, D> {
|
|
||||||
relu(x.sub_scalar(max)).add_scalar(max)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn tensor_min_scalar<B: Backend, const D: usize>(x: Tensor<B, D>, min: f64) -> Tensor<B, D> {
|
|
||||||
-tensor_max_scalar(-x, -min)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn tensor_max<B: Backend, const D: usize>(x: Tensor<B, D>, max: Tensor<B, D>) -> Tensor<B, D> {
|
|
||||||
relu(x - max.clone()) + max
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn tensor_min<B: Backend, const D: usize>(x: Tensor<B, D>, min: Tensor<B, D>) -> Tensor<B, D> {
|
|
||||||
-tensor_max(-x, -min)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn tensor_log10<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
|
|
||||||
let ln10 = (10.0f64).ln();
|
|
||||||
x.log() / ln10
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn tensor_max_element<B: Backend, const D: usize>(x: Tensor<B, D>) -> f64 {
|
|
||||||
let flat: Tensor<B, 1> = x.flatten(0, D - 1);
|
|
||||||
let max_index = flat.clone().argmax(0);
|
|
||||||
|
|
||||||
flat.select(0, max_index).into_scalar().to_f64().unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn all_zeros<B: Backend, const D: usize>(x: Tensor<B, D>) -> bool {
|
|
||||||
x.powf(2.0).sum().into_scalar().to_f64().unwrap() == 0.0
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn max_dim<B: Backend>(x: Tensor<B, 2>, dim: usize) -> Tensor<B, 2> {
|
|
||||||
let indices = x.clone().argmax(dim).flatten(0, 1);
|
|
||||||
x.select(dim, indices)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn _10pow<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
|
|
||||||
let log10 = (10.0f64).ln();
|
|
||||||
(x * log10).exp()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn to_float<B: Backend, const D: usize>(x: Tensor<B, D, Int>) -> Tensor<B, D, Float> {
|
|
||||||
let device = x.device();
|
|
||||||
Tensor::from_data(
|
|
||||||
x
|
|
||||||
.into_data()
|
|
||||||
.convert()
|
|
||||||
).to_device(&device)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn to_float_bool<B: Backend, const D: usize>(x: Tensor<B, D, Bool>) -> Tensor<B, D, Float> {
|
|
||||||
let device = x.device();
|
|
||||||
Tensor::from_data(
|
|
||||||
x
|
|
||||||
.into_data()
|
|
||||||
.convert()
|
|
||||||
).to_device(&device)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn reverse<B: Backend, const D: usize, K: TensorKind<B> + BasicOps<B> + Numeric<B>>(x: Tensor<B, D, K>, dim: usize) -> Tensor<B, D, K> where <K as BasicOps<B>>::Elem: Element {
|
|
||||||
let len = x.dims()[dim];
|
|
||||||
let indices = -Tensor::arange_device(0..len, &x.device()) + (len - 1) as i64;
|
|
||||||
x.select(dim, indices)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn div_roundup(x: usize, y: usize) -> usize {
|
|
||||||
(x + y - 1) / y
|
|
||||||
}
|
|
||||||
@@ -1,3 +1,3 @@
|
|||||||
|
pub mod backend;
|
||||||
pub mod model;
|
pub mod model;
|
||||||
pub mod tokenizer;
|
pub mod tokenizer;
|
||||||
pub mod helper;
|
|
||||||
@@ -1,23 +1,32 @@
|
|||||||
use burn::{
|
use burn::tensor::{activation::softmax, backend::Backend, Tensor};
|
||||||
tensor::{
|
|
||||||
backend::Backend,
|
|
||||||
activation::softmax,
|
|
||||||
Tensor,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
use std::f32::NEG_INFINITY;
|
use std::f32::NEG_INFINITY;
|
||||||
|
|
||||||
pub fn qkv_attention<B: Backend>(q: Tensor<B, 3>, k: Tensor<B, 3>, v: Tensor<B, 3>, mask: Option<Tensor<B, 2>>, n_head: usize) -> Tensor<B, 3> {
|
pub fn qkv_attention<B: Backend>(
|
||||||
|
q: Tensor<B, 3>,
|
||||||
|
k: Tensor<B, 3>,
|
||||||
|
v: Tensor<B, 3>,
|
||||||
|
mask: Option<Tensor<B, 2>>,
|
||||||
|
n_head: usize,
|
||||||
|
) -> Tensor<B, 3> {
|
||||||
let [n_batch, n_qctx, n_state] = q.dims();
|
let [n_batch, n_qctx, n_state] = q.dims();
|
||||||
let [_, n_ctx, _] = k.dims();
|
let [_, n_ctx, _] = k.dims();
|
||||||
|
|
||||||
let scale = (n_state as f64 / n_head as f64).powf(-0.25);
|
let scale = (n_state as f64 / n_head as f64).powf(-0.25);
|
||||||
let n_hstate = n_state / n_head;
|
let n_hstate = n_state / n_head;
|
||||||
|
|
||||||
let q = q.reshape([n_batch, n_qctx, n_head, n_hstate]).swap_dims(1, 2) * scale;
|
let q = q
|
||||||
let k = k.reshape([n_batch, n_ctx, n_head, n_hstate]).swap_dims(1, 2).transpose() * scale;
|
.reshape([n_batch, n_qctx, n_head, n_hstate])
|
||||||
let v = v.reshape([n_batch, n_ctx, n_head, n_hstate]).swap_dims(1, 2);
|
.swap_dims(1, 2)
|
||||||
|
* scale;
|
||||||
|
let k = k
|
||||||
|
.reshape([n_batch, n_ctx, n_head, n_hstate])
|
||||||
|
.swap_dims(1, 2)
|
||||||
|
.transpose()
|
||||||
|
* scale;
|
||||||
|
let v = v
|
||||||
|
.reshape([n_batch, n_ctx, n_head, n_hstate])
|
||||||
|
.swap_dims(1, 2);
|
||||||
|
|
||||||
let qk = q.matmul(k);
|
let qk = q.matmul(k);
|
||||||
|
|
||||||
@@ -36,12 +45,12 @@ pub fn qkv_attention<B: Backend>(q: Tensor<B, 3>, k: Tensor<B, 3>, v: Tensor<B,
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn attn_decoder_mask<B: Backend>(seq_length: usize, device: &B::Device) -> Tensor<B, 2> {
|
pub fn attn_decoder_mask<B: Backend>(seq_length: usize, device: &B::Device) -> Tensor<B, 2> {
|
||||||
let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length]);
|
let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length], device);
|
||||||
|
|
||||||
for i in 0..(seq_length - 1) {
|
for i in 0..(seq_length - 1) {
|
||||||
let values = Tensor::<B, 2>::zeros([1, seq_length - (i + 1)]).add_scalar(NEG_INFINITY);
|
let values = Tensor::<B, 2>::zeros([1, seq_length - (i + 1)], device).add_scalar(NEG_INFINITY);
|
||||||
mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values);
|
mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values);
|
||||||
}
|
}
|
||||||
|
|
||||||
return mask.to_device(device);
|
return mask;
|
||||||
}
|
}
|
||||||
@@ -7,26 +7,35 @@ use burn::{
|
|||||||
config::Config,
|
config::Config,
|
||||||
module::{Module, Param},
|
module::{Module, Param},
|
||||||
nn,
|
nn,
|
||||||
tensor::{
|
tensor::{backend::Backend, Tensor},
|
||||||
backend::Backend,
|
|
||||||
Tensor,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::model::groupnorm::load::load_group_norm;
|
use crate::model::groupnorm::load::load_group_norm;
|
||||||
|
|
||||||
fn load_conv_self_attention_block<B: Backend>(path: &str, device: &B::Device) -> Result<ConvSelfAttentionBlock<B>, Box<dyn Error>> {
|
fn load_conv_self_attention_block<B: Backend>(
|
||||||
|
path: &str,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Result<ConvSelfAttentionBlock<B>, Box<dyn Error>> {
|
||||||
let norm = load_group_norm(&format!("{}/{}", path, "norm"), device)?;
|
let norm = load_group_norm(&format!("{}/{}", path, "norm"), device)?;
|
||||||
let q = load_conv2d(&format!("{}/{}", path, "q"), device)?;
|
let q = load_conv2d(&format!("{}/{}", path, "q"), device)?;
|
||||||
let k = load_conv2d(&format!("{}/{}", path, "k"), device)?;
|
let k = load_conv2d(&format!("{}/{}", path, "k"), device)?;
|
||||||
let v = load_conv2d(&format!("{}/{}", path, "v"), device)?;
|
let v = load_conv2d(&format!("{}/{}", path, "v"), device)?;
|
||||||
let proj_out = load_conv2d(&format!("{}/{}", path, "proj_out"), device)?;
|
let proj_out = load_conv2d(&format!("{}/{}", path, "proj_out"), device)?;
|
||||||
|
|
||||||
Ok(ConvSelfAttentionBlock { norm, q, k, v, proj_out })
|
Ok(ConvSelfAttentionBlock {
|
||||||
|
norm,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
proj_out,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load_resnet_block<B: Backend>(path: &str, device: &B::Device) -> Result<ResnetBlock<B>, Box<dyn Error>> {
|
fn load_resnet_block<B: Backend>(
|
||||||
|
path: &str,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Result<ResnetBlock<B>, Box<dyn Error>> {
|
||||||
let norm1 = load_group_norm(&format!("{}/{}", path, "norm1"), device)?;
|
let norm1 = load_group_norm(&format!("{}/{}", path, "norm1"), device)?;
|
||||||
let silu1 = SILU {};
|
let silu1 = SILU {};
|
||||||
let conv1 = load_conv2d(&format!("{}/{}", path, "conv1"), device)?;
|
let conv1 = load_conv2d(&format!("{}/{}", path, "conv1"), device)?;
|
||||||
@@ -35,7 +44,15 @@ fn load_resnet_block<B: Backend>(path: &str, device: &B::Device) -> Result<Resne
|
|||||||
let conv2 = load_conv2d(&format!("{}/{}", path, "conv2"), device)?;
|
let conv2 = load_conv2d(&format!("{}/{}", path, "conv2"), device)?;
|
||||||
let nin_shortcut = load_conv2d(&format!("{}/{}", path, "nin_shortcut"), device).ok();
|
let nin_shortcut = load_conv2d(&format!("{}/{}", path, "nin_shortcut"), device).ok();
|
||||||
|
|
||||||
Ok(ResnetBlock { norm1, silu1, conv1, norm2, silu2, conv2, nin_shortcut })
|
Ok(ResnetBlock {
|
||||||
|
norm1,
|
||||||
|
silu1,
|
||||||
|
conv1,
|
||||||
|
norm2,
|
||||||
|
silu2,
|
||||||
|
conv2,
|
||||||
|
nin_shortcut,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load_mid<B: Backend>(path: &str, device: &B::Device) -> Result<Mid<B>, Box<dyn Error>> {
|
fn load_mid<B: Backend>(path: &str, device: &B::Device) -> Result<Mid<B>, Box<dyn Error>> {
|
||||||
@@ -43,11 +60,18 @@ fn load_mid<B: Backend>(path: &str, device: &B::Device) -> Result<Mid<B>, Box<dy
|
|||||||
let attn = load_conv_self_attention_block(&format!("{}/{}", path, "attn"), device)?;
|
let attn = load_conv_self_attention_block(&format!("{}/{}", path, "attn"), device)?;
|
||||||
let block_2 = load_resnet_block(&format!("{}/{}", path, "block_2"), device)?;
|
let block_2 = load_resnet_block(&format!("{}/{}", path, "block_2"), device)?;
|
||||||
|
|
||||||
Ok(Mid { block_1, attn, block_2 })
|
Ok(Mid {
|
||||||
|
block_1,
|
||||||
|
attn,
|
||||||
|
block_2,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load_padded_conv2d<B: Backend>(path: &str, device: &B::Device) -> Result<PaddedConv2d<B>, Box<dyn Error>> {
|
fn load_padded_conv2d<B: Backend>(
|
||||||
let conv = load_conv2d(&format!("{}/{}", path, "conv"), device)?;
|
path: &str,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Result<PaddedConv2d<B>, Box<dyn Error>> {
|
||||||
|
let mut conv = load_conv2d(&format!("{}/{}", path, "conv"), device)?;
|
||||||
|
|
||||||
let channels = load_tensor::<B, 1>("channels", path, device)?;
|
let channels = load_tensor::<B, 1>("channels", path, device)?;
|
||||||
let channels = tensor_to_array_2(channels);
|
let channels = tensor_to_array_2(channels);
|
||||||
@@ -57,35 +81,55 @@ fn load_padded_conv2d<B: Backend>(path: &str, device: &B::Device) -> Result<Padd
|
|||||||
|
|
||||||
let padding = load_tensor::<B, 1>("padding", path, device)?;
|
let padding = load_tensor::<B, 1>("padding", path, device)?;
|
||||||
let padding: [usize; 4] = tensor_to_array(padding);
|
let padding: [usize; 4] = tensor_to_array(padding);
|
||||||
let padding = Padding::new(padding[0], padding[1], padding[2], padding[3]);
|
let padding = PaddingCfg::new(padding[0], padding[1], padding[2], padding[3]);
|
||||||
|
|
||||||
let mut record = conv.into_record();
|
//let mut record = conv.into_record();
|
||||||
|
|
||||||
let mut padded_conv: PaddedConv2d<B> = PaddedConv2dConfig::new(channels, kernel_size, padding).with_stride(stride).init();
|
let mut padded_conv: PaddedConv2d<B> = PaddedConv2dConfig::new(channels, kernel_size, padding)
|
||||||
let padding_actual = PaddingConfig2d::Explicit(padded_conv.padding_actual[0], padded_conv.padding_actual[1]);
|
.with_stride(stride)
|
||||||
|
.init(device);
|
||||||
|
let padding_actual =
|
||||||
|
PaddingConfig2d::Explicit(padded_conv.padding_actual[0], padded_conv.padding_actual[1]);
|
||||||
|
|
||||||
record.padding = <PaddingConfig2d as Module<B>>::into_record(padding_actual);
|
conv.padding = burn::module::Ignored(padding_actual);
|
||||||
padded_conv.conv = padded_conv.conv.load_record(record);
|
padded_conv.conv = conv;
|
||||||
|
|
||||||
|
//record.padding = <PaddingConfig2d as Module<B>>::into_record(padding_actual);
|
||||||
|
//padded_conv.conv = padded_conv.conv.load_record(record);
|
||||||
|
|
||||||
Ok(padded_conv)
|
Ok(padded_conv)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load_decoder_block<B: Backend>(path: &str, device: &B::Device) -> Result<DecoderBlock<B>, Box<dyn Error>> {
|
fn load_decoder_block<B: Backend>(
|
||||||
|
path: &str,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Result<DecoderBlock<B>, Box<dyn Error>> {
|
||||||
let res1 = load_resnet_block(&format!("{}/{}", path, "res1"), device)?;
|
let res1 = load_resnet_block(&format!("{}/{}", path, "res1"), device)?;
|
||||||
let res2 = load_resnet_block(&format!("{}/{}", path, "res2"), device)?;
|
let res2 = load_resnet_block(&format!("{}/{}", path, "res2"), device)?;
|
||||||
let res3 = load_resnet_block(&format!("{}/{}", path, "res3"), device)?;
|
let res3 = load_resnet_block(&format!("{}/{}", path, "res3"), device)?;
|
||||||
let upsampler = load_conv2d(&format!("{}/{}", path, "upsampler"), device).ok();
|
let upsampler = load_conv2d(&format!("{}/{}", path, "upsampler"), device).ok();
|
||||||
|
|
||||||
Ok(DecoderBlock { res1, res2, res3, upsampler })
|
Ok(DecoderBlock {
|
||||||
|
res1,
|
||||||
|
res2,
|
||||||
|
res3,
|
||||||
|
upsampler,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load_encoder_block<B: Backend>(path: &str, device: &B::Device) -> Result<EncoderBlock<B>, Box<dyn Error>> {
|
fn load_encoder_block<B: Backend>(
|
||||||
|
path: &str,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Result<EncoderBlock<B>, Box<dyn Error>> {
|
||||||
let res1 = load_resnet_block(&format!("{}/{}", path, "res1"), device)?;
|
let res1 = load_resnet_block(&format!("{}/{}", path, "res1"), device)?;
|
||||||
let res2 = load_resnet_block(&format!("{}/{}", path, "res2"), device)?;
|
let res2 = load_resnet_block(&format!("{}/{}", path, "res2"), device)?;
|
||||||
let downsampler = load_padded_conv2d(&format!("{}/{}", path, "downsampler"), device).ok();
|
let downsampler = load_padded_conv2d(&format!("{}/{}", path, "downsampler"), device).ok();
|
||||||
|
|
||||||
Ok(EncoderBlock { res1, res2, downsampler })
|
Ok(EncoderBlock {
|
||||||
|
res1,
|
||||||
|
res2,
|
||||||
|
downsampler,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load_decoder<B: Backend>(path: &str, device: &B::Device) -> Result<Decoder<B>, Box<dyn Error>> {
|
fn load_decoder<B: Backend>(path: &str, device: &B::Device) -> Result<Decoder<B>, Box<dyn Error>> {
|
||||||
@@ -95,15 +139,21 @@ fn load_decoder<B: Backend>(path: &str, device: &B::Device) -> Result<Decoder<B>
|
|||||||
let n_block = load_usize::<B>("n_block", path, device)?;
|
let n_block = load_usize::<B>("n_block", path, device)?;
|
||||||
let mut blocks = (0..n_block)
|
let mut blocks = (0..n_block)
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|i| {
|
.map(|i| load_decoder_block::<B>(&format!("{}/blocks/{}", path, i), device))
|
||||||
load_decoder_block::<B>(&format!("{}/blocks/{}", path, i), device)
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
}).collect::<Result<Vec<_>, _>>()?;
|
|
||||||
|
|
||||||
let norm_out = load_group_norm(&format!("{}/{}", path, "norm_out"), device)?;
|
let norm_out = load_group_norm(&format!("{}/{}", path, "norm_out"), device)?;
|
||||||
let silu = SILU {};
|
let silu = SILU {};
|
||||||
let conv_out = load_conv2d(&format!("{}/{}", path, "conv_out"), device)?;
|
let conv_out = load_conv2d(&format!("{}/{}", path, "conv_out"), device)?;
|
||||||
|
|
||||||
Ok(Decoder { conv_in, mid, blocks, norm_out, silu, conv_out })
|
Ok(Decoder {
|
||||||
|
conv_in,
|
||||||
|
mid,
|
||||||
|
blocks,
|
||||||
|
norm_out,
|
||||||
|
silu,
|
||||||
|
conv_out,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load_encoder<B: Backend>(path: &str, device: &B::Device) -> Result<Encoder<B>, Box<dyn Error>> {
|
fn load_encoder<B: Backend>(path: &str, device: &B::Device) -> Result<Encoder<B>, Box<dyn Error>> {
|
||||||
@@ -113,22 +163,36 @@ fn load_encoder<B: Backend>(path: &str, device: &B::Device) -> Result<Encoder<B>
|
|||||||
let n_block = load_usize::<B>("n_block", path, device)?;
|
let n_block = load_usize::<B>("n_block", path, device)?;
|
||||||
let mut blocks = (0..n_block)
|
let mut blocks = (0..n_block)
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|i| {
|
.map(|i| load_encoder_block::<B>(&format!("{}/blocks/{}", path, i), device))
|
||||||
load_encoder_block::<B>(&format!("{}/blocks/{}", path, i), device)
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
}).collect::<Result<Vec<_>, _>>()?;
|
|
||||||
|
|
||||||
let norm_out = load_group_norm(&format!("{}/{}", path, "norm_out"), device)?;
|
let norm_out = load_group_norm(&format!("{}/{}", path, "norm_out"), device)?;
|
||||||
let silu = SILU {};
|
let silu = SILU {};
|
||||||
let conv_out = load_conv2d(&format!("{}/{}", path, "conv_out"), device)?;
|
let conv_out = load_conv2d(&format!("{}/{}", path, "conv_out"), device)?;
|
||||||
|
|
||||||
Ok(Encoder { conv_in, mid, blocks, norm_out, silu, conv_out })
|
Ok(Encoder {
|
||||||
|
conv_in,
|
||||||
|
mid,
|
||||||
|
blocks,
|
||||||
|
norm_out,
|
||||||
|
silu,
|
||||||
|
conv_out,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load_autoencoder<B: Backend>(path: &str, device: &B::Device) -> Result<Autoencoder<B>, Box<dyn Error>> {
|
pub fn load_autoencoder<B: Backend>(
|
||||||
|
path: &str,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Result<Autoencoder<B>, Box<dyn Error>> {
|
||||||
let encoder = load_encoder(&format!("{}/{}", path, "encoder"), device)?;
|
let encoder = load_encoder(&format!("{}/{}", path, "encoder"), device)?;
|
||||||
let decoder = load_decoder(&format!("{}/{}", path, "decoder"), device)?;
|
let decoder = load_decoder(&format!("{}/{}", path, "decoder"), device)?;
|
||||||
let quant_conv = load_conv2d(&format!("{}/{}", path, "quant_conv"), device)?;
|
let quant_conv = load_conv2d(&format!("{}/{}", path, "quant_conv"), device)?;
|
||||||
let post_quant_conv = load_conv2d(&format!("{}/{}", path, "post_quant_conv"), device)?;
|
let post_quant_conv = load_conv2d(&format!("{}/{}", path, "post_quant_conv"), device)?;
|
||||||
|
|
||||||
Ok(Autoencoder { encoder, decoder, quant_conv, post_quant_conv })
|
Ok(Autoencoder {
|
||||||
|
encoder,
|
||||||
|
decoder,
|
||||||
|
quant_conv,
|
||||||
|
post_quant_conv,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
@@ -3,35 +3,37 @@ pub mod load;
|
|||||||
use burn::{
|
use burn::{
|
||||||
config::Config,
|
config::Config,
|
||||||
module::{Module, Param},
|
module::{Module, Param},
|
||||||
nn::{self, PaddingConfig2d, conv::{Conv2d, Conv2dConfig, Conv2dRecord}},
|
nn::{
|
||||||
|
self,
|
||||||
|
conv::{Conv2d, Conv2dConfig, Conv2dRecord},
|
||||||
|
PaddingConfig2d,
|
||||||
|
},
|
||||||
tensor::{
|
tensor::{
|
||||||
|
activation::{sigmoid, softmax},
|
||||||
backend::Backend,
|
backend::Backend,
|
||||||
activation::{softmax, sigmoid},
|
|
||||||
module::embedding,
|
module::embedding,
|
||||||
Tensor,
|
Distribution, Int, Tensor,
|
||||||
Distribution,
|
|
||||||
Int,
|
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::helper::div_roundup;
|
|
||||||
|
|
||||||
use super::silu::*;
|
|
||||||
use super::groupnorm::*;
|
use super::groupnorm::*;
|
||||||
use super::attention::qkv_attention;
|
use super::silu::*;
|
||||||
|
//use crate::backend::Backend as MyBackend;
|
||||||
|
use crate::backend::{qkv_attention, attn_decoder_mask};
|
||||||
|
|
||||||
use std::iter;
|
use std::iter;
|
||||||
|
|
||||||
|
#[derive(Config, Debug)]
|
||||||
#[derive(Config)]
|
|
||||||
pub struct AutoencoderConfig {}
|
pub struct AutoencoderConfig {}
|
||||||
|
|
||||||
impl AutoencoderConfig {
|
impl AutoencoderConfig {
|
||||||
pub fn init<B: Backend>(&self) -> Autoencoder<B> {
|
pub fn init<B: Backend>(&self, device: &B::Device) -> Autoencoder<B> {
|
||||||
let encoder = EncoderConfig::new(vec![(128, 128), (128, 256), (256, 512), (512, 512)], 32, 8).init();
|
let encoder =
|
||||||
let decoder = DecoderConfig::new(vec![(512, 512), (512, 512), (512, 256), (256, 128)], 32).init();
|
EncoderConfig::new(vec![(128, 128), (128, 256), (256, 512), (512, 512)], 32, 8).init(device);
|
||||||
let quant_conv = Conv2dConfig::new([8, 8], [1, 1]).init();
|
let decoder =
|
||||||
let post_quant_conv = Conv2dConfig::new([4, 4], [1, 1]).init();
|
DecoderConfig::new(vec![(512, 512), (512, 512), (512, 256), (256, 128)], 32).init(device);
|
||||||
|
let quant_conv = Conv2dConfig::new([8, 8], [1, 1]).init(device);
|
||||||
|
let post_quant_conv = Conv2dConfig::new([4, 4], [1, 1]).init(device);
|
||||||
|
|
||||||
Autoencoder {
|
Autoencoder {
|
||||||
encoder,
|
encoder,
|
||||||
@@ -42,7 +44,6 @@ impl AutoencoderConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[derive(Module, Debug)]
|
#[derive(Module, Debug)]
|
||||||
pub struct Autoencoder<B: Backend> {
|
pub struct Autoencoder<B: Backend> {
|
||||||
encoder: Encoder<B>,
|
encoder: Encoder<B>,
|
||||||
@@ -70,7 +71,7 @@ impl<B: Backend> Autoencoder<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct EncoderConfig {
|
pub struct EncoderConfig {
|
||||||
channels: Vec<(usize, usize)>,
|
channels: Vec<(usize, usize)>,
|
||||||
n_group: usize,
|
n_group: usize,
|
||||||
@@ -78,21 +79,34 @@ pub struct EncoderConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl EncoderConfig {
|
impl EncoderConfig {
|
||||||
fn init<B: Backend>(&self) -> Encoder<B> {
|
fn init<B: Backend>(&self, device: &B::Device) -> Encoder<B> {
|
||||||
let n_expanded_channels_initial = self.channels.first().map(|f| f.1).expect("Channels must not be empty.");
|
let n_expanded_channels_initial = self
|
||||||
|
.channels
|
||||||
|
.first()
|
||||||
|
.map(|f| f.1)
|
||||||
|
.expect("Channels must not be empty.");
|
||||||
let n_expanded_channels_final = self.channels.first().unwrap().0;
|
let n_expanded_channels_final = self.channels.first().unwrap().0;
|
||||||
|
|
||||||
let conv_in = Conv2dConfig::new([3, n_expanded_channels_initial], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
let conv_in = Conv2dConfig::new([3, n_expanded_channels_initial], [3, 3])
|
||||||
|
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||||
|
.init(device);
|
||||||
|
|
||||||
let blocks = self.channels.iter().enumerate().map(|(i, &(n_channel_in, n_channel_out))| {
|
let blocks = self
|
||||||
|
.channels
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, &(n_channel_in, n_channel_out))| {
|
||||||
let downsample = i != self.channels.len() - 1;
|
let downsample = i != self.channels.len() - 1;
|
||||||
EncoderBlockConfig::new(n_channel_in, n_channel_out, downsample).init()
|
EncoderBlockConfig::new(n_channel_in, n_channel_out, downsample).init(device)
|
||||||
}).collect();
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
let mid = MidConfig::new(n_expanded_channels_final).init();
|
let mid = MidConfig::new(n_expanded_channels_final).init(device);
|
||||||
let norm_out = GroupNormConfig::new(self.n_group, n_expanded_channels_final).init();
|
let norm_out = GroupNormConfig::new(self.n_group, n_expanded_channels_final).init(device);
|
||||||
let silu = SILU::new();
|
let silu = SILU::new();
|
||||||
let conv_out = Conv2dConfig::new([n_expanded_channels_final, self.n_channels_out], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
let conv_out = Conv2dConfig::new([n_expanded_channels_final, self.n_channels_out], [3, 3])
|
||||||
|
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||||
|
.init(device);
|
||||||
|
|
||||||
Encoder {
|
Encoder {
|
||||||
conv_in,
|
conv_in,
|
||||||
@@ -105,7 +119,6 @@ impl EncoderConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[derive(Module, Debug)]
|
#[derive(Module, Debug)]
|
||||||
pub struct Encoder<B: Backend> {
|
pub struct Encoder<B: Backend> {
|
||||||
conv_in: Conv2d<B>,
|
conv_in: Conv2d<B>,
|
||||||
@@ -126,34 +139,46 @@ impl<B: Backend> Encoder<B> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let x = self.mid.forward(x);
|
let x = self.mid.forward(x);
|
||||||
self.conv_out.forward( self.silu.forward( self.norm_out.forward(x) ) )
|
self.conv_out
|
||||||
|
.forward(self.silu.forward(self.norm_out.forward(x)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Config, Debug)]
|
||||||
|
|
||||||
#[derive(Config)]
|
|
||||||
pub struct DecoderConfig {
|
pub struct DecoderConfig {
|
||||||
channels: Vec<(usize, usize)>,
|
channels: Vec<(usize, usize)>,
|
||||||
n_group: usize,
|
n_group: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DecoderConfig {
|
impl DecoderConfig {
|
||||||
fn init<B: Backend>(&self) -> Decoder<B> {
|
fn init<B: Backend>(&self, device: &B::Device) -> Decoder<B> {
|
||||||
let n_expanded_channels = self.channels.first().map(|f| f.0).expect("Channels must not be empty.");
|
let n_expanded_channels = self
|
||||||
|
.channels
|
||||||
|
.first()
|
||||||
|
.map(|f| f.0)
|
||||||
|
.expect("Channels must not be empty.");
|
||||||
let n_condensed_channels = self.channels.last().unwrap().1;
|
let n_condensed_channels = self.channels.last().unwrap().1;
|
||||||
|
|
||||||
let conv_in = Conv2dConfig::new([4, n_expanded_channels], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
let conv_in = Conv2dConfig::new([4, n_expanded_channels], [3, 3])
|
||||||
let mid = MidConfig::new(n_expanded_channels).init();
|
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||||
|
.init(device);
|
||||||
|
let mid = MidConfig::new(n_expanded_channels).init(device);
|
||||||
|
|
||||||
let blocks = self.channels.iter().enumerate().map(|(i, &(n_channel_in, n_channel_out))| {
|
let blocks = self
|
||||||
|
.channels
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, &(n_channel_in, n_channel_out))| {
|
||||||
let upsample = i != self.channels.len() - 1;
|
let upsample = i != self.channels.len() - 1;
|
||||||
DecoderBlockConfig::new(n_channel_in, n_channel_out, upsample).init()
|
DecoderBlockConfig::new(n_channel_in, n_channel_out, upsample).init(device)
|
||||||
}).collect();
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
let norm_out = GroupNormConfig::new(self.n_group, n_condensed_channels).init();
|
let norm_out = GroupNormConfig::new(self.n_group, n_condensed_channels).init(device);
|
||||||
let silu = SILU::new();
|
let silu = SILU::new();
|
||||||
let conv_out = Conv2dConfig::new([n_condensed_channels, 3], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
let conv_out = Conv2dConfig::new([n_condensed_channels, 3], [3, 3])
|
||||||
|
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||||
|
.init(device);
|
||||||
|
|
||||||
Decoder {
|
Decoder {
|
||||||
conv_in,
|
conv_in,
|
||||||
@@ -166,7 +191,6 @@ impl DecoderConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[derive(Module, Debug)]
|
#[derive(Module, Debug)]
|
||||||
pub struct Decoder<B: Backend> {
|
pub struct Decoder<B: Backend> {
|
||||||
conv_in: Conv2d<B>,
|
conv_in: Conv2d<B>,
|
||||||
@@ -187,11 +211,12 @@ impl<B: Backend> Decoder<B> {
|
|||||||
x = block.forward(x);
|
x = block.forward(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
self.conv_out.forward( self.silu.forward( self.norm_out.forward(x) ) )
|
self.conv_out
|
||||||
|
.forward(self.silu.forward(self.norm_out.forward(x)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct EncoderBlockConfig {
|
pub struct EncoderBlockConfig {
|
||||||
n_channels_in: usize,
|
n_channels_in: usize,
|
||||||
n_channels_out: usize,
|
n_channels_out: usize,
|
||||||
@@ -199,12 +224,16 @@ pub struct EncoderBlockConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl EncoderBlockConfig {
|
impl EncoderBlockConfig {
|
||||||
fn init<B: Backend>(&self) -> EncoderBlock<B> {
|
fn init<B: Backend>(&self, device: &B::Device) -> EncoderBlock<B> {
|
||||||
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init();
|
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init(device);
|
||||||
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
|
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(device);
|
||||||
let downsampler = if self.downsample {
|
let downsampler = if self.downsample {
|
||||||
let padding = Padding::new(0, 1, 0, 1);
|
let padding = PaddingCfg::new(0, 1, 0, 1);
|
||||||
Some( PaddedConv2dConfig::new([self.n_channels_out, self.n_channels_out], 3, padding).with_stride(2).init() )
|
Some(
|
||||||
|
PaddedConv2dConfig::new([self.n_channels_out, self.n_channels_out], 3, padding)
|
||||||
|
.with_stride(2)
|
||||||
|
.init(device),
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
@@ -236,7 +265,7 @@ impl<B: Backend> EncoderBlock<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct DecoderBlockConfig {
|
pub struct DecoderBlockConfig {
|
||||||
n_channels_in: usize,
|
n_channels_in: usize,
|
||||||
n_channels_out: usize,
|
n_channels_out: usize,
|
||||||
@@ -244,12 +273,16 @@ pub struct DecoderBlockConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl DecoderBlockConfig {
|
impl DecoderBlockConfig {
|
||||||
fn init<B: Backend>(&self) -> DecoderBlock<B> {
|
fn init<B: Backend>(&self, device: &B::Device) -> DecoderBlock<B> {
|
||||||
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init();
|
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init(device);
|
||||||
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
|
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(device);
|
||||||
let res3 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
|
let res3 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(device);
|
||||||
let upsampler = if self.upsample {
|
let upsampler = if self.upsample {
|
||||||
Some( Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init() )
|
Some(
|
||||||
|
Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3])
|
||||||
|
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||||
|
.init(device),
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
@@ -281,8 +314,7 @@ impl<B: Backend> DecoderBlock<B> {
|
|||||||
let [n_batch, n_channel, height, width] = x.dims();
|
let [n_batch, n_channel, height, width] = x.dims();
|
||||||
let x = x
|
let x = x
|
||||||
.reshape([n_batch, n_channel, height, 1, width, 1])
|
.reshape([n_batch, n_channel, height, 1, width, 1])
|
||||||
.repeat(3, 2)
|
.repeat(&[1, 1, 1, 2, 1, 2])
|
||||||
.repeat(5, 2)
|
|
||||||
.reshape([n_batch, n_channel, 2 * height, 2 * width]);
|
.reshape([n_batch, n_channel, 2 * height, 2 * width]);
|
||||||
d.forward(x)
|
d.forward(x)
|
||||||
} else {
|
} else {
|
||||||
@@ -291,18 +323,17 @@ impl<B: Backend> DecoderBlock<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Config, Debug)]
|
||||||
#[derive(Config)]
|
|
||||||
pub struct PaddedConv2dConfig {
|
pub struct PaddedConv2dConfig {
|
||||||
channels: [usize; 2],
|
channels: [usize; 2],
|
||||||
kernel_size: usize,
|
kernel_size: usize,
|
||||||
#[config(default = 1)]
|
#[config(default = 1)]
|
||||||
stride: usize,
|
stride: usize,
|
||||||
padding: Padding,
|
padding: PaddingCfg,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PaddedConv2dConfig {
|
impl PaddedConv2dConfig {
|
||||||
fn init<B: Backend>(&self) -> PaddedConv2d<B> {
|
fn init<B: Backend>(&self, device: &B::Device) -> PaddedConv2d<B> {
|
||||||
let calc_padding = |p_left, p_right| {
|
let calc_padding = |p_left, p_right| {
|
||||||
let n = if p_left >= p_right {
|
let n = if p_left >= p_right {
|
||||||
0
|
0
|
||||||
@@ -320,12 +351,17 @@ impl PaddedConv2dConfig {
|
|||||||
let conv = Conv2dConfig::new(self.channels, [self.kernel_size, self.kernel_size])
|
let conv = Conv2dConfig::new(self.channels, [self.kernel_size, self.kernel_size])
|
||||||
.with_stride([self.stride, self.stride])
|
.with_stride([self.stride, self.stride])
|
||||||
.with_padding(PaddingConfig2d::Explicit(pad_vertical, pad_horizontal))
|
.with_padding(PaddingConfig2d::Explicit(pad_vertical, pad_horizontal))
|
||||||
.init();
|
.init(device);
|
||||||
|
|
||||||
let kernel_size = self.kernel_size;
|
let kernel_size = self.kernel_size;
|
||||||
let stride = self.stride;
|
let stride = self.stride;
|
||||||
|
|
||||||
let padding = self.padding;
|
let padding = Padding {
|
||||||
|
pad_left: self.padding.pad_left,
|
||||||
|
pad_right: self.padding.pad_right,
|
||||||
|
pad_top: self.padding.pad_top,
|
||||||
|
pad_bottom: self.padding.pad_bottom,
|
||||||
|
};
|
||||||
|
|
||||||
PaddedConv2d {
|
PaddedConv2d {
|
||||||
conv,
|
conv,
|
||||||
@@ -337,6 +373,10 @@ impl PaddedConv2dConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn div_roundup(x: usize, y: usize) -> usize {
|
||||||
|
(x + y - 1) / y
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Module, Debug)]
|
#[derive(Module, Debug)]
|
||||||
pub struct PaddedConv2d<B: Backend> {
|
pub struct PaddedConv2d<B: Backend> {
|
||||||
conv: Conv2d<B>,
|
conv: Conv2d<B>,
|
||||||
@@ -348,27 +388,38 @@ pub struct PaddedConv2d<B: Backend> {
|
|||||||
|
|
||||||
impl<B: Backend> PaddedConv2d<B> {
|
impl<B: Backend> PaddedConv2d<B> {
|
||||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||||
println!("{} {} {:?} {:?}", self.kernel_size, self.stride, self.padding, self.padding_actual);
|
|
||||||
let [n_batch, n_channel, height, width] = x.dims();
|
let [n_batch, n_channel, height, width] = x.dims();
|
||||||
|
|
||||||
let desired_height = (self.padding.pad_top + self.padding.pad_bottom + height - self.kernel_size) / self.stride + 1;
|
let desired_height = (self.padding.pad_top + self.padding.pad_bottom + height
|
||||||
let desired_width = (self.padding.pad_left + self.padding.pad_right + width - self.kernel_size) / self.stride + 1;
|
- self.kernel_size)
|
||||||
|
/ self.stride
|
||||||
|
+ 1;
|
||||||
|
let desired_width = (self.padding.pad_left + self.padding.pad_right + width
|
||||||
|
- self.kernel_size)
|
||||||
|
/ self.stride
|
||||||
|
+ 1;
|
||||||
|
|
||||||
let skip_vert = (self.padding_actual[0] - self.padding.pad_top) / self.stride;
|
let skip_vert = (self.padding_actual[0] - self.padding.pad_top) / self.stride;
|
||||||
let skip_hor = (self.padding_actual[1] - self.padding.pad_left) / self.stride;
|
let skip_hor = (self.padding_actual[1] - self.padding.pad_left) / self.stride;
|
||||||
|
|
||||||
self.conv
|
self.conv.forward(x).slice([
|
||||||
.forward(x)
|
|
||||||
.slice([
|
|
||||||
0..n_batch,
|
0..n_batch,
|
||||||
0..n_channel,
|
0..n_channel,
|
||||||
skip_vert..(skip_vert + desired_height),
|
skip_vert..(skip_vert + desired_height),
|
||||||
skip_hor..(skip_hor + desired_width)
|
skip_hor..(skip_hor + desired_width),
|
||||||
])
|
])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config, Module, Copy, Debug)]
|
#[derive(Config, Debug)]
|
||||||
|
pub struct PaddingCfg {
|
||||||
|
pad_left: usize,
|
||||||
|
pad_right: usize,
|
||||||
|
pad_top: usize,
|
||||||
|
pad_bottom: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Module, Clone, Debug)]
|
||||||
pub struct Padding {
|
pub struct Padding {
|
||||||
pad_left: usize,
|
pad_left: usize,
|
||||||
pad_right: usize,
|
pad_right: usize,
|
||||||
@@ -376,16 +427,16 @@ pub struct Padding {
|
|||||||
pad_bottom: usize,
|
pad_bottom: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct MidConfig {
|
pub struct MidConfig {
|
||||||
n_channel: usize,
|
n_channel: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MidConfig {
|
impl MidConfig {
|
||||||
fn init<B: Backend>(&self) -> Mid<B> {
|
fn init<B: Backend>(&self, device: &B::Device) -> Mid<B> {
|
||||||
let block_1 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init();
|
let block_1 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init(device);
|
||||||
let attn = ConvSelfAttentionBlockConfig::new(self.n_channel).init();
|
let attn = ConvSelfAttentionBlockConfig::new(self.n_channel).init(device);
|
||||||
let block_2 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init();
|
let block_2 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init(device);
|
||||||
|
|
||||||
Mid {
|
Mid {
|
||||||
block_1,
|
block_1,
|
||||||
@@ -411,21 +462,24 @@ impl<B: Backend> Mid<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Config, Debug)]
|
||||||
#[derive(Config)]
|
|
||||||
pub struct ResnetBlockConfig {
|
pub struct ResnetBlockConfig {
|
||||||
in_channels: usize,
|
in_channels: usize,
|
||||||
out_channels: usize,
|
out_channels: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ResnetBlockConfig {
|
impl ResnetBlockConfig {
|
||||||
fn init<B: Backend>(&self) -> ResnetBlock<B> {
|
fn init<B: Backend>(&self, device: &B::Device) -> ResnetBlock<B> {
|
||||||
let norm1 = GroupNormConfig::new(32, self.in_channels).init();
|
let norm1 = GroupNormConfig::new(32, self.in_channels).init(device);
|
||||||
let conv1 = Conv2dConfig::new([self.in_channels, self.out_channels], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
let conv1 = Conv2dConfig::new([self.in_channels, self.out_channels], [3, 3])
|
||||||
let norm2 = GroupNormConfig::new(32, self.out_channels).init();
|
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||||
let conv2 = Conv2dConfig::new([self.out_channels, self.out_channels], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
.init(device);
|
||||||
|
let norm2 = GroupNormConfig::new(32, self.out_channels).init(device);
|
||||||
|
let conv2 = Conv2dConfig::new([self.out_channels, self.out_channels], [3, 3])
|
||||||
|
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||||
|
.init(device);
|
||||||
let nin_shortcut = if self.in_channels != self.out_channels {
|
let nin_shortcut = if self.in_channels != self.out_channels {
|
||||||
Some( Conv2dConfig::new([self.in_channels, self.out_channels], [1, 1]).init() )
|
Some(Conv2dConfig::new([self.in_channels, self.out_channels], [1, 1]).init(device))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
@@ -458,9 +512,12 @@ pub struct ResnetBlock<B: Backend> {
|
|||||||
|
|
||||||
impl<B: Backend> ResnetBlock<B> {
|
impl<B: Backend> ResnetBlock<B> {
|
||||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||||
let h = self.conv1.forward( self.silu1.forward(self.norm1.forward(x.clone())) );
|
let h = self
|
||||||
let h = self.conv2.forward( self.silu2.forward(self.norm2.forward(h)) );
|
.conv1
|
||||||
|
.forward(self.silu1.forward(self.norm1.forward(x.clone())));
|
||||||
|
let h = self
|
||||||
|
.conv2
|
||||||
|
.forward(self.silu2.forward(self.norm2.forward(h)));
|
||||||
|
|
||||||
if let Some(ns) = self.nin_shortcut.as_ref() {
|
if let Some(ns) = self.nin_shortcut.as_ref() {
|
||||||
ns.forward(x) + h
|
ns.forward(x) + h
|
||||||
@@ -470,18 +527,18 @@ impl<B: Backend> ResnetBlock<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct ConvSelfAttentionBlockConfig {
|
pub struct ConvSelfAttentionBlockConfig {
|
||||||
n_channel: usize,
|
n_channel: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ConvSelfAttentionBlockConfig {
|
impl ConvSelfAttentionBlockConfig {
|
||||||
fn init<B: Backend>(&self) -> ConvSelfAttentionBlock<B> {
|
fn init<B: Backend>(&self, device: &B::Device) -> ConvSelfAttentionBlock<B> {
|
||||||
let norm = GroupNormConfig::new(32, self.n_channel).init();
|
let norm = GroupNormConfig::new(32, self.n_channel).init(device);
|
||||||
let q = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
|
let q = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
|
||||||
let k = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
|
let k = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
|
||||||
let v = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
|
let v = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
|
||||||
let proj_out = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
|
let proj_out = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
|
||||||
|
|
||||||
ConvSelfAttentionBlock {
|
ConvSelfAttentionBlock {
|
||||||
norm,
|
norm,
|
||||||
@@ -508,11 +565,39 @@ impl<B: Backend> ConvSelfAttentionBlock<B> {
|
|||||||
|
|
||||||
let h = self.norm.forward(x.clone());
|
let h = self.norm.forward(x.clone());
|
||||||
|
|
||||||
let q = self.q.forward(h.clone()).reshape([n_batch, n_channel, height * width]).swap_dims(1, 2);
|
let q = self
|
||||||
let k = self.k.forward(h.clone()).reshape([n_batch, n_channel, height * width]).swap_dims(1, 2);
|
.q
|
||||||
let v = self.v.forward(h).reshape([n_batch, n_channel, height * width]).swap_dims(1, 2);
|
.forward(h.clone())
|
||||||
|
.reshape([n_batch, n_channel, height * width])
|
||||||
|
.swap_dims(1, 2);
|
||||||
|
let k = self
|
||||||
|
.k
|
||||||
|
.forward(h.clone())
|
||||||
|
.reshape([n_batch, n_channel, height * width])
|
||||||
|
.swap_dims(1, 2);
|
||||||
|
let v = self
|
||||||
|
.v
|
||||||
|
.forward(h)
|
||||||
|
.reshape([n_batch, n_channel, height * width])
|
||||||
|
.swap_dims(1, 2);
|
||||||
|
|
||||||
let wv = qkv_attention(q, k, v, None, 1)
|
/*let wv = Tensor::from_primitive(B::qkv_attention(
|
||||||
|
q.into_primitive(),
|
||||||
|
k.into_primitive(),
|
||||||
|
v.into_primitive(),
|
||||||
|
None,
|
||||||
|
1,
|
||||||
|
))
|
||||||
|
.swap_dims(1, 2)
|
||||||
|
.reshape([n_batch, n_channel, height, width]);*/
|
||||||
|
|
||||||
|
let wv = qkv_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
None,
|
||||||
|
1,
|
||||||
|
)
|
||||||
.swap_dims(1, 2)
|
.swap_dims(1, 2)
|
||||||
.reshape([n_batch, n_channel, height, width]);
|
.reshape([n_batch, n_channel, height, width]);
|
||||||
|
|
||||||
|
|||||||
@@ -1,14 +1,11 @@
|
|||||||
use std::error::Error;
|
|
||||||
use burn::tensor::ElementConversion;
|
use burn::tensor::ElementConversion;
|
||||||
|
use std::error::Error;
|
||||||
|
|
||||||
use burn::{
|
use burn::{
|
||||||
config::Config,
|
config::Config,
|
||||||
module::{Module, Param},
|
module::{Module, Param},
|
||||||
nn,
|
nn,
|
||||||
tensor::{
|
tensor::{backend::Backend, Tensor},
|
||||||
backend::Backend,
|
|
||||||
Tensor,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -28,7 +25,10 @@ pub fn load_mlp<B: Backend>(path: &str, device: &B::Device) -> Result<MLP<B>, Bo
|
|||||||
Ok(mlp)
|
Ok(mlp)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load_multi_head_self_attention<B: Backend>(path: &str, device: &B::Device) -> Result<MultiHeadSelfAttention<B>, Box<dyn Error>> {
|
pub fn load_multi_head_self_attention<B: Backend>(
|
||||||
|
path: &str,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Result<MultiHeadSelfAttention<B>, Box<dyn Error>> {
|
||||||
let n_head = load_usize::<B>("n_head", path, device)?;
|
let n_head = load_usize::<B>("n_head", path, device)?;
|
||||||
let query = load_linear(&format!("{}/{}", path, "query"), device)?;
|
let query = load_linear(&format!("{}/{}", path, "query"), device)?;
|
||||||
let key = load_linear(&format!("{}/{}", path, "key"), device)?;
|
let key = load_linear(&format!("{}/{}", path, "key"), device)?;
|
||||||
@@ -46,7 +46,10 @@ pub fn load_multi_head_self_attention<B: Backend>(path: &str, device: &B::Device
|
|||||||
Ok(mhsa)
|
Ok(mhsa)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load_residual_decoder_attention_block<B: Backend>(path: &str, device: &B::Device) -> Result<ResidualDecoderAttentionBlock<B>, Box<dyn Error>> {
|
pub fn load_residual_decoder_attention_block<B: Backend>(
|
||||||
|
path: &str,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Result<ResidualDecoderAttentionBlock<B>, Box<dyn Error>> {
|
||||||
let mlp = load_mlp(&format!("{}/{}", path, "mlp"), device)?;
|
let mlp = load_mlp(&format!("{}/{}", path, "mlp"), device)?;
|
||||||
let attn = load_multi_head_self_attention(&format!("{}/{}", path, "attn"), device)?;
|
let attn = load_multi_head_self_attention(&format!("{}/{}", path, "attn"), device)?;
|
||||||
let attn_ln = load_layer_norm(&format!("{}/{}", path, "attn_ln"), device)?;
|
let attn_ln = load_layer_norm(&format!("{}/{}", path, "attn_ln"), device)?;
|
||||||
@@ -64,14 +67,16 @@ pub fn load_residual_decoder_attention_block<B: Backend>(path: &str, device: &B:
|
|||||||
|
|
||||||
pub fn load_clip<B: Backend>(path: &str, device: &B::Device) -> Result<CLIP<B>, Box<dyn Error>> {
|
pub fn load_clip<B: Backend>(path: &str, device: &B::Device) -> Result<CLIP<B>, Box<dyn Error>> {
|
||||||
let token_embedding = load_embedding(&format!("{}/{}", path, "token_embedding"), device)?;
|
let token_embedding = load_embedding(&format!("{}/{}", path, "token_embedding"), device)?;
|
||||||
let position_embedding = load_tensor("weight", &format!("{}/position_embedding", path), device)?.into();
|
let position_embedding =
|
||||||
|
Param::from_tensor(load_tensor("weight", &format!("{}/position_embedding", path), device)?);
|
||||||
|
|
||||||
let n_layer = load_usize::<B>("n_layer", path, device)?;
|
let n_layer = load_usize::<B>("n_layer", path, device)?;
|
||||||
let mut blocks = (0..n_layer)
|
let mut blocks = (0..n_layer)
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
load_residual_decoder_attention_block::<B>(&format!("{}/blocks/{}", path, i), device)
|
load_residual_decoder_attention_block::<B>(&format!("{}/blocks/{}", path, i), device)
|
||||||
}).collect::<Result<Vec<_>, _>>()?;
|
})
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
|
||||||
let layer_norm = load_layer_norm(&format!("{}/{}", path, "layer_norm"), device)?;
|
let layer_norm = load_layer_norm(&format!("{}/{}", path, "layer_norm"), device)?;
|
||||||
|
|
||||||
|
|||||||
@@ -5,19 +5,17 @@ use burn::{
|
|||||||
module::{Module, Param},
|
module::{Module, Param},
|
||||||
nn,
|
nn,
|
||||||
tensor::{
|
tensor::{
|
||||||
|
activation::{sigmoid, softmax},
|
||||||
backend::Backend,
|
backend::Backend,
|
||||||
activation::{softmax, sigmoid},
|
|
||||||
module::embedding,
|
module::embedding,
|
||||||
Tensor,
|
Distribution, Int, Tensor,
|
||||||
Distribution,
|
|
||||||
Int,
|
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::model::attention::{qkv_attention, attn_decoder_mask};
|
//use crate::backend::Backend as MyBackend;
|
||||||
|
use crate::backend::{qkv_attention, attn_decoder_mask};
|
||||||
|
|
||||||
|
#[derive(Config, Debug)]
|
||||||
#[derive(Config)]
|
|
||||||
pub struct CLIPConfig {
|
pub struct CLIPConfig {
|
||||||
n_vocab: usize,
|
n_vocab: usize,
|
||||||
n_state: usize,
|
n_state: usize,
|
||||||
@@ -27,14 +25,15 @@ pub struct CLIPConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl CLIPConfig {
|
impl CLIPConfig {
|
||||||
pub fn init<B: Backend>(&self) -> CLIP<B> {
|
pub fn init<B: Backend>(&self, device: &B::Device) -> CLIP<B> {
|
||||||
let token_embedding = nn::EmbeddingConfig::new(self.n_vocab, self.n_state).init();
|
let token_embedding = nn::EmbeddingConfig::new(self.n_vocab, self.n_state).init(device);
|
||||||
let position_embedding = Tensor::random([self.n_ctx, self.n_state], Distribution::Normal(0.0, 1.0)).into();
|
let position_embedding =
|
||||||
|
Param::from_tensor(Tensor::random([self.n_ctx, self.n_state], Distribution::Normal(0.0, 1.0), device));
|
||||||
let blocks = (0..self.n_layer)
|
let blocks = (0..self.n_layer)
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|_| ResidualDecoderAttentionBlockConfig::new(self.n_state, self.n_head).init())
|
.map(|_| ResidualDecoderAttentionBlockConfig::new(self.n_state, self.n_head).init(device))
|
||||||
.collect();
|
.collect();
|
||||||
let layer_norm = nn::LayerNormConfig::new(self.n_state).init();
|
let layer_norm = nn::LayerNormConfig::new(self.n_state).init(device);
|
||||||
|
|
||||||
CLIP {
|
CLIP {
|
||||||
token_embedding,
|
token_embedding,
|
||||||
@@ -45,8 +44,6 @@ impl CLIPConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#[derive(Module, Debug)]
|
#[derive(Module, Debug)]
|
||||||
pub struct CLIP<B: Backend> {
|
pub struct CLIP<B: Backend> {
|
||||||
token_embedding: nn::Embedding<B>,
|
token_embedding: nn::Embedding<B>,
|
||||||
@@ -59,10 +56,15 @@ impl<B: Backend> CLIP<B> {
|
|||||||
pub fn forward(&self, x: Tensor<B, 2, Int>) -> Tensor<B, 3> {
|
pub fn forward(&self, x: Tensor<B, 2, Int>) -> Tensor<B, 3> {
|
||||||
let [n_batch, seq_len] = x.dims();
|
let [n_batch, seq_len] = x.dims();
|
||||||
|
|
||||||
|
//let mask = Tensor::from_primitive(B::attn_decoder_mask(seq_len, &x.device()));
|
||||||
let mask = attn_decoder_mask(seq_len, &x.device());
|
let mask = attn_decoder_mask(seq_len, &x.device());
|
||||||
|
|
||||||
let embedded = self.token_embedding.forward(x)
|
let embedded = self.token_embedding.forward(x)
|
||||||
+ self.position_embedding.val().slice([0..seq_len]).unsqueeze();
|
+ self
|
||||||
|
.position_embedding
|
||||||
|
.val()
|
||||||
|
.slice([0..seq_len])
|
||||||
|
.unsqueeze();
|
||||||
|
|
||||||
let mut x = embedded;
|
let mut x = embedded;
|
||||||
for block in &self.blocks {
|
for block in &self.blocks {
|
||||||
@@ -73,21 +75,19 @@ impl<B: Backend> CLIP<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Config, Debug)]
|
||||||
|
|
||||||
#[derive(Config)]
|
|
||||||
pub struct ResidualDecoderAttentionBlockConfig {
|
pub struct ResidualDecoderAttentionBlockConfig {
|
||||||
n_state: usize,
|
n_state: usize,
|
||||||
n_head: usize,
|
n_head: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ResidualDecoderAttentionBlockConfig {
|
impl ResidualDecoderAttentionBlockConfig {
|
||||||
pub fn init<B: Backend>(&self) -> ResidualDecoderAttentionBlock<B> {
|
pub fn init<B: Backend>(&self, device: &B::Device) -> ResidualDecoderAttentionBlock<B> {
|
||||||
let attn = MultiHeadSelfAttentionConfig::new(self.n_state, self.n_head).init();
|
let attn = MultiHeadSelfAttentionConfig::new(self.n_state, self.n_head).init(device);
|
||||||
let attn_ln = nn::LayerNormConfig::new(self.n_state).init();
|
let attn_ln = nn::LayerNormConfig::new(self.n_state).init(device);
|
||||||
|
|
||||||
let mlp = MLPConfig::new(self.n_state, 4 * self.n_state).init();
|
let mlp = MLPConfig::new(self.n_state, 4 * self.n_state).init(device);
|
||||||
let mlp_ln = nn::LayerNormConfig::new(self.n_state).init();
|
let mlp_ln = nn::LayerNormConfig::new(self.n_state).init(device);
|
||||||
|
|
||||||
ResidualDecoderAttentionBlock {
|
ResidualDecoderAttentionBlock {
|
||||||
attn,
|
attn,
|
||||||
@@ -114,28 +114,33 @@ impl<B: Backend> ResidualDecoderAttentionBlock<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct MultiHeadSelfAttentionConfig {
|
pub struct MultiHeadSelfAttentionConfig {
|
||||||
n_state: usize,
|
n_state: usize,
|
||||||
n_head: usize,
|
n_head: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MultiHeadSelfAttentionConfig {
|
impl MultiHeadSelfAttentionConfig {
|
||||||
fn init<B: Backend>(&self) -> MultiHeadSelfAttention<B> {
|
fn init<B: Backend>(&self, device: &B::Device) -> MultiHeadSelfAttention<B> {
|
||||||
assert!(self.n_state % self.n_head == 0, "State size {} must be a multiple of head size {}", self.n_state, self.n_head);
|
assert!(
|
||||||
|
self.n_state % self.n_head == 0,
|
||||||
|
"State size {} must be a multiple of head size {}",
|
||||||
|
self.n_state,
|
||||||
|
self.n_head
|
||||||
|
);
|
||||||
|
|
||||||
let n_head = self.n_head;
|
let n_head = self.n_head;
|
||||||
let query = nn::LinearConfig::new(self.n_state, self.n_state).init();
|
let query = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
|
||||||
let key = nn::LinearConfig::new(self.n_state, self.n_state).init();
|
let key = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
|
||||||
let value = nn::LinearConfig::new(self.n_state, self.n_state).init();
|
let value = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
|
||||||
let out = nn::LinearConfig::new(self.n_state, self.n_state).init();
|
let out = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
|
||||||
|
|
||||||
MultiHeadSelfAttention {
|
MultiHeadSelfAttention {
|
||||||
n_head,
|
n_head,
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
out
|
out,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -155,19 +160,26 @@ impl<B: Backend> MultiHeadSelfAttention<B> {
|
|||||||
let k = self.key.forward(x.clone());
|
let k = self.key.forward(x.clone());
|
||||||
let v = self.value.forward(x);
|
let v = self.value.forward(x);
|
||||||
|
|
||||||
let wv = qkv_attention(q, k, v, mask, self.n_head);
|
/*let wv = Tensor::from_primitive(B::qkv_attention(
|
||||||
|
q.into_primitive(),
|
||||||
|
k.into_primitive(),
|
||||||
|
v.into_primitive(),
|
||||||
|
mask.map(|m| m.into_primitive()),
|
||||||
|
self.n_head,
|
||||||
|
));*/
|
||||||
|
|
||||||
|
let wv = qkv_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
mask,
|
||||||
|
self.n_head,
|
||||||
|
);
|
||||||
|
|
||||||
return self.out.forward(wv);
|
return self.out.forward(wv);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#[derive(Config, Debug)]
|
#[derive(Config, Debug)]
|
||||||
pub struct MLPConfig {
|
pub struct MLPConfig {
|
||||||
input_size: usize,
|
input_size: usize,
|
||||||
@@ -175,16 +187,12 @@ pub struct MLPConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl MLPConfig {
|
impl MLPConfig {
|
||||||
fn init<B: Backend>(&self) -> MLP<B> {
|
fn init<B: Backend>(&self, device: &B::Device) -> MLP<B> {
|
||||||
let fc1 = nn::LinearConfig::new(self.input_size, self.hidden_size).init();
|
let fc1 = nn::LinearConfig::new(self.input_size, self.hidden_size).init(device);
|
||||||
let gelu = QuickGELU::new();
|
let gelu = QuickGELU::new();
|
||||||
let fc2 = nn::LinearConfig::new(self.hidden_size, self.input_size).init();
|
let fc2 = nn::LinearConfig::new(self.hidden_size, self.input_size).init(device);
|
||||||
|
|
||||||
MLP {
|
MLP { fc1, gelu, fc2 }
|
||||||
fc1,
|
|
||||||
gelu,
|
|
||||||
fc2,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -217,4 +225,3 @@ impl QuickGELU {
|
|||||||
x.clone() * sigmoid(x * 1.702)
|
x.clone() * sigmoid(x * 1.702)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,27 +7,31 @@ use burn::{
|
|||||||
config::Config,
|
config::Config,
|
||||||
module::{Module, Param},
|
module::{Module, Param},
|
||||||
nn,
|
nn,
|
||||||
tensor::{
|
tensor::{backend::Backend, Tensor},
|
||||||
backend::Backend,
|
|
||||||
Tensor,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn load_group_norm<B: Backend>(path: &str, device: &B::Device) -> Result<GroupNorm<B>, Box<dyn Error>> {
|
pub fn load_group_norm<B: Backend>(
|
||||||
|
path: &str,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Result<GroupNorm<B>, Box<dyn Error>> {
|
||||||
let n_group = load_usize::<B>("n_group", path, device)?.into();
|
let n_group = load_usize::<B>("n_group", path, device)?.into();
|
||||||
let n_channel = load_usize::<B>("n_channel", path, device)?.into();
|
let n_channel = load_usize::<B>("n_channel", path, device)?.into();
|
||||||
let eps = load_f32::<B>("eps", path, device)?.into();
|
let eps = load_f32::<B>("eps", path, device)?.into();
|
||||||
|
|
||||||
let gamma = load_tensor::<B, 1>("weight", path, device).ok().unwrap_or_else(|| Tensor::ones_device([n_channel], device)).into();
|
let gamma = Param::from_tensor(load_tensor::<B, 1>("weight", path, device)
|
||||||
let beta = load_tensor::<B, 1>("bias", path, device).ok().unwrap_or_else(|| Tensor::zeros_device([n_channel], device)).into();
|
.ok()
|
||||||
|
.unwrap_or_else(|| Tensor::ones([n_channel], device))
|
||||||
|
);
|
||||||
|
let beta = Param::from_tensor(load_tensor::<B, 1>("bias", path, device)
|
||||||
|
.ok()
|
||||||
|
.unwrap_or_else(|| Tensor::zeros([n_channel], device))
|
||||||
|
);
|
||||||
|
|
||||||
Ok(
|
Ok(GroupNorm {
|
||||||
GroupNorm {
|
|
||||||
n_group,
|
n_group,
|
||||||
n_channel,
|
n_channel,
|
||||||
gamma,
|
gamma,
|
||||||
beta,
|
beta,
|
||||||
eps,
|
eps,
|
||||||
}
|
})
|
||||||
)
|
|
||||||
}
|
}
|
||||||
@@ -3,13 +3,10 @@ pub mod load;
|
|||||||
use burn::{
|
use burn::{
|
||||||
config::Config,
|
config::Config,
|
||||||
module::{Module, Param},
|
module::{Module, Param},
|
||||||
tensor::{
|
tensor::{backend::Backend, Tensor},
|
||||||
backend::Backend,
|
|
||||||
Tensor,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct GroupNormConfig {
|
pub struct GroupNormConfig {
|
||||||
n_group: usize,
|
n_group: usize,
|
||||||
n_channel: usize,
|
n_channel: usize,
|
||||||
@@ -18,13 +15,18 @@ pub struct GroupNormConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl GroupNormConfig {
|
impl GroupNormConfig {
|
||||||
pub fn init<B: Backend>(&self) -> GroupNorm<B> {
|
pub fn init<B: Backend>(&self, device: &B::Device) -> GroupNorm<B> {
|
||||||
assert!(self.n_channel % self.n_group == 0, "The number of channels {} must be divisible by the number of groups {}", self.n_channel, self.n_group);
|
assert!(
|
||||||
|
self.n_channel % self.n_group == 0,
|
||||||
|
"The number of channels {} must be divisible by the number of groups {}",
|
||||||
|
self.n_channel,
|
||||||
|
self.n_group
|
||||||
|
);
|
||||||
|
|
||||||
let n_per_group = self.n_channel / self.n_group;
|
let n_per_group = self.n_channel / self.n_group;
|
||||||
|
|
||||||
let gamma = Tensor::ones([self.n_channel]).into();
|
let gamma = Param::from_tensor(Tensor::ones([self.n_channel], device));
|
||||||
let beta = Tensor::zeros([self.n_channel]).into();
|
let beta = Param::from_tensor(Tensor::zeros([self.n_channel], device));
|
||||||
|
|
||||||
let eps = self.eps;
|
let eps = self.eps;
|
||||||
|
|
||||||
@@ -56,7 +58,14 @@ impl<B: Backend> GroupNorm<B> {
|
|||||||
let mut affine_shape = [1; D];
|
let mut affine_shape = [1; D];
|
||||||
affine_shape[1] = self.n_channel;
|
affine_shape[1] = self.n_channel;
|
||||||
|
|
||||||
layernorm( x.reshape([n_batch, self.n_group, num_elements / (n_batch * self.n_group) ]), self.eps )
|
layernorm(
|
||||||
|
x.reshape([
|
||||||
|
n_batch,
|
||||||
|
self.n_group,
|
||||||
|
num_elements / (n_batch * self.n_group),
|
||||||
|
]),
|
||||||
|
self.eps,
|
||||||
|
)
|
||||||
.reshape(shape)
|
.reshape(shape)
|
||||||
.mul(self.gamma.val().reshape(affine_shape))
|
.mul(self.gamma.val().reshape(affine_shape))
|
||||||
.add(self.beta.val().reshape(affine_shape))
|
.add(self.beta.val().reshape(affine_shape))
|
||||||
@@ -68,5 +77,6 @@ pub fn layernorm<B: Backend, const D: usize>(x: Tensor<B, D>, eps: f64) -> Tenso
|
|||||||
//x.sub(mean).div(var.sqrt().add_scalar(eps))
|
//x.sub(mean).div(var.sqrt().add_scalar(eps))
|
||||||
|
|
||||||
let u = x.clone() - x.mean_dim(D - 1);
|
let u = x.clone() - x.mean_dim(D - 1);
|
||||||
u.clone().div( (u.clone() * u).mean_dim(D - 1).add_scalar(eps).sqrt() )
|
u.clone()
|
||||||
|
.div((u.clone() * u).mean_dim(D - 1).add_scalar(eps).sqrt())
|
||||||
}
|
}
|
||||||
@@ -1,36 +1,41 @@
|
|||||||
use std::error::Error;
|
|
||||||
use std::io::Read;
|
|
||||||
use npy::{self, NpyData};
|
use npy::{self, NpyData};
|
||||||
use num_traits::cast::ToPrimitive;
|
use num_traits::cast::ToPrimitive;
|
||||||
|
use burn::tensor::cast::ToElement;
|
||||||
|
use burn::prelude::TensorData;
|
||||||
|
use std::error::Error;
|
||||||
|
use std::io::Read;
|
||||||
|
|
||||||
use burn::{
|
use burn::{
|
||||||
config::Config,
|
config::Config,
|
||||||
module::{Module, Param},
|
module::{Module, Param},
|
||||||
nn::{self, conv},
|
nn::{self, conv},
|
||||||
tensor::{
|
tensor::{backend::Backend, Tensor},
|
||||||
backend::Backend,
|
|
||||||
Tensor,
|
|
||||||
Data,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use burn::tensor::ElementConversion;
|
use burn::tensor::ElementConversion;
|
||||||
|
|
||||||
pub fn numpy_to_tensor<B: Backend, const D: usize>(numpy_data: NpyData<f32>, device: &B::Device) -> Tensor<B, D> {
|
pub fn numpy_to_tensor<B: Backend, const D: usize>(
|
||||||
|
numpy_data: NpyData<f32>,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Tensor<B, D> {
|
||||||
let mut v = numpy_data.to_vec();
|
let mut v = numpy_data.to_vec();
|
||||||
|
|
||||||
let shape: Vec<_> = v[0..D].into_iter().map(|&v| v as usize).collect();
|
let shape: Vec<_> = v[0..D].into_iter().map(|&v| v as usize).collect();
|
||||||
let data: Vec<B::FloatElem> = v[D..].into_iter().map(|e| e.elem()).collect();
|
let data: Vec<B::FloatElem> = v[D..].into_iter().map(|e| e.elem()).collect();
|
||||||
|
|
||||||
Tensor::from_data_device(Data::new(data, shape.into()), device)
|
//Tensor::from_data_device(Data::new(data, shape.into()), device)
|
||||||
|
Tensor::from_data(TensorData::new(data, shape), device)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load_tensor<B: Backend, const D: usize>(name: &str, path: &str, device: &B::Device) -> Result<Tensor<B, D>, Box<dyn Error>> {
|
pub fn load_tensor<B: Backend, const D: usize>(
|
||||||
|
name: &str,
|
||||||
|
path: &str,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Result<Tensor<B, D>, Box<dyn Error>> {
|
||||||
let tensor_path = format!("{}/{}.npy", path, name);
|
let tensor_path = format!("{}/{}.npy", path, name);
|
||||||
|
|
||||||
let mut buf = vec![];
|
let mut buf = vec![];
|
||||||
std::fs::File::open(&tensor_path)?
|
std::fs::File::open(&tensor_path)?.read_to_end(&mut buf)?;
|
||||||
.read_to_end(&mut buf)?;
|
|
||||||
|
|
||||||
let tensor_numpy: NpyData<f32> = NpyData::from_bytes(&buf)?;
|
let tensor_numpy: NpyData<f32> = NpyData::from_bytes(&buf)?;
|
||||||
|
|
||||||
@@ -41,71 +46,79 @@ pub fn load_tensor<B: Backend, const D: usize>(name: &str, path: &str, device: &
|
|||||||
Ok(tensor)
|
Ok(tensor)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load_f32<B: Backend>(name: &str, path: &str, device: &B::Device) -> Result<f32, Box<dyn Error>> {
|
pub fn load_f32<B: Backend>(
|
||||||
load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_f32().unwrap())
|
name: &str,
|
||||||
|
path: &str,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Result<f32, Box<dyn Error>> {
|
||||||
|
load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_f32())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load_usize<B: Backend>(name: &str, path: &str, device: &B::Device) -> Result<usize, Box<dyn Error>> {
|
pub fn load_usize<B: Backend>(
|
||||||
load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_usize().unwrap())
|
name: &str,
|
||||||
|
path: &str,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Result<usize, Box<dyn Error>> {
|
||||||
|
load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_usize())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load_linear<B: Backend>(path: &str, device: &B::Device) -> Result<nn::Linear<B>, Box<dyn Error>> {
|
pub fn load_linear<B: Backend>(
|
||||||
|
path: &str,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Result<nn::Linear<B>, Box<dyn Error>> {
|
||||||
let weight = load_tensor::<B, 2>("weight", path, device)?;
|
let weight = load_tensor::<B, 2>("weight", path, device)?;
|
||||||
let bias = load_tensor::<B, 1>("bias", path, device).ok();
|
let bias = load_tensor::<B, 1>("bias", path, device).ok();
|
||||||
|
|
||||||
let record = nn::LinearRecord {
|
Ok(nn::Linear {
|
||||||
weight: weight.into(),
|
weight: Param::from_tensor(weight),
|
||||||
bias: bias.map(|t| t.into()),
|
bias: bias.map(|t| Param::from_tensor(t)),
|
||||||
};
|
})
|
||||||
|
|
||||||
let linear: nn::Linear<B> = nn::LinearConfig::new(3, 3).init_with(record);
|
|
||||||
Ok(linear)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load_embedding<B: Backend>(path: &str, device: &B::Device) -> Result<nn::Embedding<B>, Box<dyn Error>> {
|
pub fn load_embedding<B: Backend>(
|
||||||
|
path: &str,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Result<nn::Embedding<B>, Box<dyn Error>> {
|
||||||
let weight = load_tensor::<B, 2>("weight", path, device)?;
|
let weight = load_tensor::<B, 2>("weight", path, device)?;
|
||||||
let [n_vocab, n_state] = weight.dims();
|
|
||||||
|
|
||||||
let record = nn::EmbeddingRecord {
|
Ok(nn::Embedding {
|
||||||
weight: weight.into(),
|
weight: Param::from_tensor(weight),
|
||||||
};
|
})
|
||||||
|
|
||||||
let embedding = nn::EmbeddingConfig::new(n_vocab, n_state).init_with(record);
|
|
||||||
Ok(embedding)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load_layer_norm<B: Backend>(path: &str, device: &B::Device) -> Result<nn::LayerNorm<B>, Box<dyn Error>> {
|
pub fn load_layer_norm<B: Backend>(
|
||||||
|
path: &str,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Result<nn::LayerNorm<B>, Box<dyn Error>> {
|
||||||
let weight = load_tensor::<B, 1>("weight", path, device)?;
|
let weight = load_tensor::<B, 1>("weight", path, device)?;
|
||||||
let bias = load_tensor::<B, 1>("bias", path, device)?;
|
let bias = load_tensor::<B, 1>("bias", path, device)?;
|
||||||
let eps = load_f32::<B>("eps", path, device)? as f64;
|
let eps = load_f32::<B>("eps", path, device)? as f64;
|
||||||
|
|
||||||
let [n_state] = weight.dims();
|
let [n_state] = weight.dims();
|
||||||
|
|
||||||
let record = nn::LayerNormRecord {
|
let mut layer_norm = nn::LayerNormConfig::new(n_state).with_epsilon(eps).init(device);
|
||||||
gamma: weight.into(),
|
layer_norm.gamma = Param::from_tensor(weight);
|
||||||
beta: bias.into(),
|
layer_norm.beta = Some(Param::from_tensor(bias));
|
||||||
epsilon: <f64 as Module<B>>::into_record(eps),
|
|
||||||
};
|
|
||||||
|
|
||||||
let layer_norm: nn::LayerNorm<B> = nn::LayerNormConfig::new(n_state).init_with(record);
|
|
||||||
|
|
||||||
Ok(layer_norm)
|
Ok(layer_norm)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/*pub fn load_rmsnorm<B: Backend>(path: &str, device: &B::Device) -> Result<RMSNorm<B>, Box<dyn Error>> {
|
/*pub fn load_rmsnorm<B: Backend>(path: &str, device: &B::Device) -> Result<RMSNorm<B>, Box<dyn Error>> {
|
||||||
let weight = load_tensor::<B, 1>("weight", path, device)?;
|
let weight = load_tensor::<B, 1>("weight", path, device)?;
|
||||||
let eps = load_f32::<B>("eps", path, device)?.into();
|
let eps = load_f32::<B>("eps", path, device)?.into();
|
||||||
|
|
||||||
let rmsnorm = RMSNorm {
|
let rmsnorm = RMSNorm {
|
||||||
weight: weight.into(),
|
weight: Param::from_tensor(weight),
|
||||||
eps: eps
|
eps: eps
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(rmsnorm)
|
Ok(rmsnorm)
|
||||||
}*/
|
}*/
|
||||||
|
|
||||||
pub fn load_conv2d<B: Backend>(path: &str, device: &B::Device) -> Result<conv::Conv2d<B>, Box<dyn Error>> {
|
pub fn load_conv2d<B: Backend>(
|
||||||
|
path: &str,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Result<conv::Conv2d<B>, Box<dyn Error>> {
|
||||||
let weight = load_tensor::<B, 4>("weight", path, device)?;
|
let weight = load_tensor::<B, 4>("weight", path, device)?;
|
||||||
let bias = load_tensor::<B, 1>("bias", path, device).ok();
|
let bias = load_tensor::<B, 1>("bias", path, device).ok();
|
||||||
let has_bias = bias.is_some();
|
let has_bias = bias.is_some();
|
||||||
@@ -127,40 +140,38 @@ pub fn load_conv2d<B: Backend>(path: &str, device: &B::Device) -> Result<conv::C
|
|||||||
let padding = tensor_to_array_2(padding);
|
let padding = tensor_to_array_2(padding);
|
||||||
let padding = nn::PaddingConfig2d::Explicit(padding[0], padding[1]);
|
let padding = nn::PaddingConfig2d::Explicit(padding[0], padding[1]);
|
||||||
|
|
||||||
|
let mut conv2d = conv::Conv2dConfig::new([n_channels_in, n_channels_out], kernel_size)
|
||||||
let record = conv::Conv2dRecord {
|
|
||||||
weight: weight.into(),
|
|
||||||
bias: bias.map(|t| t.into()),
|
|
||||||
stride: <[usize; 2] as Module<B>>::into_record(stride),
|
|
||||||
kernel_size: <[usize; 2] as Module<B>>::into_record(kernel_size),
|
|
||||||
dilation: <[usize; 2] as Module<B>>::into_record(dilation),
|
|
||||||
groups: <usize as Module<B>>::into_record(n_group),
|
|
||||||
padding: <nn::PaddingConfig2d as Module<B>>::into_record(padding.clone()),
|
|
||||||
};
|
|
||||||
|
|
||||||
let conv2d: conv::Conv2d<B> = conv::Conv2dConfig::new([n_channels_in, n_channels_out], kernel_size)
|
|
||||||
.with_stride(stride)
|
.with_stride(stride)
|
||||||
.with_dilation(dilation)
|
.with_dilation(dilation)
|
||||||
.with_groups(n_group)
|
.with_groups(n_group)
|
||||||
.with_padding(padding)
|
.with_padding(padding.clone())
|
||||||
.with_bias(has_bias)
|
.with_bias(has_bias)
|
||||||
.init_with(record);
|
.init(device);
|
||||||
|
|
||||||
|
conv2d.weight = Param::from_tensor(weight);
|
||||||
|
conv2d.bias = bias.map(|t| Param::from_tensor(t));
|
||||||
|
conv2d.stride = stride;
|
||||||
|
conv2d.kernel_size = kernel_size;
|
||||||
|
conv2d.dilation = dilation;
|
||||||
|
conv2d.groups = n_group;
|
||||||
|
conv2d.padding = burn::module::Ignored(padding);
|
||||||
|
|
||||||
Ok(conv2d)
|
Ok(conv2d)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn tensor_to_array_2<B: Backend>(x: Tensor<B, 1>) -> [usize; 2] {
|
pub fn tensor_to_array_2<B: Backend>(x: Tensor<B, 1>) -> [usize; 2] {
|
||||||
let vec = x.into_data().value;
|
let vec: Vec<<B as Backend>::FloatElem> = x.into_data().to_vec().unwrap();
|
||||||
assert!(vec.len() == 2, "Tensor length must be 2.");
|
assert!(vec.len() == 2, "Tensor length must be 2.");
|
||||||
[vec[0].to_usize().unwrap(), vec[1].to_usize().unwrap()]
|
[vec[0].to_usize(), vec[1].to_usize()]
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn tensor_to_array<const N: usize, B: Backend>(x: Tensor<B, 1>) -> [usize; N] {
|
pub fn tensor_to_array<const N: usize, B: Backend>(x: Tensor<B, 1>) -> [usize; N] {
|
||||||
let vec = x.into_data().value;
|
let vec: Vec<<B as Backend>::FloatElem> = x.into_data().to_vec().unwrap();
|
||||||
assert!(vec.len() == N, "Tensor length must be {}.", N);
|
assert!(vec.len() == N, "Tensor length must be {}.", N);
|
||||||
|
|
||||||
let mut arr = [0; N];
|
let mut arr = [0; N];
|
||||||
for (a, t) in arr.iter_mut().zip(vec) {
|
for (a, t) in arr.iter_mut().zip(vec) {
|
||||||
*a = t.to_usize().unwrap();
|
*a = t.to_usize();
|
||||||
}
|
}
|
||||||
|
|
||||||
arr
|
arr
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
pub mod stablediffusion;
|
pub mod stablediffusion;
|
||||||
|
|
||||||
pub mod autoencoder;
|
pub mod autoencoder;
|
||||||
pub mod unet;
|
|
||||||
pub mod clip;
|
pub mod clip;
|
||||||
|
pub mod unet;
|
||||||
|
|
||||||
pub mod silu;
|
|
||||||
pub mod groupnorm;
|
|
||||||
pub mod attention;
|
pub mod attention;
|
||||||
|
pub mod groupnorm;
|
||||||
|
pub mod silu;
|
||||||
|
|
||||||
pub mod load;
|
pub mod load;
|
||||||
@@ -1,13 +1,8 @@
|
|||||||
use burn::{
|
use burn::{
|
||||||
module::Module,
|
module::Module,
|
||||||
tensor::{
|
tensor::{activation::sigmoid, backend::Backend, Tensor},
|
||||||
backend::Backend,
|
|
||||||
activation::sigmoid,
|
|
||||||
Tensor,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
#[derive(Module, Clone, Debug)]
|
#[derive(Module, Clone, Debug)]
|
||||||
pub struct SILU {}
|
pub struct SILU {}
|
||||||
|
|
||||||
|
|||||||
@@ -1,22 +1,24 @@
|
|||||||
use std::error::Error;
|
|
||||||
use burn::tensor::ElementConversion;
|
use burn::tensor::ElementConversion;
|
||||||
|
use std::error::Error;
|
||||||
|
|
||||||
use burn::{
|
use burn::{
|
||||||
config::Config,
|
config::Config,
|
||||||
module::{Module, Param},
|
module::{Module, Param},
|
||||||
nn,
|
nn,
|
||||||
tensor::{
|
tensor::{backend::Backend, Tensor},
|
||||||
backend::Backend,
|
|
||||||
Tensor,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::model::{load::*, autoencoder::load::load_autoencoder, unet::load::load_unet, clip::load::load_clip};
|
use crate::model::{
|
||||||
|
autoencoder::load::load_autoencoder, clip::load::load_clip, load::*, unet::load::load_unet,
|
||||||
|
};
|
||||||
|
|
||||||
pub fn load_stable_diffusion<B: Backend>(path: &str, device: &B::Device) -> Result<StableDiffusion<B>, Box<dyn Error>> {
|
pub fn load_stable_diffusion<B: Backend>(
|
||||||
|
path: &str,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Result<StableDiffusion<B>, Box<dyn Error>> {
|
||||||
let n_steps = load_usize::<B>("n_steps", path, device)?;
|
let n_steps = load_usize::<B>("n_steps", path, device)?;
|
||||||
let alpha_cumulative_products = load_tensor::<B, 1>("alphas_cumprod", path, device)?.into();
|
let alpha_cumulative_products = Param::from_tensor(load_tensor::<B, 1>("alphas_cumprod", path, device)?);
|
||||||
let autoencoder = load_autoencoder(&format!("{}/{}", path, "autoencoder"), device)?;
|
let autoencoder = load_autoencoder(&format!("{}/{}", path, "autoencoder"), device)?;
|
||||||
let diffusion = load_unet(&format!("{}/{}", path, "unet"), device)?;
|
let diffusion = load_unet(&format!("{}/{}", path, "unet"), device)?;
|
||||||
let clip = load_clip(&format!("{}/{}", path, "clip"), device)?;
|
let clip = load_clip(&format!("{}/{}", path, "clip"), device)?;
|
||||||
@@ -29,4 +31,3 @@ pub fn load_stable_diffusion<B: Backend>(path: &str, device: &B::Device) -> Resu
|
|||||||
clip,
|
clip,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,37 +3,30 @@ pub mod load;
|
|||||||
use burn::{
|
use burn::{
|
||||||
config::Config,
|
config::Config,
|
||||||
module::{Module, Param},
|
module::{Module, Param},
|
||||||
tensor::{
|
tensor::{backend::Backend, BasicOps, Distribution, Float, Int, Tensor},
|
||||||
backend::Backend,
|
tensor::cast::ToElement,
|
||||||
Tensor,
|
|
||||||
Int,
|
|
||||||
Float,
|
|
||||||
BasicOps,
|
|
||||||
Data,
|
|
||||||
Distribution,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use num_traits::ToPrimitive;
|
use num_traits::ToPrimitive;
|
||||||
|
|
||||||
|
//use crate::backend::Backend as MyBackend;
|
||||||
|
|
||||||
use super::autoencoder::{Autoencoder, AutoencoderConfig};
|
use super::autoencoder::{Autoencoder, AutoencoderConfig};
|
||||||
|
use super::clip::{CLIPConfig, CLIP};
|
||||||
use super::unet::{UNet, UNetConfig};
|
use super::unet::{UNet, UNetConfig};
|
||||||
use super::clip::{CLIP, CLIPConfig};
|
|
||||||
use crate::tokenizer::SimpleTokenizer;
|
use crate::tokenizer::SimpleTokenizer;
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct StableDiffusionConfig {
|
pub struct StableDiffusionConfig {}
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
impl StableDiffusionConfig {
|
impl StableDiffusionConfig {
|
||||||
pub fn init<B: Backend>(&self) -> StableDiffusion<B> {
|
pub fn init<B: Backend>(&self, device: &B::Device) -> StableDiffusion<B> {
|
||||||
let n_steps = 1000;
|
let n_steps = 1000;
|
||||||
let alpha_cumulative_products = offset_cosine_schedule_cumprod::<B>(n_steps).into();
|
let alpha_cumulative_products = Param::from_tensor(offset_cosine_schedule_cumprod::<B>(n_steps as i64, device));
|
||||||
|
|
||||||
let autoencoder = AutoencoderConfig::new().init();
|
let autoencoder = AutoencoderConfig::new().init(device);
|
||||||
let diffusion = UNetConfig::new().init();
|
let diffusion = UNetConfig::new().init(device);
|
||||||
let clip = CLIPConfig::new(49408, 768, 12, 77, 12).init();
|
let clip = CLIPConfig::new(49408, 768, 12, 77, 12).init(device);
|
||||||
|
|
||||||
StableDiffusion {
|
StableDiffusion {
|
||||||
n_steps,
|
n_steps,
|
||||||
@@ -55,10 +48,26 @@ pub struct StableDiffusion<B: Backend> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> StableDiffusion<B> {
|
impl<B: Backend> StableDiffusion<B> {
|
||||||
pub fn sample_image(&self, context: Tensor<B, 3>, unconditional_context: Tensor<B, 2>, unconditional_guidance_scale: f64, n_steps: usize) -> Vec<Vec<u8>> {
|
pub fn sample_image(
|
||||||
|
&self,
|
||||||
|
context: Tensor<B, 3>,
|
||||||
|
unconditional_context: Tensor<B, 2>,
|
||||||
|
unconditional_guidance_scale: f64,
|
||||||
|
n_steps: usize,
|
||||||
|
) -> Vec<Vec<u8>> {
|
||||||
let [n_batch, _, _] = context.dims();
|
let [n_batch, _, _] = context.dims();
|
||||||
|
|
||||||
let latent = self.sample_latent(context, unconditional_context, unconditional_guidance_scale, n_steps);
|
let latent = self.sample_latent(
|
||||||
|
context,
|
||||||
|
unconditional_context,
|
||||||
|
unconditional_guidance_scale,
|
||||||
|
n_steps,
|
||||||
|
);
|
||||||
|
self.latent_to_image(latent)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn latent_to_image(&self, latent: Tensor<B, 4>) -> Vec<Vec<u8>> {
|
||||||
|
let [n_batch, _, _, _] = latent.dims();
|
||||||
let image = self.autoencoder.decode_latent(latent * (1.0 / 0.18215));
|
let image = self.autoencoder.decode_latent(latent * (1.0 / 0.18215));
|
||||||
|
|
||||||
let n_channel = 3;
|
let n_channel = 3;
|
||||||
@@ -74,19 +83,29 @@ impl<B: Backend> StableDiffusion<B> {
|
|||||||
.swap_dims(2, 3)
|
.swap_dims(2, 3)
|
||||||
.mul_scalar(255.0);
|
.mul_scalar(255.0);
|
||||||
|
|
||||||
let flattened: Vec<_> = image.
|
let flattened: Vec<B::FloatElem> = image.into_data().to_vec().unwrap();
|
||||||
into_data().
|
|
||||||
value;
|
|
||||||
|
|
||||||
(0..n_batch).into_iter().map(|b| {
|
(0..n_batch)
|
||||||
|
.into_iter()
|
||||||
|
.map(|b| {
|
||||||
let start = b * num_elements_per_image;
|
let start = b * num_elements_per_image;
|
||||||
let end = start + num_elements_per_image;
|
let end = start + num_elements_per_image;
|
||||||
|
|
||||||
flattened[start..end].into_iter().map(|v| v.to_f64().unwrap().min(255.0).max(0.0).to_u8().unwrap()).collect()
|
flattened[start..end]
|
||||||
}).collect()
|
.into_iter()
|
||||||
|
.map(|v| v.to_f64().min(255.0).max(0.0) as u8)
|
||||||
|
.collect()
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn sample_latent(&self, context: Tensor<B, 3>, unconditional_context: Tensor<B, 2>, unconditional_guidance_scale: f64, n_steps: usize) -> Tensor<B, 4> {
|
pub fn sample_latent(
|
||||||
|
&self,
|
||||||
|
context: Tensor<B, 3>,
|
||||||
|
unconditional_context: Tensor<B, 2>,
|
||||||
|
unconditional_guidance_scale: f64,
|
||||||
|
n_steps: usize,
|
||||||
|
) -> Tensor<B, 4> {
|
||||||
let device = context.device();
|
let device = context.device();
|
||||||
|
|
||||||
let step_size = self.n_steps / n_steps;
|
let step_size = self.n_steps / n_steps;
|
||||||
@@ -94,7 +113,7 @@ impl<B: Backend> StableDiffusion<B> {
|
|||||||
let [n_batches, _, _] = context.dims();
|
let [n_batches, _, _] = context.dims();
|
||||||
|
|
||||||
let gen_noise = || {
|
let gen_noise = || {
|
||||||
Tensor::random([n_batches, 4, 64, 64], Distribution::Normal(0.0, 1.0)).to_device(&device)
|
Tensor::random([n_batches, 4, 64, 64], Distribution::Normal(0.0, 1.0), &device)
|
||||||
};
|
};
|
||||||
|
|
||||||
let sigma = 0.0; // Use deterministic diffusion
|
let sigma = 0.0; // Use deterministic diffusion
|
||||||
@@ -102,18 +121,34 @@ impl<B: Backend> StableDiffusion<B> {
|
|||||||
let mut latent = gen_noise();
|
let mut latent = gen_noise();
|
||||||
|
|
||||||
for t in (0..self.n_steps).rev().step_by(step_size) {
|
for t in (0..self.n_steps).rev().step_by(step_size) {
|
||||||
let current_alpha: f64 = self.alpha_cumulative_products.val().slice([t..t + 1]).into_scalar().to_f64().unwrap();
|
let current_alpha: f64 = self
|
||||||
|
.alpha_cumulative_products
|
||||||
|
.val()
|
||||||
|
.slice([t..t + 1])
|
||||||
|
.into_scalar()
|
||||||
|
.to_f64();
|
||||||
|
|
||||||
let prev_alpha: f64 = if t >= step_size {
|
let prev_alpha: f64 = if t >= step_size {
|
||||||
let i = t - step_size;
|
let i = t - step_size;
|
||||||
self.alpha_cumulative_products.val().slice([i..i + 1]).into_scalar().to_f64().unwrap()
|
self.alpha_cumulative_products
|
||||||
|
.val()
|
||||||
|
.slice([i..i + 1])
|
||||||
|
.into_scalar()
|
||||||
|
.to_f64()
|
||||||
} else {
|
} else {
|
||||||
1.0
|
1.0
|
||||||
};
|
};
|
||||||
|
|
||||||
let sqrt_noise = (1.0 - current_alpha).sqrt();
|
let sqrt_noise = (1.0 - current_alpha).sqrt();
|
||||||
|
|
||||||
let timestep = Tensor::from_ints([t as i32]).to_device(&device);
|
let timestep = Tensor::from_ints([t as i32], &device);
|
||||||
let pred_noise = self.forward_diffuser(latent.clone(), timestep, context.clone(), unconditional_context.clone(), unconditional_guidance_scale);
|
let pred_noise = self.forward_diffuser(
|
||||||
|
latent.clone(),
|
||||||
|
timestep,
|
||||||
|
context.clone(),
|
||||||
|
unconditional_context.clone(),
|
||||||
|
unconditional_guidance_scale,
|
||||||
|
);
|
||||||
let predx0 = (latent - pred_noise.clone() * sqrt_noise) / current_alpha.sqrt();
|
let predx0 = (latent - pred_noise.clone() * sqrt_noise) / current_alpha.sqrt();
|
||||||
let dir_latent = pred_noise * (1.0 - prev_alpha - sigma * sigma).sqrt();
|
let dir_latent = pred_noise * (1.0 - prev_alpha - sigma * sigma).sqrt();
|
||||||
|
|
||||||
@@ -124,21 +159,24 @@ impl<B: Backend> StableDiffusion<B> {
|
|||||||
latent
|
latent
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward_diffuser(&self, latent: Tensor<B, 4>, timestep: Tensor<B, 1, Int>, context: Tensor<B, 3>, unconditional_context: Tensor<B, 2>, unconditional_guidance_scale: f64) -> Tensor<B, 4> {
|
fn forward_diffuser(
|
||||||
|
&self,
|
||||||
|
latent: Tensor<B, 4>,
|
||||||
|
timestep: Tensor<B, 1, Int>,
|
||||||
|
context: Tensor<B, 3>,
|
||||||
|
unconditional_context: Tensor<B, 2>,
|
||||||
|
unconditional_guidance_scale: f64,
|
||||||
|
) -> Tensor<B, 4> {
|
||||||
let [n_batch, _, _, _] = latent.dims();
|
let [n_batch, _, _, _] = latent.dims();
|
||||||
//let latent = latent.repeat(0, 2);
|
//let latent = latent.repeat(0, 2);
|
||||||
|
|
||||||
let unconditional_latent = self.diffusion.forward(
|
let unconditional_latent = self.diffusion.forward(
|
||||||
latent.clone(),
|
latent.clone(),
|
||||||
timestep.clone(),
|
timestep.clone(),
|
||||||
unconditional_context.unsqueeze().repeat(0, n_batch)
|
unconditional_context.unsqueeze().repeat(&[0, n_batch]),
|
||||||
);
|
);
|
||||||
|
|
||||||
let conditional_latent = self.diffusion.forward(
|
let conditional_latent = self.diffusion.forward(latent, timestep, context);
|
||||||
latent,
|
|
||||||
timestep,
|
|
||||||
context
|
|
||||||
);
|
|
||||||
|
|
||||||
/*let latent = self.diffusion.forward(
|
/*let latent = self.diffusion.forward(
|
||||||
latent.repeat(0, 2),
|
latent.repeat(0, 2),
|
||||||
@@ -149,43 +187,51 @@ impl<B: Backend> StableDiffusion<B> {
|
|||||||
let unconditional_latent = latent.clone().slice([0..n_batch]);
|
let unconditional_latent = latent.clone().slice([0..n_batch]);
|
||||||
let conditional_latent = latent.slice([n_batch..2 * n_batch]);*/
|
let conditional_latent = latent.slice([n_batch..2 * n_batch]);*/
|
||||||
|
|
||||||
unconditional_latent.clone() + (conditional_latent - unconditional_latent) * unconditional_guidance_scale
|
unconditional_latent.clone()
|
||||||
|
+ (conditional_latent - unconditional_latent) * unconditional_guidance_scale
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn unconditional_context(&self, tokenizer: &SimpleTokenizer) -> Tensor<B, 2> {
|
pub fn unconditional_context(&self, tokenizer: &SimpleTokenizer) -> Tensor<B, 2> {
|
||||||
self.context(tokenizer, "").squeeze(0)
|
self.context(tokenizer, "").squeeze::<2>()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn context(&self, tokenizer: &SimpleTokenizer, text: &str) -> Tensor<B, 3> {
|
pub fn context(&self, tokenizer: &SimpleTokenizer, text: &str) -> Tensor<B, 3> {
|
||||||
let device = &self.devices()[0];
|
let device = &self.clip.devices()[0];
|
||||||
let text = format!("<|startoftext|>{}<|endoftext|>", text);
|
let text = format!("<|startoftext|>{}<|endoftext|>", text);
|
||||||
let tokenized: Vec<_> = tokenizer.encode(&text).into_iter().map(|v| v as i32).collect();
|
let tokenized: Vec<_> = tokenizer
|
||||||
|
.encode(&text)
|
||||||
|
.into_iter()
|
||||||
|
.map(|v| v as i32)
|
||||||
|
.collect();
|
||||||
|
|
||||||
self.clip.forward(Tensor::from_ints(&tokenized[..]).to_device(device).unsqueeze())
|
self.clip.forward(
|
||||||
|
Tensor::<B, 1, Int>::from_ints(&tokenized[..], device)
|
||||||
|
.unsqueeze(),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
use crate::helper::to_float;
|
|
||||||
use std::f64::consts::PI;
|
use std::f64::consts::PI;
|
||||||
|
|
||||||
fn cosine_schedule<B: Backend>(n_steps: usize) -> Tensor<B, 1> {
|
fn cosine_schedule<B: Backend>(n_steps: i64, device: &B::Device) -> Tensor<B, 1> {
|
||||||
to_float(Tensor::arange(1..n_steps + 1))
|
Tensor::arange(1..n_steps + 1, device)
|
||||||
|
.float()
|
||||||
.mul_scalar(PI * 0.5 / n_steps as f64)
|
.mul_scalar(PI * 0.5 / n_steps as f64)
|
||||||
.cos()
|
.cos()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn offset_cosine_schedule<B: Backend>(n_steps: usize) -> Tensor<B, 1> {
|
fn offset_cosine_schedule<B: Backend>(n_steps: i64, device: &B::Device) -> Tensor<B, 1> {
|
||||||
let min_signal_rate: f64 = 0.02;
|
let min_signal_rate: f64 = 0.02;
|
||||||
let max_signal_rate: f64 = 0.95;
|
let max_signal_rate: f64 = 0.95;
|
||||||
let start_angle = max_signal_rate.acos();
|
let start_angle = max_signal_rate.acos();
|
||||||
let end_angle = min_signal_rate.acos();
|
let end_angle = min_signal_rate.acos();
|
||||||
|
|
||||||
let times = Tensor::arange(1..n_steps + 1);
|
let times = Tensor::arange(1..n_steps + 1, device).float();
|
||||||
|
|
||||||
let diffusion_angles = to_float(times) * ( (end_angle - start_angle) / n_steps as f64) + start_angle;
|
let diffusion_angles = times * ((end_angle - start_angle) / n_steps as f64) + start_angle;
|
||||||
diffusion_angles.cos()
|
diffusion_angles.cos()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn offset_cosine_schedule_cumprod<B: Backend>(n_steps: usize) -> Tensor<B, 1> {
|
fn offset_cosine_schedule_cumprod<B: Backend>(n_steps: i64, device: &B::Device) -> Tensor<B, 1> {
|
||||||
offset_cosine_schedule::<B>(n_steps).powf(2.0)
|
offset_cosine_schedule::<B>(n_steps, device).powf_scalar(2.0)
|
||||||
}
|
}
|
||||||
@@ -7,16 +7,16 @@ use burn::{
|
|||||||
config::Config,
|
config::Config,
|
||||||
module::{Module, Param},
|
module::{Module, Param},
|
||||||
nn,
|
nn,
|
||||||
tensor::{
|
tensor::{backend::Backend, Tensor},
|
||||||
backend::Backend,
|
|
||||||
Tensor,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::model::groupnorm::load::load_group_norm;
|
use crate::model::groupnorm::load::load_group_norm;
|
||||||
|
|
||||||
pub fn load_res_block<B: Backend>(path: &str, device: &B::Device) -> Result<ResBlock<B>, Box<dyn Error>> {
|
pub fn load_res_block<B: Backend>(
|
||||||
|
path: &str,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Result<ResBlock<B>, Box<dyn Error>> {
|
||||||
let norm_in = load_group_norm::<B>(&format!("{}/{}", path, "norm_in"), device)?;
|
let norm_in = load_group_norm::<B>(&format!("{}/{}", path, "norm_in"), device)?;
|
||||||
let conv_in = load_conv2d::<B>(&format!("{}/{}", path, "conv_in"), device)?;
|
let conv_in = load_conv2d::<B>(&format!("{}/{}", path, "conv_in"), device)?;
|
||||||
let lin_embed = load_linear::<B>(&format!("{}/{}", path, "lin_embed"), device)?;
|
let lin_embed = load_linear::<B>(&format!("{}/{}", path, "lin_embed"), device)?;
|
||||||
@@ -39,7 +39,10 @@ pub fn load_res_block<B: Backend>(path: &str, device: &B::Device) -> Result<ResB
|
|||||||
Ok(res_block)
|
Ok(res_block)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load_multi_head_attention<B: Backend>(path: &str, device: &B::Device) -> Result<MultiHeadAttention<B>, Box<dyn Error>> {
|
pub fn load_multi_head_attention<B: Backend>(
|
||||||
|
path: &str,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Result<MultiHeadAttention<B>, Box<dyn Error>> {
|
||||||
let n_head = load_usize::<B>("n_head", path, device)?;
|
let n_head = load_usize::<B>("n_head", path, device)?;
|
||||||
let query = load_linear::<B>(&format!("{}/{}", path, "query"), device)?;
|
let query = load_linear::<B>(&format!("{}/{}", path, "query"), device)?;
|
||||||
let key = load_linear::<B>(&format!("{}/{}", path, "key"), device)?;
|
let key = load_linear::<B>(&format!("{}/{}", path, "key"), device)?;
|
||||||
@@ -57,19 +60,17 @@ pub fn load_multi_head_attention<B: Backend>(path: &str, device: &B::Device) ->
|
|||||||
Ok(multi_head_attention)
|
Ok(multi_head_attention)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
pub fn load_geglu<B: Backend>(path: &str, device: &B::Device) -> Result<GEGLU<B>, Box<dyn Error>> {
|
pub fn load_geglu<B: Backend>(path: &str, device: &B::Device) -> Result<GEGLU<B>, Box<dyn Error>> {
|
||||||
let proj = load_linear::<B>(&format!("{}/{}", path, "proj"), device)?;
|
let proj = load_linear::<B>(&format!("{}/{}", path, "proj"), device)?;
|
||||||
|
|
||||||
let geglue = GEGLU {
|
let geglue = GEGLU {
|
||||||
proj: proj,
|
proj: proj,
|
||||||
gelu: GELU::new(), // Assuming GELU::new() initializes a new GELU struct
|
gelu: Gelu::new(), // Assuming Gelu::new() initializes a new Gelu struct
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(geglue)
|
Ok(geglue)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
pub fn load_mlp<B: Backend>(path: &str, device: &B::Device) -> Result<MLP<B>, Box<dyn Error>> {
|
pub fn load_mlp<B: Backend>(path: &str, device: &B::Device) -> Result<MLP<B>, Box<dyn Error>> {
|
||||||
let geglu = load_geglu::<B>(&format!("{}/{}", path, "geglu"), device)?;
|
let geglu = load_geglu::<B>(&format!("{}/{}", path, "geglu"), device)?;
|
||||||
let lin = load_linear::<B>(&format!("{}/{}", path, "lin"), device)?;
|
let lin = load_linear::<B>(&format!("{}/{}", path, "lin"), device)?;
|
||||||
@@ -82,8 +83,10 @@ pub fn load_mlp<B: Backend>(path: &str, device: &B::Device) -> Result<MLP<B>, Bo
|
|||||||
Ok(mlp)
|
Ok(mlp)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn load_transformer_block<B: Backend>(
|
||||||
pub fn load_transformer_block<B: Backend>(path: &str, device: &B::Device) -> Result<TransformerBlock<B>, Box<dyn Error>> {
|
path: &str,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Result<TransformerBlock<B>, Box<dyn Error>> {
|
||||||
let norm1 = load_layer_norm::<B>(&format!("{}/{}", path, "norm1"), device)?;
|
let norm1 = load_layer_norm::<B>(&format!("{}/{}", path, "norm1"), device)?;
|
||||||
let attn1 = load_multi_head_attention::<B>(&format!("{}/{}", path, "attn1"), device)?;
|
let attn1 = load_multi_head_attention::<B>(&format!("{}/{}", path, "attn1"), device)?;
|
||||||
let norm2 = load_layer_norm::<B>(&format!("{}/{}", path, "norm2"), device)?;
|
let norm2 = load_layer_norm::<B>(&format!("{}/{}", path, "norm2"), device)?;
|
||||||
@@ -103,8 +106,10 @@ pub fn load_transformer_block<B: Backend>(path: &str, device: &B::Device) -> Res
|
|||||||
Ok(transformer_block)
|
Ok(transformer_block)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn load_spatial_transformer<B: Backend>(
|
||||||
pub fn load_spatial_transformer<B: Backend>(path: &str, device: &B::Device) -> Result<SpatialTransformer<B>, Box<dyn Error>> {
|
path: &str,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Result<SpatialTransformer<B>, Box<dyn Error>> {
|
||||||
let norm = load_group_norm::<B>(&format!("{}/{}", path, "norm"), device)?;
|
let norm = load_group_norm::<B>(&format!("{}/{}", path, "norm"), device)?;
|
||||||
let proj_in = load_conv2d::<B>(&format!("{}/{}", path, "proj_in"), device)?;
|
let proj_in = load_conv2d::<B>(&format!("{}/{}", path, "proj_in"), device)?;
|
||||||
let transformer = load_transformer_block::<B>(&format!("{}/{}", path, "transformer"), device)?;
|
let transformer = load_transformer_block::<B>(&format!("{}/{}", path, "transformer"), device)?;
|
||||||
@@ -120,24 +125,31 @@ pub fn load_spatial_transformer<B: Backend>(path: &str, device: &B::Device) -> R
|
|||||||
Ok(spatial_transformer)
|
Ok(spatial_transformer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn load_upsample<B: Backend>(
|
||||||
pub fn load_upsample<B: Backend>(path: &str, device: &B::Device) -> Result<Upsample<B>, Box<dyn Error>> {
|
path: &str,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Result<Upsample<B>, Box<dyn Error>> {
|
||||||
let conv = load_conv2d::<B>(&format!("{}/{}", path, "conv"), device)?;
|
let conv = load_conv2d::<B>(&format!("{}/{}", path, "conv"), device)?;
|
||||||
|
|
||||||
let upsample = Upsample {
|
let upsample = Upsample { conv: conv };
|
||||||
conv: conv,
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(upsample)
|
Ok(upsample)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load_downsample<B: Backend>(path: &str, device: &B::Device) -> Result<Downsample<B>, Box<dyn Error>> {
|
pub fn load_downsample<B: Backend>(
|
||||||
|
path: &str,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Result<Downsample<B>, Box<dyn Error>> {
|
||||||
load_conv2d(path, device)
|
load_conv2d(path, device)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load_res_transformer_res<B: Backend>(path: &str, device: &B::Device) -> Result<ResTransformerRes<B>, Box<dyn Error>> {
|
pub fn load_res_transformer_res<B: Backend>(
|
||||||
|
path: &str,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Result<ResTransformerRes<B>, Box<dyn Error>> {
|
||||||
let res1 = load_res_block::<B>(&format!("{}/{}", path, "res1"), device)?; // Assuming load_res_block function
|
let res1 = load_res_block::<B>(&format!("{}/{}", path, "res1"), device)?; // Assuming load_res_block function
|
||||||
let transformer = load_spatial_transformer::<B>(&format!("{}/{}", path, "transformer"), device)?;
|
let transformer =
|
||||||
|
load_spatial_transformer::<B>(&format!("{}/{}", path, "transformer"), device)?;
|
||||||
let res2 = load_res_block::<B>(&format!("{}/{}", path, "res2"), device)?;
|
let res2 = load_res_block::<B>(&format!("{}/{}", path, "res2"), device)?;
|
||||||
|
|
||||||
let res_transformer_res = ResTransformerRes {
|
let res_transformer_res = ResTransformerRes {
|
||||||
@@ -149,9 +161,13 @@ pub fn load_res_transformer_res<B: Backend>(path: &str, device: &B::Device) -> R
|
|||||||
Ok(res_transformer_res)
|
Ok(res_transformer_res)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load_res_transformer_upsample<B: Backend>(path: &str, device: &B::Device) -> Result<ResTransformerUpsample<B>, Box<dyn Error>> {
|
pub fn load_res_transformer_upsample<B: Backend>(
|
||||||
|
path: &str,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Result<ResTransformerUpsample<B>, Box<dyn Error>> {
|
||||||
let res = load_res_block::<B>(&format!("{}/{}", path, "res"), device)?;
|
let res = load_res_block::<B>(&format!("{}/{}", path, "res"), device)?;
|
||||||
let transformer = load_spatial_transformer::<B>(&format!("{}/{}", path, "transformer"), device)?;
|
let transformer =
|
||||||
|
load_spatial_transformer::<B>(&format!("{}/{}", path, "transformer"), device)?;
|
||||||
let upsample = load_upsample::<B>(&format!("{}/{}", path, "upsample"), device)?;
|
let upsample = load_upsample::<B>(&format!("{}/{}", path, "upsample"), device)?;
|
||||||
|
|
||||||
let res_transformer_upsample = ResTransformerUpsample {
|
let res_transformer_upsample = ResTransformerUpsample {
|
||||||
@@ -163,8 +179,10 @@ pub fn load_res_transformer_upsample<B: Backend>(path: &str, device: &B::Device)
|
|||||||
Ok(res_transformer_upsample)
|
Ok(res_transformer_upsample)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn load_res_upsample<B: Backend>(
|
||||||
pub fn load_res_upsample<B: Backend>(path: &str, device: &B::Device) -> Result<ResUpSample<B>, Box<dyn Error>> {
|
path: &str,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Result<ResUpSample<B>, Box<dyn Error>> {
|
||||||
let res = load_res_block::<B>(&format!("{}/{}", path, "res"), device)?;
|
let res = load_res_block::<B>(&format!("{}/{}", path, "res"), device)?;
|
||||||
let upsample = load_upsample::<B>(&format!("{}/{}", path, "upsample"), device)?;
|
let upsample = load_upsample::<B>(&format!("{}/{}", path, "upsample"), device)?;
|
||||||
|
|
||||||
@@ -176,10 +194,13 @@ pub fn load_res_upsample<B: Backend>(path: &str, device: &B::Device) -> Result<R
|
|||||||
Ok(res_upsample)
|
Ok(res_upsample)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn load_res_transformer<B: Backend>(
|
||||||
pub fn load_res_transformer<B: Backend>(path: &str, device: &B::Device) -> Result<ResTransformer<B>, Box<dyn Error>> {
|
path: &str,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Result<ResTransformer<B>, Box<dyn Error>> {
|
||||||
let res = load_res_block::<B>(&format!("{}/{}", path, "res"), device)?;
|
let res = load_res_block::<B>(&format!("{}/{}", path, "res"), device)?;
|
||||||
let transformer = load_spatial_transformer::<B>(&format!("{}/{}", path, "transformer"), device)?;
|
let transformer =
|
||||||
|
load_spatial_transformer::<B>(&format!("{}/{}", path, "transformer"), device)?;
|
||||||
|
|
||||||
let res_transformer = ResTransformer {
|
let res_transformer = ResTransformer {
|
||||||
res: res,
|
res: res,
|
||||||
@@ -189,8 +210,10 @@ pub fn load_res_transformer<B: Backend>(path: &str, device: &B::Device) -> Resul
|
|||||||
Ok(res_transformer)
|
Ok(res_transformer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn load_unet_input_blocks<B: Backend>(
|
||||||
pub fn load_unet_input_blocks<B: Backend>(path: &str, device: &B::Device) -> Result<UNetInputBlocks<B>, Box<dyn Error>> {
|
path: &str,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Result<UNetInputBlocks<B>, Box<dyn Error>> {
|
||||||
let conv = load_conv2d::<B>(&format!("{}/{}", path, "conv"), device)?;
|
let conv = load_conv2d::<B>(&format!("{}/{}", path, "conv"), device)?;
|
||||||
let rt1 = load_res_transformer::<B>(&format!("{}/{}", path, "rt1"), device)?;
|
let rt1 = load_res_transformer::<B>(&format!("{}/{}", path, "rt1"), device)?;
|
||||||
let rt2 = load_res_transformer::<B>(&format!("{}/{}", path, "rt2"), device)?;
|
let rt2 = load_res_transformer::<B>(&format!("{}/{}", path, "rt2"), device)?;
|
||||||
@@ -222,7 +245,10 @@ pub fn load_unet_input_blocks<B: Backend>(path: &str, device: &B::Device) -> Res
|
|||||||
Ok(unet_input_blocks)
|
Ok(unet_input_blocks)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load_unet_output_blocks<B: Backend>(path: &str, device: &B::Device) -> Result<UNetOutputBlocks<B>, Box<dyn Error>> {
|
pub fn load_unet_output_blocks<B: Backend>(
|
||||||
|
path: &str,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Result<UNetOutputBlocks<B>, Box<dyn Error>> {
|
||||||
let r1 = load_res_block::<B>(&format!("{}/{}", path, "r1"), device)?;
|
let r1 = load_res_block::<B>(&format!("{}/{}", path, "r1"), device)?;
|
||||||
let r2 = load_res_block::<B>(&format!("{}/{}", path, "r2"), device)?;
|
let r2 = load_res_block::<B>(&format!("{}/{}", path, "r2"), device)?;
|
||||||
let ru = load_res_upsample::<B>(&format!("{}/{}", path, "ru"), device)?;
|
let ru = load_res_upsample::<B>(&format!("{}/{}", path, "ru"), device)?;
|
||||||
@@ -252,14 +278,16 @@ pub fn load_unet_output_blocks<B: Backend>(path: &str, device: &B::Device) -> Re
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
pub fn load_unet<B: Backend>(path: &str, device: &B::Device) -> Result<UNet<B>, Box<dyn Error>> {
|
pub fn load_unet<B: Backend>(path: &str, device: &B::Device) -> Result<UNet<B>, Box<dyn Error>> {
|
||||||
let lin1_time_embed = load_linear::<B>(&format!("{}/{}", path, "lin1_time_embed"), device)?;
|
let lin1_time_embed = load_linear::<B>(&format!("{}/{}", path, "lin1_time_embed"), device)?;
|
||||||
let silu_time_embed = SILU::new(); // Assuming SILU::new() initializes a new SILU struct
|
let silu_time_embed = SILU::new(); // Assuming SILU::new() initializes a new SILU struct
|
||||||
let lin2_time_embed = load_linear::<B>(&format!("{}/{}", path, "lin2_time_embed"), device)?;
|
let lin2_time_embed = load_linear::<B>(&format!("{}/{}", path, "lin2_time_embed"), device)?;
|
||||||
let input_blocks = load_unet_input_blocks::<B>(&format!("{}/{}", path, "input_blocks"), device)?;
|
let input_blocks =
|
||||||
let middle_block = load_res_transformer_res::<B>(&format!("{}/{}", path, "middle_block"), device)?;
|
load_unet_input_blocks::<B>(&format!("{}/{}", path, "input_blocks"), device)?;
|
||||||
let output_blocks = load_unet_output_blocks::<B>(&format!("{}/{}", path, "output_blocks"), device)?;
|
let middle_block =
|
||||||
|
load_res_transformer_res::<B>(&format!("{}/{}", path, "middle_block"), device)?;
|
||||||
|
let output_blocks =
|
||||||
|
load_unet_output_blocks::<B>(&format!("{}/{}", path, "output_blocks"), device)?;
|
||||||
let norm_out = load_group_norm::<B>(&format!("{}/{}", path, "norm_out"), device)?;
|
let norm_out = load_group_norm::<B>(&format!("{}/{}", path, "norm_out"), device)?;
|
||||||
let silu_out = SILU::new(); // Assuming SILU::new() initializes a new SILU struct
|
let silu_out = SILU::new(); // Assuming SILU::new() initializes a new SILU struct
|
||||||
let conv_out = load_conv2d::<B>(&format!("{}/{}", path, "conv_out"), device)?;
|
let conv_out = load_conv2d::<B>(&format!("{}/{}", path, "conv_out"), device)?;
|
||||||
|
|||||||
@@ -3,76 +3,80 @@ pub mod load;
|
|||||||
use burn::{
|
use burn::{
|
||||||
config::Config,
|
config::Config,
|
||||||
module::{Module, Param},
|
module::{Module, Param},
|
||||||
nn::{self, PaddingConfig2d, GELU, conv::{Conv2d, Conv2dConfig}},
|
nn::{
|
||||||
tensor::{
|
self,
|
||||||
backend::Backend,
|
conv::{Conv2d, Conv2dConfig},
|
||||||
activation::softmax,
|
PaddingConfig2d, Gelu,
|
||||||
module::embedding,
|
|
||||||
Tensor,
|
|
||||||
Distribution,
|
|
||||||
Int,
|
|
||||||
},
|
},
|
||||||
|
tensor::{activation::softmax, backend::Backend, module::embedding, Distribution, Int, Tensor},
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::silu::*;
|
|
||||||
use super::groupnorm::*;
|
use super::groupnorm::*;
|
||||||
use crate::helper::to_float;
|
use super::silu::*;
|
||||||
|
|
||||||
use super::attention::qkv_attention;
|
use super::attention::qkv_attention;
|
||||||
|
|
||||||
|
fn timestep_embedding<B: Backend>(
|
||||||
fn timestep_embedding<B: Backend>(timesteps: Tensor<B, 1, Int>, dim: usize, max_period: usize) -> Tensor<B, 2> {
|
timesteps: Tensor<B, 1, Int>,
|
||||||
|
dim: usize,
|
||||||
|
max_period: usize,
|
||||||
|
) -> Tensor<B, 2> {
|
||||||
let half = dim / 2;
|
let half = dim / 2;
|
||||||
let freqs = ( to_float(Tensor::arange_device(0..half, ×teps.device())) * (-(max_period as f64).ln() / half as f64 ) ).exp();
|
let freqs = (Tensor::arange(0..half as i64, ×teps.device()).float()
|
||||||
let args = to_float(timesteps) * freqs;
|
* (-(max_period as f64).ln() / half as f64))
|
||||||
|
.exp();
|
||||||
|
let args = timesteps.float() * freqs;
|
||||||
Tensor::cat(vec![args.clone().cos(), args.sin()], 0).unsqueeze()
|
Tensor::cat(vec![args.clone().cos(), args.sin()], 0).unsqueeze()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Config, Debug)]
|
||||||
#[derive(Config)]
|
|
||||||
pub struct UNetConfig {}
|
pub struct UNetConfig {}
|
||||||
|
|
||||||
impl UNetConfig {
|
impl UNetConfig {
|
||||||
pub fn init<B: Backend>(&self) -> UNet<B> {
|
pub fn init<B: Backend>(&self, device: &B::Device) -> UNet<B> {
|
||||||
let lin1_time_embed = nn::LinearConfig::new(320, 1280).init();
|
let lin1_time_embed = nn::LinearConfig::new(320, 1280).init(device);
|
||||||
let silu_time_embed = SILU::new();
|
let silu_time_embed = SILU::new();
|
||||||
let lin2_time_embed = nn::LinearConfig::new(1280, 1280).init();
|
let lin2_time_embed = nn::LinearConfig::new(1280, 1280).init(device);
|
||||||
|
|
||||||
let input_blocks = UNetInputBlocks {
|
let input_blocks = UNetInputBlocks {
|
||||||
conv: Conv2dConfig::new([4, 320], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init(),
|
conv: Conv2dConfig::new([4, 320], [3, 3])
|
||||||
rt1: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(),
|
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||||
rt2: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(),
|
.init(device),
|
||||||
d1: DownsampleConfig::new(320).init(),
|
rt1: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(device),
|
||||||
rt3: ResTransformerConfig::new(320, 1280, 640, 768, 8).init(),
|
rt2: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(device),
|
||||||
rt4: ResTransformerConfig::new(640, 1280, 640, 768, 8).init(),
|
d1: DownsampleConfig::new(320).init(device),
|
||||||
d2: DownsampleConfig::new(640).init(),
|
rt3: ResTransformerConfig::new(320, 1280, 640, 768, 8).init(device),
|
||||||
rt5: ResTransformerConfig::new(640, 1280, 1280, 768, 8).init(),
|
rt4: ResTransformerConfig::new(640, 1280, 640, 768, 8).init(device),
|
||||||
rt6: ResTransformerConfig::new(1280, 1280, 1280, 768, 8).init(),
|
d2: DownsampleConfig::new(640).init(device),
|
||||||
d3: DownsampleConfig::new(1280).init(),
|
rt5: ResTransformerConfig::new(640, 1280, 1280, 768, 8).init(device),
|
||||||
r1: ResBlockConfig::new(1280, 1280, 1280).init(),
|
rt6: ResTransformerConfig::new(1280, 1280, 1280, 768, 8).init(device),
|
||||||
r2: ResBlockConfig::new(1280, 1280, 1280).init(),
|
d3: DownsampleConfig::new(1280).init(device),
|
||||||
|
r1: ResBlockConfig::new(1280, 1280, 1280).init(device),
|
||||||
|
r2: ResBlockConfig::new(1280, 1280, 1280).init(device),
|
||||||
};
|
};
|
||||||
|
|
||||||
let middle_block = ResTransformerResConfig::new(1280, 1280, 1280, 768, 8).init();
|
let middle_block = ResTransformerResConfig::new(1280, 1280, 1280, 768, 8).init(device);
|
||||||
|
|
||||||
let output_blocks = UNetOutputBlocks {
|
let output_blocks = UNetOutputBlocks {
|
||||||
r1: ResBlockConfig::new(2560, 1280, 1280).init(),
|
r1: ResBlockConfig::new(2560, 1280, 1280).init(device),
|
||||||
r2: ResBlockConfig::new(2560, 1280, 1280).init(),
|
r2: ResBlockConfig::new(2560, 1280, 1280).init(device),
|
||||||
ru: ResUpSampleConfig::new(2560, 1280, 1280).init(),
|
ru: ResUpSampleConfig::new(2560, 1280, 1280).init(device),
|
||||||
rt1: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(),
|
rt1: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(device),
|
||||||
rt2: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(),
|
rt2: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(device),
|
||||||
rtu1: ResTransformerUpsampleConfig::new(1920, 1280, 1280, 768, 8).init(),
|
rtu1: ResTransformerUpsampleConfig::new(1920, 1280, 1280, 768, 8).init(device),
|
||||||
rt3: ResTransformerConfig::new(1920, 1280, 640, 768, 8).init(),
|
rt3: ResTransformerConfig::new(1920, 1280, 640, 768, 8).init(device),
|
||||||
rt4: ResTransformerConfig::new(1280, 1280, 640, 768, 8).init(),
|
rt4: ResTransformerConfig::new(1280, 1280, 640, 768, 8).init(device),
|
||||||
rtu2: ResTransformerUpsampleConfig::new(960, 1280, 640, 768, 8).init(),
|
rtu2: ResTransformerUpsampleConfig::new(960, 1280, 640, 768, 8).init(device),
|
||||||
rt5: ResTransformerConfig::new(960, 1280, 320, 768, 8).init(),
|
rt5: ResTransformerConfig::new(960, 1280, 320, 768, 8).init(device),
|
||||||
rt6: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(),
|
rt6: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(device),
|
||||||
rt7: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(),
|
rt7: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(device),
|
||||||
};
|
};
|
||||||
|
|
||||||
let norm_out = GroupNormConfig::new(32, 320).init();
|
let norm_out = GroupNormConfig::new(32, 320).init(device);
|
||||||
let silu_out = SILU::new();
|
let silu_out = SILU::new();
|
||||||
let conv_out = Conv2dConfig::new([320, 4], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
let conv_out = Conv2dConfig::new([320, 4], [3, 3])
|
||||||
|
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||||
|
.init(device);
|
||||||
|
|
||||||
UNet {
|
UNet {
|
||||||
lin1_time_embed,
|
lin1_time_embed,
|
||||||
@@ -102,7 +106,12 @@ pub struct UNet<B: Backend> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> UNet<B> {
|
impl<B: Backend> UNet<B> {
|
||||||
pub fn forward(&self, x: Tensor<B, 4>, timesteps: Tensor<B, 1, Int>, context: Tensor<B, 3>) -> Tensor<B, 4> {
|
pub fn forward(
|
||||||
|
&self,
|
||||||
|
x: Tensor<B, 4>,
|
||||||
|
timesteps: Tensor<B, 1, Int>,
|
||||||
|
context: Tensor<B, 3>,
|
||||||
|
) -> Tensor<B, 4> {
|
||||||
let t_emb = timestep_embedding(timesteps, 320, 10000);
|
let t_emb = timestep_embedding(timesteps, 320, 10000);
|
||||||
let emb = self.lin1_time_embed.forward(t_emb);
|
let emb = self.lin1_time_embed.forward(t_emb);
|
||||||
let emb = self.silu_time_embed.forward(emb);
|
let emb = self.silu_time_embed.forward(emb);
|
||||||
@@ -133,8 +142,6 @@ impl<B: Backend> UNet<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#[derive(Module, Debug)]
|
#[derive(Module, Debug)]
|
||||||
pub struct UNetInputBlocks<B: Backend> {
|
pub struct UNetInputBlocks<B: Backend> {
|
||||||
conv: Conv2d<B>,
|
conv: Conv2d<B>,
|
||||||
@@ -154,18 +161,8 @@ pub struct UNetInputBlocks<B: Backend> {
|
|||||||
impl<B: Backend> UNetInputBlocks<B> {
|
impl<B: Backend> UNetInputBlocks<B> {
|
||||||
fn as_array(&self) -> [&dyn UNetBlock<B>; 12] {
|
fn as_array(&self) -> [&dyn UNetBlock<B>; 12] {
|
||||||
[
|
[
|
||||||
&self.conv,
|
&self.conv, &self.rt1, &self.rt2, &self.d1, &self.rt3, &self.rt4, &self.d2, &self.rt5,
|
||||||
&self.rt1,
|
&self.rt6, &self.d3, &self.r1, &self.r2,
|
||||||
&self.rt2,
|
|
||||||
&self.d1,
|
|
||||||
&self.rt3,
|
|
||||||
&self.rt4,
|
|
||||||
&self.d2,
|
|
||||||
&self.rt5,
|
|
||||||
&self.rt6,
|
|
||||||
&self.d3,
|
|
||||||
&self.r1,
|
|
||||||
&self.r2,
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -189,31 +186,17 @@ pub struct UNetOutputBlocks<B: Backend> {
|
|||||||
impl<B: Backend> UNetOutputBlocks<B> {
|
impl<B: Backend> UNetOutputBlocks<B> {
|
||||||
fn as_array(&self) -> [&dyn UNetBlock<B>; 12] {
|
fn as_array(&self) -> [&dyn UNetBlock<B>; 12] {
|
||||||
[
|
[
|
||||||
&self.r1,
|
&self.r1, &self.r2, &self.ru, &self.rt1, &self.rt2, &self.rtu1, &self.rt3, &self.rt4,
|
||||||
&self.r2,
|
&self.rtu2, &self.rt5, &self.rt6, &self.rt7,
|
||||||
&self.ru,
|
|
||||||
&self.rt1,
|
|
||||||
&self.rt2,
|
|
||||||
&self.rtu1,
|
|
||||||
&self.rt3,
|
|
||||||
&self.rt4,
|
|
||||||
&self.rtu2,
|
|
||||||
&self.rt5,
|
|
||||||
&self.rt6,
|
|
||||||
&self.rt7,
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
trait UNetBlock<B: Backend> {
|
trait UNetBlock<B: Backend> {
|
||||||
fn forward(&self, x: Tensor<B, 4>, emb: Tensor<B, 2>, context: Tensor<B, 3>) -> Tensor<B, 4>;
|
fn forward(&self, x: Tensor<B, 4>, emb: Tensor<B, 2>, context: Tensor<B, 3>) -> Tensor<B, 4>;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct ResTransformerConfig {
|
pub struct ResTransformerConfig {
|
||||||
n_channels_in: usize,
|
n_channels_in: usize,
|
||||||
n_channels_embed: usize,
|
n_channels_embed: usize,
|
||||||
@@ -223,14 +206,18 @@ pub struct ResTransformerConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl ResTransformerConfig {
|
impl ResTransformerConfig {
|
||||||
fn init<B: Backend>(&self) -> ResTransformer<B> {
|
fn init<B: Backend>(&self, device: &B::Device) -> ResTransformer<B> {
|
||||||
let res = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init();
|
let res = ResBlockConfig::new(
|
||||||
let transformer = SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head).init();
|
self.n_channels_in,
|
||||||
|
self.n_channels_embed,
|
||||||
|
self.n_channels_out,
|
||||||
|
)
|
||||||
|
.init(device);
|
||||||
|
let transformer =
|
||||||
|
SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head)
|
||||||
|
.init(device);
|
||||||
|
|
||||||
ResTransformer {
|
ResTransformer { res, transformer }
|
||||||
res,
|
|
||||||
transformer,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -248,7 +235,7 @@ impl<B: Backend> UNetBlock<B> for ResTransformer<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct ResUpSampleConfig {
|
pub struct ResUpSampleConfig {
|
||||||
n_channels_in: usize,
|
n_channels_in: usize,
|
||||||
n_channels_embed: usize,
|
n_channels_embed: usize,
|
||||||
@@ -256,14 +243,16 @@ pub struct ResUpSampleConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl ResUpSampleConfig {
|
impl ResUpSampleConfig {
|
||||||
fn init<B: Backend>(&self) -> ResUpSample<B> {
|
fn init<B: Backend>(&self, device: &B::Device) -> ResUpSample<B> {
|
||||||
let res = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init();
|
let res = ResBlockConfig::new(
|
||||||
let upsample = UpsampleConfig::new(self.n_channels_out).init();
|
self.n_channels_in,
|
||||||
|
self.n_channels_embed,
|
||||||
|
self.n_channels_out,
|
||||||
|
)
|
||||||
|
.init(device);
|
||||||
|
let upsample = UpsampleConfig::new(self.n_channels_out).init(device);
|
||||||
|
|
||||||
ResUpSample {
|
ResUpSample { res, upsample }
|
||||||
res,
|
|
||||||
upsample,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -281,7 +270,7 @@ impl<B: Backend> UNetBlock<B> for ResUpSample<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct ResTransformerUpsampleConfig {
|
pub struct ResTransformerUpsampleConfig {
|
||||||
n_channels_in: usize,
|
n_channels_in: usize,
|
||||||
n_channels_embed: usize,
|
n_channels_embed: usize,
|
||||||
@@ -291,10 +280,17 @@ pub struct ResTransformerUpsampleConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl ResTransformerUpsampleConfig {
|
impl ResTransformerUpsampleConfig {
|
||||||
fn init<B: Backend>(&self) -> ResTransformerUpsample<B> {
|
fn init<B: Backend>(&self, device: &B::Device) -> ResTransformerUpsample<B> {
|
||||||
let res = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init();
|
let res = ResBlockConfig::new(
|
||||||
let transformer = SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head).init();
|
self.n_channels_in,
|
||||||
let upsample = UpsampleConfig::new(self.n_channels_out).init();
|
self.n_channels_embed,
|
||||||
|
self.n_channels_out,
|
||||||
|
)
|
||||||
|
.init(device);
|
||||||
|
let transformer =
|
||||||
|
SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head)
|
||||||
|
.init(device);
|
||||||
|
let upsample = UpsampleConfig::new(self.n_channels_out).init(device);
|
||||||
|
|
||||||
ResTransformerUpsample {
|
ResTransformerUpsample {
|
||||||
res,
|
res,
|
||||||
@@ -320,7 +316,7 @@ impl<B: Backend> UNetBlock<B> for ResTransformerUpsample<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct ResTransformerResConfig {
|
pub struct ResTransformerResConfig {
|
||||||
n_channels_in: usize,
|
n_channels_in: usize,
|
||||||
n_channels_embed: usize,
|
n_channels_embed: usize,
|
||||||
@@ -330,10 +326,22 @@ pub struct ResTransformerResConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl ResTransformerResConfig {
|
impl ResTransformerResConfig {
|
||||||
fn init<B: Backend>(&self) -> ResTransformerRes<B> {
|
fn init<B: Backend>(&self, device: &B::Device) -> ResTransformerRes<B> {
|
||||||
let res1 = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init();
|
let res1 = ResBlockConfig::new(
|
||||||
let transformer = SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head).init();
|
self.n_channels_in,
|
||||||
let res2 = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init();
|
self.n_channels_embed,
|
||||||
|
self.n_channels_out,
|
||||||
|
)
|
||||||
|
.init(device);
|
||||||
|
let transformer =
|
||||||
|
SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head)
|
||||||
|
.init(device);
|
||||||
|
let res2 = ResBlockConfig::new(
|
||||||
|
self.n_channels_in,
|
||||||
|
self.n_channels_embed,
|
||||||
|
self.n_channels_out,
|
||||||
|
)
|
||||||
|
.init(device);
|
||||||
|
|
||||||
ResTransformerRes {
|
ResTransformerRes {
|
||||||
res1,
|
res1,
|
||||||
@@ -359,22 +367,18 @@ impl<B: Backend> UNetBlock<B> for ResTransformerRes<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Config, Debug)]
|
||||||
|
|
||||||
#[derive(Config)]
|
|
||||||
pub struct UpsampleConfig {
|
pub struct UpsampleConfig {
|
||||||
n_channels: usize,
|
n_channels: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UpsampleConfig {
|
impl UpsampleConfig {
|
||||||
fn init<B: Backend>(&self) -> Upsample<B> {
|
fn init<B: Backend>(&self, device: &B::Device) -> Upsample<B> {
|
||||||
let conv = Conv2dConfig::new([self.n_channels, self.n_channels], [3, 3])
|
let conv = Conv2dConfig::new([self.n_channels, self.n_channels], [3, 3])
|
||||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||||
.init();
|
.init(device);
|
||||||
|
|
||||||
Upsample {
|
Upsample { conv }
|
||||||
conv,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -388,8 +392,7 @@ impl<B: Backend> Upsample<B> {
|
|||||||
let [n_batch, n_channel, height, width] = x.dims();
|
let [n_batch, n_channel, height, width] = x.dims();
|
||||||
let x = x
|
let x = x
|
||||||
.reshape([n_batch, n_channel, height, 1, width, 1])
|
.reshape([n_batch, n_channel, height, 1, width, 1])
|
||||||
.repeat(3, 2)
|
.repeat(&[1, 1, 1, 2, 1, 2])
|
||||||
.repeat(5, 2)
|
|
||||||
.reshape([n_batch, n_channel, 2 * height, 2 * width]);
|
.reshape([n_batch, n_channel, 2 * height, 2 * width]);
|
||||||
self.conv.forward(x)
|
self.conv.forward(x)
|
||||||
}
|
}
|
||||||
@@ -401,17 +404,17 @@ impl<B: Backend> UNetBlock<B> for Upsample<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct DownsampleConfig {
|
pub struct DownsampleConfig {
|
||||||
n_channels: usize,
|
n_channels: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DownsampleConfig {
|
impl DownsampleConfig {
|
||||||
fn init<B: Backend>(&self) -> Conv2d<B> {
|
fn init<B: Backend>(&self, device: &B::Device) -> Conv2d<B> {
|
||||||
Conv2dConfig::new([self.n_channels, self.n_channels], [3, 3])
|
Conv2dConfig::new([self.n_channels, self.n_channels], [3, 3])
|
||||||
.with_stride([2, 2])
|
.with_stride([2, 2])
|
||||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||||
.init()
|
.init(device)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -423,10 +426,7 @@ impl<B: Backend> UNetBlock<B> for Conv2d<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Config, Debug)]
|
||||||
|
|
||||||
|
|
||||||
#[derive(Config)]
|
|
||||||
pub struct SpatialTransformerConfig {
|
pub struct SpatialTransformerConfig {
|
||||||
n_channels: usize,
|
n_channels: usize,
|
||||||
n_context_state: usize,
|
n_context_state: usize,
|
||||||
@@ -434,11 +434,12 @@ pub struct SpatialTransformerConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl SpatialTransformerConfig {
|
impl SpatialTransformerConfig {
|
||||||
fn init<B: Backend>(&self) -> SpatialTransformer<B> {
|
fn init<B: Backend>(&self, device: &B::Device) -> SpatialTransformer<B> {
|
||||||
let norm = GroupNormConfig::new(32, self.n_channels).init();
|
let norm = GroupNormConfig::new(32, self.n_channels).init(device);
|
||||||
let proj_in = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init();
|
let proj_in = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init(device);
|
||||||
let transformer = TransformerBlockConfig::new(self.n_channels, self.n_context_state, self.n_head).init();
|
let transformer =
|
||||||
let proj_out = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init();
|
TransformerBlockConfig::new(self.n_channels, self.n_context_state, self.n_head).init(device);
|
||||||
|
let proj_out = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init(device);
|
||||||
|
|
||||||
SpatialTransformer {
|
SpatialTransformer {
|
||||||
norm,
|
norm,
|
||||||
@@ -465,9 +466,13 @@ impl<B: Backend> SpatialTransformer<B> {
|
|||||||
|
|
||||||
let x = self.norm.forward(x);
|
let x = self.norm.forward(x);
|
||||||
let x = self.proj_in.forward(x);
|
let x = self.proj_in.forward(x);
|
||||||
let x = x.reshape([n_batch, n_channel, height * width]).swap_dims(1, 2);
|
let x = x
|
||||||
|
.reshape([n_batch, n_channel, height * width])
|
||||||
|
.swap_dims(1, 2);
|
||||||
|
|
||||||
let x = self.transformer.forward(x, context)
|
let x = self
|
||||||
|
.transformer
|
||||||
|
.forward(x, context)
|
||||||
.swap_dims(1, 2)
|
.swap_dims(1, 2)
|
||||||
.reshape([n_batch, n_channel, height, width]);
|
.reshape([n_batch, n_channel, height, width]);
|
||||||
|
|
||||||
@@ -475,14 +480,7 @@ impl<B: Backend> SpatialTransformer<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Config, Debug)]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#[derive(Config)]
|
|
||||||
pub struct TransformerBlockConfig {
|
pub struct TransformerBlockConfig {
|
||||||
n_state: usize,
|
n_state: usize,
|
||||||
n_context_state: usize,
|
n_context_state: usize,
|
||||||
@@ -490,13 +488,14 @@ pub struct TransformerBlockConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl TransformerBlockConfig {
|
impl TransformerBlockConfig {
|
||||||
fn init<B: Backend>(&self) -> TransformerBlock<B> {
|
fn init<B: Backend>(&self, device: &B::Device) -> TransformerBlock<B> {
|
||||||
let norm1 = nn::LayerNormConfig::new(self.n_state).init();
|
let norm1 = nn::LayerNormConfig::new(self.n_state).init(device);
|
||||||
let attn1 = MultiHeadAttentionConfig::new(self.n_state, self.n_state, self.n_head).init();
|
let attn1 = MultiHeadAttentionConfig::new(self.n_state, self.n_state, self.n_head).init(device);
|
||||||
let norm2 = nn::LayerNormConfig::new(self.n_state).init();
|
let norm2 = nn::LayerNormConfig::new(self.n_state).init(device);
|
||||||
let attn2 = MultiHeadAttentionConfig::new(self.n_state, self.n_context_state, self.n_head).init();
|
let attn2 =
|
||||||
let norm3 = nn::LayerNormConfig::new(self.n_state).init();
|
MultiHeadAttentionConfig::new(self.n_state, self.n_context_state, self.n_head).init(device);
|
||||||
let mlp = MLPConfig::new(self.n_state, 4).init();
|
let norm3 = nn::LayerNormConfig::new(self.n_state).init(device);
|
||||||
|
let mlp = MLPConfig::new(self.n_state, 4).init(device);
|
||||||
|
|
||||||
TransformerBlock {
|
TransformerBlock {
|
||||||
norm1,
|
norm1,
|
||||||
@@ -527,23 +526,19 @@ impl<B: Backend> TransformerBlock<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Config, Debug)]
|
||||||
#[derive(Config)]
|
|
||||||
pub struct MLPConfig {
|
pub struct MLPConfig {
|
||||||
n_state: usize,
|
n_state: usize,
|
||||||
mult: usize,
|
mult: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MLPConfig {
|
impl MLPConfig {
|
||||||
pub fn init<B: Backend>(&self) -> MLP<B> {
|
pub fn init<B: Backend>(&self, device: &B::Device) -> MLP<B> {
|
||||||
let n_state_hidden = self.n_state * self.mult;
|
let n_state_hidden = self.n_state * self.mult;
|
||||||
let geglu = GEGLUConfig::new(self.n_state, n_state_hidden).init();
|
let geglu = GEGLUConfig::new(self.n_state, n_state_hidden).init(device);
|
||||||
let lin = nn::LinearConfig::new(n_state_hidden, self.n_state).init();
|
let lin = nn::LinearConfig::new(n_state_hidden, self.n_state).init(device);
|
||||||
|
|
||||||
MLP {
|
MLP { geglu, lin }
|
||||||
geglu,
|
|
||||||
lin,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -559,29 +554,25 @@ impl<B: Backend> MLP<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Config, Debug)]
|
||||||
#[derive(Config)]
|
|
||||||
pub struct GEGLUConfig {
|
pub struct GEGLUConfig {
|
||||||
n_state_in: usize,
|
n_state_in: usize,
|
||||||
n_state_out: usize,
|
n_state_out: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GEGLUConfig {
|
impl GEGLUConfig {
|
||||||
fn init<B: Backend>(&self) -> GEGLU<B> {
|
fn init<B: Backend>(&self, device: &B::Device) -> GEGLU<B> {
|
||||||
let proj = nn::LinearConfig::new(self.n_state_in, 2 * self.n_state_out).init();
|
let proj = nn::LinearConfig::new(self.n_state_in, 2 * self.n_state_out).init(device);
|
||||||
let gelu = GELU::new();
|
let gelu = Gelu::new();
|
||||||
|
|
||||||
GEGLU {
|
GEGLU { proj, gelu }
|
||||||
proj,
|
|
||||||
gelu,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Module, Debug)]
|
#[derive(Module, Debug)]
|
||||||
pub struct GEGLU<B: Backend> {
|
pub struct GEGLU<B: Backend> {
|
||||||
proj: nn::Linear<B>,
|
proj: nn::Linear<B>,
|
||||||
gelu: GELU,
|
gelu: Gelu,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> GEGLU<B> {
|
impl<B: Backend> GEGLU<B> {
|
||||||
@@ -591,18 +582,16 @@ impl<B: Backend> GEGLU<B> {
|
|||||||
|
|
||||||
let n_state_out = n_state / 2;
|
let n_state_out = n_state / 2;
|
||||||
|
|
||||||
let x = projected.clone().slice([0..n_batch, 0..n_ctx, 0..n_state_out]);
|
let x = projected
|
||||||
|
.clone()
|
||||||
|
.slice([0..n_batch, 0..n_ctx, 0..n_state_out]);
|
||||||
let gate = projected.slice([0..n_batch, 0..n_ctx, n_state_out..n_state]);
|
let gate = projected.slice([0..n_batch, 0..n_ctx, n_state_out..n_state]);
|
||||||
|
|
||||||
x * self.gelu.forward(gate)
|
x * self.gelu.forward(gate)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Config, Debug)]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#[derive(Config)]
|
|
||||||
pub struct MultiHeadAttentionConfig {
|
pub struct MultiHeadAttentionConfig {
|
||||||
n_state: usize,
|
n_state: usize,
|
||||||
n_context_state: usize,
|
n_context_state: usize,
|
||||||
@@ -610,21 +599,32 @@ pub struct MultiHeadAttentionConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl MultiHeadAttentionConfig {
|
impl MultiHeadAttentionConfig {
|
||||||
fn init<B: Backend>(&self) -> MultiHeadAttention<B> {
|
fn init<B: Backend>(&self, device: &B::Device) -> MultiHeadAttention<B> {
|
||||||
assert!(self.n_state % self.n_head == 0, "State size {} must be a multiple of head size {}", self.n_state, self.n_head);
|
assert!(
|
||||||
|
self.n_state % self.n_head == 0,
|
||||||
|
"State size {} must be a multiple of head size {}",
|
||||||
|
self.n_state,
|
||||||
|
self.n_head
|
||||||
|
);
|
||||||
|
|
||||||
let n_head = self.n_head;
|
let n_head = self.n_head;
|
||||||
let query = nn::LinearConfig::new(self.n_state, self.n_state).with_bias(false).init();
|
let query = nn::LinearConfig::new(self.n_state, self.n_state)
|
||||||
let key = nn::LinearConfig::new(self.n_context_state, self.n_state).with_bias(false).init();
|
.with_bias(false)
|
||||||
let value = nn::LinearConfig::new(self.n_context_state, self.n_state).with_bias(false).init();
|
.init(device);
|
||||||
let out = nn::LinearConfig::new(self.n_state, self.n_state).init();
|
let key = nn::LinearConfig::new(self.n_context_state, self.n_state)
|
||||||
|
.with_bias(false)
|
||||||
|
.init(device);
|
||||||
|
let value = nn::LinearConfig::new(self.n_context_state, self.n_state)
|
||||||
|
.with_bias(false)
|
||||||
|
.init(device);
|
||||||
|
let out = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
|
||||||
|
|
||||||
MultiHeadAttention {
|
MultiHeadAttention {
|
||||||
n_head,
|
n_head,
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
out
|
out,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -652,44 +652,32 @@ impl<B: Backend> MultiHeadAttention<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Config, Debug)]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#[derive(Config)]
|
|
||||||
pub struct ResBlockConfig {
|
pub struct ResBlockConfig {
|
||||||
n_channels_in: usize,
|
n_channels_in: usize,
|
||||||
n_channels_embed: usize,
|
n_channels_embed: usize,
|
||||||
n_channels_out: usize,
|
n_channels_out: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
impl ResBlockConfig {
|
impl ResBlockConfig {
|
||||||
fn init<B: Backend>(&self) -> ResBlock<B> {
|
fn init<B: Backend>(&self, device: &B::Device) -> ResBlock<B> {
|
||||||
let norm_in = GroupNormConfig::new(32, self.n_channels_in).init();
|
let norm_in = GroupNormConfig::new(32, self.n_channels_in).init(device);
|
||||||
let silu_in = SILU::new();
|
let silu_in = SILU::new();
|
||||||
let conv_in = Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
let conv_in = Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [3, 3])
|
||||||
|
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||||
|
.init(device);
|
||||||
|
|
||||||
let silu_embed = SILU::new();
|
let silu_embed = SILU::new();
|
||||||
let lin_embed = nn::LinearConfig::new(self.n_channels_embed, self.n_channels_out).init();
|
let lin_embed = nn::LinearConfig::new(self.n_channels_embed, self.n_channels_out).init(device);
|
||||||
|
|
||||||
let norm_out = GroupNormConfig::new(32, self.n_channels_out).init();
|
let norm_out = GroupNormConfig::new(32, self.n_channels_out).init(device);
|
||||||
let silu_out = SILU::new();
|
let silu_out = SILU::new();
|
||||||
let conv_out = Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
let conv_out = Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3])
|
||||||
|
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||||
|
.init(device);
|
||||||
|
|
||||||
let skip_connection = if self.n_channels_in != self.n_channels_out {
|
let skip_connection = if self.n_channels_in != self.n_channels_out {
|
||||||
Some( Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [1, 1]).init() )
|
Some(Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [1, 1]).init(device))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
@@ -708,7 +696,6 @@ impl ResBlockConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[derive(Module, Debug)]
|
#[derive(Module, Debug)]
|
||||||
pub struct ResBlock<B: Backend> {
|
pub struct ResBlock<B: Backend> {
|
||||||
norm_in: GroupNorm<B>,
|
norm_in: GroupNorm<B>,
|
||||||
@@ -751,5 +738,3 @@ impl<B: Backend> UNetBlock<B> for ResBlock<B> {
|
|||||||
self.forward(x, emb)
|
self.forward(x, emb)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::{self, BufRead};
|
use std::io::{self, BufRead};
|
||||||
|
|
||||||
fn bytes_to_unicode() -> Vec<(u8, char)> {
|
fn bytes_to_unicode() -> Vec<(u8, char)> {
|
||||||
let mut bs: Vec<u8> = ('!' as u8 ..= '~' as u8).into_iter()
|
let mut bs: Vec<u8> = ('!' as u8..='~' as u8)
|
||||||
|
.into_iter()
|
||||||
.chain(('¡' as u8..='¬' as u8).into_iter())
|
.chain(('¡' as u8..='¬' as u8).into_iter())
|
||||||
.chain(('®' as u8..='ÿ' as u8).into_iter())
|
.chain(('®' as u8..='ÿ' as u8).into_iter())
|
||||||
.collect();
|
.collect();
|
||||||
@@ -22,19 +23,15 @@ fn bytes_to_unicode() -> Vec<(u8, char)> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bs.into_iter()
|
bs.into_iter()
|
||||||
.zip(
|
.zip(cs.into_iter().map(|c| c.into()))
|
||||||
cs.into_iter()
|
.collect()
|
||||||
.map(|c| c.into())
|
|
||||||
).collect()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_pairs(word: &[String]) -> Vec<(String, String)> {
|
fn get_pairs(word: &[String]) -> Vec<(String, String)> {
|
||||||
let prev = word.into_iter().cloned();
|
let prev = word.into_iter().cloned();
|
||||||
let next = prev.clone().skip(1);
|
let next = prev.clone().skip(1);
|
||||||
|
|
||||||
prev
|
prev.zip(next).collect()
|
||||||
.zip(next)
|
|
||||||
.collect()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn whitespace_clean(text: &str) -> String {
|
fn whitespace_clean(text: &str) -> String {
|
||||||
@@ -59,7 +56,10 @@ fn load_merges(path: &str) -> io::Result<Vec<(String, String)>> {
|
|||||||
Ok(merges)
|
Ok(merges)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn construct_vocab(chars: impl Iterator<Item=char> + Clone, merges: &[(String, String)]) -> Vec<String> {
|
fn construct_vocab(
|
||||||
|
chars: impl Iterator<Item = char> + Clone,
|
||||||
|
merges: &[(String, String)],
|
||||||
|
) -> Vec<String> {
|
||||||
let iter = chars.map(String::from);
|
let iter = chars.map(String::from);
|
||||||
let mut vocab: Vec<_> = iter.clone().chain(iter.map(|c| c + "</w>")).collect();
|
let mut vocab: Vec<_> = iter.clone().chain(iter.map(|c| c + "</w>")).collect();
|
||||||
|
|
||||||
@@ -129,7 +129,8 @@ impl SimpleTokenizer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let bigram = pairs.iter()
|
let bigram = pairs
|
||||||
|
.iter()
|
||||||
.filter(|pair| self.bpe_ranks.contains_key(pair))
|
.filter(|pair| self.bpe_ranks.contains_key(pair))
|
||||||
.min_by_key(|&pair| self.bpe_ranks[pair]);
|
.min_by_key(|&pair| self.bpe_ranks[pair]);
|
||||||
|
|
||||||
@@ -178,8 +179,16 @@ impl SimpleTokenizer {
|
|||||||
|
|
||||||
for m in self.pat.find_iter(&cleaned_text) {
|
for m in self.pat.find_iter(&cleaned_text) {
|
||||||
let token = m.as_str();
|
let token = m.as_str();
|
||||||
let token: String = token.as_bytes().into_iter().map(|b| self.byte_encoder[b]).collect();
|
let token: String = token
|
||||||
bpe_tokens.extend(self.bpe(&token).split(' ').map(|bpe_token| self.encoder[bpe_token]))
|
.as_bytes()
|
||||||
|
.into_iter()
|
||||||
|
.map(|b| self.byte_encoder[b])
|
||||||
|
.collect();
|
||||||
|
bpe_tokens.extend(
|
||||||
|
self.bpe(&token)
|
||||||
|
.split(' ')
|
||||||
|
.map(|bpe_token| self.encoder[bpe_token]),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return bpe_tokens;
|
return bpe_tokens;
|
||||||
@@ -187,9 +196,7 @@ impl SimpleTokenizer {
|
|||||||
|
|
||||||
pub fn decode(&self, tokens: &[u32]) -> String {
|
pub fn decode(&self, tokens: &[u32]) -> String {
|
||||||
let text: String = tokens.iter().map(|t| self.decoder[t].as_str()).collect();
|
let text: String = tokens.iter().map(|t| self.decoder[t].as_str()).collect();
|
||||||
let decoded_bytes: Vec<u8> = text.chars()
|
let decoded_bytes: Vec<u8> = text.chars().map(|c| self.byte_decoder[&c]).collect();
|
||||||
.map(|c| self.byte_decoder[&c])
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
String::from_utf8_lossy(&decoded_bytes[..]).replace("</w>", " ")
|
String::from_utf8_lossy(&decoded_bytes[..]).replace("</w>", " ")
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user