You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
162 lines
3.3 KiB
Python
162 lines
3.3 KiB
Python
3 weeks ago
|
#!/usr/bin/python3
|
||
|
|
||
|
import os,sys,math
|
||
|
import numpy as np
|
||
|
|
||
|
import matplotlib.pyplot as plt
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
import mnist_dataloader
|
||
|
|
||
|
from mnistdiffusion_model import mndiff_rev
|
||
|
|
||
|
################
|
||
|
## Dataloader ##
|
||
|
################
|
||
|
|
||
|
def mnist_load():
|
||
|
|
||
|
[labels_train, labels_test, images_train, images_test] = mnist_dataloader.mnist_load()
|
||
|
|
||
|
#60000 x 28 x 28
|
||
|
#60000 x 1
|
||
|
|
||
|
#Rescale data to (-1,1)
|
||
|
images_train = images_train.clone().detach().float()
|
||
|
images_test = images_test.clone().detach().float()
|
||
|
|
||
|
images_train = 2.0*(images_train/255.0)-1.0
|
||
|
images_test = 2.0*(images_test/255.0)-1.0
|
||
|
|
||
|
images_train.requires_grad = False
|
||
|
images_test.requires_grad = False
|
||
|
|
||
|
return [images_train, images_test]
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
################
|
||
|
## Operations ##
|
||
|
################
|
||
|
|
||
|
def training_parameters():
|
||
|
lparas = dict()
|
||
|
|
||
|
nts = 200
|
||
|
t = np.arange(0,nts)
|
||
|
s = 0.1
|
||
|
f_t = np.cos((t/nts+s)/(1+s)*np.pi/2.0)**2
|
||
|
a_t = f_t/f_t[0]
|
||
|
B_t = np.zeros(t.shape[0])
|
||
|
B_t[0:nts-1] = np.clip(1-a_t[1:nts]/a_t[0:nts-1],0,0.999)
|
||
|
B_t[nts-1] = B_t[nts-2]
|
||
|
|
||
|
lparas["nts"] = nts
|
||
|
lparas["t"] = torch.tensor(t,dtype=torch.float32)
|
||
|
lparas["a_t"] = torch.tensor(a_t,dtype=torch.float32)
|
||
|
lparas["B_t"] = torch.tensor(B_t,dtype=torch.float32)
|
||
|
|
||
|
#Traiing parameters for a cosine variance schedule
|
||
|
#ref: Nichol and Dharawal 2021
|
||
|
|
||
|
return lparas
|
||
|
|
||
|
|
||
|
def generate_minibatch(imgs, Ndraw, lparas):
|
||
|
|
||
|
#I can probably speed this up with compiled CUDA code
|
||
|
|
||
|
nts = lparas["nts"]
|
||
|
dev = imgs.device
|
||
|
|
||
|
imbatch = torch.zeros([Ndraw*nts,28,28],dtype=torch.float32).to(dev)
|
||
|
draws = torch.randint(0,imgs.shape[0],(Ndraw,))
|
||
|
|
||
|
for I in range(0,Ndraw):
|
||
|
imgseq = torch.zeros([nts,28,28],dtype=torch.float32).to(dev)
|
||
|
imgseq[0,:,:] = imgs[draws[I],:,:]
|
||
|
|
||
|
beta = torch.kron(lparas["B_t"],torch.ones(28,28)).to(dev)
|
||
|
beta = beta.reshape([nts,28,28])
|
||
|
|
||
|
sig = torch.sqrt(beta)
|
||
|
noise = torch.randn((nts,28,28)).to(dev)*sig
|
||
|
|
||
|
for J in range(1,nts):
|
||
|
imgseq[J,:,:] = imgseq[J-1,:,:] + noise[J-1,:,:]
|
||
|
|
||
|
I1 = I*nts
|
||
|
I2 = (I+1)*nts
|
||
|
imbatch[I1:I2,:,:] = imgseq[:,:,:]
|
||
|
|
||
|
return imbatch
|
||
|
|
||
|
def train_batch(imgbatch,lr,lparas):
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
return
|
||
|
|
||
|
|
||
|
def train():
|
||
|
|
||
|
return
|
||
|
|
||
|
def infer():
|
||
|
|
||
|
return
|
||
|
|
||
|
|
||
|
|
||
|
def test1():
|
||
|
|
||
|
[images_train, images_test] = mnist_load()
|
||
|
lparas = training_parameters()
|
||
|
|
||
|
#plt.figure()
|
||
|
#plt.plot(lparas["t"],lparas["a_t"])
|
||
|
#plt.show()
|
||
|
|
||
|
#plt.figure()
|
||
|
#plt.plot(lparas["t"],lparas["B_t"])
|
||
|
#plt.show()
|
||
|
|
||
|
images_train = images_train.to("cuda")
|
||
|
|
||
|
mb = generate_minibatch(images_train,50,lparas)
|
||
|
|
||
|
img0 = torch.squeeze(mb[0,:,:]).clone().detach().to("cpu")
|
||
|
img1 = torch.squeeze(mb[1,:,:]).clone().detach().to("cpu")
|
||
|
img2 = torch.squeeze(mb[2,:,:]).clone().detach().to("cpu")
|
||
|
img3 = torch.squeeze(mb[30,:,:]).clone().detach().to("cpu")
|
||
|
|
||
|
|
||
|
plt.figure()
|
||
|
plt.subplot(2,2,1)
|
||
|
plt.imshow(img0,cmap='gray')
|
||
|
plt.subplot(2,2,2)
|
||
|
plt.imshow(img1,cmap='gray')
|
||
|
plt.subplot(2,2,3)
|
||
|
plt.imshow(img2,cmap='gray')
|
||
|
plt.subplot(2,2,4)
|
||
|
plt.imshow(img3,cmap='gray')
|
||
|
|
||
|
plt.show()
|
||
|
|
||
|
return
|
||
|
|
||
|
|
||
|
##########
|
||
|
## Main ##
|
||
|
##########
|
||
|
|
||
|
if(__name__=="__main__"):
|
||
|
|
||
|
#test1()
|
||
|
test_gaussian_timeselect()
|