#!/usr/bin/python3 import os,sys,math import numpy as np import matplotlib.pyplot as plt import torch import torch.nn as nn import torch.nn.functional as F import mnist_dataloader from mnistdiffusion_model import mndiff_rev ################ ## Dataloader ## ################ def mnist_load(): [labels_train, labels_test, images_train, images_test] = mnist_dataloader.mnist_load() #60000 x 28 x 28 #60000 x 1 #Rescale data to (-1,1) images_train = images_train.clone().detach().float() images_test = images_test.clone().detach().float() images_train = 2.0*(images_train/255.0)-1.0 images_test = 2.0*(images_test/255.0)-1.0 images_train.requires_grad = False images_test.requires_grad = False return [images_train, images_test] ################ ## Operations ## ################ def training_parameters(): lparas = dict() nts = 200 t = np.arange(0,nts) s = 0.1 f_t = np.cos((t/nts+s)/(1+s)*np.pi/2.0)**2 a_t = f_t/f_t[0] B_t = np.zeros(t.shape[0]) B_t[0:nts-1] = np.clip(1-a_t[1:nts]/a_t[0:nts-1],0,0.999) B_t[nts-1] = B_t[nts-2] lparas["nts"] = nts lparas["t"] = torch.tensor(t,dtype=torch.float32) lparas["a_t"] = torch.tensor(a_t,dtype=torch.float32) lparas["B_t"] = torch.tensor(B_t,dtype=torch.float32) #Traiing parameters for a cosine variance schedule #ref: Nichol and Dharawal 2021 return lparas def generate_minibatch(imgs, Ndraw, lparas): #I can probably speed this up with compiled CUDA code nts = lparas["nts"] dev = imgs.device imbatch = torch.zeros([Ndraw*nts,28,28],dtype=torch.float32).to(dev) draws = torch.randint(0,imgs.shape[0],(Ndraw,)) for I in range(0,Ndraw): imgseq = torch.zeros([nts,28,28],dtype=torch.float32).to(dev) imgseq[0,:,:] = imgs[draws[I],:,:] beta = torch.kron(lparas["B_t"],torch.ones(28,28)).to(dev) beta = beta.reshape([nts,28,28]) sig = torch.sqrt(beta) noise = torch.randn((nts,28,28)).to(dev)*sig for J in range(1,nts): imgseq[J,:,:] = imgseq[J-1,:,:] + noise[J-1,:,:] I1 = I*nts I2 = (I+1)*nts imbatch[I1:I2,:,:] = imgseq[:,:,:] return imbatch def train_batch(imgbatch,lr,lparas): return def train(): return def infer(): return def test1(): [images_train, images_test] = mnist_load() lparas = training_parameters() #plt.figure() #plt.plot(lparas["t"],lparas["a_t"]) #plt.show() #plt.figure() #plt.plot(lparas["t"],lparas["B_t"]) #plt.show() images_train = images_train.to("cuda") mb = generate_minibatch(images_train,50,lparas) img0 = torch.squeeze(mb[0,:,:]).clone().detach().to("cpu") img1 = torch.squeeze(mb[1,:,:]).clone().detach().to("cpu") img2 = torch.squeeze(mb[2,:,:]).clone().detach().to("cpu") img3 = torch.squeeze(mb[30,:,:]).clone().detach().to("cpu") plt.figure() plt.subplot(2,2,1) plt.imshow(img0,cmap='gray') plt.subplot(2,2,2) plt.imshow(img1,cmap='gray') plt.subplot(2,2,3) plt.imshow(img2,cmap='gray') plt.subplot(2,2,4) plt.imshow(img3,cmap='gray') plt.show() return ########## ## Main ## ########## if(__name__=="__main__"): #test1() test_gaussian_timeselect()