Merge pull request 'Vibed Changes' (#1) from test into main

Reviewed-on: #1
This commit was merged in pull request #1.
This commit is contained in:
2026-03-03 22:17:25 +01:00
37 changed files with 266987 additions and 0 deletions

16
.gitignore vendored Normal file
View File

@@ -0,0 +1,16 @@
# Generated by Cargo
# will have compiled files and executables
debug/
target/
# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
Cargo.lock
# These are backup files generated by rustfmt
**/*.rs.bk
# MSVC Windows builds of rustc generate these, which store debugging information
*.pdb
**/*.DS_Store

22
Cargo.toml Normal file
View File

@@ -0,0 +1,22 @@
[package]
name = "stablediffusion"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features]
wgpu-backend = ["burn-wgpu"]
default = ["wgpu-backend"]
[dependencies]
burn = "0.20.1"
burn-autodiff = "0.20.1"
burn-wgpu = { version = "0.20.1", optional = true }
serde = {version = "1.0.171", features = ["std", "derive"]}
npy = "0.4.0"
num-traits = "0.2.15"
rust_tokenizers = "8.1.0"
regex = "1.9.1"
image = "0.24.6"
cfg-if = "0.1"

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 Gadersd
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

68
README.md Normal file
View File

@@ -0,0 +1,68 @@
# Stable-Diffusion-Burn
Stable-Diffusion-Burn is a Rust-based project which ports the V1 stable diffusion model into the deep learning framework, Burn. This repository is licensed under the MIT Licence.
## How To Use
### Step 1: Download the Model and Set Environment Variables
Start by downloading the SDv1-4 model provided on HuggingFace.
```bash
wget https://huggingface.co/Gadersd/Stable-Diffusion-Burn/resolve/main/SDv1-4.mpk
```
### Step 2: Run the Sample Binary
Invoke the sample binary provided in the rust code. The application now uses a pure Rust backend (WGPU/Vulkan) instead of libtorch. The WGPU backend is unstable for SD but may work well in the future as burn-wpu is optimized.
```bash
# WGPU/Vulkan backend (GPU accelerated, requires Vulkan-compatible GPU)
# Arguments: <model_type(burn or dump)> <model_name> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image_name>
# GPU (Vulkan)
cargo run --release --features wgpu-backend --bin sample burn SDv1-4 7.5 20 "An ancient mossy stone." img
# CPU (UNSTABLE - fallback if GPU not available)
cargo run --release --bin sample burn SDv1-4 7.5 20 "An ancient mossy stone." img
This command will generate an image according to the provided prompt, which will be saved as 'img0.png'.
![An image of an ancient mossy stone](img0.png)
### Optional: Extract and Convert a Fine-Tuned Model
If users are interested in using a fine-tuned version of stable diffusion, the Python scripts provided in this project can be used to transform a weight dump into a Burn model file. This does not work on Windows.
```bash
# Step into the Python directory
cd python
# Download the model, this is just the base v1.4 model as an example
wget https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt
# Install tinygrad
pip install -r requirements.txt
# Extract the weights
CPU=1 python3 dump.py sd-v1-4.ckpt
# Move the extracted weight folder out
mv params ..
# Step out of the Python directory
cd ..
# Convert the weights into a usable form
cargo run --release --bin convert params SDv1-4
```
The binaries 'convert' and 'sample' are contained in Rust. Convert works on CPU whereas sample needs CUDA.
Remember, `convert` should be used if you're planning on using the fine-tuned version of the stable diffusion.
## License
This project is licensed under MIT license.
We wish you a productive time using this project. Enjoy!

262145
bpe_simple_vocab_16e6.txt Normal file

File diff suppressed because it is too large Load Diff

BIN
img0.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 677 KiB

92
python/autoencoder.py Normal file
View File

@@ -0,0 +1,92 @@
import pathlib
import save
from save import *
from tinygrad.nn import Conv2d
def save_resnet_block(resnet_block, path):
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
save_group_norm(resnet_block.norm1, pathlib.Path(path, 'norm1'))
save_conv2d(resnet_block.conv1, pathlib.Path(path, 'conv1'))
save_group_norm(resnet_block.norm2, pathlib.Path(path, 'norm2'))
save_conv2d(resnet_block.conv2, pathlib.Path(path, 'conv2'))
if isinstance(resnet_block.nin_shortcut, Conv2d):
save_conv2d(resnet_block.nin_shortcut, pathlib.Path(path, 'nin_shortcut'))
def save_attn_block(attn_block, path):
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
save_group_norm(attn_block.norm, pathlib.Path(path, 'norm'))
save_conv2d(attn_block.q, pathlib.Path(path, 'q'))
save_conv2d(attn_block.k, pathlib.Path(path, 'k'))
save_conv2d(attn_block.v, pathlib.Path(path, 'v'))
save_conv2d(attn_block.proj_out, pathlib.Path(path, 'proj_out'))
def save_mid(mid, path):
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
save_resnet_block(mid.block_1, pathlib.Path(path, 'block_1'))
save_attn_block(mid.attn_1, pathlib.Path(path, 'attn'))
save_resnet_block(mid.block_2, pathlib.Path(path, 'block_2'))
def save_decoder_block(decoder_block, path):
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
if 'block' in decoder_block:
save_resnet_block(decoder_block["block"][0], pathlib.Path(path, 'res1'))
save_resnet_block(decoder_block["block"][1], pathlib.Path(path, 'res2'))
save_resnet_block(decoder_block["block"][2], pathlib.Path(path, 'res3'))
if 'upsample' in decoder_block:
save_conv2d(decoder_block['upsample']['conv'], pathlib.Path(path, 'upsampler'))
def save_decoder(decoder, path):
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
save_conv2d(decoder.conv_in, pathlib.Path(path, 'conv_in'))
save_mid(decoder.mid, pathlib.Path(path, 'mid'))
for i, block in enumerate(decoder.up[::-1]):
print(i)
if isinstance(block['block'][0].nin_shortcut, Conv2d):
print(block['block'][0].nin_shortcut.weight.shape)
save_decoder_block(block, pathlib.Path(path, f'blocks/{i}'))
save_scalar(len(decoder.up), "n_block", path)
save_group_norm(decoder.norm_out, pathlib.Path(path, 'norm_out'))
save_conv2d(decoder.conv_out, pathlib.Path(path, 'conv_out'))
def save_encoder_block(encoder_block, path):
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
if 'block' in encoder_block:
save_resnet_block(encoder_block["block"][0], pathlib.Path(path, 'res1'))
save_resnet_block(encoder_block["block"][1], pathlib.Path(path, 'res2'))
if 'downsample' in encoder_block:
save_padded_conv2d(encoder_block['downsample']['conv'], pathlib.Path(path, 'downsampler'))
def save_encoder(encoder, path):
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
save_conv2d(encoder.conv_in, pathlib.Path(path, 'conv_in'))
save_mid(encoder.mid, pathlib.Path(path, 'mid'))
for i, block in enumerate(encoder.down):
save_encoder_block(block, pathlib.Path(path, f'blocks/{i}'))
save_scalar(len(encoder.down), "n_block", path)
save_group_norm(encoder.norm_out, pathlib.Path(path, 'norm_out'))
save_conv2d(encoder.conv_out, pathlib.Path(path, 'conv_out'))
def save_autoencoder(autoencoder, path):
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
save_encoder(autoencoder.encoder, pathlib.Path(path, 'encoder'))
save_decoder(autoencoder.decoder, pathlib.Path(path, 'decoder'))
save_conv2d(autoencoder.quant_conv, pathlib.Path(path, 'quant_conv'))
save_conv2d(autoencoder.post_quant_conv, pathlib.Path(path, 'post_quant_conv'))

Binary file not shown.

40
python/clip.py Normal file
View File

@@ -0,0 +1,40 @@
import pathlib
import save
from save import *
def save_clipmlp(clip_mlp, path):
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
save_linear(clip_mlp.fc1, pathlib.Path(path, 'fc1'))
save_linear(clip_mlp.fc2, pathlib.Path(path, 'fc2'))
def save_clip_attention(clip_attention, path):
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
save_linear(clip_attention.k_proj, pathlib.Path(path, 'key'))
save_linear(clip_attention.v_proj, pathlib.Path(path, 'value'))
save_linear(clip_attention.q_proj, pathlib.Path(path, 'query'))
save_linear(clip_attention.out_proj, pathlib.Path(path, 'out'))
save_scalar(clip_attention.num_heads, 'n_head', path)
def save_clip_encoder_layer(clip_encoder_layer, path):
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
save_clip_attention(clip_encoder_layer.self_attn, pathlib.Path(path, 'attn'))
save_layer_norm(clip_encoder_layer.layer_norm1, pathlib.Path(path, 'attn_ln'))
save_clipmlp(clip_encoder_layer.mlp, pathlib.Path(path, 'mlp'))
save_layer_norm(clip_encoder_layer.layer_norm2, pathlib.Path(path, 'mlp_ln'))
def save_clip_encoder(clip_encoder, path):
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
for i, layer in enumerate(clip_encoder.layers):
save_clip_encoder_layer(layer, pathlib.Path(path, f'blocks/{i}'))
save_scalar(len(clip_encoder.layers), "n_layer", path)
def save_clip_text_embeddings(clip_text_embeddings, path):
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
save_embedding(clip_text_embeddings.token_embedding, pathlib.Path(path, 'token_embedding'))
save_embedding(clip_text_embeddings.position_embedding, pathlib.Path(path, 'position_embedding'))
def save_clip_text_transformer(clip_text_transformer, path):
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
save_clip_text_embeddings(clip_text_transformer.embeddings, path)
save_clip_encoder(clip_text_transformer.encoder, path)
save_layer_norm(clip_text_transformer.final_layer_norm, pathlib.Path(path, 'layer_norm'))

652
python/dump.py Normal file
View File

@@ -0,0 +1,652 @@
# This code is modified from the tinygrad stable diffusion example
# (https://github.com/tinygrad/tinygrad/blob/master/examples/stable_diffusion.py)
# used under the MIT license.
# https://arxiv.org/pdf/2112.10752.pdf
# https://github.com/ekagra-ranjan/huggingface-blog/blob/main/stable_diffusion.md
import os
import tempfile
from pathlib import Path
import gzip, argparse, math, re
from functools import lru_cache
from collections import namedtuple
from tqdm import tqdm
from tinygrad.tensor import Tensor
from tinygrad.helpers import GlobalCounters
from tinygrad import dtypes
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
#from extra.utils import download_file
from tinygrad.nn.state import torch_load, load_state_dict
# TODO: refactor AttnBlock, CrossAttention, CLIPAttention to share code
class AttnBlock:
def __init__(self, in_channels):
self.norm = GroupNorm(32, in_channels)
self.q = Conv2d(in_channels, in_channels, 1)
self.k = Conv2d(in_channels, in_channels, 1)
self.v = Conv2d(in_channels, in_channels, 1)
self.proj_out = Conv2d(in_channels, in_channels, 1)
# copied from AttnBlock in ldm repo
def __call__(self, x):
h_ = self.norm(x)
q,k,v = self.q(h_), self.k(h_), self.v(h_)
# compute attention
b,c,h,w = q.shape
q = q.reshape(b,c,h*w)
q = q.permute(0,2,1) # b,hw,c
k = k.reshape(b,c,h*w) # b,c,hw
w_ = q @ k
w_ = w_ * (c**(-0.5))
w_ = w_.softmax()
# attend to values
v = v.reshape(b,c,h*w)
w_ = w_.permute(0,2,1)
h_ = v @ w_
h_ = h_.reshape(b,c,h,w)
return x + self.proj_out(h_)
class ResnetBlock:
def __init__(self, in_channels, out_channels=None):
self.norm1 = GroupNorm(32, in_channels)
self.conv1 = Conv2d(in_channels, out_channels, 3, padding=1)
self.norm2 = GroupNorm(32, out_channels)
self.conv2 = Conv2d(out_channels, out_channels, 3, padding=1)
self.nin_shortcut = Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else lambda x: x
def __call__(self, x):
h = self.conv1(self.norm1(x).swish())
h = self.conv2(self.norm2(h).swish())
return self.nin_shortcut(x) + h
class Mid:
def __init__(self, block_in):
self.block_1 = ResnetBlock(block_in, block_in)
self.attn_1 = AttnBlock(block_in)
self.block_2 = ResnetBlock(block_in, block_in)
def __call__(self, x):
return x.sequential([self.block_1, self.attn_1, self.block_2])
class Decoder:
def __init__(self):
sz = [(128, 256), (256, 512), (512, 512), (512, 512)]
self.conv_in = Conv2d(4,512,3, padding=1)
self.mid = Mid(512)
arr = []
for i,s in enumerate(sz):
arr.append({"block":
[ResnetBlock(s[1], s[0]),
ResnetBlock(s[0], s[0]),
ResnetBlock(s[0], s[0])]})
if i != 0: arr[-1]['upsample'] = {"conv": Conv2d(s[0], s[0], 3, padding=1)}
self.up = arr
self.norm_out = GroupNorm(32, 128)
self.conv_out = Conv2d(128, 3, 3, padding=1)
def __call__(self, x):
x = self.conv_in(x)
x = self.mid(x)
for l in self.up[::-1]:
for b in l['block']:
x = b(x)
if 'upsample' in l:
# https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html ?
bs,c,py,px = x.shape
x = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py*2, px*2)
x = l['upsample']['conv'](x)
x.realize()
return self.conv_out(self.norm_out(x).swish())
class Encoder:
def __init__(self):
sz = [(128, 128), (128, 256), (256, 512), (512, 512)]
self.conv_in = Conv2d(3,128,3, padding=1)
arr = []
for i,s in enumerate(sz):
arr.append({"block":
[ResnetBlock(s[0], s[1]),
ResnetBlock(s[1], s[1])]})
if i != 3: arr[-1]['downsample'] = {"conv": Conv2d(s[1], s[1], 3, stride=2, padding=(0,1,0,1))}
self.down = arr
self.mid = Mid(512)
self.norm_out = GroupNorm(32, 512)
self.conv_out = Conv2d(512, 8, 3, padding=1)
def __call__(self, x):
x = self.conv_in(x)
for i, l in enumerate(self.down):
for b in l['block']: x = b(x)
if 'downsample' in l: x = l['downsample']['conv'](x)
x = self.mid(x)
return self.conv_out(self.norm_out(x).swish())
class AutoencoderKL:
def __init__(self):
self.encoder = Encoder()
self.decoder = Decoder()
self.quant_conv = Conv2d(8, 8, 1)
self.post_quant_conv = Conv2d(4, 4, 1)
def __call__(self, x):
latent = self.encoder(x)
latent = self.quant_conv(latent)
latent = latent[:, 0:4] # only the means
latent = self.post_quant_conv(latent)
return self.decoder(latent)
# not to be confused with ResnetBlock
class ResBlock:
def __init__(self, channels, emb_channels, out_channels):
self.in_layers = [
GroupNorm(32, channels),
Tensor.silu,
Conv2d(channels, out_channels, 3, padding=1)
]
self.emb_layers = [
Tensor.silu,
Linear(emb_channels, out_channels)
]
self.out_layers = [
GroupNorm(32, out_channels),
Tensor.silu,
lambda x: x, # needed for weights loading code to work
Conv2d(out_channels, out_channels, 3, padding=1)
]
self.skip_connection = Conv2d(channels, out_channels, 1) if channels != out_channels else lambda x: x
def __call__(self, x, emb):
h = x.sequential(self.in_layers)
emb_out = emb.sequential(self.emb_layers)
h = h + emb_out.reshape(*emb_out.shape, 1, 1)
h = h.sequential(self.out_layers)
ret = self.skip_connection(x) + h
return ret
class CrossAttention:
def __init__(self, query_dim, context_dim, n_heads, d_head):
self.to_q = Linear(query_dim, n_heads*d_head, bias=False)
self.to_k = Linear(context_dim, n_heads*d_head, bias=False)
self.to_v = Linear(context_dim, n_heads*d_head, bias=False)
self.scale = d_head ** -0.5
self.num_heads = n_heads
self.head_size = d_head
self.to_out = [Linear(n_heads*d_head, query_dim)]
def __call__(self, x, context=None):
context = x if context is None else context
q,k,v = self.to_q(x), self.to_k(context), self.to_v(context)
q = q.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0,2,1,3) # (bs, num_heads, time, head_size)
k = k.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0,2,3,1) # (bs, num_heads, head_size, time)
v = v.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0,2,1,3) # (bs, num_heads, time, head_size)
score = q.dot(k) * self.scale
weights = score.softmax() # (bs, num_heads, time, time)
attention = weights.dot(v).permute(0,2,1,3) # (bs, time, num_heads, head_size)
h_ = attention.reshape(shape=(x.shape[0], -1, self.num_heads * self.head_size))
return h_.sequential(self.to_out)
class GEGLU:
def __init__(self, dim_in, dim_out):
self.proj = Linear(dim_in, dim_out * 2)
self.dim_out = dim_out
def __call__(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * gate.gelu()
class FeedForward:
def __init__(self, dim, mult=4):
self.net = [
GEGLU(dim, dim*mult),
lambda x: x, # needed for weights loading code to work
Linear(dim*mult, dim)
]
def __call__(self, x):
return x.sequential(self.net)
class BasicTransformerBlock:
def __init__(self, dim, context_dim, n_heads, d_head):
self.attn1 = CrossAttention(dim, dim, n_heads, d_head)
self.ff = FeedForward(dim)
self.attn2 = CrossAttention(dim, context_dim, n_heads, d_head)
self.norm1 = LayerNorm(dim)
self.norm2 = LayerNorm(dim)
self.norm3 = LayerNorm(dim)
def __call__(self, x, context=None):
x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
class SpatialTransformer:
def __init__(self, channels, context_dim, n_heads, d_head):
self.norm = GroupNorm(32, channels)
assert channels == n_heads * d_head
self.proj_in = Conv2d(channels, n_heads * d_head, 1)
self.transformer_blocks = [BasicTransformerBlock(channels, context_dim, n_heads, d_head)]
self.proj_out = Conv2d(n_heads * d_head, channels, 1)
def __call__(self, x, context=None):
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
x = self.proj_in(x)
x = x.reshape(b, c, h*w).permute(0,2,1)
for block in self.transformer_blocks:
x = block(x, context=context)
x = x.permute(0,2,1).reshape(b, c, h, w)
ret = self.proj_out(x) + x_in
return ret
class Downsample:
def __init__(self, channels):
self.op = Conv2d(channels, channels, 3, stride=2, padding=1)
def __call__(self, x):
return self.op(x)
class Upsample:
def __init__(self, channels):
self.conv = Conv2d(channels, channels, 3, padding=1)
def __call__(self, x):
bs,c,py,px = x.shape
x = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py*2, px*2)
return self.conv(x)
def timestep_embedding(timesteps, dim, max_period=10000):
half = dim // 2
freqs = (-math.log(max_period) * Tensor.arange(half) / half).exp()
args = timesteps * freqs
return Tensor.cat(args.cos(), args.sin()).reshape(1, -1)
class UNetModel:
def __init__(self):
self.time_embed = [
Linear(320, 1280),
Tensor.silu,
Linear(1280, 1280),
]
self.input_blocks = [
[Conv2d(4, 320, kernel_size=3, padding=1)],
[ResBlock(320, 1280, 320), SpatialTransformer(320, 768, 8, 40)],
[ResBlock(320, 1280, 320), SpatialTransformer(320, 768, 8, 40)],
[Downsample(320)],
[ResBlock(320, 1280, 640), SpatialTransformer(640, 768, 8, 80)],
[ResBlock(640, 1280, 640), SpatialTransformer(640, 768, 8, 80)],
[Downsample(640)],
[ResBlock(640, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)],
[ResBlock(1280, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)],
[Downsample(1280)],
[ResBlock(1280, 1280, 1280)],
[ResBlock(1280, 1280, 1280)]
]
self.middle_block = [
ResBlock(1280, 1280, 1280),
SpatialTransformer(1280, 768, 8, 160),
ResBlock(1280, 1280, 1280)
]
self.output_blocks = [
[ResBlock(2560, 1280, 1280)],
[ResBlock(2560, 1280, 1280)],
[ResBlock(2560, 1280, 1280), Upsample(1280)],
[ResBlock(2560, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)],
[ResBlock(2560, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)],
[ResBlock(1920, 1280, 1280), SpatialTransformer(1280, 768, 8, 160), Upsample(1280)],
[ResBlock(1920, 1280, 640), SpatialTransformer(640, 768, 8, 80)], # 6
[ResBlock(1280, 1280, 640), SpatialTransformer(640, 768, 8, 80)],
[ResBlock(960, 1280, 640), SpatialTransformer(640, 768, 8, 80), Upsample(640)],
[ResBlock(960, 1280, 320), SpatialTransformer(320, 768, 8, 40)],
[ResBlock(640, 1280, 320), SpatialTransformer(320, 768, 8, 40)],
[ResBlock(640, 1280, 320), SpatialTransformer(320, 768, 8, 40)],
]
self.out = [
GroupNorm(32, 320),
Tensor.silu,
Conv2d(320, 4, kernel_size=3, padding=1)
]
def __call__(self, x, timesteps=None, context=None):
# TODO: real time embedding
t_emb = timestep_embedding(timesteps, 320)
emb = t_emb.sequential(self.time_embed)
def run(x, bb):
if isinstance(bb, ResBlock): x = bb(x, emb)
elif isinstance(bb, SpatialTransformer): x = bb(x, context)
else: x = bb(x)
return x
saved_inputs = []
for i,b in enumerate(self.input_blocks):
for bb in b:
x = run(x, bb)
saved_inputs.append(x)
for bb in self.middle_block:
x = run(x, bb)
for i,b in enumerate(self.output_blocks):
x = x.cat(saved_inputs.pop(), dim=1)
for bb in b:
x = run(x, bb)
return x.sequential(self.out)
class CLIPMLP:
def __init__(self):
self.fc1 = Linear(768, 3072)
self.fc2 = Linear(3072, 768)
def __call__(self, hidden_states):
hidden_states = self.fc1(hidden_states)
hidden_states = hidden_states.quick_gelu()
hidden_states = self.fc2(hidden_states)
return hidden_states
class CLIPAttention:
def __init__(self):
self.embed_dim = 768
self.num_heads = 12
self.head_dim = self.embed_dim // self.num_heads
self.scale = self.head_dim**-0.5
self.k_proj = Linear(self.embed_dim, self.embed_dim)
self.v_proj = Linear(self.embed_dim, self.embed_dim)
self.q_proj = Linear(self.embed_dim, self.embed_dim)
self.out_proj = Linear(self.embed_dim, self.embed_dim)
def _shape(self, tensor, seq_len: int, bsz: int):
return tensor.reshape(bsz, seq_len, self.num_heads, self.head_dim).permute(0,2,1,3)
def __call__(self, hidden_states, causal_attention_mask):
bsz, tgt_len, embed_dim = hidden_states.shape
query_states = self.q_proj(hidden_states) * self.scale
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).reshape(*proj_shape)
key_states = key_states.reshape(*proj_shape)
src_len = key_states.shape[1]
value_states = value_states.reshape(*proj_shape)
attn_weights = query_states @ key_states.permute(0,2,1)
attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.softmax()
attn_output = attn_weights @ value_states
attn_output = attn_output.reshape(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.permute(0,2,1,3)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output
class CLIPEncoderLayer:
def __init__(self):
self.self_attn = CLIPAttention()
self.layer_norm1 = LayerNorm(768)
self.mlp = CLIPMLP()
self.layer_norm2 = LayerNorm(768)
def __call__(self, hidden_states, causal_attention_mask):
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(hidden_states, causal_attention_mask)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class CLIPEncoder:
def __init__(self):
self.layers = [CLIPEncoderLayer() for i in range(12)]
def __call__(self, hidden_states, causal_attention_mask):
for l in self.layers:
hidden_states = l(hidden_states, causal_attention_mask)
return hidden_states
class CLIPTextEmbeddings:
def __init__(self):
self.token_embedding = Embedding(49408, 768)
self.position_embedding = Embedding(77, 768)
def __call__(self, input_ids, position_ids):
return self.token_embedding(input_ids) + self.position_embedding(position_ids)
class CLIPTextTransformer:
def __init__(self):
self.embeddings = CLIPTextEmbeddings()
self.encoder = CLIPEncoder()
self.final_layer_norm = LayerNorm(768)
def __call__(self, input_ids):
seq_len = input_ids.shape[1]
x = self.embeddings(input_ids, Tensor.arange(seq_len).reshape(1, -1))
mask = Tensor.full((1, 1, seq_len, seq_len), float("-inf")).triu(1)
x = self.encoder(x, mask)
return self.final_layer_norm(x)
# Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license)
@lru_cache()
def default_bpe():
return Path(__file__).parent.parent / "weights/bpe_simple_vocab_16e6.txt.gz"
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
class ClipTokenizer:
def __init__(self, bpe_path: str = default_bpe()):
self.byte_encoder = bytes_to_unicode()
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
merges = merges[1:49152-256-2+1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v+'</w>' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""", re.IGNORECASE)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
pairs = get_pairs(word)
if not pairs:
return token+'</w>'
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except Exception:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = whitespace_clean(text.strip()).lower()
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
# Truncation, keeping two slots for start and end tokens.
if len(bpe_tokens) > 75:
bpe_tokens = bpe_tokens[:75]
return [49406] + bpe_tokens + [49407] * (77 - len(bpe_tokens) - 1)
class StableDiffusion:
def __init__(self):
self.alphas_cumprod = Tensor.empty(1000)
self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel())
self.first_stage_model = AutoencoderKL()
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = CLIPTextTransformer()))
# TODO: make __call__ run the model
# ** ldm.models.autoencoder.AutoencoderKL (done!)
# 3x512x512 <--> 4x64x64 (16384)
# decode torch.Size([1, 4, 64, 64]) torch.Size([1, 3, 512, 512])
# section 4.3 of paper
# first_stage_model.encoder, first_stage_model.decoder
# ** ldm.modules.diffusionmodules.openaimodel.UNetModel
# this is what runs each time to sample. is this the LDM?
# input: 4x64x64
# output: 4x64x64
# model.diffusion_model
# it has attention?
# ** ldm.modules.encoders.modules.FrozenCLIPEmbedder
# cond_stage_model.transformer.text_model
# this is sd-v1-4.ckpt
FILENAME = Path(__file__).parent.parent / "weights/sd-v1-4.ckpt"
import sys
import clip as clipsave
import autoencoder as autoencodersave
import unet as unetsave
import stablediffusion as sdsave
import numpy as np
if __name__ == "__main__":
Tensor.no_grad = True
'''clip = CLIPTextTransformer()
print('Saving model...')
clipsave.save_clip_text_transformer(clip, "params")
input = Tensor([3, 1])
output = clip(input.unsqueeze(0))
print(output[0, 0:2, 0:10].numpy())'''
'''autoencoder = AutoencoderKL()
print('Saving model...')
autoencodersave.save_autoencoder(autoencoder, "params")
input = Tensor.zeros((1, 3, 10, 10))
output = autoencoder(input)
print(output.shape)
print(output.numpy())'''
'''unet = UNetModel()
print('Saving model...')
unetsave.save_unet_model(unet, 'params')
input = Tensor.zeros([1, 4, 64, 64])
context = np.array([0.5, 1.3], dtype=np.float32) # specify dtype when defining the array
context = np.repeat(context, 768 // 2)
context = np.expand_dims(context, axis=0)
context = Tensor(context)
timesteps = Tensor([1.0])
output = unet(input, timesteps, context)
#print(output.numpy())'''
if len(sys.argv) != 2:
print(f"Wrong command line parameters, Usage: python3 {sys.argv[0]} <model_filename>")
sys.exit()
FILENAME = sys.argv[1]
Tensor.no_grad = True
model = StableDiffusion()
# load in weights
#download_file('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', FILENAME)
load_state_dict(model, torch_load(FILENAME)['state_dict'], strict=False)
print('Dumping model...')
sdsave.save_stable_diffusion(model, "params")
print('Model weights saved in params.')

1
python/requirements.txt Normal file
View File

@@ -0,0 +1 @@
tinygrad==0.9.2

100
python/save.py Normal file
View File

@@ -0,0 +1,100 @@
import pathlib
import numpy as np
from tinygrad.tensor import Tensor
def save_scalar(s, name, path):
s = np.array([1.0, float(s)]).astype(np.float32)
np.save(pathlib.Path(path, f'{name}.npy'), s)
def save_tensor(tensor, name, path):
tensor_numpy = tensor.numpy()
tensor_dims = np.array(tensor_numpy.shape)
tensor_values = tensor_numpy.flatten()
tensor_to_save = np.concatenate((tensor_dims, tensor_values)).astype(np.float32)
np.save(pathlib.Path(path, f'{name}.npy'), tensor_to_save)
def save_linear(linear, path):
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
save_tensor(linear.weight.transpose(), 'weight', path) # PyTorch and Tinygrad strangely transpose linear weights so reverse that
if linear.bias is not None:
save_tensor(linear.bias, 'bias', path)
def save_layer_norm(layer_norm, path):
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
save_tensor(layer_norm.weight, 'weight', path)
save_tensor(layer_norm.bias, 'bias', path)
save_scalar(layer_norm.eps, 'eps', path)
def save_group_norm(layer_norm, path):
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
if layer_norm.weight is not None:
save_tensor(layer_norm.weight, 'weight', path)
if layer_norm.bias is not None:
save_tensor(layer_norm.bias, 'bias', path)
save_scalar(layer_norm.eps, 'eps', path)
save_scalar(layer_norm.num_groups, 'n_group', path)
save_scalar(layer_norm.num_channels, 'n_channel', path)
def to_tuple_tensor(val):
if isinstance(val, tuple):
# Convert tuple to Tensor
if len(val) == 1:
return Tensor([val[0], val[0]])
elif len(val) == 2:
return Tensor([val[0], val[1]])
else:
raise ValueError('Tuple should be of length 1 or 2 only.')
else:
# Treat as scalar and convert to Tensor
return Tensor([val, val])
def save_conv2d(conv2d, path):
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
save_tensor(conv2d.weight, 'weight', path)
if conv2d.bias is not None:
save_tensor(conv2d.bias, 'bias', path)
save_tensor(to_tuple_tensor(conv2d.stride), 'stride', path)
save_tensor(to_tuple_tensor(conv2d.padding), 'padding', path)
save_tensor(to_tuple_tensor(conv2d.dilation), 'dilation', path)
save_scalar(conv2d.groups, "n_group", path)
save_tensor(to_tuple_tensor(conv2d.kernel_size), 'kernel_size', path)
assert conv2d.groups == 1
in_channels = conv2d.weight.shape[1]
out_channels = conv2d.weight.shape[0]
save_scalar(in_channels, "n_channels_in", path)
save_scalar(out_channels, "n_channels_out", path)
def save_padded_conv2d(padded_conv2d, path):
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
# Store conv2d layer weights
orig_padding = padded_conv2d.padding
padded_conv2d.padding = (0, 0)
save_conv2d(padded_conv2d, f"{path}/conv")
padded_conv2d.padding = orig_padding
# Dimensions: in-channels and out-channels
assert padded_conv2d.groups == 1
channels = (padded_conv2d.weight.shape[1], padded_conv2d.weight.shape[0])
save_tensor(to_tuple_tensor(channels), 'channels', path)
assert len(padded_conv2d.kernel_size) == 1 or padded_conv2d.kernel_size[0] == padded_conv2d.kernel_size[1]
save_scalar(padded_conv2d.kernel_size[0], 'kernel_size', path)
# Stride
assert not isinstance(padded_conv2d.stride, tuple) or len(padded_conv2d.stride) == 1
save_scalar(padded_conv2d.stride, 'stride', path)
# Padding
padding = [padded_conv2d.padding[0], padded_conv2d.padding[1],
padded_conv2d.padding[2], padded_conv2d.padding[3]]
save_tensor(Tensor(padding), 'padding', path)
def save_embedding(embedding, path):
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
save_tensor(embedding.weight, 'weight', path)

14
python/stablediffusion.py Normal file
View File

@@ -0,0 +1,14 @@
import pathlib
from autoencoder import save_autoencoder
from unet import save_unet_model
from clip import save_clip_text_transformer
from save import save_scalar, save_tensor
def save_stable_diffusion(stable_diffusion, path):
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
save_scalar(stable_diffusion.alphas_cumprod.shape[0], "n_steps", path)
save_tensor(stable_diffusion.alphas_cumprod, 'alphas_cumprod', path)
save_autoencoder(stable_diffusion.first_stage_model, pathlib.Path(path, 'autoencoder'))
save_unet_model(stable_diffusion.model.diffusion_model, pathlib.Path(path, 'unet'))
save_clip_text_transformer(stable_diffusion.cond_stage_model.transformer.text_model, pathlib.Path(path, 'clip'))

54
python/test.py Normal file
View File

@@ -0,0 +1,54 @@
import torch
import torch.nn as nn
from torch import Tensor
import math
'''import torch
import torch.nn as nn
import torch
norm = torch.nn.LayerNorm(3)
tensor = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape((2, 3))
out = norm(tensor)
print(out)'''
'''n_channel = 6
norm = nn.LayerNorm(10)
height = 10
width = 10
n_elements = height * width * n_channel
t = torch.arange(0, n_elements, dtype=torch.float32).mul_(10.0 / n_elements).sin().reshape(1, n_channel, height, width)
out = norm(t)
print(out)'''
def timestep_embedding(timesteps, dim, max_period=10000):
half = dim // 2
freqs = (-math.log(max_period) * torch.arange(half) / half).exp()
args = timesteps * freqs
return torch.cat( (args.cos(), args.sin()) ).reshape(1, -1)
timesteps = Tensor([1, 2, 3]).reshape((3, 1))
dim = 10
res = timestep_embedding(timesteps, dim)
print(res)
'''n_group = 3
n_channel = 6
norm = nn.GroupNorm(n_group, n_channel)
height = 10
width = 10
n_elements = height * width * n_channel
t = torch.arange(0, n_elements, dtype=torch.float32).mul_(10.0 / n_elements).sin().reshape(1, n_channel, height, width)
out = norm(t)
print(out.flatten())'''

41
python/test_tiny.py Normal file
View File

@@ -0,0 +1,41 @@
from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
import math
'''import torch
import torch.nn as nn
import torch
norm = torch.nn.LayerNorm(3)
tensor = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape((2, 3))
out = norm(tensor)
print(out)'''
n_channel = 6
norm = LayerNorm(10)
height = 10
width = 10
n_elements = height * width * n_channel
t = Tensor.arange(n_elements).mul(10.0 / n_elements).sin().reshape(1, n_channel, height, width)
out = norm(t)
print(out.numpy())
'''n_group = 3
n_channel = 6
norm = nn.GroupNorm(n_group, n_channel)
height = 10
width = 10
n_elements = height * width * n_channel
t = torch.arange(0, n_elements, dtype=torch.float32).mul_(10.0 / n_elements).sin().reshape(1, n_channel, height, width)
out = norm(t)
print(out.flatten())'''

145
python/tokenizer.py Normal file
View File

@@ -0,0 +1,145 @@
import gzip
import html
import os
from functools import lru_cache
import ftfy
import regex as re
@lru_cache()
def default_bpe():
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
class SimpleTokenizer(object):
def __init__(self, bpe_path: str = default_bpe()):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
merges = merges[1:49152-256-2+1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v+'</w>' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|startoftext|>'}
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
pairs = get_pairs(word)
if not pairs:
return token+'</w>'
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
return text
if __name__ == "__main__":
simple_tokenizer = SimpleTokenizer()
#print([simple_tokenizer.byte_encoder[0], simple_tokenizer.byte_encoder[1]])
tokens = [0, 1, 59, 67, 23]
text = simple_tokenizer.decode(tokens)
print(f"Text: {text}")
text = "Hello world! <|startoftext|>asdf<|startoftext|>"
encoded = simple_tokenizer.encode(text)
decoded = simple_tokenizer.decode(encoded)
print(f"Encoded: {encoded}")
print(f"Decoded: {decoded}")

153
python/unet.py Normal file
View File

@@ -0,0 +1,153 @@
import pathlib
import os
import save
from save import *
from tinygrad.nn import Conv2d
def save_res_block(res_block, path):
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
# We can't directly save activation functions, but as they are just attribute of the block,
# we don't need to save them separately, they will be recreated along with the block.
# saving group normalization layer
save_group_norm(res_block.in_layers[0], os.path.join(path, 'norm_in'))
# saving the convolutional layer
save_conv2d(res_block.in_layers[2], os.path.join(path, 'conv_in'))
# saving the linear layer
save_linear(res_block.emb_layers[1], os.path.join(path, 'lin_embed'))
# saving group normalization in out_layers
save_group_norm(res_block.out_layers[0], os.path.join(path, 'norm_out'))
# saving the convolutional layer in out_layers
save_conv2d(res_block.out_layers[3], os.path.join(path, 'conv_out'))
# save skip_connection based on the object type
if isinstance(res_block.skip_connection, Conv2d):
save_conv2d(res_block.skip_connection, os.path.join(path, 'skip_connection'))
def save_cross_attention(cross_attention, path):
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
# Save Linear layers
save_linear(cross_attention.to_q, os.path.join(path, 'query'))
save_linear(cross_attention.to_k, os.path.join(path, 'key'))
save_linear(cross_attention.to_v, os.path.join(path, 'value'))
save_linear(cross_attention.to_out[0], os.path.join(path, 'out'))
# Save parameters
save_scalar(cross_attention.num_heads, 'n_head', path)
def save_geglu(geglu, path):
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
# Save Linear layers
save_linear(geglu.proj, os.path.join(path, 'proj'))
def save_feed_forward(feed_forward, path):
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
# Save GEGLU module
save_geglu(feed_forward.net[0], os.path.join(path, 'geglu'))
# Save Linear layer
save_linear(feed_forward.net[2], os.path.join(path, 'lin'))
def save_basic_transformer_block(basic_transformer_block, path):
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
# Save CrossAttention, FeedForward and LayerNorm instances
save_cross_attention(basic_transformer_block.attn1, os.path.join(path, 'attn1'))
save_feed_forward(basic_transformer_block.ff, os.path.join(path, 'mlp'))
save_cross_attention(basic_transformer_block.attn2, os.path.join(path, 'attn2'))
save_layer_norm(basic_transformer_block.norm1, os.path.join(path, 'norm1'))
save_layer_norm(basic_transformer_block.norm2, os.path.join(path, 'norm2'))
save_layer_norm(basic_transformer_block.norm3, os.path.join(path, 'norm3'))
def save_spatial_transformer(spatial_transformer, path):
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
# Save GroupNorm, Conv2d, BasicTransformerBlock instances
save_group_norm(spatial_transformer.norm, os.path.join(path, 'norm'))
save_conv2d(spatial_transformer.proj_in, os.path.join(path, 'proj_in'))
save_basic_transformer_block(spatial_transformer.transformer_blocks[0], os.path.join(path, 'transformer'))
save_conv2d(spatial_transformer.proj_out, os.path.join(path, 'proj_out'))
def save_downsample(downsample, path):
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
# Save Conv2d instance
save_conv2d(downsample.op, path)
def save_upsample(upsample, path):
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
# Save Conv2d instance
save_conv2d(upsample.conv, os.path.join(path, 'conv'))
def save_res_transformer_res(block, path):
save_res_block(block[0], pathlib.Path(path, 'res1'))
save_spatial_transformer(block[1], pathlib.Path(path, 'transformer'))
save_res_block(block[2], pathlib.Path(path, 'res2'))
def save_res_upsample(block, path):
save_res_block(block[0], pathlib.Path(path, 'res'))
save_upsample(block[1], pathlib.Path(path, 'upsample'))
def save_res_transformer(block, path):
save_res_block(block[0], pathlib.Path(path, 'res'))
save_spatial_transformer(block[1], pathlib.Path(path, 'transformer'))
def save_res_transformer_upsample(block, path):
save_res_block(block[0], pathlib.Path(path, 'res'))
save_spatial_transformer(block[1], pathlib.Path(path, 'transformer'))
save_upsample(block[2], pathlib.Path(path, 'upsample'))
def save_unet_input_blocks(input_blocks, path):
save_conv2d(input_blocks[0][0], pathlib.Path(path, 'conv'))
save_res_transformer(input_blocks[1], pathlib.Path(path, 'rt1'))
save_res_transformer(input_blocks[2], pathlib.Path(path, 'rt2'))
save_downsample(input_blocks[3][0], pathlib.Path(path, 'd1'))
save_res_transformer(input_blocks[4], pathlib.Path(path, 'rt3'))
save_res_transformer(input_blocks[5], pathlib.Path(path, 'rt4'))
save_downsample(input_blocks[6][0], pathlib.Path(path, 'd2'))
save_res_transformer(input_blocks[7], pathlib.Path(path, 'rt5'))
save_res_transformer(input_blocks[8], pathlib.Path(path, 'rt6'))
save_downsample(input_blocks[9][0], pathlib.Path(path, 'd3'))
save_res_block(input_blocks[10][0], pathlib.Path(path, 'r1'))
save_res_block(input_blocks[11][0], pathlib.Path(path, 'r2'))
def save_unet_output_blocks(output_blocks, path):
save_res_block(output_blocks[0][0], pathlib.Path(path, 'r1'))
save_res_block(output_blocks[1][0], pathlib.Path(path, 'r2'))
save_res_upsample(output_blocks[2], pathlib.Path(path, 'ru'))
save_res_transformer(output_blocks[3], pathlib.Path(path, 'rt1'))
save_res_transformer(output_blocks[4], pathlib.Path(path, 'rt2'))
save_res_transformer_upsample(output_blocks[5], pathlib.Path(path, 'rtu1'))
save_res_transformer(output_blocks[6], pathlib.Path(path, 'rt3'))
save_res_transformer(output_blocks[7], pathlib.Path(path, 'rt4'))
save_res_transformer_upsample(output_blocks[8], pathlib.Path(path, 'rtu2'))
save_res_transformer(output_blocks[9], pathlib.Path(path, 'rt5'))
save_res_transformer(output_blocks[10], pathlib.Path(path, 'rt6'))
save_res_transformer(output_blocks[11], pathlib.Path(path, 'rt7'))
def save_unet_model(model, path):
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
save_linear(model.time_embed[0], pathlib.Path(path, 'lin1_time_embed'))
save_linear(model.time_embed[2], pathlib.Path(path, 'lin2_time_embed'))
save_unet_input_blocks(model.input_blocks, pathlib.Path(path, 'input_blocks'))
save_res_transformer_res(model.middle_block, pathlib.Path(path, 'middle_block'))
save_unet_output_blocks(model.output_blocks, pathlib.Path(path, 'output_blocks'))
save_group_norm(model.out[0], pathlib.Path(path, 'norm_out'))
save_conv2d(model.out[2], pathlib.Path(path, 'conv_out'))

139
src/backend.rs Normal file
View File

@@ -0,0 +1,139 @@
use burn::tensor::{activation::softmax, Tensor};
use burn::prelude::Backend;
/*pub type FloatTensor<B, const D: usize> = <B as burn::tensor::backend::Backend>::TensorPrimitive<D>;
pub trait Backend: burn::tensor::backend::Backend {
fn qkv_attention(
q: FloatTensor<Self, 3>,
k: FloatTensor<Self, 3>,
v: FloatTensor<Self, 3>,
mask: Option<FloatTensor<Self, 2>>,
n_head: usize,
) -> FloatTensor<Self, 3> {
qkv_attention(
Tensor::<Self, 3>::from_primitive(q),
Tensor::from_primitive(k),
Tensor::from_primitive(v),
mask.map(|m| Tensor::from_primitive(m)),
n_head,
)
.into_primitive()
}
fn attn_decoder_mask(seq_length: usize, device: &Self::Device) -> FloatTensor<Self, 2> {
attn_decoder_mask::<Self>(seq_length, device).into_primitive()
}
}
use burn::tensor::Float;
use burn_tch::{self, TchElement, TchTensor};
use tch;
impl<E: TchElement> Backend for burn_tch::LibTorch<E> {
fn qkv_attention(
q: FloatTensor<Self, 3>,
k: FloatTensor<Self, 3>,
v: FloatTensor<Self, 3>,
mask: Option<FloatTensor<Self, 2>>,
n_head: usize,
) -> FloatTensor<Self, 2> {
let q = Tensor::from_primitive(q);
let k = Tensor::from_primitive(k);
let v = Tensor::from_primitive(v);
let [n_batch, q_ctx, n_state] = q.dims();
let [_, k_ctx, _] = k.dims();
let n_hstate = n_state / n_head;
let rearrange = |t: Tensor<Self, 3>| {
let [_, n_ctx, _] = t.dims();
t.reshape([n_batch, n_ctx, n_head, n_hstate])
.swap_dims(1, 2)
};
let q = rearrange(q).into_primitive();
let k = rearrange(k).into_primitive();
let v = rearrange(v).into_primitive();
// for some reason torch crashes when mask is None
let mask = mask.unwrap_or_else(|| {
Tensor::<Self, 2, Float>::zeros([q_ctx, k_ctx], &Self::device(&v))
.into_primitive()
});
Tensor::<Self, 4>::from_primitive(TchTensor::new(
tch::Tensor::scaled_dot_product_attention(
&q.tensor,
&k.tensor,
&v.tensor,
Some(mask.tensor),
0.0,
false,
None,
),
))
.swap_dims(1, 2)
.flatten(2, 3)
.into_primitive()
}
}
use burn_autodiff;
impl<B: Backend> Backend for burn_autodiff::Autodiff<B> {}*/
use std::f32::NEG_INFINITY;
pub fn qkv_attention<B: Backend>(
q: Tensor<B, 3>,
k: Tensor<B, 3>,
v: Tensor<B, 3>,
mask: Option<Tensor<B, 2>>,
n_head: usize,
) -> Tensor<B, 3> {
let [n_batch, n_qctx, n_state] = q.dims();
let [_, n_ctx, _] = k.dims();
let scale = (n_state as f64 / n_head as f64).powf(-0.25);
let n_hstate = n_state / n_head;
let q = q
.reshape([n_batch, n_qctx, n_head, n_hstate])
.swap_dims(1, 2)
* scale;
let k = k
.reshape([n_batch, n_ctx, n_head, n_hstate])
.swap_dims(1, 2)
.transpose()
* scale;
let v = v
.reshape([n_batch, n_ctx, n_head, n_hstate])
.swap_dims(1, 2);
let qk = q.matmul(k);
// apply mask
let qk = if let Some(mask) = mask {
qk + mask.slice([0..n_qctx, 0..n_ctx]).unsqueeze::<4>()
} else {
qk
};
// normalize value weightings
let w = softmax(qk, 3);
let o = w.matmul(v).swap_dims(1, 2).flatten(2, 3);
return o;
}
pub fn attn_decoder_mask<B: Backend>(seq_length: usize, device: &B::Device) -> Tensor<B, 2> {
let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length], device);
for i in 0..(seq_length - 1) {
let values = Tensor::<B, 2>::zeros([1, seq_length - (i + 1)], device).add_scalar(NEG_INFINITY);
mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values);
}
return mask;
}

58
src/bin/convert/main.rs Normal file
View File

@@ -0,0 +1,58 @@
use std::env;
use std::error::Error;
use std::process;
use stablediffusion::model::stablediffusion::{load::load_stable_diffusion, StableDiffusion};
use burn::{
config::Config,
module::{Module, Param},
nn,
tensor::{backend::Backend, Tensor},
};
use burn_ndarray::{NdArray, NdArrayDevice};
use burn::record::{self, NamedMpkFileRecorder, FullPrecisionSettings, Recorder};
fn convert_dump_to_model<B: Backend>(
dump_path: &str,
model_name: &str,
device: &B::Device,
) -> Result<(), Box<dyn Error>> {
println!("Loading dump...");
let model: StableDiffusion<B> = load_stable_diffusion(dump_path, device)?;
println!("Saving model...");
save_model_file(model, model_name)?;
Ok(())
}
fn save_model_file<B: Backend>(
model: StableDiffusion<B>,
name: &str,
) -> Result<(), record::RecorderError> {
NamedMpkFileRecorder::<FullPrecisionSettings>::new().record(model.into_record(), name.into())
}
fn main() {
type Backend = NdArray<f32>;
let device = NdArrayDevice::Cpu;
let args: Vec<String> = env::args().collect();
if args.len() != 3 {
eprintln!("Usage: {} <dump_path> <model_name>", args[0]);
process::exit(1);
}
let dump_path = &args[1];
let model_name = &args[2];
if let Err(e) = convert_dump_to_model::<Backend>(dump_path, model_name, &device) {
eprintln!("Failed to convert dump to model: {:?}", e);
process::exit(1);
}
println!("Successfully converted {} to {}", dump_path, model_name);
}

126
src/bin/sample/main.rs Normal file
View File

@@ -0,0 +1,126 @@
use stablediffusion::{
model::stablediffusion::{load::load_stable_diffusion, *},
tokenizer::SimpleTokenizer,
};
use burn::{
config::Config,
module::{Module, Param},
nn,
tensor::{backend::Backend, Tensor},
};
cfg_if::cfg_if! {
if #[cfg(feature = "wgpu-backend")] {
use burn_wgpu::{Wgpu, WgpuDevice};
} else {
use burn_ndarray::NdArrayDevice;
}
}
use std::env;
use std::io;
use std::process;
use burn::record::{self, NamedMpkFileRecorder, FullPrecisionSettings, Recorder};
fn load_stable_diffusion_model_file<B: Backend>(
filename: &str,
device: &B::Device,
) -> Result<StableDiffusion<B>, record::RecorderError> {
NamedMpkFileRecorder::<FullPrecisionSettings>::new()
.load(filename.into(), device)
.map(|record| StableDiffusionConfig::new().init(device).load_record(record))
}
fn main() {
let args: Vec<String> = std::env::args().collect();
if args.len() != 7 && args.len() != 8 {
eprintln!("Usage: {} <model_type(burn or dump)> <model_name> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image_name> [device(cuda, mps, cpu)]", args[0]);
process::exit(1);
}
let model_type = &args[1];
let model_name = &args[2];
let unconditional_guidance_scale: f64 = args[3].parse().unwrap_or_else(|_| {
eprintln!("Error: Invalid unconditional guidance scale.");
process::exit(1);
});
let n_steps: usize = args[4].parse().unwrap_or_else(|_| {
eprintln!("Error: Invalid number of diffusion steps.");
process::exit(1);
});
let prompt = &args[5];
let output_image_name = &args[6];
// Optional device parameter
let device_arg = if args.len() == 8 { Some(&args[7]) } else { None };
cfg_if::cfg_if! {
if #[cfg(feature = "wgpu-backend")] {
type Backend = Wgpu;
let device = WgpuDevice::BestAvailable;
} else {
type Backend = burn::backend::ndarray::NdArray<f32>;
let device = NdArrayDevice::Cpu;
}
}
println!("Loading tokenizer...");
let tokenizer = SimpleTokenizer::new().unwrap();
println!("Loading model...");
let sd: StableDiffusion<Backend> = if model_type == "burn" {
load_stable_diffusion_model_file(model_name, &device).unwrap_or_else(|err| {
eprintln!("Error loading model: {}", err);
process::exit(1);
})
} else {
load_stable_diffusion(model_name, &device).unwrap_or_else(|err| {
eprintln!("Error loading model dump: {}", err);
process::exit(1);
})
};
let unconditional_context = sd.unconditional_context(&tokenizer);
let context = sd.context(&tokenizer, prompt).unsqueeze::<3>(); //.repeat(0, 2); // generate 2 samples
println!("Sampling image...");
let images = sd.sample_image(
context,
unconditional_context,
unconditional_guidance_scale,
n_steps,
);
save_images(&images, output_image_name, 512, 512).unwrap_or_else(|err| {
eprintln!("Error saving image: {}", err);
process::exit(1);
});
}
use image::{self, ColorType::Rgb8, ImageResult};
fn save_images(images: &Vec<Vec<u8>>, basepath: &str, width: u32, height: u32) -> ImageResult<()> {
for (index, img_data) in images.iter().enumerate() {
let path = format!("{}{}.png", basepath, index);
image::save_buffer(path, &img_data[..], width, height, Rgb8)?;
}
Ok(())
}
// save red test image
fn save_test_image() -> ImageResult<()> {
let width = 256;
let height = 256;
let raw: Vec<_> = (0..width * height)
.into_iter()
.flat_map(|i| {
let row = i / width;
let red = (255.0 * row as f64 / height as f64) as u8;
[red, 0, 0]
})
.collect();
image::save_buffer("red.png", &raw[..], width, height, Rgb8)
}

3
src/lib.rs Normal file
View File

@@ -0,0 +1,3 @@
pub mod backend;
pub mod model;
pub mod tokenizer;

56
src/model/attention.rs Normal file
View File

@@ -0,0 +1,56 @@
use burn::tensor::{activation::softmax, backend::Backend, Tensor};
use std::f32::NEG_INFINITY;
pub fn qkv_attention<B: Backend>(
q: Tensor<B, 3>,
k: Tensor<B, 3>,
v: Tensor<B, 3>,
mask: Option<Tensor<B, 2>>,
n_head: usize,
) -> Tensor<B, 3> {
let [n_batch, n_qctx, n_state] = q.dims();
let [_, n_ctx, _] = k.dims();
let scale = (n_state as f64 / n_head as f64).powf(-0.25);
let n_hstate = n_state / n_head;
let q = q
.reshape([n_batch, n_qctx, n_head, n_hstate])
.swap_dims(1, 2)
* scale;
let k = k
.reshape([n_batch, n_ctx, n_head, n_hstate])
.swap_dims(1, 2)
.transpose()
* scale;
let v = v
.reshape([n_batch, n_ctx, n_head, n_hstate])
.swap_dims(1, 2);
let qk = q.matmul(k);
// apply mask
let qk = if let Some(mask) = mask {
qk + mask.slice([0..n_qctx, 0..n_ctx]).unsqueeze::<4>()
} else {
qk
};
// normalize value weightings
let w = softmax(qk, 3);
let o = w.matmul(v).swap_dims(1, 2).flatten(2, 3);
return o;
}
pub fn attn_decoder_mask<B: Backend>(seq_length: usize, device: &B::Device) -> Tensor<B, 2> {
let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length], device);
for i in 0..(seq_length - 1) {
let values = Tensor::<B, 2>::zeros([1, seq_length - (i + 1)], device).add_scalar(NEG_INFINITY);
mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values);
}
return mask;
}

View File

@@ -0,0 +1,198 @@
use super::GroupNorm;
use crate::model::load::*;
use std::error::Error;
use burn::{
config::Config,
module::{Module, Param},
nn,
tensor::{backend::Backend, Tensor},
};
use super::*;
use crate::model::groupnorm::load::load_group_norm;
fn load_conv_self_attention_block<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<ConvSelfAttentionBlock<B>, Box<dyn Error>> {
let norm = load_group_norm(&format!("{}/{}", path, "norm"), device)?;
let q = load_conv2d(&format!("{}/{}", path, "q"), device)?;
let k = load_conv2d(&format!("{}/{}", path, "k"), device)?;
let v = load_conv2d(&format!("{}/{}", path, "v"), device)?;
let proj_out = load_conv2d(&format!("{}/{}", path, "proj_out"), device)?;
Ok(ConvSelfAttentionBlock {
norm,
q,
k,
v,
proj_out,
})
}
fn load_resnet_block<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<ResnetBlock<B>, Box<dyn Error>> {
let norm1 = load_group_norm(&format!("{}/{}", path, "norm1"), device)?;
let silu1 = SILU {};
let conv1 = load_conv2d(&format!("{}/{}", path, "conv1"), device)?;
let norm2 = load_group_norm(&format!("{}/{}", path, "norm2"), device)?;
let silu2 = SILU {};
let conv2 = load_conv2d(&format!("{}/{}", path, "conv2"), device)?;
let nin_shortcut = load_conv2d(&format!("{}/{}", path, "nin_shortcut"), device).ok();
Ok(ResnetBlock {
norm1,
silu1,
conv1,
norm2,
silu2,
conv2,
nin_shortcut,
})
}
fn load_mid<B: Backend>(path: &str, device: &B::Device) -> Result<Mid<B>, Box<dyn Error>> {
let block_1 = load_resnet_block(&format!("{}/{}", path, "block_1"), device)?;
let attn = load_conv_self_attention_block(&format!("{}/{}", path, "attn"), device)?;
let block_2 = load_resnet_block(&format!("{}/{}", path, "block_2"), device)?;
Ok(Mid {
block_1,
attn,
block_2,
})
}
fn load_padded_conv2d<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<PaddedConv2d<B>, Box<dyn Error>> {
let mut conv = load_conv2d(&format!("{}/{}", path, "conv"), device)?;
let channels = load_tensor::<B, 1>("channels", path, device)?;
let channels = tensor_to_array_2(channels);
let kernel_size = load_usize::<B>("kernel_size", path, device)?;
let stride = load_usize::<B>("stride", path, device)?;
let padding = load_tensor::<B, 1>("padding", path, device)?;
let padding: [usize; 4] = tensor_to_array(padding);
let padding = PaddingCfg::new(padding[0], padding[1], padding[2], padding[3]);
//let mut record = conv.into_record();
let mut padded_conv: PaddedConv2d<B> = PaddedConv2dConfig::new(channels, kernel_size, padding)
.with_stride(stride)
.init(device);
let padding_actual =
PaddingConfig2d::Explicit(padded_conv.padding_actual[0], padded_conv.padding_actual[1]);
conv.padding = burn::module::Ignored(padding_actual);
padded_conv.conv = conv;
//record.padding = <PaddingConfig2d as Module<B>>::into_record(padding_actual);
//padded_conv.conv = padded_conv.conv.load_record(record);
Ok(padded_conv)
}
fn load_decoder_block<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<DecoderBlock<B>, Box<dyn Error>> {
let res1 = load_resnet_block(&format!("{}/{}", path, "res1"), device)?;
let res2 = load_resnet_block(&format!("{}/{}", path, "res2"), device)?;
let res3 = load_resnet_block(&format!("{}/{}", path, "res3"), device)?;
let upsampler = load_conv2d(&format!("{}/{}", path, "upsampler"), device).ok();
Ok(DecoderBlock {
res1,
res2,
res3,
upsampler,
})
}
fn load_encoder_block<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<EncoderBlock<B>, Box<dyn Error>> {
let res1 = load_resnet_block(&format!("{}/{}", path, "res1"), device)?;
let res2 = load_resnet_block(&format!("{}/{}", path, "res2"), device)?;
let downsampler = load_padded_conv2d(&format!("{}/{}", path, "downsampler"), device).ok();
Ok(EncoderBlock {
res1,
res2,
downsampler,
})
}
fn load_decoder<B: Backend>(path: &str, device: &B::Device) -> Result<Decoder<B>, Box<dyn Error>> {
let conv_in = load_conv2d(&format!("{}/{}", path, "conv_in"), device)?;
let mid = load_mid(&format!("{}/{}", path, "mid"), device)?;
let n_block = load_usize::<B>("n_block", path, device)?;
let mut blocks = (0..n_block)
.into_iter()
.map(|i| load_decoder_block::<B>(&format!("{}/blocks/{}", path, i), device))
.collect::<Result<Vec<_>, _>>()?;
let norm_out = load_group_norm(&format!("{}/{}", path, "norm_out"), device)?;
let silu = SILU {};
let conv_out = load_conv2d(&format!("{}/{}", path, "conv_out"), device)?;
Ok(Decoder {
conv_in,
mid,
blocks,
norm_out,
silu,
conv_out,
})
}
fn load_encoder<B: Backend>(path: &str, device: &B::Device) -> Result<Encoder<B>, Box<dyn Error>> {
let conv_in = load_conv2d(&format!("{}/{}", path, "conv_in"), device)?;
let mid = load_mid(&format!("{}/{}", path, "mid"), device)?;
let n_block = load_usize::<B>("n_block", path, device)?;
let mut blocks = (0..n_block)
.into_iter()
.map(|i| load_encoder_block::<B>(&format!("{}/blocks/{}", path, i), device))
.collect::<Result<Vec<_>, _>>()?;
let norm_out = load_group_norm(&format!("{}/{}", path, "norm_out"), device)?;
let silu = SILU {};
let conv_out = load_conv2d(&format!("{}/{}", path, "conv_out"), device)?;
Ok(Encoder {
conv_in,
mid,
blocks,
norm_out,
silu,
conv_out,
})
}
pub fn load_autoencoder<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<Autoencoder<B>, Box<dyn Error>> {
let encoder = load_encoder(&format!("{}/{}", path, "encoder"), device)?;
let decoder = load_decoder(&format!("{}/{}", path, "decoder"), device)?;
let quant_conv = load_conv2d(&format!("{}/{}", path, "quant_conv"), device)?;
let post_quant_conv = load_conv2d(&format!("{}/{}", path, "post_quant_conv"), device)?;
Ok(Autoencoder {
encoder,
decoder,
quant_conv,
post_quant_conv,
})
}

View File

@@ -0,0 +1,608 @@
pub mod load;
use burn::{
config::Config,
module::{Module, Param},
nn::{
self,
conv::{Conv2d, Conv2dConfig, Conv2dRecord},
PaddingConfig2d,
},
tensor::{
activation::{sigmoid, softmax},
backend::Backend,
module::embedding,
Distribution, Int, Tensor,
},
};
use super::groupnorm::*;
use super::silu::*;
//use crate::backend::Backend as MyBackend;
use crate::backend::{qkv_attention, attn_decoder_mask};
use std::iter;
#[derive(Config, Debug)]
pub struct AutoencoderConfig {}
impl AutoencoderConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> Autoencoder<B> {
let encoder =
EncoderConfig::new(vec![(128, 128), (128, 256), (256, 512), (512, 512)], 32, 8).init(device);
let decoder =
DecoderConfig::new(vec![(512, 512), (512, 512), (512, 256), (256, 128)], 32).init(device);
let quant_conv = Conv2dConfig::new([8, 8], [1, 1]).init(device);
let post_quant_conv = Conv2dConfig::new([4, 4], [1, 1]).init(device);
Autoencoder {
encoder,
decoder,
quant_conv,
post_quant_conv,
}
}
}
#[derive(Module, Debug)]
pub struct Autoencoder<B: Backend> {
encoder: Encoder<B>,
decoder: Decoder<B>,
quant_conv: Conv2d<B>,
post_quant_conv: Conv2d<B>,
}
impl<B: Backend> Autoencoder<B> {
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
self.decode_latent(self.encode_image(x))
}
pub fn encode_image(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let [n_batch, _, _, _] = x.dims();
let latent = self.encoder.forward(x);
let latent = self.quant_conv.forward(latent);
let latent = latent.slice([0..n_batch, 0..4]);
latent
}
pub fn decode_latent(&self, latent: Tensor<B, 4>) -> Tensor<B, 4> {
let latent = self.post_quant_conv.forward(latent);
self.decoder.forward(latent)
}
}
#[derive(Config, Debug)]
pub struct EncoderConfig {
channels: Vec<(usize, usize)>,
n_group: usize,
n_channels_out: usize,
}
impl EncoderConfig {
fn init<B: Backend>(&self, device: &B::Device) -> Encoder<B> {
let n_expanded_channels_initial = self
.channels
.first()
.map(|f| f.1)
.expect("Channels must not be empty.");
let n_expanded_channels_final = self.channels.first().unwrap().0;
let conv_in = Conv2dConfig::new([3, n_expanded_channels_initial], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init(device);
let blocks = self
.channels
.iter()
.enumerate()
.map(|(i, &(n_channel_in, n_channel_out))| {
let downsample = i != self.channels.len() - 1;
EncoderBlockConfig::new(n_channel_in, n_channel_out, downsample).init(device)
})
.collect();
let mid = MidConfig::new(n_expanded_channels_final).init(device);
let norm_out = GroupNormConfig::new(self.n_group, n_expanded_channels_final).init(device);
let silu = SILU::new();
let conv_out = Conv2dConfig::new([n_expanded_channels_final, self.n_channels_out], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init(device);
Encoder {
conv_in,
mid,
blocks,
norm_out,
silu,
conv_out,
}
}
}
#[derive(Module, Debug)]
pub struct Encoder<B: Backend> {
conv_in: Conv2d<B>,
mid: Mid<B>,
blocks: Vec<EncoderBlock<B>>,
norm_out: GroupNorm<B>,
silu: SILU,
conv_out: Conv2d<B>,
}
impl<B: Backend> Encoder<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.conv_in.forward(x);
let mut x = x;
for block in &self.blocks {
x = block.forward(x);
}
let x = self.mid.forward(x);
self.conv_out
.forward(self.silu.forward(self.norm_out.forward(x)))
}
}
#[derive(Config, Debug)]
pub struct DecoderConfig {
channels: Vec<(usize, usize)>,
n_group: usize,
}
impl DecoderConfig {
fn init<B: Backend>(&self, device: &B::Device) -> Decoder<B> {
let n_expanded_channels = self
.channels
.first()
.map(|f| f.0)
.expect("Channels must not be empty.");
let n_condensed_channels = self.channels.last().unwrap().1;
let conv_in = Conv2dConfig::new([4, n_expanded_channels], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init(device);
let mid = MidConfig::new(n_expanded_channels).init(device);
let blocks = self
.channels
.iter()
.enumerate()
.map(|(i, &(n_channel_in, n_channel_out))| {
let upsample = i != self.channels.len() - 1;
DecoderBlockConfig::new(n_channel_in, n_channel_out, upsample).init(device)
})
.collect();
let norm_out = GroupNormConfig::new(self.n_group, n_condensed_channels).init(device);
let silu = SILU::new();
let conv_out = Conv2dConfig::new([n_condensed_channels, 3], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init(device);
Decoder {
conv_in,
mid,
blocks,
norm_out,
silu,
conv_out,
}
}
}
#[derive(Module, Debug)]
pub struct Decoder<B: Backend> {
conv_in: Conv2d<B>,
mid: Mid<B>,
blocks: Vec<DecoderBlock<B>>,
norm_out: GroupNorm<B>,
silu: SILU,
conv_out: Conv2d<B>,
}
impl<B: Backend> Decoder<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.conv_in.forward(x);
let x = self.mid.forward(x);
let mut x = x;
for block in &self.blocks {
x = block.forward(x);
}
self.conv_out
.forward(self.silu.forward(self.norm_out.forward(x)))
}
}
#[derive(Config, Debug)]
pub struct EncoderBlockConfig {
n_channels_in: usize,
n_channels_out: usize,
downsample: bool,
}
impl EncoderBlockConfig {
fn init<B: Backend>(&self, device: &B::Device) -> EncoderBlock<B> {
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init(device);
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(device);
let downsampler = if self.downsample {
let padding = PaddingCfg::new(0, 1, 0, 1);
Some(
PaddedConv2dConfig::new([self.n_channels_out, self.n_channels_out], 3, padding)
.with_stride(2)
.init(device),
)
} else {
None
};
EncoderBlock {
res1,
res2,
downsampler,
}
}
}
#[derive(Module, Debug)]
pub struct EncoderBlock<B: Backend> {
res1: ResnetBlock<B>,
res2: ResnetBlock<B>,
downsampler: Option<PaddedConv2d<B>>,
}
impl<B: Backend> EncoderBlock<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.res1.forward(x);
let x = self.res2.forward(x);
if let Some(d) = self.downsampler.as_ref() {
d.forward(x)
} else {
x
}
}
}
#[derive(Config, Debug)]
pub struct DecoderBlockConfig {
n_channels_in: usize,
n_channels_out: usize,
upsample: bool,
}
impl DecoderBlockConfig {
fn init<B: Backend>(&self, device: &B::Device) -> DecoderBlock<B> {
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init(device);
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(device);
let res3 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(device);
let upsampler = if self.upsample {
Some(
Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init(device),
)
} else {
None
};
DecoderBlock {
res1,
res2,
res3,
upsampler,
}
}
}
#[derive(Module, Debug)]
pub struct DecoderBlock<B: Backend> {
res1: ResnetBlock<B>,
res2: ResnetBlock<B>,
res3: ResnetBlock<B>,
upsampler: Option<Conv2d<B>>,
}
impl<B: Backend> DecoderBlock<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.res1.forward(x);
let x = self.res2.forward(x);
let x = self.res3.forward(x);
if let Some(d) = self.upsampler.as_ref() {
let [n_batch, n_channel, height, width] = x.dims();
let x = x
.reshape([n_batch, n_channel, height, 1, width, 1])
.repeat(&[1, 1, 1, 2, 1, 2])
.reshape([n_batch, n_channel, 2 * height, 2 * width]);
d.forward(x)
} else {
x
}
}
}
#[derive(Config, Debug)]
pub struct PaddedConv2dConfig {
channels: [usize; 2],
kernel_size: usize,
#[config(default = 1)]
stride: usize,
padding: PaddingCfg,
}
impl PaddedConv2dConfig {
fn init<B: Backend>(&self, device: &B::Device) -> PaddedConv2d<B> {
let calc_padding = |p_left, p_right| {
let n = if p_left >= p_right {
0
} else {
div_roundup(p_right - p_left, self.stride)
};
n * self.stride + p_left
};
let pad_vertical = calc_padding(self.padding.pad_top, self.padding.pad_bottom);
let pad_horizontal = calc_padding(self.padding.pad_left, self.padding.pad_right);
let padding_actual = [pad_vertical, pad_horizontal];
let conv = Conv2dConfig::new(self.channels, [self.kernel_size, self.kernel_size])
.with_stride([self.stride, self.stride])
.with_padding(PaddingConfig2d::Explicit(pad_vertical, pad_horizontal))
.init(device);
let kernel_size = self.kernel_size;
let stride = self.stride;
let padding = Padding {
pad_left: self.padding.pad_left,
pad_right: self.padding.pad_right,
pad_top: self.padding.pad_top,
pad_bottom: self.padding.pad_bottom,
};
PaddedConv2d {
conv,
kernel_size,
stride,
padding,
padding_actual,
}
}
}
fn div_roundup(x: usize, y: usize) -> usize {
(x + y - 1) / y
}
#[derive(Module, Debug)]
pub struct PaddedConv2d<B: Backend> {
conv: Conv2d<B>,
kernel_size: usize,
stride: usize,
padding: Padding,
padding_actual: [usize; 2],
}
impl<B: Backend> PaddedConv2d<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let [n_batch, n_channel, height, width] = x.dims();
let desired_height = (self.padding.pad_top + self.padding.pad_bottom + height
- self.kernel_size)
/ self.stride
+ 1;
let desired_width = (self.padding.pad_left + self.padding.pad_right + width
- self.kernel_size)
/ self.stride
+ 1;
let skip_vert = (self.padding_actual[0] - self.padding.pad_top) / self.stride;
let skip_hor = (self.padding_actual[1] - self.padding.pad_left) / self.stride;
self.conv.forward(x).slice([
0..n_batch,
0..n_channel,
skip_vert..(skip_vert + desired_height),
skip_hor..(skip_hor + desired_width),
])
}
}
#[derive(Config, Debug)]
pub struct PaddingCfg {
pad_left: usize,
pad_right: usize,
pad_top: usize,
pad_bottom: usize,
}
#[derive(Module, Clone, Debug)]
pub struct Padding {
pad_left: usize,
pad_right: usize,
pad_top: usize,
pad_bottom: usize,
}
#[derive(Config, Debug)]
pub struct MidConfig {
n_channel: usize,
}
impl MidConfig {
fn init<B: Backend>(&self, device: &B::Device) -> Mid<B> {
let block_1 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init(device);
let attn = ConvSelfAttentionBlockConfig::new(self.n_channel).init(device);
let block_2 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init(device);
Mid {
block_1,
attn,
block_2,
}
}
}
#[derive(Module, Debug)]
pub struct Mid<B: Backend> {
block_1: ResnetBlock<B>,
attn: ConvSelfAttentionBlock<B>,
block_2: ResnetBlock<B>,
}
impl<B: Backend> Mid<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.block_1.forward(x);
let x = self.attn.forward(x);
let x = self.block_2.forward(x);
x
}
}
#[derive(Config, Debug)]
pub struct ResnetBlockConfig {
in_channels: usize,
out_channels: usize,
}
impl ResnetBlockConfig {
fn init<B: Backend>(&self, device: &B::Device) -> ResnetBlock<B> {
let norm1 = GroupNormConfig::new(32, self.in_channels).init(device);
let conv1 = Conv2dConfig::new([self.in_channels, self.out_channels], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init(device);
let norm2 = GroupNormConfig::new(32, self.out_channels).init(device);
let conv2 = Conv2dConfig::new([self.out_channels, self.out_channels], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init(device);
let nin_shortcut = if self.in_channels != self.out_channels {
Some(Conv2dConfig::new([self.in_channels, self.out_channels], [1, 1]).init(device))
} else {
None
};
let silu1 = SILU::new();
let silu2 = SILU::new();
ResnetBlock {
norm1,
silu1,
conv1,
norm2,
silu2,
conv2,
nin_shortcut,
}
}
}
#[derive(Module, Debug)]
pub struct ResnetBlock<B: Backend> {
norm1: GroupNorm<B>,
silu1: SILU,
conv1: Conv2d<B>,
norm2: GroupNorm<B>,
silu2: SILU,
conv2: Conv2d<B>,
nin_shortcut: Option<Conv2d<B>>,
}
impl<B: Backend> ResnetBlock<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let h = self
.conv1
.forward(self.silu1.forward(self.norm1.forward(x.clone())));
let h = self
.conv2
.forward(self.silu2.forward(self.norm2.forward(h)));
if let Some(ns) = self.nin_shortcut.as_ref() {
ns.forward(x) + h
} else {
x + h
}
}
}
#[derive(Config, Debug)]
pub struct ConvSelfAttentionBlockConfig {
n_channel: usize,
}
impl ConvSelfAttentionBlockConfig {
fn init<B: Backend>(&self, device: &B::Device) -> ConvSelfAttentionBlock<B> {
let norm = GroupNormConfig::new(32, self.n_channel).init(device);
let q = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
let k = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
let v = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
let proj_out = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
ConvSelfAttentionBlock {
norm,
q,
k,
v,
proj_out,
}
}
}
#[derive(Module, Debug)]
pub struct ConvSelfAttentionBlock<B: Backend> {
norm: GroupNorm<B>,
q: Conv2d<B>,
k: Conv2d<B>,
v: Conv2d<B>,
proj_out: Conv2d<B>,
}
impl<B: Backend> ConvSelfAttentionBlock<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let [n_batch, n_channel, height, width] = x.dims();
let h = self.norm.forward(x.clone());
let q = self
.q
.forward(h.clone())
.reshape([n_batch, n_channel, height * width])
.swap_dims(1, 2);
let k = self
.k
.forward(h.clone())
.reshape([n_batch, n_channel, height * width])
.swap_dims(1, 2);
let v = self
.v
.forward(h)
.reshape([n_batch, n_channel, height * width])
.swap_dims(1, 2);
/*let wv = Tensor::from_primitive(B::qkv_attention(
q.into_primitive(),
k.into_primitive(),
v.into_primitive(),
None,
1,
))
.swap_dims(1, 2)
.reshape([n_batch, n_channel, height, width]);*/
let wv = qkv_attention(
q,
k,
v,
None,
1,
)
.swap_dims(1, 2)
.reshape([n_batch, n_channel, height, width]);
let projected = self.proj_out.forward(wv);
x + projected
}
}

91
src/model/clip/load.rs Normal file
View File

@@ -0,0 +1,91 @@
use burn::tensor::ElementConversion;
use std::error::Error;
use burn::{
config::Config,
module::{Module, Param},
nn,
tensor::{backend::Backend, Tensor},
};
use super::*;
use crate::model::load::*;
pub fn load_mlp<B: Backend>(path: &str, device: &B::Device) -> Result<MLP<B>, Box<dyn Error>> {
let fc1 = load_linear(&format!("{}/{}", path, "fc1"), device)?;
let gelu = QuickGELU::new();
let fc2 = load_linear(&format!("{}/{}", path, "fc2"), device)?;
let mlp = MLP {
fc1: fc1,
gelu: gelu,
fc2: fc2,
};
Ok(mlp)
}
pub fn load_multi_head_self_attention<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<MultiHeadSelfAttention<B>, Box<dyn Error>> {
let n_head = load_usize::<B>("n_head", path, device)?;
let query = load_linear(&format!("{}/{}", path, "query"), device)?;
let key = load_linear(&format!("{}/{}", path, "key"), device)?;
let value = load_linear(&format!("{}/{}", path, "value"), device)?;
let out = load_linear(&format!("{}/{}", path, "out"), device)?;
let mhsa = MultiHeadSelfAttention {
n_head: n_head,
query: query,
key: key,
value: value,
out: out,
};
Ok(mhsa)
}
pub fn load_residual_decoder_attention_block<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<ResidualDecoderAttentionBlock<B>, Box<dyn Error>> {
let mlp = load_mlp(&format!("{}/{}", path, "mlp"), device)?;
let attn = load_multi_head_self_attention(&format!("{}/{}", path, "attn"), device)?;
let attn_ln = load_layer_norm(&format!("{}/{}", path, "attn_ln"), device)?;
let mlp_ln = load_layer_norm(&format!("{}/{}", path, "mlp_ln"), device)?;
let rdab = ResidualDecoderAttentionBlock {
attn: attn,
attn_ln: attn_ln,
mlp: mlp,
mlp_ln: mlp_ln,
};
Ok(rdab)
}
pub fn load_clip<B: Backend>(path: &str, device: &B::Device) -> Result<CLIP<B>, Box<dyn Error>> {
let token_embedding = load_embedding(&format!("{}/{}", path, "token_embedding"), device)?;
let position_embedding =
Param::from_tensor(load_tensor("weight", &format!("{}/position_embedding", path), device)?);
let n_layer = load_usize::<B>("n_layer", path, device)?;
let mut blocks = (0..n_layer)
.into_iter()
.map(|i| {
load_residual_decoder_attention_block::<B>(&format!("{}/blocks/{}", path, i), device)
})
.collect::<Result<Vec<_>, _>>()?;
let layer_norm = load_layer_norm(&format!("{}/{}", path, "layer_norm"), device)?;
let clip = CLIP {
token_embedding: token_embedding,
position_embedding: position_embedding,
blocks: blocks,
layer_norm: layer_norm,
};
Ok(clip)
}

227
src/model/clip/mod.rs Normal file
View File

@@ -0,0 +1,227 @@
pub mod load;
use burn::{
config::Config,
module::{Module, Param},
nn,
tensor::{
activation::{sigmoid, softmax},
backend::Backend,
module::embedding,
Distribution, Int, Tensor,
},
};
//use crate::backend::Backend as MyBackend;
use crate::backend::{qkv_attention, attn_decoder_mask};
#[derive(Config, Debug)]
pub struct CLIPConfig {
n_vocab: usize,
n_state: usize,
n_head: usize,
n_ctx: usize,
n_layer: usize,
}
impl CLIPConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> CLIP<B> {
let token_embedding = nn::EmbeddingConfig::new(self.n_vocab, self.n_state).init(device);
let position_embedding =
Param::from_tensor(Tensor::random([self.n_ctx, self.n_state], Distribution::Normal(0.0, 1.0), device));
let blocks = (0..self.n_layer)
.into_iter()
.map(|_| ResidualDecoderAttentionBlockConfig::new(self.n_state, self.n_head).init(device))
.collect();
let layer_norm = nn::LayerNormConfig::new(self.n_state).init(device);
CLIP {
token_embedding,
position_embedding,
blocks,
layer_norm,
}
}
}
#[derive(Module, Debug)]
pub struct CLIP<B: Backend> {
token_embedding: nn::Embedding<B>,
position_embedding: Param<Tensor<B, 2>>,
blocks: Vec<ResidualDecoderAttentionBlock<B>>,
layer_norm: nn::LayerNorm<B>,
}
impl<B: Backend> CLIP<B> {
pub fn forward(&self, x: Tensor<B, 2, Int>) -> Tensor<B, 3> {
let [n_batch, seq_len] = x.dims();
//let mask = Tensor::from_primitive(B::attn_decoder_mask(seq_len, &x.device()));
let mask = attn_decoder_mask(seq_len, &x.device());
let embedded = self.token_embedding.forward(x)
+ self
.position_embedding
.val()
.slice([0..seq_len])
.unsqueeze();
let mut x = embedded;
for block in &self.blocks {
x = block.forward(x, mask.clone());
}
self.layer_norm.forward(x)
}
}
#[derive(Config, Debug)]
pub struct ResidualDecoderAttentionBlockConfig {
n_state: usize,
n_head: usize,
}
impl ResidualDecoderAttentionBlockConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> ResidualDecoderAttentionBlock<B> {
let attn = MultiHeadSelfAttentionConfig::new(self.n_state, self.n_head).init(device);
let attn_ln = nn::LayerNormConfig::new(self.n_state).init(device);
let mlp = MLPConfig::new(self.n_state, 4 * self.n_state).init(device);
let mlp_ln = nn::LayerNormConfig::new(self.n_state).init(device);
ResidualDecoderAttentionBlock {
attn,
attn_ln,
mlp,
mlp_ln,
}
}
}
#[derive(Module, Debug)]
pub struct ResidualDecoderAttentionBlock<B: Backend> {
attn: MultiHeadSelfAttention<B>,
attn_ln: nn::LayerNorm<B>,
mlp: MLP<B>,
mlp_ln: nn::LayerNorm<B>,
}
impl<B: Backend> ResidualDecoderAttentionBlock<B> {
fn forward(&self, x: Tensor<B, 3>, mask: Tensor<B, 2>) -> Tensor<B, 3> {
let x = x.clone() + self.attn.forward(self.attn_ln.forward(x), Some(mask));
let x = x.clone() + self.mlp.forward(self.mlp_ln.forward(x));
return x;
}
}
#[derive(Config, Debug)]
pub struct MultiHeadSelfAttentionConfig {
n_state: usize,
n_head: usize,
}
impl MultiHeadSelfAttentionConfig {
fn init<B: Backend>(&self, device: &B::Device) -> MultiHeadSelfAttention<B> {
assert!(
self.n_state % self.n_head == 0,
"State size {} must be a multiple of head size {}",
self.n_state,
self.n_head
);
let n_head = self.n_head;
let query = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
let key = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
let value = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
let out = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
MultiHeadSelfAttention {
n_head,
query,
key,
value,
out,
}
}
}
#[derive(Module, Debug)]
pub struct MultiHeadSelfAttention<B: Backend> {
n_head: usize,
query: nn::Linear<B>,
key: nn::Linear<B>,
value: nn::Linear<B>,
out: nn::Linear<B>,
}
impl<B: Backend> MultiHeadSelfAttention<B> {
pub fn forward(&self, x: Tensor<B, 3>, mask: Option<Tensor<B, 2>>) -> Tensor<B, 3> {
let q = self.query.forward(x.clone());
let k = self.key.forward(x.clone());
let v = self.value.forward(x);
/*let wv = Tensor::from_primitive(B::qkv_attention(
q.into_primitive(),
k.into_primitive(),
v.into_primitive(),
mask.map(|m| m.into_primitive()),
self.n_head,
));*/
let wv = qkv_attention(
q,
k,
v,
mask,
self.n_head,
);
return self.out.forward(wv);
}
}
#[derive(Config, Debug)]
pub struct MLPConfig {
input_size: usize,
hidden_size: usize,
}
impl MLPConfig {
fn init<B: Backend>(&self, device: &B::Device) -> MLP<B> {
let fc1 = nn::LinearConfig::new(self.input_size, self.hidden_size).init(device);
let gelu = QuickGELU::new();
let fc2 = nn::LinearConfig::new(self.hidden_size, self.input_size).init(device);
MLP { fc1, gelu, fc2 }
}
}
#[derive(Module, Debug)]
pub struct MLP<B: Backend> {
fc1: nn::Linear<B>,
gelu: QuickGELU,
fc2: nn::Linear<B>,
}
impl<B: Backend> MLP<B> {
fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
let x = self.fc1.forward(x);
let x = self.gelu.forward(x);
let x = self.fc2.forward(x);
x
}
}
#[derive(Module, Clone, Debug)]
pub struct QuickGELU {}
impl QuickGELU {
fn new() -> Self {
Self {}
}
fn forward<B: Backend, const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
x.clone() * sigmoid(x * 1.702)
}
}

View File

@@ -0,0 +1,37 @@
use super::GroupNorm;
use crate::model::load::*;
use std::error::Error;
use burn::{
config::Config,
module::{Module, Param},
nn,
tensor::{backend::Backend, Tensor},
};
pub fn load_group_norm<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<GroupNorm<B>, Box<dyn Error>> {
let n_group = load_usize::<B>("n_group", path, device)?.into();
let n_channel = load_usize::<B>("n_channel", path, device)?.into();
let eps = load_f32::<B>("eps", path, device)?.into();
let gamma = Param::from_tensor(load_tensor::<B, 1>("weight", path, device)
.ok()
.unwrap_or_else(|| Tensor::ones([n_channel], device))
);
let beta = Param::from_tensor(load_tensor::<B, 1>("bias", path, device)
.ok()
.unwrap_or_else(|| Tensor::zeros([n_channel], device))
);
Ok(GroupNorm {
n_group,
n_channel,
gamma,
beta,
eps,
})
}

View File

@@ -0,0 +1,82 @@
pub mod load;
use burn::{
config::Config,
module::{Module, Param},
tensor::{backend::Backend, Tensor},
};
#[derive(Config, Debug)]
pub struct GroupNormConfig {
n_group: usize,
n_channel: usize,
#[config(default = 1e-5)]
eps: f64,
}
impl GroupNormConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> GroupNorm<B> {
assert!(
self.n_channel % self.n_group == 0,
"The number of channels {} must be divisible by the number of groups {}",
self.n_channel,
self.n_group
);
let n_per_group = self.n_channel / self.n_group;
let gamma = Param::from_tensor(Tensor::ones([self.n_channel], device));
let beta = Param::from_tensor(Tensor::zeros([self.n_channel], device));
let eps = self.eps;
GroupNorm {
n_group: self.n_group,
n_channel: self.n_channel,
gamma,
beta,
eps,
}
}
}
#[derive(Module, Debug)]
pub struct GroupNorm<B: Backend> {
n_group: usize,
n_channel: usize,
gamma: Param<Tensor<B, 1>>,
beta: Param<Tensor<B, 1>>,
eps: f64,
}
impl<B: Backend> GroupNorm<B> {
pub fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
let shape = x.shape();
let n_batch = shape.dims[0];
let num_elements = shape.num_elements();
let mut affine_shape = [1; D];
affine_shape[1] = self.n_channel;
layernorm(
x.reshape([
n_batch,
self.n_group,
num_elements / (n_batch * self.n_group),
]),
self.eps,
)
.reshape(shape)
.mul(self.gamma.val().reshape(affine_shape))
.add(self.beta.val().reshape(affine_shape))
}
}
pub fn layernorm<B: Backend, const D: usize>(x: Tensor<B, D>, eps: f64) -> Tensor<B, D> {
//let (var, mean) = x.clone().var_mean_bias(D - 1);
//x.sub(mean).div(var.sqrt().add_scalar(eps))
let u = x.clone() - x.mean_dim(D - 1);
u.clone()
.div((u.clone() * u).mean_dim(D - 1).add_scalar(eps).sqrt())
}

178
src/model/load.rs Normal file
View File

@@ -0,0 +1,178 @@
use npy::{self, NpyData};
use num_traits::cast::ToPrimitive;
use burn::tensor::cast::ToElement;
use burn::prelude::TensorData;
use std::error::Error;
use std::io::Read;
use burn::{
config::Config,
module::{Module, Param},
nn::{self, conv},
tensor::{backend::Backend, Tensor},
};
use burn::tensor::ElementConversion;
pub fn numpy_to_tensor<B: Backend, const D: usize>(
numpy_data: NpyData<f32>,
device: &B::Device,
) -> Tensor<B, D> {
let mut v = numpy_data.to_vec();
let shape: Vec<_> = v[0..D].into_iter().map(|&v| v as usize).collect();
let data: Vec<B::FloatElem> = v[D..].into_iter().map(|e| e.elem()).collect();
//Tensor::from_data_device(Data::new(data, shape.into()), device)
Tensor::from_data(TensorData::new(data, shape), device)
}
pub fn load_tensor<B: Backend, const D: usize>(
name: &str,
path: &str,
device: &B::Device,
) -> Result<Tensor<B, D>, Box<dyn Error>> {
let tensor_path = format!("{}/{}.npy", path, name);
let mut buf = vec![];
std::fs::File::open(&tensor_path)?.read_to_end(&mut buf)?;
let tensor_numpy: NpyData<f32> = NpyData::from_bytes(&buf)?;
let tensor = numpy_to_tensor(tensor_numpy, device);
println!("{}", tensor_path);
Ok(tensor)
}
pub fn load_f32<B: Backend>(
name: &str,
path: &str,
device: &B::Device,
) -> Result<f32, Box<dyn Error>> {
load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_f32())
}
pub fn load_usize<B: Backend>(
name: &str,
path: &str,
device: &B::Device,
) -> Result<usize, Box<dyn Error>> {
load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_usize())
}
pub fn load_linear<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<nn::Linear<B>, Box<dyn Error>> {
let weight = load_tensor::<B, 2>("weight", path, device)?;
let bias = load_tensor::<B, 1>("bias", path, device).ok();
Ok(nn::Linear {
weight: Param::from_tensor(weight),
bias: bias.map(|t| Param::from_tensor(t)),
})
}
pub fn load_embedding<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<nn::Embedding<B>, Box<dyn Error>> {
let weight = load_tensor::<B, 2>("weight", path, device)?;
Ok(nn::Embedding {
weight: Param::from_tensor(weight),
})
}
pub fn load_layer_norm<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<nn::LayerNorm<B>, Box<dyn Error>> {
let weight = load_tensor::<B, 1>("weight", path, device)?;
let bias = load_tensor::<B, 1>("bias", path, device)?;
let eps = load_f32::<B>("eps", path, device)? as f64;
let [n_state] = weight.dims();
let mut layer_norm = nn::LayerNormConfig::new(n_state).with_epsilon(eps).init(device);
layer_norm.gamma = Param::from_tensor(weight);
layer_norm.beta = Some(Param::from_tensor(bias));
Ok(layer_norm)
}
/*pub fn load_rmsnorm<B: Backend>(path: &str, device: &B::Device) -> Result<RMSNorm<B>, Box<dyn Error>> {
let weight = load_tensor::<B, 1>("weight", path, device)?;
let eps = load_f32::<B>("eps", path, device)?.into();
let rmsnorm = RMSNorm {
weight: Param::from_tensor(weight),
eps: eps
};
Ok(rmsnorm)
}*/
pub fn load_conv2d<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<conv::Conv2d<B>, Box<dyn Error>> {
let weight = load_tensor::<B, 4>("weight", path, device)?;
let bias = load_tensor::<B, 1>("bias", path, device).ok();
let has_bias = bias.is_some();
let stride = load_tensor::<B, 1>("stride", path, device)?;
let stride = tensor_to_array_2(stride);
let kernel_size = load_tensor::<B, 1>("kernel_size", path, device)?;
let kernel_size = tensor_to_array_2(kernel_size);
let dilation = load_tensor::<B, 1>("dilation", path, device)?;
let dilation = tensor_to_array_2(dilation);
let n_group = load_usize::<B>("n_group", path, device)?.into();
let n_channels_in = load_usize::<B>("n_channels_in", path, device)?.into();
let n_channels_out = load_usize::<B>("n_channels_out", path, device)?.into();
let padding = load_tensor::<B, 1>("padding", path, device)?;
let padding = tensor_to_array_2(padding);
let padding = nn::PaddingConfig2d::Explicit(padding[0], padding[1]);
let mut conv2d = conv::Conv2dConfig::new([n_channels_in, n_channels_out], kernel_size)
.with_stride(stride)
.with_dilation(dilation)
.with_groups(n_group)
.with_padding(padding.clone())
.with_bias(has_bias)
.init(device);
conv2d.weight = Param::from_tensor(weight);
conv2d.bias = bias.map(|t| Param::from_tensor(t));
conv2d.stride = stride;
conv2d.kernel_size = kernel_size;
conv2d.dilation = dilation;
conv2d.groups = n_group;
conv2d.padding = burn::module::Ignored(padding);
Ok(conv2d)
}
pub fn tensor_to_array_2<B: Backend>(x: Tensor<B, 1>) -> [usize; 2] {
let vec: Vec<<B as Backend>::FloatElem> = x.into_data().to_vec().unwrap();
assert!(vec.len() == 2, "Tensor length must be 2.");
[vec[0].to_usize(), vec[1].to_usize()]
}
pub fn tensor_to_array<const N: usize, B: Backend>(x: Tensor<B, 1>) -> [usize; N] {
let vec: Vec<<B as Backend>::FloatElem> = x.into_data().to_vec().unwrap();
assert!(vec.len() == N, "Tensor length must be {}.", N);
let mut arr = [0; N];
for (a, t) in arr.iter_mut().zip(vec) {
*a = t.to_usize();
}
arr
}

11
src/model/mod.rs Normal file
View File

@@ -0,0 +1,11 @@
pub mod stablediffusion;
pub mod autoencoder;
pub mod clip;
pub mod unet;
pub mod attention;
pub mod groupnorm;
pub mod silu;
pub mod load;

17
src/model/silu.rs Normal file
View File

@@ -0,0 +1,17 @@
use burn::{
module::Module,
tensor::{activation::sigmoid, backend::Backend, Tensor},
};
#[derive(Module, Clone, Debug)]
pub struct SILU {}
impl SILU {
pub fn new() -> Self {
Self {}
}
pub fn forward<B: Backend, const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
x.clone() * sigmoid(x)
}
}

View File

@@ -0,0 +1,33 @@
use burn::tensor::ElementConversion;
use std::error::Error;
use burn::{
config::Config,
module::{Module, Param},
nn,
tensor::{backend::Backend, Tensor},
};
use super::*;
use crate::model::{
autoencoder::load::load_autoencoder, clip::load::load_clip, load::*, unet::load::load_unet,
};
pub fn load_stable_diffusion<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<StableDiffusion<B>, Box<dyn Error>> {
let n_steps = load_usize::<B>("n_steps", path, device)?;
let alpha_cumulative_products = Param::from_tensor(load_tensor::<B, 1>("alphas_cumprod", path, device)?);
let autoencoder = load_autoencoder(&format!("{}/{}", path, "autoencoder"), device)?;
let diffusion = load_unet(&format!("{}/{}", path, "unet"), device)?;
let clip = load_clip(&format!("{}/{}", path, "clip"), device)?;
Ok(StableDiffusion {
n_steps,
alpha_cumulative_products,
autoencoder,
diffusion,
clip,
})
}

View File

@@ -0,0 +1,237 @@
pub mod load;
use burn::{
config::Config,
module::{Module, Param},
tensor::{backend::Backend, BasicOps, Distribution, Float, Int, Tensor},
tensor::cast::ToElement,
};
use num_traits::ToPrimitive;
//use crate::backend::Backend as MyBackend;
use super::autoencoder::{Autoencoder, AutoencoderConfig};
use super::clip::{CLIPConfig, CLIP};
use super::unet::{UNet, UNetConfig};
use crate::tokenizer::SimpleTokenizer;
#[derive(Config, Debug)]
pub struct StableDiffusionConfig {}
impl StableDiffusionConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> StableDiffusion<B> {
let n_steps = 1000;
let alpha_cumulative_products = Param::from_tensor(offset_cosine_schedule_cumprod::<B>(n_steps as i64, device));
let autoencoder = AutoencoderConfig::new().init(device);
let diffusion = UNetConfig::new().init(device);
let clip = CLIPConfig::new(49408, 768, 12, 77, 12).init(device);
StableDiffusion {
n_steps,
alpha_cumulative_products,
autoencoder,
diffusion,
clip,
}
}
}
#[derive(Module, Debug)]
pub struct StableDiffusion<B: Backend> {
n_steps: usize,
alpha_cumulative_products: Param<Tensor<B, 1>>,
autoencoder: Autoencoder<B>,
diffusion: UNet<B>,
clip: CLIP<B>,
}
impl<B: Backend> StableDiffusion<B> {
pub fn sample_image(
&self,
context: Tensor<B, 3>,
unconditional_context: Tensor<B, 2>,
unconditional_guidance_scale: f64,
n_steps: usize,
) -> Vec<Vec<u8>> {
let [n_batch, _, _] = context.dims();
let latent = self.sample_latent(
context,
unconditional_context,
unconditional_guidance_scale,
n_steps,
);
self.latent_to_image(latent)
}
pub fn latent_to_image(&self, latent: Tensor<B, 4>) -> Vec<Vec<u8>> {
let [n_batch, _, _, _] = latent.dims();
let image = self.autoencoder.decode_latent(latent * (1.0 / 0.18215));
let n_channel = 3;
let height = 512;
let width = 512;
let num_elements_per_image = n_channel * height * width;
// correct size and scale and reorder to
let image = (image + 1.0) / 2.0;
let image = image
.reshape([n_batch, n_channel, height, width])
.swap_dims(1, 2)
.swap_dims(2, 3)
.mul_scalar(255.0);
let flattened: Vec<B::FloatElem> = image.into_data().to_vec().unwrap();
(0..n_batch)
.into_iter()
.map(|b| {
let start = b * num_elements_per_image;
let end = start + num_elements_per_image;
flattened[start..end]
.into_iter()
.map(|v| v.to_f64().min(255.0).max(0.0) as u8)
.collect()
})
.collect()
}
pub fn sample_latent(
&self,
context: Tensor<B, 3>,
unconditional_context: Tensor<B, 2>,
unconditional_guidance_scale: f64,
n_steps: usize,
) -> Tensor<B, 4> {
let device = context.device();
let step_size = self.n_steps / n_steps;
let [n_batches, _, _] = context.dims();
let gen_noise = || {
Tensor::random([n_batches, 4, 64, 64], Distribution::Normal(0.0, 1.0), &device)
};
let sigma = 0.0; // Use deterministic diffusion
let mut latent = gen_noise();
for t in (0..self.n_steps).rev().step_by(step_size) {
let current_alpha: f64 = self
.alpha_cumulative_products
.val()
.slice([t..t + 1])
.into_scalar()
.to_f64();
let prev_alpha: f64 = if t >= step_size {
let i = t - step_size;
self.alpha_cumulative_products
.val()
.slice([i..i + 1])
.into_scalar()
.to_f64()
} else {
1.0
};
let sqrt_noise = (1.0 - current_alpha).sqrt();
let timestep = Tensor::from_ints([t as i32], &device);
let pred_noise = self.forward_diffuser(
latent.clone(),
timestep,
context.clone(),
unconditional_context.clone(),
unconditional_guidance_scale,
);
let predx0 = (latent - pred_noise.clone() * sqrt_noise) / current_alpha.sqrt();
let dir_latent = pred_noise * (1.0 - prev_alpha - sigma * sigma).sqrt();
let prev_latent = predx0 * prev_alpha.sqrt() + dir_latent + gen_noise() * sigma;
latent = prev_latent;
}
latent
}
fn forward_diffuser(
&self,
latent: Tensor<B, 4>,
timestep: Tensor<B, 1, Int>,
context: Tensor<B, 3>,
unconditional_context: Tensor<B, 2>,
unconditional_guidance_scale: f64,
) -> Tensor<B, 4> {
let [n_batch, _, _, _] = latent.dims();
//let latent = latent.repeat(0, 2);
let unconditional_latent = self.diffusion.forward(
latent.clone(),
timestep.clone(),
unconditional_context.unsqueeze().repeat(&[0, n_batch]),
);
let conditional_latent = self.diffusion.forward(latent, timestep, context);
/*let latent = self.diffusion.forward(
latent.repeat(0, 2),
timestep.repeat(0, 2),
Tensor::cat(vec![unconditional_context.unsqueeze::<3>(), context], 0)
);
let unconditional_latent = latent.clone().slice([0..n_batch]);
let conditional_latent = latent.slice([n_batch..2 * n_batch]);*/
unconditional_latent.clone()
+ (conditional_latent - unconditional_latent) * unconditional_guidance_scale
}
pub fn unconditional_context(&self, tokenizer: &SimpleTokenizer) -> Tensor<B, 2> {
self.context(tokenizer, "").squeeze::<2>()
}
pub fn context(&self, tokenizer: &SimpleTokenizer, text: &str) -> Tensor<B, 3> {
let device = &self.clip.devices()[0];
let text = format!("<|startoftext|>{}<|endoftext|>", text);
let tokenized: Vec<_> = tokenizer
.encode(&text)
.into_iter()
.map(|v| v as i32)
.collect();
self.clip.forward(
Tensor::<B, 1, Int>::from_ints(&tokenized[..], device)
.unsqueeze(),
)
}
}
use std::f64::consts::PI;
fn cosine_schedule<B: Backend>(n_steps: i64, device: &B::Device) -> Tensor<B, 1> {
Tensor::arange(1..n_steps + 1, device)
.float()
.mul_scalar(PI * 0.5 / n_steps as f64)
.cos()
}
fn offset_cosine_schedule<B: Backend>(n_steps: i64, device: &B::Device) -> Tensor<B, 1> {
let min_signal_rate: f64 = 0.02;
let max_signal_rate: f64 = 0.95;
let start_angle = max_signal_rate.acos();
let end_angle = min_signal_rate.acos();
let times = Tensor::arange(1..n_steps + 1, device).float();
let diffusion_angles = times * ((end_angle - start_angle) / n_steps as f64) + start_angle;
diffusion_angles.cos()
}
fn offset_cosine_schedule_cumprod<B: Backend>(n_steps: i64, device: &B::Device) -> Tensor<B, 1> {
offset_cosine_schedule::<B>(n_steps, device).powf_scalar(2.0)
}

306
src/model/unet/load.rs Normal file
View File

@@ -0,0 +1,306 @@
use super::GroupNorm;
use crate::model::load::*;
use std::error::Error;
use burn::{
config::Config,
module::{Module, Param},
nn,
tensor::{backend::Backend, Tensor},
};
use super::*;
use crate::model::groupnorm::load::load_group_norm;
pub fn load_res_block<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<ResBlock<B>, Box<dyn Error>> {
let norm_in = load_group_norm::<B>(&format!("{}/{}", path, "norm_in"), device)?;
let conv_in = load_conv2d::<B>(&format!("{}/{}", path, "conv_in"), device)?;
let lin_embed = load_linear::<B>(&format!("{}/{}", path, "lin_embed"), device)?;
let norm_out = load_group_norm::<B>(&format!("{}/{}", path, "norm_out"), device)?;
let conv_out = load_conv2d::<B>(&format!("{}/{}", path, "conv_out"), device)?;
let skip_connection = load_conv2d::<B>(&format!("{}/{}", path, "skip_connection"), device).ok();
let res_block = ResBlock {
norm_in: norm_in,
silu_in: SILU::new(),
conv_in: conv_in,
silu_embed: SILU::new(),
lin_embed: lin_embed,
norm_out: norm_out,
silu_out: SILU::new(),
conv_out: conv_out,
skip_connection: skip_connection,
};
Ok(res_block)
}
pub fn load_multi_head_attention<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<MultiHeadAttention<B>, Box<dyn Error>> {
let n_head = load_usize::<B>("n_head", path, device)?;
let query = load_linear::<B>(&format!("{}/{}", path, "query"), device)?;
let key = load_linear::<B>(&format!("{}/{}", path, "key"), device)?;
let value = load_linear::<B>(&format!("{}/{}", path, "value"), device)?;
let out = load_linear::<B>(&format!("{}/{}", path, "out"), device)?;
let multi_head_attention = MultiHeadAttention {
n_head: n_head,
query: query,
key: key,
value: value,
out: out,
};
Ok(multi_head_attention)
}
pub fn load_geglu<B: Backend>(path: &str, device: &B::Device) -> Result<GEGLU<B>, Box<dyn Error>> {
let proj = load_linear::<B>(&format!("{}/{}", path, "proj"), device)?;
let geglue = GEGLU {
proj: proj,
gelu: Gelu::new(), // Assuming Gelu::new() initializes a new Gelu struct
};
Ok(geglue)
}
pub fn load_mlp<B: Backend>(path: &str, device: &B::Device) -> Result<MLP<B>, Box<dyn Error>> {
let geglu = load_geglu::<B>(&format!("{}/{}", path, "geglu"), device)?;
let lin = load_linear::<B>(&format!("{}/{}", path, "lin"), device)?;
let mlp = MLP {
geglu: geglu,
lin: lin,
};
Ok(mlp)
}
pub fn load_transformer_block<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<TransformerBlock<B>, Box<dyn Error>> {
let norm1 = load_layer_norm::<B>(&format!("{}/{}", path, "norm1"), device)?;
let attn1 = load_multi_head_attention::<B>(&format!("{}/{}", path, "attn1"), device)?;
let norm2 = load_layer_norm::<B>(&format!("{}/{}", path, "norm2"), device)?;
let attn2 = load_multi_head_attention::<B>(&format!("{}/{}", path, "attn2"), device)?;
let norm3 = load_layer_norm::<B>(&format!("{}/{}", path, "norm3"), device)?;
let mlp = load_mlp::<B>(&format!("{}/{}", path, "mlp"), device)?;
let transformer_block = TransformerBlock {
norm1: norm1,
attn1: attn1,
norm2: norm2,
attn2: attn2,
norm3: norm3,
mlp: mlp,
};
Ok(transformer_block)
}
pub fn load_spatial_transformer<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<SpatialTransformer<B>, Box<dyn Error>> {
let norm = load_group_norm::<B>(&format!("{}/{}", path, "norm"), device)?;
let proj_in = load_conv2d::<B>(&format!("{}/{}", path, "proj_in"), device)?;
let transformer = load_transformer_block::<B>(&format!("{}/{}", path, "transformer"), device)?;
let proj_out = load_conv2d::<B>(&format!("{}/{}", path, "proj_out"), device)?;
let spatial_transformer = SpatialTransformer {
norm: norm,
proj_in: proj_in,
transformer: transformer,
proj_out: proj_out,
};
Ok(spatial_transformer)
}
pub fn load_upsample<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<Upsample<B>, Box<dyn Error>> {
let conv = load_conv2d::<B>(&format!("{}/{}", path, "conv"), device)?;
let upsample = Upsample { conv: conv };
Ok(upsample)
}
pub fn load_downsample<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<Downsample<B>, Box<dyn Error>> {
load_conv2d(path, device)
}
pub fn load_res_transformer_res<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<ResTransformerRes<B>, Box<dyn Error>> {
let res1 = load_res_block::<B>(&format!("{}/{}", path, "res1"), device)?; // Assuming load_res_block function
let transformer =
load_spatial_transformer::<B>(&format!("{}/{}", path, "transformer"), device)?;
let res2 = load_res_block::<B>(&format!("{}/{}", path, "res2"), device)?;
let res_transformer_res = ResTransformerRes {
res1: res1,
transformer: transformer,
res2: res2,
};
Ok(res_transformer_res)
}
pub fn load_res_transformer_upsample<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<ResTransformerUpsample<B>, Box<dyn Error>> {
let res = load_res_block::<B>(&format!("{}/{}", path, "res"), device)?;
let transformer =
load_spatial_transformer::<B>(&format!("{}/{}", path, "transformer"), device)?;
let upsample = load_upsample::<B>(&format!("{}/{}", path, "upsample"), device)?;
let res_transformer_upsample = ResTransformerUpsample {
res: res,
transformer: transformer,
upsample: upsample,
};
Ok(res_transformer_upsample)
}
pub fn load_res_upsample<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<ResUpSample<B>, Box<dyn Error>> {
let res = load_res_block::<B>(&format!("{}/{}", path, "res"), device)?;
let upsample = load_upsample::<B>(&format!("{}/{}", path, "upsample"), device)?;
let res_upsample = ResUpSample {
res: res,
upsample: upsample,
};
Ok(res_upsample)
}
pub fn load_res_transformer<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<ResTransformer<B>, Box<dyn Error>> {
let res = load_res_block::<B>(&format!("{}/{}", path, "res"), device)?;
let transformer =
load_spatial_transformer::<B>(&format!("{}/{}", path, "transformer"), device)?;
let res_transformer = ResTransformer {
res: res,
transformer: transformer,
};
Ok(res_transformer)
}
pub fn load_unet_input_blocks<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<UNetInputBlocks<B>, Box<dyn Error>> {
let conv = load_conv2d::<B>(&format!("{}/{}", path, "conv"), device)?;
let rt1 = load_res_transformer::<B>(&format!("{}/{}", path, "rt1"), device)?;
let rt2 = load_res_transformer::<B>(&format!("{}/{}", path, "rt2"), device)?;
let d1 = load_downsample::<B>(&format!("{}/{}", path, "d1"), device)?;
let rt3 = load_res_transformer::<B>(&format!("{}/{}", path, "rt3"), device)?;
let rt4 = load_res_transformer::<B>(&format!("{}/{}", path, "rt4"), device)?;
let d2 = load_downsample::<B>(&format!("{}/{}", path, "d2"), device)?;
let rt5 = load_res_transformer::<B>(&format!("{}/{}", path, "rt5"), device)?;
let rt6 = load_res_transformer::<B>(&format!("{}/{}", path, "rt6"), device)?;
let d3 = load_downsample::<B>(&format!("{}/{}", path, "d3"), device)?;
let r1 = load_res_block::<B>(&format!("{}/{}", path, "r1"), device)?;
let r2 = load_res_block::<B>(&format!("{}/{}", path, "r2"), device)?;
let unet_input_blocks = UNetInputBlocks {
conv: conv,
rt1: rt1,
rt2: rt2,
d1: d1,
rt3: rt3,
rt4: rt4,
d2: d2,
rt5: rt5,
rt6: rt6,
d3: d3,
r1: r1,
r2: r2,
};
Ok(unet_input_blocks)
}
pub fn load_unet_output_blocks<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<UNetOutputBlocks<B>, Box<dyn Error>> {
let r1 = load_res_block::<B>(&format!("{}/{}", path, "r1"), device)?;
let r2 = load_res_block::<B>(&format!("{}/{}", path, "r2"), device)?;
let ru = load_res_upsample::<B>(&format!("{}/{}", path, "ru"), device)?;
let rt1 = load_res_transformer::<B>(&format!("{}/{}", path, "rt1"), device)?;
let rt2 = load_res_transformer::<B>(&format!("{}/{}", path, "rt2"), device)?;
let rtu1 = load_res_transformer_upsample::<B>(&format!("{}/{}", path, "rtu1"), device)?;
let rt3 = load_res_transformer::<B>(&format!("{}/{}", path, "rt3"), device)?;
let rt4 = load_res_transformer::<B>(&format!("{}/{}", path, "rt4"), device)?;
let rtu2 = load_res_transformer_upsample::<B>(&format!("{}/{}", path, "rtu2"), device)?;
let rt5 = load_res_transformer::<B>(&format!("{}/{}", path, "rt5"), device)?;
let rt6 = load_res_transformer::<B>(&format!("{}/{}", path, "rt6"), device)?;
let rt7 = load_res_transformer::<B>(&format!("{}/{}", path, "rt7"), device)?;
Ok(UNetOutputBlocks {
r1,
r2,
ru,
rt1,
rt2,
rtu1,
rt3,
rt4,
rtu2,
rt5,
rt6,
rt7,
})
}
pub fn load_unet<B: Backend>(path: &str, device: &B::Device) -> Result<UNet<B>, Box<dyn Error>> {
let lin1_time_embed = load_linear::<B>(&format!("{}/{}", path, "lin1_time_embed"), device)?;
let silu_time_embed = SILU::new(); // Assuming SILU::new() initializes a new SILU struct
let lin2_time_embed = load_linear::<B>(&format!("{}/{}", path, "lin2_time_embed"), device)?;
let input_blocks =
load_unet_input_blocks::<B>(&format!("{}/{}", path, "input_blocks"), device)?;
let middle_block =
load_res_transformer_res::<B>(&format!("{}/{}", path, "middle_block"), device)?;
let output_blocks =
load_unet_output_blocks::<B>(&format!("{}/{}", path, "output_blocks"), device)?;
let norm_out = load_group_norm::<B>(&format!("{}/{}", path, "norm_out"), device)?;
let silu_out = SILU::new(); // Assuming SILU::new() initializes a new SILU struct
let conv_out = load_conv2d::<B>(&format!("{}/{}", path, "conv_out"), device)?;
Ok(UNet {
lin1_time_embed,
silu_time_embed,
lin2_time_embed,
input_blocks,
middle_block,
output_blocks,
norm_out,
silu_out,
conv_out,
})
}

740
src/model/unet/mod.rs Normal file
View File

@@ -0,0 +1,740 @@
pub mod load;
use burn::{
config::Config,
module::{Module, Param},
nn::{
self,
conv::{Conv2d, Conv2dConfig},
PaddingConfig2d, Gelu,
},
tensor::{activation::softmax, backend::Backend, module::embedding, Distribution, Int, Tensor},
};
use super::groupnorm::*;
use super::silu::*;
use super::attention::qkv_attention;
fn timestep_embedding<B: Backend>(
timesteps: Tensor<B, 1, Int>,
dim: usize,
max_period: usize,
) -> Tensor<B, 2> {
let half = dim / 2;
let freqs = (Tensor::arange(0..half as i64, &timesteps.device()).float()
* (-(max_period as f64).ln() / half as f64))
.exp();
let args = timesteps.float() * freqs;
Tensor::cat(vec![args.clone().cos(), args.sin()], 0).unsqueeze()
}
#[derive(Config, Debug)]
pub struct UNetConfig {}
impl UNetConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> UNet<B> {
let lin1_time_embed = nn::LinearConfig::new(320, 1280).init(device);
let silu_time_embed = SILU::new();
let lin2_time_embed = nn::LinearConfig::new(1280, 1280).init(device);
let input_blocks = UNetInputBlocks {
conv: Conv2dConfig::new([4, 320], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init(device),
rt1: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(device),
rt2: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(device),
d1: DownsampleConfig::new(320).init(device),
rt3: ResTransformerConfig::new(320, 1280, 640, 768, 8).init(device),
rt4: ResTransformerConfig::new(640, 1280, 640, 768, 8).init(device),
d2: DownsampleConfig::new(640).init(device),
rt5: ResTransformerConfig::new(640, 1280, 1280, 768, 8).init(device),
rt6: ResTransformerConfig::new(1280, 1280, 1280, 768, 8).init(device),
d3: DownsampleConfig::new(1280).init(device),
r1: ResBlockConfig::new(1280, 1280, 1280).init(device),
r2: ResBlockConfig::new(1280, 1280, 1280).init(device),
};
let middle_block = ResTransformerResConfig::new(1280, 1280, 1280, 768, 8).init(device);
let output_blocks = UNetOutputBlocks {
r1: ResBlockConfig::new(2560, 1280, 1280).init(device),
r2: ResBlockConfig::new(2560, 1280, 1280).init(device),
ru: ResUpSampleConfig::new(2560, 1280, 1280).init(device),
rt1: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(device),
rt2: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(device),
rtu1: ResTransformerUpsampleConfig::new(1920, 1280, 1280, 768, 8).init(device),
rt3: ResTransformerConfig::new(1920, 1280, 640, 768, 8).init(device),
rt4: ResTransformerConfig::new(1280, 1280, 640, 768, 8).init(device),
rtu2: ResTransformerUpsampleConfig::new(960, 1280, 640, 768, 8).init(device),
rt5: ResTransformerConfig::new(960, 1280, 320, 768, 8).init(device),
rt6: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(device),
rt7: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(device),
};
let norm_out = GroupNormConfig::new(32, 320).init(device);
let silu_out = SILU::new();
let conv_out = Conv2dConfig::new([320, 4], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init(device);
UNet {
lin1_time_embed,
silu_time_embed,
lin2_time_embed,
input_blocks,
middle_block,
output_blocks,
norm_out,
silu_out,
conv_out,
}
}
}
#[derive(Module, Debug)]
pub struct UNet<B: Backend> {
lin1_time_embed: nn::Linear<B>,
silu_time_embed: SILU,
lin2_time_embed: nn::Linear<B>,
input_blocks: UNetInputBlocks<B>,
middle_block: ResTransformerRes<B>,
output_blocks: UNetOutputBlocks<B>,
norm_out: GroupNorm<B>,
silu_out: SILU,
conv_out: Conv2d<B>,
}
impl<B: Backend> UNet<B> {
pub fn forward(
&self,
x: Tensor<B, 4>,
timesteps: Tensor<B, 1, Int>,
context: Tensor<B, 3>,
) -> Tensor<B, 4> {
let t_emb = timestep_embedding(timesteps, 320, 10000);
let emb = self.lin1_time_embed.forward(t_emb);
let emb = self.silu_time_embed.forward(emb);
let emb = self.lin2_time_embed.forward(emb);
let mut saved_inputs = Vec::new();
let mut x = x;
// input blocks
for block in self.input_blocks.as_array() {
x = block.forward(x, emb.clone(), context.clone());
saved_inputs.push(x.clone())
}
// middle block
x = self.middle_block.forward(x, emb.clone(), context.clone());
// output blocks
for block in self.output_blocks.as_array() {
x = Tensor::cat(vec![x, saved_inputs.pop().unwrap()], 1);
x = block.forward(x, emb.clone(), context.clone());
}
let x = self.norm_out.forward(x);
let x = self.silu_out.forward(x);
let x = self.conv_out.forward(x);
x
}
}
#[derive(Module, Debug)]
pub struct UNetInputBlocks<B: Backend> {
conv: Conv2d<B>,
rt1: ResTransformer<B>,
rt2: ResTransformer<B>,
d1: Downsample<B>,
rt3: ResTransformer<B>,
rt4: ResTransformer<B>,
d2: Downsample<B>,
rt5: ResTransformer<B>,
rt6: ResTransformer<B>,
d3: Downsample<B>,
r1: ResBlock<B>,
r2: ResBlock<B>,
}
impl<B: Backend> UNetInputBlocks<B> {
fn as_array(&self) -> [&dyn UNetBlock<B>; 12] {
[
&self.conv, &self.rt1, &self.rt2, &self.d1, &self.rt3, &self.rt4, &self.d2, &self.rt5,
&self.rt6, &self.d3, &self.r1, &self.r2,
]
}
}
#[derive(Module, Debug)]
pub struct UNetOutputBlocks<B: Backend> {
r1: ResBlock<B>,
r2: ResBlock<B>,
ru: ResUpSample<B>,
rt1: ResTransformer<B>,
rt2: ResTransformer<B>,
rtu1: ResTransformerUpsample<B>,
rt3: ResTransformer<B>,
rt4: ResTransformer<B>,
rtu2: ResTransformerUpsample<B>,
rt5: ResTransformer<B>,
rt6: ResTransformer<B>,
rt7: ResTransformer<B>,
}
impl<B: Backend> UNetOutputBlocks<B> {
fn as_array(&self) -> [&dyn UNetBlock<B>; 12] {
[
&self.r1, &self.r2, &self.ru, &self.rt1, &self.rt2, &self.rtu1, &self.rt3, &self.rt4,
&self.rtu2, &self.rt5, &self.rt6, &self.rt7,
]
}
}
trait UNetBlock<B: Backend> {
fn forward(&self, x: Tensor<B, 4>, emb: Tensor<B, 2>, context: Tensor<B, 3>) -> Tensor<B, 4>;
}
#[derive(Config, Debug)]
pub struct ResTransformerConfig {
n_channels_in: usize,
n_channels_embed: usize,
n_channels_out: usize,
n_context_state: usize,
n_head: usize,
}
impl ResTransformerConfig {
fn init<B: Backend>(&self, device: &B::Device) -> ResTransformer<B> {
let res = ResBlockConfig::new(
self.n_channels_in,
self.n_channels_embed,
self.n_channels_out,
)
.init(device);
let transformer =
SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head)
.init(device);
ResTransformer { res, transformer }
}
}
#[derive(Module, Debug)]
pub struct ResTransformer<B: Backend> {
res: ResBlock<B>,
transformer: SpatialTransformer<B>,
}
impl<B: Backend> UNetBlock<B> for ResTransformer<B> {
fn forward(&self, x: Tensor<B, 4>, emb: Tensor<B, 2>, context: Tensor<B, 3>) -> Tensor<B, 4> {
let x = self.res.forward(x, emb);
let x = self.transformer.forward(x, context);
x
}
}
#[derive(Config, Debug)]
pub struct ResUpSampleConfig {
n_channels_in: usize,
n_channels_embed: usize,
n_channels_out: usize,
}
impl ResUpSampleConfig {
fn init<B: Backend>(&self, device: &B::Device) -> ResUpSample<B> {
let res = ResBlockConfig::new(
self.n_channels_in,
self.n_channels_embed,
self.n_channels_out,
)
.init(device);
let upsample = UpsampleConfig::new(self.n_channels_out).init(device);
ResUpSample { res, upsample }
}
}
#[derive(Module, Debug)]
pub struct ResUpSample<B: Backend> {
res: ResBlock<B>,
upsample: Upsample<B>,
}
impl<B: Backend> UNetBlock<B> for ResUpSample<B> {
fn forward(&self, x: Tensor<B, 4>, emb: Tensor<B, 2>, context: Tensor<B, 3>) -> Tensor<B, 4> {
let x = self.res.forward(x, emb);
let x = self.upsample.forward(x);
x
}
}
#[derive(Config, Debug)]
pub struct ResTransformerUpsampleConfig {
n_channels_in: usize,
n_channels_embed: usize,
n_channels_out: usize,
n_context_state: usize,
n_head: usize,
}
impl ResTransformerUpsampleConfig {
fn init<B: Backend>(&self, device: &B::Device) -> ResTransformerUpsample<B> {
let res = ResBlockConfig::new(
self.n_channels_in,
self.n_channels_embed,
self.n_channels_out,
)
.init(device);
let transformer =
SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head)
.init(device);
let upsample = UpsampleConfig::new(self.n_channels_out).init(device);
ResTransformerUpsample {
res,
transformer,
upsample,
}
}
}
#[derive(Module, Debug)]
pub struct ResTransformerUpsample<B: Backend> {
res: ResBlock<B>,
transformer: SpatialTransformer<B>,
upsample: Upsample<B>,
}
impl<B: Backend> UNetBlock<B> for ResTransformerUpsample<B> {
fn forward(&self, x: Tensor<B, 4>, emb: Tensor<B, 2>, context: Tensor<B, 3>) -> Tensor<B, 4> {
let x = self.res.forward(x, emb);
let x = self.transformer.forward(x, context);
let x = self.upsample.forward(x);
x
}
}
#[derive(Config, Debug)]
pub struct ResTransformerResConfig {
n_channels_in: usize,
n_channels_embed: usize,
n_channels_out: usize,
n_context_state: usize,
n_head: usize,
}
impl ResTransformerResConfig {
fn init<B: Backend>(&self, device: &B::Device) -> ResTransformerRes<B> {
let res1 = ResBlockConfig::new(
self.n_channels_in,
self.n_channels_embed,
self.n_channels_out,
)
.init(device);
let transformer =
SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head)
.init(device);
let res2 = ResBlockConfig::new(
self.n_channels_in,
self.n_channels_embed,
self.n_channels_out,
)
.init(device);
ResTransformerRes {
res1,
transformer,
res2,
}
}
}
#[derive(Module, Debug)]
pub struct ResTransformerRes<B: Backend> {
res1: ResBlock<B>,
transformer: SpatialTransformer<B>,
res2: ResBlock<B>,
}
impl<B: Backend> UNetBlock<B> for ResTransformerRes<B> {
fn forward(&self, x: Tensor<B, 4>, emb: Tensor<B, 2>, context: Tensor<B, 3>) -> Tensor<B, 4> {
let x = self.res1.forward(x, emb.clone());
let x = self.transformer.forward(x, context);
let x = self.res2.forward(x, emb);
x
}
}
#[derive(Config, Debug)]
pub struct UpsampleConfig {
n_channels: usize,
}
impl UpsampleConfig {
fn init<B: Backend>(&self, device: &B::Device) -> Upsample<B> {
let conv = Conv2dConfig::new([self.n_channels, self.n_channels], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init(device);
Upsample { conv }
}
}
#[derive(Module, Debug)]
pub struct Upsample<B: Backend> {
conv: Conv2d<B>,
}
impl<B: Backend> Upsample<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let [n_batch, n_channel, height, width] = x.dims();
let x = x
.reshape([n_batch, n_channel, height, 1, width, 1])
.repeat(&[1, 1, 1, 2, 1, 2])
.reshape([n_batch, n_channel, 2 * height, 2 * width]);
self.conv.forward(x)
}
}
impl<B: Backend> UNetBlock<B> for Upsample<B> {
fn forward(&self, x: Tensor<B, 4>, emb: Tensor<B, 2>, context: Tensor<B, 3>) -> Tensor<B, 4> {
self.forward(x)
}
}
#[derive(Config, Debug)]
pub struct DownsampleConfig {
n_channels: usize,
}
impl DownsampleConfig {
fn init<B: Backend>(&self, device: &B::Device) -> Conv2d<B> {
Conv2dConfig::new([self.n_channels, self.n_channels], [3, 3])
.with_stride([2, 2])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init(device)
}
}
type Downsample<B> = Conv2d<B>;
impl<B: Backend> UNetBlock<B> for Conv2d<B> {
fn forward(&self, x: Tensor<B, 4>, emb: Tensor<B, 2>, context: Tensor<B, 3>) -> Tensor<B, 4> {
self.forward(x)
}
}
#[derive(Config, Debug)]
pub struct SpatialTransformerConfig {
n_channels: usize,
n_context_state: usize,
n_head: usize,
}
impl SpatialTransformerConfig {
fn init<B: Backend>(&self, device: &B::Device) -> SpatialTransformer<B> {
let norm = GroupNormConfig::new(32, self.n_channels).init(device);
let proj_in = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init(device);
let transformer =
TransformerBlockConfig::new(self.n_channels, self.n_context_state, self.n_head).init(device);
let proj_out = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init(device);
SpatialTransformer {
norm,
proj_in,
transformer,
proj_out,
}
}
}
#[derive(Module, Debug)]
pub struct SpatialTransformer<B: Backend> {
norm: GroupNorm<B>,
proj_in: Conv2d<B>,
transformer: TransformerBlock<B>,
proj_out: Conv2d<B>,
}
impl<B: Backend> SpatialTransformer<B> {
fn forward(&self, x: Tensor<B, 4>, context: Tensor<B, 3>) -> Tensor<B, 4> {
let [n_batch, n_channel, height, width] = x.dims();
let x_in = x.clone();
let x = self.norm.forward(x);
let x = self.proj_in.forward(x);
let x = x
.reshape([n_batch, n_channel, height * width])
.swap_dims(1, 2);
let x = self
.transformer
.forward(x, context)
.swap_dims(1, 2)
.reshape([n_batch, n_channel, height, width]);
x_in + self.proj_out.forward(x)
}
}
#[derive(Config, Debug)]
pub struct TransformerBlockConfig {
n_state: usize,
n_context_state: usize,
n_head: usize,
}
impl TransformerBlockConfig {
fn init<B: Backend>(&self, device: &B::Device) -> TransformerBlock<B> {
let norm1 = nn::LayerNormConfig::new(self.n_state).init(device);
let attn1 = MultiHeadAttentionConfig::new(self.n_state, self.n_state, self.n_head).init(device);
let norm2 = nn::LayerNormConfig::new(self.n_state).init(device);
let attn2 =
MultiHeadAttentionConfig::new(self.n_state, self.n_context_state, self.n_head).init(device);
let norm3 = nn::LayerNormConfig::new(self.n_state).init(device);
let mlp = MLPConfig::new(self.n_state, 4).init(device);
TransformerBlock {
norm1,
attn1,
norm2,
attn2,
norm3,
mlp,
}
}
}
#[derive(Module, Debug)]
pub struct TransformerBlock<B: Backend> {
norm1: nn::LayerNorm<B>,
attn1: MultiHeadAttention<B>,
norm2: nn::LayerNorm<B>,
attn2: MultiHeadAttention<B>,
norm3: nn::LayerNorm<B>,
mlp: MLP<B>,
}
impl<B: Backend> TransformerBlock<B> {
fn forward(&self, x: Tensor<B, 3>, context: Tensor<B, 3>) -> Tensor<B, 3> {
let x = x.clone() + self.attn1.forward(self.norm1.forward(x), None);
let x = x.clone() + self.attn2.forward(self.norm2.forward(x), Some(context));
x.clone() + self.mlp.forward(self.norm3.forward(x))
}
}
#[derive(Config, Debug)]
pub struct MLPConfig {
n_state: usize,
mult: usize,
}
impl MLPConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> MLP<B> {
let n_state_hidden = self.n_state * self.mult;
let geglu = GEGLUConfig::new(self.n_state, n_state_hidden).init(device);
let lin = nn::LinearConfig::new(n_state_hidden, self.n_state).init(device);
MLP { geglu, lin }
}
}
#[derive(Module, Debug)]
pub struct MLP<B: Backend> {
geglu: GEGLU<B>,
lin: nn::Linear<B>,
}
impl<B: Backend> MLP<B> {
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
self.lin.forward(self.geglu.forward(x))
}
}
#[derive(Config, Debug)]
pub struct GEGLUConfig {
n_state_in: usize,
n_state_out: usize,
}
impl GEGLUConfig {
fn init<B: Backend>(&self, device: &B::Device) -> GEGLU<B> {
let proj = nn::LinearConfig::new(self.n_state_in, 2 * self.n_state_out).init(device);
let gelu = Gelu::new();
GEGLU { proj, gelu }
}
}
#[derive(Module, Debug)]
pub struct GEGLU<B: Backend> {
proj: nn::Linear<B>,
gelu: Gelu,
}
impl<B: Backend> GEGLU<B> {
fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let projected = self.proj.forward(x);
let [n_batch, n_ctx, n_state] = projected.dims();
let n_state_out = n_state / 2;
let x = projected
.clone()
.slice([0..n_batch, 0..n_ctx, 0..n_state_out]);
let gate = projected.slice([0..n_batch, 0..n_ctx, n_state_out..n_state]);
x * self.gelu.forward(gate)
}
}
#[derive(Config, Debug)]
pub struct MultiHeadAttentionConfig {
n_state: usize,
n_context_state: usize,
n_head: usize,
}
impl MultiHeadAttentionConfig {
fn init<B: Backend>(&self, device: &B::Device) -> MultiHeadAttention<B> {
assert!(
self.n_state % self.n_head == 0,
"State size {} must be a multiple of head size {}",
self.n_state,
self.n_head
);
let n_head = self.n_head;
let query = nn::LinearConfig::new(self.n_state, self.n_state)
.with_bias(false)
.init(device);
let key = nn::LinearConfig::new(self.n_context_state, self.n_state)
.with_bias(false)
.init(device);
let value = nn::LinearConfig::new(self.n_context_state, self.n_state)
.with_bias(false)
.init(device);
let out = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
MultiHeadAttention {
n_head,
query,
key,
value,
out,
}
}
}
#[derive(Module, Debug)]
pub struct MultiHeadAttention<B: Backend> {
n_head: usize,
query: nn::Linear<B>,
key: nn::Linear<B>,
value: nn::Linear<B>,
out: nn::Linear<B>,
}
impl<B: Backend> MultiHeadAttention<B> {
pub fn forward(&self, x: Tensor<B, 3>, context: Option<Tensor<B, 3>>) -> Tensor<B, 3> {
let xa = context.unwrap_or_else(|| x.clone());
let q = self.query.forward(x);
let k = self.key.forward(xa.clone());
let v = self.value.forward(xa);
let wv = qkv_attention(q, k, v, None, self.n_head);
self.out.forward(wv)
}
}
#[derive(Config, Debug)]
pub struct ResBlockConfig {
n_channels_in: usize,
n_channels_embed: usize,
n_channels_out: usize,
}
impl ResBlockConfig {
fn init<B: Backend>(&self, device: &B::Device) -> ResBlock<B> {
let norm_in = GroupNormConfig::new(32, self.n_channels_in).init(device);
let silu_in = SILU::new();
let conv_in = Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init(device);
let silu_embed = SILU::new();
let lin_embed = nn::LinearConfig::new(self.n_channels_embed, self.n_channels_out).init(device);
let norm_out = GroupNormConfig::new(32, self.n_channels_out).init(device);
let silu_out = SILU::new();
let conv_out = Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init(device);
let skip_connection = if self.n_channels_in != self.n_channels_out {
Some(Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [1, 1]).init(device))
} else {
None
};
ResBlock {
norm_in,
silu_in,
conv_in,
silu_embed,
lin_embed,
norm_out,
silu_out,
conv_out,
skip_connection,
}
}
}
#[derive(Module, Debug)]
pub struct ResBlock<B: Backend> {
norm_in: GroupNorm<B>,
silu_in: SILU,
conv_in: Conv2d<B>,
silu_embed: SILU,
lin_embed: nn::Linear<B>,
norm_out: GroupNorm<B>,
silu_out: SILU,
conv_out: Conv2d<B>,
skip_connection: Option<Conv2d<B>>,
}
impl<B: Backend> ResBlock<B> {
fn forward(&self, x: Tensor<B, 4>, embed: Tensor<B, 2>) -> Tensor<B, 4> {
let h = self.norm_in.forward(x.clone());
let h = self.silu_in.forward(h);
let h = self.conv_in.forward(h);
let embed_out = self.silu_embed.forward(embed);
let embed_out = self.lin_embed.forward(embed_out);
let [n_batch_embed, n_state_embed] = embed_out.dims();
let h = h + embed_out.reshape([n_batch_embed, n_state_embed, 1, 1]);
let h = self.norm_out.forward(h);
let h = self.silu_out.forward(h);
let h = self.conv_out.forward(h);
if let Some(skipc) = self.skip_connection.as_ref() {
skipc.forward(x) + h
} else {
x + h
}
}
}
impl<B: Backend> UNetBlock<B> for ResBlock<B> {
fn forward(&self, x: Tensor<B, 4>, emb: Tensor<B, 2>, context: Tensor<B, 3>) -> Tensor<B, 4> {
self.forward(x, emb)
}
}

54
src/token.rs Normal file
View File

@@ -0,0 +1,54 @@
use std::result;
use rust_tokenizers::{error::TokenizerError, tokenizer::{Tokenizer, SentencePieceBpeTokenizer, TruncationStrategy}, vocab::Vocab};
const BOS_TOKEN_ID: i64 = 1;
const EOS_TOKEN_ID: i64 = 2;
pub type Result<T> = result::Result<T, TokenizerError>;
pub struct LlamaTokenizer {
spm: SentencePieceBpeTokenizer,
}
impl LlamaTokenizer {
pub fn new(tokenizer_path: &str) -> Result<Self> {
let lower_case = false;
SentencePieceBpeTokenizer::from_file(tokenizer_path, lower_case)
.map(|spm| Self { spm } )
}
pub fn encode(&self, text: &str, include_bos: bool, include_eos: bool) -> Vec<i64> {
let pre = if include_bos {
vec![BOS_TOKEN_ID]
} else {
vec![]
};
let post = if include_eos {
vec![EOS_TOKEN_ID]
} else {
vec![]
};
let token_ids = self.spm.encode(text, None, std::usize::MAX, &TruncationStrategy::LongestFirst, 0).token_ids;
[pre, token_ids, post]
.into_iter()
.flat_map(|v| v.into_iter())
.collect()
}
pub fn decode(&self, tokens: &[i64], skip_special_tokens: bool) -> String {
let clean_spaces = false;
self.spm.decode(tokens, skip_special_tokens, clean_spaces)
}
pub fn vocab_size(&self, include_special_tokens: bool) -> usize {
let vocab = self.spm.vocab();
if include_special_tokens {
vocab.values().len() + vocab.special_values().len()
} else {
vocab.values().len()
}
}
}

222
src/tokenizer.rs Normal file
View File

@@ -0,0 +1,222 @@
use regex::Regex;
use std::collections::HashMap;
use std::fs::File;
use std::io::{self, BufRead};
fn bytes_to_unicode() -> Vec<(u8, char)> {
let mut bs: Vec<u8> = ('!' as u8..='~' as u8)
.into_iter()
.chain(('¡' as u8..='¬' as u8).into_iter())
.chain(('®' as u8..='ÿ' as u8).into_iter())
.collect();
let mut cs: Vec<_> = bs.iter().cloned().map(char::from).collect();
let mut n = 0;
for b in 0u8..=255u8 {
if !bs.contains(&b) {
bs.push(b);
cs.push(char::from_u32(256 + n).unwrap());
n += 1;
}
}
bs.into_iter()
.zip(cs.into_iter().map(|c| c.into()))
.collect()
}
fn get_pairs(word: &[String]) -> Vec<(String, String)> {
let prev = word.into_iter().cloned();
let next = prev.clone().skip(1);
prev.zip(next).collect()
}
fn whitespace_clean(text: &str) -> String {
text.split_whitespace().collect::<Vec<&str>>().join(" ")
}
fn load_merges(path: &str) -> io::Result<Vec<(String, String)>> {
let file = File::open(&path)?;
let reader = io::BufReader::new(file);
let mut merges = Vec::new();
for line in reader.lines() {
let line = line?;
let mut words = line.split_whitespace();
if let (Some(word1), Some(word2)) = (words.next(), words.next()) {
merges.push((word1.into(), word2.into()));
}
}
Ok(merges)
}
fn construct_vocab(
chars: impl Iterator<Item = char> + Clone,
merges: &[(String, String)],
) -> Vec<String> {
let iter = chars.map(String::from);
let mut vocab: Vec<_> = iter.clone().chain(iter.map(|c| c + "</w>")).collect();
for merge in merges {
vocab.push(format!("{}{}", merge.0, merge.1));
}
vocab.extend(["<|startoftext|>".to_string(), "<|endoftext|>".to_string()]);
return vocab;
}
pub struct SimpleTokenizer {
byte_encoder: HashMap<u8, char>,
byte_decoder: HashMap<char, u8>,
encoder: HashMap<String, u32>,
decoder: HashMap<u32, String>,
bpe_ranks: HashMap<(String, String), u32>,
cache: HashMap<String, String>,
pat: Regex,
}
impl SimpleTokenizer {
pub fn new() -> io::Result<Self> {
let byte_unicode_values = bytes_to_unicode();
let byte_encoder: HashMap<_, _> = byte_unicode_values.iter().cloned().collect();
let byte_decoder = byte_encoder.iter().map(|(k, v)| (*v, *k)).collect();
let merges = load_merges("bpe_simple_vocab_16e6.txt")?;
let merges = merges[1..49152 - 256 - 2 + 1].to_vec();
let vocab = construct_vocab(byte_unicode_values.into_iter().map(|(_, u)| u), &merges[..]);
let encoder: HashMap<String, u32> = vocab.iter().cloned().zip((0..).into_iter()).collect();
let decoder: HashMap<u32, String> = encoder.iter().map(|(k, v)| (*v, k.clone())).collect();
let bpe_ranks = merges.iter().cloned().zip((0..).into_iter()).collect();
let cache = HashMap::from([
("<|startoftext|>".to_string(), "<|startoftext|>".to_string()),
("<|endoftext|>".to_string(), "<|endoftext|>".to_string()),
]);
let pat = Regex::new(r"(?i)<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|\p{L}+|\p{N}|[^\s\p{L}\p{N}]+").unwrap();
Ok(SimpleTokenizer {
byte_encoder: byte_encoder,
byte_decoder: byte_decoder,
encoder: encoder,
decoder: decoder,
bpe_ranks: bpe_ranks,
cache: cache,
pat: pat,
})
}
pub fn bpe(&self, token: &str) -> String {
if let Some(word) = self.cache.get(token) {
return word.clone();
}
let mut word: Vec<String> = token.chars().map(|c| c.to_string()).collect();
word.last_mut().map(|w| *w += "</w>");
let mut pairs = get_pairs(&word);
if pairs.is_empty() {
return format!("{}{}", token, "</w>");
}
loop {
let bigram = pairs
.iter()
.filter(|pair| self.bpe_ranks.contains_key(pair))
.min_by_key(|&pair| self.bpe_ranks[pair]);
if bigram.is_none() {
break;
}
let (first, second) = bigram.unwrap();
let mut new_word = Vec::new();
let mut i = 0;
while i < word.len() {
if let Some((j, _)) = word.iter().enumerate().skip(i).find(|(_, w)| w == &first) {
new_word.extend(word[i..j].iter().cloned());
i = j;
} else {
new_word.extend(word[i..].iter().cloned());
break;
}
if &word[i] == first && i < word.len() - 1 && &word[i + 1] == second {
new_word.push(format!("{}{}", first, second));
i += 2;
} else {
new_word.push(word[i].clone());
i += 1;
}
}
word = new_word;
if word.len() == 1 {
break;
} else {
pairs = get_pairs(&word[..])
}
}
let word = word.join(" ");
//self.cache.insert(token.into(), word);
return word;
}
pub fn encode(&self, text: &str) -> Vec<u32> {
let cleaned_text = whitespace_clean(text.trim()).to_lowercase();
let mut bpe_tokens: Vec<u32> = Vec::new();
for m in self.pat.find_iter(&cleaned_text) {
let token = m.as_str();
let token: String = token
.as_bytes()
.into_iter()
.map(|b| self.byte_encoder[b])
.collect();
bpe_tokens.extend(
self.bpe(&token)
.split(' ')
.map(|bpe_token| self.encoder[bpe_token]),
)
}
return bpe_tokens;
}
pub fn decode(&self, tokens: &[u32]) -> String {
let text: String = tokens.iter().map(|t| self.decoder[t].as_str()).collect();
let decoded_bytes: Vec<u8> = text.chars().map(|c| self.byte_decoder[&c]).collect();
String::from_utf8_lossy(&decoded_bytes[..]).replace("</w>", " ")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encode_decode() {
let tokenizer = SimpleTokenizer::new().unwrap();
let text = "Hello world! <|startoftext|>asdf<|startoftext|>";
let target_encode = [3306, 1002, 256, 49406, 587, 10468, 49406];
let target_decode = "hello world ! <|startoftext|>asdf <|startoftext|>"; // extra spaces sometimes
let encoded = tokenizer.encode(&text);
assert_eq!(&target_encode[..], &encoded[..]);
let decoded = tokenizer.decode(&encoded[..]);
assert_eq!(target_decode, decoded);
}
}