Compare commits
10 Commits
c24d37df00
...
test
| Author | SHA1 | Date | |
|---|---|---|---|
| 754810ca88 | |||
|
|
6cfd6db5a5 | ||
|
|
893fb0950d | ||
|
|
9e4d7bd310 | ||
|
|
01b1aea897 | ||
|
|
f4c58c1790 | ||
|
|
a62795347f | ||
|
|
1830756917 | ||
|
|
b87273c2be | ||
|
|
31c24a82ef |
20
Cargo.toml
20
Cargo.toml
@@ -6,27 +6,17 @@ edition = "2021"
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[features]
|
||||
default = ["torch-backend"]
|
||||
torch-backend = ["burn-tch"]
|
||||
wgpu-backend = ["burn-wgpu"]
|
||||
|
||||
[dependencies.burn-tch]
|
||||
package = "burn-tch"
|
||||
git = "https://github.com/burn-rs/burn.git"
|
||||
optional = true
|
||||
|
||||
[dependencies.burn-wgpu]
|
||||
package = "burn-wgpu"
|
||||
git = "https://github.com/burn-rs/burn.git"
|
||||
optional = true
|
||||
default = ["wgpu-backend"]
|
||||
|
||||
[dependencies]
|
||||
burn = { git = "https://github.com/burn-rs/burn.git" }
|
||||
burn = "0.20.1"
|
||||
burn-autodiff = "0.20.1"
|
||||
burn-wgpu = { version = "0.20.1", optional = true }
|
||||
serde = {version = "1.0.171", features = ["std", "derive"]}
|
||||
npy = "0.4.0"
|
||||
num-traits = "0.2.15"
|
||||
rust_tokenizers = "8.1.0"
|
||||
regex = "1.9.1"
|
||||
image = "0.24.6"
|
||||
bincode = {version = "2.0.0-alpha.0", features = ["std"]}
|
||||
cfg-if = "0.1"
|
||||
cfg-if = "0.1"
|
||||
|
||||
33
README.md
33
README.md
@@ -2,37 +2,29 @@
|
||||
|
||||
Stable-Diffusion-Burn is a Rust-based project which ports the V1 stable diffusion model into the deep learning framework, Burn. This repository is licensed under the MIT Licence.
|
||||
|
||||
## Support The Project
|
||||
|
||||
Stable-Diffusion-Burn is a passion project that is open and free to all. I want to empower everyone with reliable AI that can be run by ourselves on our own hardware to ensure that great AI is not limited to the hands of the few. If you support this vision consider supporting me so that I can continue on this journey and produce more projects such as Stable Diffusion XL in Rust.
|
||||
|
||||
You can show your support by buying a shirt at https://www.bonfire.com/machine-learning/. The shirt image was, of course, generated by my Rust powered Stable Diffusion! I'd love to release more projects and any support will help make that happen!
|
||||
|
||||
Any contribution would be greatly appreciated. Thanks!
|
||||
|
||||
## How To Use
|
||||
|
||||
### Step 1: Download the Model and Set Environment Variables
|
||||
|
||||
Start by downloading the SDv1-4.bin model provided on HuggingFace.
|
||||
Start by downloading the SDv1-4 model provided on HuggingFace.
|
||||
|
||||
```bash
|
||||
wget https://huggingface.co/Gadersd/Stable-Diffusion-Burn/resolve/main/V1/SDv1-4.bin
|
||||
wget https://huggingface.co/Gadersd/Stable-Diffusion-Burn/resolve/main/SDv1-4.mpk
|
||||
```
|
||||
|
||||
Next, set the appropriate CUDA version. It may be possible to run the model using wgpu without the need for torch in the future using `cargo run --features wgpu-backend...` but currently wgpu doesn't support buffer sizes large enough for Stable Diffusion.
|
||||
|
||||
```bash
|
||||
export TORCH_CUDA_VERSION=cu113
|
||||
```
|
||||
### Step 2: Run the Sample Binary
|
||||
|
||||
Invoke the sample binary provided in the rust code, as shown below:
|
||||
Invoke the sample binary provided in the rust code. The application now uses a pure Rust backend (WGPU/Vulkan) instead of libtorch. The WGPU backend is unstable for SD but may work well in the future as burn-wpu is optimized.
|
||||
|
||||
```bash
|
||||
# Arguments: <model_type(burn or dump)> <model> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image>
|
||||
# WGPU/Vulkan backend (GPU accelerated, requires Vulkan-compatible GPU)
|
||||
# Arguments: <model_type(burn or dump)> <model_name> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image_name>
|
||||
|
||||
# GPU (Vulkan)
|
||||
cargo run --release --features wgpu-backend --bin sample burn SDv1-4 7.5 20 "An ancient mossy stone." img
|
||||
|
||||
# CPU (UNSTABLE - fallback if GPU not available)
|
||||
cargo run --release --bin sample burn SDv1-4 7.5 20 "An ancient mossy stone." img
|
||||
```
|
||||
|
||||
This command will generate an image according to the provided prompt, which will be saved as 'img0.png'.
|
||||
|
||||
@@ -40,7 +32,7 @@ This command will generate an image according to the provided prompt, which will
|
||||
|
||||
### Optional: Extract and Convert a Fine-Tuned Model
|
||||
|
||||
If users are interested in using a fine-tuned version of stable diffusion, the Python scripts provided in this project can be used to transform a weight dump into a Burn model file.
|
||||
If users are interested in using a fine-tuned version of stable diffusion, the Python scripts provided in this project can be used to transform a weight dump into a Burn model file. This does not work on Windows.
|
||||
|
||||
```bash
|
||||
# Step into the Python directory
|
||||
@@ -49,6 +41,9 @@ cd python
|
||||
# Download the model, this is just the base v1.4 model as an example
|
||||
wget https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt
|
||||
|
||||
# Install tinygrad
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Extract the weights
|
||||
CPU=1 python3 dump.py sd-v1-4.ckpt
|
||||
|
||||
|
||||
BIN
img0.png
BIN
img0.png
Binary file not shown.
|
Before Width: | Height: | Size: 671 KiB After Width: | Height: | Size: 677 KiB |
@@ -13,10 +13,11 @@ from collections import namedtuple
|
||||
|
||||
from tqdm import tqdm
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import dtypes, GlobalCounters
|
||||
from tinygrad.helpers import GlobalCounters
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
|
||||
from extra.utils import download_file
|
||||
from tinygrad.state import torch_load, load_state_dict
|
||||
#from extra.utils import download_file
|
||||
from tinygrad.nn.state import torch_load, load_state_dict
|
||||
|
||||
# TODO: refactor AttnBlock, CrossAttention, CLIPAttention to share code
|
||||
|
||||
|
||||
1
python/requirements.txt
Normal file
1
python/requirements.txt
Normal file
@@ -0,0 +1 @@
|
||||
tinygrad==0.9.2
|
||||
139
src/backend.rs
Normal file
139
src/backend.rs
Normal file
@@ -0,0 +1,139 @@
|
||||
use burn::tensor::{activation::softmax, Tensor};
|
||||
use burn::prelude::Backend;
|
||||
|
||||
/*pub type FloatTensor<B, const D: usize> = <B as burn::tensor::backend::Backend>::TensorPrimitive<D>;
|
||||
|
||||
pub trait Backend: burn::tensor::backend::Backend {
|
||||
fn qkv_attention(
|
||||
q: FloatTensor<Self, 3>,
|
||||
k: FloatTensor<Self, 3>,
|
||||
v: FloatTensor<Self, 3>,
|
||||
mask: Option<FloatTensor<Self, 2>>,
|
||||
n_head: usize,
|
||||
) -> FloatTensor<Self, 3> {
|
||||
qkv_attention(
|
||||
Tensor::<Self, 3>::from_primitive(q),
|
||||
Tensor::from_primitive(k),
|
||||
Tensor::from_primitive(v),
|
||||
mask.map(|m| Tensor::from_primitive(m)),
|
||||
n_head,
|
||||
)
|
||||
.into_primitive()
|
||||
}
|
||||
|
||||
fn attn_decoder_mask(seq_length: usize, device: &Self::Device) -> FloatTensor<Self, 2> {
|
||||
attn_decoder_mask::<Self>(seq_length, device).into_primitive()
|
||||
}
|
||||
}
|
||||
|
||||
use burn::tensor::Float;
|
||||
use burn_tch::{self, TchElement, TchTensor};
|
||||
use tch;
|
||||
|
||||
impl<E: TchElement> Backend for burn_tch::LibTorch<E> {
|
||||
fn qkv_attention(
|
||||
q: FloatTensor<Self, 3>,
|
||||
k: FloatTensor<Self, 3>,
|
||||
v: FloatTensor<Self, 3>,
|
||||
mask: Option<FloatTensor<Self, 2>>,
|
||||
n_head: usize,
|
||||
) -> FloatTensor<Self, 2> {
|
||||
let q = Tensor::from_primitive(q);
|
||||
let k = Tensor::from_primitive(k);
|
||||
let v = Tensor::from_primitive(v);
|
||||
|
||||
let [n_batch, q_ctx, n_state] = q.dims();
|
||||
let [_, k_ctx, _] = k.dims();
|
||||
let n_hstate = n_state / n_head;
|
||||
|
||||
let rearrange = |t: Tensor<Self, 3>| {
|
||||
let [_, n_ctx, _] = t.dims();
|
||||
t.reshape([n_batch, n_ctx, n_head, n_hstate])
|
||||
.swap_dims(1, 2)
|
||||
};
|
||||
|
||||
let q = rearrange(q).into_primitive();
|
||||
let k = rearrange(k).into_primitive();
|
||||
let v = rearrange(v).into_primitive();
|
||||
|
||||
// for some reason torch crashes when mask is None
|
||||
let mask = mask.unwrap_or_else(|| {
|
||||
Tensor::<Self, 2, Float>::zeros([q_ctx, k_ctx], &Self::device(&v))
|
||||
.into_primitive()
|
||||
});
|
||||
|
||||
Tensor::<Self, 4>::from_primitive(TchTensor::new(
|
||||
tch::Tensor::scaled_dot_product_attention(
|
||||
&q.tensor,
|
||||
&k.tensor,
|
||||
&v.tensor,
|
||||
Some(mask.tensor),
|
||||
0.0,
|
||||
false,
|
||||
None,
|
||||
),
|
||||
))
|
||||
.swap_dims(1, 2)
|
||||
.flatten(2, 3)
|
||||
.into_primitive()
|
||||
}
|
||||
}
|
||||
|
||||
use burn_autodiff;
|
||||
|
||||
impl<B: Backend> Backend for burn_autodiff::Autodiff<B> {}*/
|
||||
|
||||
use std::f32::NEG_INFINITY;
|
||||
|
||||
pub fn qkv_attention<B: Backend>(
|
||||
q: Tensor<B, 3>,
|
||||
k: Tensor<B, 3>,
|
||||
v: Tensor<B, 3>,
|
||||
mask: Option<Tensor<B, 2>>,
|
||||
n_head: usize,
|
||||
) -> Tensor<B, 3> {
|
||||
let [n_batch, n_qctx, n_state] = q.dims();
|
||||
let [_, n_ctx, _] = k.dims();
|
||||
|
||||
let scale = (n_state as f64 / n_head as f64).powf(-0.25);
|
||||
let n_hstate = n_state / n_head;
|
||||
|
||||
let q = q
|
||||
.reshape([n_batch, n_qctx, n_head, n_hstate])
|
||||
.swap_dims(1, 2)
|
||||
* scale;
|
||||
let k = k
|
||||
.reshape([n_batch, n_ctx, n_head, n_hstate])
|
||||
.swap_dims(1, 2)
|
||||
.transpose()
|
||||
* scale;
|
||||
let v = v
|
||||
.reshape([n_batch, n_ctx, n_head, n_hstate])
|
||||
.swap_dims(1, 2);
|
||||
|
||||
let qk = q.matmul(k);
|
||||
|
||||
// apply mask
|
||||
let qk = if let Some(mask) = mask {
|
||||
qk + mask.slice([0..n_qctx, 0..n_ctx]).unsqueeze::<4>()
|
||||
} else {
|
||||
qk
|
||||
};
|
||||
|
||||
// normalize value weightings
|
||||
let w = softmax(qk, 3);
|
||||
let o = w.matmul(v).swap_dims(1, 2).flatten(2, 3);
|
||||
|
||||
return o;
|
||||
}
|
||||
|
||||
pub fn attn_decoder_mask<B: Backend>(seq_length: usize, device: &B::Device) -> Tensor<B, 2> {
|
||||
let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length], device);
|
||||
|
||||
for i in 0..(seq_length - 1) {
|
||||
let values = Tensor::<B, 2>::zeros([1, seq_length - (i + 1)], device).add_scalar(NEG_INFINITY);
|
||||
mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values);
|
||||
}
|
||||
|
||||
return mask;
|
||||
}
|
||||
@@ -1,32 +1,27 @@
|
||||
use std::env;
|
||||
use std::process;
|
||||
use std::error::Error;
|
||||
use std::process;
|
||||
|
||||
use stablediffusion::model::stablediffusion::{StableDiffusion, load::load_stable_diffusion};
|
||||
use stablediffusion::model::stablediffusion::{load::load_stable_diffusion, StableDiffusion};
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn,
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
Tensor,
|
||||
},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
cfg_if::cfg_if! {
|
||||
if #[cfg(feature = "torch-backend")] {
|
||||
use burn_tch::{TchBackend, TchDevice};
|
||||
} else if #[cfg(feature = "wgpu-backend")] {
|
||||
use burn_wgpu::{WgpuBackend, WgpuDevice, AutoGraphicsApi};
|
||||
}
|
||||
}
|
||||
use burn_ndarray::{NdArray, NdArrayDevice};
|
||||
|
||||
use burn::record::{self, Recorder, BinFileRecorder, FullPrecisionSettings};
|
||||
use burn::record::{self, NamedMpkFileRecorder, FullPrecisionSettings, Recorder};
|
||||
|
||||
fn convert_dump_to_model<B: Backend>(dump_path: &str, model_name: &str, device: &B::Device) -> Result<(), Box<dyn Error>> {
|
||||
fn convert_dump_to_model<B: Backend>(
|
||||
dump_path: &str,
|
||||
model_name: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
println!("Loading dump...");
|
||||
let model: StableDiffusion::<B> = load_stable_diffusion(dump_path, device)?;
|
||||
let model: StableDiffusion<B> = load_stable_diffusion(dump_path, device)?;
|
||||
|
||||
println!("Saving model...");
|
||||
save_model_file(model, model_name)?;
|
||||
@@ -34,24 +29,16 @@ fn convert_dump_to_model<B: Backend>(dump_path: &str, model_name: &str, device:
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn save_model_file<B: Backend>(model: StableDiffusion<B>, name: &str) -> Result<(), record::RecorderError> {
|
||||
BinFileRecorder::<FullPrecisionSettings>::new()
|
||||
.record(
|
||||
model.into_record(),
|
||||
name.into(),
|
||||
)
|
||||
fn save_model_file<B: Backend>(
|
||||
model: StableDiffusion<B>,
|
||||
name: &str,
|
||||
) -> Result<(), record::RecorderError> {
|
||||
NamedMpkFileRecorder::<FullPrecisionSettings>::new().record(model.into_record(), name.into())
|
||||
}
|
||||
|
||||
fn main() {
|
||||
cfg_if::cfg_if! {
|
||||
if #[cfg(feature = "torch-backend")] {
|
||||
type Backend = TchBackend<f32>;
|
||||
let device = TchDevice::Cpu;
|
||||
} else if #[cfg(feature = "wgpu-backend")] {
|
||||
type Backend = WgpuBackend<AutoGraphicsApi, f32, i32>;
|
||||
let device = WgpuDevice::CPU;
|
||||
}
|
||||
}
|
||||
type Backend = NdArray<f32>;
|
||||
let device = NdArrayDevice::Cpu;
|
||||
|
||||
let args: Vec<String> = env::args().collect();
|
||||
if args.len() != 3 {
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
use stablediffusion::{tokenizer::SimpleTokenizer, model::stablediffusion::{*, load::load_stable_diffusion}};
|
||||
use stablediffusion::{
|
||||
model::stablediffusion::{load::load_stable_diffusion, *},
|
||||
tokenizer::SimpleTokenizer,
|
||||
};
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn,
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
Tensor,
|
||||
},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
cfg_if::cfg_if! {
|
||||
if #[cfg(feature = "torch-backend")] {
|
||||
use burn_tch::{TchBackend, TchDevice};
|
||||
} else if #[cfg(feature = "wgpu-backend")] {
|
||||
use burn_wgpu::{WgpuBackend, WgpuDevice, AutoGraphicsApi};
|
||||
if #[cfg(feature = "wgpu-backend")] {
|
||||
use burn_wgpu::{Wgpu, WgpuDevice};
|
||||
} else {
|
||||
use burn_ndarray::NdArrayDevice;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,28 +22,21 @@ use std::env;
|
||||
use std::io;
|
||||
use std::process;
|
||||
|
||||
use burn::record::{self, Recorder, BinFileRecorder, FullPrecisionSettings};
|
||||
use burn::record::{self, NamedMpkFileRecorder, FullPrecisionSettings, Recorder};
|
||||
|
||||
fn load_stable_diffusion_model_file<B: Backend>(filename: &str) -> Result<StableDiffusion<B>, record::RecorderError> {
|
||||
BinFileRecorder::<FullPrecisionSettings>::new()
|
||||
.load(filename.into())
|
||||
.map(|record| StableDiffusionConfig::new().init().load_record(record))
|
||||
fn load_stable_diffusion_model_file<B: Backend>(
|
||||
filename: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<StableDiffusion<B>, record::RecorderError> {
|
||||
NamedMpkFileRecorder::<FullPrecisionSettings>::new()
|
||||
.load(filename.into(), device)
|
||||
.map(|record| StableDiffusionConfig::new().init(device).load_record(record))
|
||||
}
|
||||
|
||||
fn main() {
|
||||
cfg_if::cfg_if! {
|
||||
if #[cfg(feature = "torch-backend")] {
|
||||
type Backend = TchBackend<f32>;
|
||||
let device = TchDevice::Cuda(0);
|
||||
} else if #[cfg(feature = "wgpu-backend")] {
|
||||
type Backend = WgpuBackend<AutoGraphicsApi, f32, i32>;
|
||||
let device = WgpuDevice::BestAvailable;
|
||||
}
|
||||
}
|
||||
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
if args.len() != 7 {
|
||||
eprintln!("Usage: {} <model_type(burn or dump)> <model_name> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image_name>", args[0]);
|
||||
if args.len() != 7 && args.len() != 8 {
|
||||
eprintln!("Usage: {} <model_type(burn or dump)> <model_name> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image_name> [device(cuda, mps, cpu)]", args[0]);
|
||||
process::exit(1);
|
||||
}
|
||||
|
||||
@@ -60,11 +53,24 @@ fn main() {
|
||||
let prompt = &args[5];
|
||||
let output_image_name = &args[6];
|
||||
|
||||
// Optional device parameter
|
||||
let device_arg = if args.len() == 8 { Some(&args[7]) } else { None };
|
||||
|
||||
cfg_if::cfg_if! {
|
||||
if #[cfg(feature = "wgpu-backend")] {
|
||||
type Backend = Wgpu;
|
||||
let device = WgpuDevice::BestAvailable;
|
||||
} else {
|
||||
type Backend = burn::backend::ndarray::NdArray<f32>;
|
||||
let device = NdArrayDevice::Cpu;
|
||||
}
|
||||
}
|
||||
|
||||
println!("Loading tokenizer...");
|
||||
let tokenizer = SimpleTokenizer::new().unwrap();
|
||||
println!("Loading model...");
|
||||
let sd: StableDiffusion<Backend> = if model_type == "burn" {
|
||||
load_stable_diffusion_model_file(model_name).unwrap_or_else(|err| {
|
||||
load_stable_diffusion_model_file(model_name, &device).unwrap_or_else(|err| {
|
||||
eprintln!("Error loading model: {}", err);
|
||||
process::exit(1);
|
||||
})
|
||||
@@ -74,21 +80,24 @@ fn main() {
|
||||
process::exit(1);
|
||||
})
|
||||
};
|
||||
|
||||
let sd = sd.to_device(&device);
|
||||
|
||||
let unconditional_context = sd.unconditional_context(&tokenizer);
|
||||
let context = sd.context(&tokenizer, prompt).unsqueeze().repeat(0, 2); // generate 2 samples
|
||||
let context = sd.context(&tokenizer, prompt).unsqueeze::<3>(); //.repeat(0, 2); // generate 2 samples
|
||||
|
||||
println!("Sampling image...");
|
||||
let images = sd.sample_image(context, unconditional_context, unconditional_guidance_scale, n_steps);
|
||||
let images = sd.sample_image(
|
||||
context,
|
||||
unconditional_context,
|
||||
unconditional_guidance_scale,
|
||||
n_steps,
|
||||
);
|
||||
save_images(&images, output_image_name, 512, 512).unwrap_or_else(|err| {
|
||||
eprintln!("Error saving image: {}", err);
|
||||
process::exit(1);
|
||||
});
|
||||
}
|
||||
|
||||
use image::{self, ImageResult, ColorType::Rgb8};
|
||||
use image::{self, ColorType::Rgb8, ImageResult};
|
||||
|
||||
fn save_images(images: &Vec<Vec<u8>>, basepath: &str, width: u32, height: u32) -> ImageResult<()> {
|
||||
for (index, img_data) in images.iter().enumerate() {
|
||||
@@ -103,12 +112,15 @@ fn save_images(images: &Vec<Vec<u8>>, basepath: &str, width: u32, height: u32) -
|
||||
fn save_test_image() -> ImageResult<()> {
|
||||
let width = 256;
|
||||
let height = 256;
|
||||
let raw: Vec<_> = (0..width * height).into_iter().flat_map(|i| {
|
||||
let row = i / width;
|
||||
let red = (255.0 * row as f64 / height as f64) as u8;
|
||||
let raw: Vec<_> = (0..width * height)
|
||||
.into_iter()
|
||||
.flat_map(|i| {
|
||||
let row = i / width;
|
||||
let red = (255.0 * row as f64 / height as f64) as u8;
|
||||
|
||||
[red, 0, 0]
|
||||
}).collect();
|
||||
[red, 0, 0]
|
||||
})
|
||||
.collect();
|
||||
|
||||
image::save_buffer("red.png", &raw[..], width, height, Rgb8)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
use burn::{
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
activation::relu,
|
||||
Tensor,
|
||||
Int,
|
||||
Bool,
|
||||
Float,
|
||||
TensorKind,
|
||||
BasicOps,
|
||||
Numeric,
|
||||
Element,
|
||||
},
|
||||
};
|
||||
|
||||
use num_traits::ToPrimitive;
|
||||
|
||||
|
||||
pub fn tensor_max_scalar<B: Backend, const D: usize>(x: Tensor<B, D>, max: f64) -> Tensor<B, D> {
|
||||
relu(x.sub_scalar(max)).add_scalar(max)
|
||||
}
|
||||
|
||||
pub fn tensor_min_scalar<B: Backend, const D: usize>(x: Tensor<B, D>, min: f64) -> Tensor<B, D> {
|
||||
-tensor_max_scalar(-x, -min)
|
||||
}
|
||||
|
||||
pub fn tensor_max<B: Backend, const D: usize>(x: Tensor<B, D>, max: Tensor<B, D>) -> Tensor<B, D> {
|
||||
relu(x - max.clone()) + max
|
||||
}
|
||||
|
||||
pub fn tensor_min<B: Backend, const D: usize>(x: Tensor<B, D>, min: Tensor<B, D>) -> Tensor<B, D> {
|
||||
-tensor_max(-x, -min)
|
||||
}
|
||||
|
||||
pub fn tensor_log10<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
|
||||
let ln10 = (10.0f64).ln();
|
||||
x.log() / ln10
|
||||
}
|
||||
|
||||
pub fn tensor_max_element<B: Backend, const D: usize>(x: Tensor<B, D>) -> f64 {
|
||||
let flat: Tensor<B, 1> = x.flatten(0, D - 1);
|
||||
let max_index = flat.clone().argmax(0);
|
||||
|
||||
flat.select(0, max_index).into_scalar().to_f64().unwrap()
|
||||
}
|
||||
|
||||
pub fn all_zeros<B: Backend, const D: usize>(x: Tensor<B, D>) -> bool {
|
||||
x.powf(2.0).sum().into_scalar().to_f64().unwrap() == 0.0
|
||||
}
|
||||
|
||||
pub fn max_dim<B: Backend>(x: Tensor<B, 2>, dim: usize) -> Tensor<B, 2> {
|
||||
let indices = x.clone().argmax(dim).flatten(0, 1);
|
||||
x.select(dim, indices)
|
||||
}
|
||||
|
||||
pub fn _10pow<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
|
||||
let log10 = (10.0f64).ln();
|
||||
(x * log10).exp()
|
||||
}
|
||||
|
||||
pub fn to_float<B: Backend, const D: usize>(x: Tensor<B, D, Int>) -> Tensor<B, D, Float> {
|
||||
let device = x.device();
|
||||
Tensor::from_data(
|
||||
x
|
||||
.into_data()
|
||||
.convert()
|
||||
).to_device(&device)
|
||||
}
|
||||
|
||||
pub fn to_float_bool<B: Backend, const D: usize>(x: Tensor<B, D, Bool>) -> Tensor<B, D, Float> {
|
||||
let device = x.device();
|
||||
Tensor::from_data(
|
||||
x
|
||||
.into_data()
|
||||
.convert()
|
||||
).to_device(&device)
|
||||
}
|
||||
|
||||
pub fn reverse<B: Backend, const D: usize, K: TensorKind<B> + BasicOps<B> + Numeric<B>>(x: Tensor<B, D, K>, dim: usize) -> Tensor<B, D, K> where <K as BasicOps<B>>::Elem: Element {
|
||||
let len = x.dims()[dim];
|
||||
let indices = -Tensor::arange_device(0..len, &x.device()) + (len - 1) as i64;
|
||||
x.select(dim, indices)
|
||||
}
|
||||
|
||||
pub fn div_roundup(x: usize, y: usize) -> usize {
|
||||
(x + y - 1) / y
|
||||
}
|
||||
@@ -1,3 +1,3 @@
|
||||
pub mod backend;
|
||||
pub mod model;
|
||||
pub mod tokenizer;
|
||||
pub mod helper;
|
||||
@@ -1,23 +1,32 @@
|
||||
use burn::{
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
activation::softmax,
|
||||
Tensor,
|
||||
},
|
||||
};
|
||||
use burn::tensor::{activation::softmax, backend::Backend, Tensor};
|
||||
|
||||
use std::f32::NEG_INFINITY;
|
||||
|
||||
pub fn qkv_attention<B: Backend>(q: Tensor<B, 3>, k: Tensor<B, 3>, v: Tensor<B, 3>, mask: Option<Tensor<B, 2>>, n_head: usize) -> Tensor<B, 3> {
|
||||
pub fn qkv_attention<B: Backend>(
|
||||
q: Tensor<B, 3>,
|
||||
k: Tensor<B, 3>,
|
||||
v: Tensor<B, 3>,
|
||||
mask: Option<Tensor<B, 2>>,
|
||||
n_head: usize,
|
||||
) -> Tensor<B, 3> {
|
||||
let [n_batch, n_qctx, n_state] = q.dims();
|
||||
let [_, n_ctx, _] = k.dims();
|
||||
|
||||
let scale = (n_state as f64 / n_head as f64).powf(-0.25);
|
||||
let n_hstate = n_state / n_head;
|
||||
|
||||
let q = q.reshape([n_batch, n_qctx, n_head, n_hstate]).swap_dims(1, 2) * scale;
|
||||
let k = k.reshape([n_batch, n_ctx, n_head, n_hstate]).swap_dims(1, 2).transpose() * scale;
|
||||
let v = v.reshape([n_batch, n_ctx, n_head, n_hstate]).swap_dims(1, 2);
|
||||
let q = q
|
||||
.reshape([n_batch, n_qctx, n_head, n_hstate])
|
||||
.swap_dims(1, 2)
|
||||
* scale;
|
||||
let k = k
|
||||
.reshape([n_batch, n_ctx, n_head, n_hstate])
|
||||
.swap_dims(1, 2)
|
||||
.transpose()
|
||||
* scale;
|
||||
let v = v
|
||||
.reshape([n_batch, n_ctx, n_head, n_hstate])
|
||||
.swap_dims(1, 2);
|
||||
|
||||
let qk = q.matmul(k);
|
||||
|
||||
@@ -36,12 +45,12 @@ pub fn qkv_attention<B: Backend>(q: Tensor<B, 3>, k: Tensor<B, 3>, v: Tensor<B,
|
||||
}
|
||||
|
||||
pub fn attn_decoder_mask<B: Backend>(seq_length: usize, device: &B::Device) -> Tensor<B, 2> {
|
||||
let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length]);
|
||||
let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length], device);
|
||||
|
||||
for i in 0..(seq_length - 1) {
|
||||
let values = Tensor::<B, 2>::zeros([1, seq_length - (i + 1)]).add_scalar(NEG_INFINITY);
|
||||
let values = Tensor::<B, 2>::zeros([1, seq_length - (i + 1)], device).add_scalar(NEG_INFINITY);
|
||||
mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values);
|
||||
}
|
||||
|
||||
return mask.to_device(device);
|
||||
}
|
||||
return mask;
|
||||
}
|
||||
|
||||
@@ -4,29 +4,38 @@ use crate::model::load::*;
|
||||
use std::error::Error;
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn,
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
Tensor,
|
||||
},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
use crate::model::groupnorm::load::load_group_norm;
|
||||
|
||||
fn load_conv_self_attention_block<B: Backend>(path: &str, device: &B::Device) -> Result<ConvSelfAttentionBlock<B>, Box<dyn Error>> {
|
||||
fn load_conv_self_attention_block<B: Backend>(
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<ConvSelfAttentionBlock<B>, Box<dyn Error>> {
|
||||
let norm = load_group_norm(&format!("{}/{}", path, "norm"), device)?;
|
||||
let q = load_conv2d(&format!("{}/{}", path, "q"), device)?;
|
||||
let k = load_conv2d(&format!("{}/{}", path, "k"), device)?;
|
||||
let v = load_conv2d(&format!("{}/{}", path, "v"), device)?;
|
||||
let proj_out = load_conv2d(&format!("{}/{}", path, "proj_out"), device)?;
|
||||
|
||||
Ok(ConvSelfAttentionBlock { norm, q, k, v, proj_out })
|
||||
Ok(ConvSelfAttentionBlock {
|
||||
norm,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
proj_out,
|
||||
})
|
||||
}
|
||||
|
||||
fn load_resnet_block<B: Backend>(path: &str, device: &B::Device) -> Result<ResnetBlock<B>, Box<dyn Error>> {
|
||||
fn load_resnet_block<B: Backend>(
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<ResnetBlock<B>, Box<dyn Error>> {
|
||||
let norm1 = load_group_norm(&format!("{}/{}", path, "norm1"), device)?;
|
||||
let silu1 = SILU {};
|
||||
let conv1 = load_conv2d(&format!("{}/{}", path, "conv1"), device)?;
|
||||
@@ -35,7 +44,15 @@ fn load_resnet_block<B: Backend>(path: &str, device: &B::Device) -> Result<Resne
|
||||
let conv2 = load_conv2d(&format!("{}/{}", path, "conv2"), device)?;
|
||||
let nin_shortcut = load_conv2d(&format!("{}/{}", path, "nin_shortcut"), device).ok();
|
||||
|
||||
Ok(ResnetBlock { norm1, silu1, conv1, norm2, silu2, conv2, nin_shortcut })
|
||||
Ok(ResnetBlock {
|
||||
norm1,
|
||||
silu1,
|
||||
conv1,
|
||||
norm2,
|
||||
silu2,
|
||||
conv2,
|
||||
nin_shortcut,
|
||||
})
|
||||
}
|
||||
|
||||
fn load_mid<B: Backend>(path: &str, device: &B::Device) -> Result<Mid<B>, Box<dyn Error>> {
|
||||
@@ -43,49 +60,76 @@ fn load_mid<B: Backend>(path: &str, device: &B::Device) -> Result<Mid<B>, Box<dy
|
||||
let attn = load_conv_self_attention_block(&format!("{}/{}", path, "attn"), device)?;
|
||||
let block_2 = load_resnet_block(&format!("{}/{}", path, "block_2"), device)?;
|
||||
|
||||
Ok(Mid { block_1, attn, block_2 })
|
||||
Ok(Mid {
|
||||
block_1,
|
||||
attn,
|
||||
block_2,
|
||||
})
|
||||
}
|
||||
|
||||
fn load_padded_conv2d<B: Backend>(path: &str, device: &B::Device) -> Result<PaddedConv2d<B>, Box<dyn Error>> {
|
||||
let conv = load_conv2d(&format!("{}/{}", path, "conv"), device)?;
|
||||
fn load_padded_conv2d<B: Backend>(
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<PaddedConv2d<B>, Box<dyn Error>> {
|
||||
let mut conv = load_conv2d(&format!("{}/{}", path, "conv"), device)?;
|
||||
|
||||
let channels = load_tensor::<B, 1>("channels", path, device)?;
|
||||
let channels = tensor_to_array_2(channels);
|
||||
let channels = tensor_to_array_2(channels);
|
||||
|
||||
let kernel_size = load_usize::<B>("kernel_size", path, device)?;
|
||||
let stride = load_usize::<B>("stride", path, device)?;
|
||||
|
||||
let padding = load_tensor::<B, 1>("padding", path, device)?;
|
||||
let padding: [usize; 4] = tensor_to_array(padding);
|
||||
let padding = Padding::new(padding[0], padding[1], padding[2], padding[3]);
|
||||
let padding = PaddingCfg::new(padding[0], padding[1], padding[2], padding[3]);
|
||||
|
||||
let mut record = conv.into_record();
|
||||
//let mut record = conv.into_record();
|
||||
|
||||
let mut padded_conv: PaddedConv2d<B> = PaddedConv2dConfig::new(channels, kernel_size, padding).with_stride(stride).init();
|
||||
let padding_actual = PaddingConfig2d::Explicit(padded_conv.padding_actual[0], padded_conv.padding_actual[1]);
|
||||
let mut padded_conv: PaddedConv2d<B> = PaddedConv2dConfig::new(channels, kernel_size, padding)
|
||||
.with_stride(stride)
|
||||
.init(device);
|
||||
let padding_actual =
|
||||
PaddingConfig2d::Explicit(padded_conv.padding_actual[0], padded_conv.padding_actual[1]);
|
||||
|
||||
record.padding = <PaddingConfig2d as Module<B>>::into_record(padding_actual);
|
||||
padded_conv.conv = padded_conv.conv.load_record(record);
|
||||
|
||||
conv.padding = burn::module::Ignored(padding_actual);
|
||||
padded_conv.conv = conv;
|
||||
|
||||
//record.padding = <PaddingConfig2d as Module<B>>::into_record(padding_actual);
|
||||
//padded_conv.conv = padded_conv.conv.load_record(record);
|
||||
|
||||
Ok(padded_conv)
|
||||
}
|
||||
|
||||
fn load_decoder_block<B: Backend>(path: &str, device: &B::Device) -> Result<DecoderBlock<B>, Box<dyn Error>> {
|
||||
fn load_decoder_block<B: Backend>(
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<DecoderBlock<B>, Box<dyn Error>> {
|
||||
let res1 = load_resnet_block(&format!("{}/{}", path, "res1"), device)?;
|
||||
let res2 = load_resnet_block(&format!("{}/{}", path, "res2"), device)?;
|
||||
let res3 = load_resnet_block(&format!("{}/{}", path, "res3"), device)?;
|
||||
let upsampler = load_conv2d(&format!("{}/{}", path, "upsampler"), device).ok();
|
||||
|
||||
Ok(DecoderBlock { res1, res2, res3, upsampler })
|
||||
Ok(DecoderBlock {
|
||||
res1,
|
||||
res2,
|
||||
res3,
|
||||
upsampler,
|
||||
})
|
||||
}
|
||||
|
||||
fn load_encoder_block<B: Backend>(path: &str, device: &B::Device) -> Result<EncoderBlock<B>, Box<dyn Error>> {
|
||||
fn load_encoder_block<B: Backend>(
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<EncoderBlock<B>, Box<dyn Error>> {
|
||||
let res1 = load_resnet_block(&format!("{}/{}", path, "res1"), device)?;
|
||||
let res2 = load_resnet_block(&format!("{}/{}", path, "res2"), device)?;
|
||||
let downsampler = load_padded_conv2d(&format!("{}/{}", path, "downsampler"), device).ok();
|
||||
|
||||
Ok(EncoderBlock { res1, res2, downsampler })
|
||||
Ok(EncoderBlock {
|
||||
res1,
|
||||
res2,
|
||||
downsampler,
|
||||
})
|
||||
}
|
||||
|
||||
fn load_decoder<B: Backend>(path: &str, device: &B::Device) -> Result<Decoder<B>, Box<dyn Error>> {
|
||||
@@ -95,15 +139,21 @@ fn load_decoder<B: Backend>(path: &str, device: &B::Device) -> Result<Decoder<B>
|
||||
let n_block = load_usize::<B>("n_block", path, device)?;
|
||||
let mut blocks = (0..n_block)
|
||||
.into_iter()
|
||||
.map(|i| {
|
||||
load_decoder_block::<B>(&format!("{}/blocks/{}", path, i), device)
|
||||
}).collect::<Result<Vec<_>, _>>()?;
|
||||
.map(|i| load_decoder_block::<B>(&format!("{}/blocks/{}", path, i), device))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let norm_out = load_group_norm(&format!("{}/{}", path, "norm_out"), device)?;
|
||||
let silu = SILU {};
|
||||
let conv_out = load_conv2d(&format!("{}/{}", path, "conv_out"), device)?;
|
||||
|
||||
Ok(Decoder { conv_in, mid, blocks, norm_out, silu, conv_out })
|
||||
Ok(Decoder {
|
||||
conv_in,
|
||||
mid,
|
||||
blocks,
|
||||
norm_out,
|
||||
silu,
|
||||
conv_out,
|
||||
})
|
||||
}
|
||||
|
||||
fn load_encoder<B: Backend>(path: &str, device: &B::Device) -> Result<Encoder<B>, Box<dyn Error>> {
|
||||
@@ -113,22 +163,36 @@ fn load_encoder<B: Backend>(path: &str, device: &B::Device) -> Result<Encoder<B>
|
||||
let n_block = load_usize::<B>("n_block", path, device)?;
|
||||
let mut blocks = (0..n_block)
|
||||
.into_iter()
|
||||
.map(|i| {
|
||||
load_encoder_block::<B>(&format!("{}/blocks/{}", path, i), device)
|
||||
}).collect::<Result<Vec<_>, _>>()?;
|
||||
.map(|i| load_encoder_block::<B>(&format!("{}/blocks/{}", path, i), device))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let norm_out = load_group_norm(&format!("{}/{}", path, "norm_out"), device)?;
|
||||
let silu = SILU {};
|
||||
let conv_out = load_conv2d(&format!("{}/{}", path, "conv_out"), device)?;
|
||||
|
||||
Ok(Encoder { conv_in, mid, blocks, norm_out, silu, conv_out })
|
||||
Ok(Encoder {
|
||||
conv_in,
|
||||
mid,
|
||||
blocks,
|
||||
norm_out,
|
||||
silu,
|
||||
conv_out,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn load_autoencoder<B: Backend>(path: &str, device: &B::Device) -> Result<Autoencoder<B>, Box<dyn Error>> {
|
||||
pub fn load_autoencoder<B: Backend>(
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<Autoencoder<B>, Box<dyn Error>> {
|
||||
let encoder = load_encoder(&format!("{}/{}", path, "encoder"), device)?;
|
||||
let decoder = load_decoder(&format!("{}/{}", path, "decoder"), device)?;
|
||||
let quant_conv = load_conv2d(&format!("{}/{}", path, "quant_conv"), device)?;
|
||||
let post_quant_conv = load_conv2d(&format!("{}/{}", path, "post_quant_conv"), device)?;
|
||||
|
||||
Ok(Autoencoder { encoder, decoder, quant_conv, post_quant_conv })
|
||||
}
|
||||
Ok(Autoencoder {
|
||||
encoder,
|
||||
decoder,
|
||||
quant_conv,
|
||||
post_quant_conv,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,59 +1,60 @@
|
||||
pub mod load;
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn::{self, PaddingConfig2d, conv::{Conv2d, Conv2dConfig, Conv2dRecord}},
|
||||
nn::{
|
||||
self,
|
||||
conv::{Conv2d, Conv2dConfig, Conv2dRecord},
|
||||
PaddingConfig2d,
|
||||
},
|
||||
tensor::{
|
||||
activation::{sigmoid, softmax},
|
||||
backend::Backend,
|
||||
activation::{softmax, sigmoid},
|
||||
module::embedding,
|
||||
Tensor,
|
||||
Distribution,
|
||||
Int,
|
||||
module::embedding,
|
||||
Distribution, Int, Tensor,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::helper::div_roundup;
|
||||
|
||||
use super::silu::*;
|
||||
use super::groupnorm::*;
|
||||
use super::attention::qkv_attention;
|
||||
use super::silu::*;
|
||||
//use crate::backend::Backend as MyBackend;
|
||||
use crate::backend::{qkv_attention, attn_decoder_mask};
|
||||
|
||||
use std::iter;
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
#[derive(Config, Debug)]
|
||||
pub struct AutoencoderConfig {}
|
||||
|
||||
impl AutoencoderConfig {
|
||||
pub fn init<B: Backend>(&self) -> Autoencoder<B> {
|
||||
let encoder = EncoderConfig::new(vec![(128, 128), (128, 256), (256, 512), (512, 512)], 32, 8).init();
|
||||
let decoder = DecoderConfig::new(vec![(512, 512), (512, 512), (512, 256), (256, 128)], 32).init();
|
||||
let quant_conv = Conv2dConfig::new([8, 8], [1, 1]).init();
|
||||
let post_quant_conv = Conv2dConfig::new([4, 4], [1, 1]).init();
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> Autoencoder<B> {
|
||||
let encoder =
|
||||
EncoderConfig::new(vec![(128, 128), (128, 256), (256, 512), (512, 512)], 32, 8).init(device);
|
||||
let decoder =
|
||||
DecoderConfig::new(vec![(512, 512), (512, 512), (512, 256), (256, 128)], 32).init(device);
|
||||
let quant_conv = Conv2dConfig::new([8, 8], [1, 1]).init(device);
|
||||
let post_quant_conv = Conv2dConfig::new([4, 4], [1, 1]).init(device);
|
||||
|
||||
Autoencoder {
|
||||
encoder,
|
||||
decoder,
|
||||
quant_conv,
|
||||
post_quant_conv,
|
||||
encoder,
|
||||
decoder,
|
||||
quant_conv,
|
||||
post_quant_conv,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Autoencoder<B: Backend> {
|
||||
encoder: Encoder<B>,
|
||||
decoder: Decoder<B>,
|
||||
quant_conv: Conv2d<B>,
|
||||
post_quant_conv: Conv2d<B>,
|
||||
encoder: Encoder<B>,
|
||||
decoder: Decoder<B>,
|
||||
quant_conv: Conv2d<B>,
|
||||
post_quant_conv: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Autoencoder<B> {
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
self.decode_latent( self.encode_image(x) )
|
||||
self.decode_latent(self.encode_image(x))
|
||||
}
|
||||
|
||||
pub fn encode_image(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
@@ -70,50 +71,62 @@ impl<B: Backend> Autoencoder<B> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
#[derive(Config, Debug)]
|
||||
pub struct EncoderConfig {
|
||||
channels: Vec<(usize, usize)>,
|
||||
n_group: usize,
|
||||
n_channels_out: usize,
|
||||
channels: Vec<(usize, usize)>,
|
||||
n_group: usize,
|
||||
n_channels_out: usize,
|
||||
}
|
||||
|
||||
impl EncoderConfig {
|
||||
fn init<B: Backend>(&self) -> Encoder<B> {
|
||||
let n_expanded_channels_initial = self.channels.first().map(|f| f.1).expect("Channels must not be empty.");
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> Encoder<B> {
|
||||
let n_expanded_channels_initial = self
|
||||
.channels
|
||||
.first()
|
||||
.map(|f| f.1)
|
||||
.expect("Channels must not be empty.");
|
||||
let n_expanded_channels_final = self.channels.first().unwrap().0;
|
||||
|
||||
let conv_in = Conv2dConfig::new([3, n_expanded_channels_initial], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
||||
let conv_in = Conv2dConfig::new([3, n_expanded_channels_initial], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init(device);
|
||||
|
||||
let blocks = self.channels.iter().enumerate().map(|(i, &(n_channel_in, n_channel_out))| {
|
||||
let downsample = i != self.channels.len() - 1;
|
||||
EncoderBlockConfig::new(n_channel_in, n_channel_out, downsample).init()
|
||||
}).collect();
|
||||
let blocks = self
|
||||
.channels
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &(n_channel_in, n_channel_out))| {
|
||||
let downsample = i != self.channels.len() - 1;
|
||||
EncoderBlockConfig::new(n_channel_in, n_channel_out, downsample).init(device)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mid = MidConfig::new(n_expanded_channels_final).init();
|
||||
let norm_out = GroupNormConfig::new(self.n_group, n_expanded_channels_final).init();
|
||||
let mid = MidConfig::new(n_expanded_channels_final).init(device);
|
||||
let norm_out = GroupNormConfig::new(self.n_group, n_expanded_channels_final).init(device);
|
||||
let silu = SILU::new();
|
||||
let conv_out = Conv2dConfig::new([n_expanded_channels_final, self.n_channels_out], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
||||
let conv_out = Conv2dConfig::new([n_expanded_channels_final, self.n_channels_out], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init(device);
|
||||
|
||||
Encoder {
|
||||
conv_in,
|
||||
mid,
|
||||
blocks,
|
||||
norm_out,
|
||||
silu,
|
||||
conv_out,
|
||||
conv_in,
|
||||
mid,
|
||||
blocks,
|
||||
norm_out,
|
||||
silu,
|
||||
conv_out,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Encoder<B: Backend> {
|
||||
conv_in: Conv2d<B>,
|
||||
mid: Mid<B>,
|
||||
blocks: Vec<EncoderBlock<B>>,
|
||||
norm_out: GroupNorm<B>,
|
||||
silu: SILU,
|
||||
conv_out: Conv2d<B>,
|
||||
conv_in: Conv2d<B>,
|
||||
mid: Mid<B>,
|
||||
blocks: Vec<EncoderBlock<B>>,
|
||||
norm_out: GroupNorm<B>,
|
||||
silu: SILU,
|
||||
conv_out: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Encoder<B> {
|
||||
@@ -126,55 +139,66 @@ impl<B: Backend> Encoder<B> {
|
||||
}
|
||||
|
||||
let x = self.mid.forward(x);
|
||||
self.conv_out.forward( self.silu.forward( self.norm_out.forward(x) ) )
|
||||
self.conv_out
|
||||
.forward(self.silu.forward(self.norm_out.forward(x)))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
#[derive(Config, Debug)]
|
||||
pub struct DecoderConfig {
|
||||
channels: Vec<(usize, usize)>,
|
||||
n_group: usize,
|
||||
channels: Vec<(usize, usize)>,
|
||||
n_group: usize,
|
||||
}
|
||||
|
||||
impl DecoderConfig {
|
||||
fn init<B: Backend>(&self) -> Decoder<B> {
|
||||
let n_expanded_channels = self.channels.first().map(|f| f.0).expect("Channels must not be empty.");
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> Decoder<B> {
|
||||
let n_expanded_channels = self
|
||||
.channels
|
||||
.first()
|
||||
.map(|f| f.0)
|
||||
.expect("Channels must not be empty.");
|
||||
let n_condensed_channels = self.channels.last().unwrap().1;
|
||||
|
||||
let conv_in = Conv2dConfig::new([4, n_expanded_channels], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
||||
let mid = MidConfig::new(n_expanded_channels).init();
|
||||
let conv_in = Conv2dConfig::new([4, n_expanded_channels], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init(device);
|
||||
let mid = MidConfig::new(n_expanded_channels).init(device);
|
||||
|
||||
let blocks = self.channels.iter().enumerate().map(|(i, &(n_channel_in, n_channel_out))| {
|
||||
let upsample = i != self.channels.len() - 1;
|
||||
DecoderBlockConfig::new(n_channel_in, n_channel_out, upsample).init()
|
||||
}).collect();
|
||||
let blocks = self
|
||||
.channels
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &(n_channel_in, n_channel_out))| {
|
||||
let upsample = i != self.channels.len() - 1;
|
||||
DecoderBlockConfig::new(n_channel_in, n_channel_out, upsample).init(device)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let norm_out = GroupNormConfig::new(self.n_group, n_condensed_channels).init();
|
||||
let norm_out = GroupNormConfig::new(self.n_group, n_condensed_channels).init(device);
|
||||
let silu = SILU::new();
|
||||
let conv_out = Conv2dConfig::new([n_condensed_channels, 3], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
||||
let conv_out = Conv2dConfig::new([n_condensed_channels, 3], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init(device);
|
||||
|
||||
Decoder {
|
||||
conv_in,
|
||||
mid,
|
||||
blocks,
|
||||
norm_out,
|
||||
silu,
|
||||
conv_out,
|
||||
conv_in,
|
||||
mid,
|
||||
blocks,
|
||||
norm_out,
|
||||
silu,
|
||||
conv_out,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Decoder<B: Backend> {
|
||||
conv_in: Conv2d<B>,
|
||||
mid: Mid<B>,
|
||||
blocks: Vec<DecoderBlock<B>>,
|
||||
norm_out: GroupNorm<B>,
|
||||
silu: SILU,
|
||||
conv_out: Conv2d<B>,
|
||||
conv_in: Conv2d<B>,
|
||||
mid: Mid<B>,
|
||||
blocks: Vec<DecoderBlock<B>>,
|
||||
norm_out: GroupNorm<B>,
|
||||
silu: SILU,
|
||||
conv_out: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Decoder<B> {
|
||||
@@ -187,41 +211,46 @@ impl<B: Backend> Decoder<B> {
|
||||
x = block.forward(x);
|
||||
}
|
||||
|
||||
self.conv_out.forward( self.silu.forward( self.norm_out.forward(x) ) )
|
||||
self.conv_out
|
||||
.forward(self.silu.forward(self.norm_out.forward(x)))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
#[derive(Config, Debug)]
|
||||
pub struct EncoderBlockConfig {
|
||||
n_channels_in: usize,
|
||||
n_channels_out: usize,
|
||||
downsample: bool,
|
||||
n_channels_in: usize,
|
||||
n_channels_out: usize,
|
||||
downsample: bool,
|
||||
}
|
||||
|
||||
impl EncoderBlockConfig {
|
||||
fn init<B: Backend>(&self) -> EncoderBlock<B> {
|
||||
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init();
|
||||
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> EncoderBlock<B> {
|
||||
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init(device);
|
||||
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(device);
|
||||
let downsampler = if self.downsample {
|
||||
let padding = Padding::new(0, 1, 0, 1);
|
||||
Some( PaddedConv2dConfig::new([self.n_channels_out, self.n_channels_out], 3, padding).with_stride(2).init() )
|
||||
let padding = PaddingCfg::new(0, 1, 0, 1);
|
||||
Some(
|
||||
PaddedConv2dConfig::new([self.n_channels_out, self.n_channels_out], 3, padding)
|
||||
.with_stride(2)
|
||||
.init(device),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
EncoderBlock {
|
||||
res1,
|
||||
res2,
|
||||
downsampler,
|
||||
res1,
|
||||
res2,
|
||||
downsampler,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct EncoderBlock<B: Backend> {
|
||||
res1: ResnetBlock<B>,
|
||||
res2: ResnetBlock<B>,
|
||||
downsampler: Option<PaddedConv2d<B>>,
|
||||
res1: ResnetBlock<B>,
|
||||
res2: ResnetBlock<B>,
|
||||
downsampler: Option<PaddedConv2d<B>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> EncoderBlock<B> {
|
||||
@@ -236,39 +265,43 @@ impl<B: Backend> EncoderBlock<B> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
#[derive(Config, Debug)]
|
||||
pub struct DecoderBlockConfig {
|
||||
n_channels_in: usize,
|
||||
n_channels_out: usize,
|
||||
upsample: bool,
|
||||
n_channels_in: usize,
|
||||
n_channels_out: usize,
|
||||
upsample: bool,
|
||||
}
|
||||
|
||||
impl DecoderBlockConfig {
|
||||
fn init<B: Backend>(&self) -> DecoderBlock<B> {
|
||||
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init();
|
||||
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
|
||||
let res3 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> DecoderBlock<B> {
|
||||
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init(device);
|
||||
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(device);
|
||||
let res3 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(device);
|
||||
let upsampler = if self.upsample {
|
||||
Some( Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init() )
|
||||
Some(
|
||||
Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init(device),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
DecoderBlock {
|
||||
res1,
|
||||
res2,
|
||||
res3,
|
||||
upsampler,
|
||||
res1,
|
||||
res2,
|
||||
res3,
|
||||
upsampler,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct DecoderBlock<B: Backend> {
|
||||
res1: ResnetBlock<B>,
|
||||
res2: ResnetBlock<B>,
|
||||
res3: ResnetBlock<B>,
|
||||
upsampler: Option<Conv2d<B>>,
|
||||
res1: ResnetBlock<B>,
|
||||
res2: ResnetBlock<B>,
|
||||
res3: ResnetBlock<B>,
|
||||
upsampler: Option<Conv2d<B>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> DecoderBlock<B> {
|
||||
@@ -280,10 +313,9 @@ impl<B: Backend> DecoderBlock<B> {
|
||||
if let Some(d) = self.upsampler.as_ref() {
|
||||
let [n_batch, n_channel, height, width] = x.dims();
|
||||
let x = x
|
||||
.reshape([n_batch, n_channel, height, 1, width, 1])
|
||||
.repeat(3, 2)
|
||||
.repeat(5, 2)
|
||||
.reshape([n_batch, n_channel, 2 * height, 2 * width]);
|
||||
.reshape([n_batch, n_channel, height, 1, width, 1])
|
||||
.repeat(&[1, 1, 1, 2, 1, 2])
|
||||
.reshape([n_batch, n_channel, 2 * height, 2 * width]);
|
||||
d.forward(x)
|
||||
} else {
|
||||
x
|
||||
@@ -291,18 +323,17 @@ impl<B: Backend> DecoderBlock<B> {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
#[derive(Config, Debug)]
|
||||
pub struct PaddedConv2dConfig {
|
||||
channels: [usize; 2],
|
||||
kernel_size: usize,
|
||||
channels: [usize; 2],
|
||||
kernel_size: usize,
|
||||
#[config(default = 1)]
|
||||
stride: usize,
|
||||
padding: Padding,
|
||||
stride: usize,
|
||||
padding: PaddingCfg,
|
||||
}
|
||||
|
||||
impl PaddedConv2dConfig {
|
||||
fn init<B: Backend>(&self) -> PaddedConv2d<B> {
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> PaddedConv2d<B> {
|
||||
let calc_padding = |p_left, p_right| {
|
||||
let n = if p_left >= p_right {
|
||||
0
|
||||
@@ -320,86 +351,106 @@ impl PaddedConv2dConfig {
|
||||
let conv = Conv2dConfig::new(self.channels, [self.kernel_size, self.kernel_size])
|
||||
.with_stride([self.stride, self.stride])
|
||||
.with_padding(PaddingConfig2d::Explicit(pad_vertical, pad_horizontal))
|
||||
.init();
|
||||
.init(device);
|
||||
|
||||
let kernel_size = self.kernel_size;
|
||||
let stride = self.stride;
|
||||
|
||||
let padding = self.padding;
|
||||
let padding = Padding {
|
||||
pad_left: self.padding.pad_left,
|
||||
pad_right: self.padding.pad_right,
|
||||
pad_top: self.padding.pad_top,
|
||||
pad_bottom: self.padding.pad_bottom,
|
||||
};
|
||||
|
||||
PaddedConv2d {
|
||||
conv,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
padding_actual,
|
||||
conv,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
padding_actual,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn div_roundup(x: usize, y: usize) -> usize {
|
||||
(x + y - 1) / y
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct PaddedConv2d<B: Backend> {
|
||||
conv: Conv2d<B>,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: Padding,
|
||||
padding_actual: [usize; 2],
|
||||
conv: Conv2d<B>,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: Padding,
|
||||
padding_actual: [usize; 2],
|
||||
}
|
||||
|
||||
impl<B: Backend> PaddedConv2d<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
println!("{} {} {:?} {:?}", self.kernel_size, self.stride, self.padding, self.padding_actual);
|
||||
let [n_batch, n_channel, height, width] = x.dims();
|
||||
|
||||
let desired_height = (self.padding.pad_top + self.padding.pad_bottom + height - self.kernel_size) / self.stride + 1;
|
||||
let desired_width = (self.padding.pad_left + self.padding.pad_right + width - self.kernel_size) / self.stride + 1;
|
||||
let desired_height = (self.padding.pad_top + self.padding.pad_bottom + height
|
||||
- self.kernel_size)
|
||||
/ self.stride
|
||||
+ 1;
|
||||
let desired_width = (self.padding.pad_left + self.padding.pad_right + width
|
||||
- self.kernel_size)
|
||||
/ self.stride
|
||||
+ 1;
|
||||
|
||||
let skip_vert = (self.padding_actual[0] - self.padding.pad_top) / self.stride;
|
||||
let skip_hor = (self.padding_actual[1] - self.padding.pad_left) / self.stride;
|
||||
|
||||
self.conv
|
||||
.forward(x)
|
||||
.slice([
|
||||
0..n_batch,
|
||||
0..n_channel,
|
||||
skip_vert..(skip_vert + desired_height),
|
||||
skip_hor..(skip_hor + desired_width)
|
||||
])
|
||||
self.conv.forward(x).slice([
|
||||
0..n_batch,
|
||||
0..n_channel,
|
||||
skip_vert..(skip_vert + desired_height),
|
||||
skip_hor..(skip_hor + desired_width),
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Config, Module, Copy, Debug)]
|
||||
pub struct Padding {
|
||||
pad_left: usize,
|
||||
pad_right: usize,
|
||||
pad_top: usize,
|
||||
#[derive(Config, Debug)]
|
||||
pub struct PaddingCfg {
|
||||
pad_left: usize,
|
||||
pad_right: usize,
|
||||
pad_top: usize,
|
||||
pad_bottom: usize,
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
#[derive(Module, Clone, Debug)]
|
||||
pub struct Padding {
|
||||
pad_left: usize,
|
||||
pad_right: usize,
|
||||
pad_top: usize,
|
||||
pad_bottom: usize,
|
||||
}
|
||||
|
||||
#[derive(Config, Debug)]
|
||||
pub struct MidConfig {
|
||||
n_channel: usize,
|
||||
n_channel: usize,
|
||||
}
|
||||
|
||||
impl MidConfig {
|
||||
fn init<B: Backend>(&self) -> Mid<B> {
|
||||
let block_1 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init();
|
||||
let attn = ConvSelfAttentionBlockConfig::new(self.n_channel).init();
|
||||
let block_2 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init();
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> Mid<B> {
|
||||
let block_1 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init(device);
|
||||
let attn = ConvSelfAttentionBlockConfig::new(self.n_channel).init(device);
|
||||
let block_2 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init(device);
|
||||
|
||||
Mid {
|
||||
block_1,
|
||||
attn,
|
||||
block_2,
|
||||
block_1,
|
||||
attn,
|
||||
block_2,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Mid<B: Backend> {
|
||||
block_1: ResnetBlock<B>,
|
||||
attn: ConvSelfAttentionBlock<B>,
|
||||
block_2: ResnetBlock<B>,
|
||||
block_1: ResnetBlock<B>,
|
||||
attn: ConvSelfAttentionBlock<B>,
|
||||
block_2: ResnetBlock<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Mid<B> {
|
||||
@@ -411,21 +462,24 @@ impl<B: Backend> Mid<B> {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
#[derive(Config, Debug)]
|
||||
pub struct ResnetBlockConfig {
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
}
|
||||
|
||||
impl ResnetBlockConfig {
|
||||
fn init<B: Backend>(&self) -> ResnetBlock<B> {
|
||||
let norm1 = GroupNormConfig::new(32, self.in_channels).init();
|
||||
let conv1 = Conv2dConfig::new([self.in_channels, self.out_channels], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
||||
let norm2 = GroupNormConfig::new(32, self.out_channels).init();
|
||||
let conv2 = Conv2dConfig::new([self.out_channels, self.out_channels], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> ResnetBlock<B> {
|
||||
let norm1 = GroupNormConfig::new(32, self.in_channels).init(device);
|
||||
let conv1 = Conv2dConfig::new([self.in_channels, self.out_channels], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init(device);
|
||||
let norm2 = GroupNormConfig::new(32, self.out_channels).init(device);
|
||||
let conv2 = Conv2dConfig::new([self.out_channels, self.out_channels], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init(device);
|
||||
let nin_shortcut = if self.in_channels != self.out_channels {
|
||||
Some( Conv2dConfig::new([self.in_channels, self.out_channels], [1, 1]).init() )
|
||||
Some(Conv2dConfig::new([self.in_channels, self.out_channels], [1, 1]).init(device))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -434,34 +488,37 @@ impl ResnetBlockConfig {
|
||||
let silu2 = SILU::new();
|
||||
|
||||
ResnetBlock {
|
||||
norm1,
|
||||
silu1,
|
||||
conv1,
|
||||
norm2,
|
||||
silu2,
|
||||
conv2,
|
||||
nin_shortcut,
|
||||
norm1,
|
||||
silu1,
|
||||
conv1,
|
||||
norm2,
|
||||
silu2,
|
||||
conv2,
|
||||
nin_shortcut,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ResnetBlock<B: Backend> {
|
||||
norm1: GroupNorm<B>,
|
||||
silu1: SILU,
|
||||
conv1: Conv2d<B>,
|
||||
norm2: GroupNorm<B>,
|
||||
silu2: SILU,
|
||||
conv2: Conv2d<B>,
|
||||
nin_shortcut: Option<Conv2d<B>>,
|
||||
norm1: GroupNorm<B>,
|
||||
silu1: SILU,
|
||||
conv1: Conv2d<B>,
|
||||
norm2: GroupNorm<B>,
|
||||
silu2: SILU,
|
||||
conv2: Conv2d<B>,
|
||||
nin_shortcut: Option<Conv2d<B>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ResnetBlock<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let h = self.conv1.forward( self.silu1.forward(self.norm1.forward(x.clone())) );
|
||||
let h = self.conv2.forward( self.silu2.forward(self.norm2.forward(h)) );
|
||||
let h = self
|
||||
.conv1
|
||||
.forward(self.silu1.forward(self.norm1.forward(x.clone())));
|
||||
let h = self
|
||||
.conv2
|
||||
.forward(self.silu2.forward(self.norm2.forward(h)));
|
||||
|
||||
|
||||
if let Some(ns) = self.nin_shortcut.as_ref() {
|
||||
ns.forward(x) + h
|
||||
} else {
|
||||
@@ -470,36 +527,36 @@ impl<B: Backend> ResnetBlock<B> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
#[derive(Config, Debug)]
|
||||
pub struct ConvSelfAttentionBlockConfig {
|
||||
n_channel: usize,
|
||||
n_channel: usize,
|
||||
}
|
||||
|
||||
impl ConvSelfAttentionBlockConfig {
|
||||
fn init<B: Backend>(&self) -> ConvSelfAttentionBlock<B> {
|
||||
let norm = GroupNormConfig::new(32, self.n_channel).init();
|
||||
let q = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
|
||||
let k = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
|
||||
let v = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
|
||||
let proj_out = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> ConvSelfAttentionBlock<B> {
|
||||
let norm = GroupNormConfig::new(32, self.n_channel).init(device);
|
||||
let q = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
|
||||
let k = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
|
||||
let v = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
|
||||
let proj_out = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
|
||||
|
||||
ConvSelfAttentionBlock {
|
||||
norm,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
proj_out,
|
||||
norm,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
proj_out,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ConvSelfAttentionBlock<B: Backend> {
|
||||
norm: GroupNorm<B>,
|
||||
q: Conv2d<B>,
|
||||
k: Conv2d<B>,
|
||||
v: Conv2d<B>,
|
||||
proj_out: Conv2d<B>,
|
||||
norm: GroupNorm<B>,
|
||||
q: Conv2d<B>,
|
||||
k: Conv2d<B>,
|
||||
v: Conv2d<B>,
|
||||
proj_out: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ConvSelfAttentionBlock<B> {
|
||||
@@ -508,13 +565,41 @@ impl<B: Backend> ConvSelfAttentionBlock<B> {
|
||||
|
||||
let h = self.norm.forward(x.clone());
|
||||
|
||||
let q = self.q.forward(h.clone()).reshape([n_batch, n_channel, height * width]).swap_dims(1, 2);
|
||||
let k = self.k.forward(h.clone()).reshape([n_batch, n_channel, height * width]).swap_dims(1, 2);
|
||||
let v = self.v.forward(h).reshape([n_batch, n_channel, height * width]).swap_dims(1, 2);
|
||||
let q = self
|
||||
.q
|
||||
.forward(h.clone())
|
||||
.reshape([n_batch, n_channel, height * width])
|
||||
.swap_dims(1, 2);
|
||||
let k = self
|
||||
.k
|
||||
.forward(h.clone())
|
||||
.reshape([n_batch, n_channel, height * width])
|
||||
.swap_dims(1, 2);
|
||||
let v = self
|
||||
.v
|
||||
.forward(h)
|
||||
.reshape([n_batch, n_channel, height * width])
|
||||
.swap_dims(1, 2);
|
||||
|
||||
let wv = qkv_attention(q, k, v, None, 1)
|
||||
.swap_dims(1, 2)
|
||||
.reshape([n_batch, n_channel, height, width]);
|
||||
/*let wv = Tensor::from_primitive(B::qkv_attention(
|
||||
q.into_primitive(),
|
||||
k.into_primitive(),
|
||||
v.into_primitive(),
|
||||
None,
|
||||
1,
|
||||
))
|
||||
.swap_dims(1, 2)
|
||||
.reshape([n_batch, n_channel, height, width]);*/
|
||||
|
||||
let wv = qkv_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
None,
|
||||
1,
|
||||
)
|
||||
.swap_dims(1, 2)
|
||||
.reshape([n_batch, n_channel, height, width]);
|
||||
|
||||
let projected = self.proj_out.forward(wv);
|
||||
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
use std::error::Error;
|
||||
use burn::tensor::ElementConversion;
|
||||
use std::error::Error;
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn,
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
Tensor,
|
||||
},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
@@ -28,7 +25,10 @@ pub fn load_mlp<B: Backend>(path: &str, device: &B::Device) -> Result<MLP<B>, Bo
|
||||
Ok(mlp)
|
||||
}
|
||||
|
||||
pub fn load_multi_head_self_attention<B: Backend>(path: &str, device: &B::Device) -> Result<MultiHeadSelfAttention<B>, Box<dyn Error>> {
|
||||
pub fn load_multi_head_self_attention<B: Backend>(
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<MultiHeadSelfAttention<B>, Box<dyn Error>> {
|
||||
let n_head = load_usize::<B>("n_head", path, device)?;
|
||||
let query = load_linear(&format!("{}/{}", path, "query"), device)?;
|
||||
let key = load_linear(&format!("{}/{}", path, "key"), device)?;
|
||||
@@ -46,7 +46,10 @@ pub fn load_multi_head_self_attention<B: Backend>(path: &str, device: &B::Device
|
||||
Ok(mhsa)
|
||||
}
|
||||
|
||||
pub fn load_residual_decoder_attention_block<B: Backend>(path: &str, device: &B::Device) -> Result<ResidualDecoderAttentionBlock<B>, Box<dyn Error>> {
|
||||
pub fn load_residual_decoder_attention_block<B: Backend>(
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<ResidualDecoderAttentionBlock<B>, Box<dyn Error>> {
|
||||
let mlp = load_mlp(&format!("{}/{}", path, "mlp"), device)?;
|
||||
let attn = load_multi_head_self_attention(&format!("{}/{}", path, "attn"), device)?;
|
||||
let attn_ln = load_layer_norm(&format!("{}/{}", path, "attn_ln"), device)?;
|
||||
@@ -64,15 +67,17 @@ pub fn load_residual_decoder_attention_block<B: Backend>(path: &str, device: &B:
|
||||
|
||||
pub fn load_clip<B: Backend>(path: &str, device: &B::Device) -> Result<CLIP<B>, Box<dyn Error>> {
|
||||
let token_embedding = load_embedding(&format!("{}/{}", path, "token_embedding"), device)?;
|
||||
let position_embedding = load_tensor("weight", &format!("{}/position_embedding", path), device)?.into();
|
||||
let position_embedding =
|
||||
Param::from_tensor(load_tensor("weight", &format!("{}/position_embedding", path), device)?);
|
||||
|
||||
let n_layer = load_usize::<B>("n_layer", path, device)?;
|
||||
let mut blocks = (0..n_layer)
|
||||
.into_iter()
|
||||
.map(|i| {
|
||||
load_residual_decoder_attention_block::<B>(&format!("{}/blocks/{}", path, i), device)
|
||||
}).collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let layer_norm = load_layer_norm(&format!("{}/{}", path, "layer_norm"), device)?;
|
||||
|
||||
let clip = CLIP {
|
||||
@@ -81,6 +86,6 @@ pub fn load_clip<B: Backend>(path: &str, device: &B::Device) -> Result<CLIP<B>,
|
||||
blocks: blocks,
|
||||
layer_norm: layer_norm,
|
||||
};
|
||||
|
||||
|
||||
Ok(clip)
|
||||
}
|
||||
|
||||
@@ -1,69 +1,71 @@
|
||||
pub mod load;
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn,
|
||||
tensor::{
|
||||
activation::{sigmoid, softmax},
|
||||
backend::Backend,
|
||||
activation::{softmax, sigmoid},
|
||||
module::embedding,
|
||||
Tensor,
|
||||
Distribution,
|
||||
Int,
|
||||
module::embedding,
|
||||
Distribution, Int, Tensor,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::model::attention::{qkv_attention, attn_decoder_mask};
|
||||
//use crate::backend::Backend as MyBackend;
|
||||
use crate::backend::{qkv_attention, attn_decoder_mask};
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
#[derive(Config, Debug)]
|
||||
pub struct CLIPConfig {
|
||||
n_vocab: usize,
|
||||
n_state: usize,
|
||||
n_head: usize,
|
||||
n_ctx: usize,
|
||||
n_layer: usize,
|
||||
n_vocab: usize,
|
||||
n_state: usize,
|
||||
n_head: usize,
|
||||
n_ctx: usize,
|
||||
n_layer: usize,
|
||||
}
|
||||
|
||||
impl CLIPConfig {
|
||||
pub fn init<B: Backend>(&self) -> CLIP<B> {
|
||||
let token_embedding = nn::EmbeddingConfig::new(self.n_vocab, self.n_state).init();
|
||||
let position_embedding = Tensor::random([self.n_ctx, self.n_state], Distribution::Normal(0.0, 1.0)).into();
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> CLIP<B> {
|
||||
let token_embedding = nn::EmbeddingConfig::new(self.n_vocab, self.n_state).init(device);
|
||||
let position_embedding =
|
||||
Param::from_tensor(Tensor::random([self.n_ctx, self.n_state], Distribution::Normal(0.0, 1.0), device));
|
||||
let blocks = (0..self.n_layer)
|
||||
.into_iter()
|
||||
.map(|_| ResidualDecoderAttentionBlockConfig::new(self.n_state, self.n_head).init())
|
||||
.map(|_| ResidualDecoderAttentionBlockConfig::new(self.n_state, self.n_head).init(device))
|
||||
.collect();
|
||||
let layer_norm = nn::LayerNormConfig::new(self.n_state).init();
|
||||
let layer_norm = nn::LayerNormConfig::new(self.n_state).init(device);
|
||||
|
||||
CLIP {
|
||||
token_embedding,
|
||||
position_embedding,
|
||||
blocks,
|
||||
layer_norm,
|
||||
token_embedding,
|
||||
position_embedding,
|
||||
blocks,
|
||||
layer_norm,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct CLIP<B: Backend> {
|
||||
token_embedding: nn::Embedding<B>,
|
||||
position_embedding: Param<Tensor<B, 2>>,
|
||||
blocks: Vec<ResidualDecoderAttentionBlock<B>>,
|
||||
layer_norm: nn::LayerNorm<B>,
|
||||
token_embedding: nn::Embedding<B>,
|
||||
position_embedding: Param<Tensor<B, 2>>,
|
||||
blocks: Vec<ResidualDecoderAttentionBlock<B>>,
|
||||
layer_norm: nn::LayerNorm<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> CLIP<B> {
|
||||
pub fn forward(&self, x: Tensor<B, 2, Int>) -> Tensor<B, 3> {
|
||||
let [n_batch, seq_len] = x.dims();
|
||||
|
||||
|
||||
//let mask = Tensor::from_primitive(B::attn_decoder_mask(seq_len, &x.device()));
|
||||
let mask = attn_decoder_mask(seq_len, &x.device());
|
||||
|
||||
let embedded = self.token_embedding.forward(x)
|
||||
+ self.position_embedding.val().slice([0..seq_len]).unsqueeze();
|
||||
|
||||
let embedded = self.token_embedding.forward(x)
|
||||
+ self
|
||||
.position_embedding
|
||||
.val()
|
||||
.slice([0..seq_len])
|
||||
.unsqueeze();
|
||||
|
||||
let mut x = embedded;
|
||||
for block in &self.blocks {
|
||||
x = block.forward(x, mask.clone());
|
||||
@@ -73,37 +75,35 @@ impl<B: Backend> CLIP<B> {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
#[derive(Config, Debug)]
|
||||
pub struct ResidualDecoderAttentionBlockConfig {
|
||||
n_state: usize,
|
||||
n_head: usize,
|
||||
n_state: usize,
|
||||
n_head: usize,
|
||||
}
|
||||
|
||||
impl ResidualDecoderAttentionBlockConfig {
|
||||
pub fn init<B: Backend>(&self) -> ResidualDecoderAttentionBlock<B> {
|
||||
let attn = MultiHeadSelfAttentionConfig::new(self.n_state, self.n_head).init();
|
||||
let attn_ln = nn::LayerNormConfig::new(self.n_state).init();
|
||||
|
||||
let mlp = MLPConfig::new(self.n_state, 4 * self.n_state).init();
|
||||
let mlp_ln = nn::LayerNormConfig::new(self.n_state).init();
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> ResidualDecoderAttentionBlock<B> {
|
||||
let attn = MultiHeadSelfAttentionConfig::new(self.n_state, self.n_head).init(device);
|
||||
let attn_ln = nn::LayerNormConfig::new(self.n_state).init(device);
|
||||
|
||||
let mlp = MLPConfig::new(self.n_state, 4 * self.n_state).init(device);
|
||||
let mlp_ln = nn::LayerNormConfig::new(self.n_state).init(device);
|
||||
|
||||
ResidualDecoderAttentionBlock {
|
||||
attn,
|
||||
attn_ln,
|
||||
mlp,
|
||||
mlp_ln,
|
||||
attn,
|
||||
attn_ln,
|
||||
mlp,
|
||||
mlp_ln,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ResidualDecoderAttentionBlock<B: Backend> {
|
||||
attn: MultiHeadSelfAttention<B>,
|
||||
attn_ln: nn::LayerNorm<B>,
|
||||
mlp: MLP<B>,
|
||||
mlp_ln: nn::LayerNorm<B>,
|
||||
attn: MultiHeadSelfAttention<B>,
|
||||
attn_ln: nn::LayerNorm<B>,
|
||||
mlp: MLP<B>,
|
||||
mlp_ln: nn::LayerNorm<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ResidualDecoderAttentionBlock<B> {
|
||||
@@ -114,39 +114,44 @@ impl<B: Backend> ResidualDecoderAttentionBlock<B> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
#[derive(Config, Debug)]
|
||||
pub struct MultiHeadSelfAttentionConfig {
|
||||
n_state: usize,
|
||||
n_head: usize,
|
||||
n_head: usize,
|
||||
}
|
||||
|
||||
impl MultiHeadSelfAttentionConfig {
|
||||
fn init<B: Backend>(&self) -> MultiHeadSelfAttention<B> {
|
||||
assert!(self.n_state % self.n_head == 0, "State size {} must be a multiple of head size {}", self.n_state, self.n_head);
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> MultiHeadSelfAttention<B> {
|
||||
assert!(
|
||||
self.n_state % self.n_head == 0,
|
||||
"State size {} must be a multiple of head size {}",
|
||||
self.n_state,
|
||||
self.n_head
|
||||
);
|
||||
|
||||
let n_head = self.n_head;
|
||||
let query = nn::LinearConfig::new(self.n_state, self.n_state).init();
|
||||
let key = nn::LinearConfig::new(self.n_state, self.n_state).init();
|
||||
let value = nn::LinearConfig::new(self.n_state, self.n_state).init();
|
||||
let out = nn::LinearConfig::new(self.n_state, self.n_state).init();
|
||||
let query = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
|
||||
let key = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
|
||||
let value = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
|
||||
let out = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
|
||||
|
||||
MultiHeadSelfAttention {
|
||||
n_head,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out
|
||||
MultiHeadSelfAttention {
|
||||
n_head,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct MultiHeadSelfAttention<B: Backend> {
|
||||
n_head: usize,
|
||||
query: nn::Linear<B>,
|
||||
key: nn::Linear<B>,
|
||||
value: nn::Linear<B>,
|
||||
out: nn::Linear<B>,
|
||||
n_head: usize,
|
||||
query: nn::Linear<B>,
|
||||
key: nn::Linear<B>,
|
||||
value: nn::Linear<B>,
|
||||
out: nn::Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> MultiHeadSelfAttention<B> {
|
||||
@@ -155,44 +160,47 @@ impl<B: Backend> MultiHeadSelfAttention<B> {
|
||||
let k = self.key.forward(x.clone());
|
||||
let v = self.value.forward(x);
|
||||
|
||||
let wv = qkv_attention(q, k, v, mask, self.n_head);
|
||||
/*let wv = Tensor::from_primitive(B::qkv_attention(
|
||||
q.into_primitive(),
|
||||
k.into_primitive(),
|
||||
v.into_primitive(),
|
||||
mask.map(|m| m.into_primitive()),
|
||||
self.n_head,
|
||||
));*/
|
||||
|
||||
let wv = qkv_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
mask,
|
||||
self.n_head,
|
||||
);
|
||||
|
||||
return self.out.forward(wv);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
#[derive(Config, Debug)]
|
||||
pub struct MLPConfig {
|
||||
input_size: usize,
|
||||
hidden_size: usize,
|
||||
input_size: usize,
|
||||
hidden_size: usize,
|
||||
}
|
||||
|
||||
impl MLPConfig {
|
||||
fn init<B: Backend>(&self) -> MLP<B> {
|
||||
let fc1 = nn::LinearConfig::new(self.input_size, self.hidden_size).init();
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> MLP<B> {
|
||||
let fc1 = nn::LinearConfig::new(self.input_size, self.hidden_size).init(device);
|
||||
let gelu = QuickGELU::new();
|
||||
let fc2 = nn::LinearConfig::new(self.hidden_size, self.input_size).init();
|
||||
let fc2 = nn::LinearConfig::new(self.hidden_size, self.input_size).init(device);
|
||||
|
||||
MLP {
|
||||
fc1,
|
||||
gelu,
|
||||
fc2,
|
||||
}
|
||||
MLP { fc1, gelu, fc2 }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct MLP<B: Backend> {
|
||||
fc1: nn::Linear<B>,
|
||||
gelu: QuickGELU,
|
||||
fc2: nn::Linear<B>,
|
||||
fc1: nn::Linear<B>,
|
||||
gelu: QuickGELU,
|
||||
fc2: nn::Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> MLP<B> {
|
||||
@@ -217,4 +225,3 @@ impl QuickGELU {
|
||||
x.clone() * sigmoid(x * 1.702)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,30 +4,34 @@ use crate::model::load::*;
|
||||
use std::error::Error;
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn,
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
Tensor,
|
||||
},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
pub fn load_group_norm<B: Backend>(path: &str, device: &B::Device) -> Result<GroupNorm<B>, Box<dyn Error>> {
|
||||
pub fn load_group_norm<B: Backend>(
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<GroupNorm<B>, Box<dyn Error>> {
|
||||
let n_group = load_usize::<B>("n_group", path, device)?.into();
|
||||
let n_channel = load_usize::<B>("n_channel", path, device)?.into();
|
||||
let eps = load_f32::<B>("eps", path, device)?.into();
|
||||
|
||||
let gamma = load_tensor::<B, 1>("weight", path, device).ok().unwrap_or_else(|| Tensor::ones_device([n_channel], device)).into();
|
||||
let beta = load_tensor::<B, 1>("bias", path, device).ok().unwrap_or_else(|| Tensor::zeros_device([n_channel], device)).into();
|
||||
let gamma = Param::from_tensor(load_tensor::<B, 1>("weight", path, device)
|
||||
.ok()
|
||||
.unwrap_or_else(|| Tensor::ones([n_channel], device))
|
||||
);
|
||||
let beta = Param::from_tensor(load_tensor::<B, 1>("bias", path, device)
|
||||
.ok()
|
||||
.unwrap_or_else(|| Tensor::zeros([n_channel], device))
|
||||
);
|
||||
|
||||
Ok(
|
||||
GroupNorm {
|
||||
n_group,
|
||||
n_channel,
|
||||
gamma,
|
||||
beta,
|
||||
eps,
|
||||
}
|
||||
)
|
||||
}
|
||||
Ok(GroupNorm {
|
||||
n_group,
|
||||
n_channel,
|
||||
gamma,
|
||||
beta,
|
||||
eps,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,50 +1,52 @@
|
||||
pub mod load;
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
Tensor,
|
||||
},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
#[derive(Config)]
|
||||
#[derive(Config, Debug)]
|
||||
pub struct GroupNormConfig {
|
||||
n_group: usize,
|
||||
n_channel: usize,
|
||||
n_group: usize,
|
||||
n_channel: usize,
|
||||
#[config(default = 1e-5)]
|
||||
eps: f64,
|
||||
eps: f64,
|
||||
}
|
||||
|
||||
impl GroupNormConfig {
|
||||
pub fn init<B: Backend>(&self) -> GroupNorm<B> {
|
||||
assert!(self.n_channel % self.n_group == 0, "The number of channels {} must be divisible by the number of groups {}", self.n_channel, self.n_group);
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> GroupNorm<B> {
|
||||
assert!(
|
||||
self.n_channel % self.n_group == 0,
|
||||
"The number of channels {} must be divisible by the number of groups {}",
|
||||
self.n_channel,
|
||||
self.n_group
|
||||
);
|
||||
|
||||
let n_per_group = self.n_channel / self.n_group;
|
||||
|
||||
let gamma = Tensor::ones([self.n_channel]).into();
|
||||
let beta = Tensor::zeros([self.n_channel]).into();
|
||||
let gamma = Param::from_tensor(Tensor::ones([self.n_channel], device));
|
||||
let beta = Param::from_tensor(Tensor::zeros([self.n_channel], device));
|
||||
|
||||
let eps = self.eps;
|
||||
|
||||
GroupNorm {
|
||||
n_group: self.n_group,
|
||||
n_channel: self.n_channel,
|
||||
gamma,
|
||||
beta,
|
||||
eps,
|
||||
n_group: self.n_group,
|
||||
n_channel: self.n_channel,
|
||||
gamma,
|
||||
beta,
|
||||
eps,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct GroupNorm<B: Backend> {
|
||||
n_group: usize,
|
||||
n_channel: usize,
|
||||
gamma: Param<Tensor<B, 1>>,
|
||||
beta: Param<Tensor<B, 1>>,
|
||||
eps: f64,
|
||||
n_group: usize,
|
||||
n_channel: usize,
|
||||
gamma: Param<Tensor<B, 1>>,
|
||||
beta: Param<Tensor<B, 1>>,
|
||||
eps: f64,
|
||||
}
|
||||
|
||||
impl<B: Backend> GroupNorm<B> {
|
||||
@@ -56,10 +58,17 @@ impl<B: Backend> GroupNorm<B> {
|
||||
let mut affine_shape = [1; D];
|
||||
affine_shape[1] = self.n_channel;
|
||||
|
||||
layernorm( x.reshape([n_batch, self.n_group, num_elements / (n_batch * self.n_group) ]), self.eps )
|
||||
.reshape(shape)
|
||||
.mul(self.gamma.val().reshape(affine_shape))
|
||||
.add(self.beta.val().reshape(affine_shape))
|
||||
layernorm(
|
||||
x.reshape([
|
||||
n_batch,
|
||||
self.n_group,
|
||||
num_elements / (n_batch * self.n_group),
|
||||
]),
|
||||
self.eps,
|
||||
)
|
||||
.reshape(shape)
|
||||
.mul(self.gamma.val().reshape(affine_shape))
|
||||
.add(self.beta.val().reshape(affine_shape))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,5 +77,6 @@ pub fn layernorm<B: Backend, const D: usize>(x: Tensor<B, D>, eps: f64) -> Tenso
|
||||
//x.sub(mean).div(var.sqrt().add_scalar(eps))
|
||||
|
||||
let u = x.clone() - x.mean_dim(D - 1);
|
||||
u.clone().div( (u.clone() * u).mean_dim(D - 1).add_scalar(eps).sqrt() )
|
||||
}
|
||||
u.clone()
|
||||
.div((u.clone() * u).mean_dim(D - 1).add_scalar(eps).sqrt())
|
||||
}
|
||||
|
||||
@@ -1,36 +1,41 @@
|
||||
use std::error::Error;
|
||||
use std::io::Read;
|
||||
use npy::{self, NpyData};
|
||||
use num_traits::cast::ToPrimitive;
|
||||
use burn::tensor::cast::ToElement;
|
||||
use burn::prelude::TensorData;
|
||||
use std::error::Error;
|
||||
use std::io::Read;
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn::{self, conv},
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
Tensor,
|
||||
Data,
|
||||
},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
use burn::tensor::ElementConversion;
|
||||
|
||||
pub fn numpy_to_tensor<B: Backend, const D: usize>(numpy_data: NpyData<f32>, device: &B::Device) -> Tensor<B, D> {
|
||||
pub fn numpy_to_tensor<B: Backend, const D: usize>(
|
||||
numpy_data: NpyData<f32>,
|
||||
device: &B::Device,
|
||||
) -> Tensor<B, D> {
|
||||
let mut v = numpy_data.to_vec();
|
||||
|
||||
let shape: Vec<_> = v[0..D].into_iter().map(|&v| v as usize).collect();
|
||||
let data: Vec<B::FloatElem> = v[D..].into_iter().map(|e| e.elem()).collect();
|
||||
|
||||
Tensor::from_data_device(Data::new(data, shape.into()), device)
|
||||
|
||||
//Tensor::from_data_device(Data::new(data, shape.into()), device)
|
||||
Tensor::from_data(TensorData::new(data, shape), device)
|
||||
}
|
||||
|
||||
pub fn load_tensor<B: Backend, const D: usize>(name: &str, path: &str, device: &B::Device) -> Result<Tensor<B, D>, Box<dyn Error>> {
|
||||
pub fn load_tensor<B: Backend, const D: usize>(
|
||||
name: &str,
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<Tensor<B, D>, Box<dyn Error>> {
|
||||
let tensor_path = format!("{}/{}.npy", path, name);
|
||||
|
||||
let mut buf = vec![];
|
||||
std::fs::File::open(&tensor_path)?
|
||||
.read_to_end(&mut buf)?;
|
||||
std::fs::File::open(&tensor_path)?.read_to_end(&mut buf)?;
|
||||
|
||||
let tensor_numpy: NpyData<f32> = NpyData::from_bytes(&buf)?;
|
||||
|
||||
@@ -41,71 +46,79 @@ pub fn load_tensor<B: Backend, const D: usize>(name: &str, path: &str, device: &
|
||||
Ok(tensor)
|
||||
}
|
||||
|
||||
pub fn load_f32<B: Backend>(name: &str, path: &str, device: &B::Device) -> Result<f32, Box<dyn Error>> {
|
||||
load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_f32().unwrap())
|
||||
pub fn load_f32<B: Backend>(
|
||||
name: &str,
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<f32, Box<dyn Error>> {
|
||||
load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_f32())
|
||||
}
|
||||
|
||||
pub fn load_usize<B: Backend>(name: &str, path: &str, device: &B::Device) -> Result<usize, Box<dyn Error>> {
|
||||
load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_usize().unwrap())
|
||||
pub fn load_usize<B: Backend>(
|
||||
name: &str,
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<usize, Box<dyn Error>> {
|
||||
load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_usize())
|
||||
}
|
||||
|
||||
pub fn load_linear<B: Backend>(path: &str, device: &B::Device) -> Result<nn::Linear<B>, Box<dyn Error>> {
|
||||
pub fn load_linear<B: Backend>(
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<nn::Linear<B>, Box<dyn Error>> {
|
||||
let weight = load_tensor::<B, 2>("weight", path, device)?;
|
||||
let bias = load_tensor::<B, 1>("bias", path, device).ok();
|
||||
|
||||
let record = nn::LinearRecord {
|
||||
weight: weight.into(),
|
||||
bias: bias.map(|t| t.into()),
|
||||
};
|
||||
|
||||
let linear: nn::Linear<B> = nn::LinearConfig::new(3, 3).init_with(record);
|
||||
Ok(linear)
|
||||
Ok(nn::Linear {
|
||||
weight: Param::from_tensor(weight),
|
||||
bias: bias.map(|t| Param::from_tensor(t)),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn load_embedding<B: Backend>(path: &str, device: &B::Device) -> Result<nn::Embedding<B>, Box<dyn Error>> {
|
||||
pub fn load_embedding<B: Backend>(
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<nn::Embedding<B>, Box<dyn Error>> {
|
||||
let weight = load_tensor::<B, 2>("weight", path, device)?;
|
||||
let [n_vocab, n_state] = weight.dims();
|
||||
|
||||
let record = nn::EmbeddingRecord {
|
||||
weight: weight.into(),
|
||||
};
|
||||
|
||||
let embedding = nn::EmbeddingConfig::new(n_vocab, n_state).init_with(record);
|
||||
Ok(embedding)
|
||||
Ok(nn::Embedding {
|
||||
weight: Param::from_tensor(weight),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn load_layer_norm<B: Backend>(path: &str, device: &B::Device) -> Result<nn::LayerNorm<B>, Box<dyn Error>> {
|
||||
pub fn load_layer_norm<B: Backend>(
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<nn::LayerNorm<B>, Box<dyn Error>> {
|
||||
let weight = load_tensor::<B, 1>("weight", path, device)?;
|
||||
let bias = load_tensor::<B, 1>("bias", path, device)?;
|
||||
let eps = load_f32::<B>("eps", path, device)? as f64;
|
||||
|
||||
let [n_state] = weight.dims();
|
||||
|
||||
let record = nn::LayerNormRecord {
|
||||
gamma: weight.into(),
|
||||
beta: bias.into(),
|
||||
epsilon: <f64 as Module<B>>::into_record(eps),
|
||||
};
|
||||
|
||||
let layer_norm: nn::LayerNorm<B> = nn::LayerNormConfig::new(n_state).init_with(record);
|
||||
let mut layer_norm = nn::LayerNormConfig::new(n_state).with_epsilon(eps).init(device);
|
||||
layer_norm.gamma = Param::from_tensor(weight);
|
||||
layer_norm.beta = Some(Param::from_tensor(bias));
|
||||
|
||||
Ok(layer_norm)
|
||||
}
|
||||
|
||||
|
||||
/*pub fn load_rmsnorm<B: Backend>(path: &str, device: &B::Device) -> Result<RMSNorm<B>, Box<dyn Error>> {
|
||||
let weight = load_tensor::<B, 1>("weight", path, device)?;
|
||||
let eps = load_f32::<B>("eps", path, device)?.into();
|
||||
|
||||
let rmsnorm = RMSNorm {
|
||||
weight: weight.into(),
|
||||
let rmsnorm = RMSNorm {
|
||||
weight: Param::from_tensor(weight),
|
||||
eps: eps
|
||||
};
|
||||
|
||||
|
||||
Ok(rmsnorm)
|
||||
}*/
|
||||
|
||||
pub fn load_conv2d<B: Backend>(path: &str, device: &B::Device) -> Result<conv::Conv2d<B>, Box<dyn Error>> {
|
||||
pub fn load_conv2d<B: Backend>(
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<conv::Conv2d<B>, Box<dyn Error>> {
|
||||
let weight = load_tensor::<B, 4>("weight", path, device)?;
|
||||
let bias = load_tensor::<B, 1>("bias", path, device).ok();
|
||||
let has_bias = bias.is_some();
|
||||
@@ -127,41 +140,39 @@ pub fn load_conv2d<B: Backend>(path: &str, device: &B::Device) -> Result<conv::C
|
||||
let padding = tensor_to_array_2(padding);
|
||||
let padding = nn::PaddingConfig2d::Explicit(padding[0], padding[1]);
|
||||
|
||||
|
||||
let record = conv::Conv2dRecord {
|
||||
weight: weight.into(),
|
||||
bias: bias.map(|t| t.into()),
|
||||
stride: <[usize; 2] as Module<B>>::into_record(stride),
|
||||
kernel_size: <[usize; 2] as Module<B>>::into_record(kernel_size),
|
||||
dilation: <[usize; 2] as Module<B>>::into_record(dilation),
|
||||
groups: <usize as Module<B>>::into_record(n_group),
|
||||
padding: <nn::PaddingConfig2d as Module<B>>::into_record(padding.clone()),
|
||||
};
|
||||
let mut conv2d = conv::Conv2dConfig::new([n_channels_in, n_channels_out], kernel_size)
|
||||
.with_stride(stride)
|
||||
.with_dilation(dilation)
|
||||
.with_groups(n_group)
|
||||
.with_padding(padding.clone())
|
||||
.with_bias(has_bias)
|
||||
.init(device);
|
||||
|
||||
let conv2d: conv::Conv2d<B> = conv::Conv2dConfig::new([n_channels_in, n_channels_out], kernel_size)
|
||||
.with_stride(stride)
|
||||
.with_dilation(dilation)
|
||||
.with_groups(n_group)
|
||||
.with_padding(padding)
|
||||
.with_bias(has_bias)
|
||||
.init_with(record);
|
||||
conv2d.weight = Param::from_tensor(weight);
|
||||
conv2d.bias = bias.map(|t| Param::from_tensor(t));
|
||||
conv2d.stride = stride;
|
||||
conv2d.kernel_size = kernel_size;
|
||||
conv2d.dilation = dilation;
|
||||
conv2d.groups = n_group;
|
||||
conv2d.padding = burn::module::Ignored(padding);
|
||||
|
||||
Ok(conv2d)
|
||||
}
|
||||
|
||||
pub fn tensor_to_array_2<B: Backend>(x: Tensor<B, 1>) -> [usize; 2] {
|
||||
let vec = x.into_data().value;
|
||||
let vec: Vec<<B as Backend>::FloatElem> = x.into_data().to_vec().unwrap();
|
||||
assert!(vec.len() == 2, "Tensor length must be 2.");
|
||||
[vec[0].to_usize().unwrap(), vec[1].to_usize().unwrap()]
|
||||
[vec[0].to_usize(), vec[1].to_usize()]
|
||||
}
|
||||
|
||||
pub fn tensor_to_array<const N: usize, B: Backend>(x: Tensor<B, 1>) -> [usize; N] {
|
||||
let vec = x.into_data().value;
|
||||
let vec: Vec<<B as Backend>::FloatElem> = x.into_data().to_vec().unwrap();
|
||||
assert!(vec.len() == N, "Tensor length must be {}.", N);
|
||||
|
||||
let mut arr = [0; N];
|
||||
for (a, t) in arr.iter_mut().zip(vec) {
|
||||
*a = t.to_usize().unwrap();
|
||||
*a = t.to_usize();
|
||||
}
|
||||
|
||||
arr
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
pub mod stablediffusion;
|
||||
|
||||
pub mod autoencoder;
|
||||
pub mod unet;
|
||||
pub mod clip;
|
||||
pub mod unet;
|
||||
|
||||
pub mod silu;
|
||||
pub mod groupnorm;
|
||||
pub mod attention;
|
||||
pub mod groupnorm;
|
||||
pub mod silu;
|
||||
|
||||
pub mod load;
|
||||
pub mod load;
|
||||
|
||||
@@ -1,13 +1,8 @@
|
||||
use burn::{
|
||||
module::Module,
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
activation::sigmoid,
|
||||
Tensor,
|
||||
},
|
||||
tensor::{activation::sigmoid, backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
|
||||
#[derive(Module, Clone, Debug)]
|
||||
pub struct SILU {}
|
||||
|
||||
@@ -19,4 +14,4 @@ impl SILU {
|
||||
pub fn forward<B: Backend, const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
|
||||
x.clone() * sigmoid(x)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,32 +1,33 @@
|
||||
use std::error::Error;
|
||||
use burn::tensor::ElementConversion;
|
||||
use std::error::Error;
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn,
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
Tensor,
|
||||
},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
use crate::model::{load::*, autoencoder::load::load_autoencoder, unet::load::load_unet, clip::load::load_clip};
|
||||
use crate::model::{
|
||||
autoencoder::load::load_autoencoder, clip::load::load_clip, load::*, unet::load::load_unet,
|
||||
};
|
||||
|
||||
pub fn load_stable_diffusion<B: Backend>(path: &str, device: &B::Device) -> Result<StableDiffusion<B>, Box<dyn Error>> {
|
||||
pub fn load_stable_diffusion<B: Backend>(
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<StableDiffusion<B>, Box<dyn Error>> {
|
||||
let n_steps = load_usize::<B>("n_steps", path, device)?;
|
||||
let alpha_cumulative_products = load_tensor::<B, 1>("alphas_cumprod", path, device)?.into();
|
||||
let alpha_cumulative_products = Param::from_tensor(load_tensor::<B, 1>("alphas_cumprod", path, device)?);
|
||||
let autoencoder = load_autoencoder(&format!("{}/{}", path, "autoencoder"), device)?;
|
||||
let diffusion = load_unet(&format!("{}/{}", path, "unet"), device)?;
|
||||
let clip = load_clip(&format!("{}/{}", path, "clip"), device)?;
|
||||
|
||||
Ok(StableDiffusion {
|
||||
n_steps,
|
||||
alpha_cumulative_products,
|
||||
autoencoder,
|
||||
diffusion,
|
||||
clip,
|
||||
n_steps,
|
||||
alpha_cumulative_products,
|
||||
autoencoder,
|
||||
diffusion,
|
||||
clip,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -1,64 +1,73 @@
|
||||
pub mod load;
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
Tensor,
|
||||
Int,
|
||||
Float,
|
||||
BasicOps,
|
||||
Data,
|
||||
Distribution,
|
||||
},
|
||||
tensor::{backend::Backend, BasicOps, Distribution, Float, Int, Tensor},
|
||||
tensor::cast::ToElement,
|
||||
};
|
||||
|
||||
use num_traits::ToPrimitive;
|
||||
|
||||
//use crate::backend::Backend as MyBackend;
|
||||
|
||||
use super::autoencoder::{Autoencoder, AutoencoderConfig};
|
||||
use super::clip::{CLIPConfig, CLIP};
|
||||
use super::unet::{UNet, UNetConfig};
|
||||
use super::clip::{CLIP, CLIPConfig};
|
||||
use crate::tokenizer::SimpleTokenizer;
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct StableDiffusionConfig {
|
||||
|
||||
}
|
||||
#[derive(Config, Debug)]
|
||||
pub struct StableDiffusionConfig {}
|
||||
|
||||
impl StableDiffusionConfig {
|
||||
pub fn init<B: Backend>(&self) -> StableDiffusion<B> {
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> StableDiffusion<B> {
|
||||
let n_steps = 1000;
|
||||
let alpha_cumulative_products = offset_cosine_schedule_cumprod::<B>(n_steps).into();
|
||||
let alpha_cumulative_products = Param::from_tensor(offset_cosine_schedule_cumprod::<B>(n_steps as i64, device));
|
||||
|
||||
let autoencoder = AutoencoderConfig::new().init();
|
||||
let diffusion = UNetConfig::new().init();
|
||||
let clip = CLIPConfig::new(49408, 768, 12, 77, 12).init();
|
||||
let autoencoder = AutoencoderConfig::new().init(device);
|
||||
let diffusion = UNetConfig::new().init(device);
|
||||
let clip = CLIPConfig::new(49408, 768, 12, 77, 12).init(device);
|
||||
|
||||
StableDiffusion {
|
||||
n_steps,
|
||||
alpha_cumulative_products,
|
||||
autoencoder,
|
||||
diffusion,
|
||||
clip,
|
||||
n_steps,
|
||||
alpha_cumulative_products,
|
||||
autoencoder,
|
||||
diffusion,
|
||||
clip,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct StableDiffusion<B: Backend> {
|
||||
n_steps: usize,
|
||||
alpha_cumulative_products: Param<Tensor<B, 1>>,
|
||||
autoencoder: Autoencoder<B>,
|
||||
diffusion: UNet<B>,
|
||||
clip: CLIP<B>,
|
||||
n_steps: usize,
|
||||
alpha_cumulative_products: Param<Tensor<B, 1>>,
|
||||
autoencoder: Autoencoder<B>,
|
||||
diffusion: UNet<B>,
|
||||
clip: CLIP<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> StableDiffusion<B> {
|
||||
pub fn sample_image(&self, context: Tensor<B, 3>, unconditional_context: Tensor<B, 2>, unconditional_guidance_scale: f64, n_steps: usize) -> Vec<Vec<u8>> {
|
||||
pub fn sample_image(
|
||||
&self,
|
||||
context: Tensor<B, 3>,
|
||||
unconditional_context: Tensor<B, 2>,
|
||||
unconditional_guidance_scale: f64,
|
||||
n_steps: usize,
|
||||
) -> Vec<Vec<u8>> {
|
||||
let [n_batch, _, _] = context.dims();
|
||||
|
||||
let latent = self.sample_latent(context, unconditional_context, unconditional_guidance_scale, n_steps);
|
||||
let latent = self.sample_latent(
|
||||
context,
|
||||
unconditional_context,
|
||||
unconditional_guidance_scale,
|
||||
n_steps,
|
||||
);
|
||||
self.latent_to_image(latent)
|
||||
}
|
||||
|
||||
pub fn latent_to_image(&self, latent: Tensor<B, 4>) -> Vec<Vec<u8>> {
|
||||
let [n_batch, _, _, _] = latent.dims();
|
||||
let image = self.autoencoder.decode_latent(latent * (1.0 / 0.18215));
|
||||
|
||||
let n_channel = 3;
|
||||
@@ -66,7 +75,7 @@ impl<B: Backend> StableDiffusion<B> {
|
||||
let width = 512;
|
||||
let num_elements_per_image = n_channel * height * width;
|
||||
|
||||
// correct size and scale and reorder to
|
||||
// correct size and scale and reorder to
|
||||
let image = (image + 1.0) / 2.0;
|
||||
let image = image
|
||||
.reshape([n_batch, n_channel, height, width])
|
||||
@@ -74,19 +83,29 @@ impl<B: Backend> StableDiffusion<B> {
|
||||
.swap_dims(2, 3)
|
||||
.mul_scalar(255.0);
|
||||
|
||||
let flattened: Vec<_> = image.
|
||||
into_data().
|
||||
value;
|
||||
let flattened: Vec<B::FloatElem> = image.into_data().to_vec().unwrap();
|
||||
|
||||
(0..n_batch).into_iter().map(|b| {
|
||||
let start = b * num_elements_per_image;
|
||||
let end = start + num_elements_per_image;
|
||||
(0..n_batch)
|
||||
.into_iter()
|
||||
.map(|b| {
|
||||
let start = b * num_elements_per_image;
|
||||
let end = start + num_elements_per_image;
|
||||
|
||||
flattened[start..end].into_iter().map(|v| v.to_f64().unwrap().min(255.0).max(0.0).to_u8().unwrap()).collect()
|
||||
}).collect()
|
||||
flattened[start..end]
|
||||
.into_iter()
|
||||
.map(|v| v.to_f64().min(255.0).max(0.0) as u8)
|
||||
.collect()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn sample_latent(&self, context: Tensor<B, 3>, unconditional_context: Tensor<B, 2>, unconditional_guidance_scale: f64, n_steps: usize) -> Tensor<B, 4> {
|
||||
pub fn sample_latent(
|
||||
&self,
|
||||
context: Tensor<B, 3>,
|
||||
unconditional_context: Tensor<B, 2>,
|
||||
unconditional_guidance_scale: f64,
|
||||
n_steps: usize,
|
||||
) -> Tensor<B, 4> {
|
||||
let device = context.device();
|
||||
|
||||
let step_size = self.n_steps / n_steps;
|
||||
@@ -94,7 +113,7 @@ impl<B: Backend> StableDiffusion<B> {
|
||||
let [n_batches, _, _] = context.dims();
|
||||
|
||||
let gen_noise = || {
|
||||
Tensor::random([n_batches, 4, 64, 64], Distribution::Normal(0.0, 1.0)).to_device(&device)
|
||||
Tensor::random([n_batches, 4, 64, 64], Distribution::Normal(0.0, 1.0), &device)
|
||||
};
|
||||
|
||||
let sigma = 0.0; // Use deterministic diffusion
|
||||
@@ -102,18 +121,34 @@ impl<B: Backend> StableDiffusion<B> {
|
||||
let mut latent = gen_noise();
|
||||
|
||||
for t in (0..self.n_steps).rev().step_by(step_size) {
|
||||
let current_alpha: f64 = self.alpha_cumulative_products.val().slice([t..t + 1]).into_scalar().to_f64().unwrap();
|
||||
let current_alpha: f64 = self
|
||||
.alpha_cumulative_products
|
||||
.val()
|
||||
.slice([t..t + 1])
|
||||
.into_scalar()
|
||||
.to_f64();
|
||||
|
||||
let prev_alpha: f64 = if t >= step_size {
|
||||
let i = t - step_size;
|
||||
self.alpha_cumulative_products.val().slice([i..i + 1]).into_scalar().to_f64().unwrap()
|
||||
self.alpha_cumulative_products
|
||||
.val()
|
||||
.slice([i..i + 1])
|
||||
.into_scalar()
|
||||
.to_f64()
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
|
||||
let sqrt_noise = (1.0 - current_alpha).sqrt();
|
||||
|
||||
let timestep = Tensor::from_ints([t as i32]).to_device(&device);
|
||||
let pred_noise = self.forward_diffuser(latent.clone(), timestep, context.clone(), unconditional_context.clone(), unconditional_guidance_scale);
|
||||
let timestep = Tensor::from_ints([t as i32], &device);
|
||||
let pred_noise = self.forward_diffuser(
|
||||
latent.clone(),
|
||||
timestep,
|
||||
context.clone(),
|
||||
unconditional_context.clone(),
|
||||
unconditional_guidance_scale,
|
||||
);
|
||||
let predx0 = (latent - pred_noise.clone() * sqrt_noise) / current_alpha.sqrt();
|
||||
let dir_latent = pred_noise * (1.0 - prev_alpha - sigma * sigma).sqrt();
|
||||
|
||||
@@ -124,68 +159,79 @@ impl<B: Backend> StableDiffusion<B> {
|
||||
latent
|
||||
}
|
||||
|
||||
fn forward_diffuser(&self, latent: Tensor<B, 4>, timestep: Tensor<B, 1, Int>, context: Tensor<B, 3>, unconditional_context: Tensor<B, 2>, unconditional_guidance_scale: f64) -> Tensor<B, 4> {
|
||||
fn forward_diffuser(
|
||||
&self,
|
||||
latent: Tensor<B, 4>,
|
||||
timestep: Tensor<B, 1, Int>,
|
||||
context: Tensor<B, 3>,
|
||||
unconditional_context: Tensor<B, 2>,
|
||||
unconditional_guidance_scale: f64,
|
||||
) -> Tensor<B, 4> {
|
||||
let [n_batch, _, _, _] = latent.dims();
|
||||
//let latent = latent.repeat(0, 2);
|
||||
|
||||
let unconditional_latent = self.diffusion.forward(
|
||||
latent.clone(),
|
||||
timestep.clone(),
|
||||
unconditional_context.unsqueeze().repeat(0, n_batch)
|
||||
latent.clone(),
|
||||
timestep.clone(),
|
||||
unconditional_context.unsqueeze().repeat(&[0, n_batch]),
|
||||
);
|
||||
|
||||
let conditional_latent = self.diffusion.forward(
|
||||
latent,
|
||||
timestep,
|
||||
context
|
||||
);
|
||||
let conditional_latent = self.diffusion.forward(latent, timestep, context);
|
||||
|
||||
/*let latent = self.diffusion.forward(
|
||||
latent.repeat(0, 2),
|
||||
timestep.repeat(0, 2),
|
||||
latent.repeat(0, 2),
|
||||
timestep.repeat(0, 2),
|
||||
Tensor::cat(vec![unconditional_context.unsqueeze::<3>(), context], 0)
|
||||
);
|
||||
|
||||
let unconditional_latent = latent.clone().slice([0..n_batch]);
|
||||
let conditional_latent = latent.slice([n_batch..2 * n_batch]);*/
|
||||
|
||||
unconditional_latent.clone() + (conditional_latent - unconditional_latent) * unconditional_guidance_scale
|
||||
unconditional_latent.clone()
|
||||
+ (conditional_latent - unconditional_latent) * unconditional_guidance_scale
|
||||
}
|
||||
|
||||
pub fn unconditional_context(&self, tokenizer: &SimpleTokenizer) -> Tensor<B, 2> {
|
||||
self.context(tokenizer, "").squeeze(0)
|
||||
self.context(tokenizer, "").squeeze::<2>()
|
||||
}
|
||||
|
||||
pub fn context(&self, tokenizer: &SimpleTokenizer, text: &str) -> Tensor<B, 3> {
|
||||
let device = &self.devices()[0];
|
||||
let device = &self.clip.devices()[0];
|
||||
let text = format!("<|startoftext|>{}<|endoftext|>", text);
|
||||
let tokenized: Vec<_> = tokenizer.encode(&text).into_iter().map(|v| v as i32).collect();
|
||||
let tokenized: Vec<_> = tokenizer
|
||||
.encode(&text)
|
||||
.into_iter()
|
||||
.map(|v| v as i32)
|
||||
.collect();
|
||||
|
||||
self.clip.forward(Tensor::from_ints(&tokenized[..]).to_device(device).unsqueeze())
|
||||
self.clip.forward(
|
||||
Tensor::<B, 1, Int>::from_ints(&tokenized[..], device)
|
||||
.unsqueeze(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
use crate::helper::to_float;
|
||||
use std::f64::consts::PI;
|
||||
|
||||
fn cosine_schedule<B: Backend>(n_steps: usize) -> Tensor<B, 1> {
|
||||
to_float(Tensor::arange(1..n_steps + 1))
|
||||
fn cosine_schedule<B: Backend>(n_steps: i64, device: &B::Device) -> Tensor<B, 1> {
|
||||
Tensor::arange(1..n_steps + 1, device)
|
||||
.float()
|
||||
.mul_scalar(PI * 0.5 / n_steps as f64)
|
||||
.cos()
|
||||
}
|
||||
|
||||
fn offset_cosine_schedule<B: Backend>(n_steps: usize) -> Tensor<B, 1> {
|
||||
fn offset_cosine_schedule<B: Backend>(n_steps: i64, device: &B::Device) -> Tensor<B, 1> {
|
||||
let min_signal_rate: f64 = 0.02;
|
||||
let max_signal_rate: f64 = 0.95;
|
||||
let start_angle = max_signal_rate.acos();
|
||||
let end_angle = min_signal_rate.acos();
|
||||
|
||||
let times = Tensor::arange(1..n_steps + 1);
|
||||
let times = Tensor::arange(1..n_steps + 1, device).float();
|
||||
|
||||
let diffusion_angles = to_float(times) * ( (end_angle - start_angle) / n_steps as f64) + start_angle;
|
||||
let diffusion_angles = times * ((end_angle - start_angle) / n_steps as f64) + start_angle;
|
||||
diffusion_angles.cos()
|
||||
}
|
||||
|
||||
fn offset_cosine_schedule_cumprod<B: Backend>(n_steps: usize) -> Tensor<B, 1> {
|
||||
offset_cosine_schedule::<B>(n_steps).powf(2.0)
|
||||
}
|
||||
fn offset_cosine_schedule_cumprod<B: Backend>(n_steps: i64, device: &B::Device) -> Tensor<B, 1> {
|
||||
offset_cosine_schedule::<B>(n_steps, device).powf_scalar(2.0)
|
||||
}
|
||||
|
||||
@@ -4,19 +4,19 @@ use crate::model::load::*;
|
||||
use std::error::Error;
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn,
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
Tensor,
|
||||
},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
use crate::model::groupnorm::load::load_group_norm;
|
||||
|
||||
pub fn load_res_block<B: Backend>(path: &str, device: &B::Device) -> Result<ResBlock<B>, Box<dyn Error>> {
|
||||
pub fn load_res_block<B: Backend>(
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<ResBlock<B>, Box<dyn Error>> {
|
||||
let norm_in = load_group_norm::<B>(&format!("{}/{}", path, "norm_in"), device)?;
|
||||
let conv_in = load_conv2d::<B>(&format!("{}/{}", path, "conv_in"), device)?;
|
||||
let lin_embed = load_linear::<B>(&format!("{}/{}", path, "lin_embed"), device)?;
|
||||
@@ -26,12 +26,12 @@ pub fn load_res_block<B: Backend>(path: &str, device: &B::Device) -> Result<ResB
|
||||
|
||||
let res_block = ResBlock {
|
||||
norm_in: norm_in,
|
||||
silu_in: SILU::new(),
|
||||
silu_in: SILU::new(),
|
||||
conv_in: conv_in,
|
||||
silu_embed: SILU::new(),
|
||||
silu_embed: SILU::new(),
|
||||
lin_embed: lin_embed,
|
||||
norm_out: norm_out,
|
||||
silu_out: SILU::new(),
|
||||
silu_out: SILU::new(),
|
||||
conv_out: conv_out,
|
||||
skip_connection: skip_connection,
|
||||
};
|
||||
@@ -39,7 +39,10 @@ pub fn load_res_block<B: Backend>(path: &str, device: &B::Device) -> Result<ResB
|
||||
Ok(res_block)
|
||||
}
|
||||
|
||||
pub fn load_multi_head_attention<B: Backend>(path: &str, device: &B::Device) -> Result<MultiHeadAttention<B>, Box<dyn Error>> {
|
||||
pub fn load_multi_head_attention<B: Backend>(
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<MultiHeadAttention<B>, Box<dyn Error>> {
|
||||
let n_head = load_usize::<B>("n_head", path, device)?;
|
||||
let query = load_linear::<B>(&format!("{}/{}", path, "query"), device)?;
|
||||
let key = load_linear::<B>(&format!("{}/{}", path, "key"), device)?;
|
||||
@@ -53,23 +56,21 @@ pub fn load_multi_head_attention<B: Backend>(path: &str, device: &B::Device) ->
|
||||
value: value,
|
||||
out: out,
|
||||
};
|
||||
|
||||
|
||||
Ok(multi_head_attention)
|
||||
}
|
||||
|
||||
|
||||
pub fn load_geglu<B: Backend>(path: &str, device: &B::Device) -> Result<GEGLU<B>, Box<dyn Error>> {
|
||||
let proj = load_linear::<B>(&format!("{}/{}", path, "proj"), device)?;
|
||||
|
||||
let geglue = GEGLU {
|
||||
proj: proj,
|
||||
gelu: GELU::new(), // Assuming GELU::new() initializes a new GELU struct
|
||||
gelu: Gelu::new(), // Assuming Gelu::new() initializes a new Gelu struct
|
||||
};
|
||||
|
||||
|
||||
Ok(geglue)
|
||||
}
|
||||
|
||||
|
||||
pub fn load_mlp<B: Backend>(path: &str, device: &B::Device) -> Result<MLP<B>, Box<dyn Error>> {
|
||||
let geglu = load_geglu::<B>(&format!("{}/{}", path, "geglu"), device)?;
|
||||
let lin = load_linear::<B>(&format!("{}/{}", path, "lin"), device)?;
|
||||
@@ -78,12 +79,14 @@ pub fn load_mlp<B: Backend>(path: &str, device: &B::Device) -> Result<MLP<B>, Bo
|
||||
geglu: geglu,
|
||||
lin: lin,
|
||||
};
|
||||
|
||||
|
||||
Ok(mlp)
|
||||
}
|
||||
|
||||
|
||||
pub fn load_transformer_block<B: Backend>(path: &str, device: &B::Device) -> Result<TransformerBlock<B>, Box<dyn Error>> {
|
||||
pub fn load_transformer_block<B: Backend>(
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<TransformerBlock<B>, Box<dyn Error>> {
|
||||
let norm1 = load_layer_norm::<B>(&format!("{}/{}", path, "norm1"), device)?;
|
||||
let attn1 = load_multi_head_attention::<B>(&format!("{}/{}", path, "attn1"), device)?;
|
||||
let norm2 = load_layer_norm::<B>(&format!("{}/{}", path, "norm2"), device)?;
|
||||
@@ -99,12 +102,14 @@ pub fn load_transformer_block<B: Backend>(path: &str, device: &B::Device) -> Res
|
||||
norm3: norm3,
|
||||
mlp: mlp,
|
||||
};
|
||||
|
||||
|
||||
Ok(transformer_block)
|
||||
}
|
||||
|
||||
|
||||
pub fn load_spatial_transformer<B: Backend>(path: &str, device: &B::Device) -> Result<SpatialTransformer<B>, Box<dyn Error>> {
|
||||
pub fn load_spatial_transformer<B: Backend>(
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<SpatialTransformer<B>, Box<dyn Error>> {
|
||||
let norm = load_group_norm::<B>(&format!("{}/{}", path, "norm"), device)?;
|
||||
let proj_in = load_conv2d::<B>(&format!("{}/{}", path, "proj_in"), device)?;
|
||||
let transformer = load_transformer_block::<B>(&format!("{}/{}", path, "transformer"), device)?;
|
||||
@@ -116,28 +121,35 @@ pub fn load_spatial_transformer<B: Backend>(path: &str, device: &B::Device) -> R
|
||||
transformer: transformer,
|
||||
proj_out: proj_out,
|
||||
};
|
||||
|
||||
|
||||
Ok(spatial_transformer)
|
||||
}
|
||||
|
||||
|
||||
pub fn load_upsample<B: Backend>(path: &str, device: &B::Device) -> Result<Upsample<B>, Box<dyn Error>> {
|
||||
pub fn load_upsample<B: Backend>(
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<Upsample<B>, Box<dyn Error>> {
|
||||
let conv = load_conv2d::<B>(&format!("{}/{}", path, "conv"), device)?;
|
||||
|
||||
let upsample = Upsample {
|
||||
conv: conv,
|
||||
};
|
||||
|
||||
let upsample = Upsample { conv: conv };
|
||||
|
||||
Ok(upsample)
|
||||
}
|
||||
|
||||
pub fn load_downsample<B: Backend>(path: &str, device: &B::Device) -> Result<Downsample<B>, Box<dyn Error>> {
|
||||
pub fn load_downsample<B: Backend>(
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<Downsample<B>, Box<dyn Error>> {
|
||||
load_conv2d(path, device)
|
||||
}
|
||||
|
||||
pub fn load_res_transformer_res<B: Backend>(path: &str, device: &B::Device) -> Result<ResTransformerRes<B>, Box<dyn Error>> {
|
||||
pub fn load_res_transformer_res<B: Backend>(
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<ResTransformerRes<B>, Box<dyn Error>> {
|
||||
let res1 = load_res_block::<B>(&format!("{}/{}", path, "res1"), device)?; // Assuming load_res_block function
|
||||
let transformer = load_spatial_transformer::<B>(&format!("{}/{}", path, "transformer"), device)?;
|
||||
let transformer =
|
||||
load_spatial_transformer::<B>(&format!("{}/{}", path, "transformer"), device)?;
|
||||
let res2 = load_res_block::<B>(&format!("{}/{}", path, "res2"), device)?;
|
||||
|
||||
let res_transformer_res = ResTransformerRes {
|
||||
@@ -145,13 +157,17 @@ pub fn load_res_transformer_res<B: Backend>(path: &str, device: &B::Device) -> R
|
||||
transformer: transformer,
|
||||
res2: res2,
|
||||
};
|
||||
|
||||
|
||||
Ok(res_transformer_res)
|
||||
}
|
||||
|
||||
pub fn load_res_transformer_upsample<B: Backend>(path: &str, device: &B::Device) -> Result<ResTransformerUpsample<B>, Box<dyn Error>> {
|
||||
pub fn load_res_transformer_upsample<B: Backend>(
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<ResTransformerUpsample<B>, Box<dyn Error>> {
|
||||
let res = load_res_block::<B>(&format!("{}/{}", path, "res"), device)?;
|
||||
let transformer = load_spatial_transformer::<B>(&format!("{}/{}", path, "transformer"), device)?;
|
||||
let transformer =
|
||||
load_spatial_transformer::<B>(&format!("{}/{}", path, "transformer"), device)?;
|
||||
let upsample = load_upsample::<B>(&format!("{}/{}", path, "upsample"), device)?;
|
||||
|
||||
let res_transformer_upsample = ResTransformerUpsample {
|
||||
@@ -159,12 +175,14 @@ pub fn load_res_transformer_upsample<B: Backend>(path: &str, device: &B::Device)
|
||||
transformer: transformer,
|
||||
upsample: upsample,
|
||||
};
|
||||
|
||||
|
||||
Ok(res_transformer_upsample)
|
||||
}
|
||||
|
||||
|
||||
pub fn load_res_upsample<B: Backend>(path: &str, device: &B::Device) -> Result<ResUpSample<B>, Box<dyn Error>> {
|
||||
pub fn load_res_upsample<B: Backend>(
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<ResUpSample<B>, Box<dyn Error>> {
|
||||
let res = load_res_block::<B>(&format!("{}/{}", path, "res"), device)?;
|
||||
let upsample = load_upsample::<B>(&format!("{}/{}", path, "upsample"), device)?;
|
||||
|
||||
@@ -172,25 +190,30 @@ pub fn load_res_upsample<B: Backend>(path: &str, device: &B::Device) -> Result<R
|
||||
res: res,
|
||||
upsample: upsample,
|
||||
};
|
||||
|
||||
|
||||
Ok(res_upsample)
|
||||
}
|
||||
|
||||
|
||||
pub fn load_res_transformer<B: Backend>(path: &str, device: &B::Device) -> Result<ResTransformer<B>, Box<dyn Error>> {
|
||||
pub fn load_res_transformer<B: Backend>(
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<ResTransformer<B>, Box<dyn Error>> {
|
||||
let res = load_res_block::<B>(&format!("{}/{}", path, "res"), device)?;
|
||||
let transformer = load_spatial_transformer::<B>(&format!("{}/{}", path, "transformer"), device)?;
|
||||
let transformer =
|
||||
load_spatial_transformer::<B>(&format!("{}/{}", path, "transformer"), device)?;
|
||||
|
||||
let res_transformer = ResTransformer {
|
||||
res: res,
|
||||
transformer: transformer,
|
||||
};
|
||||
|
||||
|
||||
Ok(res_transformer)
|
||||
}
|
||||
|
||||
|
||||
pub fn load_unet_input_blocks<B: Backend>(path: &str, device: &B::Device) -> Result<UNetInputBlocks<B>, Box<dyn Error>> {
|
||||
pub fn load_unet_input_blocks<B: Backend>(
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<UNetInputBlocks<B>, Box<dyn Error>> {
|
||||
let conv = load_conv2d::<B>(&format!("{}/{}", path, "conv"), device)?;
|
||||
let rt1 = load_res_transformer::<B>(&format!("{}/{}", path, "rt1"), device)?;
|
||||
let rt2 = load_res_transformer::<B>(&format!("{}/{}", path, "rt2"), device)?;
|
||||
@@ -218,11 +241,14 @@ pub fn load_unet_input_blocks<B: Backend>(path: &str, device: &B::Device) -> Res
|
||||
r1: r1,
|
||||
r2: r2,
|
||||
};
|
||||
|
||||
|
||||
Ok(unet_input_blocks)
|
||||
}
|
||||
|
||||
pub fn load_unet_output_blocks<B: Backend>(path: &str, device: &B::Device) -> Result<UNetOutputBlocks<B>, Box<dyn Error>> {
|
||||
pub fn load_unet_output_blocks<B: Backend>(
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<UNetOutputBlocks<B>, Box<dyn Error>> {
|
||||
let r1 = load_res_block::<B>(&format!("{}/{}", path, "r1"), device)?;
|
||||
let r2 = load_res_block::<B>(&format!("{}/{}", path, "r2"), device)?;
|
||||
let ru = load_res_upsample::<B>(&format!("{}/{}", path, "ru"), device)?;
|
||||
@@ -252,14 +278,16 @@ pub fn load_unet_output_blocks<B: Backend>(path: &str, device: &B::Device) -> Re
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
pub fn load_unet<B: Backend>(path: &str, device: &B::Device) -> Result<UNet<B>, Box<dyn Error>> {
|
||||
let lin1_time_embed = load_linear::<B>(&format!("{}/{}", path, "lin1_time_embed"), device)?;
|
||||
let silu_time_embed = SILU::new(); // Assuming SILU::new() initializes a new SILU struct
|
||||
let lin2_time_embed = load_linear::<B>(&format!("{}/{}", path, "lin2_time_embed"), device)?;
|
||||
let input_blocks = load_unet_input_blocks::<B>(&format!("{}/{}", path, "input_blocks"), device)?;
|
||||
let middle_block = load_res_transformer_res::<B>(&format!("{}/{}", path, "middle_block"), device)?;
|
||||
let output_blocks = load_unet_output_blocks::<B>(&format!("{}/{}", path, "output_blocks"), device)?;
|
||||
let input_blocks =
|
||||
load_unet_input_blocks::<B>(&format!("{}/{}", path, "input_blocks"), device)?;
|
||||
let middle_block =
|
||||
load_res_transformer_res::<B>(&format!("{}/{}", path, "middle_block"), device)?;
|
||||
let output_blocks =
|
||||
load_unet_output_blocks::<B>(&format!("{}/{}", path, "output_blocks"), device)?;
|
||||
let norm_out = load_group_norm::<B>(&format!("{}/{}", path, "norm_out"), device)?;
|
||||
let silu_out = SILU::new(); // Assuming SILU::new() initializes a new SILU struct
|
||||
let conv_out = load_conv2d::<B>(&format!("{}/{}", path, "conv_out"), device)?;
|
||||
|
||||
@@ -1,108 +1,117 @@
|
||||
pub mod load;
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn::{self, PaddingConfig2d, GELU, conv::{Conv2d, Conv2dConfig}},
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
activation::softmax,
|
||||
module::embedding,
|
||||
Tensor,
|
||||
Distribution,
|
||||
Int,
|
||||
nn::{
|
||||
self,
|
||||
conv::{Conv2d, Conv2dConfig},
|
||||
PaddingConfig2d, Gelu,
|
||||
},
|
||||
tensor::{activation::softmax, backend::Backend, module::embedding, Distribution, Int, Tensor},
|
||||
};
|
||||
|
||||
use super::silu::*;
|
||||
use super::groupnorm::*;
|
||||
use crate::helper::to_float;
|
||||
use super::silu::*;
|
||||
|
||||
use super::attention::qkv_attention;
|
||||
|
||||
|
||||
fn timestep_embedding<B: Backend>(timesteps: Tensor<B, 1, Int>, dim: usize, max_period: usize) -> Tensor<B, 2> {
|
||||
fn timestep_embedding<B: Backend>(
|
||||
timesteps: Tensor<B, 1, Int>,
|
||||
dim: usize,
|
||||
max_period: usize,
|
||||
) -> Tensor<B, 2> {
|
||||
let half = dim / 2;
|
||||
let freqs = ( to_float(Tensor::arange_device(0..half, ×teps.device())) * (-(max_period as f64).ln() / half as f64 ) ).exp();
|
||||
let args = to_float(timesteps) * freqs;
|
||||
let freqs = (Tensor::arange(0..half as i64, ×teps.device()).float()
|
||||
* (-(max_period as f64).ln() / half as f64))
|
||||
.exp();
|
||||
let args = timesteps.float() * freqs;
|
||||
Tensor::cat(vec![args.clone().cos(), args.sin()], 0).unsqueeze()
|
||||
}
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
#[derive(Config, Debug)]
|
||||
pub struct UNetConfig {}
|
||||
|
||||
impl UNetConfig {
|
||||
pub fn init<B: Backend>(&self) -> UNet<B> {
|
||||
let lin1_time_embed = nn::LinearConfig::new(320, 1280).init();
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> UNet<B> {
|
||||
let lin1_time_embed = nn::LinearConfig::new(320, 1280).init(device);
|
||||
let silu_time_embed = SILU::new();
|
||||
let lin2_time_embed = nn::LinearConfig::new(1280, 1280).init();
|
||||
let lin2_time_embed = nn::LinearConfig::new(1280, 1280).init(device);
|
||||
|
||||
let input_blocks = UNetInputBlocks {
|
||||
conv: Conv2dConfig::new([4, 320], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init(),
|
||||
rt1: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(),
|
||||
rt2: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(),
|
||||
d1: DownsampleConfig::new(320).init(),
|
||||
rt3: ResTransformerConfig::new(320, 1280, 640, 768, 8).init(),
|
||||
rt4: ResTransformerConfig::new(640, 1280, 640, 768, 8).init(),
|
||||
d2: DownsampleConfig::new(640).init(),
|
||||
rt5: ResTransformerConfig::new(640, 1280, 1280, 768, 8).init(),
|
||||
rt6: ResTransformerConfig::new(1280, 1280, 1280, 768, 8).init(),
|
||||
d3: DownsampleConfig::new(1280).init(),
|
||||
r1: ResBlockConfig::new(1280, 1280, 1280).init(),
|
||||
r2: ResBlockConfig::new(1280, 1280, 1280).init(),
|
||||
conv: Conv2dConfig::new([4, 320], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init(device),
|
||||
rt1: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(device),
|
||||
rt2: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(device),
|
||||
d1: DownsampleConfig::new(320).init(device),
|
||||
rt3: ResTransformerConfig::new(320, 1280, 640, 768, 8).init(device),
|
||||
rt4: ResTransformerConfig::new(640, 1280, 640, 768, 8).init(device),
|
||||
d2: DownsampleConfig::new(640).init(device),
|
||||
rt5: ResTransformerConfig::new(640, 1280, 1280, 768, 8).init(device),
|
||||
rt6: ResTransformerConfig::new(1280, 1280, 1280, 768, 8).init(device),
|
||||
d3: DownsampleConfig::new(1280).init(device),
|
||||
r1: ResBlockConfig::new(1280, 1280, 1280).init(device),
|
||||
r2: ResBlockConfig::new(1280, 1280, 1280).init(device),
|
||||
};
|
||||
|
||||
let middle_block = ResTransformerResConfig::new(1280, 1280, 1280, 768, 8).init();
|
||||
|
||||
let middle_block = ResTransformerResConfig::new(1280, 1280, 1280, 768, 8).init(device);
|
||||
|
||||
let output_blocks = UNetOutputBlocks {
|
||||
r1: ResBlockConfig::new(2560, 1280, 1280).init(),
|
||||
r2: ResBlockConfig::new(2560, 1280, 1280).init(),
|
||||
ru: ResUpSampleConfig::new(2560, 1280, 1280).init(),
|
||||
rt1: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(),
|
||||
rt2: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(),
|
||||
rtu1: ResTransformerUpsampleConfig::new(1920, 1280, 1280, 768, 8).init(),
|
||||
rt3: ResTransformerConfig::new(1920, 1280, 640, 768, 8).init(),
|
||||
rt4: ResTransformerConfig::new(1280, 1280, 640, 768, 8).init(),
|
||||
rtu2: ResTransformerUpsampleConfig::new(960, 1280, 640, 768, 8).init(),
|
||||
rt5: ResTransformerConfig::new(960, 1280, 320, 768, 8).init(),
|
||||
rt6: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(),
|
||||
rt7: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(),
|
||||
r1: ResBlockConfig::new(2560, 1280, 1280).init(device),
|
||||
r2: ResBlockConfig::new(2560, 1280, 1280).init(device),
|
||||
ru: ResUpSampleConfig::new(2560, 1280, 1280).init(device),
|
||||
rt1: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(device),
|
||||
rt2: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(device),
|
||||
rtu1: ResTransformerUpsampleConfig::new(1920, 1280, 1280, 768, 8).init(device),
|
||||
rt3: ResTransformerConfig::new(1920, 1280, 640, 768, 8).init(device),
|
||||
rt4: ResTransformerConfig::new(1280, 1280, 640, 768, 8).init(device),
|
||||
rtu2: ResTransformerUpsampleConfig::new(960, 1280, 640, 768, 8).init(device),
|
||||
rt5: ResTransformerConfig::new(960, 1280, 320, 768, 8).init(device),
|
||||
rt6: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(device),
|
||||
rt7: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(device),
|
||||
};
|
||||
|
||||
let norm_out = GroupNormConfig::new(32, 320).init();
|
||||
let norm_out = GroupNormConfig::new(32, 320).init(device);
|
||||
let silu_out = SILU::new();
|
||||
let conv_out = Conv2dConfig::new([320, 4], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
||||
let conv_out = Conv2dConfig::new([320, 4], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init(device);
|
||||
|
||||
UNet {
|
||||
lin1_time_embed,
|
||||
silu_time_embed,
|
||||
lin2_time_embed,
|
||||
input_blocks,
|
||||
middle_block,
|
||||
output_blocks,
|
||||
norm_out,
|
||||
silu_out,
|
||||
conv_out,
|
||||
lin1_time_embed,
|
||||
silu_time_embed,
|
||||
lin2_time_embed,
|
||||
input_blocks,
|
||||
middle_block,
|
||||
output_blocks,
|
||||
norm_out,
|
||||
silu_out,
|
||||
conv_out,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct UNet<B: Backend> {
|
||||
lin1_time_embed: nn::Linear<B>,
|
||||
silu_time_embed: SILU,
|
||||
lin2_time_embed: nn::Linear<B>,
|
||||
input_blocks: UNetInputBlocks<B>,
|
||||
middle_block: ResTransformerRes<B>,
|
||||
output_blocks: UNetOutputBlocks<B>,
|
||||
norm_out: GroupNorm<B>,
|
||||
silu_out: SILU,
|
||||
conv_out: Conv2d<B>,
|
||||
lin1_time_embed: nn::Linear<B>,
|
||||
silu_time_embed: SILU,
|
||||
lin2_time_embed: nn::Linear<B>,
|
||||
input_blocks: UNetInputBlocks<B>,
|
||||
middle_block: ResTransformerRes<B>,
|
||||
output_blocks: UNetOutputBlocks<B>,
|
||||
norm_out: GroupNorm<B>,
|
||||
silu_out: SILU,
|
||||
conv_out: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> UNet<B> {
|
||||
pub fn forward(&self, x: Tensor<B, 4>, timesteps: Tensor<B, 1, Int>, context: Tensor<B, 3>) -> Tensor<B, 4> {
|
||||
pub fn forward(
|
||||
&self,
|
||||
x: Tensor<B, 4>,
|
||||
timesteps: Tensor<B, 1, Int>,
|
||||
context: Tensor<B, 3>,
|
||||
) -> Tensor<B, 4> {
|
||||
let t_emb = timestep_embedding(timesteps, 320, 10000);
|
||||
let emb = self.lin1_time_embed.forward(t_emb);
|
||||
let emb = self.silu_time_embed.forward(emb);
|
||||
@@ -133,39 +142,27 @@ impl<B: Backend> UNet<B> {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct UNetInputBlocks<B: Backend> {
|
||||
conv: Conv2d<B>,
|
||||
rt1: ResTransformer<B>,
|
||||
rt2: ResTransformer<B>,
|
||||
d1: Downsample<B>,
|
||||
rt3: ResTransformer<B>,
|
||||
rt4: ResTransformer<B>,
|
||||
d2: Downsample<B>,
|
||||
rt5: ResTransformer<B>,
|
||||
rt6: ResTransformer<B>,
|
||||
d3: Downsample<B>,
|
||||
r1: ResBlock<B>,
|
||||
r2: ResBlock<B>,
|
||||
conv: Conv2d<B>,
|
||||
rt1: ResTransformer<B>,
|
||||
rt2: ResTransformer<B>,
|
||||
d1: Downsample<B>,
|
||||
rt3: ResTransformer<B>,
|
||||
rt4: ResTransformer<B>,
|
||||
d2: Downsample<B>,
|
||||
rt5: ResTransformer<B>,
|
||||
rt6: ResTransformer<B>,
|
||||
d3: Downsample<B>,
|
||||
r1: ResBlock<B>,
|
||||
r2: ResBlock<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> UNetInputBlocks<B> {
|
||||
fn as_array(&self) -> [&dyn UNetBlock<B>; 12] {
|
||||
[
|
||||
&self.conv,
|
||||
&self.rt1,
|
||||
&self.rt2,
|
||||
&self.d1,
|
||||
&self.rt3,
|
||||
&self.rt4,
|
||||
&self.d2,
|
||||
&self.rt5,
|
||||
&self.rt6,
|
||||
&self.d3,
|
||||
&self.r1,
|
||||
&self.r2,
|
||||
&self.conv, &self.rt1, &self.rt2, &self.d1, &self.rt3, &self.rt4, &self.d2, &self.rt5,
|
||||
&self.rt6, &self.d3, &self.r1, &self.r2,
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -177,67 +174,57 @@ pub struct UNetOutputBlocks<B: Backend> {
|
||||
ru: ResUpSample<B>,
|
||||
rt1: ResTransformer<B>,
|
||||
rt2: ResTransformer<B>,
|
||||
rtu1: ResTransformerUpsample<B>,
|
||||
rtu1: ResTransformerUpsample<B>,
|
||||
rt3: ResTransformer<B>,
|
||||
rt4: ResTransformer<B>,
|
||||
rtu2: ResTransformerUpsample<B>,
|
||||
rt5: ResTransformer<B>,
|
||||
rt6: ResTransformer<B>,
|
||||
rt7: ResTransformer<B>,
|
||||
rtu2: ResTransformerUpsample<B>,
|
||||
rt5: ResTransformer<B>,
|
||||
rt6: ResTransformer<B>,
|
||||
rt7: ResTransformer<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> UNetOutputBlocks<B> {
|
||||
fn as_array(&self) -> [&dyn UNetBlock<B>; 12] {
|
||||
[
|
||||
&self.r1,
|
||||
&self.r2,
|
||||
&self.ru,
|
||||
&self.rt1,
|
||||
&self.rt2,
|
||||
&self.rtu1,
|
||||
&self.rt3,
|
||||
&self.rt4,
|
||||
&self.rtu2,
|
||||
&self.rt5,
|
||||
&self.rt6,
|
||||
&self.rt7,
|
||||
&self.r1, &self.r2, &self.ru, &self.rt1, &self.rt2, &self.rtu1, &self.rt3, &self.rt4,
|
||||
&self.rtu2, &self.rt5, &self.rt6, &self.rt7,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
trait UNetBlock<B: Backend> {
|
||||
fn forward(&self, x: Tensor<B, 4>, emb: Tensor<B, 2>, context: Tensor<B, 3>) -> Tensor<B, 4>;
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
#[derive(Config, Debug)]
|
||||
pub struct ResTransformerConfig {
|
||||
n_channels_in: usize,
|
||||
n_channels_embed: usize,
|
||||
n_channels_out: usize,
|
||||
n_context_state: usize,
|
||||
n_head: usize,
|
||||
n_channels_in: usize,
|
||||
n_channels_embed: usize,
|
||||
n_channels_out: usize,
|
||||
n_context_state: usize,
|
||||
n_head: usize,
|
||||
}
|
||||
|
||||
impl ResTransformerConfig {
|
||||
fn init<B: Backend>(&self) -> ResTransformer<B> {
|
||||
let res = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init();
|
||||
let transformer = SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head).init();
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> ResTransformer<B> {
|
||||
let res = ResBlockConfig::new(
|
||||
self.n_channels_in,
|
||||
self.n_channels_embed,
|
||||
self.n_channels_out,
|
||||
)
|
||||
.init(device);
|
||||
let transformer =
|
||||
SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head)
|
||||
.init(device);
|
||||
|
||||
ResTransformer {
|
||||
res,
|
||||
transformer,
|
||||
}
|
||||
ResTransformer { res, transformer }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ResTransformer<B: Backend> {
|
||||
res: ResBlock<B>,
|
||||
transformer: SpatialTransformer<B>,
|
||||
res: ResBlock<B>,
|
||||
transformer: SpatialTransformer<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> UNetBlock<B> for ResTransformer<B> {
|
||||
@@ -248,29 +235,31 @@ impl<B: Backend> UNetBlock<B> for ResTransformer<B> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
#[derive(Config, Debug)]
|
||||
pub struct ResUpSampleConfig {
|
||||
n_channels_in: usize,
|
||||
n_channels_embed: usize,
|
||||
n_channels_out: usize,
|
||||
n_channels_in: usize,
|
||||
n_channels_embed: usize,
|
||||
n_channels_out: usize,
|
||||
}
|
||||
|
||||
impl ResUpSampleConfig {
|
||||
fn init<B: Backend>(&self) -> ResUpSample<B> {
|
||||
let res = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init();
|
||||
let upsample = UpsampleConfig::new(self.n_channels_out).init();
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> ResUpSample<B> {
|
||||
let res = ResBlockConfig::new(
|
||||
self.n_channels_in,
|
||||
self.n_channels_embed,
|
||||
self.n_channels_out,
|
||||
)
|
||||
.init(device);
|
||||
let upsample = UpsampleConfig::new(self.n_channels_out).init(device);
|
||||
|
||||
ResUpSample {
|
||||
res,
|
||||
upsample,
|
||||
}
|
||||
ResUpSample { res, upsample }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ResUpSample<B: Backend> {
|
||||
res: ResBlock<B>,
|
||||
upsample: Upsample<B>,
|
||||
res: ResBlock<B>,
|
||||
upsample: Upsample<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> UNetBlock<B> for ResUpSample<B> {
|
||||
@@ -281,34 +270,41 @@ impl<B: Backend> UNetBlock<B> for ResUpSample<B> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
#[derive(Config, Debug)]
|
||||
pub struct ResTransformerUpsampleConfig {
|
||||
n_channels_in: usize,
|
||||
n_channels_embed: usize,
|
||||
n_channels_out: usize,
|
||||
n_context_state: usize,
|
||||
n_head: usize,
|
||||
n_channels_in: usize,
|
||||
n_channels_embed: usize,
|
||||
n_channels_out: usize,
|
||||
n_context_state: usize,
|
||||
n_head: usize,
|
||||
}
|
||||
|
||||
impl ResTransformerUpsampleConfig {
|
||||
fn init<B: Backend>(&self) -> ResTransformerUpsample<B> {
|
||||
let res = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init();
|
||||
let transformer = SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head).init();
|
||||
let upsample = UpsampleConfig::new(self.n_channels_out).init();
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> ResTransformerUpsample<B> {
|
||||
let res = ResBlockConfig::new(
|
||||
self.n_channels_in,
|
||||
self.n_channels_embed,
|
||||
self.n_channels_out,
|
||||
)
|
||||
.init(device);
|
||||
let transformer =
|
||||
SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head)
|
||||
.init(device);
|
||||
let upsample = UpsampleConfig::new(self.n_channels_out).init(device);
|
||||
|
||||
ResTransformerUpsample {
|
||||
res,
|
||||
transformer,
|
||||
upsample,
|
||||
res,
|
||||
transformer,
|
||||
upsample,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ResTransformerUpsample<B: Backend> {
|
||||
res: ResBlock<B>,
|
||||
transformer: SpatialTransformer<B>,
|
||||
upsample: Upsample<B>,
|
||||
res: ResBlock<B>,
|
||||
transformer: SpatialTransformer<B>,
|
||||
upsample: Upsample<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> UNetBlock<B> for ResTransformerUpsample<B> {
|
||||
@@ -320,34 +316,46 @@ impl<B: Backend> UNetBlock<B> for ResTransformerUpsample<B> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
#[derive(Config, Debug)]
|
||||
pub struct ResTransformerResConfig {
|
||||
n_channels_in: usize,
|
||||
n_channels_embed: usize,
|
||||
n_channels_out: usize,
|
||||
n_context_state: usize,
|
||||
n_head: usize,
|
||||
n_channels_in: usize,
|
||||
n_channels_embed: usize,
|
||||
n_channels_out: usize,
|
||||
n_context_state: usize,
|
||||
n_head: usize,
|
||||
}
|
||||
|
||||
impl ResTransformerResConfig {
|
||||
fn init<B: Backend>(&self) -> ResTransformerRes<B> {
|
||||
let res1 = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init();
|
||||
let transformer = SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head).init();
|
||||
let res2 = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init();
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> ResTransformerRes<B> {
|
||||
let res1 = ResBlockConfig::new(
|
||||
self.n_channels_in,
|
||||
self.n_channels_embed,
|
||||
self.n_channels_out,
|
||||
)
|
||||
.init(device);
|
||||
let transformer =
|
||||
SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head)
|
||||
.init(device);
|
||||
let res2 = ResBlockConfig::new(
|
||||
self.n_channels_in,
|
||||
self.n_channels_embed,
|
||||
self.n_channels_out,
|
||||
)
|
||||
.init(device);
|
||||
|
||||
ResTransformerRes {
|
||||
res1,
|
||||
transformer,
|
||||
res2,
|
||||
res1,
|
||||
transformer,
|
||||
res2,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ResTransformerRes<B: Backend> {
|
||||
res1: ResBlock<B>,
|
||||
transformer: SpatialTransformer<B>,
|
||||
res2: ResBlock<B>,
|
||||
res1: ResBlock<B>,
|
||||
transformer: SpatialTransformer<B>,
|
||||
res2: ResBlock<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> UNetBlock<B> for ResTransformerRes<B> {
|
||||
@@ -359,38 +367,33 @@ impl<B: Backend> UNetBlock<B> for ResTransformerRes<B> {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
#[derive(Config, Debug)]
|
||||
pub struct UpsampleConfig {
|
||||
n_channels: usize,
|
||||
n_channels: usize,
|
||||
}
|
||||
|
||||
impl UpsampleConfig {
|
||||
fn init<B: Backend>(&self) -> Upsample<B> {
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> Upsample<B> {
|
||||
let conv = Conv2dConfig::new([self.n_channels, self.n_channels], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init();
|
||||
.init(device);
|
||||
|
||||
Upsample {
|
||||
conv,
|
||||
}
|
||||
Upsample { conv }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Upsample<B: Backend> {
|
||||
conv: Conv2d<B>,
|
||||
conv: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Upsample<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let [n_batch, n_channel, height, width] = x.dims();
|
||||
let x = x
|
||||
.reshape([n_batch, n_channel, height, 1, width, 1])
|
||||
.repeat(3, 2)
|
||||
.repeat(5, 2)
|
||||
.reshape([n_batch, n_channel, 2 * height, 2 * width]);
|
||||
.reshape([n_batch, n_channel, height, 1, width, 1])
|
||||
.repeat(&[1, 1, 1, 2, 1, 2])
|
||||
.reshape([n_batch, n_channel, 2 * height, 2 * width]);
|
||||
self.conv.forward(x)
|
||||
}
|
||||
}
|
||||
@@ -401,17 +404,17 @@ impl<B: Backend> UNetBlock<B> for Upsample<B> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
#[derive(Config, Debug)]
|
||||
pub struct DownsampleConfig {
|
||||
n_channels: usize,
|
||||
n_channels: usize,
|
||||
}
|
||||
|
||||
impl DownsampleConfig {
|
||||
fn init<B: Backend>(&self) -> Conv2d<B> {
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> Conv2d<B> {
|
||||
Conv2dConfig::new([self.n_channels, self.n_channels], [3, 3])
|
||||
.with_stride([2, 2])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init()
|
||||
.init(device)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -423,38 +426,36 @@ impl<B: Backend> UNetBlock<B> for Conv2d<B> {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
#[derive(Config, Debug)]
|
||||
pub struct SpatialTransformerConfig {
|
||||
n_channels: usize,
|
||||
n_context_state: usize,
|
||||
n_head: usize,
|
||||
n_channels: usize,
|
||||
n_context_state: usize,
|
||||
n_head: usize,
|
||||
}
|
||||
|
||||
impl SpatialTransformerConfig {
|
||||
fn init<B: Backend>(&self) -> SpatialTransformer<B> {
|
||||
let norm = GroupNormConfig::new(32, self.n_channels).init();
|
||||
let proj_in = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init();
|
||||
let transformer = TransformerBlockConfig::new(self.n_channels, self.n_context_state, self.n_head).init();
|
||||
let proj_out = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init();
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> SpatialTransformer<B> {
|
||||
let norm = GroupNormConfig::new(32, self.n_channels).init(device);
|
||||
let proj_in = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init(device);
|
||||
let transformer =
|
||||
TransformerBlockConfig::new(self.n_channels, self.n_context_state, self.n_head).init(device);
|
||||
let proj_out = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init(device);
|
||||
|
||||
SpatialTransformer {
|
||||
norm,
|
||||
proj_in,
|
||||
transformer,
|
||||
proj_out,
|
||||
norm,
|
||||
proj_in,
|
||||
transformer,
|
||||
proj_out,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct SpatialTransformer<B: Backend> {
|
||||
norm: GroupNorm<B>,
|
||||
proj_in: Conv2d<B>,
|
||||
transformer: TransformerBlock<B>,
|
||||
proj_out: Conv2d<B>,
|
||||
norm: GroupNorm<B>,
|
||||
proj_in: Conv2d<B>,
|
||||
transformer: TransformerBlock<B>,
|
||||
proj_out: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> SpatialTransformer<B> {
|
||||
@@ -465,9 +466,13 @@ impl<B: Backend> SpatialTransformer<B> {
|
||||
|
||||
let x = self.norm.forward(x);
|
||||
let x = self.proj_in.forward(x);
|
||||
let x = x.reshape([n_batch, n_channel, height * width]).swap_dims(1, 2);
|
||||
let x = x
|
||||
.reshape([n_batch, n_channel, height * width])
|
||||
.swap_dims(1, 2);
|
||||
|
||||
let x = self.transformer.forward(x, context)
|
||||
let x = self
|
||||
.transformer
|
||||
.forward(x, context)
|
||||
.swap_dims(1, 2)
|
||||
.reshape([n_batch, n_channel, height, width]);
|
||||
|
||||
@@ -475,113 +480,99 @@ impl<B: Backend> SpatialTransformer<B> {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
#[derive(Config, Debug)]
|
||||
pub struct TransformerBlockConfig {
|
||||
n_state: usize,
|
||||
n_context_state: usize,
|
||||
n_head: usize,
|
||||
n_state: usize,
|
||||
n_context_state: usize,
|
||||
n_head: usize,
|
||||
}
|
||||
|
||||
impl TransformerBlockConfig {
|
||||
fn init<B: Backend>(&self) -> TransformerBlock<B> {
|
||||
let norm1 = nn::LayerNormConfig::new(self.n_state).init();
|
||||
let attn1 = MultiHeadAttentionConfig::new(self.n_state, self.n_state, self.n_head).init();
|
||||
let norm2 = nn::LayerNormConfig::new(self.n_state).init();
|
||||
let attn2 = MultiHeadAttentionConfig::new(self.n_state, self.n_context_state, self.n_head).init();
|
||||
let norm3 = nn::LayerNormConfig::new(self.n_state).init();
|
||||
let mlp = MLPConfig::new(self.n_state, 4).init();
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> TransformerBlock<B> {
|
||||
let norm1 = nn::LayerNormConfig::new(self.n_state).init(device);
|
||||
let attn1 = MultiHeadAttentionConfig::new(self.n_state, self.n_state, self.n_head).init(device);
|
||||
let norm2 = nn::LayerNormConfig::new(self.n_state).init(device);
|
||||
let attn2 =
|
||||
MultiHeadAttentionConfig::new(self.n_state, self.n_context_state, self.n_head).init(device);
|
||||
let norm3 = nn::LayerNormConfig::new(self.n_state).init(device);
|
||||
let mlp = MLPConfig::new(self.n_state, 4).init(device);
|
||||
|
||||
TransformerBlock {
|
||||
norm1,
|
||||
attn1,
|
||||
norm2,
|
||||
attn2,
|
||||
norm3,
|
||||
mlp,
|
||||
norm1,
|
||||
attn1,
|
||||
norm2,
|
||||
attn2,
|
||||
norm3,
|
||||
mlp,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct TransformerBlock<B: Backend> {
|
||||
norm1: nn::LayerNorm<B>,
|
||||
attn1: MultiHeadAttention<B>,
|
||||
norm2: nn::LayerNorm<B>,
|
||||
attn2: MultiHeadAttention<B>,
|
||||
norm3: nn::LayerNorm<B>,
|
||||
mlp: MLP<B>,
|
||||
norm1: nn::LayerNorm<B>,
|
||||
attn1: MultiHeadAttention<B>,
|
||||
norm2: nn::LayerNorm<B>,
|
||||
attn2: MultiHeadAttention<B>,
|
||||
norm3: nn::LayerNorm<B>,
|
||||
mlp: MLP<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> TransformerBlock<B> {
|
||||
fn forward(&self, x: Tensor<B, 3>, context: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
let x = x.clone() + self.attn1.forward( self.norm1.forward(x), None);
|
||||
let x = x.clone() + self.attn2.forward( self.norm2.forward(x), Some(context));
|
||||
x.clone() + self.mlp.forward( self.norm3.forward(x) )
|
||||
let x = x.clone() + self.attn1.forward(self.norm1.forward(x), None);
|
||||
let x = x.clone() + self.attn2.forward(self.norm2.forward(x), Some(context));
|
||||
x.clone() + self.mlp.forward(self.norm3.forward(x))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
#[derive(Config, Debug)]
|
||||
pub struct MLPConfig {
|
||||
n_state: usize,
|
||||
mult: usize,
|
||||
n_state: usize,
|
||||
mult: usize,
|
||||
}
|
||||
|
||||
impl MLPConfig {
|
||||
pub fn init<B: Backend>(&self) -> MLP<B> {
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> MLP<B> {
|
||||
let n_state_hidden = self.n_state * self.mult;
|
||||
let geglu = GEGLUConfig::new(self.n_state, n_state_hidden).init();
|
||||
let lin = nn::LinearConfig::new(n_state_hidden, self.n_state).init();
|
||||
let geglu = GEGLUConfig::new(self.n_state, n_state_hidden).init(device);
|
||||
let lin = nn::LinearConfig::new(n_state_hidden, self.n_state).init(device);
|
||||
|
||||
MLP {
|
||||
geglu,
|
||||
lin,
|
||||
}
|
||||
MLP { geglu, lin }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct MLP<B: Backend> {
|
||||
geglu: GEGLU<B>,
|
||||
lin: nn::Linear<B>,
|
||||
geglu: GEGLU<B>,
|
||||
lin: nn::Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> MLP<B> {
|
||||
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
self.lin.forward( self.geglu.forward(x) )
|
||||
self.lin.forward(self.geglu.forward(x))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
#[derive(Config, Debug)]
|
||||
pub struct GEGLUConfig {
|
||||
n_state_in: usize,
|
||||
n_state_out: usize,
|
||||
n_state_in: usize,
|
||||
n_state_out: usize,
|
||||
}
|
||||
|
||||
impl GEGLUConfig {
|
||||
fn init<B: Backend>(&self) -> GEGLU<B> {
|
||||
let proj = nn::LinearConfig::new(self.n_state_in, 2 * self.n_state_out).init();
|
||||
let gelu = GELU::new();
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> GEGLU<B> {
|
||||
let proj = nn::LinearConfig::new(self.n_state_in, 2 * self.n_state_out).init(device);
|
||||
let gelu = Gelu::new();
|
||||
|
||||
GEGLU {
|
||||
proj,
|
||||
gelu,
|
||||
}
|
||||
GEGLU { proj, gelu }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct GEGLU<B: Backend> {
|
||||
proj: nn::Linear<B>,
|
||||
gelu: GELU,
|
||||
proj: nn::Linear<B>,
|
||||
gelu: Gelu,
|
||||
}
|
||||
|
||||
impl<B: Backend> GEGLU<B> {
|
||||
@@ -591,51 +582,60 @@ impl<B: Backend> GEGLU<B> {
|
||||
|
||||
let n_state_out = n_state / 2;
|
||||
|
||||
let x = projected.clone().slice([0..n_batch, 0..n_ctx, 0..n_state_out]);
|
||||
let x = projected
|
||||
.clone()
|
||||
.slice([0..n_batch, 0..n_ctx, 0..n_state_out]);
|
||||
let gate = projected.slice([0..n_batch, 0..n_ctx, n_state_out..n_state]);
|
||||
|
||||
x * self.gelu.forward(gate)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
#[derive(Config, Debug)]
|
||||
pub struct MultiHeadAttentionConfig {
|
||||
n_state: usize,
|
||||
n_context_state: usize,
|
||||
n_head: usize,
|
||||
n_state: usize,
|
||||
n_context_state: usize,
|
||||
n_head: usize,
|
||||
}
|
||||
|
||||
impl MultiHeadAttentionConfig {
|
||||
fn init<B: Backend>(&self) -> MultiHeadAttention<B> {
|
||||
assert!(self.n_state % self.n_head == 0, "State size {} must be a multiple of head size {}", self.n_state, self.n_head);
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> MultiHeadAttention<B> {
|
||||
assert!(
|
||||
self.n_state % self.n_head == 0,
|
||||
"State size {} must be a multiple of head size {}",
|
||||
self.n_state,
|
||||
self.n_head
|
||||
);
|
||||
|
||||
let n_head = self.n_head;
|
||||
let query = nn::LinearConfig::new(self.n_state, self.n_state).with_bias(false).init();
|
||||
let key = nn::LinearConfig::new(self.n_context_state, self.n_state).with_bias(false).init();
|
||||
let value = nn::LinearConfig::new(self.n_context_state, self.n_state).with_bias(false).init();
|
||||
let out = nn::LinearConfig::new(self.n_state, self.n_state).init();
|
||||
let query = nn::LinearConfig::new(self.n_state, self.n_state)
|
||||
.with_bias(false)
|
||||
.init(device);
|
||||
let key = nn::LinearConfig::new(self.n_context_state, self.n_state)
|
||||
.with_bias(false)
|
||||
.init(device);
|
||||
let value = nn::LinearConfig::new(self.n_context_state, self.n_state)
|
||||
.with_bias(false)
|
||||
.init(device);
|
||||
let out = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
|
||||
|
||||
MultiHeadAttention {
|
||||
n_head,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out
|
||||
MultiHeadAttention {
|
||||
n_head,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct MultiHeadAttention<B: Backend> {
|
||||
n_head: usize,
|
||||
query: nn::Linear<B>,
|
||||
key: nn::Linear<B>,
|
||||
value: nn::Linear<B>,
|
||||
out: nn::Linear<B>,
|
||||
n_head: usize,
|
||||
query: nn::Linear<B>,
|
||||
key: nn::Linear<B>,
|
||||
value: nn::Linear<B>,
|
||||
out: nn::Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> MultiHeadAttention<B> {
|
||||
@@ -652,74 +652,61 @@ impl<B: Backend> MultiHeadAttention<B> {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
#[derive(Config, Debug)]
|
||||
pub struct ResBlockConfig {
|
||||
n_channels_in: usize,
|
||||
n_channels_embed: usize,
|
||||
n_channels_out: usize,
|
||||
n_channels_in: usize,
|
||||
n_channels_embed: usize,
|
||||
n_channels_out: usize,
|
||||
}
|
||||
|
||||
|
||||
impl ResBlockConfig {
|
||||
fn init<B: Backend>(&self) -> ResBlock<B> {
|
||||
let norm_in = GroupNormConfig::new(32, self.n_channels_in).init();
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> ResBlock<B> {
|
||||
let norm_in = GroupNormConfig::new(32, self.n_channels_in).init(device);
|
||||
let silu_in = SILU::new();
|
||||
let conv_in = Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
||||
let conv_in = Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init(device);
|
||||
|
||||
let silu_embed = SILU::new();
|
||||
let lin_embed = nn::LinearConfig::new(self.n_channels_embed, self.n_channels_out).init();
|
||||
let lin_embed = nn::LinearConfig::new(self.n_channels_embed, self.n_channels_out).init(device);
|
||||
|
||||
let norm_out = GroupNormConfig::new(32, self.n_channels_out).init();
|
||||
let norm_out = GroupNormConfig::new(32, self.n_channels_out).init(device);
|
||||
let silu_out = SILU::new();
|
||||
let conv_out = Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
||||
let conv_out = Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init(device);
|
||||
|
||||
let skip_connection = if self.n_channels_in != self.n_channels_out {
|
||||
Some( Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [1, 1]).init() )
|
||||
Some(Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [1, 1]).init(device))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
ResBlock {
|
||||
norm_in,
|
||||
silu_in,
|
||||
conv_in,
|
||||
silu_embed,
|
||||
lin_embed,
|
||||
norm_out,
|
||||
silu_out,
|
||||
conv_out,
|
||||
skip_connection,
|
||||
norm_in,
|
||||
silu_in,
|
||||
conv_in,
|
||||
silu_embed,
|
||||
lin_embed,
|
||||
norm_out,
|
||||
silu_out,
|
||||
conv_out,
|
||||
skip_connection,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ResBlock<B: Backend> {
|
||||
norm_in: GroupNorm<B>,
|
||||
silu_in: SILU,
|
||||
conv_in: Conv2d<B>,
|
||||
silu_embed: SILU,
|
||||
lin_embed: nn::Linear<B>,
|
||||
norm_out: GroupNorm<B>,
|
||||
silu_out: SILU,
|
||||
conv_out: Conv2d<B>,
|
||||
skip_connection: Option<Conv2d<B>>,
|
||||
norm_in: GroupNorm<B>,
|
||||
silu_in: SILU,
|
||||
conv_in: Conv2d<B>,
|
||||
silu_embed: SILU,
|
||||
lin_embed: nn::Linear<B>,
|
||||
norm_out: GroupNorm<B>,
|
||||
silu_out: SILU,
|
||||
conv_out: Conv2d<B>,
|
||||
skip_connection: Option<Conv2d<B>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ResBlock<B> {
|
||||
@@ -730,7 +717,7 @@ impl<B: Backend> ResBlock<B> {
|
||||
|
||||
let embed_out = self.silu_embed.forward(embed);
|
||||
let embed_out = self.lin_embed.forward(embed_out);
|
||||
|
||||
|
||||
let [n_batch_embed, n_state_embed] = embed_out.dims();
|
||||
let h = h + embed_out.reshape([n_batch_embed, n_state_embed, 1, 1]);
|
||||
|
||||
@@ -751,5 +738,3 @@ impl<B: Backend> UNetBlock<B> for ResBlock<B> {
|
||||
self.forward(x, emb)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
use std::collections::HashMap;
|
||||
use regex::Regex;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use std::fs::File;
|
||||
use std::io::{self, BufRead};
|
||||
|
||||
fn bytes_to_unicode() -> Vec<(u8, char)> {
|
||||
let mut bs: Vec<u8> = ('!' as u8 ..= '~' as u8).into_iter()
|
||||
.chain( ('¡' as u8..='¬' as u8).into_iter() )
|
||||
.chain( ('®' as u8..='ÿ' as u8).into_iter() )
|
||||
let mut bs: Vec<u8> = ('!' as u8..='~' as u8)
|
||||
.into_iter()
|
||||
.chain(('¡' as u8..='¬' as u8).into_iter())
|
||||
.chain(('®' as u8..='ÿ' as u8).into_iter())
|
||||
.collect();
|
||||
|
||||
let mut cs: Vec<_> = bs.iter().cloned().map(char::from).collect();
|
||||
@@ -16,25 +17,21 @@ fn bytes_to_unicode() -> Vec<(u8, char)> {
|
||||
for b in 0u8..=255u8 {
|
||||
if !bs.contains(&b) {
|
||||
bs.push(b);
|
||||
cs.push( char::from_u32(256 + n).unwrap() );
|
||||
cs.push(char::from_u32(256 + n).unwrap());
|
||||
n += 1;
|
||||
}
|
||||
}
|
||||
|
||||
bs.into_iter()
|
||||
.zip(
|
||||
cs.into_iter()
|
||||
.map(|c| c.into())
|
||||
).collect()
|
||||
.zip(cs.into_iter().map(|c| c.into()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn get_pairs(word: &[String]) -> Vec<(String, String)> {
|
||||
let prev = word.into_iter().cloned();
|
||||
let next = prev.clone().skip(1);
|
||||
|
||||
prev
|
||||
.zip(next)
|
||||
.collect()
|
||||
prev.zip(next).collect()
|
||||
}
|
||||
|
||||
fn whitespace_clean(text: &str) -> String {
|
||||
@@ -44,24 +41,27 @@ fn whitespace_clean(text: &str) -> String {
|
||||
fn load_merges(path: &str) -> io::Result<Vec<(String, String)>> {
|
||||
let file = File::open(&path)?;
|
||||
let reader = io::BufReader::new(file);
|
||||
|
||||
|
||||
let mut merges = Vec::new();
|
||||
|
||||
|
||||
for line in reader.lines() {
|
||||
let line = line?;
|
||||
let mut words = line.split_whitespace();
|
||||
|
||||
|
||||
if let (Some(word1), Some(word2)) = (words.next(), words.next()) {
|
||||
merges.push((word1.into(), word2.into()));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Ok(merges)
|
||||
}
|
||||
|
||||
fn construct_vocab(chars: impl Iterator<Item=char> + Clone, merges: &[(String, String)]) -> Vec<String> {
|
||||
fn construct_vocab(
|
||||
chars: impl Iterator<Item = char> + Clone,
|
||||
merges: &[(String, String)],
|
||||
) -> Vec<String> {
|
||||
let iter = chars.map(String::from);
|
||||
let mut vocab: Vec<_> = iter.clone().chain( iter.map(|c| c + "</w>") ).collect();
|
||||
let mut vocab: Vec<_> = iter.clone().chain(iter.map(|c| c + "</w>")).collect();
|
||||
|
||||
for merge in merges {
|
||||
vocab.push(format!("{}{}", merge.0, merge.1));
|
||||
@@ -79,7 +79,7 @@ pub struct SimpleTokenizer {
|
||||
decoder: HashMap<u32, String>,
|
||||
bpe_ranks: HashMap<(String, String), u32>,
|
||||
cache: HashMap<String, String>,
|
||||
pat: Regex,
|
||||
pat: Regex,
|
||||
}
|
||||
|
||||
impl SimpleTokenizer {
|
||||
@@ -87,10 +87,10 @@ impl SimpleTokenizer {
|
||||
let byte_unicode_values = bytes_to_unicode();
|
||||
|
||||
let byte_encoder: HashMap<_, _> = byte_unicode_values.iter().cloned().collect();
|
||||
let byte_decoder = byte_encoder.iter().map(|(k,v)| (*v,*k)).collect();
|
||||
let byte_decoder = byte_encoder.iter().map(|(k, v)| (*v, *k)).collect();
|
||||
|
||||
let merges = load_merges("bpe_simple_vocab_16e6.txt")?;
|
||||
let merges = merges[1..49152-256-2+1].to_vec();
|
||||
let merges = merges[1..49152 - 256 - 2 + 1].to_vec();
|
||||
|
||||
let vocab = construct_vocab(byte_unicode_values.into_iter().map(|(_, u)| u), &merges[..]);
|
||||
|
||||
@@ -98,38 +98,39 @@ impl SimpleTokenizer {
|
||||
let decoder: HashMap<u32, String> = encoder.iter().map(|(k, v)| (*v, k.clone())).collect();
|
||||
let bpe_ranks = merges.iter().cloned().zip((0..).into_iter()).collect();
|
||||
let cache = HashMap::from([
|
||||
("<|startoftext|>".to_string(), "<|startoftext|>".to_string()),
|
||||
("<|endoftext|>".to_string(), "<|endoftext|>".to_string()),
|
||||
("<|startoftext|>".to_string(), "<|startoftext|>".to_string()),
|
||||
("<|endoftext|>".to_string(), "<|endoftext|>".to_string()),
|
||||
]);
|
||||
|
||||
let pat = Regex::new(r"(?i)<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|\p{L}+|\p{N}|[^\s\p{L}\p{N}]+").unwrap();
|
||||
|
||||
Ok( SimpleTokenizer {
|
||||
Ok(SimpleTokenizer {
|
||||
byte_encoder: byte_encoder,
|
||||
byte_decoder: byte_decoder,
|
||||
encoder: encoder,
|
||||
decoder: decoder,
|
||||
bpe_ranks: bpe_ranks,
|
||||
cache: cache,
|
||||
pat: pat,
|
||||
} )
|
||||
pat: pat,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn bpe(&self, token: &str) -> String {
|
||||
if let Some(word) = self.cache.get(token) {
|
||||
return word.clone();
|
||||
}
|
||||
|
||||
|
||||
let mut word: Vec<String> = token.chars().map(|c| c.to_string()).collect();
|
||||
word.last_mut().map(|w| *w += "</w>");
|
||||
let mut pairs = get_pairs(&word);
|
||||
|
||||
|
||||
if pairs.is_empty() {
|
||||
return format!("{}{}", token, "</w>");
|
||||
}
|
||||
|
||||
|
||||
loop {
|
||||
let bigram = pairs.iter()
|
||||
let bigram = pairs
|
||||
.iter()
|
||||
.filter(|pair| self.bpe_ranks.contains_key(pair))
|
||||
.min_by_key(|&pair| self.bpe_ranks[pair]);
|
||||
|
||||
@@ -141,14 +142,14 @@ impl SimpleTokenizer {
|
||||
let mut new_word = Vec::new();
|
||||
let mut i = 0;
|
||||
while i < word.len() {
|
||||
if let Some( (j, _) ) = word.iter().enumerate().skip(i).find(|(_, w)| w == &first) {
|
||||
if let Some((j, _)) = word.iter().enumerate().skip(i).find(|(_, w)| w == &first) {
|
||||
new_word.extend(word[i..j].iter().cloned());
|
||||
i = j;
|
||||
} else {
|
||||
new_word.extend(word[i..].iter().cloned());
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
if &word[i] == first && i < word.len() - 1 && &word[i + 1] == second {
|
||||
new_word.push(format!("{}{}", first, second));
|
||||
i += 2;
|
||||
@@ -157,7 +158,7 @@ impl SimpleTokenizer {
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
word = new_word;
|
||||
if word.len() == 1 {
|
||||
break;
|
||||
@@ -170,7 +171,7 @@ impl SimpleTokenizer {
|
||||
//self.cache.insert(token.into(), word);
|
||||
return word;
|
||||
}
|
||||
|
||||
|
||||
pub fn encode(&self, text: &str) -> Vec<u32> {
|
||||
let cleaned_text = whitespace_clean(text.trim()).to_lowercase();
|
||||
|
||||
@@ -178,8 +179,16 @@ impl SimpleTokenizer {
|
||||
|
||||
for m in self.pat.find_iter(&cleaned_text) {
|
||||
let token = m.as_str();
|
||||
let token: String = token.as_bytes().into_iter().map(|b| self.byte_encoder[b]).collect();
|
||||
bpe_tokens.extend(self.bpe(&token).split(' ').map(|bpe_token| self.encoder[bpe_token]))
|
||||
let token: String = token
|
||||
.as_bytes()
|
||||
.into_iter()
|
||||
.map(|b| self.byte_encoder[b])
|
||||
.collect();
|
||||
bpe_tokens.extend(
|
||||
self.bpe(&token)
|
||||
.split(' ')
|
||||
.map(|bpe_token| self.encoder[bpe_token]),
|
||||
)
|
||||
}
|
||||
|
||||
return bpe_tokens;
|
||||
@@ -187,9 +196,7 @@ impl SimpleTokenizer {
|
||||
|
||||
pub fn decode(&self, tokens: &[u32]) -> String {
|
||||
let text: String = tokens.iter().map(|t| self.decoder[t].as_str()).collect();
|
||||
let decoded_bytes: Vec<u8> = text.chars()
|
||||
.map(|c| self.byte_decoder[&c])
|
||||
.collect();
|
||||
let decoded_bytes: Vec<u8> = text.chars().map(|c| self.byte_decoder[&c]).collect();
|
||||
|
||||
String::from_utf8_lossy(&decoded_bytes[..]).replace("</w>", " ")
|
||||
}
|
||||
@@ -212,4 +219,4 @@ mod tests {
|
||||
let decoded = tokenizer.decode(&encoded[..]);
|
||||
assert_eq!(target_decode, decoded);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user