#!/usr/bin/python3 import os,sys,math import numpy as np import matplotlib.pyplot as plt import torch import torch.nn as nn import torch.nn.functional as F import mnist_dataloader from mnistdiffusion_model import mndiff_rev from mnistdiffusion_utils import * import time ################ ## Dataloader ## ################ def mnist_load(): [labels_train, labels_test, images_train, images_test] = mnist_dataloader.mnist_load() #60000 x 28 x 28 #60000 x 1 #Rescale data to (-1,1) images_train = images_train.clone().detach().float() images_test = images_test.clone().detach().float() images_train = 2.0*(images_train/255.0)-1.0 images_test = 2.0*(images_test/255.0)-1.0 images_train.requires_grad = False images_test.requires_grad = False return [images_train, images_test] ################ ## Operations ## ################ # Rather than killing myself figuring out variable variance schedules, # just use a constant variance schedule to start with def training_parameters(): lparas = dict() nts = 200 Beta = 0.0050 lr = 1E-4 batchsize = 40 save_every = 200 record_every = 25 B_t = np.ones((nts))*Beta a_t = np.zeros((nts)) a_t[0] = 1 for I in range(1,nts): a_t[I] = a_t[I-1]*(1-B_t[I-1]) t = np.linspace(0,1,nts) #normalized time (t/T) lparas["nts"] = nts lparas["Beta"] = Beta lparas["t"] = torch.tensor(t,dtype=torch.float32) lparas["a_t"] = torch.tensor(a_t,dtype=torch.float32) lparas["B_t"] = torch.tensor(B_t,dtype=torch.float32) lparas["lr"] = lr lparas["batchsize"] = batchsize lparas["save_every"] = save_every lparas["record_every"] = record_every return lparas def generate_minibatch(imgs, Ndraw, lparas): #I can probably speed this up with compiled CUDA code nts = lparas["nts"] Beta = lparas["Beta"] dev = imgs.device imbatch = torch.zeros([Ndraw,nts,28,28],dtype=torch.float32).to(dev) draws = torch.randint(0,imgs.shape[0],(Ndraw,)) times = lparas["t"].to(dev) times = times.reshape([1,nts]).expand([Ndraw,nts]) for I in range(0,Ndraw): imgseq = torch.zeros([nts,28,28],dtype=torch.float32).to(dev) imgseq[0,:,:] = imgs[draws[I],:,:] beta = torch.kron(lparas["B_t"],torch.ones(28,28)).to(dev) beta = beta.reshape([nts,28,28]) sig = torch.sqrt(beta) noise = torch.randn((nts,28,28)).to(dev)*sig for J in range(1,nts): imgseq[J,:,:] = imgseq[J-1,:,:]*torch.sqrt(1.0-beta[J-1,:,:]) + noise[J-1,:,:] imbatch[I,:,:,:] = imgseq[:,:,:] return [imbatch,times] def train_batch(revnet,batchimg,batchtime,lparas): L = 0.0 lr = lparas["lr"] nts = batchtime.shape[1] dev = revnet.device #beta = torch.as_tensor(lparas["Beta"]).float().to(dev) beta = float(lparas["Beta"]) optim = torch.optim.SGD(revnet.parameters(),lr=lr) optim = torch.optim.Adam(revnet.parameters(),lr=lr) loss = torch.nn.MSELoss().to(dev) d_est = revnet(batchimg,batchtime) d_est = d_est[:,1:nts,:,:] targ = batchimg[:,1:nts,:,:]-batchimg[:,0:nts-1,:,:] L2 = loss(targ,beta*d_est) L2.backward() optim.step() L = L2.detach().cpu().numpy() L = L/beta**2 loss.zero_grad() return L def train(directory,saveseries): print("Loading dataset...") [images_train, images_test] = mnist_load() lparas = training_parameters() batchsize = lparas["batchsize"] save_every = lparas["save_every"] record_every = lparas["record_every"] device = torch.device("cuda") revnet = mndiff_rev() fname = get_current_save(directory,saveseries) if(fname is None): fname = get_next_save(directory,saveseries) else: print("Loading {}".format(fname)) revnet.load(fname) images_train = images_train.to(device) revnet = revnet.to(device) ## Training Loop I = 0 J = 0 while(True): if(I%record_every==0): print("Minibatch {}".format(I)) #New minibatch losses = np.zeros((record_every)) #draw minibatch [mb,mbt] = generate_minibatch(images_train,batchsize,lparas) #train minibatch L = train_batch(revnet,mb,mbt,lparas) losses[J] = L J = J + 1 #record results if(J%record_every==0): J = 0 loss = np.mean(losses) revnet.recordloss(loss) print("\tloss={:1.5f}".format(loss)) #save if(I%save_every==0): fnext = get_next_save(directory,saveseries) revnet.save(fnext) I = I + 1 return def infer(directory,saveseries): ttt = int(time.time()) np.random.seed(ttt%50000) torch.manual_seed(ttt%50000) lparas = training_parameters() revnet = mndiff_rev() fname = get_current_save(directory,saveseries) if(fname is None): print("Warning, no state loaded.") else: print("Loading {}".format(fname)) revnet.load(fname) nts = lparas["nts"] beta = lparas["Beta"] nts = nts*15 imgseq = torch.zeros([nts,28,28]) imgseq[nts-1,:,:] = torch.randn(28,28) #imgseq[nts-1,:,17:20] = 3.0 tl = torch.linspace(0,1,nts) bb = torch.as_tensor(beta) for I in range(nts-2,-1,-1): img = imgseq[I+1,:,:] t = tl[I] eps = revnet(img,t) img2 = img*torch.sqrt(1-bb) - bb*eps + 1.0*bb*torch.randn(28,28) imgseq[I,:,:] = img2 im0 = torch.squeeze(imgseq[0,:,:]) im1 = torch.squeeze(imgseq[int(nts*0.25),:,:]) im2 = torch.squeeze(imgseq[int(nts*0.75),:,:]) im3 = torch.squeeze(imgseq[int(nts*0.90),:,:]) im0 = im0.detach().numpy() im1 = im1.detach().numpy() im2 = im2.detach().numpy() im3 = im3.detach().numpy() plt.figure() plt.subplot(2,2,1) plt.imshow(im3,cmap='gray') plt.subplot(2,2,2) plt.imshow(im2,cmap='gray') plt.subplot(2,2,3) plt.imshow(im1,cmap='gray') plt.subplot(2,2,4) plt.imshow(im0,cmap='gray') plt.show() return def plot_history(directory, saveseries): revnet = mndiff_rev() fname = get_current_save(directory,saveseries) if(fname is None): fname = get_next_save(directory,saveseries) else: print("Loading {}".format(fname)) revnet.load(fname) revnet.plothistory() return def test1(): [images_train, images_test] = mnist_load() lparas = training_parameters() #plt.figure() #plt.plot(lparas["t"],lparas["a_t"]) #plt.show() #plt.figure() #plt.plot(lparas["t"],lparas["B_t"]) #plt.show() images_train = images_train.to("cuda") [mb,mbtime] = generate_minibatch(images_train,50,lparas) img0 = torch.squeeze(mb[0,0,:,:]).clone().detach().to("cpu") img1 = torch.squeeze(mb[0,1,:,:]).clone().detach().to("cpu") img2 = torch.squeeze(mb[0,2,:,:]).clone().detach().to("cpu") img3 = torch.squeeze(mb[0,99,:,:]).clone().detach().to("cpu") print(mb.shape) print(mbtime.shape) plt.figure() plt.subplot(2,2,1) plt.imshow(img0,cmap='gray') plt.subplot(2,2,2) plt.imshow(img1,cmap='gray') plt.subplot(2,2,3) plt.imshow(img2,cmap='gray') plt.subplot(2,2,4) plt.imshow(img3,cmap='gray') plt.show() return def test2(directory, saveseries): lparas = training_parameters() nts = lparas["nts"] beta = lparas["Beta"] revnet = mndiff_rev() fname = get_current_save(directory,saveseries) if(fname is None): fname = get_next_save(directory,saveseries) else: print("Loading {}".format(fname)) revnet.load(fname) [images_train, images_test] = mnist_load() x = torch.squeeze(images_train[100,:,:]) #x = beta*torch.randn(28,28) y = revnet(x,0.5) x = x.detach().numpy() y = y.detach().numpy() print("y mn={:1.3f} std={:1.3f} std2*beta={}".format(np.mean(y),np.std(y),np.std(y)**2*beta)) plt.figure() plt.subplot(2,1,1) plt.imshow(x,cmap='gray') plt.subplot(2,1,2) plt.imshow(y,cmap='gray') plt.show() return ########## ## Main ## ########## def mainswitch(): args = sys.argv if(len(args)<2): print("mnistdiffusion [operation]") print("operations:") print("\ttrain") print("\tplot_history") print("\tinfer") exit(0) if(sys.argv[1]=="train"): train("./saves","test_2a") if(sys.argv[1]=="plot_history"): plot_history("./saves","test_2a") if(sys.argv[1]=="infer"): infer("./saves","test_2a") if(__name__=="__main__"): ttt = int(time.time()) np.random.seed(ttt%50000) torch.manual_seed(ttt%50000) #test1() #test2("./saves","test_2a") mainswitch()