Add files via upload
Add initial project files
This commit is contained in:
16
Cargo.toml
Normal file
16
Cargo.toml
Normal file
@@ -0,0 +1,16 @@
|
||||
[package]
|
||||
name = "stablediffusion"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
burn = "0.8.0"
|
||||
burn-tch = "0.8.0"
|
||||
serde = {version = "1.0.171", features = ["std", "derive"]}
|
||||
npy = "0.4.0"
|
||||
num-traits = "0.2.15"
|
||||
rust_tokenizers = "8.1.0"
|
||||
regex = "1.9.1"
|
||||
image = "0.24.6"
|
||||
262145
bpe_simple_vocab_16e6.txt
Normal file
262145
bpe_simple_vocab_16e6.txt
Normal file
File diff suppressed because it is too large
Load Diff
92
python/autoencoder.py
Normal file
92
python/autoencoder.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import pathlib
|
||||
import save
|
||||
from save import *
|
||||
|
||||
from tinygrad.nn import Conv2d
|
||||
|
||||
def save_resnet_block(resnet_block, path):
|
||||
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
save_group_norm(resnet_block.norm1, pathlib.Path(path, 'norm1'))
|
||||
save_conv2d(resnet_block.conv1, pathlib.Path(path, 'conv1'))
|
||||
save_group_norm(resnet_block.norm2, pathlib.Path(path, 'norm2'))
|
||||
save_conv2d(resnet_block.conv2, pathlib.Path(path, 'conv2'))
|
||||
|
||||
if isinstance(resnet_block.nin_shortcut, Conv2d):
|
||||
save_conv2d(resnet_block.nin_shortcut, pathlib.Path(path, 'nin_shortcut'))
|
||||
|
||||
def save_attn_block(attn_block, path):
|
||||
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
save_group_norm(attn_block.norm, pathlib.Path(path, 'norm'))
|
||||
save_conv2d(attn_block.q, pathlib.Path(path, 'q'))
|
||||
save_conv2d(attn_block.k, pathlib.Path(path, 'k'))
|
||||
save_conv2d(attn_block.v, pathlib.Path(path, 'v'))
|
||||
save_conv2d(attn_block.proj_out, pathlib.Path(path, 'proj_out'))
|
||||
|
||||
def save_mid(mid, path):
|
||||
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
save_resnet_block(mid.block_1, pathlib.Path(path, 'block_1'))
|
||||
save_attn_block(mid.attn_1, pathlib.Path(path, 'attn'))
|
||||
save_resnet_block(mid.block_2, pathlib.Path(path, 'block_2'))
|
||||
|
||||
def save_decoder_block(decoder_block, path):
|
||||
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if 'block' in decoder_block:
|
||||
save_resnet_block(decoder_block["block"][0], pathlib.Path(path, 'res1'))
|
||||
save_resnet_block(decoder_block["block"][1], pathlib.Path(path, 'res2'))
|
||||
save_resnet_block(decoder_block["block"][2], pathlib.Path(path, 'res3'))
|
||||
|
||||
if 'upsample' in decoder_block:
|
||||
save_conv2d(decoder_block['upsample']['conv'], pathlib.Path(path, 'upsampler'))
|
||||
|
||||
|
||||
def save_decoder(decoder, path):
|
||||
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
save_conv2d(decoder.conv_in, pathlib.Path(path, 'conv_in'))
|
||||
save_mid(decoder.mid, pathlib.Path(path, 'mid'))
|
||||
|
||||
for i, block in enumerate(decoder.up[::-1]):
|
||||
print(i)
|
||||
if isinstance(block['block'][0].nin_shortcut, Conv2d):
|
||||
print(block['block'][0].nin_shortcut.weight.shape)
|
||||
save_decoder_block(block, pathlib.Path(path, f'blocks/{i}'))
|
||||
|
||||
save_scalar(len(decoder.up), "n_block", path)
|
||||
save_group_norm(decoder.norm_out, pathlib.Path(path, 'norm_out'))
|
||||
save_conv2d(decoder.conv_out, pathlib.Path(path, 'conv_out'))
|
||||
|
||||
def save_encoder_block(encoder_block, path):
|
||||
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if 'block' in encoder_block:
|
||||
save_resnet_block(encoder_block["block"][0], pathlib.Path(path, 'res1'))
|
||||
save_resnet_block(encoder_block["block"][1], pathlib.Path(path, 'res2'))
|
||||
|
||||
if 'downsample' in encoder_block:
|
||||
save_padded_conv2d(encoder_block['downsample']['conv'], pathlib.Path(path, 'downsampler'))
|
||||
|
||||
def save_encoder(encoder, path):
|
||||
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
save_conv2d(encoder.conv_in, pathlib.Path(path, 'conv_in'))
|
||||
save_mid(encoder.mid, pathlib.Path(path, 'mid'))
|
||||
|
||||
for i, block in enumerate(encoder.down):
|
||||
save_encoder_block(block, pathlib.Path(path, f'blocks/{i}'))
|
||||
|
||||
save_scalar(len(encoder.down), "n_block", path)
|
||||
save_group_norm(encoder.norm_out, pathlib.Path(path, 'norm_out'))
|
||||
save_conv2d(encoder.conv_out, pathlib.Path(path, 'conv_out'))
|
||||
|
||||
|
||||
def save_autoencoder(autoencoder, path):
|
||||
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
save_encoder(autoencoder.encoder, pathlib.Path(path, 'encoder'))
|
||||
save_decoder(autoencoder.decoder, pathlib.Path(path, 'decoder'))
|
||||
save_conv2d(autoencoder.quant_conv, pathlib.Path(path, 'quant_conv'))
|
||||
save_conv2d(autoencoder.post_quant_conv, pathlib.Path(path, 'post_quant_conv'))
|
||||
BIN
python/bpe_simple_vocab_16e6.txt.gz
Normal file
BIN
python/bpe_simple_vocab_16e6.txt.gz
Normal file
Binary file not shown.
40
python/clip.py
Normal file
40
python/clip.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import pathlib
|
||||
import save
|
||||
from save import *
|
||||
|
||||
def save_clipmlp(clip_mlp, path):
|
||||
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
||||
save_linear(clip_mlp.fc1, pathlib.Path(path, 'fc1'))
|
||||
save_linear(clip_mlp.fc2, pathlib.Path(path, 'fc2'))
|
||||
|
||||
def save_clip_attention(clip_attention, path):
|
||||
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
||||
save_linear(clip_attention.k_proj, pathlib.Path(path, 'key'))
|
||||
save_linear(clip_attention.v_proj, pathlib.Path(path, 'value'))
|
||||
save_linear(clip_attention.q_proj, pathlib.Path(path, 'query'))
|
||||
save_linear(clip_attention.out_proj, pathlib.Path(path, 'out'))
|
||||
save_scalar(clip_attention.num_heads, 'n_head', path)
|
||||
|
||||
def save_clip_encoder_layer(clip_encoder_layer, path):
|
||||
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
||||
save_clip_attention(clip_encoder_layer.self_attn, pathlib.Path(path, 'attn'))
|
||||
save_layer_norm(clip_encoder_layer.layer_norm1, pathlib.Path(path, 'attn_ln'))
|
||||
save_clipmlp(clip_encoder_layer.mlp, pathlib.Path(path, 'mlp'))
|
||||
save_layer_norm(clip_encoder_layer.layer_norm2, pathlib.Path(path, 'mlp_ln'))
|
||||
|
||||
def save_clip_encoder(clip_encoder, path):
|
||||
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
||||
for i, layer in enumerate(clip_encoder.layers):
|
||||
save_clip_encoder_layer(layer, pathlib.Path(path, f'blocks/{i}'))
|
||||
save_scalar(len(clip_encoder.layers), "n_layer", path)
|
||||
|
||||
def save_clip_text_embeddings(clip_text_embeddings, path):
|
||||
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
||||
save_embedding(clip_text_embeddings.token_embedding, pathlib.Path(path, 'token_embedding'))
|
||||
save_embedding(clip_text_embeddings.position_embedding, pathlib.Path(path, 'position_embedding'))
|
||||
|
||||
def save_clip_text_transformer(clip_text_transformer, path):
|
||||
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
||||
save_clip_text_embeddings(clip_text_transformer.embeddings, path)
|
||||
save_clip_encoder(clip_text_transformer.encoder, path)
|
||||
save_layer_norm(clip_text_transformer.final_layer_norm, pathlib.Path(path, 'layer_norm'))
|
||||
100
python/save.py
Normal file
100
python/save.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import pathlib
|
||||
import numpy as np
|
||||
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
def save_scalar(s, name, path):
|
||||
s = np.array([1.0, float(s)]).astype(np.float32)
|
||||
np.save(pathlib.Path(path, f'{name}.npy'), s)
|
||||
|
||||
def save_tensor(tensor, name, path):
|
||||
tensor_numpy = tensor.numpy()
|
||||
tensor_dims = np.array(tensor_numpy.shape)
|
||||
tensor_values = tensor_numpy.flatten()
|
||||
tensor_to_save = np.concatenate((tensor_dims, tensor_values)).astype(np.float32)
|
||||
np.save(pathlib.Path(path, f'{name}.npy'), tensor_to_save)
|
||||
|
||||
def save_linear(linear, path):
|
||||
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
||||
save_tensor(linear.weight.transpose(), 'weight', path) # PyTorch and Tinygrad strangely transpose linear weights so reverse that
|
||||
if linear.bias is not None:
|
||||
save_tensor(linear.bias, 'bias', path)
|
||||
|
||||
def save_layer_norm(layer_norm, path):
|
||||
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
||||
save_tensor(layer_norm.weight, 'weight', path)
|
||||
save_tensor(layer_norm.bias, 'bias', path)
|
||||
save_scalar(layer_norm.eps, 'eps', path)
|
||||
|
||||
def save_group_norm(layer_norm, path):
|
||||
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
||||
if layer_norm.weight is not None:
|
||||
save_tensor(layer_norm.weight, 'weight', path)
|
||||
if layer_norm.bias is not None:
|
||||
save_tensor(layer_norm.bias, 'bias', path)
|
||||
save_scalar(layer_norm.eps, 'eps', path)
|
||||
save_scalar(layer_norm.num_groups, 'n_group', path)
|
||||
save_scalar(layer_norm.num_channels, 'n_channel', path)
|
||||
|
||||
def to_tuple_tensor(val):
|
||||
if isinstance(val, tuple):
|
||||
# Convert tuple to Tensor
|
||||
if len(val) == 1:
|
||||
return Tensor([val[0], val[0]])
|
||||
elif len(val) == 2:
|
||||
return Tensor([val[0], val[1]])
|
||||
else:
|
||||
raise ValueError('Tuple should be of length 1 or 2 only.')
|
||||
else:
|
||||
# Treat as scalar and convert to Tensor
|
||||
return Tensor([val, val])
|
||||
|
||||
def save_conv2d(conv2d, path):
|
||||
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
save_tensor(conv2d.weight, 'weight', path)
|
||||
if conv2d.bias is not None:
|
||||
save_tensor(conv2d.bias, 'bias', path)
|
||||
save_tensor(to_tuple_tensor(conv2d.stride), 'stride', path)
|
||||
save_tensor(to_tuple_tensor(conv2d.padding), 'padding', path)
|
||||
save_tensor(to_tuple_tensor(conv2d.dilation), 'dilation', path)
|
||||
save_scalar(conv2d.groups, "n_group", path)
|
||||
save_tensor(to_tuple_tensor(conv2d.kernel_size), 'kernel_size', path)
|
||||
|
||||
assert conv2d.groups == 1
|
||||
in_channels = conv2d.weight.shape[1]
|
||||
out_channels = conv2d.weight.shape[0]
|
||||
save_scalar(in_channels, "n_channels_in", path)
|
||||
save_scalar(out_channels, "n_channels_out", path)
|
||||
|
||||
def save_padded_conv2d(padded_conv2d, path):
|
||||
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Store conv2d layer weights
|
||||
orig_padding = padded_conv2d.padding
|
||||
padded_conv2d.padding = (0, 0)
|
||||
save_conv2d(padded_conv2d, f"{path}/conv")
|
||||
padded_conv2d.padding = orig_padding
|
||||
|
||||
# Dimensions: in-channels and out-channels
|
||||
assert padded_conv2d.groups == 1
|
||||
channels = (padded_conv2d.weight.shape[1], padded_conv2d.weight.shape[0])
|
||||
save_tensor(to_tuple_tensor(channels), 'channels', path)
|
||||
|
||||
assert len(padded_conv2d.kernel_size) == 1 or padded_conv2d.kernel_size[0] == padded_conv2d.kernel_size[1]
|
||||
save_scalar(padded_conv2d.kernel_size[0], 'kernel_size', path)
|
||||
|
||||
# Stride
|
||||
assert not isinstance(padded_conv2d.stride, tuple) or len(padded_conv2d.stride) == 1
|
||||
save_scalar(padded_conv2d.stride, 'stride', path)
|
||||
|
||||
# Padding
|
||||
padding = [padded_conv2d.padding[0], padded_conv2d.padding[1],
|
||||
padded_conv2d.padding[2], padded_conv2d.padding[3]]
|
||||
save_tensor(Tensor(padding), 'padding', path)
|
||||
|
||||
|
||||
def save_embedding(embedding, path):
|
||||
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
||||
save_tensor(embedding.weight, 'weight', path)
|
||||
|
||||
736
python/stable_diffusion.py
Normal file
736
python/stable_diffusion.py
Normal file
@@ -0,0 +1,736 @@
|
||||
# https://arxiv.org/pdf/2112.10752.pdf
|
||||
# https://github.com/ekagra-ranjan/huggingface-blog/blob/main/stable_diffusion.md
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
import gzip, argparse, math, re
|
||||
from functools import lru_cache
|
||||
from collections import namedtuple
|
||||
|
||||
from tqdm import tqdm
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import dtypes, GlobalCounters
|
||||
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
|
||||
from extra.utils import download_file
|
||||
from tinygrad.state import torch_load, load_state_dict
|
||||
|
||||
# TODO: refactor AttnBlock, CrossAttention, CLIPAttention to share code
|
||||
|
||||
class AttnBlock:
|
||||
def __init__(self, in_channels):
|
||||
self.norm = GroupNorm(32, in_channels)
|
||||
self.q = Conv2d(in_channels, in_channels, 1)
|
||||
self.k = Conv2d(in_channels, in_channels, 1)
|
||||
self.v = Conv2d(in_channels, in_channels, 1)
|
||||
self.proj_out = Conv2d(in_channels, in_channels, 1)
|
||||
|
||||
# copied from AttnBlock in ldm repo
|
||||
def __call__(self, x):
|
||||
h_ = self.norm(x)
|
||||
q,k,v = self.q(h_), self.k(h_), self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b,c,h,w = q.shape
|
||||
q = q.reshape(b,c,h*w)
|
||||
q = q.permute(0,2,1) # b,hw,c
|
||||
k = k.reshape(b,c,h*w) # b,c,hw
|
||||
w_ = q @ k
|
||||
w_ = w_ * (c**(-0.5))
|
||||
w_ = w_.softmax()
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b,c,h*w)
|
||||
w_ = w_.permute(0,2,1)
|
||||
h_ = v @ w_
|
||||
h_ = h_.reshape(b,c,h,w)
|
||||
|
||||
return x + self.proj_out(h_)
|
||||
|
||||
class ResnetBlock:
|
||||
def __init__(self, in_channels, out_channels=None):
|
||||
self.norm1 = GroupNorm(32, in_channels)
|
||||
self.conv1 = Conv2d(in_channels, out_channels, 3, padding=1)
|
||||
self.norm2 = GroupNorm(32, out_channels)
|
||||
self.conv2 = Conv2d(out_channels, out_channels, 3, padding=1)
|
||||
self.nin_shortcut = Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else lambda x: x
|
||||
|
||||
def __call__(self, x):
|
||||
h = self.conv1(self.norm1(x).swish())
|
||||
'''v = h
|
||||
print(v.shape)
|
||||
print(v[0, 0:10, :, :].numpy())'''
|
||||
h = self.conv2(self.norm2(h).swish())
|
||||
return self.nin_shortcut(x) + h
|
||||
|
||||
class Mid:
|
||||
def __init__(self, block_in):
|
||||
self.block_1 = ResnetBlock(block_in, block_in)
|
||||
self.attn_1 = AttnBlock(block_in)
|
||||
self.block_2 = ResnetBlock(block_in, block_in)
|
||||
|
||||
def __call__(self, x):
|
||||
return x.sequential([self.block_1, self.attn_1, self.block_2])
|
||||
|
||||
class Decoder:
|
||||
def __init__(self):
|
||||
sz = [(128, 256), (256, 512), (512, 512), (512, 512)]
|
||||
self.conv_in = Conv2d(4,512,3, padding=1)
|
||||
self.mid = Mid(512)
|
||||
|
||||
arr = []
|
||||
for i,s in enumerate(sz):
|
||||
arr.append({"block":
|
||||
[ResnetBlock(s[1], s[0]),
|
||||
ResnetBlock(s[0], s[0]),
|
||||
ResnetBlock(s[0], s[0])]})
|
||||
if i != 0: arr[-1]['upsample'] = {"conv": Conv2d(s[0], s[0], 3, padding=1)}
|
||||
self.up = arr
|
||||
|
||||
self.norm_out = GroupNorm(32, 128)
|
||||
self.conv_out = Conv2d(128, 3, 3, padding=1)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.conv_in(x)
|
||||
x = self.mid(x)
|
||||
|
||||
for l in self.up[::-1]:
|
||||
print("decode", x.shape)
|
||||
for b in l['block']:
|
||||
x = b(x)
|
||||
if 'upsample' in l:
|
||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html ?
|
||||
bs,c,py,px = x.shape
|
||||
x = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py*2, px*2)
|
||||
x = l['upsample']['conv'](x)
|
||||
x.realize()
|
||||
|
||||
return self.conv_out(self.norm_out(x).swish())
|
||||
|
||||
class Encoder:
|
||||
def __init__(self):
|
||||
sz = [(128, 128), (128, 256), (256, 512), (512, 512)]
|
||||
self.conv_in = Conv2d(3,128,3, padding=1)
|
||||
|
||||
arr = []
|
||||
for i,s in enumerate(sz):
|
||||
arr.append({"block":
|
||||
[ResnetBlock(s[0], s[1]),
|
||||
ResnetBlock(s[1], s[1])]})
|
||||
if i != 3: arr[-1]['downsample'] = {"conv": Conv2d(s[1], s[1], 3, stride=2, padding=(0,1,0,1))}
|
||||
self.down = arr
|
||||
|
||||
self.mid = Mid(512)
|
||||
self.norm_out = GroupNorm(32, 512)
|
||||
self.conv_out = Conv2d(512, 8, 3, padding=1)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.conv_in(x)
|
||||
|
||||
for i, l in enumerate(self.down):
|
||||
print("encode", x.shape)
|
||||
for b in l['block']: x = b(x)
|
||||
if 'downsample' in l: x = l['downsample']['conv'](x)
|
||||
|
||||
x = self.mid(x)
|
||||
return self.conv_out(self.norm_out(x).swish())
|
||||
|
||||
class AutoencoderKL:
|
||||
def __init__(self):
|
||||
self.encoder = Encoder()
|
||||
self.decoder = Decoder()
|
||||
self.quant_conv = Conv2d(8, 8, 1)
|
||||
self.post_quant_conv = Conv2d(4, 4, 1)
|
||||
|
||||
def __call__(self, x):
|
||||
latent = self.encoder(x)
|
||||
latent = self.quant_conv(latent)
|
||||
latent = latent[:, 0:4] # only the means
|
||||
print("latent", latent.shape)
|
||||
latent = self.post_quant_conv(latent)
|
||||
return self.decoder(latent)
|
||||
|
||||
# not to be confused with ResnetBlock
|
||||
class ResBlock:
|
||||
def __init__(self, channels, emb_channels, out_channels):
|
||||
self.in_layers = [
|
||||
GroupNorm(32, channels),
|
||||
Tensor.silu,
|
||||
Conv2d(channels, out_channels, 3, padding=1)
|
||||
]
|
||||
self.emb_layers = [
|
||||
Tensor.silu,
|
||||
Linear(emb_channels, out_channels)
|
||||
]
|
||||
self.out_layers = [
|
||||
GroupNorm(32, out_channels),
|
||||
Tensor.silu,
|
||||
lambda x: x, # needed for weights loading code to work
|
||||
Conv2d(out_channels, out_channels, 3, padding=1)
|
||||
]
|
||||
self.skip_connection = Conv2d(channels, out_channels, 1) if channels != out_channels else lambda x: x
|
||||
|
||||
def __call__(self, x, emb):
|
||||
h = x.sequential(self.in_layers)
|
||||
emb_out = emb.sequential(self.emb_layers)
|
||||
h = h + emb_out.reshape(*emb_out.shape, 1, 1)
|
||||
h = h.sequential(self.out_layers)
|
||||
ret = self.skip_connection(x) + h
|
||||
return ret
|
||||
|
||||
class CrossAttention:
|
||||
def __init__(self, query_dim, context_dim, n_heads, d_head):
|
||||
self.to_q = Linear(query_dim, n_heads*d_head, bias=False)
|
||||
self.to_k = Linear(context_dim, n_heads*d_head, bias=False)
|
||||
self.to_v = Linear(context_dim, n_heads*d_head, bias=False)
|
||||
self.scale = d_head ** -0.5
|
||||
self.num_heads = n_heads
|
||||
self.head_size = d_head
|
||||
self.to_out = [Linear(n_heads*d_head, query_dim)]
|
||||
|
||||
def __call__(self, x, context=None):
|
||||
context = x if context is None else context
|
||||
q,k,v = self.to_q(x), self.to_k(context), self.to_v(context)
|
||||
q = q.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0,2,1,3) # (bs, num_heads, time, head_size)
|
||||
k = k.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0,2,3,1) # (bs, num_heads, head_size, time)
|
||||
v = v.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0,2,1,3) # (bs, num_heads, time, head_size)
|
||||
|
||||
score = q.dot(k) * self.scale
|
||||
weights = score.softmax() # (bs, num_heads, time, time)
|
||||
attention = weights.dot(v).permute(0,2,1,3) # (bs, time, num_heads, head_size)
|
||||
|
||||
h_ = attention.reshape(shape=(x.shape[0], -1, self.num_heads * self.head_size))
|
||||
return h_.sequential(self.to_out)
|
||||
|
||||
class GEGLU:
|
||||
def __init__(self, dim_in, dim_out):
|
||||
self.proj = Linear(dim_in, dim_out * 2)
|
||||
self.dim_out = dim_out
|
||||
|
||||
def __call__(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * gate.gelu()
|
||||
|
||||
class FeedForward:
|
||||
def __init__(self, dim, mult=4):
|
||||
self.net = [
|
||||
GEGLU(dim, dim*mult),
|
||||
lambda x: x, # needed for weights loading code to work
|
||||
Linear(dim*mult, dim)
|
||||
]
|
||||
|
||||
def __call__(self, x):
|
||||
return x.sequential(self.net)
|
||||
|
||||
class BasicTransformerBlock:
|
||||
def __init__(self, dim, context_dim, n_heads, d_head):
|
||||
self.attn1 = CrossAttention(dim, dim, n_heads, d_head)
|
||||
self.ff = FeedForward(dim)
|
||||
self.attn2 = CrossAttention(dim, context_dim, n_heads, d_head)
|
||||
self.norm1 = LayerNorm(dim)
|
||||
self.norm2 = LayerNorm(dim)
|
||||
self.norm3 = LayerNorm(dim)
|
||||
|
||||
def __call__(self, x, context=None):
|
||||
x = self.attn1(self.norm1(x)) + x
|
||||
x = self.attn2(self.norm2(x), context=context) + x
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
||||
class SpatialTransformer:
|
||||
def __init__(self, channels, context_dim, n_heads, d_head):
|
||||
self.norm = GroupNorm(32, channels)
|
||||
assert channels == n_heads * d_head
|
||||
self.proj_in = Conv2d(channels, n_heads * d_head, 1)
|
||||
self.transformer_blocks = [BasicTransformerBlock(channels, context_dim, n_heads, d_head)]
|
||||
self.proj_out = Conv2d(n_heads * d_head, channels, 1)
|
||||
|
||||
def __call__(self, x, context=None):
|
||||
b, c, h, w = x.shape
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
x = self.proj_in(x)
|
||||
x = x.reshape(b, c, h*w).permute(0,2,1)
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, context=context)
|
||||
x = x.permute(0,2,1).reshape(b, c, h, w)
|
||||
ret = self.proj_out(x) + x_in
|
||||
return ret
|
||||
|
||||
class Downsample:
|
||||
def __init__(self, channels):
|
||||
self.op = Conv2d(channels, channels, 3, stride=2, padding=1)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.op(x)
|
||||
|
||||
class Upsample:
|
||||
def __init__(self, channels):
|
||||
self.conv = Conv2d(channels, channels, 3, padding=1)
|
||||
|
||||
def __call__(self, x):
|
||||
bs,c,py,px = x.shape
|
||||
x = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py*2, px*2)
|
||||
return self.conv(x)
|
||||
|
||||
def timestep_embedding(timesteps, dim, max_period=10000):
|
||||
half = dim // 2
|
||||
freqs = (-math.log(max_period) * Tensor.arange(half) / half).exp()
|
||||
args = timesteps * freqs
|
||||
return Tensor.cat(args.cos(), args.sin()).reshape(1, -1)
|
||||
|
||||
class UNetModel:
|
||||
def __init__(self):
|
||||
self.time_embed = [
|
||||
Linear(320, 1280),
|
||||
Tensor.silu,
|
||||
Linear(1280, 1280),
|
||||
]
|
||||
self.input_blocks = [
|
||||
[Conv2d(4, 320, kernel_size=3, padding=1)],
|
||||
[ResBlock(320, 1280, 320), SpatialTransformer(320, 768, 8, 40)],
|
||||
[ResBlock(320, 1280, 320), SpatialTransformer(320, 768, 8, 40)],
|
||||
[Downsample(320)],
|
||||
[ResBlock(320, 1280, 640), SpatialTransformer(640, 768, 8, 80)],
|
||||
[ResBlock(640, 1280, 640), SpatialTransformer(640, 768, 8, 80)],
|
||||
[Downsample(640)],
|
||||
[ResBlock(640, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)],
|
||||
[ResBlock(1280, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)],
|
||||
[Downsample(1280)],
|
||||
[ResBlock(1280, 1280, 1280)],
|
||||
[ResBlock(1280, 1280, 1280)]
|
||||
]
|
||||
self.middle_block = [
|
||||
ResBlock(1280, 1280, 1280),
|
||||
SpatialTransformer(1280, 768, 8, 160),
|
||||
ResBlock(1280, 1280, 1280)
|
||||
]
|
||||
self.output_blocks = [
|
||||
[ResBlock(2560, 1280, 1280)],
|
||||
[ResBlock(2560, 1280, 1280)],
|
||||
[ResBlock(2560, 1280, 1280), Upsample(1280)],
|
||||
[ResBlock(2560, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)],
|
||||
[ResBlock(2560, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)],
|
||||
[ResBlock(1920, 1280, 1280), SpatialTransformer(1280, 768, 8, 160), Upsample(1280)],
|
||||
[ResBlock(1920, 1280, 640), SpatialTransformer(640, 768, 8, 80)], # 6
|
||||
[ResBlock(1280, 1280, 640), SpatialTransformer(640, 768, 8, 80)],
|
||||
[ResBlock(960, 1280, 640), SpatialTransformer(640, 768, 8, 80), Upsample(640)],
|
||||
[ResBlock(960, 1280, 320), SpatialTransformer(320, 768, 8, 40)],
|
||||
[ResBlock(640, 1280, 320), SpatialTransformer(320, 768, 8, 40)],
|
||||
[ResBlock(640, 1280, 320), SpatialTransformer(320, 768, 8, 40)],
|
||||
]
|
||||
self.out = [
|
||||
GroupNorm(32, 320),
|
||||
Tensor.silu,
|
||||
Conv2d(320, 4, kernel_size=3, padding=1)
|
||||
]
|
||||
|
||||
def __call__(self, x, timesteps=None, context=None):
|
||||
# TODO: real time embedding
|
||||
t_emb = timestep_embedding(timesteps, 320)
|
||||
emb = t_emb.sequential(self.time_embed)
|
||||
|
||||
|
||||
|
||||
def run(x, bb):
|
||||
if isinstance(bb, ResBlock): x = bb(x, emb)
|
||||
elif isinstance(bb, SpatialTransformer): x = bb(x, context)
|
||||
else: x = bb(x)
|
||||
return x
|
||||
|
||||
saved_inputs = []
|
||||
for i,b in enumerate(self.input_blocks):
|
||||
#print("input block", i)
|
||||
print(x.numpy())
|
||||
for bb in b:
|
||||
x = run(x, bb)
|
||||
saved_inputs.append(x)
|
||||
for bb in self.middle_block:
|
||||
x = run(x, bb)
|
||||
for i,b in enumerate(self.output_blocks):
|
||||
#print("output block", i)
|
||||
x = x.cat(saved_inputs.pop(), dim=1)
|
||||
for bb in b:
|
||||
x = run(x, bb)
|
||||
return x.sequential(self.out)
|
||||
|
||||
class CLIPMLP:
|
||||
def __init__(self):
|
||||
self.fc1 = Linear(768, 3072)
|
||||
self.fc2 = Linear(3072, 768)
|
||||
|
||||
def __call__(self, hidden_states):
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = hidden_states.quick_gelu()
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
class CLIPAttention:
|
||||
def __init__(self):
|
||||
self.embed_dim = 768
|
||||
self.num_heads = 12
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.k_proj = Linear(self.embed_dim, self.embed_dim)
|
||||
self.v_proj = Linear(self.embed_dim, self.embed_dim)
|
||||
self.q_proj = Linear(self.embed_dim, self.embed_dim)
|
||||
self.out_proj = Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
def _shape(self, tensor, seq_len: int, bsz: int):
|
||||
return tensor.reshape(bsz, seq_len, self.num_heads, self.head_dim).permute(0,2,1,3)
|
||||
|
||||
def __call__(self, hidden_states, causal_attention_mask):
|
||||
bsz, tgt_len, embed_dim = hidden_states.shape
|
||||
|
||||
query_states = self.q_proj(hidden_states) * self.scale
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).reshape(*proj_shape)
|
||||
key_states = key_states.reshape(*proj_shape)
|
||||
src_len = key_states.shape[1]
|
||||
value_states = value_states.reshape(*proj_shape)
|
||||
|
||||
attn_weights = query_states @ key_states.permute(0,2,1)
|
||||
|
||||
attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
|
||||
attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
attn_weights = attn_weights.softmax()
|
||||
|
||||
attn_output = attn_weights @ value_states
|
||||
|
||||
attn_output = attn_output.reshape(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||
attn_output = attn_output.permute(0,2,1,3)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
class CLIPEncoderLayer:
|
||||
def __init__(self):
|
||||
self.self_attn = CLIPAttention()
|
||||
self.layer_norm1 = LayerNorm(768)
|
||||
self.mlp = CLIPMLP()
|
||||
self.layer_norm2 = LayerNorm(768)
|
||||
|
||||
def __call__(self, hidden_states, causal_attention_mask):
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states = self.self_attn(hidden_states, causal_attention_mask)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
class CLIPEncoder:
|
||||
def __init__(self):
|
||||
self.layers = [CLIPEncoderLayer() for i in range(12)]
|
||||
|
||||
def __call__(self, hidden_states, causal_attention_mask):
|
||||
for l in self.layers:
|
||||
hidden_states = l(hidden_states, causal_attention_mask)
|
||||
return hidden_states
|
||||
|
||||
class CLIPTextEmbeddings:
|
||||
def __init__(self):
|
||||
self.token_embedding = Embedding(49408, 768)
|
||||
self.position_embedding = Embedding(77, 768)
|
||||
|
||||
def __call__(self, input_ids, position_ids):
|
||||
return self.token_embedding(input_ids) + self.position_embedding(position_ids)
|
||||
|
||||
class CLIPTextTransformer:
|
||||
def __init__(self):
|
||||
self.embeddings = CLIPTextEmbeddings()
|
||||
self.encoder = CLIPEncoder()
|
||||
self.final_layer_norm = LayerNorm(768)
|
||||
|
||||
def __call__(self, input_ids):
|
||||
seq_len = input_ids.shape[1]
|
||||
x = self.embeddings(input_ids, Tensor.arange(seq_len).reshape(1, -1))
|
||||
mask = Tensor.full((1, 1, seq_len, seq_len), float("-inf")).triu(1)
|
||||
x = self.encoder(x, mask)
|
||||
return self.final_layer_norm(x)
|
||||
|
||||
# Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license)
|
||||
@lru_cache()
|
||||
def default_bpe():
|
||||
return Path(__file__).parent.parent / "weights/bpe_simple_vocab_16e6.txt.gz"
|
||||
|
||||
def get_pairs(word):
|
||||
"""Return set of symbol pairs in a word.
|
||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||
"""
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
The reversible bpe codes work on unicode strings.
|
||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8+n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
class ClipTokenizer:
|
||||
def __init__(self, bpe_path: str = default_bpe()):
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
||||
merges = merges[1:49152-256-2+1]
|
||||
merges = [tuple(merge.split()) for merge in merges]
|
||||
vocab = list(bytes_to_unicode().values())
|
||||
vocab = vocab + [v+'</w>' for v in vocab]
|
||||
for merge in merges:
|
||||
vocab.append(''.join(merge))
|
||||
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
||||
self.encoder = dict(zip(vocab, range(len(vocab))))
|
||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
||||
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""", re.IGNORECASE)
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token+'</w>'
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except Exception:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
||||
new_word.append(first+second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
pairs = get_pairs(word)
|
||||
word = ' '.join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def encode(self, text):
|
||||
bpe_tokens = []
|
||||
text = whitespace_clean(text.strip()).lower()
|
||||
for token in re.findall(self.pat, text):
|
||||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
||||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
||||
# Truncation, keeping two slots for start and end tokens.
|
||||
if len(bpe_tokens) > 75:
|
||||
bpe_tokens = bpe_tokens[:75]
|
||||
return [49406] + bpe_tokens + [49407] * (77 - len(bpe_tokens) - 1)
|
||||
|
||||
class StableDiffusion:
|
||||
def __init__(self):
|
||||
self.alphas_cumprod = Tensor.empty(1000)
|
||||
self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel())
|
||||
self.first_stage_model = AutoencoderKL()
|
||||
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = CLIPTextTransformer()))
|
||||
|
||||
# TODO: make __call__ run the model
|
||||
|
||||
# ** ldm.models.autoencoder.AutoencoderKL (done!)
|
||||
# 3x512x512 <--> 4x64x64 (16384)
|
||||
# decode torch.Size([1, 4, 64, 64]) torch.Size([1, 3, 512, 512])
|
||||
# section 4.3 of paper
|
||||
# first_stage_model.encoder, first_stage_model.decoder
|
||||
|
||||
# ** ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
# this is what runs each time to sample. is this the LDM?
|
||||
# input: 4x64x64
|
||||
# output: 4x64x64
|
||||
# model.diffusion_model
|
||||
# it has attention?
|
||||
|
||||
# ** ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
# cond_stage_model.transformer.text_model
|
||||
|
||||
# this is sd-v1-4.ckpt
|
||||
FILENAME = Path(__file__).parent.parent / "weights/sd-v1-4.ckpt"
|
||||
|
||||
import clip as clipsave
|
||||
import autoencoder as autoencodersave
|
||||
import unet as unetsave
|
||||
import stablediffusion as sdsave
|
||||
|
||||
import numpy as np
|
||||
|
||||
if __name__ == "__main__":
|
||||
Tensor.no_grad = True
|
||||
'''clip = CLIPTextTransformer()
|
||||
|
||||
print('Saving model...')
|
||||
clipsave.save_clip_text_transformer(clip, "params")
|
||||
|
||||
input = Tensor([3, 1])
|
||||
output = clip(input.unsqueeze(0))
|
||||
|
||||
print(output[0, 0:2, 0:10].numpy())'''
|
||||
|
||||
'''autoencoder = AutoencoderKL()
|
||||
print('Saving model...')
|
||||
autoencodersave.save_autoencoder(autoencoder, "params")
|
||||
input = Tensor.zeros((1, 3, 10, 10))
|
||||
output = autoencoder(input)
|
||||
print(output.shape)
|
||||
print(output.numpy())'''
|
||||
|
||||
'''unet = UNetModel()
|
||||
print('Saving model...')
|
||||
unetsave.save_unet_model(unet, 'params')
|
||||
input = Tensor.zeros([1, 4, 64, 64])
|
||||
|
||||
context = np.array([0.5, 1.3], dtype=np.float32) # specify dtype when defining the array
|
||||
context = np.repeat(context, 768 // 2)
|
||||
context = np.expand_dims(context, axis=0)
|
||||
context = Tensor(context)
|
||||
|
||||
timesteps = Tensor([1.0])
|
||||
|
||||
output = unet(input, timesteps, context)
|
||||
#print(output.numpy())'''
|
||||
|
||||
|
||||
Tensor.no_grad = True
|
||||
model = StableDiffusion()
|
||||
|
||||
# load in weights
|
||||
download_file('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', FILENAME)
|
||||
load_state_dict(model, torch_load(FILENAME)['state_dict'], strict=False)
|
||||
|
||||
sdsave.save_stable_diffusion(model, "params")
|
||||
|
||||
|
||||
'''parser = argparse.ArgumentParser(description='Run Stable Diffusion', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--steps', type=int, default=5, help="Number of steps in diffusion")
|
||||
parser.add_argument('--prompt', type=str, default="a horse sized cat eating a bagel", help="Phrase to render")
|
||||
parser.add_argument('--out', type=str, default=os.path.join(tempfile.gettempdir(), "rendered.png"), help="Output filename")
|
||||
args = parser.parse_args()
|
||||
|
||||
Tensor.no_grad = True
|
||||
model = StableDiffusion()
|
||||
|
||||
# load in weights
|
||||
download_file('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', FILENAME)
|
||||
load_state_dict(model, torch_load(FILENAME)['state_dict'], strict=False)
|
||||
|
||||
# run through CLIP to get context
|
||||
tokenizer = ClipTokenizer()
|
||||
prompt = Tensor([tokenizer.encode(args.prompt)])
|
||||
context = model.cond_stage_model.transformer.text_model(prompt).realize()
|
||||
print("got CLIP context", context.shape)
|
||||
|
||||
prompt = Tensor([tokenizer.encode("")])
|
||||
unconditional_context = model.cond_stage_model.transformer.text_model(prompt).realize()
|
||||
print("got unconditional CLIP context", unconditional_context.shape)
|
||||
|
||||
# done with clip model
|
||||
del model.cond_stage_model
|
||||
|
||||
def get_model_output(latent, timestep):
|
||||
# put into diffuser
|
||||
latents = model.model.diffusion_model(latent.expand(2, *latent.shape[1:]), timestep.expand(2, *timestep.shape[1:]), unconditional_context.cat(context, dim=0))
|
||||
unconditional_latent, latent = latents[0:1], latents[1:2]
|
||||
|
||||
unconditional_guidance_scale = 7.5
|
||||
e_t = unconditional_latent + unconditional_guidance_scale * (latent - unconditional_latent)
|
||||
return e_t
|
||||
|
||||
timesteps = list(range(1, 1000, 1000//args.steps))
|
||||
print(f"running for {timesteps} timesteps")
|
||||
alphas = [model.alphas_cumprod.numpy()[t] for t in timesteps]
|
||||
alphas_prev = [1.0] + alphas[:-1]
|
||||
|
||||
def get_x_prev_and_pred_x0(x, e_t, index):
|
||||
temperature = 1
|
||||
a_t, a_prev = alphas[index], alphas_prev[index]
|
||||
sigma_t = 0
|
||||
sqrt_one_minus_at = math.sqrt(1-a_t)
|
||||
#print(a_t, a_prev, sigma_t, sqrt_one_minus_at)
|
||||
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / math.sqrt(a_t)
|
||||
|
||||
# direction pointing to x_t
|
||||
dir_xt = math.sqrt(1. - a_prev - sigma_t**2) * e_t
|
||||
noise = sigma_t * Tensor.randn(*x.shape) * temperature
|
||||
|
||||
x_prev = math.sqrt(a_prev) * pred_x0 + dir_xt #+ noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
# start with random noise
|
||||
latent = Tensor.randn(1,4,64,64)
|
||||
|
||||
# this is diffusion
|
||||
for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])):
|
||||
GlobalCounters.reset()
|
||||
t.set_description("%3d %3d" % (index, timestep))
|
||||
e_t = get_model_output(latent, Tensor([timestep]))
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t, index)
|
||||
#e_t_next = get_model_output(x_prev)
|
||||
#e_t_prime = (e_t + e_t_next) / 2
|
||||
#x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t_prime, index)
|
||||
latent = x_prev
|
||||
latent.realize()
|
||||
|
||||
# upsample latent space to image with autoencoder
|
||||
x = model.first_stage_model.post_quant_conv(1/0.18215 * latent)
|
||||
x = model.first_stage_model.decoder(x)
|
||||
|
||||
# make image correct size and scale
|
||||
x = (x + 1.0) / 2.0
|
||||
x = (x.reshape(3,512,512).permute(1,2,0).clip(0,1)*255).cast(dtypes.uint8)
|
||||
print(x.shape)
|
||||
|
||||
# save image
|
||||
from PIL import Image
|
||||
im = Image.fromarray(x.numpy())
|
||||
print(f"saving {args.out}")
|
||||
im.save(args.out)
|
||||
# Open image.
|
||||
im.show()'''
|
||||
12
python/stablediffusion.py
Normal file
12
python/stablediffusion.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from autoencoder import save_autoencoder
|
||||
from unet import save_unet_model
|
||||
from clip import save_clip_text_transformer
|
||||
|
||||
from save import save_scalar, save_tensor
|
||||
|
||||
def save_stable_diffusion(stable_diffusion, path):
|
||||
save_scalar(stable_diffusion.alphas_cumprod.shape[0], "n_steps", path)
|
||||
save_tensor(stable_diffusion.alphas_cumprod, 'alphas_cumprod', path)
|
||||
save_autoencoder(stable_diffusion.autoencoder, 'autoencoder', path)
|
||||
save_unet_model(stable_diffusion.diffusion, 'unet', path)
|
||||
save_clip_text_transformer(stable_diffusion.clip, 'clip', path)
|
||||
54
python/test.py
Normal file
54
python/test.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
import math
|
||||
|
||||
'''import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import torch
|
||||
|
||||
norm = torch.nn.LayerNorm(3)
|
||||
|
||||
tensor = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape((2, 3))
|
||||
|
||||
out = norm(tensor)
|
||||
|
||||
print(out)'''
|
||||
|
||||
'''n_channel = 6
|
||||
norm = nn.LayerNorm(10)
|
||||
|
||||
height = 10
|
||||
width = 10
|
||||
n_elements = height * width * n_channel
|
||||
|
||||
t = torch.arange(0, n_elements, dtype=torch.float32).mul_(10.0 / n_elements).sin().reshape(1, n_channel, height, width)
|
||||
|
||||
out = norm(t)
|
||||
print(out)'''
|
||||
|
||||
def timestep_embedding(timesteps, dim, max_period=10000):
|
||||
half = dim // 2
|
||||
freqs = (-math.log(max_period) * torch.arange(half) / half).exp()
|
||||
args = timesteps * freqs
|
||||
return torch.cat( (args.cos(), args.sin()) ).reshape(1, -1)
|
||||
|
||||
timesteps = Tensor([1, 2, 3]).reshape((3, 1))
|
||||
dim = 10
|
||||
res = timestep_embedding(timesteps, dim)
|
||||
|
||||
print(res)
|
||||
|
||||
'''n_group = 3
|
||||
n_channel = 6
|
||||
norm = nn.GroupNorm(n_group, n_channel)
|
||||
|
||||
height = 10
|
||||
width = 10
|
||||
n_elements = height * width * n_channel
|
||||
|
||||
t = torch.arange(0, n_elements, dtype=torch.float32).mul_(10.0 / n_elements).sin().reshape(1, n_channel, height, width)
|
||||
|
||||
out = norm(t)
|
||||
print(out.flatten())'''
|
||||
41
python/test_tiny.py
Normal file
41
python/test_tiny.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
|
||||
import math
|
||||
|
||||
'''import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import torch
|
||||
|
||||
norm = torch.nn.LayerNorm(3)
|
||||
|
||||
tensor = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape((2, 3))
|
||||
|
||||
out = norm(tensor)
|
||||
|
||||
print(out)'''
|
||||
|
||||
n_channel = 6
|
||||
norm = LayerNorm(10)
|
||||
|
||||
height = 10
|
||||
width = 10
|
||||
n_elements = height * width * n_channel
|
||||
|
||||
t = Tensor.arange(n_elements).mul(10.0 / n_elements).sin().reshape(1, n_channel, height, width)
|
||||
|
||||
out = norm(t)
|
||||
print(out.numpy())
|
||||
|
||||
'''n_group = 3
|
||||
n_channel = 6
|
||||
norm = nn.GroupNorm(n_group, n_channel)
|
||||
|
||||
height = 10
|
||||
width = 10
|
||||
n_elements = height * width * n_channel
|
||||
|
||||
t = torch.arange(0, n_elements, dtype=torch.float32).mul_(10.0 / n_elements).sin().reshape(1, n_channel, height, width)
|
||||
|
||||
out = norm(t)
|
||||
print(out.flatten())'''
|
||||
145
python/tokenizer.py
Normal file
145
python/tokenizer.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import gzip
|
||||
import html
|
||||
import os
|
||||
from functools import lru_cache
|
||||
|
||||
import ftfy
|
||||
import regex as re
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def default_bpe():
|
||||
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
The reversible bpe codes work on unicode strings.
|
||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8+n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
|
||||
def get_pairs(word):
|
||||
"""Return set of symbol pairs in a word.
|
||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||
"""
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
class SimpleTokenizer(object):
|
||||
def __init__(self, bpe_path: str = default_bpe()):
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
||||
merges = merges[1:49152-256-2+1]
|
||||
merges = [tuple(merge.split()) for merge in merges]
|
||||
vocab = list(bytes_to_unicode().values())
|
||||
vocab = vocab + [v+'</w>' for v in vocab]
|
||||
for merge in merges:
|
||||
vocab.append(''.join(merge))
|
||||
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
||||
self.encoder = dict(zip(vocab, range(len(vocab))))
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|startoftext|>'}
|
||||
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token+'</w>'
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
||||
new_word.append(first+second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = ' '.join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def encode(self, text):
|
||||
bpe_tokens = []
|
||||
text = whitespace_clean(basic_clean(text)).lower()
|
||||
for token in re.findall(self.pat, text):
|
||||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
||||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
||||
return bpe_tokens
|
||||
|
||||
def decode(self, tokens):
|
||||
text = ''.join([self.decoder[token] for token in tokens])
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
||||
return text
|
||||
|
||||
if __name__ == "__main__":
|
||||
simple_tokenizer = SimpleTokenizer()
|
||||
#print([simple_tokenizer.byte_encoder[0], simple_tokenizer.byte_encoder[1]])
|
||||
tokens = [0, 1, 59, 67, 23]
|
||||
text = simple_tokenizer.decode(tokens)
|
||||
print(f"Text: {text}")
|
||||
|
||||
text = "Hello world! <|startoftext|>asdf<|startoftext|>"
|
||||
encoded = simple_tokenizer.encode(text)
|
||||
decoded = simple_tokenizer.decode(encoded)
|
||||
print(f"Encoded: {encoded}")
|
||||
print(f"Decoded: {decoded}")
|
||||
153
python/unet.py
Normal file
153
python/unet.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import pathlib
|
||||
import os
|
||||
import save
|
||||
from save import *
|
||||
|
||||
from tinygrad.nn import Conv2d
|
||||
|
||||
def save_res_block(res_block, path):
|
||||
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
||||
# We can't directly save activation functions, but as they are just attribute of the block,
|
||||
# we don't need to save them separately, they will be recreated along with the block.
|
||||
|
||||
# saving group normalization layer
|
||||
save_group_norm(res_block.in_layers[0], os.path.join(path, 'norm_in'))
|
||||
|
||||
# saving the convolutional layer
|
||||
save_conv2d(res_block.in_layers[2], os.path.join(path, 'conv_in'))
|
||||
|
||||
# saving the linear layer
|
||||
save_linear(res_block.emb_layers[1], os.path.join(path, 'lin_embed'))
|
||||
|
||||
# saving group normalization in out_layers
|
||||
save_group_norm(res_block.out_layers[0], os.path.join(path, 'norm_out'))
|
||||
|
||||
# saving the convolutional layer in out_layers
|
||||
save_conv2d(res_block.out_layers[3], os.path.join(path, 'conv_out'))
|
||||
|
||||
# save skip_connection based on the object type
|
||||
if isinstance(res_block.skip_connection, Conv2d):
|
||||
save_conv2d(res_block.skip_connection, os.path.join(path, 'skip_connection'))
|
||||
|
||||
def save_cross_attention(cross_attention, path):
|
||||
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save Linear layers
|
||||
save_linear(cross_attention.to_q, os.path.join(path, 'query'))
|
||||
save_linear(cross_attention.to_k, os.path.join(path, 'key'))
|
||||
save_linear(cross_attention.to_v, os.path.join(path, 'value'))
|
||||
save_linear(cross_attention.to_out[0], os.path.join(path, 'out'))
|
||||
|
||||
# Save parameters
|
||||
save_scalar(cross_attention.num_heads, 'n_head', path)
|
||||
|
||||
def save_geglu(geglu, path):
|
||||
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save Linear layers
|
||||
save_linear(geglu.proj, os.path.join(path, 'proj'))
|
||||
|
||||
def save_feed_forward(feed_forward, path):
|
||||
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save GEGLU module
|
||||
save_geglu(feed_forward.net[0], os.path.join(path, 'geglu'))
|
||||
|
||||
# Save Linear layer
|
||||
save_linear(feed_forward.net[2], os.path.join(path, 'lin'))
|
||||
|
||||
|
||||
def save_basic_transformer_block(basic_transformer_block, path):
|
||||
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save CrossAttention, FeedForward and LayerNorm instances
|
||||
save_cross_attention(basic_transformer_block.attn1, os.path.join(path, 'attn1'))
|
||||
save_feed_forward(basic_transformer_block.ff, os.path.join(path, 'mlp'))
|
||||
save_cross_attention(basic_transformer_block.attn2, os.path.join(path, 'attn2'))
|
||||
|
||||
save_layer_norm(basic_transformer_block.norm1, os.path.join(path, 'norm1'))
|
||||
save_layer_norm(basic_transformer_block.norm2, os.path.join(path, 'norm2'))
|
||||
save_layer_norm(basic_transformer_block.norm3, os.path.join(path, 'norm3'))
|
||||
|
||||
|
||||
|
||||
def save_spatial_transformer(spatial_transformer, path):
|
||||
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save GroupNorm, Conv2d, BasicTransformerBlock instances
|
||||
save_group_norm(spatial_transformer.norm, os.path.join(path, 'norm'))
|
||||
save_conv2d(spatial_transformer.proj_in, os.path.join(path, 'proj_in'))
|
||||
save_basic_transformer_block(spatial_transformer.transformer_blocks[0], os.path.join(path, 'transformer'))
|
||||
save_conv2d(spatial_transformer.proj_out, os.path.join(path, 'proj_out'))
|
||||
|
||||
|
||||
def save_downsample(downsample, path):
|
||||
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save Conv2d instance
|
||||
save_conv2d(downsample.op, path)
|
||||
|
||||
def save_upsample(upsample, path):
|
||||
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save Conv2d instance
|
||||
save_conv2d(upsample.conv, os.path.join(path, 'conv'))
|
||||
|
||||
|
||||
def save_res_transformer_res(block, path):
|
||||
save_res_block(block[0], pathlib.Path(path, 'res1'))
|
||||
save_spatial_transformer(block[1], pathlib.Path(path, 'transformer'))
|
||||
save_res_block(block[2], pathlib.Path(path, 'res2'))
|
||||
|
||||
def save_res_upsample(block, path):
|
||||
save_res_block(block[0], pathlib.Path(path, 'res'))
|
||||
save_upsample(block[1], pathlib.Path(path, 'upsample'))
|
||||
|
||||
def save_res_transformer(block, path):
|
||||
save_res_block(block[0], pathlib.Path(path, 'res'))
|
||||
save_spatial_transformer(block[1], pathlib.Path(path, 'transformer'))
|
||||
|
||||
def save_res_transformer_upsample(block, path):
|
||||
save_res_block(block[0], pathlib.Path(path, 'res'))
|
||||
save_spatial_transformer(block[1], pathlib.Path(path, 'transformer'))
|
||||
save_upsample(block[2], pathlib.Path(path, 'upsample'))
|
||||
|
||||
|
||||
def save_unet_input_blocks(input_blocks, path):
|
||||
save_conv2d(input_blocks[0][0], pathlib.Path(path, 'conv'))
|
||||
save_res_transformer(input_blocks[1], pathlib.Path(path, 'rt1'))
|
||||
save_res_transformer(input_blocks[2], pathlib.Path(path, 'rt2'))
|
||||
save_downsample(input_blocks[3][0], pathlib.Path(path, 'd1'))
|
||||
save_res_transformer(input_blocks[4], pathlib.Path(path, 'rt3'))
|
||||
save_res_transformer(input_blocks[5], pathlib.Path(path, 'rt4'))
|
||||
save_downsample(input_blocks[6][0], pathlib.Path(path, 'd2'))
|
||||
save_res_transformer(input_blocks[7], pathlib.Path(path, 'rt5'))
|
||||
save_res_transformer(input_blocks[8], pathlib.Path(path, 'rt6'))
|
||||
save_downsample(input_blocks[9][0], pathlib.Path(path, 'd3'))
|
||||
save_res_block(input_blocks[10][0], pathlib.Path(path, 'r1'))
|
||||
save_res_block(input_blocks[11][0], pathlib.Path(path, 'r2'))
|
||||
|
||||
def save_unet_output_blocks(output_blocks, path):
|
||||
save_res_block(output_blocks[0][0], pathlib.Path(path, 'r1'))
|
||||
save_res_block(output_blocks[1][0], pathlib.Path(path, 'r2'))
|
||||
save_res_upsample(output_blocks[2], pathlib.Path(path, 'ru'))
|
||||
save_res_transformer(output_blocks[3], pathlib.Path(path, 'rt1'))
|
||||
save_res_transformer(output_blocks[4], pathlib.Path(path, 'rt2'))
|
||||
save_res_transformer_upsample(output_blocks[5], pathlib.Path(path, 'rtu1'))
|
||||
save_res_transformer(output_blocks[6], pathlib.Path(path, 'rt3'))
|
||||
save_res_transformer(output_blocks[7], pathlib.Path(path, 'rt4'))
|
||||
save_res_transformer_upsample(output_blocks[8], pathlib.Path(path, 'rtu2'))
|
||||
save_res_transformer(output_blocks[9], pathlib.Path(path, 'rt5'))
|
||||
save_res_transformer(output_blocks[10], pathlib.Path(path, 'rt6'))
|
||||
save_res_transformer(output_blocks[11], pathlib.Path(path, 'rt7'))
|
||||
|
||||
def save_unet_model(model, path):
|
||||
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
||||
save_linear(model.time_embed[0], pathlib.Path(path, 'lin1_time_embed'))
|
||||
save_linear(model.time_embed[2], pathlib.Path(path, 'lin2_time_embed'))
|
||||
save_unet_input_blocks(model.input_blocks, pathlib.Path(path, 'input_blocks'))
|
||||
save_res_transformer_res(model.middle_block, pathlib.Path(path, 'middle_block'))
|
||||
save_unet_output_blocks(model.output_blocks, pathlib.Path(path, 'output_blocks'))
|
||||
save_group_norm(model.out[0], pathlib.Path(path, 'norm_out'))
|
||||
save_conv2d(model.out[2], pathlib.Path(path, 'conv_out'))
|
||||
|
||||
87
src/helper.rs
Normal file
87
src/helper.rs
Normal file
@@ -0,0 +1,87 @@
|
||||
use burn::{
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
activation::relu,
|
||||
Tensor,
|
||||
Int,
|
||||
Bool,
|
||||
Float,
|
||||
TensorKind,
|
||||
BasicOps,
|
||||
Numeric,
|
||||
Element,
|
||||
},
|
||||
};
|
||||
|
||||
use num_traits::ToPrimitive;
|
||||
|
||||
|
||||
pub fn tensor_max_scalar<B: Backend, const D: usize>(x: Tensor<B, D>, max: f64) -> Tensor<B, D> {
|
||||
relu(x.sub_scalar(max)).add_scalar(max)
|
||||
}
|
||||
|
||||
pub fn tensor_min_scalar<B: Backend, const D: usize>(x: Tensor<B, D>, min: f64) -> Tensor<B, D> {
|
||||
-tensor_max_scalar(-x, -min)
|
||||
}
|
||||
|
||||
pub fn tensor_max<B: Backend, const D: usize>(x: Tensor<B, D>, max: Tensor<B, D>) -> Tensor<B, D> {
|
||||
relu(x - max.clone()) + max
|
||||
}
|
||||
|
||||
pub fn tensor_min<B: Backend, const D: usize>(x: Tensor<B, D>, min: Tensor<B, D>) -> Tensor<B, D> {
|
||||
-tensor_max(-x, -min)
|
||||
}
|
||||
|
||||
pub fn tensor_log10<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
|
||||
let ln10 = (10.0f64).ln();
|
||||
x.log() / ln10
|
||||
}
|
||||
|
||||
pub fn tensor_max_element<B: Backend, const D: usize>(x: Tensor<B, D>) -> f64 {
|
||||
let flat: Tensor<B, 1> = x.flatten(0, D - 1);
|
||||
let max_index = flat.clone().argmax(0);
|
||||
|
||||
flat.select(0, max_index).into_scalar().to_f64().unwrap()
|
||||
}
|
||||
|
||||
pub fn all_zeros<B: Backend, const D: usize>(x: Tensor<B, D>) -> bool {
|
||||
x.powf(2.0).sum().into_scalar().to_f64().unwrap() == 0.0
|
||||
}
|
||||
|
||||
pub fn max_dim<B: Backend>(x: Tensor<B, 2>, dim: usize) -> Tensor<B, 2> {
|
||||
let indices = x.clone().argmax(dim).flatten(0, 1);
|
||||
x.select(dim, indices)
|
||||
}
|
||||
|
||||
pub fn _10pow<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
|
||||
let log10 = (10.0f64).ln();
|
||||
(x * log10).exp()
|
||||
}
|
||||
|
||||
pub fn to_float<B: Backend, const D: usize>(x: Tensor<B, D, Int>) -> Tensor<B, D, Float> {
|
||||
let device = x.device();
|
||||
Tensor::from_data(
|
||||
x
|
||||
.into_data()
|
||||
.convert()
|
||||
).to_device(&device)
|
||||
}
|
||||
|
||||
pub fn to_float_bool<B: Backend, const D: usize>(x: Tensor<B, D, Bool>) -> Tensor<B, D, Float> {
|
||||
let device = x.device();
|
||||
Tensor::from_data(
|
||||
x
|
||||
.into_data()
|
||||
.convert()
|
||||
).to_device(&device)
|
||||
}
|
||||
|
||||
pub fn reverse<B: Backend, const D: usize, K: TensorKind<B> + BasicOps<B> + Numeric<B>>(x: Tensor<B, D, K>, dim: usize) -> Tensor<B, D, K> where <K as BasicOps<B>>::Elem: Element {
|
||||
let len = x.dims()[dim];
|
||||
let indices = -Tensor::arange_device(0..len, &x.device()) + (len - 1) as i64;
|
||||
x.select(dim, indices)
|
||||
}
|
||||
|
||||
pub fn div_roundup(x: usize, y: usize) -> usize {
|
||||
(x + y - 1) / y
|
||||
}
|
||||
5
src/lib.rs
Normal file
5
src/lib.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
#![feature(generic_const_exprs)]
|
||||
|
||||
pub mod model;
|
||||
pub mod tokenizer;
|
||||
pub mod helper;
|
||||
104
src/main.rs
Normal file
104
src/main.rs
Normal file
@@ -0,0 +1,104 @@
|
||||
use stablediffusion::{tokenizer::SimpleTokenizer, model::clip::{*, load::*},
|
||||
model::autoencoder::{*, load::*},
|
||||
model::groupnorm::*,
|
||||
model::unet::{*, load::*},
|
||||
model::stablediffusion::{*, load::*}};
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn,
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
Tensor,
|
||||
},
|
||||
};
|
||||
use burn_tch::{TchBackend, TchDevice};
|
||||
|
||||
fn print_tensor<B: Backend>(x: Tensor<B, 4>) {
|
||||
let data = x/*.slice([0..1, 0..4, 0..10])*/.into_data();
|
||||
println!("{:?}", data);
|
||||
}
|
||||
|
||||
use stablediffusion::helper::to_float;
|
||||
|
||||
fn main() {
|
||||
type Backend = TchBackend<f32>;
|
||||
//let device = TchDevice::Cpu;
|
||||
let device = TchDevice::Cuda(0);
|
||||
|
||||
/*let norm: nn::LayerNorm<Backend> = nn::LayerNormConfig::new(3).init();
|
||||
let tensor = Tensor::from_floats([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape([2, 3]);
|
||||
|
||||
let out = norm.forward(tensor);
|
||||
|
||||
println!("{:?}", out.into_data());
|
||||
|
||||
return;*/
|
||||
|
||||
/*let n_channel = 6;
|
||||
let norm: nn::LayerNorm<Backend> = nn::LayerNormConfig::new(10).init();
|
||||
let height = 10;
|
||||
let width = 10;
|
||||
let n_elements = height * width * n_channel;
|
||||
let t: Tensor<Backend, 4> = to_float(Tensor::arange(0..n_elements)).mul_scalar(10.0 / n_elements as f64).sin().reshape([1, n_channel, height, width]);
|
||||
let out = layernorm(t, 1e-5); //norm.forward(t);
|
||||
println!("{:?}", out.to_data());
|
||||
return;*/
|
||||
|
||||
/*let clip: CLIP<Backend> = load_clip("params", &device).unwrap();
|
||||
let input = Tensor::from_ints([3, 1]);
|
||||
|
||||
let output = clip.forward(input.unsqueeze());
|
||||
print_tensor(output);*/
|
||||
|
||||
/*let autoencoder: Autoencoder<Backend> = load_autoencoder("params", &device).unwrap();
|
||||
let input = Tensor::zeros([1, 3, 10, 10]);
|
||||
let output = autoencoder.forward(input);
|
||||
print_tensor(output);*/
|
||||
|
||||
/*let unet: UNet<Backend> = load_unet("params", &device).unwrap();
|
||||
let input = Tensor::zeros([1, 4, 64, 64]);
|
||||
let context = Tensor::from_floats([0.5, 1.3]).repeat(0, 768 / 2).unsqueeze();
|
||||
let timesteps = Tensor::from_floats([1.0]);
|
||||
|
||||
let output = unet.forward(input, timesteps, context);*/
|
||||
//print_tensor(output);
|
||||
|
||||
let tokenizer = SimpleTokenizer::new().unwrap();
|
||||
let sd: StableDiffusion<Backend> = load_stable_diffusion("params", &device).unwrap();
|
||||
|
||||
let unconditional_guidance_scale = 7.5;
|
||||
let unconditional_context = sd.unconditional_context(&tokenizer);
|
||||
let context = sd.context(&tokenizer, "A rainbow pony is flying.").unsqueeze();
|
||||
|
||||
let n_steps = 5;
|
||||
|
||||
let images = sd.sample_image(context, unconditional_context, unconditional_guidance_scale, n_steps);
|
||||
save_images(&images, "image_samples/", 512, 512).unwrap();
|
||||
}
|
||||
|
||||
use image::{self, ImageResult, ColorType::Rgb8};
|
||||
|
||||
fn save_images(images: &Vec<Vec<u8>>, basepath: &str, width: u32, height: u32) -> ImageResult<()> {
|
||||
for (index, img_data) in images.iter().enumerate() {
|
||||
let path = format!("{}{}.png", basepath, index);
|
||||
image::save_buffer(path, &img_data[..], width, height, Rgb8)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// save red test image
|
||||
fn save_test_image() -> ImageResult<()> {
|
||||
let width = 256;
|
||||
let height = 256;
|
||||
let raw: Vec<_> = (0..width * height).into_iter().flat_map(|i| {
|
||||
let row = i / width;
|
||||
let red = (255.0 * row as f64 / height as f64) as u8;
|
||||
|
||||
[red, 0, 0]
|
||||
}).collect();
|
||||
|
||||
image::save_buffer("red.png", &raw[..], width, height, Rgb8)
|
||||
}
|
||||
47
src/model/attention.rs
Normal file
47
src/model/attention.rs
Normal file
@@ -0,0 +1,47 @@
|
||||
use burn::{
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
activation::softmax,
|
||||
Tensor,
|
||||
},
|
||||
};
|
||||
|
||||
use std::f32::NEG_INFINITY;
|
||||
|
||||
pub fn qkv_attention<B: Backend>(q: Tensor<B, 3>, k: Tensor<B, 3>, v: Tensor<B, 3>, mask: Option<Tensor<B, 2>>, n_head: usize) -> Tensor<B, 3> {
|
||||
let [n_batch, n_qctx, n_state] = q.dims();
|
||||
let [_, n_ctx, _] = k.dims();
|
||||
|
||||
let scale = (n_state as f64 / n_head as f64).powf(-0.25);
|
||||
let n_hstate = n_state / n_head;
|
||||
|
||||
let q = q.reshape([n_batch, n_qctx, n_head, n_hstate]).swap_dims(1, 2) * scale;
|
||||
let k = k.reshape([n_batch, n_ctx, n_head, n_hstate]).swap_dims(1, 2).transpose() * scale;
|
||||
let v = v.reshape([n_batch, n_ctx, n_head, n_hstate]).swap_dims(1, 2);
|
||||
|
||||
let qk = q.matmul(k);
|
||||
|
||||
// apply mask
|
||||
let qk = if let Some(mask) = mask {
|
||||
qk + mask.slice([0..n_qctx, 0..n_ctx]).unsqueeze::<4>()
|
||||
} else {
|
||||
qk
|
||||
};
|
||||
|
||||
// normalize value weightings
|
||||
let w = softmax(qk, 3);
|
||||
let o = w.matmul(v).swap_dims(1, 2).flatten(2, 3);
|
||||
|
||||
return o;
|
||||
}
|
||||
|
||||
pub fn attn_decoder_mask<B: Backend>(seq_length: usize) -> Tensor<B, 2> {
|
||||
let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length]);
|
||||
|
||||
for i in 0..(seq_length - 1) {
|
||||
let values = Tensor::<B, 2>::zeros([1, seq_length - (i + 1)]).add_scalar(NEG_INFINITY);
|
||||
mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values);
|
||||
}
|
||||
|
||||
return mask;
|
||||
}
|
||||
134
src/model/autoencoder/load.rs
Normal file
134
src/model/autoencoder/load.rs
Normal file
@@ -0,0 +1,134 @@
|
||||
use super::GroupNorm;
|
||||
use crate::model::load::*;
|
||||
|
||||
use std::error::Error;
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn,
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
Tensor,
|
||||
},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
use crate::model::groupnorm::load::load_group_norm;
|
||||
|
||||
fn load_conv_self_attention_block<B: Backend>(path: &str, device: &B::Device) -> Result<ConvSelfAttentionBlock<B>, Box<dyn Error>> {
|
||||
let norm = load_group_norm(&format!("{}/{}", path, "norm"), device)?;
|
||||
let q = load_conv2d(&format!("{}/{}", path, "q"), device)?;
|
||||
let k = load_conv2d(&format!("{}/{}", path, "k"), device)?;
|
||||
let v = load_conv2d(&format!("{}/{}", path, "v"), device)?;
|
||||
let proj_out = load_conv2d(&format!("{}/{}", path, "proj_out"), device)?;
|
||||
|
||||
Ok(ConvSelfAttentionBlock { norm, q, k, v, proj_out })
|
||||
}
|
||||
|
||||
fn load_resnet_block<B: Backend>(path: &str, device: &B::Device) -> Result<ResnetBlock<B>, Box<dyn Error>> {
|
||||
let norm1 = load_group_norm(&format!("{}/{}", path, "norm1"), device)?;
|
||||
let silu1 = SILU {};
|
||||
let conv1 = load_conv2d(&format!("{}/{}", path, "conv1"), device)?;
|
||||
let norm2 = load_group_norm(&format!("{}/{}", path, "norm2"), device)?;
|
||||
let silu2 = SILU {};
|
||||
let conv2 = load_conv2d(&format!("{}/{}", path, "conv2"), device)?;
|
||||
let nin_shortcut = load_conv2d(&format!("{}/{}", path, "nin_shortcut"), device).ok();
|
||||
|
||||
Ok(ResnetBlock { norm1, silu1, conv1, norm2, silu2, conv2, nin_shortcut })
|
||||
}
|
||||
|
||||
fn load_mid<B: Backend>(path: &str, device: &B::Device) -> Result<Mid<B>, Box<dyn Error>> {
|
||||
let block_1 = load_resnet_block(&format!("{}/{}", path, "block_1"), device)?;
|
||||
let attn = load_conv_self_attention_block(&format!("{}/{}", path, "attn"), device)?;
|
||||
let block_2 = load_resnet_block(&format!("{}/{}", path, "block_2"), device)?;
|
||||
|
||||
Ok(Mid { block_1, attn, block_2 })
|
||||
}
|
||||
|
||||
fn load_padded_conv2d<B: Backend>(path: &str, device: &B::Device) -> Result<PaddedConv2d<B>, Box<dyn Error>> {
|
||||
let conv = load_conv2d(&format!("{}/{}", path, "conv"), device)?;
|
||||
|
||||
let channels = load_tensor::<B, 1>("channels", path, device)?;
|
||||
let channels = tensor_to_array_2(channels);
|
||||
|
||||
let kernel_size = load_usize::<B>("kernel_size", path, device)?;
|
||||
let stride = load_usize::<B>("stride", path, device)?;
|
||||
|
||||
let padding = load_tensor::<B, 1>("padding", path, device)?;
|
||||
let padding: [usize; 4] = tensor_to_array(padding);
|
||||
let padding = Padding::new(padding[0], padding[1], padding[2], padding[3]);
|
||||
|
||||
let mut record = conv.into_record();
|
||||
|
||||
let mut padded_conv: PaddedConv2d<B> = PaddedConv2dConfig::new(channels, kernel_size, padding).with_stride(stride).init();
|
||||
let padding_actual = PaddingConfig2d::Explicit(padded_conv.padding_actual[0], padded_conv.padding_actual[1]);
|
||||
|
||||
record.padding = <PaddingConfig2d as Module<B>>::into_record(padding_actual);
|
||||
padded_conv.conv = padded_conv.conv.load_record(record);
|
||||
|
||||
|
||||
Ok(padded_conv)
|
||||
}
|
||||
|
||||
fn load_decoder_block<B: Backend>(path: &str, device: &B::Device) -> Result<DecoderBlock<B>, Box<dyn Error>> {
|
||||
let res1 = load_resnet_block(&format!("{}/{}", path, "res1"), device)?;
|
||||
let res2 = load_resnet_block(&format!("{}/{}", path, "res2"), device)?;
|
||||
let res3 = load_resnet_block(&format!("{}/{}", path, "res3"), device)?;
|
||||
let upsampler = load_conv2d(&format!("{}/{}", path, "upsampler"), device).ok();
|
||||
|
||||
Ok(DecoderBlock { res1, res2, res3, upsampler })
|
||||
}
|
||||
|
||||
fn load_encoder_block<B: Backend>(path: &str, device: &B::Device) -> Result<EncoderBlock<B>, Box<dyn Error>> {
|
||||
let res1 = load_resnet_block(&format!("{}/{}", path, "res1"), device)?;
|
||||
let res2 = load_resnet_block(&format!("{}/{}", path, "res2"), device)?;
|
||||
let downsampler = load_padded_conv2d(&format!("{}/{}", path, "downsampler"), device).ok();
|
||||
|
||||
Ok(EncoderBlock { res1, res2, downsampler })
|
||||
}
|
||||
|
||||
fn load_decoder<B: Backend>(path: &str, device: &B::Device) -> Result<Decoder<B>, Box<dyn Error>> {
|
||||
let conv_in = load_conv2d(&format!("{}/{}", path, "conv_in"), device)?;
|
||||
let mid = load_mid(&format!("{}/{}", path, "mid"), device)?;
|
||||
|
||||
let n_block = load_usize::<B>("n_block", path, device)?;
|
||||
let mut blocks = (0..n_block)
|
||||
.into_iter()
|
||||
.map(|i| {
|
||||
load_decoder_block::<B>(&format!("{}/blocks/{}", path, i), device)
|
||||
}).collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let norm_out = load_group_norm(&format!("{}/{}", path, "norm_out"), device)?;
|
||||
let silu = SILU {};
|
||||
let conv_out = load_conv2d(&format!("{}/{}", path, "conv_out"), device)?;
|
||||
|
||||
Ok(Decoder { conv_in, mid, blocks, norm_out, silu, conv_out })
|
||||
}
|
||||
|
||||
fn load_encoder<B: Backend>(path: &str, device: &B::Device) -> Result<Encoder<B>, Box<dyn Error>> {
|
||||
let conv_in = load_conv2d(&format!("{}/{}", path, "conv_in"), device)?;
|
||||
let mid = load_mid(&format!("{}/{}", path, "mid"), device)?;
|
||||
|
||||
let n_block = load_usize::<B>("n_block", path, device)?;
|
||||
let mut blocks = (0..n_block)
|
||||
.into_iter()
|
||||
.map(|i| {
|
||||
load_encoder_block::<B>(&format!("{}/blocks/{}", path, i), device)
|
||||
}).collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let norm_out = load_group_norm(&format!("{}/{}", path, "norm_out"), device)?;
|
||||
let silu = SILU {};
|
||||
let conv_out = load_conv2d(&format!("{}/{}", path, "conv_out"), device)?;
|
||||
|
||||
Ok(Encoder { conv_in, mid, blocks, norm_out, silu, conv_out })
|
||||
}
|
||||
|
||||
pub fn load_autoencoder<B: Backend>(path: &str, device: &B::Device) -> Result<Autoencoder<B>, Box<dyn Error>> {
|
||||
let encoder = load_encoder(&format!("{}/{}", path, "encoder"), device)?;
|
||||
let decoder = load_decoder(&format!("{}/{}", path, "decoder"), device)?;
|
||||
let quant_conv = load_conv2d(&format!("{}/{}", path, "quant_conv"), device)?;
|
||||
let post_quant_conv = load_conv2d(&format!("{}/{}", path, "post_quant_conv"), device)?;
|
||||
|
||||
Ok(Autoencoder { encoder, decoder, quant_conv, post_quant_conv })
|
||||
}
|
||||
530
src/model/autoencoder/mod.rs
Normal file
530
src/model/autoencoder/mod.rs
Normal file
@@ -0,0 +1,530 @@
|
||||
pub mod load;
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn::{self, PaddingConfig2d, conv::{Conv2d, Conv2dConfig, Conv2dRecord}},
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
activation::{softmax, sigmoid},
|
||||
module::embedding,
|
||||
Tensor,
|
||||
Distribution,
|
||||
Int,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::helper::div_roundup;
|
||||
|
||||
use super::silu::*;
|
||||
use super::groupnorm::*;
|
||||
use super::attention::qkv_attention;
|
||||
|
||||
use std::iter;
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct AutoencoderConfig {}
|
||||
|
||||
impl AutoencoderConfig {
|
||||
pub fn init<B: Backend>(&self) -> Autoencoder<B> {
|
||||
let encoder = EncoderConfig::new(vec![(128, 128), (128, 256), (256, 512), (512, 512)], 32, 8).init();
|
||||
let decoder = DecoderConfig::new(vec![(512, 512), (512, 512), (512, 256), (256, 128)], 32).init();
|
||||
let quant_conv = Conv2dConfig::new([8, 8], [1, 1]).init();
|
||||
let post_quant_conv = Conv2dConfig::new([4, 4], [1, 1]).init();
|
||||
|
||||
Autoencoder {
|
||||
encoder,
|
||||
decoder,
|
||||
quant_conv,
|
||||
post_quant_conv,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fn print_tensor<B: Backend>(x: Tensor<B, 4>) {
|
||||
let [_, channels, height, width] = x.dims();
|
||||
let channels = channels.min(10);
|
||||
let data = x.slice([0..1, 0..channels, 0..height, 0..width]).into_data();
|
||||
println!("{:?}", data);
|
||||
}
|
||||
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Autoencoder<B: Backend> {
|
||||
encoder: Encoder<B>,
|
||||
decoder: Decoder<B>,
|
||||
quant_conv: Conv2d<B>,
|
||||
post_quant_conv: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Autoencoder<B> {
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
self.decode_latent( self.encode_image(x) )
|
||||
}
|
||||
|
||||
pub fn encode_image(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let [n_batch, _, _, _] = x.dims();
|
||||
let latent = self.encoder.forward(x);
|
||||
let latent = self.quant_conv.forward(latent);
|
||||
let latent = latent.slice([0..n_batch, 0..4]);
|
||||
latent
|
||||
}
|
||||
|
||||
pub fn decode_latent(&self, latent: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let latent = self.post_quant_conv.forward(latent);
|
||||
self.decoder.forward(latent)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct EncoderConfig {
|
||||
channels: Vec<(usize, usize)>,
|
||||
n_group: usize,
|
||||
n_channels_out: usize,
|
||||
}
|
||||
|
||||
impl EncoderConfig {
|
||||
fn init<B: Backend>(&self) -> Encoder<B> {
|
||||
let n_expanded_channels_initial = self.channels.first().map(|f| f.1).expect("Channels must not be empty.");
|
||||
let n_expanded_channels_final = self.channels.first().unwrap().0;
|
||||
|
||||
let conv_in = Conv2dConfig::new([3, n_expanded_channels_initial], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
||||
|
||||
let blocks = self.channels.iter().enumerate().map(|(i, &(n_channel_in, n_channel_out))| {
|
||||
let downsample = i != self.channels.len() - 1;
|
||||
EncoderBlockConfig::new(n_channel_in, n_channel_out, downsample).init()
|
||||
}).collect();
|
||||
|
||||
let mid = MidConfig::new(n_expanded_channels_final).init();
|
||||
let norm_out = GroupNormConfig::new(self.n_group, n_expanded_channels_final).init();
|
||||
let silu = SILU::new();
|
||||
let conv_out = Conv2dConfig::new([n_expanded_channels_final, self.n_channels_out], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
||||
|
||||
Encoder {
|
||||
conv_in,
|
||||
mid,
|
||||
blocks,
|
||||
norm_out,
|
||||
silu,
|
||||
conv_out,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Encoder<B: Backend> {
|
||||
conv_in: Conv2d<B>,
|
||||
mid: Mid<B>,
|
||||
blocks: Vec<EncoderBlock<B>>,
|
||||
norm_out: GroupNorm<B>,
|
||||
silu: SILU,
|
||||
conv_out: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Encoder<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let x = self.conv_in.forward(x);
|
||||
|
||||
let mut x = x;
|
||||
for block in &self.blocks {
|
||||
x = block.forward(x);
|
||||
}
|
||||
|
||||
let x = self.mid.forward(x);
|
||||
self.conv_out.forward( self.silu.forward( self.norm_out.forward(x) ) )
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct DecoderConfig {
|
||||
channels: Vec<(usize, usize)>,
|
||||
n_group: usize,
|
||||
}
|
||||
|
||||
impl DecoderConfig {
|
||||
fn init<B: Backend>(&self) -> Decoder<B> {
|
||||
let n_expanded_channels = self.channels.first().map(|f| f.0).expect("Channels must not be empty.");
|
||||
let n_condensed_channels = self.channels.last().unwrap().1;
|
||||
|
||||
let conv_in = Conv2dConfig::new([4, n_expanded_channels], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
||||
let mid = MidConfig::new(n_expanded_channels).init();
|
||||
|
||||
let blocks = self.channels.iter().enumerate().map(|(i, &(n_channel_in, n_channel_out))| {
|
||||
let upsample = i != self.channels.len() - 1;
|
||||
DecoderBlockConfig::new(n_channel_in, n_channel_out, upsample).init()
|
||||
}).collect();
|
||||
|
||||
let norm_out = GroupNormConfig::new(self.n_group, n_condensed_channels).init();
|
||||
let silu = SILU::new();
|
||||
let conv_out = Conv2dConfig::new([n_condensed_channels, 3], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
||||
|
||||
Decoder {
|
||||
conv_in,
|
||||
mid,
|
||||
blocks,
|
||||
norm_out,
|
||||
silu,
|
||||
conv_out,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Decoder<B: Backend> {
|
||||
conv_in: Conv2d<B>,
|
||||
mid: Mid<B>,
|
||||
blocks: Vec<DecoderBlock<B>>,
|
||||
norm_out: GroupNorm<B>,
|
||||
silu: SILU,
|
||||
conv_out: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Decoder<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let x = self.conv_in.forward(x);
|
||||
let x = self.mid.forward(x);
|
||||
|
||||
let mut x = x;
|
||||
for block in &self.blocks {
|
||||
x = block.forward(x);
|
||||
}
|
||||
|
||||
self.conv_out.forward( self.silu.forward( self.norm_out.forward(x) ) )
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct EncoderBlockConfig {
|
||||
n_channels_in: usize,
|
||||
n_channels_out: usize,
|
||||
downsample: bool,
|
||||
}
|
||||
|
||||
impl EncoderBlockConfig {
|
||||
fn init<B: Backend>(&self) -> EncoderBlock<B> {
|
||||
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init();
|
||||
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
|
||||
let downsampler = if self.downsample {
|
||||
let padding = Padding::new(0, 1, 0, 1);
|
||||
Some( PaddedConv2dConfig::new([self.n_channels_out, self.n_channels_out], 3, padding).with_stride(2).init() )
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
EncoderBlock {
|
||||
res1,
|
||||
res2,
|
||||
downsampler,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct EncoderBlock<B: Backend> {
|
||||
res1: ResnetBlock<B>,
|
||||
res2: ResnetBlock<B>,
|
||||
downsampler: Option<PaddedConv2d<B>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> EncoderBlock<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let x = self.res1.forward(x);
|
||||
let x = self.res2.forward(x);
|
||||
if let Some(d) = self.downsampler.as_ref() {
|
||||
d.forward(x)
|
||||
} else {
|
||||
x
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct DecoderBlockConfig {
|
||||
n_channels_in: usize,
|
||||
n_channels_out: usize,
|
||||
upsample: bool,
|
||||
}
|
||||
|
||||
impl DecoderBlockConfig {
|
||||
fn init<B: Backend>(&self) -> DecoderBlock<B> {
|
||||
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init();
|
||||
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
|
||||
let res3 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
|
||||
let upsampler = if self.upsample {
|
||||
Some( Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init() )
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
DecoderBlock {
|
||||
res1,
|
||||
res2,
|
||||
res3,
|
||||
upsampler,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct DecoderBlock<B: Backend> {
|
||||
res1: ResnetBlock<B>,
|
||||
res2: ResnetBlock<B>,
|
||||
res3: ResnetBlock<B>,
|
||||
upsampler: Option<Conv2d<B>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> DecoderBlock<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let x = self.res1.forward(x);
|
||||
let x = self.res2.forward(x);
|
||||
let x = self.res3.forward(x);
|
||||
|
||||
if let Some(d) = self.upsampler.as_ref() {
|
||||
let [n_batch, n_channel, height, width] = x.dims();
|
||||
let x = x
|
||||
.reshape([n_batch, n_channel, height, 1, width, 1])
|
||||
.repeat(3, 2)
|
||||
.repeat(5, 2)
|
||||
.reshape([n_batch, n_channel, 2 * height, 2 * width]);
|
||||
d.forward(x)
|
||||
} else {
|
||||
x
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct PaddedConv2dConfig {
|
||||
channels: [usize; 2],
|
||||
kernel_size: usize,
|
||||
#[config(default = 1)]
|
||||
stride: usize,
|
||||
padding: Padding,
|
||||
}
|
||||
|
||||
impl PaddedConv2dConfig {
|
||||
fn init<B: Backend>(&self) -> PaddedConv2d<B> {
|
||||
let calc_padding = |p_left, p_right| {
|
||||
let n = if p_left >= p_right {
|
||||
0
|
||||
} else {
|
||||
div_roundup(p_right - p_left, self.stride)
|
||||
};
|
||||
|
||||
n * self.stride + p_left
|
||||
};
|
||||
|
||||
let pad_vertical = calc_padding(self.padding.pad_top, self.padding.pad_bottom);
|
||||
let pad_horizontal = calc_padding(self.padding.pad_left, self.padding.pad_right);
|
||||
let padding_actual = [pad_vertical, pad_horizontal];
|
||||
|
||||
let conv = Conv2dConfig::new(self.channels, [self.kernel_size, self.kernel_size])
|
||||
.with_stride([self.stride, self.stride])
|
||||
.with_padding(PaddingConfig2d::Explicit(pad_vertical, pad_horizontal))
|
||||
.init();
|
||||
|
||||
let kernel_size = self.kernel_size;
|
||||
let stride = self.stride;
|
||||
|
||||
let padding = self.padding;
|
||||
|
||||
PaddedConv2d {
|
||||
conv,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
padding_actual,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct PaddedConv2d<B: Backend> {
|
||||
conv: Conv2d<B>,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: Padding,
|
||||
padding_actual: [usize; 2],
|
||||
}
|
||||
|
||||
impl<B: Backend> PaddedConv2d<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let [n_batch, n_channel, height, width] = x.dims();
|
||||
|
||||
let desired_height = (self.padding.pad_top + self.padding.pad_bottom + height - self.kernel_size) / self.stride + 1;
|
||||
let desired_width = (self.padding.pad_left + self.padding.pad_right + width - self.kernel_size) / self.stride + 1;
|
||||
|
||||
let skip_vert = (self.padding_actual[0] - self.padding.pad_top) / self.stride;
|
||||
let skip_hor = (self.padding_actual[1] - self.padding.pad_left) / self.stride;
|
||||
|
||||
self.conv
|
||||
.forward(x)
|
||||
.slice([
|
||||
0..n_batch,
|
||||
0..n_channel,
|
||||
skip_vert..(skip_vert + desired_height),
|
||||
skip_hor..(skip_hor + desired_width)
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Config, Module, Copy, Debug)]
|
||||
pub struct Padding {
|
||||
pad_left: usize,
|
||||
pad_right: usize,
|
||||
pad_top: usize,
|
||||
pad_bottom: usize,
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct MidConfig {
|
||||
n_channel: usize,
|
||||
}
|
||||
|
||||
impl MidConfig {
|
||||
fn init<B: Backend>(&self) -> Mid<B> {
|
||||
let block_1 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init();
|
||||
let attn = ConvSelfAttentionBlockConfig::new(self.n_channel).init();
|
||||
let block_2 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init();
|
||||
|
||||
Mid {
|
||||
block_1,
|
||||
attn,
|
||||
block_2,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Mid<B: Backend> {
|
||||
block_1: ResnetBlock<B>,
|
||||
attn: ConvSelfAttentionBlock<B>,
|
||||
block_2: ResnetBlock<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Mid<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let x = self.block_1.forward(x);
|
||||
let x = self.attn.forward(x);
|
||||
let x = self.block_2.forward(x);
|
||||
x
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct ResnetBlockConfig {
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
}
|
||||
|
||||
impl ResnetBlockConfig {
|
||||
fn init<B: Backend>(&self) -> ResnetBlock<B> {
|
||||
let norm1 = GroupNormConfig::new(32, self.in_channels).init();
|
||||
let conv1 = Conv2dConfig::new([self.in_channels, self.out_channels], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
||||
let norm2 = GroupNormConfig::new(32, self.out_channels).init();
|
||||
let conv2 = Conv2dConfig::new([self.out_channels, self.out_channels], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
||||
let nin_shortcut = if self.in_channels != self.out_channels {
|
||||
Some( Conv2dConfig::new([self.in_channels, self.out_channels], [1, 1]).init() )
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let silu1 = SILU::new();
|
||||
let silu2 = SILU::new();
|
||||
|
||||
ResnetBlock {
|
||||
norm1,
|
||||
silu1,
|
||||
conv1,
|
||||
norm2,
|
||||
silu2,
|
||||
conv2,
|
||||
nin_shortcut,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ResnetBlock<B: Backend> {
|
||||
norm1: GroupNorm<B>,
|
||||
silu1: SILU,
|
||||
conv1: Conv2d<B>,
|
||||
norm2: GroupNorm<B>,
|
||||
silu2: SILU,
|
||||
conv2: Conv2d<B>,
|
||||
nin_shortcut: Option<Conv2d<B>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ResnetBlock<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let h = self.conv1.forward( self.silu1.forward(self.norm1.forward(x.clone())) );
|
||||
let h = self.conv2.forward( self.silu2.forward(self.norm2.forward(h)) );
|
||||
|
||||
|
||||
if let Some(ns) = self.nin_shortcut.as_ref() {
|
||||
ns.forward(x) + h
|
||||
} else {
|
||||
x + h
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct ConvSelfAttentionBlockConfig {
|
||||
n_channel: usize,
|
||||
}
|
||||
|
||||
impl ConvSelfAttentionBlockConfig {
|
||||
fn init<B: Backend>(&self) -> ConvSelfAttentionBlock<B> {
|
||||
let norm = GroupNormConfig::new(32, self.n_channel).init();
|
||||
let q = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
|
||||
let k = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
|
||||
let v = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
|
||||
let proj_out = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
|
||||
|
||||
ConvSelfAttentionBlock {
|
||||
norm,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
proj_out,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ConvSelfAttentionBlock<B: Backend> {
|
||||
norm: GroupNorm<B>,
|
||||
q: Conv2d<B>,
|
||||
k: Conv2d<B>,
|
||||
v: Conv2d<B>,
|
||||
proj_out: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ConvSelfAttentionBlock<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let [n_batch, n_channel, height, width] = x.dims();
|
||||
|
||||
let h = self.norm.forward(x.clone());
|
||||
|
||||
let q = self.q.forward(h.clone()).reshape([n_batch, n_channel, height * width]).swap_dims(1, 2);
|
||||
let k = self.k.forward(h.clone()).reshape([n_batch, n_channel, height * width]).swap_dims(1, 2);
|
||||
let v = self.v.forward(h).reshape([n_batch, n_channel, height * width]).swap_dims(1, 2);
|
||||
|
||||
let wv = qkv_attention(q, k, v, None, 1)
|
||||
.swap_dims(1, 2)
|
||||
.reshape([n_batch, n_channel, height, width]);
|
||||
|
||||
let projected = self.proj_out.forward(wv);
|
||||
|
||||
x + projected
|
||||
}
|
||||
}
|
||||
86
src/model/clip/load.rs
Normal file
86
src/model/clip/load.rs
Normal file
@@ -0,0 +1,86 @@
|
||||
use std::error::Error;
|
||||
use burn::tensor::ElementConversion;
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn,
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
Tensor,
|
||||
},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
use crate::model::load::*;
|
||||
|
||||
pub fn load_mlp<B: Backend>(path: &str, device: &B::Device) -> Result<MLP<B>, Box<dyn Error>> {
|
||||
let fc1 = load_linear(&format!("{}/{}", path, "fc1"), device)?;
|
||||
let gelu = QuickGELU::new();
|
||||
let fc2 = load_linear(&format!("{}/{}", path, "fc2"), device)?;
|
||||
|
||||
let mlp = MLP {
|
||||
fc1: fc1,
|
||||
gelu: gelu,
|
||||
fc2: fc2,
|
||||
};
|
||||
|
||||
Ok(mlp)
|
||||
}
|
||||
|
||||
pub fn load_multi_head_self_attention<B: Backend>(path: &str, device: &B::Device) -> Result<MultiHeadSelfAttention<B>, Box<dyn Error>> {
|
||||
let n_head = load_usize::<B>("n_head", path, device)?;
|
||||
let query = load_linear(&format!("{}/{}", path, "query"), device)?;
|
||||
let key = load_linear(&format!("{}/{}", path, "key"), device)?;
|
||||
let value = load_linear(&format!("{}/{}", path, "value"), device)?;
|
||||
let out = load_linear(&format!("{}/{}", path, "out"), device)?;
|
||||
|
||||
let mhsa = MultiHeadSelfAttention {
|
||||
n_head: n_head,
|
||||
query: query,
|
||||
key: key,
|
||||
value: value,
|
||||
out: out,
|
||||
};
|
||||
|
||||
Ok(mhsa)
|
||||
}
|
||||
|
||||
pub fn load_residual_decoder_attention_block<B: Backend>(path: &str, device: &B::Device) -> Result<ResidualDecoderAttentionBlock<B>, Box<dyn Error>> {
|
||||
let mlp = load_mlp(&format!("{}/{}", path, "mlp"), device)?;
|
||||
let attn = load_multi_head_self_attention(&format!("{}/{}", path, "attn"), device)?;
|
||||
let attn_ln = load_layer_norm(&format!("{}/{}", path, "attn_ln"), device)?;
|
||||
let mlp_ln = load_layer_norm(&format!("{}/{}", path, "mlp_ln"), device)?;
|
||||
|
||||
let rdab = ResidualDecoderAttentionBlock {
|
||||
attn: attn,
|
||||
attn_ln: attn_ln,
|
||||
mlp: mlp,
|
||||
mlp_ln: mlp_ln,
|
||||
};
|
||||
|
||||
Ok(rdab)
|
||||
}
|
||||
|
||||
pub fn load_clip<B: Backend>(path: &str, device: &B::Device) -> Result<CLIP<B>, Box<dyn Error>> {
|
||||
let token_embedding = load_embedding(&format!("{}/{}", path, "token_embedding"), device)?;
|
||||
let position_embedding = load_tensor("weight", &format!("{}/position_embedding", path), device)?.into();
|
||||
|
||||
let n_layer = load_usize::<B>("n_layer", path, device)?;
|
||||
let mut blocks = (0..n_layer)
|
||||
.into_iter()
|
||||
.map(|i| {
|
||||
load_residual_decoder_attention_block::<B>(&format!("{}/blocks/{}", path, i), device)
|
||||
}).collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let layer_norm = load_layer_norm(&format!("{}/{}", path, "layer_norm"), device)?;
|
||||
|
||||
let clip = CLIP {
|
||||
token_embedding: token_embedding,
|
||||
position_embedding: position_embedding,
|
||||
blocks: blocks,
|
||||
layer_norm: layer_norm,
|
||||
};
|
||||
|
||||
Ok(clip)
|
||||
}
|
||||
220
src/model/clip/mod.rs
Normal file
220
src/model/clip/mod.rs
Normal file
@@ -0,0 +1,220 @@
|
||||
pub mod load;
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn,
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
activation::{softmax, sigmoid},
|
||||
module::embedding,
|
||||
Tensor,
|
||||
Distribution,
|
||||
Int,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::model::attention::{qkv_attention, attn_decoder_mask};
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct CLIPConfig {
|
||||
n_vocab: usize,
|
||||
n_state: usize,
|
||||
n_head: usize,
|
||||
n_ctx: usize,
|
||||
n_layer: usize,
|
||||
}
|
||||
|
||||
impl CLIPConfig {
|
||||
pub fn init<B: Backend>(&self) -> CLIP<B> {
|
||||
let token_embedding = nn::EmbeddingConfig::new(self.n_vocab, self.n_state).init();
|
||||
let position_embedding = Tensor::random([self.n_ctx, self.n_state], Distribution::Normal(0.0, 1.0)).into();
|
||||
let blocks = (0..self.n_layer)
|
||||
.into_iter()
|
||||
.map(|_| ResidualDecoderAttentionBlockConfig::new(self.n_state, self.n_head).init())
|
||||
.collect();
|
||||
let layer_norm = nn::LayerNormConfig::new(self.n_state).init();
|
||||
|
||||
CLIP {
|
||||
token_embedding,
|
||||
position_embedding,
|
||||
blocks,
|
||||
layer_norm,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct CLIP<B: Backend> {
|
||||
token_embedding: nn::Embedding<B>,
|
||||
position_embedding: Param<Tensor<B, 2>>,
|
||||
blocks: Vec<ResidualDecoderAttentionBlock<B>>,
|
||||
layer_norm: nn::LayerNorm<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> CLIP<B> {
|
||||
pub fn forward(&self, x: Tensor<B, 2, Int>) -> Tensor<B, 3> {
|
||||
let [n_batch, seq_len] = x.dims();
|
||||
|
||||
let mask = attn_decoder_mask(seq_len);
|
||||
|
||||
let embedded = self.token_embedding.forward(x)
|
||||
+ self.position_embedding.val().slice([0..seq_len]).unsqueeze();
|
||||
|
||||
let mut x = embedded;
|
||||
for block in &self.blocks {
|
||||
x = block.forward(x, mask.clone());
|
||||
}
|
||||
|
||||
self.layer_norm.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct ResidualDecoderAttentionBlockConfig {
|
||||
n_state: usize,
|
||||
n_head: usize,
|
||||
}
|
||||
|
||||
impl ResidualDecoderAttentionBlockConfig {
|
||||
pub fn init<B: Backend>(&self) -> ResidualDecoderAttentionBlock<B> {
|
||||
let attn = MultiHeadSelfAttentionConfig::new(self.n_state, self.n_head).init();
|
||||
let attn_ln = nn::LayerNormConfig::new(self.n_state).init();
|
||||
|
||||
let mlp = MLPConfig::new(self.n_state, 4 * self.n_state).init();
|
||||
let mlp_ln = nn::LayerNormConfig::new(self.n_state).init();
|
||||
|
||||
ResidualDecoderAttentionBlock {
|
||||
attn,
|
||||
attn_ln,
|
||||
mlp,
|
||||
mlp_ln,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ResidualDecoderAttentionBlock<B: Backend> {
|
||||
attn: MultiHeadSelfAttention<B>,
|
||||
attn_ln: nn::LayerNorm<B>,
|
||||
mlp: MLP<B>,
|
||||
mlp_ln: nn::LayerNorm<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ResidualDecoderAttentionBlock<B> {
|
||||
fn forward(&self, x: Tensor<B, 3>, mask: Tensor<B, 2>) -> Tensor<B, 3> {
|
||||
let x = x.clone() + self.attn.forward(self.attn_ln.forward(x), Some(mask));
|
||||
let x = x.clone() + self.mlp.forward(self.mlp_ln.forward(x));
|
||||
return x;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct MultiHeadSelfAttentionConfig {
|
||||
n_state: usize,
|
||||
n_head: usize,
|
||||
}
|
||||
|
||||
impl MultiHeadSelfAttentionConfig {
|
||||
fn init<B: Backend>(&self) -> MultiHeadSelfAttention<B> {
|
||||
assert!(self.n_state % self.n_head == 0, "State size {} must be a multiple of head size {}", self.n_state, self.n_head);
|
||||
|
||||
let n_head = self.n_head;
|
||||
let query = nn::LinearConfig::new(self.n_state, self.n_state).init();
|
||||
let key = nn::LinearConfig::new(self.n_state, self.n_state).init();
|
||||
let value = nn::LinearConfig::new(self.n_state, self.n_state).init();
|
||||
let out = nn::LinearConfig::new(self.n_state, self.n_state).init();
|
||||
|
||||
MultiHeadSelfAttention {
|
||||
n_head,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct MultiHeadSelfAttention<B: Backend> {
|
||||
n_head: usize,
|
||||
query: nn::Linear<B>,
|
||||
key: nn::Linear<B>,
|
||||
value: nn::Linear<B>,
|
||||
out: nn::Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> MultiHeadSelfAttention<B> {
|
||||
pub fn forward(&self, x: Tensor<B, 3>, mask: Option<Tensor<B, 2>>) -> Tensor<B, 3> {
|
||||
let q = self.query.forward(x.clone());
|
||||
let k = self.key.forward(x.clone());
|
||||
let v = self.value.forward(x);
|
||||
|
||||
let wv = qkv_attention(q, k, v, mask, self.n_head);
|
||||
|
||||
return self.out.forward(wv);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
#[derive(Config, Debug)]
|
||||
pub struct MLPConfig {
|
||||
input_size: usize,
|
||||
hidden_size: usize,
|
||||
}
|
||||
|
||||
impl MLPConfig {
|
||||
fn init<B: Backend>(&self) -> MLP<B> {
|
||||
let fc1 = nn::LinearConfig::new(self.input_size, self.hidden_size).init();
|
||||
let gelu = QuickGELU::new();
|
||||
let fc2 = nn::LinearConfig::new(self.hidden_size, self.input_size).init();
|
||||
|
||||
MLP {
|
||||
fc1,
|
||||
gelu,
|
||||
fc2,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct MLP<B: Backend> {
|
||||
fc1: nn::Linear<B>,
|
||||
gelu: QuickGELU,
|
||||
fc2: nn::Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> MLP<B> {
|
||||
fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
|
||||
let x = self.fc1.forward(x);
|
||||
let x = self.gelu.forward(x);
|
||||
let x = self.fc2.forward(x);
|
||||
|
||||
x
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Clone, Debug)]
|
||||
pub struct QuickGELU {}
|
||||
|
||||
impl QuickGELU {
|
||||
fn new() -> Self {
|
||||
Self {}
|
||||
}
|
||||
|
||||
fn forward<B: Backend, const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
|
||||
x.clone() * sigmoid(x * 1.702)
|
||||
}
|
||||
}
|
||||
|
||||
33
src/model/groupnorm/load.rs
Normal file
33
src/model/groupnorm/load.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
use super::GroupNorm;
|
||||
use crate::model::load::*;
|
||||
|
||||
use std::error::Error;
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn,
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
Tensor,
|
||||
},
|
||||
};
|
||||
|
||||
pub fn load_group_norm<B: Backend>(path: &str, device: &B::Device) -> Result<GroupNorm<B>, Box<dyn Error>> {
|
||||
let n_group = load_usize::<B>("n_group", path, device)?.into();
|
||||
let n_channel = load_usize::<B>("n_channel", path, device)?.into();
|
||||
let eps = load_f32::<B>("eps", path, device)?.into();
|
||||
|
||||
let gamma = load_tensor::<B, 1>("weight", path, device).ok().unwrap_or_else(|| Tensor::ones_device([n_channel], device)).into();
|
||||
let beta = load_tensor::<B, 1>("bias", path, device).ok().unwrap_or_else(|| Tensor::zeros_device([n_channel], device)).into();
|
||||
|
||||
Ok(
|
||||
GroupNorm {
|
||||
n_group,
|
||||
n_channel,
|
||||
gamma,
|
||||
beta,
|
||||
eps,
|
||||
}
|
||||
)
|
||||
}
|
||||
72
src/model/groupnorm/mod.rs
Normal file
72
src/model/groupnorm/mod.rs
Normal file
@@ -0,0 +1,72 @@
|
||||
pub mod load;
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
Tensor,
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct GroupNormConfig {
|
||||
n_group: usize,
|
||||
n_channel: usize,
|
||||
#[config(default = 1e-5)]
|
||||
eps: f64,
|
||||
}
|
||||
|
||||
impl GroupNormConfig {
|
||||
pub fn init<B: Backend>(&self) -> GroupNorm<B> {
|
||||
assert!(self.n_channel % self.n_group == 0, "The number of channels {} must be divisible by the number of groups {}", self.n_channel, self.n_group);
|
||||
|
||||
let n_per_group = self.n_channel / self.n_group;
|
||||
|
||||
let gamma = Tensor::ones([self.n_channel]).into();
|
||||
let beta = Tensor::zeros([self.n_channel]).into();
|
||||
|
||||
let eps = self.eps;
|
||||
|
||||
GroupNorm {
|
||||
n_group: self.n_group,
|
||||
n_channel: self.n_channel,
|
||||
gamma,
|
||||
beta,
|
||||
eps,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct GroupNorm<B: Backend> {
|
||||
n_group: usize,
|
||||
n_channel: usize,
|
||||
gamma: Param<Tensor<B, 1>>,
|
||||
beta: Param<Tensor<B, 1>>,
|
||||
eps: f64,
|
||||
}
|
||||
|
||||
impl<B: Backend> GroupNorm<B> {
|
||||
pub fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
|
||||
let shape = x.shape();
|
||||
let n_batch = shape.dims[0];
|
||||
let num_elements = shape.num_elements();
|
||||
|
||||
let mut affine_shape = [1; D];
|
||||
affine_shape[1] = self.n_channel;
|
||||
|
||||
layernorm( x.reshape([n_batch, self.n_group, num_elements / (n_batch * self.n_group) ]), self.eps )
|
||||
.reshape(shape)
|
||||
.mul(self.gamma.val().reshape(affine_shape))
|
||||
.add(self.beta.val().reshape(affine_shape))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn layernorm<B: Backend, const D: usize>(x: Tensor<B, D>, eps: f64) -> Tensor<B, D> {
|
||||
//let (var, mean) = x.clone().var_mean_bias(D - 1);
|
||||
//x.sub(mean).div(var.sqrt().add_scalar(eps))
|
||||
|
||||
let u = x.clone() - x.mean_dim(D - 1);
|
||||
u.clone().div( (u.clone() * u).mean_dim(D - 1).add_scalar(eps).sqrt() )
|
||||
}
|
||||
167
src/model/load.rs
Normal file
167
src/model/load.rs
Normal file
@@ -0,0 +1,167 @@
|
||||
use std::error::Error;
|
||||
use std::io::Read;
|
||||
use npy::{self, NpyData};
|
||||
use num_traits::cast::ToPrimitive;
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn::{self, conv},
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
Tensor,
|
||||
Data,
|
||||
},
|
||||
};
|
||||
|
||||
use burn::tensor::ElementConversion;
|
||||
|
||||
pub fn numpy_to_tensor<B: Backend, const D: usize>(numpy_data: NpyData<f32>, device: &B::Device) -> Tensor<B, D> {
|
||||
let mut v = numpy_data.to_vec();
|
||||
|
||||
let shape: Vec<_> = v[0..D].into_iter().map(|&v| v as usize).collect();
|
||||
let data: Vec<B::FloatElem> = v[D..].into_iter().map(|e| e.elem()).collect();
|
||||
|
||||
Tensor::from_data_device(Data::new(data, shape.into()), device)
|
||||
}
|
||||
|
||||
pub fn load_tensor<B: Backend, const D: usize>(name: &str, path: &str, device: &B::Device) -> Result<Tensor<B, D>, Box<dyn Error>> {
|
||||
let tensor_path = format!("{}/{}.npy", path, name);
|
||||
|
||||
let mut buf = vec![];
|
||||
std::fs::File::open(&tensor_path)?
|
||||
.read_to_end(&mut buf)?;
|
||||
|
||||
let tensor_numpy: NpyData<f32> = NpyData::from_bytes(&buf)?;
|
||||
|
||||
let tensor = numpy_to_tensor(tensor_numpy, device);
|
||||
|
||||
println!("{}", tensor_path);
|
||||
|
||||
Ok(tensor)
|
||||
}
|
||||
|
||||
pub fn load_f32<B: Backend>(name: &str, path: &str, device: &B::Device) -> Result<f32, Box<dyn Error>> {
|
||||
load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_f32().unwrap())
|
||||
}
|
||||
|
||||
pub fn load_usize<B: Backend>(name: &str, path: &str, device: &B::Device) -> Result<usize, Box<dyn Error>> {
|
||||
load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_usize().unwrap())
|
||||
}
|
||||
|
||||
pub fn load_linear<B: Backend>(path: &str, device: &B::Device) -> Result<nn::Linear<B>, Box<dyn Error>> {
|
||||
let weight = load_tensor::<B, 2>("weight", path, device)?;
|
||||
let bias = load_tensor::<B, 1>("bias", path, device).ok();
|
||||
|
||||
let record = nn::LinearRecord {
|
||||
weight: weight.into(),
|
||||
bias: bias.map(|t| t.into()),
|
||||
};
|
||||
|
||||
let linear: nn::Linear<B> = nn::LinearConfig::new(3, 3).init_with(record);
|
||||
Ok(linear)
|
||||
}
|
||||
|
||||
pub fn load_embedding<B: Backend>(path: &str, device: &B::Device) -> Result<nn::Embedding<B>, Box<dyn Error>> {
|
||||
let weight = load_tensor::<B, 2>("weight", path, device)?;
|
||||
let [n_vocab, n_state] = weight.dims();
|
||||
|
||||
let record = nn::EmbeddingRecord {
|
||||
weight: weight.into(),
|
||||
};
|
||||
|
||||
let embedding = nn::EmbeddingConfig::new(n_vocab, n_state).init_with(record);
|
||||
Ok(embedding)
|
||||
}
|
||||
|
||||
pub fn load_layer_norm<B: Backend>(path: &str, device: &B::Device) -> Result<nn::LayerNorm<B>, Box<dyn Error>> {
|
||||
let weight = load_tensor::<B, 1>("weight", path, device)?;
|
||||
let bias = load_tensor::<B, 1>("bias", path, device)?;
|
||||
let eps = load_f32::<B>("eps", path, device)? as f64;
|
||||
|
||||
let [n_state] = weight.dims();
|
||||
|
||||
let record = nn::LayerNormRecord {
|
||||
gamma: weight.into(),
|
||||
beta: bias.into(),
|
||||
epsilon: <f64 as Module<B>>::into_record(eps),
|
||||
};
|
||||
|
||||
let layer_norm: nn::LayerNorm<B> = nn::LayerNormConfig::new(n_state).init_with(record);
|
||||
|
||||
Ok(layer_norm)
|
||||
}
|
||||
|
||||
|
||||
/*pub fn load_rmsnorm<B: Backend>(path: &str, device: &B::Device) -> Result<RMSNorm<B>, Box<dyn Error>> {
|
||||
let weight = load_tensor::<B, 1>("weight", path, device)?;
|
||||
let eps = load_f32::<B>("eps", path, device)?.into();
|
||||
|
||||
let rmsnorm = RMSNorm {
|
||||
weight: weight.into(),
|
||||
eps: eps
|
||||
};
|
||||
|
||||
Ok(rmsnorm)
|
||||
}*/
|
||||
|
||||
pub fn load_conv2d<B: Backend>(path: &str, device: &B::Device) -> Result<conv::Conv2d<B>, Box<dyn Error>> {
|
||||
let weight = load_tensor::<B, 4>("weight", path, device)?;
|
||||
let bias = load_tensor::<B, 1>("bias", path, device).ok();
|
||||
let has_bias = bias.is_some();
|
||||
|
||||
let stride = load_tensor::<B, 1>("stride", path, device)?;
|
||||
let stride = tensor_to_array_2(stride);
|
||||
|
||||
let kernel_size = load_tensor::<B, 1>("kernel_size", path, device)?;
|
||||
let kernel_size = tensor_to_array_2(kernel_size);
|
||||
|
||||
let dilation = load_tensor::<B, 1>("dilation", path, device)?;
|
||||
let dilation = tensor_to_array_2(dilation);
|
||||
|
||||
let n_group = load_usize::<B>("n_group", path, device)?.into();
|
||||
let n_channels_in = load_usize::<B>("n_channels_in", path, device)?.into();
|
||||
let n_channels_out = load_usize::<B>("n_channels_out", path, device)?.into();
|
||||
|
||||
let padding = load_tensor::<B, 1>("padding", path, device)?;
|
||||
let padding = tensor_to_array_2(padding);
|
||||
let padding = nn::PaddingConfig2d::Explicit(padding[0], padding[1]);
|
||||
|
||||
|
||||
let record = conv::Conv2dRecord {
|
||||
weight: weight.into(),
|
||||
bias: bias.map(|t| t.into()),
|
||||
stride: <[usize; 2] as Module<B>>::into_record(stride),
|
||||
kernel_size: <[usize; 2] as Module<B>>::into_record(kernel_size),
|
||||
dilation: <[usize; 2] as Module<B>>::into_record(dilation),
|
||||
groups: <usize as Module<B>>::into_record(n_group),
|
||||
padding: <nn::PaddingConfig2d as Module<B>>::into_record(padding.clone()),
|
||||
};
|
||||
|
||||
let conv2d: conv::Conv2d<B> = conv::Conv2dConfig::new([n_channels_in, n_channels_out], kernel_size)
|
||||
.with_stride(stride)
|
||||
.with_dilation(dilation)
|
||||
.with_groups(n_group)
|
||||
.with_padding(padding)
|
||||
.with_bias(has_bias)
|
||||
.init_with(record);
|
||||
Ok(conv2d)
|
||||
}
|
||||
|
||||
pub fn tensor_to_array_2<B: Backend>(x: Tensor<B, 1>) -> [usize; 2] {
|
||||
let vec = x.into_data().value;
|
||||
assert!(vec.len() == 2, "Tensor length must be 2.");
|
||||
[vec[0].to_usize().unwrap(), vec[1].to_usize().unwrap()]
|
||||
}
|
||||
|
||||
pub fn tensor_to_array<const N: usize, B: Backend>(x: Tensor<B, 1>) -> [usize; N] {
|
||||
let vec = x.into_data().value;
|
||||
assert!(vec.len() == N, "Tensor length must be {}.", N);
|
||||
|
||||
let mut arr = [0; N];
|
||||
for (a, t) in arr.iter_mut().zip(vec) {
|
||||
*a = t.to_usize().unwrap();
|
||||
}
|
||||
|
||||
arr
|
||||
}
|
||||
11
src/model/mod.rs
Normal file
11
src/model/mod.rs
Normal file
@@ -0,0 +1,11 @@
|
||||
pub mod stablediffusion;
|
||||
|
||||
pub mod autoencoder;
|
||||
pub mod unet;
|
||||
pub mod clip;
|
||||
|
||||
pub mod silu;
|
||||
pub mod groupnorm;
|
||||
pub mod attention;
|
||||
|
||||
pub mod load;
|
||||
22
src/model/silu.rs
Normal file
22
src/model/silu.rs
Normal file
@@ -0,0 +1,22 @@
|
||||
use burn::{
|
||||
module::Module,
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
activation::sigmoid,
|
||||
Tensor,
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
#[derive(Module, Clone, Debug)]
|
||||
pub struct SILU {}
|
||||
|
||||
impl SILU {
|
||||
pub fn new() -> Self {
|
||||
Self {}
|
||||
}
|
||||
|
||||
pub fn forward<B: Backend, const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
|
||||
x.clone() * sigmoid(x)
|
||||
}
|
||||
}
|
||||
34
src/model/stablediffusion/load.rs
Normal file
34
src/model/stablediffusion/load.rs
Normal file
@@ -0,0 +1,34 @@
|
||||
use std::error::Error;
|
||||
use burn::tensor::ElementConversion;
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn,
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
Tensor,
|
||||
},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
use crate::model::{load::*, autoencoder::load::load_autoencoder, unet::load::load_unet, clip::load::load_clip};
|
||||
|
||||
pub fn load_stable_diffusion<B: Backend>(path: &str, device: &B::Device) -> Result<StableDiffusion<B>, Box<dyn Error>> {
|
||||
let n_steps = load_usize::<B>("n_steps", path, device)?;
|
||||
let alpha_cumulative_products: Vec<_> = load_tensor::<B, 1>("alphas_cumprod", path, device)?.into_data().value.into_iter()
|
||||
.map(|v: <Float as BasicOps<B>>::Elem| v.to_f64().unwrap())
|
||||
.collect();
|
||||
let autoencoder = load_autoencoder(&format!("{}/{}", path, "autoencoder"), device)?;
|
||||
let diffusion = load_unet(&format!("{}/{}", path, "unet"), device)?;
|
||||
let clip = load_clip(&format!("{}/{}", path, "clip"), device)?;
|
||||
|
||||
Ok(StableDiffusion {
|
||||
n_steps,
|
||||
alpha_cumulative_products,
|
||||
autoencoder,
|
||||
diffusion,
|
||||
clip,
|
||||
})
|
||||
}
|
||||
|
||||
181
src/model/stablediffusion/mod.rs
Normal file
181
src/model/stablediffusion/mod.rs
Normal file
@@ -0,0 +1,181 @@
|
||||
pub mod load;
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
module::Module,
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
Tensor,
|
||||
Int,
|
||||
Float,
|
||||
BasicOps,
|
||||
Data,
|
||||
Distribution,
|
||||
},
|
||||
};
|
||||
|
||||
use num_traits::ToPrimitive;
|
||||
|
||||
use super::autoencoder::{Autoencoder, AutoencoderConfig};
|
||||
use super::unet::{UNet, UNetConfig};
|
||||
use super::clip::{CLIP, CLIPConfig};
|
||||
use crate::tokenizer::SimpleTokenizer;
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct StableDiffusionConfig {
|
||||
|
||||
}
|
||||
|
||||
impl StableDiffusionConfig {
|
||||
fn init<B: Backend>(&self) -> StableDiffusion<B> {
|
||||
let n_steps = 1000;
|
||||
let alpha_cumulative_products = offset_cosine_schedule_cumprod::<B>(n_steps)
|
||||
.into_data().value
|
||||
.into_iter()
|
||||
.map(|v: <Float as BasicOps<B>>::Elem| v.to_f64().unwrap()).collect();
|
||||
|
||||
let autoencoder = AutoencoderConfig::new().init();
|
||||
let diffusion = UNetConfig::new().init();
|
||||
let clip = CLIPConfig::new(49408, 768, 12, 77, 12).init();
|
||||
|
||||
StableDiffusion {
|
||||
n_steps,
|
||||
alpha_cumulative_products,
|
||||
autoencoder,
|
||||
diffusion,
|
||||
clip,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct StableDiffusion<B: Backend> {
|
||||
n_steps: usize,
|
||||
alpha_cumulative_products: Vec<f64>,
|
||||
autoencoder: Autoencoder<B>,
|
||||
diffusion: UNet<B>,
|
||||
clip: CLIP<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> StableDiffusion<B> {
|
||||
pub fn sample_image(&self, context: Tensor<B, 3>, unconditional_context: Tensor<B, 2>, unconditional_guidance_scale: f64, n_steps: usize) -> Vec<Vec<u8>> {
|
||||
let [n_batch, _, _] = context.dims();
|
||||
|
||||
let latent = self.sample_latent(context, unconditional_context, unconditional_guidance_scale, n_steps);
|
||||
let image = self.autoencoder.decode_latent(latent * (1.0 / 0.18215));
|
||||
|
||||
let n_channel = 3;
|
||||
let height = 512;
|
||||
let width = 512;
|
||||
let num_elements_per_image = n_channel * height * width;
|
||||
|
||||
// correct size and scale and reorder to
|
||||
let image = (image + 1.0) / 2.0;
|
||||
let image = image
|
||||
.reshape([n_batch, n_channel, height, width])
|
||||
.swap_dims(1, 2)
|
||||
.swap_dims(2, 3)
|
||||
.mul_scalar(255.0);
|
||||
|
||||
let flattened: Vec<_> = image.
|
||||
into_data().
|
||||
value;
|
||||
|
||||
(0..n_batch).into_iter().map(|b| {
|
||||
let start = b * num_elements_per_image;
|
||||
let end = start + num_elements_per_image;
|
||||
|
||||
flattened[start..end].into_iter().map(|v| v.to_u8().unwrap()).collect()
|
||||
}).collect()
|
||||
}
|
||||
|
||||
pub fn sample_latent(&self, context: Tensor<B, 3>, unconditional_context: Tensor<B, 2>, unconditional_guidance_scale: f64, n_steps: usize) -> Tensor<B, 4> {
|
||||
assert!(self.n_steps % n_steps == 0);
|
||||
|
||||
let step_size = self.n_steps / n_steps;
|
||||
|
||||
let [n_batches, _, _] = context.dims();
|
||||
|
||||
let gen_noise = || {
|
||||
Tensor::random([n_batches, 4, 64, 64], Distribution::Normal(0.0, 1.0) )
|
||||
};
|
||||
|
||||
let sigma = 0.0; // Use deterministic diffusion
|
||||
|
||||
let mut latent = gen_noise();
|
||||
|
||||
for t in (0..self.n_steps).rev().step_by(step_size) {
|
||||
let current_alpha = self.alpha_cumulative_products[t];
|
||||
let prev_alpha = if t >= step_size {
|
||||
self.alpha_cumulative_products[t - step_size]
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
|
||||
let sqrt_noise = (1.0 - current_alpha).sqrt();
|
||||
|
||||
let timestep = Tensor::from_ints([t as i32]);
|
||||
let pred_noise = self.forward_diffuser(latent.clone(), timestep, context.clone(), unconditional_context.clone(), unconditional_guidance_scale);
|
||||
|
||||
let predx0 = (latent - pred_noise.clone() * sqrt_noise) / current_alpha.sqrt();
|
||||
let dir_latent = pred_noise * (1.0 - prev_alpha - sigma * sigma).sqrt();
|
||||
|
||||
let prev_latent = predx0 * prev_alpha.sqrt() + dir_latent + gen_noise() * sigma;
|
||||
latent = prev_latent;
|
||||
}
|
||||
|
||||
latent
|
||||
}
|
||||
|
||||
fn forward_diffuser(&self, latent: Tensor<B, 4>, timestep: Tensor<B, 1, Int>, context: Tensor<B, 3>, unconditional_context: Tensor<B, 2>, unconditional_guidance_scale: f64) -> Tensor<B, 4> {
|
||||
let [n_batch, n_channel, height, width] = latent.dims();
|
||||
let latent = latent.repeat(0, 2);
|
||||
|
||||
let latent = self.diffusion.forward(
|
||||
latent.repeat(0, 2),
|
||||
timestep.repeat(0, 2),
|
||||
Tensor::cat(vec![unconditional_context.unsqueeze::<3>(), context], 0)
|
||||
);
|
||||
|
||||
let unconditional_latent = latent.clone().slice([0..n_batch]);
|
||||
let conditional_latent = latent.slice([n_batch..2 * n_batch]);
|
||||
|
||||
unconditional_latent.clone() + (conditional_latent - unconditional_latent) * unconditional_guidance_scale
|
||||
}
|
||||
|
||||
pub fn unconditional_context(&self, tokenizer: &SimpleTokenizer) -> Tensor<B, 2> {
|
||||
self.context(tokenizer, "").squeeze(0)
|
||||
}
|
||||
|
||||
pub fn context(&self, tokenizer: &SimpleTokenizer, text: &str) -> Tensor<B, 3> {
|
||||
let text = format!("<|startoftext|>{}<|endoftext|>", text);
|
||||
let tokenized: Vec<_> = tokenizer.encode(&text).into_iter().map(|v| v as i32).collect();
|
||||
|
||||
self.clip.forward(Tensor::from_ints(&tokenized[..]).unsqueeze())
|
||||
}
|
||||
}
|
||||
|
||||
use crate::helper::to_float;
|
||||
use std::f64::consts::PI;
|
||||
|
||||
fn cosine_schedule<B: Backend>(n_steps: usize) -> Tensor<B, 1> {
|
||||
to_float(Tensor::arange(1..n_steps + 1))
|
||||
.mul_scalar(PI * 0.5 / n_steps as f64)
|
||||
.cos()
|
||||
}
|
||||
|
||||
fn offset_cosine_schedule<B: Backend>(n_steps: usize) -> Tensor<B, 1> {
|
||||
let min_signal_rate: f64 = 0.02;
|
||||
let max_signal_rate: f64 = 0.95;
|
||||
let start_angle = max_signal_rate.acos();
|
||||
let end_angle = min_signal_rate.acos();
|
||||
|
||||
let times = Tensor::arange(1..n_steps + 1);
|
||||
|
||||
let diffusion_angles = to_float(times) * ( (end_angle - start_angle) / n_steps as f64) + start_angle;
|
||||
diffusion_angles.cos()
|
||||
}
|
||||
|
||||
fn offset_cosine_schedule_cumprod<B: Backend>(n_steps: usize) -> Tensor<B, 1> {
|
||||
offset_cosine_schedule::<B>(n_steps).powf(2.0)
|
||||
}
|
||||
278
src/model/unet/load.rs
Normal file
278
src/model/unet/load.rs
Normal file
@@ -0,0 +1,278 @@
|
||||
use super::GroupNorm;
|
||||
use crate::model::load::*;
|
||||
|
||||
use std::error::Error;
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn,
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
Tensor,
|
||||
},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
use crate::model::groupnorm::load::load_group_norm;
|
||||
|
||||
pub fn load_res_block<B: Backend>(path: &str, device: &B::Device) -> Result<ResBlock<B>, Box<dyn Error>> {
|
||||
let norm_in = load_group_norm::<B>(&format!("{}/{}", path, "norm_in"), device)?;
|
||||
let conv_in = load_conv2d::<B>(&format!("{}/{}", path, "conv_in"), device)?;
|
||||
let lin_embed = load_linear::<B>(&format!("{}/{}", path, "lin_embed"), device)?;
|
||||
let norm_out = load_group_norm::<B>(&format!("{}/{}", path, "norm_out"), device)?;
|
||||
let conv_out = load_conv2d::<B>(&format!("{}/{}", path, "conv_out"), device)?;
|
||||
let skip_connection = load_conv2d::<B>(&format!("{}/{}", path, "skip_connection"), device).ok();
|
||||
|
||||
let res_block = ResBlock {
|
||||
norm_in: norm_in,
|
||||
silu_in: SILU::new(),
|
||||
conv_in: conv_in,
|
||||
silu_embed: SILU::new(),
|
||||
lin_embed: lin_embed,
|
||||
norm_out: norm_out,
|
||||
silu_out: SILU::new(),
|
||||
conv_out: conv_out,
|
||||
skip_connection: skip_connection,
|
||||
};
|
||||
|
||||
Ok(res_block)
|
||||
}
|
||||
|
||||
pub fn load_multi_head_attention<B: Backend>(path: &str, device: &B::Device) -> Result<MultiHeadAttention<B>, Box<dyn Error>> {
|
||||
let n_head = load_usize::<B>("n_head", path, device)?;
|
||||
let query = load_linear::<B>(&format!("{}/{}", path, "query"), device)?;
|
||||
let key = load_linear::<B>(&format!("{}/{}", path, "key"), device)?;
|
||||
let value = load_linear::<B>(&format!("{}/{}", path, "value"), device)?;
|
||||
let out = load_linear::<B>(&format!("{}/{}", path, "out"), device)?;
|
||||
|
||||
let multi_head_attention = MultiHeadAttention {
|
||||
n_head: n_head,
|
||||
query: query,
|
||||
key: key,
|
||||
value: value,
|
||||
out: out,
|
||||
};
|
||||
|
||||
Ok(multi_head_attention)
|
||||
}
|
||||
|
||||
|
||||
pub fn load_geglu<B: Backend>(path: &str, device: &B::Device) -> Result<GEGLU<B>, Box<dyn Error>> {
|
||||
let proj = load_linear::<B>(&format!("{}/{}", path, "proj"), device)?;
|
||||
|
||||
let geglue = GEGLU {
|
||||
proj: proj,
|
||||
gelu: GELU::new(), // Assuming GELU::new() initializes a new GELU struct
|
||||
};
|
||||
|
||||
Ok(geglue)
|
||||
}
|
||||
|
||||
|
||||
pub fn load_mlp<B: Backend>(path: &str, device: &B::Device) -> Result<MLP<B>, Box<dyn Error>> {
|
||||
let geglu = load_geglu::<B>(&format!("{}/{}", path, "geglu"), device)?;
|
||||
let lin = load_linear::<B>(&format!("{}/{}", path, "lin"), device)?;
|
||||
|
||||
let mlp = MLP {
|
||||
geglu: geglu,
|
||||
lin: lin,
|
||||
};
|
||||
|
||||
Ok(mlp)
|
||||
}
|
||||
|
||||
|
||||
pub fn load_transformer_block<B: Backend>(path: &str, device: &B::Device) -> Result<TransformerBlock<B>, Box<dyn Error>> {
|
||||
let norm1 = load_layer_norm::<B>(&format!("{}/{}", path, "norm1"), device)?;
|
||||
let attn1 = load_multi_head_attention::<B>(&format!("{}/{}", path, "attn1"), device)?;
|
||||
let norm2 = load_layer_norm::<B>(&format!("{}/{}", path, "norm2"), device)?;
|
||||
let attn2 = load_multi_head_attention::<B>(&format!("{}/{}", path, "attn2"), device)?;
|
||||
let norm3 = load_layer_norm::<B>(&format!("{}/{}", path, "norm3"), device)?;
|
||||
let mlp = load_mlp::<B>(&format!("{}/{}", path, "mlp"), device)?;
|
||||
|
||||
let transformer_block = TransformerBlock {
|
||||
norm1: norm1,
|
||||
attn1: attn1,
|
||||
norm2: norm2,
|
||||
attn2: attn2,
|
||||
norm3: norm3,
|
||||
mlp: mlp,
|
||||
};
|
||||
|
||||
Ok(transformer_block)
|
||||
}
|
||||
|
||||
|
||||
pub fn load_spatial_transformer<B: Backend>(path: &str, device: &B::Device) -> Result<SpatialTransformer<B>, Box<dyn Error>> {
|
||||
let norm = load_group_norm::<B>(&format!("{}/{}", path, "norm"), device)?;
|
||||
let proj_in = load_conv2d::<B>(&format!("{}/{}", path, "proj_in"), device)?;
|
||||
let transformer = load_transformer_block::<B>(&format!("{}/{}", path, "transformer"), device)?;
|
||||
let proj_out = load_conv2d::<B>(&format!("{}/{}", path, "proj_out"), device)?;
|
||||
|
||||
let spatial_transformer = SpatialTransformer {
|
||||
norm: norm,
|
||||
proj_in: proj_in,
|
||||
transformer: transformer,
|
||||
proj_out: proj_out,
|
||||
};
|
||||
|
||||
Ok(spatial_transformer)
|
||||
}
|
||||
|
||||
|
||||
pub fn load_upsample<B: Backend>(path: &str, device: &B::Device) -> Result<Upsample<B>, Box<dyn Error>> {
|
||||
let conv = load_conv2d::<B>(&format!("{}/{}", path, "conv"), device)?;
|
||||
|
||||
let upsample = Upsample {
|
||||
conv: conv,
|
||||
};
|
||||
|
||||
Ok(upsample)
|
||||
}
|
||||
|
||||
pub fn load_downsample<B: Backend>(path: &str, device: &B::Device) -> Result<Downsample<B>, Box<dyn Error>> {
|
||||
load_conv2d(path, device)
|
||||
}
|
||||
|
||||
pub fn load_res_transformer_res<B: Backend>(path: &str, device: &B::Device) -> Result<ResTransformerRes<B>, Box<dyn Error>> {
|
||||
let res1 = load_res_block::<B>(&format!("{}/{}", path, "res1"), device)?; // Assuming load_res_block function
|
||||
let transformer = load_spatial_transformer::<B>(&format!("{}/{}", path, "transformer"), device)?;
|
||||
let res2 = load_res_block::<B>(&format!("{}/{}", path, "res2"), device)?;
|
||||
|
||||
let res_transformer_res = ResTransformerRes {
|
||||
res1: res1,
|
||||
transformer: transformer,
|
||||
res2: res2,
|
||||
};
|
||||
|
||||
Ok(res_transformer_res)
|
||||
}
|
||||
|
||||
pub fn load_res_transformer_upsample<B: Backend>(path: &str, device: &B::Device) -> Result<ResTransformerUpsample<B>, Box<dyn Error>> {
|
||||
let res = load_res_block::<B>(&format!("{}/{}", path, "res"), device)?;
|
||||
let transformer = load_spatial_transformer::<B>(&format!("{}/{}", path, "transformer"), device)?;
|
||||
let upsample = load_upsample::<B>(&format!("{}/{}", path, "upsample"), device)?;
|
||||
|
||||
let res_transformer_upsample = ResTransformerUpsample {
|
||||
res: res,
|
||||
transformer: transformer,
|
||||
upsample: upsample,
|
||||
};
|
||||
|
||||
Ok(res_transformer_upsample)
|
||||
}
|
||||
|
||||
|
||||
pub fn load_res_upsample<B: Backend>(path: &str, device: &B::Device) -> Result<ResUpSample<B>, Box<dyn Error>> {
|
||||
let res = load_res_block::<B>(&format!("{}/{}", path, "res"), device)?;
|
||||
let upsample = load_upsample::<B>(&format!("{}/{}", path, "upsample"), device)?;
|
||||
|
||||
let res_upsample = ResUpSample {
|
||||
res: res,
|
||||
upsample: upsample,
|
||||
};
|
||||
|
||||
Ok(res_upsample)
|
||||
}
|
||||
|
||||
|
||||
pub fn load_res_transformer<B: Backend>(path: &str, device: &B::Device) -> Result<ResTransformer<B>, Box<dyn Error>> {
|
||||
let res = load_res_block::<B>(&format!("{}/{}", path, "res"), device)?;
|
||||
let transformer = load_spatial_transformer::<B>(&format!("{}/{}", path, "transformer"), device)?;
|
||||
|
||||
let res_transformer = ResTransformer {
|
||||
res: res,
|
||||
transformer: transformer,
|
||||
};
|
||||
|
||||
Ok(res_transformer)
|
||||
}
|
||||
|
||||
|
||||
pub fn load_unet_input_blocks<B: Backend>(path: &str, device: &B::Device) -> Result<UNetInputBlocks<B>, Box<dyn Error>> {
|
||||
let conv = load_conv2d::<B>(&format!("{}/{}", path, "conv"), device)?;
|
||||
let rt1 = load_res_transformer::<B>(&format!("{}/{}", path, "rt1"), device)?;
|
||||
let rt2 = load_res_transformer::<B>(&format!("{}/{}", path, "rt2"), device)?;
|
||||
let d1 = load_downsample::<B>(&format!("{}/{}", path, "d1"), device)?;
|
||||
let rt3 = load_res_transformer::<B>(&format!("{}/{}", path, "rt3"), device)?;
|
||||
let rt4 = load_res_transformer::<B>(&format!("{}/{}", path, "rt4"), device)?;
|
||||
let d2 = load_downsample::<B>(&format!("{}/{}", path, "d2"), device)?;
|
||||
let rt5 = load_res_transformer::<B>(&format!("{}/{}", path, "rt5"), device)?;
|
||||
let rt6 = load_res_transformer::<B>(&format!("{}/{}", path, "rt6"), device)?;
|
||||
let d3 = load_downsample::<B>(&format!("{}/{}", path, "d3"), device)?;
|
||||
let r1 = load_res_block::<B>(&format!("{}/{}", path, "r1"), device)?;
|
||||
let r2 = load_res_block::<B>(&format!("{}/{}", path, "r2"), device)?;
|
||||
|
||||
let unet_input_blocks = UNetInputBlocks {
|
||||
conv: conv,
|
||||
rt1: rt1,
|
||||
rt2: rt2,
|
||||
d1: d1,
|
||||
rt3: rt3,
|
||||
rt4: rt4,
|
||||
d2: d2,
|
||||
rt5: rt5,
|
||||
rt6: rt6,
|
||||
d3: d3,
|
||||
r1: r1,
|
||||
r2: r2,
|
||||
};
|
||||
|
||||
Ok(unet_input_blocks)
|
||||
}
|
||||
|
||||
pub fn load_unet_output_blocks<B: Backend>(path: &str, device: &B::Device) -> Result<UNetOutputBlocks<B>, Box<dyn Error>> {
|
||||
let r1 = load_res_block::<B>(&format!("{}/{}", path, "r1"), device)?;
|
||||
let r2 = load_res_block::<B>(&format!("{}/{}", path, "r2"), device)?;
|
||||
let ru = load_res_upsample::<B>(&format!("{}/{}", path, "ru"), device)?;
|
||||
let rt1 = load_res_transformer::<B>(&format!("{}/{}", path, "rt1"), device)?;
|
||||
let rt2 = load_res_transformer::<B>(&format!("{}/{}", path, "rt2"), device)?;
|
||||
let rtu1 = load_res_transformer_upsample::<B>(&format!("{}/{}", path, "rtu1"), device)?;
|
||||
let rt3 = load_res_transformer::<B>(&format!("{}/{}", path, "rt3"), device)?;
|
||||
let rt4 = load_res_transformer::<B>(&format!("{}/{}", path, "rt4"), device)?;
|
||||
let rtu2 = load_res_transformer_upsample::<B>(&format!("{}/{}", path, "rtu2"), device)?;
|
||||
let rt5 = load_res_transformer::<B>(&format!("{}/{}", path, "rt5"), device)?;
|
||||
let rt6 = load_res_transformer::<B>(&format!("{}/{}", path, "rt6"), device)?;
|
||||
let rt7 = load_res_transformer::<B>(&format!("{}/{}", path, "rt7"), device)?;
|
||||
|
||||
Ok(UNetOutputBlocks {
|
||||
r1,
|
||||
r2,
|
||||
ru,
|
||||
rt1,
|
||||
rt2,
|
||||
rtu1,
|
||||
rt3,
|
||||
rt4,
|
||||
rtu2,
|
||||
rt5,
|
||||
rt6,
|
||||
rt7,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
pub fn load_unet<B: Backend>(path: &str, device: &B::Device) -> Result<UNet<B>, Box<dyn Error>> {
|
||||
let lin1_time_embed = load_linear::<B>(&format!("{}/{}", path, "lin1_time_embed"), device)?;
|
||||
let silu_time_embed = SILU::new(); // Assuming SILU::new() initializes a new SILU struct
|
||||
let lin2_time_embed = load_linear::<B>(&format!("{}/{}", path, "lin2_time_embed"), device)?;
|
||||
let input_blocks = load_unet_input_blocks::<B>(&format!("{}/{}", path, "input_blocks"), device)?;
|
||||
let middle_block = load_res_transformer_res::<B>(&format!("{}/{}", path, "middle_block"), device)?;
|
||||
let output_blocks = load_unet_output_blocks::<B>(&format!("{}/{}", path, "output_blocks"), device)?;
|
||||
let norm_out = load_group_norm::<B>(&format!("{}/{}", path, "norm_out"), device)?;
|
||||
let silu_out = SILU::new(); // Assuming SILU::new() initializes a new SILU struct
|
||||
let conv_out = load_conv2d::<B>(&format!("{}/{}", path, "conv_out"), device)?;
|
||||
|
||||
Ok(UNet {
|
||||
lin1_time_embed,
|
||||
silu_time_embed,
|
||||
lin2_time_embed,
|
||||
input_blocks,
|
||||
middle_block,
|
||||
output_blocks,
|
||||
norm_out,
|
||||
silu_out,
|
||||
conv_out,
|
||||
})
|
||||
}
|
||||
757
src/model/unet/mod.rs
Normal file
757
src/model/unet/mod.rs
Normal file
@@ -0,0 +1,757 @@
|
||||
pub mod load;
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn::{self, PaddingConfig2d, GELU, conv::{Conv2d, Conv2dConfig}},
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
activation::softmax,
|
||||
module::embedding,
|
||||
Tensor,
|
||||
Distribution,
|
||||
Int,
|
||||
},
|
||||
};
|
||||
|
||||
use super::silu::*;
|
||||
use super::groupnorm::*;
|
||||
use crate::helper::to_float;
|
||||
|
||||
use super::attention::qkv_attention;
|
||||
|
||||
|
||||
fn timestep_embedding<B: Backend>(timesteps: Tensor<B, 1, Int>, dim: usize, max_period: usize) -> Tensor<B, 2> {
|
||||
let half = dim / 2;
|
||||
let freqs = ( to_float(Tensor::arange_device(0..half, ×teps.device())) * (-(max_period as f64).ln() / half as f64 ) ).exp();
|
||||
let args = to_float(timesteps) * freqs;
|
||||
Tensor::cat(vec![args.clone().cos(), args.sin()], 0).unsqueeze()
|
||||
}
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct UNetConfig {}
|
||||
|
||||
impl UNetConfig {
|
||||
pub fn init<B: Backend>(&self) -> UNet<B> {
|
||||
let lin1_time_embed = nn::LinearConfig::new(320, 1280).init();
|
||||
let silu_time_embed = SILU::new();
|
||||
let lin2_time_embed = nn::LinearConfig::new(1280, 1280).init();
|
||||
|
||||
let input_blocks = UNetInputBlocks {
|
||||
conv: Conv2dConfig::new([4, 320], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init(),
|
||||
rt1: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(),
|
||||
rt2: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(),
|
||||
d1: DownsampleConfig::new(320).init(),
|
||||
rt3: ResTransformerConfig::new(320, 1280, 640, 768, 8).init(),
|
||||
rt4: ResTransformerConfig::new(640, 1280, 640, 768, 8).init(),
|
||||
d2: DownsampleConfig::new(640).init(),
|
||||
rt5: ResTransformerConfig::new(640, 1280, 1280, 768, 8).init(),
|
||||
rt6: ResTransformerConfig::new(1280, 1280, 1280, 768, 8).init(),
|
||||
d3: DownsampleConfig::new(1280).init(),
|
||||
r1: ResBlockConfig::new(1280, 1280, 1280).init(),
|
||||
r2: ResBlockConfig::new(1280, 1280, 1280).init(),
|
||||
};
|
||||
|
||||
let middle_block = ResTransformerResConfig::new(1280, 1280, 1280, 768, 8).init();
|
||||
|
||||
let output_blocks = UNetOutputBlocks {
|
||||
r1: ResBlockConfig::new(2560, 1280, 1280).init(),
|
||||
r2: ResBlockConfig::new(2560, 1280, 1280).init(),
|
||||
ru: ResUpSampleConfig::new(2560, 1280, 1280).init(),
|
||||
rt1: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(),
|
||||
rt2: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(),
|
||||
rtu1: ResTransformerUpsampleConfig::new(1920, 1280, 1280, 768, 8).init(),
|
||||
rt3: ResTransformerConfig::new(1920, 1280, 640, 768, 8).init(),
|
||||
rt4: ResTransformerConfig::new(1280, 1280, 640, 768, 8).init(),
|
||||
rtu2: ResTransformerUpsampleConfig::new(960, 1280, 640, 768, 8).init(),
|
||||
rt5: ResTransformerConfig::new(960, 1280, 320, 768, 8).init(),
|
||||
rt6: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(),
|
||||
rt7: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(),
|
||||
};
|
||||
|
||||
let norm_out = GroupNormConfig::new(32, 320).init();
|
||||
let silu_out = SILU::new();
|
||||
let conv_out = Conv2dConfig::new([320, 4], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
||||
|
||||
UNet {
|
||||
lin1_time_embed,
|
||||
silu_time_embed,
|
||||
lin2_time_embed,
|
||||
input_blocks,
|
||||
middle_block,
|
||||
output_blocks,
|
||||
norm_out,
|
||||
silu_out,
|
||||
conv_out,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct UNet<B: Backend> {
|
||||
lin1_time_embed: nn::Linear<B>,
|
||||
silu_time_embed: SILU,
|
||||
lin2_time_embed: nn::Linear<B>,
|
||||
input_blocks: UNetInputBlocks<B>,
|
||||
middle_block: ResTransformerRes<B>,
|
||||
output_blocks: UNetOutputBlocks<B>,
|
||||
norm_out: GroupNorm<B>,
|
||||
silu_out: SILU,
|
||||
conv_out: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> UNet<B> {
|
||||
pub fn forward(&self, x: Tensor<B, 4>, timesteps: Tensor<B, 1, Int>, context: Tensor<B, 3>) -> Tensor<B, 4> {
|
||||
let t_emb = timestep_embedding(timesteps, 320, 10000);
|
||||
let emb = self.lin1_time_embed.forward(t_emb);
|
||||
let emb = self.silu_time_embed.forward(emb);
|
||||
let emb = self.lin2_time_embed.forward(emb);
|
||||
|
||||
let mut saved_inputs = Vec::new();
|
||||
let mut x = x;
|
||||
|
||||
// input blocks
|
||||
for block in self.input_blocks.as_array() {
|
||||
println!("{:?}", x.clone().flatten::<1>(0, 3).slice([0..100]).into_data());
|
||||
x = block.forward(x, emb.clone(), context.clone());
|
||||
saved_inputs.push(x.clone())
|
||||
}
|
||||
|
||||
// middle block
|
||||
x = self.middle_block.forward(x, emb.clone(), context.clone());
|
||||
|
||||
// output blocks
|
||||
for block in self.output_blocks.as_array() {
|
||||
x = Tensor::cat(vec![x, saved_inputs.pop().unwrap()], 1);
|
||||
x = block.forward(x, emb.clone(), context.clone());
|
||||
}
|
||||
|
||||
let x = self.norm_out.forward(x);
|
||||
let x = self.silu_out.forward(x);
|
||||
let x = self.conv_out.forward(x);
|
||||
x
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct UNetInputBlocks<B: Backend> {
|
||||
conv: Conv2d<B>,
|
||||
rt1: ResTransformer<B>,
|
||||
rt2: ResTransformer<B>,
|
||||
d1: Downsample<B>,
|
||||
rt3: ResTransformer<B>,
|
||||
rt4: ResTransformer<B>,
|
||||
d2: Downsample<B>,
|
||||
rt5: ResTransformer<B>,
|
||||
rt6: ResTransformer<B>,
|
||||
d3: Downsample<B>,
|
||||
r1: ResBlock<B>,
|
||||
r2: ResBlock<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> UNetInputBlocks<B> {
|
||||
fn as_array(&self) -> [&dyn UNetBlock<B>; 12] {
|
||||
[
|
||||
&self.conv,
|
||||
&self.rt1,
|
||||
&self.rt2,
|
||||
&self.d1,
|
||||
&self.rt3,
|
||||
&self.rt4,
|
||||
&self.d2,
|
||||
&self.rt5,
|
||||
&self.rt6,
|
||||
&self.d3,
|
||||
&self.r1,
|
||||
&self.r2,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct UNetOutputBlocks<B: Backend> {
|
||||
r1: ResBlock<B>,
|
||||
r2: ResBlock<B>,
|
||||
ru: ResUpSample<B>,
|
||||
rt1: ResTransformer<B>,
|
||||
rt2: ResTransformer<B>,
|
||||
rtu1: ResTransformerUpsample<B>,
|
||||
rt3: ResTransformer<B>,
|
||||
rt4: ResTransformer<B>,
|
||||
rtu2: ResTransformerUpsample<B>,
|
||||
rt5: ResTransformer<B>,
|
||||
rt6: ResTransformer<B>,
|
||||
rt7: ResTransformer<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> UNetOutputBlocks<B> {
|
||||
fn as_array(&self) -> [&dyn UNetBlock<B>; 12] {
|
||||
[
|
||||
&self.r1,
|
||||
&self.r2,
|
||||
&self.ru,
|
||||
&self.rt1,
|
||||
&self.rt2,
|
||||
&self.rtu1,
|
||||
&self.rt3,
|
||||
&self.rt4,
|
||||
&self.rtu2,
|
||||
&self.rt5,
|
||||
&self.rt6,
|
||||
&self.rt7,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
trait UNetBlock<B: Backend> {
|
||||
fn forward(&self, x: Tensor<B, 4>, emb: Tensor<B, 2>, context: Tensor<B, 3>) -> Tensor<B, 4>;
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct ResTransformerConfig {
|
||||
n_channels_in: usize,
|
||||
n_channels_embed: usize,
|
||||
n_channels_out: usize,
|
||||
n_context_state: usize,
|
||||
n_head: usize,
|
||||
}
|
||||
|
||||
impl ResTransformerConfig {
|
||||
fn init<B: Backend>(&self) -> ResTransformer<B> {
|
||||
let res = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init();
|
||||
let transformer = SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head).init();
|
||||
|
||||
ResTransformer {
|
||||
res,
|
||||
transformer,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ResTransformer<B: Backend> {
|
||||
res: ResBlock<B>,
|
||||
transformer: SpatialTransformer<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> UNetBlock<B> for ResTransformer<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>, emb: Tensor<B, 2>, context: Tensor<B, 3>) -> Tensor<B, 4> {
|
||||
let x = self.res.forward(x, emb);
|
||||
let x = self.transformer.forward(x, context);
|
||||
x
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct ResUpSampleConfig {
|
||||
n_channels_in: usize,
|
||||
n_channels_embed: usize,
|
||||
n_channels_out: usize,
|
||||
}
|
||||
|
||||
impl ResUpSampleConfig {
|
||||
fn init<B: Backend>(&self) -> ResUpSample<B> {
|
||||
let res = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init();
|
||||
let upsample = UpsampleConfig::new(self.n_channels_out).init();
|
||||
|
||||
ResUpSample {
|
||||
res,
|
||||
upsample,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ResUpSample<B: Backend> {
|
||||
res: ResBlock<B>,
|
||||
upsample: Upsample<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> UNetBlock<B> for ResUpSample<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>, emb: Tensor<B, 2>, context: Tensor<B, 3>) -> Tensor<B, 4> {
|
||||
let x = self.res.forward(x, emb);
|
||||
let x = self.upsample.forward(x);
|
||||
x
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct ResTransformerUpsampleConfig {
|
||||
n_channels_in: usize,
|
||||
n_channels_embed: usize,
|
||||
n_channels_out: usize,
|
||||
n_context_state: usize,
|
||||
n_head: usize,
|
||||
}
|
||||
|
||||
impl ResTransformerUpsampleConfig {
|
||||
fn init<B: Backend>(&self) -> ResTransformerUpsample<B> {
|
||||
let res = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init();
|
||||
let transformer = SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head).init();
|
||||
let upsample = UpsampleConfig::new(self.n_channels_out).init();
|
||||
|
||||
ResTransformerUpsample {
|
||||
res,
|
||||
transformer,
|
||||
upsample,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ResTransformerUpsample<B: Backend> {
|
||||
res: ResBlock<B>,
|
||||
transformer: SpatialTransformer<B>,
|
||||
upsample: Upsample<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> UNetBlock<B> for ResTransformerUpsample<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>, emb: Tensor<B, 2>, context: Tensor<B, 3>) -> Tensor<B, 4> {
|
||||
let x = self.res.forward(x, emb);
|
||||
let x = self.transformer.forward(x, context);
|
||||
let x = self.upsample.forward(x);
|
||||
x
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct ResTransformerResConfig {
|
||||
n_channels_in: usize,
|
||||
n_channels_embed: usize,
|
||||
n_channels_out: usize,
|
||||
n_context_state: usize,
|
||||
n_head: usize,
|
||||
}
|
||||
|
||||
impl ResTransformerResConfig {
|
||||
fn init<B: Backend>(&self) -> ResTransformerRes<B> {
|
||||
let res1 = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init();
|
||||
let transformer = SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head).init();
|
||||
let res2 = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init();
|
||||
|
||||
ResTransformerRes {
|
||||
res1,
|
||||
transformer,
|
||||
res2,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ResTransformerRes<B: Backend> {
|
||||
res1: ResBlock<B>,
|
||||
transformer: SpatialTransformer<B>,
|
||||
res2: ResBlock<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> UNetBlock<B> for ResTransformerRes<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>, emb: Tensor<B, 2>, context: Tensor<B, 3>) -> Tensor<B, 4> {
|
||||
let x = self.res1.forward(x, emb.clone());
|
||||
let x = self.transformer.forward(x, context);
|
||||
let x = self.res2.forward(x, emb);
|
||||
x
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct UpsampleConfig {
|
||||
n_channels: usize,
|
||||
}
|
||||
|
||||
impl UpsampleConfig {
|
||||
fn init<B: Backend>(&self) -> Upsample<B> {
|
||||
let conv = Conv2dConfig::new([self.n_channels, self.n_channels], [3, 3])
|
||||
.with_stride([2, 2])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init();
|
||||
|
||||
Upsample {
|
||||
conv,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Upsample<B: Backend> {
|
||||
conv: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Upsample<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let [n_batch, n_channel, height, width] = x.dims();
|
||||
let x = x
|
||||
.reshape([n_batch, n_channel, height, 1, width, 1])
|
||||
.repeat(3, 2)
|
||||
.repeat(5, 2)
|
||||
.reshape([n_batch, n_channel, 2 * height, 2 * width]);
|
||||
self.conv.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> UNetBlock<B> for Upsample<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>, emb: Tensor<B, 2>, context: Tensor<B, 3>) -> Tensor<B, 4> {
|
||||
self.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct DownsampleConfig {
|
||||
n_channels: usize,
|
||||
}
|
||||
|
||||
impl DownsampleConfig {
|
||||
fn init<B: Backend>(&self) -> Conv2d<B> {
|
||||
Conv2dConfig::new([self.n_channels, self.n_channels], [3, 3])
|
||||
.with_stride([2, 2])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init()
|
||||
}
|
||||
}
|
||||
|
||||
type Downsample<B> = Conv2d<B>;
|
||||
|
||||
impl<B: Backend> UNetBlock<B> for Conv2d<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>, emb: Tensor<B, 2>, context: Tensor<B, 3>) -> Tensor<B, 4> {
|
||||
self.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct SpatialTransformerConfig {
|
||||
n_channels: usize,
|
||||
n_context_state: usize,
|
||||
n_head: usize,
|
||||
}
|
||||
|
||||
impl SpatialTransformerConfig {
|
||||
fn init<B: Backend>(&self) -> SpatialTransformer<B> {
|
||||
let norm = GroupNormConfig::new(32, self.n_channels).init();
|
||||
let proj_in = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init();
|
||||
let transformer = TransformerBlockConfig::new(self.n_channels, self.n_context_state, self.n_head).init();
|
||||
let proj_out = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init();
|
||||
|
||||
SpatialTransformer {
|
||||
norm,
|
||||
proj_in,
|
||||
transformer,
|
||||
proj_out,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct SpatialTransformer<B: Backend> {
|
||||
norm: GroupNorm<B>,
|
||||
proj_in: Conv2d<B>,
|
||||
transformer: TransformerBlock<B>,
|
||||
proj_out: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> SpatialTransformer<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>, context: Tensor<B, 3>) -> Tensor<B, 4> {
|
||||
let [n_batch, n_channel, height, width] = x.dims();
|
||||
|
||||
let x_in = x.clone();
|
||||
|
||||
let x = self.norm.forward(x);
|
||||
let x = self.proj_in.forward(x);
|
||||
let x = x.reshape([n_batch, n_channel, height * width]).swap_dims(1, 2);
|
||||
|
||||
let x = self.transformer.forward(x, context)
|
||||
.swap_dims(1, 2)
|
||||
.reshape([n_batch, n_channel, height, width]);
|
||||
|
||||
x_in + self.proj_out.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct TransformerBlockConfig {
|
||||
n_state: usize,
|
||||
n_context_state: usize,
|
||||
n_head: usize,
|
||||
}
|
||||
|
||||
impl TransformerBlockConfig {
|
||||
fn init<B: Backend>(&self) -> TransformerBlock<B> {
|
||||
let norm1 = nn::LayerNormConfig::new(self.n_state).init();
|
||||
let attn1 = MultiHeadAttentionConfig::new(self.n_state, self.n_context_state, self.n_head).init();
|
||||
let norm2 = nn::LayerNormConfig::new(self.n_state).init();
|
||||
let attn2 = MultiHeadAttentionConfig::new(self.n_state, self.n_context_state, self.n_head).init();
|
||||
let norm3 = nn::LayerNormConfig::new(self.n_state).init();
|
||||
let mlp = MLPConfig::new(self.n_state, 4).init();
|
||||
|
||||
TransformerBlock {
|
||||
norm1,
|
||||
attn1,
|
||||
norm2,
|
||||
attn2,
|
||||
norm3,
|
||||
mlp,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct TransformerBlock<B: Backend> {
|
||||
norm1: nn::LayerNorm<B>,
|
||||
attn1: MultiHeadAttention<B>,
|
||||
norm2: nn::LayerNorm<B>,
|
||||
attn2: MultiHeadAttention<B>,
|
||||
norm3: nn::LayerNorm<B>,
|
||||
mlp: MLP<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> TransformerBlock<B> {
|
||||
fn forward(&self, x: Tensor<B, 3>, context: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
let x = x.clone() + self.attn1.forward( self.norm1.forward(x), None);
|
||||
let x = x.clone() + self.attn2.forward( self.norm2.forward(x), Some(context));
|
||||
x.clone() + self.mlp.forward( self.norm3.forward(x) )
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct MLPConfig {
|
||||
n_state: usize,
|
||||
mult: usize,
|
||||
}
|
||||
|
||||
impl MLPConfig {
|
||||
pub fn init<B: Backend>(&self) -> MLP<B> {
|
||||
let n_state_hidden = self.n_state * self.mult;
|
||||
let geglu = GEGLUConfig::new(self.n_state, n_state_hidden).init();
|
||||
let lin = nn::LinearConfig::new(n_state_hidden, self.n_state).init();
|
||||
|
||||
MLP {
|
||||
geglu,
|
||||
lin,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct MLP<B: Backend> {
|
||||
geglu: GEGLU<B>,
|
||||
lin: nn::Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> MLP<B> {
|
||||
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
self.lin.forward( self.geglu.forward(x) )
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct GEGLUConfig {
|
||||
n_state_in: usize,
|
||||
n_state_out: usize,
|
||||
}
|
||||
|
||||
impl GEGLUConfig {
|
||||
fn init<B: Backend>(&self) -> GEGLU<B> {
|
||||
let proj = nn::LinearConfig::new(self.n_state_in, 2 * self.n_state_out).init();
|
||||
let gelu = GELU::new();
|
||||
|
||||
GEGLU {
|
||||
proj,
|
||||
gelu,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct GEGLU<B: Backend> {
|
||||
proj: nn::Linear<B>,
|
||||
gelu: GELU,
|
||||
}
|
||||
|
||||
impl<B: Backend> GEGLU<B> {
|
||||
fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
let projected = self.proj.forward(x);
|
||||
let [n_batch, n_ctx, n_state] = projected.dims();
|
||||
|
||||
let n_state_out = n_state / 2;
|
||||
|
||||
let x = projected.clone().slice([0..n_batch, 0..n_ctx, 0..n_state_out]);
|
||||
let gate = projected.slice([0..n_batch, 0..n_ctx, n_state_out..n_state]);
|
||||
|
||||
x * self.gelu.forward(gate)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct MultiHeadAttentionConfig {
|
||||
n_state: usize,
|
||||
n_context_state: usize,
|
||||
n_head: usize,
|
||||
}
|
||||
|
||||
impl MultiHeadAttentionConfig {
|
||||
fn init<B: Backend>(&self) -> MultiHeadAttention<B> {
|
||||
assert!(self.n_state % self.n_head == 0, "State size {} must be a multiple of head size {}", self.n_state, self.n_head);
|
||||
|
||||
let n_head = self.n_head;
|
||||
let query = nn::LinearConfig::new(self.n_state, self.n_state).with_bias(false).init();
|
||||
let key = nn::LinearConfig::new(self.n_context_state, self.n_state).with_bias(false).init();
|
||||
let value = nn::LinearConfig::new(self.n_context_state, self.n_state).with_bias(false).init();
|
||||
let out = nn::LinearConfig::new(self.n_state, self.n_state).init();
|
||||
|
||||
MultiHeadAttention {
|
||||
n_head,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct MultiHeadAttention<B: Backend> {
|
||||
n_head: usize,
|
||||
query: nn::Linear<B>,
|
||||
key: nn::Linear<B>,
|
||||
value: nn::Linear<B>,
|
||||
out: nn::Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> MultiHeadAttention<B> {
|
||||
pub fn forward(&self, x: Tensor<B, 3>, context: Option<Tensor<B, 3>>) -> Tensor<B, 3> {
|
||||
let xa = context.unwrap_or_else(|| x.clone());
|
||||
|
||||
let q = self.query.forward(x);
|
||||
let k = self.key.forward(xa.clone());
|
||||
let v = self.value.forward(xa);
|
||||
|
||||
let wv = qkv_attention(q, k, v, None, self.n_head);
|
||||
|
||||
self.out.forward(wv)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct ResBlockConfig {
|
||||
n_channels_in: usize,
|
||||
n_channels_embed: usize,
|
||||
n_channels_out: usize,
|
||||
}
|
||||
|
||||
|
||||
impl ResBlockConfig {
|
||||
fn init<B: Backend>(&self) -> ResBlock<B> {
|
||||
let norm_in = GroupNormConfig::new(32, self.n_channels_in).init();
|
||||
let silu_in = SILU::new();
|
||||
let conv_in = Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
||||
|
||||
let silu_embed = SILU::new();
|
||||
let lin_embed = nn::LinearConfig::new(self.n_channels_embed, self.n_channels_out).init();
|
||||
|
||||
let norm_out = GroupNormConfig::new(32, self.n_channels_out).init();
|
||||
let silu_out = SILU::new();
|
||||
let conv_out = Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
||||
|
||||
let skip_connection = if self.n_channels_in != self.n_channels_out {
|
||||
Some( Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [1, 1]).init() )
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
ResBlock {
|
||||
norm_in,
|
||||
silu_in,
|
||||
conv_in,
|
||||
silu_embed,
|
||||
lin_embed,
|
||||
norm_out,
|
||||
silu_out,
|
||||
conv_out,
|
||||
skip_connection,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ResBlock<B: Backend> {
|
||||
norm_in: GroupNorm<B>,
|
||||
silu_in: SILU,
|
||||
conv_in: Conv2d<B>,
|
||||
silu_embed: SILU,
|
||||
lin_embed: nn::Linear<B>,
|
||||
norm_out: GroupNorm<B>,
|
||||
silu_out: SILU,
|
||||
conv_out: Conv2d<B>,
|
||||
skip_connection: Option<Conv2d<B>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ResBlock<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>, embed: Tensor<B, 2>) -> Tensor<B, 4> {
|
||||
let h = self.norm_in.forward(x.clone());
|
||||
let h = self.silu_in.forward(h);
|
||||
let h = self.conv_in.forward(h);
|
||||
|
||||
let embed_out = self.silu_embed.forward(embed);
|
||||
let embed_out = self.lin_embed.forward(embed_out);
|
||||
|
||||
let [n_batch_embed, n_state_embed] = embed_out.dims();
|
||||
let h = h + embed_out.reshape([n_batch_embed, n_state_embed, 1, 1]);
|
||||
|
||||
let h = self.norm_out.forward(h);
|
||||
let h = self.silu_out.forward(h);
|
||||
let h = self.conv_out.forward(h);
|
||||
|
||||
if let Some(skipc) = self.skip_connection.as_ref() {
|
||||
skipc.forward(x) + h
|
||||
} else {
|
||||
x + h
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> UNetBlock<B> for ResBlock<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>, emb: Tensor<B, 2>, context: Tensor<B, 3>) -> Tensor<B, 4> {
|
||||
self.forward(x, emb)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
54
src/token.rs
Normal file
54
src/token.rs
Normal file
@@ -0,0 +1,54 @@
|
||||
use std::result;
|
||||
use rust_tokenizers::{error::TokenizerError, tokenizer::{Tokenizer, SentencePieceBpeTokenizer, TruncationStrategy}, vocab::Vocab};
|
||||
|
||||
const BOS_TOKEN_ID: i64 = 1;
|
||||
const EOS_TOKEN_ID: i64 = 2;
|
||||
|
||||
pub type Result<T> = result::Result<T, TokenizerError>;
|
||||
|
||||
pub struct LlamaTokenizer {
|
||||
spm: SentencePieceBpeTokenizer,
|
||||
}
|
||||
|
||||
impl LlamaTokenizer {
|
||||
pub fn new(tokenizer_path: &str) -> Result<Self> {
|
||||
let lower_case = false;
|
||||
SentencePieceBpeTokenizer::from_file(tokenizer_path, lower_case)
|
||||
.map(|spm| Self { spm } )
|
||||
}
|
||||
|
||||
pub fn encode(&self, text: &str, include_bos: bool, include_eos: bool) -> Vec<i64> {
|
||||
let pre = if include_bos {
|
||||
vec![BOS_TOKEN_ID]
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
let post = if include_eos {
|
||||
vec![EOS_TOKEN_ID]
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
let token_ids = self.spm.encode(text, None, std::usize::MAX, &TruncationStrategy::LongestFirst, 0).token_ids;
|
||||
|
||||
[pre, token_ids, post]
|
||||
.into_iter()
|
||||
.flat_map(|v| v.into_iter())
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn decode(&self, tokens: &[i64], skip_special_tokens: bool) -> String {
|
||||
let clean_spaces = false;
|
||||
self.spm.decode(tokens, skip_special_tokens, clean_spaces)
|
||||
}
|
||||
|
||||
pub fn vocab_size(&self, include_special_tokens: bool) -> usize {
|
||||
let vocab = self.spm.vocab();
|
||||
if include_special_tokens {
|
||||
vocab.values().len() + vocab.special_values().len()
|
||||
} else {
|
||||
vocab.values().len()
|
||||
}
|
||||
}
|
||||
}
|
||||
215
src/tokenizer.rs
Normal file
215
src/tokenizer.rs
Normal file
@@ -0,0 +1,215 @@
|
||||
use std::collections::HashMap;
|
||||
use regex::Regex;
|
||||
|
||||
use std::fs::File;
|
||||
use std::io::{self, BufRead};
|
||||
|
||||
fn bytes_to_unicode() -> Vec<(u8, char)> {
|
||||
let mut bs: Vec<u8> = ('!' as u8 ..= '~' as u8).into_iter()
|
||||
.chain( ('¡' as u8..='¬' as u8).into_iter() )
|
||||
.chain( ('®' as u8..='ÿ' as u8).into_iter() )
|
||||
.collect();
|
||||
|
||||
let mut cs: Vec<_> = bs.iter().cloned().map(char::from).collect();
|
||||
|
||||
let mut n = 0;
|
||||
for b in 0u8..=255u8 {
|
||||
if !bs.contains(&b) {
|
||||
bs.push(b);
|
||||
cs.push( char::from_u32(256 + n).unwrap() );
|
||||
n += 1;
|
||||
}
|
||||
}
|
||||
|
||||
bs.into_iter()
|
||||
.zip(
|
||||
cs.into_iter()
|
||||
.map(|c| c.into())
|
||||
).collect()
|
||||
}
|
||||
|
||||
fn get_pairs(word: &[String]) -> Vec<(String, String)> {
|
||||
let prev = word.into_iter().cloned();
|
||||
let next = prev.clone().skip(1);
|
||||
|
||||
prev
|
||||
.zip(next)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn whitespace_clean(text: &str) -> String {
|
||||
text.split_whitespace().collect::<Vec<&str>>().join(" ")
|
||||
}
|
||||
|
||||
fn load_merges(path: &str) -> io::Result<Vec<(String, String)>> {
|
||||
let file = File::open(&path)?;
|
||||
let reader = io::BufReader::new(file);
|
||||
|
||||
let mut merges = Vec::new();
|
||||
|
||||
for line in reader.lines() {
|
||||
let line = line?;
|
||||
let mut words = line.split_whitespace();
|
||||
|
||||
if let (Some(word1), Some(word2)) = (words.next(), words.next()) {
|
||||
merges.push((word1.into(), word2.into()));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(merges)
|
||||
}
|
||||
|
||||
fn construct_vocab(chars: impl Iterator<Item=char> + Clone, merges: &[(String, String)]) -> Vec<String> {
|
||||
let iter = chars.map(String::from);
|
||||
let mut vocab: Vec<_> = iter.clone().chain( iter.map(|c| c + "</w>") ).collect();
|
||||
|
||||
for merge in merges {
|
||||
vocab.push(format!("{}{}", merge.0, merge.1));
|
||||
}
|
||||
|
||||
vocab.extend(["<|startoftext|>".to_string(), "<|endoftext|>".to_string()]);
|
||||
|
||||
return vocab;
|
||||
}
|
||||
|
||||
pub struct SimpleTokenizer {
|
||||
byte_encoder: HashMap<u8, char>,
|
||||
byte_decoder: HashMap<char, u8>,
|
||||
encoder: HashMap<String, u32>,
|
||||
decoder: HashMap<u32, String>,
|
||||
bpe_ranks: HashMap<(String, String), u32>,
|
||||
cache: HashMap<String, String>,
|
||||
pat: Regex,
|
||||
}
|
||||
|
||||
impl SimpleTokenizer {
|
||||
pub fn new() -> io::Result<Self> {
|
||||
let byte_unicode_values = bytes_to_unicode();
|
||||
|
||||
let byte_encoder: HashMap<_, _> = byte_unicode_values.iter().cloned().collect();
|
||||
let byte_decoder = byte_encoder.iter().map(|(k,v)| (*v,*k)).collect();
|
||||
|
||||
let merges = load_merges("bpe_simple_vocab_16e6.txt")?;
|
||||
let merges = merges[1..49152-256-2+1].to_vec();
|
||||
|
||||
let vocab = construct_vocab(byte_unicode_values.into_iter().map(|(_, u)| u), &merges[..]);
|
||||
|
||||
let encoder: HashMap<String, u32> = vocab.iter().cloned().zip((0..).into_iter()).collect();
|
||||
let decoder: HashMap<u32, String> = encoder.iter().map(|(k, v)| (*v, k.clone())).collect();
|
||||
let bpe_ranks = merges.iter().cloned().zip((0..).into_iter()).collect();
|
||||
let cache = HashMap::from([
|
||||
("<|startoftext|>".to_string(), "<|startoftext|>".to_string()),
|
||||
("<|endoftext|>".to_string(), "<|endoftext|>".to_string()),
|
||||
]);
|
||||
|
||||
let pat = Regex::new(r"(?i)<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|\p{L}+|\p{N}|[^\s\p{L}\p{N}]+").unwrap();
|
||||
|
||||
Ok( SimpleTokenizer {
|
||||
byte_encoder: byte_encoder,
|
||||
byte_decoder: byte_decoder,
|
||||
encoder: encoder,
|
||||
decoder: decoder,
|
||||
bpe_ranks: bpe_ranks,
|
||||
cache: cache,
|
||||
pat: pat,
|
||||
} )
|
||||
}
|
||||
|
||||
pub fn bpe(&self, token: &str) -> String {
|
||||
if let Some(word) = self.cache.get(token) {
|
||||
return word.clone();
|
||||
}
|
||||
|
||||
let mut word: Vec<String> = token.chars().map(|c| c.to_string()).collect();
|
||||
word.last_mut().map(|w| *w += "</w>");
|
||||
let mut pairs = get_pairs(&word);
|
||||
|
||||
if pairs.is_empty() {
|
||||
return format!("{}{}", token, "</w>");
|
||||
}
|
||||
|
||||
loop {
|
||||
let bigram = pairs.iter()
|
||||
.filter(|pair| self.bpe_ranks.contains_key(pair))
|
||||
.min_by_key(|&pair| self.bpe_ranks[pair]);
|
||||
|
||||
if bigram.is_none() {
|
||||
break;
|
||||
}
|
||||
|
||||
let (first, second) = bigram.unwrap();
|
||||
let mut new_word = Vec::new();
|
||||
let mut i = 0;
|
||||
while i < word.len() {
|
||||
if let Some( (j, _) ) = word.iter().enumerate().skip(i).find(|(_, w)| w == &first) {
|
||||
new_word.extend(word[i..j].iter().cloned());
|
||||
i = j;
|
||||
} else {
|
||||
new_word.extend(word[i..].iter().cloned());
|
||||
break;
|
||||
}
|
||||
|
||||
if &word[i] == first && i < word.len() - 1 && &word[i + 1] == second {
|
||||
new_word.push(format!("{}{}", first, second));
|
||||
i += 2;
|
||||
} else {
|
||||
new_word.push(word[i].clone());
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
word = new_word;
|
||||
if word.len() == 1 {
|
||||
break;
|
||||
} else {
|
||||
pairs = get_pairs(&word[..])
|
||||
}
|
||||
}
|
||||
|
||||
let word = word.join(" ");
|
||||
//self.cache.insert(token.into(), word);
|
||||
return word;
|
||||
}
|
||||
|
||||
pub fn encode(&self, text: &str) -> Vec<u32> {
|
||||
let cleaned_text = whitespace_clean(text.trim()).to_lowercase();
|
||||
|
||||
let mut bpe_tokens: Vec<u32> = Vec::new();
|
||||
|
||||
for m in self.pat.find_iter(&cleaned_text) {
|
||||
let token = m.as_str();
|
||||
let token: String = token.as_bytes().into_iter().map(|b| self.byte_encoder[b]).collect();
|
||||
bpe_tokens.extend(self.bpe(&token).split(' ').map(|bpe_token| self.encoder[bpe_token]))
|
||||
}
|
||||
|
||||
return bpe_tokens;
|
||||
}
|
||||
|
||||
pub fn decode(&self, tokens: &[u32]) -> String {
|
||||
let text: String = tokens.iter().map(|t| self.decoder[t].as_str()).collect();
|
||||
let decoded_bytes: Vec<u8> = text.chars()
|
||||
.map(|c| self.byte_decoder[&c])
|
||||
.collect();
|
||||
|
||||
String::from_utf8_lossy(&decoded_bytes[..]).replace("</w>", " ")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_encode_decode() {
|
||||
let tokenizer = SimpleTokenizer::new().unwrap();
|
||||
|
||||
let text = "Hello world! <|startoftext|>asdf<|startoftext|>";
|
||||
let target_encode = [3306, 1002, 256, 49406, 587, 10468, 49406];
|
||||
let target_decode = "hello world ! <|startoftext|>asdf <|startoftext|>"; // extra spaces sometimes
|
||||
|
||||
let encoded = tokenizer.encode(&text);
|
||||
assert_eq!(&target_encode[..], &encoded[..]);
|
||||
let decoded = tokenizer.decode(&encoded[..]);
|
||||
assert_eq!(target_decode, decoded);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user