Add first successful sampling implementation

This commit is contained in:
Gadersd
2023-08-04 17:01:44 -04:00
committed by Ben_Kosytorz
parent b794e9a9ec
commit 8e7a8d9be4
9 changed files with 42 additions and 34 deletions

View File

@@ -56,9 +56,6 @@ class ResnetBlock:
def __call__(self, x):
h = self.conv1(self.norm1(x).swish())
'''v = h
print(v.shape)
print(v[0, 0:10, :, :].numpy())'''
h = self.conv2(self.norm2(h).swish())
return self.nin_shortcut(x) + h
@@ -145,7 +142,6 @@ class AutoencoderKL:
latent = self.encoder(x)
latent = self.quant_conv(latent)
latent = latent[:, 0:4] # only the means
print("latent", latent.shape)
latent = self.post_quant_conv(latent)
return self.decoder(latent)
@@ -339,15 +335,12 @@ class UNetModel:
saved_inputs = []
for i,b in enumerate(self.input_blocks):
#print("input block", i)
print(x.numpy())
for bb in b:
x = run(x, bb)
saved_inputs.append(x)
for bb in self.middle_block:
x = run(x, bb)
for i,b in enumerate(self.output_blocks):
#print("output block", i)
x = x.cat(saved_inputs.pop(), dim=1)
for bb in b:
x = run(x, bb)
@@ -644,7 +637,9 @@ if __name__ == "__main__":
download_file('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', FILENAME)
load_state_dict(model, torch_load(FILENAME)['state_dict'], strict=False)
print('Saving model...')
sdsave.save_stable_diffusion(model, "params")
print('Model saved.')
'''parser = argparse.ArgumentParser(description='Run Stable Diffusion', formatter_class=argparse.ArgumentDefaultsHelpFormatter)