You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
399 lines
8.7 KiB
Python
399 lines
8.7 KiB
Python
3 weeks ago
|
#!/usr/bin/python3
|
||
|
|
||
|
import os,sys,math
|
||
|
import numpy as np
|
||
|
|
||
|
import matplotlib.pyplot as plt
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
import mnist_dataloader
|
||
|
|
||
|
from mnistdiffusion_model import mndiff_rev
|
||
|
from mnistdiffusion_utils import *
|
||
|
|
||
|
import time
|
||
|
|
||
|
################
|
||
|
## Dataloader ##
|
||
|
################
|
||
|
|
||
|
def mnist_load():
|
||
|
|
||
|
[labels_train, labels_test, images_train, images_test] = mnist_dataloader.mnist_load()
|
||
|
|
||
|
#60000 x 28 x 28
|
||
|
#60000 x 1
|
||
|
|
||
|
#Rescale data to (-1,1)
|
||
|
images_train = images_train.clone().detach().float()
|
||
|
images_test = images_test.clone().detach().float()
|
||
|
|
||
|
images_train = 2.0*(images_train/255.0)-1.0
|
||
|
images_test = 2.0*(images_test/255.0)-1.0
|
||
|
|
||
|
images_train.requires_grad = False
|
||
|
images_test.requires_grad = False
|
||
|
|
||
|
return [images_train, images_test]
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
################
|
||
|
## Operations ##
|
||
|
################
|
||
|
|
||
|
# Rather than killing myself figuring out variable variance schedules,
|
||
|
# just use a constant variance schedule to start with
|
||
|
|
||
|
def training_parameters():
|
||
|
lparas = dict()
|
||
|
|
||
|
nts = 200
|
||
|
Beta = 0.0050
|
||
|
|
||
|
lr = 1E-4
|
||
|
batchsize = 30
|
||
|
save_every = 200
|
||
|
record_every = 25
|
||
|
|
||
|
B_t = np.ones((nts))*Beta
|
||
|
a_t = np.zeros((nts))
|
||
|
a_t[0] = 1
|
||
|
for I in range(1,nts):
|
||
|
a_t[I] = a_t[I-1]*(1-B_t[I-1])
|
||
|
|
||
|
t = np.linspace(0,1,nts) #normalized time (t/T)
|
||
|
|
||
|
lparas["nts"] = nts
|
||
|
lparas["Beta"] = Beta
|
||
|
lparas["t"] = torch.tensor(t,dtype=torch.float32)
|
||
|
lparas["a_t"] = torch.tensor(a_t,dtype=torch.float32)
|
||
|
lparas["B_t"] = torch.tensor(B_t,dtype=torch.float32)
|
||
|
lparas["lr"] = lr
|
||
|
lparas["batchsize"] = batchsize
|
||
|
lparas["save_every"] = save_every
|
||
|
lparas["record_every"] = record_every
|
||
|
return lparas
|
||
|
|
||
|
|
||
|
def generate_minibatch(imgs, Ndraw, lparas):
|
||
|
|
||
|
#I can probably speed this up with compiled CUDA code
|
||
|
|
||
|
nts = lparas["nts"]
|
||
|
Beta = lparas["Beta"]
|
||
|
|
||
|
dev = imgs.device
|
||
|
|
||
|
imbatch = torch.zeros([Ndraw,nts,28,28],dtype=torch.float32).to(dev)
|
||
|
draws = torch.randint(0,imgs.shape[0],(Ndraw,))
|
||
|
|
||
|
times = lparas["t"].to(dev)
|
||
|
times = times.reshape([1,nts]).expand([Ndraw,nts])
|
||
|
|
||
|
for I in range(0,Ndraw):
|
||
|
imgseq = torch.zeros([nts,28,28],dtype=torch.float32).to(dev)
|
||
|
imgseq[0,:,:] = imgs[draws[I],:,:]
|
||
|
|
||
|
beta = torch.kron(lparas["B_t"],torch.ones(28,28)).to(dev)
|
||
|
beta = beta.reshape([nts,28,28])
|
||
|
|
||
|
sig = torch.sqrt(beta)
|
||
|
noise = torch.randn((nts,28,28)).to(dev)*sig
|
||
|
|
||
|
for J in range(1,nts):
|
||
|
imgseq[J,:,:] = imgseq[J-1,:,:]*torch.sqrt(1.0-beta[J-1,:,:]) + noise[J-1,:,:]
|
||
|
|
||
|
|
||
|
imbatch[I,:,:,:] = imgseq[:,:,:]
|
||
|
|
||
|
return [imbatch,times]
|
||
|
|
||
|
def train_batch(revnet,batchimg,batchtime,lparas):
|
||
|
|
||
|
L = 0.0
|
||
|
lr = lparas["lr"]
|
||
|
nts = batchtime.shape[1]
|
||
|
|
||
|
dev = revnet.device
|
||
|
#beta = torch.as_tensor(lparas["Beta"]).float().to(dev)
|
||
|
beta = float(lparas["Beta"])
|
||
|
|
||
|
optim = torch.optim.SGD(revnet.parameters(),lr=lr)
|
||
|
optim = torch.optim.Adam(revnet.parameters(),lr=lr)
|
||
|
|
||
|
loss = torch.nn.MSELoss().to(dev)
|
||
|
|
||
|
d_est = revnet(batchimg,batchtime)
|
||
|
d_est = d_est[:,1:nts,:,:]
|
||
|
targ = batchimg[:,1:nts,:,:]-batchimg[:,0:nts-1,:,:]
|
||
|
|
||
|
L2 = loss(targ,beta*d_est)
|
||
|
|
||
|
L2.backward()
|
||
|
|
||
|
optim.step()
|
||
|
L = L2.detach().cpu().numpy()
|
||
|
L = L/beta**2
|
||
|
|
||
|
loss.zero_grad()
|
||
|
|
||
|
return L
|
||
|
|
||
|
|
||
|
def train(directory,saveseries):
|
||
|
|
||
|
print("Loading dataset...")
|
||
|
[images_train, images_test] = mnist_load()
|
||
|
lparas = training_parameters()
|
||
|
|
||
|
batchsize = lparas["batchsize"]
|
||
|
save_every = lparas["save_every"]
|
||
|
record_every = lparas["record_every"]
|
||
|
device = torch.device("cuda")
|
||
|
|
||
|
revnet = mndiff_rev()
|
||
|
fname = get_current_save(directory,saveseries)
|
||
|
if(fname is None):
|
||
|
fname = get_next_save(directory,saveseries)
|
||
|
else:
|
||
|
print("Loading {}".format(fname))
|
||
|
revnet.load(fname)
|
||
|
|
||
|
images_train = images_train.to(device)
|
||
|
revnet = revnet.to(device)
|
||
|
|
||
|
## Training Loop
|
||
|
I = 0
|
||
|
J = 0
|
||
|
while(True):
|
||
|
|
||
|
if(I%record_every==0):
|
||
|
print("Minibatch {}".format(I))
|
||
|
|
||
|
#New minibatch
|
||
|
losses = np.zeros((record_every))
|
||
|
|
||
|
|
||
|
#draw minibatch
|
||
|
[mb,mbt] = generate_minibatch(images_train,batchsize,lparas)
|
||
|
|
||
|
#train minibatch
|
||
|
L = train_batch(revnet,mb,mbt,lparas)
|
||
|
losses[J] = L
|
||
|
J = J + 1
|
||
|
|
||
|
#record results
|
||
|
if(J%record_every==0):
|
||
|
J = 0
|
||
|
loss = np.mean(losses)
|
||
|
revnet.recordloss(loss)
|
||
|
print("\tloss={:1.5f}".format(loss))
|
||
|
|
||
|
#save
|
||
|
if(I%save_every==0):
|
||
|
fnext = get_next_save(directory,saveseries)
|
||
|
revnet.save(fnext)
|
||
|
|
||
|
I = I + 1
|
||
|
|
||
|
return
|
||
|
|
||
|
def infer(directory,saveseries):
|
||
|
|
||
|
ttt = int(time.time())
|
||
|
np.random.seed(ttt%50000)
|
||
|
torch.manual_seed(ttt%50000)
|
||
|
|
||
|
|
||
|
|
||
|
lparas = training_parameters()
|
||
|
revnet = mndiff_rev()
|
||
|
fname = get_current_save(directory,saveseries)
|
||
|
if(fname is None):
|
||
|
print("Warning, no state loaded.")
|
||
|
else:
|
||
|
print("Loading {}".format(fname))
|
||
|
revnet.load(fname)
|
||
|
|
||
|
|
||
|
nts = lparas["nts"]
|
||
|
beta = lparas["Beta"]
|
||
|
|
||
|
nts = nts*15
|
||
|
|
||
|
imgseq = torch.zeros([nts,28,28])
|
||
|
imgseq[nts-1,:,:] = torch.randn(28,28)
|
||
|
|
||
|
#imgseq[nts-1,:,17:20] = 3.0
|
||
|
|
||
|
tl = torch.linspace(0,1,nts)
|
||
|
|
||
|
bb = torch.as_tensor(beta)
|
||
|
|
||
|
for I in range(nts-2,-1,-1):
|
||
|
img = imgseq[I+1,:,:]
|
||
|
t = tl[I]
|
||
|
eps = revnet(img,t)
|
||
|
|
||
|
|
||
|
|
||
|
img2 = img*torch.sqrt(1-bb) - bb*eps + 1.0*bb*torch.randn(28,28)
|
||
|
|
||
|
imgseq[I,:,:] = img2
|
||
|
|
||
|
im0 = torch.squeeze(imgseq[0,:,:])
|
||
|
im1 = torch.squeeze(imgseq[int(nts*0.25),:,:])
|
||
|
im2 = torch.squeeze(imgseq[int(nts*0.75),:,:])
|
||
|
im3 = torch.squeeze(imgseq[int(nts*0.90),:,:])
|
||
|
|
||
|
im0 = im0.detach().numpy()
|
||
|
im1 = im1.detach().numpy()
|
||
|
im2 = im2.detach().numpy()
|
||
|
im3 = im3.detach().numpy()
|
||
|
|
||
|
plt.figure()
|
||
|
plt.subplot(2,2,1)
|
||
|
plt.imshow(im3,cmap='gray')
|
||
|
plt.subplot(2,2,2)
|
||
|
plt.imshow(im2,cmap='gray')
|
||
|
plt.subplot(2,2,3)
|
||
|
plt.imshow(im1,cmap='gray')
|
||
|
plt.subplot(2,2,4)
|
||
|
plt.imshow(im0,cmap='gray')
|
||
|
plt.show()
|
||
|
|
||
|
|
||
|
return
|
||
|
|
||
|
def plot_history(directory, saveseries):
|
||
|
|
||
|
revnet = mndiff_rev()
|
||
|
fname = get_current_save(directory,saveseries)
|
||
|
if(fname is None):
|
||
|
fname = get_next_save(directory,saveseries)
|
||
|
else:
|
||
|
print("Loading {}".format(fname))
|
||
|
revnet.load(fname)
|
||
|
|
||
|
revnet.plothistory()
|
||
|
|
||
|
return
|
||
|
|
||
|
def test1():
|
||
|
|
||
|
[images_train, images_test] = mnist_load()
|
||
|
lparas = training_parameters()
|
||
|
|
||
|
#plt.figure()
|
||
|
#plt.plot(lparas["t"],lparas["a_t"])
|
||
|
#plt.show()
|
||
|
|
||
|
#plt.figure()
|
||
|
#plt.plot(lparas["t"],lparas["B_t"])
|
||
|
#plt.show()
|
||
|
|
||
|
images_train = images_train.to("cuda")
|
||
|
|
||
|
[mb,mbtime] = generate_minibatch(images_train,50,lparas)
|
||
|
|
||
|
img0 = torch.squeeze(mb[0,0,:,:]).clone().detach().to("cpu")
|
||
|
img1 = torch.squeeze(mb[0,1,:,:]).clone().detach().to("cpu")
|
||
|
img2 = torch.squeeze(mb[0,2,:,:]).clone().detach().to("cpu")
|
||
|
img3 = torch.squeeze(mb[0,99,:,:]).clone().detach().to("cpu")
|
||
|
|
||
|
print(mb.shape)
|
||
|
print(mbtime.shape)
|
||
|
|
||
|
plt.figure()
|
||
|
plt.subplot(2,2,1)
|
||
|
plt.imshow(img0,cmap='gray')
|
||
|
plt.subplot(2,2,2)
|
||
|
plt.imshow(img1,cmap='gray')
|
||
|
plt.subplot(2,2,3)
|
||
|
plt.imshow(img2,cmap='gray')
|
||
|
plt.subplot(2,2,4)
|
||
|
plt.imshow(img3,cmap='gray')
|
||
|
|
||
|
plt.show()
|
||
|
|
||
|
return
|
||
|
|
||
|
def test2(directory, saveseries):
|
||
|
|
||
|
lparas = training_parameters()
|
||
|
nts = lparas["nts"]
|
||
|
beta = lparas["Beta"]
|
||
|
|
||
|
revnet = mndiff_rev()
|
||
|
fname = get_current_save(directory,saveseries)
|
||
|
if(fname is None):
|
||
|
fname = get_next_save(directory,saveseries)
|
||
|
else:
|
||
|
print("Loading {}".format(fname))
|
||
|
revnet.load(fname)
|
||
|
|
||
|
|
||
|
[images_train, images_test] = mnist_load()
|
||
|
|
||
|
x = torch.squeeze(images_train[100,:,:])
|
||
|
#x = beta*torch.randn(28,28)
|
||
|
y = revnet(x,0.5)
|
||
|
|
||
|
x = x.detach().numpy()
|
||
|
y = y.detach().numpy()
|
||
|
|
||
|
print("y mn={:1.3f} std={:1.3f} std2*beta={}".format(np.mean(y),np.std(y),np.std(y)**2*beta))
|
||
|
|
||
|
|
||
|
plt.figure()
|
||
|
plt.subplot(2,1,1)
|
||
|
plt.imshow(x,cmap='gray')
|
||
|
plt.subplot(2,1,2)
|
||
|
plt.imshow(y,cmap='gray')
|
||
|
plt.show()
|
||
|
|
||
|
|
||
|
return
|
||
|
|
||
|
##########
|
||
|
## Main ##
|
||
|
##########
|
||
|
|
||
|
def mainswitch():
|
||
|
args = sys.argv
|
||
|
if(len(args)<2):
|
||
|
print("mnistdiffusion [operation]")
|
||
|
print("operations:")
|
||
|
print("\ttrain")
|
||
|
print("\tplot_history")
|
||
|
print("\tinfer")
|
||
|
|
||
|
exit(0)
|
||
|
|
||
|
if(sys.argv[1]=="train"):
|
||
|
train("./saves","test_3a")
|
||
|
if(sys.argv[1]=="plot_history"):
|
||
|
plot_history("./saves","test_3a")
|
||
|
if(sys.argv[1]=="infer"):
|
||
|
infer("./saves","test_3a")
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
if(__name__=="__main__"):
|
||
|
|
||
|
ttt = int(time.time())
|
||
|
np.random.seed(ttt%50000)
|
||
|
torch.manual_seed(ttt%50000)
|
||
|
|
||
|
#test1()
|
||
|
#test2("./saves","test_2a")
|
||
|
mainswitch()
|
||
|
|