You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

402 lines
8.7 KiB
Python

#!/usr/bin/python3
import os,sys,math
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import mnist_dataloader
from mnistdiffusion_model import mndiff_rev
from mnistdiffusion_utils import *
import time
################
## Dataloader ##
################
def mnist_load():
[labels_train, labels_test, images_train, images_test] = mnist_dataloader.mnist_load()
#60000 x 28 x 28
#60000 x 1
#Rescale data to (-1,1)
images_train = images_train.clone().detach().float()
images_test = images_test.clone().detach().float()
images_train = 2.0*(images_train/255.0)-1.0
images_test = 2.0*(images_test/255.0)-1.0
images_train.requires_grad = False
images_test.requires_grad = False
return [images_train, images_test]
################
## Operations ##
################
# Rather than killing myself figuring out variable variance schedules,
# just use a constant variance schedule to start with
def training_parameters():
lparas = dict()
nts = 200
Beta = 0.0050
lr = 1E-4
batchsize = 80
save_every = 200
record_every = 25
B_t = np.ones((nts))*Beta
a_t = np.zeros((nts))
a_t[0] = 1
for I in range(1,nts):
a_t[I] = a_t[I-1]*(1-B_t[I-1])
t = np.linspace(0,1,nts) #normalized time (t/T)
lparas["nts"] = nts
lparas["Beta"] = Beta
lparas["t"] = torch.tensor(t,dtype=torch.float32)
lparas["a_t"] = torch.tensor(a_t,dtype=torch.float32)
lparas["B_t"] = torch.tensor(B_t,dtype=torch.float32)
lparas["lr"] = lr
lparas["batchsize"] = batchsize
lparas["save_every"] = save_every
lparas["record_every"] = record_every
return lparas
def generate_minibatch(imgs, Ndraw, lparas):
#I can probably speed this up with compiled CUDA code
nts = lparas["nts"]
Beta = lparas["Beta"]
dev = imgs.device
imbatch = torch.zeros([Ndraw,nts,28,28],dtype=torch.float32).to(dev)
draws = torch.randint(0,imgs.shape[0],(Ndraw,))
times = lparas["t"].to(dev)
times = times.reshape([1,nts]).expand([Ndraw,nts])
for I in range(0,Ndraw):
imgseq = torch.zeros([nts,28,28],dtype=torch.float32).to(dev)
imgseq[0,:,:] = imgs[draws[I],:,:]
beta = torch.kron(lparas["B_t"],torch.ones(28,28)).to(dev)
beta = beta.reshape([nts,28,28])
sig = torch.sqrt(beta)
noise = torch.randn((nts,28,28)).to(dev)*sig
for J in range(1,nts):
imgseq[J,:,:] = imgseq[J-1,:,:]*torch.sqrt(1.0-beta[J-1,:,:]) + noise[J-1,:,:]
imbatch[I,:,:,:] = imgseq[:,:,:]
return [imbatch,times]
def train_batch(revnet,batchimg,batchtime,lparas):
L = 0.0
lr = lparas["lr"]
nts = batchtime.shape[1]
dev = revnet.device
#beta = torch.as_tensor(lparas["Beta"]).float().to(dev)
beta = float(lparas["Beta"])
optim = torch.optim.SGD(revnet.parameters(),lr=lr)
optim = torch.optim.Adam(revnet.parameters(),lr=lr)
loss = torch.nn.MSELoss().to(dev)
d_est = revnet(batchimg,batchtime)
d_est = d_est[:,1:nts,:,:]
targ = batchimg[:,1:nts,:,:]-batchimg[:,0:nts-1,:,:]
L2 = loss(targ,beta*d_est)
L2.backward()
optim.step()
L = L2.detach().cpu().numpy()
L = L/beta**2
loss.zero_grad()
return L
def train(directory,saveseries):
print("Loading dataset...")
[images_train, images_test] = mnist_load()
lparas = training_parameters()
batchsize = lparas["batchsize"]
save_every = lparas["save_every"]
record_every = lparas["record_every"]
device = torch.device("cuda")
revnet = mndiff_rev()
fname = get_current_save(directory,saveseries)
if(fname is None):
fname = get_next_save(directory,saveseries)
else:
print("Loading {}".format(fname))
revnet.load(fname)
images_train = images_train.to(device)
revnet = revnet.to(device)
## Training Loop
I = 0
J = 0
while(True):
if(I%record_every==0):
print("Minibatch {}".format(I))
#New minibatch
losses = np.zeros((record_every))
#draw minibatch
[mb,mbt] = generate_minibatch(images_train,batchsize,lparas)
#train minibatch
L = train_batch(revnet,mb,mbt,lparas)
losses[J] = L
J = J + 1
#record results
if(J%record_every==0):
J = 0
loss = np.mean(losses)
revnet.recordloss(loss)
print("\tloss={:1.5f}".format(loss))
#save
if(I%save_every==0):
fnext = get_next_save(directory,saveseries)
revnet.save(fnext)
I = I + 1
return
def infer(directory,saveseries):
ttt = int(time.time())
np.random.seed(ttt%50000)
torch.manual_seed(ttt%50000)
lparas = training_parameters()
revnet = mndiff_rev()
fname = get_current_save(directory,saveseries)
if(fname is None):
print("Warning, no state loaded.")
else:
print("Loading {}".format(fname))
revnet.load(fname)
nts = lparas["nts"]
beta = lparas["Beta"]
nts = nts*15
imgseq = torch.zeros([nts,28,28])
imgseq[nts-1,:,:] = torch.randn(28,28)
#imgseq[nts-1,:,17:20] = 3.0
tl = torch.linspace(0,1,nts)
bb = torch.as_tensor(beta)
for I in range(nts-2,-1,-1):
img = imgseq[I+1,:,:]
t = tl[I]
eps = revnet(img,t)
img2 = img*torch.sqrt(1-bb) - bb*eps + 0.5*bb*torch.randn(28,28)
imgseq[I,:,:] = img2
im0 = torch.squeeze(imgseq[0,:,:])
im1 = torch.squeeze(imgseq[int(nts*0.25),:,:])
im2 = torch.squeeze(imgseq[int(nts*0.75),:,:])
im3 = torch.squeeze(imgseq[int(nts*0.90),:,:])
im0 = im0.detach().numpy()
im1 = im1.detach().numpy()
im2 = im2.detach().numpy()
im3 = im3.detach().numpy()
plt.figure()
plt.subplot(2,2,1)
plt.imshow(im3,cmap='gray')
plt.subplot(2,2,2)
plt.imshow(im2,cmap='gray')
plt.subplot(2,2,3)
plt.imshow(im1,cmap='gray')
plt.subplot(2,2,4)
plt.imshow(im0,cmap='gray')
plt.show()
return
def plot_history(directory, saveseries):
revnet = mndiff_rev()
fname = get_current_save(directory,saveseries)
if(fname is None):
fname = get_next_save(directory,saveseries)
else:
print("Loading {}".format(fname))
revnet.load(fname)
revnet.plothistory()
return
def test1():
[images_train, images_test] = mnist_load()
lparas = training_parameters()
#plt.figure()
#plt.plot(lparas["t"],lparas["a_t"])
#plt.show()
#plt.figure()
#plt.plot(lparas["t"],lparas["B_t"])
#plt.show()
images_train = images_train.to("cuda")
[mb,mbtime] = generate_minibatch(images_train,50,lparas)
img0 = torch.squeeze(mb[0,0,:,:]).clone().detach().to("cpu")
img1 = torch.squeeze(mb[0,1,:,:]).clone().detach().to("cpu")
img2 = torch.squeeze(mb[0,2,:,:]).clone().detach().to("cpu")
img3 = torch.squeeze(mb[0,99,:,:]).clone().detach().to("cpu")
print(mb.shape)
print(mbtime.shape)
plt.figure()
plt.subplot(2,2,1)
plt.imshow(img0,cmap='gray')
plt.subplot(2,2,2)
plt.imshow(img1,cmap='gray')
plt.subplot(2,2,3)
plt.imshow(img2,cmap='gray')
plt.subplot(2,2,4)
plt.imshow(img3,cmap='gray')
plt.show()
return
def test2(directory, saveseries):
lparas = training_parameters()
nts = lparas["nts"]
beta = lparas["Beta"]
revnet = mndiff_rev()
fname = get_current_save(directory,saveseries)
if(fname is None):
fname = get_next_save(directory,saveseries)
else:
print("Loading {}".format(fname))
revnet.load(fname)
[images_train, images_test] = mnist_load()
x = torch.squeeze(images_train[100,:,:])
#x = beta*torch.randn(28,28)
y = revnet(x,0.5)
x = x.detach().numpy()
y = y.detach().numpy()
print("y mn={:1.3f} std={:1.3f} std2*beta={}".format(np.mean(y),np.std(y),np.std(y)**2*beta))
plt.figure()
plt.subplot(2,1,1)
plt.imshow(x,cmap='gray')
plt.subplot(2,1,2)
plt.imshow(y,cmap='gray')
plt.show()
return
##########
## Main ##
##########
def mainswitch():
args = sys.argv
if(len(args)<2):
print("mnistdiffusion [operation]")
print("operations:")
print("\ttrain")
print("\tplot_history")
print("\tinfer")
exit(0)
if(sys.argv[1]=="train"):
train("./saves","test_f")
if(sys.argv[1]=="plot_history"):
plot_history("./saves","test_f")
if(sys.argv[1]=="infer"):
infer("./saves","test_f")
if(__name__=="__main__"):
ttt = int(time.time())
np.random.seed(ttt%50000)
torch.manual_seed(ttt%50000)
#test1()
#test2("./saves","test_d")
mainswitch()