From 3c2c43dd93c5291dc7471abfcfe0befbfd3bf69d Mon Sep 17 00:00:00 2001 From: Gadersd Date: Fri, 4 Aug 2023 21:22:47 -0400 Subject: [PATCH] Remove stable_diffusion.py --- python/stable_diffusion.py | 731 ------------------------------------- 1 file changed, 731 deletions(-) delete mode 100644 python/stable_diffusion.py diff --git a/python/stable_diffusion.py b/python/stable_diffusion.py deleted file mode 100644 index 10c6622..0000000 --- a/python/stable_diffusion.py +++ /dev/null @@ -1,731 +0,0 @@ -# https://arxiv.org/pdf/2112.10752.pdf -# https://github.com/ekagra-ranjan/huggingface-blog/blob/main/stable_diffusion.md -import os -import tempfile -from pathlib import Path -import gzip, argparse, math, re -from functools import lru_cache -from collections import namedtuple - -from tqdm import tqdm -from tinygrad.tensor import Tensor -from tinygrad.helpers import dtypes, GlobalCounters -from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding -from extra.utils import download_file -from tinygrad.state import torch_load, load_state_dict - -# TODO: refactor AttnBlock, CrossAttention, CLIPAttention to share code - -class AttnBlock: - def __init__(self, in_channels): - self.norm = GroupNorm(32, in_channels) - self.q = Conv2d(in_channels, in_channels, 1) - self.k = Conv2d(in_channels, in_channels, 1) - self.v = Conv2d(in_channels, in_channels, 1) - self.proj_out = Conv2d(in_channels, in_channels, 1) - - # copied from AttnBlock in ldm repo - def __call__(self, x): - h_ = self.norm(x) - q,k,v = self.q(h_), self.k(h_), self.v(h_) - - # compute attention - b,c,h,w = q.shape - q = q.reshape(b,c,h*w) - q = q.permute(0,2,1) # b,hw,c - k = k.reshape(b,c,h*w) # b,c,hw - w_ = q @ k - w_ = w_ * (c**(-0.5)) - w_ = w_.softmax() - - # attend to values - v = v.reshape(b,c,h*w) - w_ = w_.permute(0,2,1) - h_ = v @ w_ - h_ = h_.reshape(b,c,h,w) - - return x + self.proj_out(h_) - -class ResnetBlock: - def __init__(self, in_channels, out_channels=None): - self.norm1 = GroupNorm(32, in_channels) - self.conv1 = Conv2d(in_channels, out_channels, 3, padding=1) - self.norm2 = GroupNorm(32, out_channels) - self.conv2 = Conv2d(out_channels, out_channels, 3, padding=1) - self.nin_shortcut = Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else lambda x: x - - def __call__(self, x): - h = self.conv1(self.norm1(x).swish()) - h = self.conv2(self.norm2(h).swish()) - return self.nin_shortcut(x) + h - -class Mid: - def __init__(self, block_in): - self.block_1 = ResnetBlock(block_in, block_in) - self.attn_1 = AttnBlock(block_in) - self.block_2 = ResnetBlock(block_in, block_in) - - def __call__(self, x): - return x.sequential([self.block_1, self.attn_1, self.block_2]) - -class Decoder: - def __init__(self): - sz = [(128, 256), (256, 512), (512, 512), (512, 512)] - self.conv_in = Conv2d(4,512,3, padding=1) - self.mid = Mid(512) - - arr = [] - for i,s in enumerate(sz): - arr.append({"block": - [ResnetBlock(s[1], s[0]), - ResnetBlock(s[0], s[0]), - ResnetBlock(s[0], s[0])]}) - if i != 0: arr[-1]['upsample'] = {"conv": Conv2d(s[0], s[0], 3, padding=1)} - self.up = arr - - self.norm_out = GroupNorm(32, 128) - self.conv_out = Conv2d(128, 3, 3, padding=1) - - def __call__(self, x): - x = self.conv_in(x) - x = self.mid(x) - - for l in self.up[::-1]: - print("decode", x.shape) - for b in l['block']: - x = b(x) - if 'upsample' in l: - # https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html ? - bs,c,py,px = x.shape - x = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py*2, px*2) - x = l['upsample']['conv'](x) - x.realize() - - return self.conv_out(self.norm_out(x).swish()) - -class Encoder: - def __init__(self): - sz = [(128, 128), (128, 256), (256, 512), (512, 512)] - self.conv_in = Conv2d(3,128,3, padding=1) - - arr = [] - for i,s in enumerate(sz): - arr.append({"block": - [ResnetBlock(s[0], s[1]), - ResnetBlock(s[1], s[1])]}) - if i != 3: arr[-1]['downsample'] = {"conv": Conv2d(s[1], s[1], 3, stride=2, padding=(0,1,0,1))} - self.down = arr - - self.mid = Mid(512) - self.norm_out = GroupNorm(32, 512) - self.conv_out = Conv2d(512, 8, 3, padding=1) - - def __call__(self, x): - x = self.conv_in(x) - - for i, l in enumerate(self.down): - print("encode", x.shape) - for b in l['block']: x = b(x) - if 'downsample' in l: x = l['downsample']['conv'](x) - - x = self.mid(x) - return self.conv_out(self.norm_out(x).swish()) - -class AutoencoderKL: - def __init__(self): - self.encoder = Encoder() - self.decoder = Decoder() - self.quant_conv = Conv2d(8, 8, 1) - self.post_quant_conv = Conv2d(4, 4, 1) - - def __call__(self, x): - latent = self.encoder(x) - latent = self.quant_conv(latent) - latent = latent[:, 0:4] # only the means - latent = self.post_quant_conv(latent) - return self.decoder(latent) - -# not to be confused with ResnetBlock -class ResBlock: - def __init__(self, channels, emb_channels, out_channels): - self.in_layers = [ - GroupNorm(32, channels), - Tensor.silu, - Conv2d(channels, out_channels, 3, padding=1) - ] - self.emb_layers = [ - Tensor.silu, - Linear(emb_channels, out_channels) - ] - self.out_layers = [ - GroupNorm(32, out_channels), - Tensor.silu, - lambda x: x, # needed for weights loading code to work - Conv2d(out_channels, out_channels, 3, padding=1) - ] - self.skip_connection = Conv2d(channels, out_channels, 1) if channels != out_channels else lambda x: x - - def __call__(self, x, emb): - h = x.sequential(self.in_layers) - emb_out = emb.sequential(self.emb_layers) - h = h + emb_out.reshape(*emb_out.shape, 1, 1) - h = h.sequential(self.out_layers) - ret = self.skip_connection(x) + h - return ret - -class CrossAttention: - def __init__(self, query_dim, context_dim, n_heads, d_head): - self.to_q = Linear(query_dim, n_heads*d_head, bias=False) - self.to_k = Linear(context_dim, n_heads*d_head, bias=False) - self.to_v = Linear(context_dim, n_heads*d_head, bias=False) - self.scale = d_head ** -0.5 - self.num_heads = n_heads - self.head_size = d_head - self.to_out = [Linear(n_heads*d_head, query_dim)] - - def __call__(self, x, context=None): - context = x if context is None else context - q,k,v = self.to_q(x), self.to_k(context), self.to_v(context) - q = q.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0,2,1,3) # (bs, num_heads, time, head_size) - k = k.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0,2,3,1) # (bs, num_heads, head_size, time) - v = v.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0,2,1,3) # (bs, num_heads, time, head_size) - - score = q.dot(k) * self.scale - weights = score.softmax() # (bs, num_heads, time, time) - attention = weights.dot(v).permute(0,2,1,3) # (bs, time, num_heads, head_size) - - h_ = attention.reshape(shape=(x.shape[0], -1, self.num_heads * self.head_size)) - return h_.sequential(self.to_out) - -class GEGLU: - def __init__(self, dim_in, dim_out): - self.proj = Linear(dim_in, dim_out * 2) - self.dim_out = dim_out - - def __call__(self, x): - x, gate = self.proj(x).chunk(2, dim=-1) - return x * gate.gelu() - -class FeedForward: - def __init__(self, dim, mult=4): - self.net = [ - GEGLU(dim, dim*mult), - lambda x: x, # needed for weights loading code to work - Linear(dim*mult, dim) - ] - - def __call__(self, x): - return x.sequential(self.net) - -class BasicTransformerBlock: - def __init__(self, dim, context_dim, n_heads, d_head): - self.attn1 = CrossAttention(dim, dim, n_heads, d_head) - self.ff = FeedForward(dim) - self.attn2 = CrossAttention(dim, context_dim, n_heads, d_head) - self.norm1 = LayerNorm(dim) - self.norm2 = LayerNorm(dim) - self.norm3 = LayerNorm(dim) - - def __call__(self, x, context=None): - x = self.attn1(self.norm1(x)) + x - x = self.attn2(self.norm2(x), context=context) + x - x = self.ff(self.norm3(x)) + x - return x - -class SpatialTransformer: - def __init__(self, channels, context_dim, n_heads, d_head): - self.norm = GroupNorm(32, channels) - assert channels == n_heads * d_head - self.proj_in = Conv2d(channels, n_heads * d_head, 1) - self.transformer_blocks = [BasicTransformerBlock(channels, context_dim, n_heads, d_head)] - self.proj_out = Conv2d(n_heads * d_head, channels, 1) - - def __call__(self, x, context=None): - b, c, h, w = x.shape - x_in = x - x = self.norm(x) - x = self.proj_in(x) - x = x.reshape(b, c, h*w).permute(0,2,1) - for block in self.transformer_blocks: - x = block(x, context=context) - x = x.permute(0,2,1).reshape(b, c, h, w) - ret = self.proj_out(x) + x_in - return ret - -class Downsample: - def __init__(self, channels): - self.op = Conv2d(channels, channels, 3, stride=2, padding=1) - - def __call__(self, x): - return self.op(x) - -class Upsample: - def __init__(self, channels): - self.conv = Conv2d(channels, channels, 3, padding=1) - - def __call__(self, x): - bs,c,py,px = x.shape - x = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py*2, px*2) - return self.conv(x) - -def timestep_embedding(timesteps, dim, max_period=10000): - half = dim // 2 - freqs = (-math.log(max_period) * Tensor.arange(half) / half).exp() - args = timesteps * freqs - return Tensor.cat(args.cos(), args.sin()).reshape(1, -1) - -class UNetModel: - def __init__(self): - self.time_embed = [ - Linear(320, 1280), - Tensor.silu, - Linear(1280, 1280), - ] - self.input_blocks = [ - [Conv2d(4, 320, kernel_size=3, padding=1)], - [ResBlock(320, 1280, 320), SpatialTransformer(320, 768, 8, 40)], - [ResBlock(320, 1280, 320), SpatialTransformer(320, 768, 8, 40)], - [Downsample(320)], - [ResBlock(320, 1280, 640), SpatialTransformer(640, 768, 8, 80)], - [ResBlock(640, 1280, 640), SpatialTransformer(640, 768, 8, 80)], - [Downsample(640)], - [ResBlock(640, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)], - [ResBlock(1280, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)], - [Downsample(1280)], - [ResBlock(1280, 1280, 1280)], - [ResBlock(1280, 1280, 1280)] - ] - self.middle_block = [ - ResBlock(1280, 1280, 1280), - SpatialTransformer(1280, 768, 8, 160), - ResBlock(1280, 1280, 1280) - ] - self.output_blocks = [ - [ResBlock(2560, 1280, 1280)], - [ResBlock(2560, 1280, 1280)], - [ResBlock(2560, 1280, 1280), Upsample(1280)], - [ResBlock(2560, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)], - [ResBlock(2560, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)], - [ResBlock(1920, 1280, 1280), SpatialTransformer(1280, 768, 8, 160), Upsample(1280)], - [ResBlock(1920, 1280, 640), SpatialTransformer(640, 768, 8, 80)], # 6 - [ResBlock(1280, 1280, 640), SpatialTransformer(640, 768, 8, 80)], - [ResBlock(960, 1280, 640), SpatialTransformer(640, 768, 8, 80), Upsample(640)], - [ResBlock(960, 1280, 320), SpatialTransformer(320, 768, 8, 40)], - [ResBlock(640, 1280, 320), SpatialTransformer(320, 768, 8, 40)], - [ResBlock(640, 1280, 320), SpatialTransformer(320, 768, 8, 40)], - ] - self.out = [ - GroupNorm(32, 320), - Tensor.silu, - Conv2d(320, 4, kernel_size=3, padding=1) - ] - - def __call__(self, x, timesteps=None, context=None): - # TODO: real time embedding - t_emb = timestep_embedding(timesteps, 320) - emb = t_emb.sequential(self.time_embed) - - - - def run(x, bb): - if isinstance(bb, ResBlock): x = bb(x, emb) - elif isinstance(bb, SpatialTransformer): x = bb(x, context) - else: x = bb(x) - return x - - saved_inputs = [] - for i,b in enumerate(self.input_blocks): - for bb in b: - x = run(x, bb) - saved_inputs.append(x) - for bb in self.middle_block: - x = run(x, bb) - for i,b in enumerate(self.output_blocks): - x = x.cat(saved_inputs.pop(), dim=1) - for bb in b: - x = run(x, bb) - return x.sequential(self.out) - -class CLIPMLP: - def __init__(self): - self.fc1 = Linear(768, 3072) - self.fc2 = Linear(3072, 768) - - def __call__(self, hidden_states): - hidden_states = self.fc1(hidden_states) - hidden_states = hidden_states.quick_gelu() - hidden_states = self.fc2(hidden_states) - return hidden_states - -class CLIPAttention: - def __init__(self): - self.embed_dim = 768 - self.num_heads = 12 - self.head_dim = self.embed_dim // self.num_heads - self.scale = self.head_dim**-0.5 - self.k_proj = Linear(self.embed_dim, self.embed_dim) - self.v_proj = Linear(self.embed_dim, self.embed_dim) - self.q_proj = Linear(self.embed_dim, self.embed_dim) - self.out_proj = Linear(self.embed_dim, self.embed_dim) - - def _shape(self, tensor, seq_len: int, bsz: int): - return tensor.reshape(bsz, seq_len, self.num_heads, self.head_dim).permute(0,2,1,3) - - def __call__(self, hidden_states, causal_attention_mask): - bsz, tgt_len, embed_dim = hidden_states.shape - - query_states = self.q_proj(hidden_states) * self.scale - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).reshape(*proj_shape) - key_states = key_states.reshape(*proj_shape) - src_len = key_states.shape[1] - value_states = value_states.reshape(*proj_shape) - - attn_weights = query_states @ key_states.permute(0,2,1) - - attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask - attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = attn_weights.softmax() - - attn_output = attn_weights @ value_states - - attn_output = attn_output.reshape(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.permute(0,2,1,3) - attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) - - attn_output = self.out_proj(attn_output) - return attn_output - -class CLIPEncoderLayer: - def __init__(self): - self.self_attn = CLIPAttention() - self.layer_norm1 = LayerNorm(768) - self.mlp = CLIPMLP() - self.layer_norm2 = LayerNorm(768) - - def __call__(self, hidden_states, causal_attention_mask): - residual = hidden_states - hidden_states = self.layer_norm1(hidden_states) - hidden_states = self.self_attn(hidden_states, causal_attention_mask) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.layer_norm2(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - return hidden_states - -class CLIPEncoder: - def __init__(self): - self.layers = [CLIPEncoderLayer() for i in range(12)] - - def __call__(self, hidden_states, causal_attention_mask): - for l in self.layers: - hidden_states = l(hidden_states, causal_attention_mask) - return hidden_states - -class CLIPTextEmbeddings: - def __init__(self): - self.token_embedding = Embedding(49408, 768) - self.position_embedding = Embedding(77, 768) - - def __call__(self, input_ids, position_ids): - return self.token_embedding(input_ids) + self.position_embedding(position_ids) - -class CLIPTextTransformer: - def __init__(self): - self.embeddings = CLIPTextEmbeddings() - self.encoder = CLIPEncoder() - self.final_layer_norm = LayerNorm(768) - - def __call__(self, input_ids): - seq_len = input_ids.shape[1] - x = self.embeddings(input_ids, Tensor.arange(seq_len).reshape(1, -1)) - mask = Tensor.full((1, 1, seq_len, seq_len), float("-inf")).triu(1) - x = self.encoder(x, mask) - return self.final_layer_norm(x) - -# Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license) -@lru_cache() -def default_bpe(): - return Path(__file__).parent.parent / "weights/bpe_simple_vocab_16e6.txt.gz" - -def get_pairs(word): - """Return set of symbol pairs in a word. - Word is represented as tuple of symbols (symbols being variable-length strings). - """ - pairs = set() - prev_char = word[0] - for char in word[1:]: - pairs.add((prev_char, char)) - prev_char = char - return pairs - -def whitespace_clean(text): - text = re.sub(r'\s+', ' ', text) - text = text.strip() - return text - -def bytes_to_unicode(): - """ - Returns list of utf-8 byte and a corresponding list of unicode strings. - The reversible bpe codes work on unicode strings. - This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. - When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. - This is a signficant percentage of your normal, say, 32K bpe vocab. - To avoid that, we want lookup tables between utf-8 bytes and unicode strings. - And avoids mapping to whitespace/control characters the bpe code barfs on. - """ - bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) - cs = bs[:] - n = 0 - for b in range(2**8): - if b not in bs: - bs.append(b) - cs.append(2**8+n) - n += 1 - cs = [chr(n) for n in cs] - return dict(zip(bs, cs)) - -class ClipTokenizer: - def __init__(self, bpe_path: str = default_bpe()): - self.byte_encoder = bytes_to_unicode() - merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') - merges = merges[1:49152-256-2+1] - merges = [tuple(merge.split()) for merge in merges] - vocab = list(bytes_to_unicode().values()) - vocab = vocab + [v+'' for v in vocab] - for merge in merges: - vocab.append(''.join(merge)) - vocab.extend(['<|startoftext|>', '<|endoftext|>']) - self.encoder = dict(zip(vocab, range(len(vocab)))) - self.bpe_ranks = dict(zip(merges, range(len(merges)))) - self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} - self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""", re.IGNORECASE) - - def bpe(self, token): - if token in self.cache: - return self.cache[token] - word = tuple(token[:-1]) + ( token[-1] + '',) - pairs = get_pairs(word) - - if not pairs: - return token+'' - - while True: - bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) - if bigram not in self.bpe_ranks: - break - first, second = bigram - new_word = [] - i = 0 - while i < len(word): - try: - j = word.index(first, i) - new_word.extend(word[i:j]) - i = j - except Exception: - new_word.extend(word[i:]) - break - - if word[i] == first and i < len(word)-1 and word[i+1] == second: - new_word.append(first+second) - i += 2 - else: - new_word.append(word[i]) - i += 1 - new_word = tuple(new_word) - word = new_word - if len(word) == 1: - break - pairs = get_pairs(word) - word = ' '.join(word) - self.cache[token] = word - return word - - def encode(self, text): - bpe_tokens = [] - text = whitespace_clean(text.strip()).lower() - for token in re.findall(self.pat, text): - token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) - bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) - # Truncation, keeping two slots for start and end tokens. - if len(bpe_tokens) > 75: - bpe_tokens = bpe_tokens[:75] - return [49406] + bpe_tokens + [49407] * (77 - len(bpe_tokens) - 1) - -class StableDiffusion: - def __init__(self): - self.alphas_cumprod = Tensor.empty(1000) - self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel()) - self.first_stage_model = AutoencoderKL() - self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = CLIPTextTransformer())) - - # TODO: make __call__ run the model - -# ** ldm.models.autoencoder.AutoencoderKL (done!) -# 3x512x512 <--> 4x64x64 (16384) -# decode torch.Size([1, 4, 64, 64]) torch.Size([1, 3, 512, 512]) -# section 4.3 of paper -# first_stage_model.encoder, first_stage_model.decoder - -# ** ldm.modules.diffusionmodules.openaimodel.UNetModel -# this is what runs each time to sample. is this the LDM? -# input: 4x64x64 -# output: 4x64x64 -# model.diffusion_model -# it has attention? - -# ** ldm.modules.encoders.modules.FrozenCLIPEmbedder -# cond_stage_model.transformer.text_model - -# this is sd-v1-4.ckpt -FILENAME = Path(__file__).parent.parent / "weights/sd-v1-4.ckpt" - -import clip as clipsave -import autoencoder as autoencodersave -import unet as unetsave -import stablediffusion as sdsave - -import numpy as np - -if __name__ == "__main__": - Tensor.no_grad = True - '''clip = CLIPTextTransformer() - - print('Saving model...') - clipsave.save_clip_text_transformer(clip, "params") - - input = Tensor([3, 1]) - output = clip(input.unsqueeze(0)) - - print(output[0, 0:2, 0:10].numpy())''' - - '''autoencoder = AutoencoderKL() - print('Saving model...') - autoencodersave.save_autoencoder(autoencoder, "params") - input = Tensor.zeros((1, 3, 10, 10)) - output = autoencoder(input) - print(output.shape) - print(output.numpy())''' - - '''unet = UNetModel() - print('Saving model...') - unetsave.save_unet_model(unet, 'params') - input = Tensor.zeros([1, 4, 64, 64]) - - context = np.array([0.5, 1.3], dtype=np.float32) # specify dtype when defining the array - context = np.repeat(context, 768 // 2) - context = np.expand_dims(context, axis=0) - context = Tensor(context) - - timesteps = Tensor([1.0]) - - output = unet(input, timesteps, context) - #print(output.numpy())''' - - - Tensor.no_grad = True - model = StableDiffusion() - - # load in weights - download_file('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', FILENAME) - load_state_dict(model, torch_load(FILENAME)['state_dict'], strict=False) - - print('Saving model...') - sdsave.save_stable_diffusion(model, "params") - print('Model saved.') - - - '''parser = argparse.ArgumentParser(description='Run Stable Diffusion', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--steps', type=int, default=5, help="Number of steps in diffusion") - parser.add_argument('--prompt', type=str, default="a horse sized cat eating a bagel", help="Phrase to render") - parser.add_argument('--out', type=str, default=os.path.join(tempfile.gettempdir(), "rendered.png"), help="Output filename") - args = parser.parse_args() - - Tensor.no_grad = True - model = StableDiffusion() - - # load in weights - download_file('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', FILENAME) - load_state_dict(model, torch_load(FILENAME)['state_dict'], strict=False) - - # run through CLIP to get context - tokenizer = ClipTokenizer() - prompt = Tensor([tokenizer.encode(args.prompt)]) - context = model.cond_stage_model.transformer.text_model(prompt).realize() - print("got CLIP context", context.shape) - - prompt = Tensor([tokenizer.encode("")]) - unconditional_context = model.cond_stage_model.transformer.text_model(prompt).realize() - print("got unconditional CLIP context", unconditional_context.shape) - - # done with clip model - del model.cond_stage_model - - def get_model_output(latent, timestep): - # put into diffuser - latents = model.model.diffusion_model(latent.expand(2, *latent.shape[1:]), timestep.expand(2, *timestep.shape[1:]), unconditional_context.cat(context, dim=0)) - unconditional_latent, latent = latents[0:1], latents[1:2] - - unconditional_guidance_scale = 7.5 - e_t = unconditional_latent + unconditional_guidance_scale * (latent - unconditional_latent) - return e_t - - timesteps = list(range(1, 1000, 1000//args.steps)) - print(f"running for {timesteps} timesteps") - alphas = [model.alphas_cumprod.numpy()[t] for t in timesteps] - alphas_prev = [1.0] + alphas[:-1] - - def get_x_prev_and_pred_x0(x, e_t, index): - temperature = 1 - a_t, a_prev = alphas[index], alphas_prev[index] - sigma_t = 0 - sqrt_one_minus_at = math.sqrt(1-a_t) - #print(a_t, a_prev, sigma_t, sqrt_one_minus_at) - - pred_x0 = (x - sqrt_one_minus_at * e_t) / math.sqrt(a_t) - - # direction pointing to x_t - dir_xt = math.sqrt(1. - a_prev - sigma_t**2) * e_t - noise = sigma_t * Tensor.randn(*x.shape) * temperature - - x_prev = math.sqrt(a_prev) * pred_x0 + dir_xt #+ noise - return x_prev, pred_x0 - - # start with random noise - latent = Tensor.randn(1,4,64,64) - - # this is diffusion - for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])): - GlobalCounters.reset() - t.set_description("%3d %3d" % (index, timestep)) - e_t = get_model_output(latent, Tensor([timestep])) - x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t, index) - #e_t_next = get_model_output(x_prev) - #e_t_prime = (e_t + e_t_next) / 2 - #x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t_prime, index) - latent = x_prev - latent.realize() - - # upsample latent space to image with autoencoder - x = model.first_stage_model.post_quant_conv(1/0.18215 * latent) - x = model.first_stage_model.decoder(x) - - # make image correct size and scale - x = (x + 1.0) / 2.0 - x = (x.reshape(3,512,512).permute(1,2,0).clip(0,1)*255).cast(dtypes.uint8) - print(x.shape) - - # save image - from PIL import Image - im = Image.fromarray(x.numpy()) - print(f"saving {args.out}") - im.save(args.out) - # Open image. - im.show()'''