commit eb2c910e296f4d607aa4de816a7626458de83b25 Author: Aaron Date: Tue Feb 4 21:49:19 2025 -0500 Making some of my models public. diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..002a26d --- /dev/null +++ b/LICENSE @@ -0,0 +1,2 @@ +Copyright 2023, Aaron M. Schinder +Released under the MIT/BSD License diff --git a/attempt1/__pycache__/mnist_dataloader.cpython-36.pyc b/attempt1/__pycache__/mnist_dataloader.cpython-36.pyc new file mode 100644 index 0000000..2775a98 Binary files /dev/null and b/attempt1/__pycache__/mnist_dataloader.cpython-36.pyc differ diff --git a/attempt1/__pycache__/mnistdiffusion_model.cpython-36.pyc b/attempt1/__pycache__/mnistdiffusion_model.cpython-36.pyc new file mode 100644 index 0000000..a028052 Binary files /dev/null and b/attempt1/__pycache__/mnistdiffusion_model.cpython-36.pyc differ diff --git a/attempt1/__pycache__/mnistdiffusion_utils.cpython-36.pyc b/attempt1/__pycache__/mnistdiffusion_utils.cpython-36.pyc new file mode 100644 index 0000000..7b7d8bc Binary files /dev/null and b/attempt1/__pycache__/mnistdiffusion_utils.cpython-36.pyc differ diff --git a/attempt1/mnist_dataloader.py b/attempt1/mnist_dataloader.py new file mode 100644 index 0000000..be0dff3 --- /dev/null +++ b/attempt1/mnist_dataloader.py @@ -0,0 +1,91 @@ +#!/usr/bin/python3 +import os,sys,math +import numpy as np +import cv2 +import gzip #need to use gzip.open instead of open +import struct + +import torch + +def read_MNIST_label_file(fname): + #fp = gzip.open('./train-labels-idx1-ubyte.gz','rb'); + fp = gzip.open(fname,'rb'); + magic = fp.read(4); + #nitems = np.frombuffer(fp.read(4),dtype=np.int32)[0]; #some sort of endiannes problem + bts = fp.read(4); + #bts = bytereverse(bts); + #nitems = np.frombuffer(bts,dtype=np.int32); + nitems = np.int32(struct.unpack('>I',bts)[0]); #it was a non-native endianness in teh integer encoding + #> < @ - endianness + + bts = fp.read(nitems); + N = len(bts); + labels = np.zeros((N),dtype=np.uint8); + labels = np.frombuffer(bts,dtype=np.uint8,count=N); + #for i in range(0,10): + # bt = fp.read(1); + # labels[i] = np.frombuffer(bt,dtype=np.uint8); + fp.close(); + return labels; + +def read_MNIST_image_file(fname): + fp = gzip.open(fname,'rb'); + magic = fp.read(4); + bts = fp.read(4); + nitems = np.int32(struct.unpack('>I',bts)[0]); + bts = fp.read(4); + nrows = np.int32(struct.unpack('>I',bts)[0]); + bts = fp.read(4); + ncols = np.int32(struct.unpack('>I',bts)[0]); + + images = np.zeros((nitems,nrows,ncols),dtype=np.uint8); + for I in range(0,nitems): + bts = fp.read(nrows*ncols); + img1 = np.frombuffer(bts,dtype=np.uint8,count=nrows*ncols); + img1 = img1.reshape((nrows,ncols)); + images[I,:,:] = img1; + + fp.close(); + + return images; + +#The mnist dataset is small enough to fit entirely in memory +def mnist_load(): + baseloc = "../training_data" + + traindatafile = "train-images-idx3-ubyte.gz" + trainlabelfile = "train-labels-idx1-ubyte.gz" + testdatafile = "t10k-images-idx3-ubyte.gz" + testlabelfile = "t10k-labels-idx1-ubyte.gz" + + traindatafile = os.path.join(baseloc,traindatafile) + trainlabelfile = os.path.join(baseloc,trainlabelfile) + testdatafile = os.path.join(baseloc,testdatafile) + testlabelfile = os.path.join(baseloc,testlabelfile) + + labels_train = read_MNIST_label_file(trainlabelfile) + labels_test = read_MNIST_label_file(testlabelfile) + images_train = read_MNIST_image_file(traindatafile) + images_test = read_MNIST_image_file(testdatafile) + + labels_train = torch.tensor(labels_train,dtype=torch.float32,requires_grad=False) + labels_test = torch.tensor(labels_test,dtype=torch.float32,requires_grad=False) + images_train = torch.tensor(images_train,dtype=torch.float32,requires_grad=False) + images_test = torch.tensor(images_test,dtype=torch.float32,requires_grad=False) + + # #debug + # print(labels_train.shape) + # print(labels_test.shape) + # print(images_train.shape) + # print(images_test.shape) + + + return [labels_train, labels_test, images_train, images_test] + +if(__name__ == "__main__"): + [labels_train, labels_test, images_train, images_test] = mnist_load() + print("Loaded MNIST Data") + print(labels_train.shape) + print(labels_test.shape) + print(images_train.shape) + print(images_test.shape) diff --git a/attempt1/mnistdiffusion.py b/attempt1/mnistdiffusion.py new file mode 100644 index 0000000..07aaf0d --- /dev/null +++ b/attempt1/mnistdiffusion.py @@ -0,0 +1,401 @@ +#!/usr/bin/python3 + +import os,sys,math +import numpy as np + +import matplotlib.pyplot as plt + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import mnist_dataloader + +from mnistdiffusion_model import mndiff_rev +from mnistdiffusion_utils import * + +import time + +################ +## Dataloader ## +################ + +def mnist_load(): + + [labels_train, labels_test, images_train, images_test] = mnist_dataloader.mnist_load() + + #60000 x 28 x 28 + #60000 x 1 + + #Rescale data to (-1,1) + images_train = images_train.clone().detach().float() + images_test = images_test.clone().detach().float() + + images_train = 2.0*(images_train/255.0)-1.0 + images_test = 2.0*(images_test/255.0)-1.0 + + images_train.requires_grad = False + images_test.requires_grad = False + + return [images_train, images_test] + + + + +################ +## Operations ## +################ + +# Rather than killing myself figuring out variable variance schedules, +# just use a constant variance schedule to start with + +def training_parameters(): + lparas = dict() + + nts = 200 + Beta = 0.0050 + + lr = 1E-4 + batchsize = 80 + save_every = 200 + record_every = 25 + + B_t = np.ones((nts))*Beta + a_t = np.zeros((nts)) + a_t[0] = 1 + for I in range(1,nts): + a_t[I] = a_t[I-1]*(1-B_t[I-1]) + + t = np.linspace(0,1,nts) #normalized time (t/T) + + lparas["nts"] = nts + lparas["Beta"] = Beta + lparas["t"] = torch.tensor(t,dtype=torch.float32) + lparas["a_t"] = torch.tensor(a_t,dtype=torch.float32) + lparas["B_t"] = torch.tensor(B_t,dtype=torch.float32) + lparas["lr"] = lr + lparas["batchsize"] = batchsize + lparas["save_every"] = save_every + lparas["record_every"] = record_every + return lparas + + +def generate_minibatch(imgs, Ndraw, lparas): + + #I can probably speed this up with compiled CUDA code + + nts = lparas["nts"] + Beta = lparas["Beta"] + + dev = imgs.device + + imbatch = torch.zeros([Ndraw,nts,28,28],dtype=torch.float32).to(dev) + draws = torch.randint(0,imgs.shape[0],(Ndraw,)) + + times = lparas["t"].to(dev) + times = times.reshape([1,nts]).expand([Ndraw,nts]) + + for I in range(0,Ndraw): + imgseq = torch.zeros([nts,28,28],dtype=torch.float32).to(dev) + imgseq[0,:,:] = imgs[draws[I],:,:] + + beta = torch.kron(lparas["B_t"],torch.ones(28,28)).to(dev) + beta = beta.reshape([nts,28,28]) + + sig = torch.sqrt(beta) + noise = torch.randn((nts,28,28)).to(dev)*sig + + for J in range(1,nts): + imgseq[J,:,:] = imgseq[J-1,:,:]*torch.sqrt(1.0-beta[J-1,:,:]) + noise[J-1,:,:] + + + imbatch[I,:,:,:] = imgseq[:,:,:] + + return [imbatch,times] + +def train_batch(revnet,batchimg,batchtime,lparas): + + L = 0.0 + lr = lparas["lr"] + nts = batchtime.shape[1] + + dev = revnet.device + #beta = torch.as_tensor(lparas["Beta"]).float().to(dev) + beta = float(lparas["Beta"]) + + optim = torch.optim.SGD(revnet.parameters(),lr=lr) + optim = torch.optim.Adam(revnet.parameters(),lr=lr) + + loss = torch.nn.MSELoss().to(dev) + + d_est = revnet(batchimg,batchtime) + d_est = d_est[:,1:nts,:,:] + targ = batchimg[:,1:nts,:,:]-batchimg[:,0:nts-1,:,:] + + L2 = loss(targ,beta*d_est) + + L2.backward() + + optim.step() + L = L2.detach().cpu().numpy() + L = L/beta**2 + + loss.zero_grad() + + + return L + + + + +def train(directory,saveseries): + + print("Loading dataset...") + [images_train, images_test] = mnist_load() + lparas = training_parameters() + + batchsize = lparas["batchsize"] + save_every = lparas["save_every"] + record_every = lparas["record_every"] + device = torch.device("cuda") + + revnet = mndiff_rev() + fname = get_current_save(directory,saveseries) + if(fname is None): + fname = get_next_save(directory,saveseries) + else: + print("Loading {}".format(fname)) + revnet.load(fname) + + images_train = images_train.to(device) + revnet = revnet.to(device) + + ## Training Loop + I = 0 + J = 0 + while(True): + + if(I%record_every==0): + print("Minibatch {}".format(I)) + + #New minibatch + losses = np.zeros((record_every)) + + + #draw minibatch + [mb,mbt] = generate_minibatch(images_train,batchsize,lparas) + + #train minibatch + L = train_batch(revnet,mb,mbt,lparas) + losses[J] = L + J = J + 1 + + #record results + if(J%record_every==0): + J = 0 + loss = np.mean(losses) + revnet.recordloss(loss) + print("\tloss={:1.5f}".format(loss)) + + #save + if(I%save_every==0): + fnext = get_next_save(directory,saveseries) + revnet.save(fnext) + + I = I + 1 + + return + +def infer(directory,saveseries): + + ttt = int(time.time()) + np.random.seed(ttt%50000) + torch.manual_seed(ttt%50000) + + + + lparas = training_parameters() + revnet = mndiff_rev() + fname = get_current_save(directory,saveseries) + if(fname is None): + print("Warning, no state loaded.") + else: + print("Loading {}".format(fname)) + revnet.load(fname) + + + nts = lparas["nts"] + beta = lparas["Beta"] + + nts = nts*15 + + imgseq = torch.zeros([nts,28,28]) + imgseq[nts-1,:,:] = torch.randn(28,28) + + #imgseq[nts-1,:,17:20] = 3.0 + + tl = torch.linspace(0,1,nts) + + bb = torch.as_tensor(beta) + + for I in range(nts-2,-1,-1): + img = imgseq[I+1,:,:] + t = tl[I] + eps = revnet(img,t) + + + + img2 = img*torch.sqrt(1-bb) - bb*eps + 0.5*bb*torch.randn(28,28) + + imgseq[I,:,:] = img2 + + im0 = torch.squeeze(imgseq[0,:,:]) + im1 = torch.squeeze(imgseq[int(nts*0.25),:,:]) + im2 = torch.squeeze(imgseq[int(nts*0.75),:,:]) + im3 = torch.squeeze(imgseq[int(nts*0.90),:,:]) + + im0 = im0.detach().numpy() + im1 = im1.detach().numpy() + im2 = im2.detach().numpy() + im3 = im3.detach().numpy() + + plt.figure() + plt.subplot(2,2,1) + plt.imshow(im3,cmap='gray') + plt.subplot(2,2,2) + plt.imshow(im2,cmap='gray') + plt.subplot(2,2,3) + plt.imshow(im1,cmap='gray') + plt.subplot(2,2,4) + plt.imshow(im0,cmap='gray') + plt.show() + + + return + +def plot_history(directory, saveseries): + + revnet = mndiff_rev() + fname = get_current_save(directory,saveseries) + if(fname is None): + fname = get_next_save(directory,saveseries) + else: + print("Loading {}".format(fname)) + revnet.load(fname) + + revnet.plothistory() + + return + +def test1(): + + [images_train, images_test] = mnist_load() + lparas = training_parameters() + + #plt.figure() + #plt.plot(lparas["t"],lparas["a_t"]) + #plt.show() + + #plt.figure() + #plt.plot(lparas["t"],lparas["B_t"]) + #plt.show() + + images_train = images_train.to("cuda") + + [mb,mbtime] = generate_minibatch(images_train,50,lparas) + + img0 = torch.squeeze(mb[0,0,:,:]).clone().detach().to("cpu") + img1 = torch.squeeze(mb[0,1,:,:]).clone().detach().to("cpu") + img2 = torch.squeeze(mb[0,2,:,:]).clone().detach().to("cpu") + img3 = torch.squeeze(mb[0,99,:,:]).clone().detach().to("cpu") + + print(mb.shape) + print(mbtime.shape) + + plt.figure() + plt.subplot(2,2,1) + plt.imshow(img0,cmap='gray') + plt.subplot(2,2,2) + plt.imshow(img1,cmap='gray') + plt.subplot(2,2,3) + plt.imshow(img2,cmap='gray') + plt.subplot(2,2,4) + plt.imshow(img3,cmap='gray') + + plt.show() + + return + +def test2(directory, saveseries): + + lparas = training_parameters() + nts = lparas["nts"] + beta = lparas["Beta"] + + revnet = mndiff_rev() + fname = get_current_save(directory,saveseries) + if(fname is None): + fname = get_next_save(directory,saveseries) + else: + print("Loading {}".format(fname)) + revnet.load(fname) + + + [images_train, images_test] = mnist_load() + + x = torch.squeeze(images_train[100,:,:]) + #x = beta*torch.randn(28,28) + y = revnet(x,0.5) + + x = x.detach().numpy() + y = y.detach().numpy() + + print("y mn={:1.3f} std={:1.3f} std2*beta={}".format(np.mean(y),np.std(y),np.std(y)**2*beta)) + + + plt.figure() + plt.subplot(2,1,1) + plt.imshow(x,cmap='gray') + plt.subplot(2,1,2) + plt.imshow(y,cmap='gray') + plt.show() + + + return + +########## +## Main ## +########## + +def mainswitch(): + args = sys.argv + if(len(args)<2): + print("mnistdiffusion [operation]") + print("operations:") + print("\ttrain") + print("\tplot_history") + print("\tinfer") + + exit(0) + + if(sys.argv[1]=="train"): + train("./saves","test_f") + if(sys.argv[1]=="plot_history"): + plot_history("./saves","test_f") + if(sys.argv[1]=="infer"): + infer("./saves","test_f") + + + + + +if(__name__=="__main__"): + + ttt = int(time.time()) + np.random.seed(ttt%50000) + torch.manual_seed(ttt%50000) + + #test1() + #test2("./saves","test_d") + mainswitch() + diff --git a/attempt1/mnistdiffusion_model.py b/attempt1/mnistdiffusion_model.py new file mode 100644 index 0000000..705f5f1 --- /dev/null +++ b/attempt1/mnistdiffusion_model.py @@ -0,0 +1,289 @@ +#!/usr/bin/python3 + +import os,sys,math +import numpy as np + +import matplotlib.pyplot as plt + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import mnist_dataloader + + +## Sohl-Dickstein 2015 describes a gaussian "bump" function to +# implement learnable time-dependence +# +#network input vector and normalized time (tn e (0,1)) +# input: (dimx, dimy, ..., features), (scalar) --> output: (dimx, dimy, ...) +# input: (ntimes, ..., features), (ntimes) --> output: (ntimes , ...) +class gaussian_timeselect(nn.Module): + + def __init__(self, nfeatures=5, width = 0.25): + super(gaussian_timeselect,self).__init__() + + tau_j = torch.linspace(0,1,nfeatures).float() + self.tau_j = torch.nn.parameter.Parameter(tau_j, requires_grad=False) + + width = torch.tensor(width,dtype=torch.float32) + self.width = torch.nn.parameter.Parameter(width,requires_grad = False) + + #network input vector and normalized time (tn e (0,1)) + # input: (dimx, dimy, ..., features), (scalar) --> output: (dimx, dimy, ...) + # input: (ntimes, ..., features), (ntimes) --> output: (ntimes , ...) + def forward(self, x, tn): + + ndim = len(x.shape) + + tn = torch.as_tensor(tn) + tn.requires_grad = False + + + if(len(tn.shape)==0): + #case 0 - if tn is a scalar + g = torch.exp(-1.0/(2.0*self.width)*(tn-self.tau_j)**2) + g = g/torch.clip(torch.sum(g),min=0.001) + x = torch.tensordot(x,g,dims=[[ndim-1],[0]]) + else: + #Case 1 - tn is a tensor + nf = self.tau_j.shape[0] + nt = tn.numel() + nt2 = x.shape[0] + ndim = len(x.shape) + nf2 = x.shape[ndim-1] + + if(nt!=nt2): + raise Exception("gaussian_timeselect time mismatch") + if(nf!=nf2): + raise Exception("gaussian_timeselect feature mismatch") + + shp2 = x.shape[1:ndim-1] + nshp2 = shp2.numel() + + x = x.reshape([nt,nshp2,nf]) + tau = self.tau_j.reshape([1,1,nf]) + tau = tau.expand([nt,nshp2,nf]) + tn = tn.reshape([nt,1,1]) + tn = tn.expand([nt,nshp2,nf]) + + g = torch.exp(-1.0/(2.0*self.width)*(tn-tau)**2) + gs = torch.clip(torch.sum(g,axis=2),min=0.0001) + gs = gs.reshape([nt,nshp2,1]).expand([nt,nshp2,nf]) + g = g/gs + + x = torch.sum(x*g,axis=2) + #shp3 = torch.Size(torch.cat(torch.tensor([nt]),torch.tensor(shp2))) + # contact [nt, *shp2] + shp3 = torch.Size([nt,*shp2]) + x = x.reshape(shp3) + + return x + +#The reverse network should output a mean image and a covariance image +#given an input image + +class mndiff_rev(nn.Module): + + def __init__(self): + super(mndiff_rev,self).__init__() + + self.L1 = nn.Linear(28*28,1000,bias=True) + self.L2 = nn.Linear(1000,1000,bias=True) + self.L3 = nn.Linear(1000,28*28*5,bias=True) + + #self.L4 = nn.Conv2d(10,10,kernel_size=1) + #channels must be 2nd dimension + self.L4a = nn.Linear(5,5) + #equivalent, channels must be last dimension + self.L4b = nn.Linear(5,5) + + + self.L5 = gaussian_timeselect(5,0.1) + + #self.L5alt = nn.Linear(5,1) + + #self.RL = nn.LeakyReLU(0.01) + self.tanh = nn.Tanh() + + self.device = torch.device("cpu") + + self.losstime = [] + self.lossrecord = [] + + return + + + def forward(self,x,t): + + #handle data-shape cases + #case 0: x (pix_x, pix_y), t is scalar + #case 1: x (ntimes, pix_x, pix_y), t is (ntimes) + #case 2: x (nbatch, ntimes, pix_x, pix_y), t is (nbatch, ntimes) + + if(len(x.shape)==2): + case = 0 + t = torch.as_tensor(t).to(self.device) + t = t.float() + x = x.reshape([1,28*28]) + ntimes = 1 + nbatch = 1 + elif(len(x.shape)==3): + case = 1 + ntimes = x.shape[0] + nbatch = 1 + x = x.reshape([ntimes,28*28]) + elif(len(x.shape)==4): + case = 2 + nbatch = x.shape[0] + ntimes = x.shape[1] + #x = x.reshape([nbatch,ntimes,28*28]) + x = x.reshape([nbatch*ntimes,28*28]) + t = t.reshape([nbatch*ntimes]) + + y = torch.randn(28*28,device=self.device) + + else: + print("Error: Expect input to be dimension 2,3,4") + raise Exception("Error: Expect input to be dimension 2,3,4") + + s1 = self.L1(x) + s1 = self.tanh(s1) + s1 = self.L2(s1) + s1 = self.tanh(s1) + s1 = self.L3(s1) + s1 = self.tanh(s1) + s1 = s1.reshape([ntimes*nbatch,28*28,5]) + s1 = self.L4a(s1) + s1 = self.tanh(s1) + s1 = self.L4b(s1) + #s1 = self.tanh(s1) + + s2 = self.L5(s1,t) + #s2 = self.L5alt(s1) + + eps = s2 + #eps = self.tanh(s2) + + if(case==0): + eps = eps.reshape([28,28]) + elif(case==1): + eps = eps.reshape([ntimes,28,28]) + elif(case==2): + eps = eps.reshape([nbatch,ntimes,28,28]) + + return eps + + def to(self,dev): + self.device = torch.device(dev) + ret = super(mndiff_rev,self).to(self.device) + return ret + + + def save(self,fname): + try: + dev = self.device + self.to("cpu") + q = [self.state_dict(), self.losstime, self.lossrecord] + torch.save(q,fname) + self.to(dev) + except: + print("model save: problem saving {}".format(fname)) + + return + + def load(self,fname): + try: + [modelsd,losstime,lossrecord] = torch.load(fname) + self.load_state_dict(modelsd) + self.losstime = losstime + self.lossrecord = lossrecord + except: + print("model load: problem loading {}".format(fname)) + + return + + + def plothistory(self): + x = self.losstime + y = self.lossrecord + plt.figure() + plt.plot(x,y,'k.',markersize=1) + plt.title('Loss History') + plt.show() + + return + + def recordloss(self,loss): + + t = 0 + L = len(self.losstime) + if(L>0): + t = self.losstime[L-1] + self.losstime.append(t+1) + self.lossrecord.append(loss) + + return + + + +########### +## Tests ## +########### + +def test_gaussian_timeselect(): + gts = gaussian_timeselect(5,0.025) + gts = gts.to("cuda") + + x = torch.randn((8,8,5)).to("cuda") + #t = torch.tensor([0,0.5,1]).to("cuda") + t = torch.linspace(0,1,8).to("cuda") + + y = gts(x,0) + y2 = gts(x,t) + y3 = gts(x,0.111) + + print(y) + print(x[:,:,0]) + + print(x.shape) + print(y.shape) + print(y2.shape) + + print(torch.abs(y2[0,:]-y[0,:])>1E-6) + + + + return + +def test_mndiff_rev(): + + mnd = mndiff_rev() + + x0 = torch.randn(28,28) + t0 = 0 + x1 = torch.randn(5,28,28) + t1 = torch.linspace(0,1,5) + + x2 = torch.randn(8,5,28,28) + t2 = torch.linspace(0,1,5) + t2 = t2.reshape([1,5]).expand([8,5]) + + y0 = mnd(x0,t0) + y1 = mnd(x1,t1) + y2 = mnd(x2,t2) + + print(y0.shape) + print(y1.shape) + print(y2.shape) + + + + return + +if(__name__=="__main__"): + + test_gaussian_timeselect() + #test_mndiff_rev() + + pass \ No newline at end of file diff --git a/attempt1/mnistdiffusion_utils.py b/attempt1/mnistdiffusion_utils.py new file mode 100644 index 0000000..bdc488e --- /dev/null +++ b/attempt1/mnistdiffusion_utils.py @@ -0,0 +1,78 @@ +#!/usr/bin/python3 + +import os,sys,math + + +def get_savenumber(fname): + num = 0 + + fn2 = os.path.split(fname)[1] + fn2 = os.path.splitext(fn2)[0] + + L = len(fn2) + L2 = L + for I in range(L-1,0,-1): + c = fn2[I] + if(not c.isnumeric()): + L2 = I+1 + break + nm = fn2[L2:L] + + try: + num = int(nm) + except: + num = 0 + + return num + + +def sorttwopythonlists(list1, list2): + + zipped_pairs = zip(list2, list1) + + z = [x for _, x in sorted(zipped_pairs)] + z2 = sorted(list2) + + return [z,z2] + +def get_current_save(directory,saveseries): + fname = None + + fl = os.listdir(directory) + fl2 = [] + nums = [] + + for f in fl: + if(f.find(saveseries)>=0): + fl2.append(os.path.join(directory,f)) + + for f in fl2: + n = get_savenumber(f) + nums.append(n) + + [fl2,nums] = sorttwopythonlists(fl2,nums) + + if(len(fl2)>0): + fname = fl2[len(fl2)-1] + else: + fname = None + + return fname + +def get_next_save(directory,saveseries): + fname = None + fncurrent = get_current_save(directory,saveseries) + if(fncurrent is None): + fname = os.path.join(directory,"{}{:02d}.pyt".format(saveseries,0)) + else: + N = get_savenumber(fncurrent) + N = N + 1 + fname = os.path.join(directory,"{}{:02d}.pyt".format(saveseries,N)) + + + return fname + +if(__name__=="__main__"): + + #print(get_savenumber("./saves/helloworld05202.py")) + pass diff --git a/attempt1/old/mnistdiffusion.py b/attempt1/old/mnistdiffusion.py new file mode 100644 index 0000000..45ed2b2 --- /dev/null +++ b/attempt1/old/mnistdiffusion.py @@ -0,0 +1,161 @@ +#!/usr/bin/python3 + +import os,sys,math +import numpy as np + +import matplotlib.pyplot as plt + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import mnist_dataloader + +from mnistdiffusion_model import mndiff_rev + +################ +## Dataloader ## +################ + +def mnist_load(): + + [labels_train, labels_test, images_train, images_test] = mnist_dataloader.mnist_load() + + #60000 x 28 x 28 + #60000 x 1 + + #Rescale data to (-1,1) + images_train = images_train.clone().detach().float() + images_test = images_test.clone().detach().float() + + images_train = 2.0*(images_train/255.0)-1.0 + images_test = 2.0*(images_test/255.0)-1.0 + + images_train.requires_grad = False + images_test.requires_grad = False + + return [images_train, images_test] + + + + +################ +## Operations ## +################ + +def training_parameters(): + lparas = dict() + + nts = 200 + t = np.arange(0,nts) + s = 0.1 + f_t = np.cos((t/nts+s)/(1+s)*np.pi/2.0)**2 + a_t = f_t/f_t[0] + B_t = np.zeros(t.shape[0]) + B_t[0:nts-1] = np.clip(1-a_t[1:nts]/a_t[0:nts-1],0,0.999) + B_t[nts-1] = B_t[nts-2] + + lparas["nts"] = nts + lparas["t"] = torch.tensor(t,dtype=torch.float32) + lparas["a_t"] = torch.tensor(a_t,dtype=torch.float32) + lparas["B_t"] = torch.tensor(B_t,dtype=torch.float32) + + #Traiing parameters for a cosine variance schedule + #ref: Nichol and Dharawal 2021 + + return lparas + + +def generate_minibatch(imgs, Ndraw, lparas): + + #I can probably speed this up with compiled CUDA code + + nts = lparas["nts"] + dev = imgs.device + + imbatch = torch.zeros([Ndraw*nts,28,28],dtype=torch.float32).to(dev) + draws = torch.randint(0,imgs.shape[0],(Ndraw,)) + + for I in range(0,Ndraw): + imgseq = torch.zeros([nts,28,28],dtype=torch.float32).to(dev) + imgseq[0,:,:] = imgs[draws[I],:,:] + + beta = torch.kron(lparas["B_t"],torch.ones(28,28)).to(dev) + beta = beta.reshape([nts,28,28]) + + sig = torch.sqrt(beta) + noise = torch.randn((nts,28,28)).to(dev)*sig + + for J in range(1,nts): + imgseq[J,:,:] = imgseq[J-1,:,:] + noise[J-1,:,:] + + I1 = I*nts + I2 = (I+1)*nts + imbatch[I1:I2,:,:] = imgseq[:,:,:] + + return imbatch + +def train_batch(imgbatch,lr,lparas): + + + + + return + + +def train(): + + return + +def infer(): + + return + + + +def test1(): + + [images_train, images_test] = mnist_load() + lparas = training_parameters() + + #plt.figure() + #plt.plot(lparas["t"],lparas["a_t"]) + #plt.show() + + #plt.figure() + #plt.plot(lparas["t"],lparas["B_t"]) + #plt.show() + + images_train = images_train.to("cuda") + + mb = generate_minibatch(images_train,50,lparas) + + img0 = torch.squeeze(mb[0,:,:]).clone().detach().to("cpu") + img1 = torch.squeeze(mb[1,:,:]).clone().detach().to("cpu") + img2 = torch.squeeze(mb[2,:,:]).clone().detach().to("cpu") + img3 = torch.squeeze(mb[30,:,:]).clone().detach().to("cpu") + + + plt.figure() + plt.subplot(2,2,1) + plt.imshow(img0,cmap='gray') + plt.subplot(2,2,2) + plt.imshow(img1,cmap='gray') + plt.subplot(2,2,3) + plt.imshow(img2,cmap='gray') + plt.subplot(2,2,4) + plt.imshow(img3,cmap='gray') + + plt.show() + + return + + +########## +## Main ## +########## + +if(__name__=="__main__"): + + #test1() + test_gaussian_timeselect() diff --git a/attempt1/scratchpad.ipynb b/attempt1/scratchpad.ipynb new file mode 100644 index 0000000..8968d3c --- /dev/null +++ b/attempt1/scratchpad.ipynb @@ -0,0 +1,85 @@ +{ + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9-final" + }, + "orig_nbformat": 2, + "kernelspec": { + "name": "python3", + "display_name": "Python 3.6.9 64-bit", + "metadata": { + "interpreter": { + "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 2, + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os,sys,math\n", + "import numpy as np\n", + "\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "import mnist_dataloader\n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "torch.Size([])\ntorch.Size([])\n[1, 2, 3]\n" + ] + } + ], + "source": [ + "a = torch.randn(10,14)\n", + "b = a.shape[1:1]\n", + "print(b)\n", + "b.numel()\n", + "\n", + "print(b)\n", + "\n", + "b = torch.Size([1])\n", + "c = torch.Size([2,3])\n", + "d = torch.Size(torch.cat([torch.tensor(b),torch.tensor(c)]))\n", + "\n", + "d = [*b,*c]\n", + "\n", + "print(d)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ] +} \ No newline at end of file diff --git a/attempt2/__pycache__/mnist_dataloader.cpython-36.pyc b/attempt2/__pycache__/mnist_dataloader.cpython-36.pyc new file mode 100644 index 0000000..85ab9db Binary files /dev/null and b/attempt2/__pycache__/mnist_dataloader.cpython-36.pyc differ diff --git a/attempt2/__pycache__/mnistdiffusion_model.cpython-36.pyc b/attempt2/__pycache__/mnistdiffusion_model.cpython-36.pyc new file mode 100644 index 0000000..be927b4 Binary files /dev/null and b/attempt2/__pycache__/mnistdiffusion_model.cpython-36.pyc differ diff --git a/attempt2/__pycache__/mnistdiffusion_utils.cpython-36.pyc b/attempt2/__pycache__/mnistdiffusion_utils.cpython-36.pyc new file mode 100644 index 0000000..15e1a7f Binary files /dev/null and b/attempt2/__pycache__/mnistdiffusion_utils.cpython-36.pyc differ diff --git a/attempt2/mnist_dataloader.py b/attempt2/mnist_dataloader.py new file mode 100644 index 0000000..be0dff3 --- /dev/null +++ b/attempt2/mnist_dataloader.py @@ -0,0 +1,91 @@ +#!/usr/bin/python3 +import os,sys,math +import numpy as np +import cv2 +import gzip #need to use gzip.open instead of open +import struct + +import torch + +def read_MNIST_label_file(fname): + #fp = gzip.open('./train-labels-idx1-ubyte.gz','rb'); + fp = gzip.open(fname,'rb'); + magic = fp.read(4); + #nitems = np.frombuffer(fp.read(4),dtype=np.int32)[0]; #some sort of endiannes problem + bts = fp.read(4); + #bts = bytereverse(bts); + #nitems = np.frombuffer(bts,dtype=np.int32); + nitems = np.int32(struct.unpack('>I',bts)[0]); #it was a non-native endianness in teh integer encoding + #> < @ - endianness + + bts = fp.read(nitems); + N = len(bts); + labels = np.zeros((N),dtype=np.uint8); + labels = np.frombuffer(bts,dtype=np.uint8,count=N); + #for i in range(0,10): + # bt = fp.read(1); + # labels[i] = np.frombuffer(bt,dtype=np.uint8); + fp.close(); + return labels; + +def read_MNIST_image_file(fname): + fp = gzip.open(fname,'rb'); + magic = fp.read(4); + bts = fp.read(4); + nitems = np.int32(struct.unpack('>I',bts)[0]); + bts = fp.read(4); + nrows = np.int32(struct.unpack('>I',bts)[0]); + bts = fp.read(4); + ncols = np.int32(struct.unpack('>I',bts)[0]); + + images = np.zeros((nitems,nrows,ncols),dtype=np.uint8); + for I in range(0,nitems): + bts = fp.read(nrows*ncols); + img1 = np.frombuffer(bts,dtype=np.uint8,count=nrows*ncols); + img1 = img1.reshape((nrows,ncols)); + images[I,:,:] = img1; + + fp.close(); + + return images; + +#The mnist dataset is small enough to fit entirely in memory +def mnist_load(): + baseloc = "../training_data" + + traindatafile = "train-images-idx3-ubyte.gz" + trainlabelfile = "train-labels-idx1-ubyte.gz" + testdatafile = "t10k-images-idx3-ubyte.gz" + testlabelfile = "t10k-labels-idx1-ubyte.gz" + + traindatafile = os.path.join(baseloc,traindatafile) + trainlabelfile = os.path.join(baseloc,trainlabelfile) + testdatafile = os.path.join(baseloc,testdatafile) + testlabelfile = os.path.join(baseloc,testlabelfile) + + labels_train = read_MNIST_label_file(trainlabelfile) + labels_test = read_MNIST_label_file(testlabelfile) + images_train = read_MNIST_image_file(traindatafile) + images_test = read_MNIST_image_file(testdatafile) + + labels_train = torch.tensor(labels_train,dtype=torch.float32,requires_grad=False) + labels_test = torch.tensor(labels_test,dtype=torch.float32,requires_grad=False) + images_train = torch.tensor(images_train,dtype=torch.float32,requires_grad=False) + images_test = torch.tensor(images_test,dtype=torch.float32,requires_grad=False) + + # #debug + # print(labels_train.shape) + # print(labels_test.shape) + # print(images_train.shape) + # print(images_test.shape) + + + return [labels_train, labels_test, images_train, images_test] + +if(__name__ == "__main__"): + [labels_train, labels_test, images_train, images_test] = mnist_load() + print("Loaded MNIST Data") + print(labels_train.shape) + print(labels_test.shape) + print(images_train.shape) + print(images_test.shape) diff --git a/attempt2/mnistdiffusion.py b/attempt2/mnistdiffusion.py new file mode 100644 index 0000000..0c37dc5 --- /dev/null +++ b/attempt2/mnistdiffusion.py @@ -0,0 +1,401 @@ +#!/usr/bin/python3 + +import os,sys,math +import numpy as np + +import matplotlib.pyplot as plt + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import mnist_dataloader + +from mnistdiffusion_model import mndiff_rev +from mnistdiffusion_utils import * + +import time + +################ +## Dataloader ## +################ + +def mnist_load(): + + [labels_train, labels_test, images_train, images_test] = mnist_dataloader.mnist_load() + + #60000 x 28 x 28 + #60000 x 1 + + #Rescale data to (-1,1) + images_train = images_train.clone().detach().float() + images_test = images_test.clone().detach().float() + + images_train = 2.0*(images_train/255.0)-1.0 + images_test = 2.0*(images_test/255.0)-1.0 + + images_train.requires_grad = False + images_test.requires_grad = False + + return [images_train, images_test] + + + + +################ +## Operations ## +################ + +# Rather than killing myself figuring out variable variance schedules, +# just use a constant variance schedule to start with + +def training_parameters(): + lparas = dict() + + nts = 200 + Beta = 0.0050 + + lr = 1E-4 + batchsize = 40 + save_every = 200 + record_every = 25 + + B_t = np.ones((nts))*Beta + a_t = np.zeros((nts)) + a_t[0] = 1 + for I in range(1,nts): + a_t[I] = a_t[I-1]*(1-B_t[I-1]) + + t = np.linspace(0,1,nts) #normalized time (t/T) + + lparas["nts"] = nts + lparas["Beta"] = Beta + lparas["t"] = torch.tensor(t,dtype=torch.float32) + lparas["a_t"] = torch.tensor(a_t,dtype=torch.float32) + lparas["B_t"] = torch.tensor(B_t,dtype=torch.float32) + lparas["lr"] = lr + lparas["batchsize"] = batchsize + lparas["save_every"] = save_every + lparas["record_every"] = record_every + return lparas + + +def generate_minibatch(imgs, Ndraw, lparas): + + #I can probably speed this up with compiled CUDA code + + nts = lparas["nts"] + Beta = lparas["Beta"] + + dev = imgs.device + + imbatch = torch.zeros([Ndraw,nts,28,28],dtype=torch.float32).to(dev) + draws = torch.randint(0,imgs.shape[0],(Ndraw,)) + + times = lparas["t"].to(dev) + times = times.reshape([1,nts]).expand([Ndraw,nts]) + + for I in range(0,Ndraw): + imgseq = torch.zeros([nts,28,28],dtype=torch.float32).to(dev) + imgseq[0,:,:] = imgs[draws[I],:,:] + + beta = torch.kron(lparas["B_t"],torch.ones(28,28)).to(dev) + beta = beta.reshape([nts,28,28]) + + sig = torch.sqrt(beta) + noise = torch.randn((nts,28,28)).to(dev)*sig + + for J in range(1,nts): + imgseq[J,:,:] = imgseq[J-1,:,:]*torch.sqrt(1.0-beta[J-1,:,:]) + noise[J-1,:,:] + + + imbatch[I,:,:,:] = imgseq[:,:,:] + + return [imbatch,times] + +def train_batch(revnet,batchimg,batchtime,lparas): + + L = 0.0 + lr = lparas["lr"] + nts = batchtime.shape[1] + + dev = revnet.device + #beta = torch.as_tensor(lparas["Beta"]).float().to(dev) + beta = float(lparas["Beta"]) + + optim = torch.optim.SGD(revnet.parameters(),lr=lr) + optim = torch.optim.Adam(revnet.parameters(),lr=lr) + + loss = torch.nn.MSELoss().to(dev) + + d_est = revnet(batchimg,batchtime) + d_est = d_est[:,1:nts,:,:] + targ = batchimg[:,1:nts,:,:]-batchimg[:,0:nts-1,:,:] + + L2 = loss(targ,beta*d_est) + + L2.backward() + + optim.step() + L = L2.detach().cpu().numpy() + L = L/beta**2 + + loss.zero_grad() + + + return L + + + + +def train(directory,saveseries): + + print("Loading dataset...") + [images_train, images_test] = mnist_load() + lparas = training_parameters() + + batchsize = lparas["batchsize"] + save_every = lparas["save_every"] + record_every = lparas["record_every"] + device = torch.device("cuda") + + revnet = mndiff_rev() + fname = get_current_save(directory,saveseries) + if(fname is None): + fname = get_next_save(directory,saveseries) + else: + print("Loading {}".format(fname)) + revnet.load(fname) + + images_train = images_train.to(device) + revnet = revnet.to(device) + + ## Training Loop + I = 0 + J = 0 + while(True): + + if(I%record_every==0): + print("Minibatch {}".format(I)) + + #New minibatch + losses = np.zeros((record_every)) + + + #draw minibatch + [mb,mbt] = generate_minibatch(images_train,batchsize,lparas) + + #train minibatch + L = train_batch(revnet,mb,mbt,lparas) + losses[J] = L + J = J + 1 + + #record results + if(J%record_every==0): + J = 0 + loss = np.mean(losses) + revnet.recordloss(loss) + print("\tloss={:1.5f}".format(loss)) + + #save + if(I%save_every==0): + fnext = get_next_save(directory,saveseries) + revnet.save(fnext) + + I = I + 1 + + return + +def infer(directory,saveseries): + + ttt = int(time.time()) + np.random.seed(ttt%50000) + torch.manual_seed(ttt%50000) + + + + lparas = training_parameters() + revnet = mndiff_rev() + fname = get_current_save(directory,saveseries) + if(fname is None): + print("Warning, no state loaded.") + else: + print("Loading {}".format(fname)) + revnet.load(fname) + + + nts = lparas["nts"] + beta = lparas["Beta"] + + nts = nts*15 + + imgseq = torch.zeros([nts,28,28]) + imgseq[nts-1,:,:] = torch.randn(28,28) + + #imgseq[nts-1,:,17:20] = 3.0 + + tl = torch.linspace(0,1,nts) + + bb = torch.as_tensor(beta) + + for I in range(nts-2,-1,-1): + img = imgseq[I+1,:,:] + t = tl[I] + eps = revnet(img,t) + + + + img2 = img*torch.sqrt(1-bb) - bb*eps + 1.0*bb*torch.randn(28,28) + + imgseq[I,:,:] = img2 + + im0 = torch.squeeze(imgseq[0,:,:]) + im1 = torch.squeeze(imgseq[int(nts*0.25),:,:]) + im2 = torch.squeeze(imgseq[int(nts*0.75),:,:]) + im3 = torch.squeeze(imgseq[int(nts*0.90),:,:]) + + im0 = im0.detach().numpy() + im1 = im1.detach().numpy() + im2 = im2.detach().numpy() + im3 = im3.detach().numpy() + + plt.figure() + plt.subplot(2,2,1) + plt.imshow(im3,cmap='gray') + plt.subplot(2,2,2) + plt.imshow(im2,cmap='gray') + plt.subplot(2,2,3) + plt.imshow(im1,cmap='gray') + plt.subplot(2,2,4) + plt.imshow(im0,cmap='gray') + plt.show() + + + return + +def plot_history(directory, saveseries): + + revnet = mndiff_rev() + fname = get_current_save(directory,saveseries) + if(fname is None): + fname = get_next_save(directory,saveseries) + else: + print("Loading {}".format(fname)) + revnet.load(fname) + + revnet.plothistory() + + return + +def test1(): + + [images_train, images_test] = mnist_load() + lparas = training_parameters() + + #plt.figure() + #plt.plot(lparas["t"],lparas["a_t"]) + #plt.show() + + #plt.figure() + #plt.plot(lparas["t"],lparas["B_t"]) + #plt.show() + + images_train = images_train.to("cuda") + + [mb,mbtime] = generate_minibatch(images_train,50,lparas) + + img0 = torch.squeeze(mb[0,0,:,:]).clone().detach().to("cpu") + img1 = torch.squeeze(mb[0,1,:,:]).clone().detach().to("cpu") + img2 = torch.squeeze(mb[0,2,:,:]).clone().detach().to("cpu") + img3 = torch.squeeze(mb[0,99,:,:]).clone().detach().to("cpu") + + print(mb.shape) + print(mbtime.shape) + + plt.figure() + plt.subplot(2,2,1) + plt.imshow(img0,cmap='gray') + plt.subplot(2,2,2) + plt.imshow(img1,cmap='gray') + plt.subplot(2,2,3) + plt.imshow(img2,cmap='gray') + plt.subplot(2,2,4) + plt.imshow(img3,cmap='gray') + + plt.show() + + return + +def test2(directory, saveseries): + + lparas = training_parameters() + nts = lparas["nts"] + beta = lparas["Beta"] + + revnet = mndiff_rev() + fname = get_current_save(directory,saveseries) + if(fname is None): + fname = get_next_save(directory,saveseries) + else: + print("Loading {}".format(fname)) + revnet.load(fname) + + + [images_train, images_test] = mnist_load() + + x = torch.squeeze(images_train[100,:,:]) + #x = beta*torch.randn(28,28) + y = revnet(x,0.5) + + x = x.detach().numpy() + y = y.detach().numpy() + + print("y mn={:1.3f} std={:1.3f} std2*beta={}".format(np.mean(y),np.std(y),np.std(y)**2*beta)) + + + plt.figure() + plt.subplot(2,1,1) + plt.imshow(x,cmap='gray') + plt.subplot(2,1,2) + plt.imshow(y,cmap='gray') + plt.show() + + + return + +########## +## Main ## +########## + +def mainswitch(): + args = sys.argv + if(len(args)<2): + print("mnistdiffusion [operation]") + print("operations:") + print("\ttrain") + print("\tplot_history") + print("\tinfer") + + exit(0) + + if(sys.argv[1]=="train"): + train("./saves","test_2a") + if(sys.argv[1]=="plot_history"): + plot_history("./saves","test_2a") + if(sys.argv[1]=="infer"): + infer("./saves","test_2a") + + + + + +if(__name__=="__main__"): + + ttt = int(time.time()) + np.random.seed(ttt%50000) + torch.manual_seed(ttt%50000) + + #test1() + #test2("./saves","test_2a") + mainswitch() + diff --git a/attempt2/mnistdiffusion_model.py b/attempt2/mnistdiffusion_model.py new file mode 100644 index 0000000..3d33769 --- /dev/null +++ b/attempt2/mnistdiffusion_model.py @@ -0,0 +1,298 @@ +#!/usr/bin/python3 + +import os,sys,math +import numpy as np + +import matplotlib.pyplot as plt + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import mnist_dataloader + + +## Sohl-Dickstein 2015 describes a gaussian "bump" function to +# implement learnable time-dependence +# +#network input vector and normalized time (tn e (0,1)) +# input: (dimx, dimy, ..., features), (scalar) --> output: (dimx, dimy, ...) +# input: (ntimes, ..., features), (ntimes) --> output: (ntimes , ...) +class gaussian_timeselect(nn.Module): + + def __init__(self, nfeatures=5, width = 0.25): + super(gaussian_timeselect,self).__init__() + + tau_j = torch.linspace(0,1,nfeatures).float() + self.tau_j = torch.nn.parameter.Parameter(tau_j, requires_grad=False) + + width = torch.tensor(width,dtype=torch.float32) + self.width = torch.nn.parameter.Parameter(width,requires_grad = False) + + #network input vector and normalized time (tn e (0,1)) + # input: (dimx, dimy, ..., features), (scalar) --> output: (dimx, dimy, ...) + # input: (ntimes, ..., features), (ntimes) --> output: (ntimes , ...) + def forward(self, x, tn): + + ndim = len(x.shape) + + tn = torch.as_tensor(tn) + tn.requires_grad = False + + + if(len(tn.shape)==0): + #case 0 - if tn is a scalar + g = torch.exp(-1.0/(2.0*self.width)*(tn-self.tau_j)**2) + g = g/torch.clip(torch.sum(g),min=0.001) + x = torch.tensordot(x,g,dims=[[ndim-1],[0]]) + else: + #Case 1 - tn is a tensor + nf = self.tau_j.shape[0] + nt = tn.numel() + nt2 = x.shape[0] + ndim = len(x.shape) + nf2 = x.shape[ndim-1] + + if(nt!=nt2): + raise Exception("gaussian_timeselect time mismatch") + if(nf!=nf2): + raise Exception("gaussian_timeselect feature mismatch") + + shp2 = x.shape[1:ndim-1] + nshp2 = shp2.numel() + + x = x.reshape([nt,nshp2,nf]) + tau = self.tau_j.reshape([1,1,nf]) + tau = tau.expand([nt,nshp2,nf]) + tn = tn.reshape([nt,1,1]) + tn = tn.expand([nt,nshp2,nf]) + + g = torch.exp(-1.0/(2.0*self.width)*(tn-tau)**2) + gs = torch.clip(torch.sum(g,axis=2),min=0.0001) + gs = gs.reshape([nt,nshp2,1]).expand([nt,nshp2,nf]) + g = g/gs + + x = torch.sum(x*g,axis=2) + #shp3 = torch.Size(torch.cat(torch.tensor([nt]),torch.tensor(shp2))) + # contact [nt, *shp2] + shp3 = torch.Size([nt,*shp2]) + x = x.reshape(shp3) + + return x + +#The reverse network should output a mean image and a covariance image +#given an input image + +class mndiff_rev(nn.Module): + + def __init__(self): + super(mndiff_rev,self).__init__() + + self.L1 = torch.nn.Conv2d(1,10,3) + self.L2 = torch.nn.Conv2d(10,10,5) + self.L3 = torch.nn.ConvTranspose2d(10,10,5) + self.L4 = torch.nn.ConvTranspose2d(10,10,3) + + #bypass branch + #self.Lb1 = torch.nn.Linear(1,10) + #self.Lb2 = torch.nn.Linear(10,2) + #self.Lb4f = torch.nn.Linear(12,12) + + self.L5 = gaussian_timeselect(10,0.1) + + + #self.L4 = nn.Conv2d(10,10,kernel_size=1) + #channels must be 2nd dimension + # self.L4a = nn.Linear(5,5) + # #equivalent, channels must be last dimension + # self.L4b = nn.Linear(5,5) + + + # self.L5 = gaussian_timeselect(5,0.1) + + #self.L5alt = nn.Linear(5,1) + + #self.RL = nn.LeakyReLU(0.01) + self.tanh = nn.Tanh() + + self.device = torch.device("cpu") + + self.losstime = [] + self.lossrecord = [] + + return + + + def forward(self,x,t): + + #handle data-shape cases + #case 0: x (pix_x, pix_y), t is scalar + #case 1: x (ntimes, pix_x, pix_y), t is (ntimes) + #case 2: x (nbatch, ntimes, pix_x, pix_y), t is (nbatch, ntimes) + + if(len(x.shape)==2): + case = 0 + t = torch.as_tensor(t).to(self.device) + t = t.float() + x = x.reshape([1,1,28,28]) + ntimes = 1 + nbatch = 1 + elif(len(x.shape)==3): + case = 1 + ntimes = x.shape[0] + nbatch = 1 + x = x.reshape([ntimes,1,28,28]) + elif(len(x.shape)==4): + case = 2 + nbatch = x.shape[0] + ntimes = x.shape[1] + #x = x.reshape([nbatch,ntimes,28*28]) + x = x.reshape([nbatch*ntimes,1,28,28]) + t = t.reshape([nbatch*ntimes]) + + y = torch.randn(28*28,device=self.device) + + else: + print("Error: Expect input to be dimension 2,3,4") + raise Exception("Error: Expect input to be dimension 2,3,4") + + s1 = self.L1(x) + s1 = self.tanh(s1) + #print(s1.shape) + s1 = self.L2(s1) + s1 = self.tanh(s1) + #print(s1.shape) + s1 = self.L3(s1) + s1 = self.tanh(s1) + #print(s1.shape) + s1 = self.L4(s1) + s1 = self.tanh(s1) + #print(s1.shape) + s1 = s1.permute([0,2,3,1]) + s2 = self.L5(s1,t) + + eps = s2 + + + if(case==0): + eps = eps.reshape([28,28]) + elif(case==1): + eps = eps.reshape([ntimes,28,28]) + elif(case==2): + eps = eps.reshape([nbatch,ntimes,28,28]) + + return eps + + def to(self,dev): + self.device = torch.device(dev) + ret = super(mndiff_rev,self).to(self.device) + return ret + + + def save(self,fname): + try: + dev = self.device + self.to("cpu") + q = [self.state_dict(), self.losstime, self.lossrecord] + torch.save(q,fname) + self.to(dev) + except: + print("model save: problem saving {}".format(fname)) + + return + + def load(self,fname): + try: + [modelsd,losstime,lossrecord] = torch.load(fname) + self.load_state_dict(modelsd) + self.losstime = losstime + self.lossrecord = lossrecord + except: + print("model load: problem loading {}".format(fname)) + + return + + + def plothistory(self): + x = self.losstime + y = self.lossrecord + plt.figure() + plt.plot(x,y,'k.',markersize=1) + plt.title('Loss History') + plt.show() + + return + + def recordloss(self,loss): + + t = 0 + L = len(self.losstime) + if(L>0): + t = self.losstime[L-1] + self.losstime.append(t+1) + self.lossrecord.append(loss) + + return + + + +########### +## Tests ## +########### + +def test_gaussian_timeselect(): + gts = gaussian_timeselect(5,0.025) + gts = gts.to("cuda") + + x = torch.randn((8,8,5)).to("cuda") + #t = torch.tensor([0,0.5,1]).to("cuda") + t = torch.linspace(0,1,8).to("cuda") + + y = gts(x,0) + y2 = gts(x,t) + y3 = gts(x,0.111) + + print(y) + print(x[:,:,0]) + + print(x.shape) + print(y.shape) + print(y2.shape) + + print(torch.abs(y2[0,:]-y[0,:])>1E-6) + + + + return + +def test_mndiff_rev(): + + mnd = mndiff_rev() + + x0 = torch.randn(28,28) + t0 = 0 + x1 = torch.randn(5,28,28) + t1 = torch.linspace(0,1,5) + + x2 = torch.randn(8,5,28,28) + t2 = torch.linspace(0,1,5) + t2 = t2.reshape([1,5]).expand([8,5]) + + y0 = mnd(x0,t0) + y1 = mnd(x1,t1) + y2 = mnd(x2,t2) + + print(y0.shape) + print(y1.shape) + print(y2.shape) + + + + return + +if(__name__=="__main__"): + + #test_gaussian_timeselect() + test_mndiff_rev() + + pass \ No newline at end of file diff --git a/attempt2/mnistdiffusion_utils.py b/attempt2/mnistdiffusion_utils.py new file mode 100644 index 0000000..bdc488e --- /dev/null +++ b/attempt2/mnistdiffusion_utils.py @@ -0,0 +1,78 @@ +#!/usr/bin/python3 + +import os,sys,math + + +def get_savenumber(fname): + num = 0 + + fn2 = os.path.split(fname)[1] + fn2 = os.path.splitext(fn2)[0] + + L = len(fn2) + L2 = L + for I in range(L-1,0,-1): + c = fn2[I] + if(not c.isnumeric()): + L2 = I+1 + break + nm = fn2[L2:L] + + try: + num = int(nm) + except: + num = 0 + + return num + + +def sorttwopythonlists(list1, list2): + + zipped_pairs = zip(list2, list1) + + z = [x for _, x in sorted(zipped_pairs)] + z2 = sorted(list2) + + return [z,z2] + +def get_current_save(directory,saveseries): + fname = None + + fl = os.listdir(directory) + fl2 = [] + nums = [] + + for f in fl: + if(f.find(saveseries)>=0): + fl2.append(os.path.join(directory,f)) + + for f in fl2: + n = get_savenumber(f) + nums.append(n) + + [fl2,nums] = sorttwopythonlists(fl2,nums) + + if(len(fl2)>0): + fname = fl2[len(fl2)-1] + else: + fname = None + + return fname + +def get_next_save(directory,saveseries): + fname = None + fncurrent = get_current_save(directory,saveseries) + if(fncurrent is None): + fname = os.path.join(directory,"{}{:02d}.pyt".format(saveseries,0)) + else: + N = get_savenumber(fncurrent) + N = N + 1 + fname = os.path.join(directory,"{}{:02d}.pyt".format(saveseries,N)) + + + return fname + +if(__name__=="__main__"): + + #print(get_savenumber("./saves/helloworld05202.py")) + pass diff --git a/attempt2/old/mnistdiffusion.py b/attempt2/old/mnistdiffusion.py new file mode 100644 index 0000000..45ed2b2 --- /dev/null +++ b/attempt2/old/mnistdiffusion.py @@ -0,0 +1,161 @@ +#!/usr/bin/python3 + +import os,sys,math +import numpy as np + +import matplotlib.pyplot as plt + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import mnist_dataloader + +from mnistdiffusion_model import mndiff_rev + +################ +## Dataloader ## +################ + +def mnist_load(): + + [labels_train, labels_test, images_train, images_test] = mnist_dataloader.mnist_load() + + #60000 x 28 x 28 + #60000 x 1 + + #Rescale data to (-1,1) + images_train = images_train.clone().detach().float() + images_test = images_test.clone().detach().float() + + images_train = 2.0*(images_train/255.0)-1.0 + images_test = 2.0*(images_test/255.0)-1.0 + + images_train.requires_grad = False + images_test.requires_grad = False + + return [images_train, images_test] + + + + +################ +## Operations ## +################ + +def training_parameters(): + lparas = dict() + + nts = 200 + t = np.arange(0,nts) + s = 0.1 + f_t = np.cos((t/nts+s)/(1+s)*np.pi/2.0)**2 + a_t = f_t/f_t[0] + B_t = np.zeros(t.shape[0]) + B_t[0:nts-1] = np.clip(1-a_t[1:nts]/a_t[0:nts-1],0,0.999) + B_t[nts-1] = B_t[nts-2] + + lparas["nts"] = nts + lparas["t"] = torch.tensor(t,dtype=torch.float32) + lparas["a_t"] = torch.tensor(a_t,dtype=torch.float32) + lparas["B_t"] = torch.tensor(B_t,dtype=torch.float32) + + #Traiing parameters for a cosine variance schedule + #ref: Nichol and Dharawal 2021 + + return lparas + + +def generate_minibatch(imgs, Ndraw, lparas): + + #I can probably speed this up with compiled CUDA code + + nts = lparas["nts"] + dev = imgs.device + + imbatch = torch.zeros([Ndraw*nts,28,28],dtype=torch.float32).to(dev) + draws = torch.randint(0,imgs.shape[0],(Ndraw,)) + + for I in range(0,Ndraw): + imgseq = torch.zeros([nts,28,28],dtype=torch.float32).to(dev) + imgseq[0,:,:] = imgs[draws[I],:,:] + + beta = torch.kron(lparas["B_t"],torch.ones(28,28)).to(dev) + beta = beta.reshape([nts,28,28]) + + sig = torch.sqrt(beta) + noise = torch.randn((nts,28,28)).to(dev)*sig + + for J in range(1,nts): + imgseq[J,:,:] = imgseq[J-1,:,:] + noise[J-1,:,:] + + I1 = I*nts + I2 = (I+1)*nts + imbatch[I1:I2,:,:] = imgseq[:,:,:] + + return imbatch + +def train_batch(imgbatch,lr,lparas): + + + + + return + + +def train(): + + return + +def infer(): + + return + + + +def test1(): + + [images_train, images_test] = mnist_load() + lparas = training_parameters() + + #plt.figure() + #plt.plot(lparas["t"],lparas["a_t"]) + #plt.show() + + #plt.figure() + #plt.plot(lparas["t"],lparas["B_t"]) + #plt.show() + + images_train = images_train.to("cuda") + + mb = generate_minibatch(images_train,50,lparas) + + img0 = torch.squeeze(mb[0,:,:]).clone().detach().to("cpu") + img1 = torch.squeeze(mb[1,:,:]).clone().detach().to("cpu") + img2 = torch.squeeze(mb[2,:,:]).clone().detach().to("cpu") + img3 = torch.squeeze(mb[30,:,:]).clone().detach().to("cpu") + + + plt.figure() + plt.subplot(2,2,1) + plt.imshow(img0,cmap='gray') + plt.subplot(2,2,2) + plt.imshow(img1,cmap='gray') + plt.subplot(2,2,3) + plt.imshow(img2,cmap='gray') + plt.subplot(2,2,4) + plt.imshow(img3,cmap='gray') + + plt.show() + + return + + +########## +## Main ## +########## + +if(__name__=="__main__"): + + #test1() + test_gaussian_timeselect() diff --git a/attempt2/saves/test_2a00.pyt b/attempt2/saves/test_2a00.pyt new file mode 100644 index 0000000..b77944a Binary files /dev/null and b/attempt2/saves/test_2a00.pyt differ diff --git a/attempt2/saves/test_2a01.pyt b/attempt2/saves/test_2a01.pyt new file mode 100644 index 0000000..a5397f7 Binary files /dev/null and b/attempt2/saves/test_2a01.pyt differ diff --git a/attempt2/saves/test_2a02.pyt b/attempt2/saves/test_2a02.pyt new file mode 100644 index 0000000..aa24c24 Binary files /dev/null and b/attempt2/saves/test_2a02.pyt differ diff --git a/attempt2/scratchpad.ipynb b/attempt2/scratchpad.ipynb new file mode 100644 index 0000000..8968d3c --- /dev/null +++ b/attempt2/scratchpad.ipynb @@ -0,0 +1,85 @@ +{ + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9-final" + }, + "orig_nbformat": 2, + "kernelspec": { + "name": "python3", + "display_name": "Python 3.6.9 64-bit", + "metadata": { + "interpreter": { + "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 2, + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os,sys,math\n", + "import numpy as np\n", + "\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "import mnist_dataloader\n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "torch.Size([])\ntorch.Size([])\n[1, 2, 3]\n" + ] + } + ], + "source": [ + "a = torch.randn(10,14)\n", + "b = a.shape[1:1]\n", + "print(b)\n", + "b.numel()\n", + "\n", + "print(b)\n", + "\n", + "b = torch.Size([1])\n", + "c = torch.Size([2,3])\n", + "d = torch.Size(torch.cat([torch.tensor(b),torch.tensor(c)]))\n", + "\n", + "d = [*b,*c]\n", + "\n", + "print(d)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ] +} \ No newline at end of file diff --git a/attempt3/__pycache__/mnist_dataloader.cpython-36.pyc b/attempt3/__pycache__/mnist_dataloader.cpython-36.pyc new file mode 100644 index 0000000..bdf0c11 Binary files /dev/null and b/attempt3/__pycache__/mnist_dataloader.cpython-36.pyc differ diff --git a/attempt3/__pycache__/mnistdiffusion_model.cpython-36.pyc b/attempt3/__pycache__/mnistdiffusion_model.cpython-36.pyc new file mode 100644 index 0000000..66ae9e3 Binary files /dev/null and b/attempt3/__pycache__/mnistdiffusion_model.cpython-36.pyc differ diff --git a/attempt3/__pycache__/mnistdiffusion_utils.cpython-36.pyc b/attempt3/__pycache__/mnistdiffusion_utils.cpython-36.pyc new file mode 100644 index 0000000..6d838e5 Binary files /dev/null and b/attempt3/__pycache__/mnistdiffusion_utils.cpython-36.pyc differ diff --git a/attempt3/mnist_dataloader.py b/attempt3/mnist_dataloader.py new file mode 100644 index 0000000..be0dff3 --- /dev/null +++ b/attempt3/mnist_dataloader.py @@ -0,0 +1,91 @@ +#!/usr/bin/python3 +import os,sys,math +import numpy as np +import cv2 +import gzip #need to use gzip.open instead of open +import struct + +import torch + +def read_MNIST_label_file(fname): + #fp = gzip.open('./train-labels-idx1-ubyte.gz','rb'); + fp = gzip.open(fname,'rb'); + magic = fp.read(4); + #nitems = np.frombuffer(fp.read(4),dtype=np.int32)[0]; #some sort of endiannes problem + bts = fp.read(4); + #bts = bytereverse(bts); + #nitems = np.frombuffer(bts,dtype=np.int32); + nitems = np.int32(struct.unpack('>I',bts)[0]); #it was a non-native endianness in teh integer encoding + #> < @ - endianness + + bts = fp.read(nitems); + N = len(bts); + labels = np.zeros((N),dtype=np.uint8); + labels = np.frombuffer(bts,dtype=np.uint8,count=N); + #for i in range(0,10): + # bt = fp.read(1); + # labels[i] = np.frombuffer(bt,dtype=np.uint8); + fp.close(); + return labels; + +def read_MNIST_image_file(fname): + fp = gzip.open(fname,'rb'); + magic = fp.read(4); + bts = fp.read(4); + nitems = np.int32(struct.unpack('>I',bts)[0]); + bts = fp.read(4); + nrows = np.int32(struct.unpack('>I',bts)[0]); + bts = fp.read(4); + ncols = np.int32(struct.unpack('>I',bts)[0]); + + images = np.zeros((nitems,nrows,ncols),dtype=np.uint8); + for I in range(0,nitems): + bts = fp.read(nrows*ncols); + img1 = np.frombuffer(bts,dtype=np.uint8,count=nrows*ncols); + img1 = img1.reshape((nrows,ncols)); + images[I,:,:] = img1; + + fp.close(); + + return images; + +#The mnist dataset is small enough to fit entirely in memory +def mnist_load(): + baseloc = "../training_data" + + traindatafile = "train-images-idx3-ubyte.gz" + trainlabelfile = "train-labels-idx1-ubyte.gz" + testdatafile = "t10k-images-idx3-ubyte.gz" + testlabelfile = "t10k-labels-idx1-ubyte.gz" + + traindatafile = os.path.join(baseloc,traindatafile) + trainlabelfile = os.path.join(baseloc,trainlabelfile) + testdatafile = os.path.join(baseloc,testdatafile) + testlabelfile = os.path.join(baseloc,testlabelfile) + + labels_train = read_MNIST_label_file(trainlabelfile) + labels_test = read_MNIST_label_file(testlabelfile) + images_train = read_MNIST_image_file(traindatafile) + images_test = read_MNIST_image_file(testdatafile) + + labels_train = torch.tensor(labels_train,dtype=torch.float32,requires_grad=False) + labels_test = torch.tensor(labels_test,dtype=torch.float32,requires_grad=False) + images_train = torch.tensor(images_train,dtype=torch.float32,requires_grad=False) + images_test = torch.tensor(images_test,dtype=torch.float32,requires_grad=False) + + # #debug + # print(labels_train.shape) + # print(labels_test.shape) + # print(images_train.shape) + # print(images_test.shape) + + + return [labels_train, labels_test, images_train, images_test] + +if(__name__ == "__main__"): + [labels_train, labels_test, images_train, images_test] = mnist_load() + print("Loaded MNIST Data") + print(labels_train.shape) + print(labels_test.shape) + print(images_train.shape) + print(images_test.shape) diff --git a/attempt3/mnistdiffusion.py b/attempt3/mnistdiffusion.py new file mode 100644 index 0000000..0297d8a --- /dev/null +++ b/attempt3/mnistdiffusion.py @@ -0,0 +1,398 @@ +#!/usr/bin/python3 + +import os,sys,math +import numpy as np + +import matplotlib.pyplot as plt + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import mnist_dataloader + +from mnistdiffusion_model import mndiff_rev +from mnistdiffusion_utils import * + +import time + +################ +## Dataloader ## +################ + +def mnist_load(): + + [labels_train, labels_test, images_train, images_test] = mnist_dataloader.mnist_load() + + #60000 x 28 x 28 + #60000 x 1 + + #Rescale data to (-1,1) + images_train = images_train.clone().detach().float() + images_test = images_test.clone().detach().float() + + images_train = 2.0*(images_train/255.0)-1.0 + images_test = 2.0*(images_test/255.0)-1.0 + + images_train.requires_grad = False + images_test.requires_grad = False + + return [images_train, images_test] + + + + +################ +## Operations ## +################ + +# Rather than killing myself figuring out variable variance schedules, +# just use a constant variance schedule to start with + +def training_parameters(): + lparas = dict() + + nts = 200 + Beta = 0.0050 + + lr = 1E-4 + batchsize = 30 + save_every = 200 + record_every = 25 + + B_t = np.ones((nts))*Beta + a_t = np.zeros((nts)) + a_t[0] = 1 + for I in range(1,nts): + a_t[I] = a_t[I-1]*(1-B_t[I-1]) + + t = np.linspace(0,1,nts) #normalized time (t/T) + + lparas["nts"] = nts + lparas["Beta"] = Beta + lparas["t"] = torch.tensor(t,dtype=torch.float32) + lparas["a_t"] = torch.tensor(a_t,dtype=torch.float32) + lparas["B_t"] = torch.tensor(B_t,dtype=torch.float32) + lparas["lr"] = lr + lparas["batchsize"] = batchsize + lparas["save_every"] = save_every + lparas["record_every"] = record_every + return lparas + + +def generate_minibatch(imgs, Ndraw, lparas): + + #I can probably speed this up with compiled CUDA code + + nts = lparas["nts"] + Beta = lparas["Beta"] + + dev = imgs.device + + imbatch = torch.zeros([Ndraw,nts,28,28],dtype=torch.float32).to(dev) + draws = torch.randint(0,imgs.shape[0],(Ndraw,)) + + times = lparas["t"].to(dev) + times = times.reshape([1,nts]).expand([Ndraw,nts]) + + for I in range(0,Ndraw): + imgseq = torch.zeros([nts,28,28],dtype=torch.float32).to(dev) + imgseq[0,:,:] = imgs[draws[I],:,:] + + beta = torch.kron(lparas["B_t"],torch.ones(28,28)).to(dev) + beta = beta.reshape([nts,28,28]) + + sig = torch.sqrt(beta) + noise = torch.randn((nts,28,28)).to(dev)*sig + + for J in range(1,nts): + imgseq[J,:,:] = imgseq[J-1,:,:]*torch.sqrt(1.0-beta[J-1,:,:]) + noise[J-1,:,:] + + + imbatch[I,:,:,:] = imgseq[:,:,:] + + return [imbatch,times] + +def train_batch(revnet,batchimg,batchtime,lparas): + + L = 0.0 + lr = lparas["lr"] + nts = batchtime.shape[1] + + dev = revnet.device + #beta = torch.as_tensor(lparas["Beta"]).float().to(dev) + beta = float(lparas["Beta"]) + + optim = torch.optim.SGD(revnet.parameters(),lr=lr) + optim = torch.optim.Adam(revnet.parameters(),lr=lr) + + loss = torch.nn.MSELoss().to(dev) + + d_est = revnet(batchimg,batchtime) + d_est = d_est[:,1:nts,:,:] + targ = batchimg[:,1:nts,:,:]-batchimg[:,0:nts-1,:,:] + + L2 = loss(targ,beta*d_est) + + L2.backward() + + optim.step() + L = L2.detach().cpu().numpy() + L = L/beta**2 + + loss.zero_grad() + + return L + + +def train(directory,saveseries): + + print("Loading dataset...") + [images_train, images_test] = mnist_load() + lparas = training_parameters() + + batchsize = lparas["batchsize"] + save_every = lparas["save_every"] + record_every = lparas["record_every"] + device = torch.device("cuda") + + revnet = mndiff_rev() + fname = get_current_save(directory,saveseries) + if(fname is None): + fname = get_next_save(directory,saveseries) + else: + print("Loading {}".format(fname)) + revnet.load(fname) + + images_train = images_train.to(device) + revnet = revnet.to(device) + + ## Training Loop + I = 0 + J = 0 + while(True): + + if(I%record_every==0): + print("Minibatch {}".format(I)) + + #New minibatch + losses = np.zeros((record_every)) + + + #draw minibatch + [mb,mbt] = generate_minibatch(images_train,batchsize,lparas) + + #train minibatch + L = train_batch(revnet,mb,mbt,lparas) + losses[J] = L + J = J + 1 + + #record results + if(J%record_every==0): + J = 0 + loss = np.mean(losses) + revnet.recordloss(loss) + print("\tloss={:1.5f}".format(loss)) + + #save + if(I%save_every==0): + fnext = get_next_save(directory,saveseries) + revnet.save(fnext) + + I = I + 1 + + return + +def infer(directory,saveseries): + + ttt = int(time.time()) + np.random.seed(ttt%50000) + torch.manual_seed(ttt%50000) + + + + lparas = training_parameters() + revnet = mndiff_rev() + fname = get_current_save(directory,saveseries) + if(fname is None): + print("Warning, no state loaded.") + else: + print("Loading {}".format(fname)) + revnet.load(fname) + + + nts = lparas["nts"] + beta = lparas["Beta"] + + nts = nts*15 + + imgseq = torch.zeros([nts,28,28]) + imgseq[nts-1,:,:] = torch.randn(28,28) + + #imgseq[nts-1,:,17:20] = 3.0 + + tl = torch.linspace(0,1,nts) + + bb = torch.as_tensor(beta) + + for I in range(nts-2,-1,-1): + img = imgseq[I+1,:,:] + t = tl[I] + eps = revnet(img,t) + + + + img2 = img*torch.sqrt(1-bb) - bb*eps + 1.0*bb*torch.randn(28,28) + + imgseq[I,:,:] = img2 + + im0 = torch.squeeze(imgseq[0,:,:]) + im1 = torch.squeeze(imgseq[int(nts*0.25),:,:]) + im2 = torch.squeeze(imgseq[int(nts*0.75),:,:]) + im3 = torch.squeeze(imgseq[int(nts*0.90),:,:]) + + im0 = im0.detach().numpy() + im1 = im1.detach().numpy() + im2 = im2.detach().numpy() + im3 = im3.detach().numpy() + + plt.figure() + plt.subplot(2,2,1) + plt.imshow(im3,cmap='gray') + plt.subplot(2,2,2) + plt.imshow(im2,cmap='gray') + plt.subplot(2,2,3) + plt.imshow(im1,cmap='gray') + plt.subplot(2,2,4) + plt.imshow(im0,cmap='gray') + plt.show() + + + return + +def plot_history(directory, saveseries): + + revnet = mndiff_rev() + fname = get_current_save(directory,saveseries) + if(fname is None): + fname = get_next_save(directory,saveseries) + else: + print("Loading {}".format(fname)) + revnet.load(fname) + + revnet.plothistory() + + return + +def test1(): + + [images_train, images_test] = mnist_load() + lparas = training_parameters() + + #plt.figure() + #plt.plot(lparas["t"],lparas["a_t"]) + #plt.show() + + #plt.figure() + #plt.plot(lparas["t"],lparas["B_t"]) + #plt.show() + + images_train = images_train.to("cuda") + + [mb,mbtime] = generate_minibatch(images_train,50,lparas) + + img0 = torch.squeeze(mb[0,0,:,:]).clone().detach().to("cpu") + img1 = torch.squeeze(mb[0,1,:,:]).clone().detach().to("cpu") + img2 = torch.squeeze(mb[0,2,:,:]).clone().detach().to("cpu") + img3 = torch.squeeze(mb[0,99,:,:]).clone().detach().to("cpu") + + print(mb.shape) + print(mbtime.shape) + + plt.figure() + plt.subplot(2,2,1) + plt.imshow(img0,cmap='gray') + plt.subplot(2,2,2) + plt.imshow(img1,cmap='gray') + plt.subplot(2,2,3) + plt.imshow(img2,cmap='gray') + plt.subplot(2,2,4) + plt.imshow(img3,cmap='gray') + + plt.show() + + return + +def test2(directory, saveseries): + + lparas = training_parameters() + nts = lparas["nts"] + beta = lparas["Beta"] + + revnet = mndiff_rev() + fname = get_current_save(directory,saveseries) + if(fname is None): + fname = get_next_save(directory,saveseries) + else: + print("Loading {}".format(fname)) + revnet.load(fname) + + + [images_train, images_test] = mnist_load() + + x = torch.squeeze(images_train[100,:,:]) + #x = beta*torch.randn(28,28) + y = revnet(x,0.5) + + x = x.detach().numpy() + y = y.detach().numpy() + + print("y mn={:1.3f} std={:1.3f} std2*beta={}".format(np.mean(y),np.std(y),np.std(y)**2*beta)) + + + plt.figure() + plt.subplot(2,1,1) + plt.imshow(x,cmap='gray') + plt.subplot(2,1,2) + plt.imshow(y,cmap='gray') + plt.show() + + + return + +########## +## Main ## +########## + +def mainswitch(): + args = sys.argv + if(len(args)<2): + print("mnistdiffusion [operation]") + print("operations:") + print("\ttrain") + print("\tplot_history") + print("\tinfer") + + exit(0) + + if(sys.argv[1]=="train"): + train("./saves","test_3a") + if(sys.argv[1]=="plot_history"): + plot_history("./saves","test_3a") + if(sys.argv[1]=="infer"): + infer("./saves","test_3a") + + + + + +if(__name__=="__main__"): + + ttt = int(time.time()) + np.random.seed(ttt%50000) + torch.manual_seed(ttt%50000) + + #test1() + #test2("./saves","test_2a") + mainswitch() + diff --git a/attempt3/mnistdiffusion_model.py b/attempt3/mnistdiffusion_model.py new file mode 100644 index 0000000..87298ea --- /dev/null +++ b/attempt3/mnistdiffusion_model.py @@ -0,0 +1,300 @@ +#!/usr/bin/python3 + +import os,sys,math +import numpy as np + +import matplotlib.pyplot as plt + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import mnist_dataloader + + +## Sohl-Dickstein 2015 describes a gaussian "bump" function to +# implement learnable time-dependence +# +#network input vector and normalized time (tn e (0,1)) +# input: (dimx, dimy, ..., features), (scalar) --> output: (dimx, dimy, ...) +# input: (ntimes, ..., features), (ntimes) --> output: (ntimes , ...) +class gaussian_timeselect(nn.Module): + + def __init__(self, nfeatures=5, width = 0.25): + super(gaussian_timeselect,self).__init__() + + tau_j = torch.linspace(0,1,nfeatures).float() + self.tau_j = torch.nn.parameter.Parameter(tau_j, requires_grad=False) + + width = torch.tensor(width,dtype=torch.float32) + self.width = torch.nn.parameter.Parameter(width,requires_grad = False) + + #network input vector and normalized time (tn e (0,1)) + # input: (dimx, dimy, ..., features), (scalar) --> output: (dimx, dimy, ...) + # input: (ntimes, ..., features), (ntimes) --> output: (ntimes , ...) + def forward(self, x, tn): + + ndim = len(x.shape) + + tn = torch.as_tensor(tn) + tn.requires_grad = False + + + if(len(tn.shape)==0): + #case 0 - if tn is a scalar + g = torch.exp(-1.0/(2.0*self.width)*(tn-self.tau_j)**2) + g = g/torch.clip(torch.sum(g),min=0.001) + x = torch.tensordot(x,g,dims=[[ndim-1],[0]]) + else: + #Case 1 - tn is a tensor + nf = self.tau_j.shape[0] + nt = tn.numel() + nt2 = x.shape[0] + ndim = len(x.shape) + nf2 = x.shape[ndim-1] + + if(nt!=nt2): + raise Exception("gaussian_timeselect time mismatch") + if(nf!=nf2): + raise Exception("gaussian_timeselect feature mismatch") + + shp2 = x.shape[1:ndim-1] + nshp2 = shp2.numel() + + x = x.reshape([nt,nshp2,nf]) + tau = self.tau_j.reshape([1,1,nf]) + tau = tau.expand([nt,nshp2,nf]) + tn = tn.reshape([nt,1,1]) + tn = tn.expand([nt,nshp2,nf]) + + g = torch.exp(-1.0/(2.0*self.width)*(tn-tau)**2) + gs = torch.clip(torch.sum(g,axis=2),min=0.0001) + gs = gs.reshape([nt,nshp2,1]).expand([nt,nshp2,nf]) + g = g/gs + + x = torch.sum(x*g,axis=2) + #shp3 = torch.Size(torch.cat(torch.tensor([nt]),torch.tensor(shp2))) + # contact [nt, *shp2] + shp3 = torch.Size([nt,*shp2]) + x = x.reshape(shp3) + + return x + +#The reverse network should output a mean image and a covariance image +#given an input image + +class mndiff_rev(nn.Module): + + def __init__(self): + super(mndiff_rev,self).__init__() + + self.nchan = 10 + + self.L1a = torch.nn.Conv2d(1,self.nchan,3,padding=1) + self.L1b = torch.nn.Conv2d(self.nchan,self.nchan,3) + self.L1c = torch.nn.MaxPool2d(kernel_size=2) + + self.L2a = torch.nn.Conv2d(self.nchan,self.nchan,3) + self.L2b = torch.nn.Conv2d(self.nchan,self.nchan,3) + + self.L3a = torch.nn.ConvTranspose2d(self.nchan,self.nchan,3) + self.L3b = torch.nn.ConvTranspose2d(self.nchan,self.nchan,3) + + #self.L4a = torch.nn.Upsample(scale_factor=2) + self.L4b = torch.nn.ConvTranspose2d(2*self.nchan,self.nchan,3) + self.L4c = torch.nn.ConvTranspose2d(self.nchan,1,3,padding=1) + + self.nl = nn.Tanh() + + self.device = torch.device("cpu") + + self.losstime = [] + self.lossrecord = [] + + return + + + def forward(self,x,t): + + #handle data-shape cases + #case 0: x (pix_x, pix_y), t is scalar + #case 1: x (ntimes, pix_x, pix_y), t is (ntimes) + #case 2: x (nbatch, ntimes, pix_x, pix_y), t is (nbatch, ntimes) + + if(len(x.shape)==2): + case = 0 + t = torch.as_tensor(t).to(self.device) + t = t.float() + x = x.reshape([1,1,28,28]) + ntimes = 1 + nbatch = 1 + elif(len(x.shape)==3): + case = 1 + ntimes = x.shape[0] + nbatch = 1 + x = x.reshape([ntimes,1,28,28]) + elif(len(x.shape)==4): + case = 2 + nbatch = x.shape[0] + ntimes = x.shape[1] + #x = x.reshape([nbatch,ntimes,28*28]) + x = x.reshape([nbatch*ntimes,1,28,28]) + t = t.reshape([nbatch*ntimes]) + else: + print("Error: Expect input to be dimension 2,3,4") + raise Exception("Error: Expect input to be dimension 2,3,4") + + #block 1 + s1 = self.L1a(x) + s1 = self.nl(s1) + s1 = self.L1b(s1) #skip connection + s1 = self.nl(s1) + s2 = self.L1c(s1) + + #block 2 + s2 = self.L2a(s2) + s2 = self.nl(s2) + s2 = self.L2b(s2) + s2 = self.nl(s2) + + #block 3 + s2 = self.L3a(s2) + s2 = self.nl(s2) + s2 = self.L3b(s2) + s2 = self.nl(s2) + + #block 4 + #s2 = self.L4a(s2) + s2 = torch.nn.functional.interpolate(s2,scale_factor=[2,2]) + #concat skip connection + s3 = torch.cat([s1,s2],axis=1) + + s3 = self.L4b(s3) + s3 = self.nl(s3) + s3 = self.L4c(s3) + + #output + eps = s3 + + if(case==0): + eps = eps.reshape([28,28]) + elif(case==1): + eps = eps.reshape([ntimes,28,28]) + elif(case==2): + eps = eps.reshape([nbatch,ntimes,28,28]) + + return eps + + def to(self,dev): + self.device = torch.device(dev) + ret = super(mndiff_rev,self).to(self.device) + return ret + + + def save(self,fname): + try: + dev = self.device + self.to("cpu") + q = [self.state_dict(), self.losstime, self.lossrecord] + torch.save(q,fname) + self.to(dev) + except: + print("model save: problem saving {}".format(fname)) + + return + + def load(self,fname): + try: + [modelsd,losstime,lossrecord] = torch.load(fname) + self.load_state_dict(modelsd) + self.losstime = losstime + self.lossrecord = lossrecord + except: + print("model load: problem loading {}".format(fname)) + + return + + + def plothistory(self): + x = self.losstime + y = self.lossrecord + plt.figure() + plt.plot(x,y,'k.',markersize=1) + plt.title('Loss History') + plt.show() + + return + + def recordloss(self,loss): + + t = 0 + L = len(self.losstime) + if(L>0): + t = self.losstime[L-1] + self.losstime.append(t+1) + self.lossrecord.append(loss) + + return + + + +########### +## Tests ## +########### + +def test_gaussian_timeselect(): + gts = gaussian_timeselect(5,0.025) + gts = gts.to("cuda") + + x = torch.randn((8,8,5)).to("cuda") + #t = torch.tensor([0,0.5,1]).to("cuda") + t = torch.linspace(0,1,8).to("cuda") + + y = gts(x,0) + y2 = gts(x,t) + y3 = gts(x,0.111) + + print(y) + print(x[:,:,0]) + + print(x.shape) + print(y.shape) + print(y2.shape) + + print(torch.abs(y2[0,:]-y[0,:])>1E-6) + + + + return + +def test_mndiff_rev(): + + mnd = mndiff_rev() + + x0 = torch.randn(28,28) + t0 = 0 + x1 = torch.randn(5,28,28) + t1 = torch.linspace(0,1,5) + + x2 = torch.randn(8,5,28,28) + t2 = torch.linspace(0,1,5) + t2 = t2.reshape([1,5]).expand([8,5]) + + y0 = mnd(x0,t0) + y1 = mnd(x1,t1) + y2 = mnd(x2,t2) + + print(y0.shape) + print(y1.shape) + print(y2.shape) + + + + return + +if(__name__=="__main__"): + + #test_gaussian_timeselect() + test_mndiff_rev() + + pass \ No newline at end of file diff --git a/attempt3/mnistdiffusion_utils.py b/attempt3/mnistdiffusion_utils.py new file mode 100644 index 0000000..bdc488e --- /dev/null +++ b/attempt3/mnistdiffusion_utils.py @@ -0,0 +1,78 @@ +#!/usr/bin/python3 + +import os,sys,math + + +def get_savenumber(fname): + num = 0 + + fn2 = os.path.split(fname)[1] + fn2 = os.path.splitext(fn2)[0] + + L = len(fn2) + L2 = L + for I in range(L-1,0,-1): + c = fn2[I] + if(not c.isnumeric()): + L2 = I+1 + break + nm = fn2[L2:L] + + try: + num = int(nm) + except: + num = 0 + + return num + + +def sorttwopythonlists(list1, list2): + + zipped_pairs = zip(list2, list1) + + z = [x for _, x in sorted(zipped_pairs)] + z2 = sorted(list2) + + return [z,z2] + +def get_current_save(directory,saveseries): + fname = None + + fl = os.listdir(directory) + fl2 = [] + nums = [] + + for f in fl: + if(f.find(saveseries)>=0): + fl2.append(os.path.join(directory,f)) + + for f in fl2: + n = get_savenumber(f) + nums.append(n) + + [fl2,nums] = sorttwopythonlists(fl2,nums) + + if(len(fl2)>0): + fname = fl2[len(fl2)-1] + else: + fname = None + + return fname + +def get_next_save(directory,saveseries): + fname = None + fncurrent = get_current_save(directory,saveseries) + if(fncurrent is None): + fname = os.path.join(directory,"{}{:02d}.pyt".format(saveseries,0)) + else: + N = get_savenumber(fncurrent) + N = N + 1 + fname = os.path.join(directory,"{}{:02d}.pyt".format(saveseries,N)) + + + return fname + +if(__name__=="__main__"): + + #print(get_savenumber("./saves/helloworld05202.py")) + pass diff --git a/attempt3/old/mnistdiffusion.py b/attempt3/old/mnistdiffusion.py new file mode 100644 index 0000000..45ed2b2 --- /dev/null +++ b/attempt3/old/mnistdiffusion.py @@ -0,0 +1,161 @@ +#!/usr/bin/python3 + +import os,sys,math +import numpy as np + +import matplotlib.pyplot as plt + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import mnist_dataloader + +from mnistdiffusion_model import mndiff_rev + +################ +## Dataloader ## +################ + +def mnist_load(): + + [labels_train, labels_test, images_train, images_test] = mnist_dataloader.mnist_load() + + #60000 x 28 x 28 + #60000 x 1 + + #Rescale data to (-1,1) + images_train = images_train.clone().detach().float() + images_test = images_test.clone().detach().float() + + images_train = 2.0*(images_train/255.0)-1.0 + images_test = 2.0*(images_test/255.0)-1.0 + + images_train.requires_grad = False + images_test.requires_grad = False + + return [images_train, images_test] + + + + +################ +## Operations ## +################ + +def training_parameters(): + lparas = dict() + + nts = 200 + t = np.arange(0,nts) + s = 0.1 + f_t = np.cos((t/nts+s)/(1+s)*np.pi/2.0)**2 + a_t = f_t/f_t[0] + B_t = np.zeros(t.shape[0]) + B_t[0:nts-1] = np.clip(1-a_t[1:nts]/a_t[0:nts-1],0,0.999) + B_t[nts-1] = B_t[nts-2] + + lparas["nts"] = nts + lparas["t"] = torch.tensor(t,dtype=torch.float32) + lparas["a_t"] = torch.tensor(a_t,dtype=torch.float32) + lparas["B_t"] = torch.tensor(B_t,dtype=torch.float32) + + #Traiing parameters for a cosine variance schedule + #ref: Nichol and Dharawal 2021 + + return lparas + + +def generate_minibatch(imgs, Ndraw, lparas): + + #I can probably speed this up with compiled CUDA code + + nts = lparas["nts"] + dev = imgs.device + + imbatch = torch.zeros([Ndraw*nts,28,28],dtype=torch.float32).to(dev) + draws = torch.randint(0,imgs.shape[0],(Ndraw,)) + + for I in range(0,Ndraw): + imgseq = torch.zeros([nts,28,28],dtype=torch.float32).to(dev) + imgseq[0,:,:] = imgs[draws[I],:,:] + + beta = torch.kron(lparas["B_t"],torch.ones(28,28)).to(dev) + beta = beta.reshape([nts,28,28]) + + sig = torch.sqrt(beta) + noise = torch.randn((nts,28,28)).to(dev)*sig + + for J in range(1,nts): + imgseq[J,:,:] = imgseq[J-1,:,:] + noise[J-1,:,:] + + I1 = I*nts + I2 = (I+1)*nts + imbatch[I1:I2,:,:] = imgseq[:,:,:] + + return imbatch + +def train_batch(imgbatch,lr,lparas): + + + + + return + + +def train(): + + return + +def infer(): + + return + + + +def test1(): + + [images_train, images_test] = mnist_load() + lparas = training_parameters() + + #plt.figure() + #plt.plot(lparas["t"],lparas["a_t"]) + #plt.show() + + #plt.figure() + #plt.plot(lparas["t"],lparas["B_t"]) + #plt.show() + + images_train = images_train.to("cuda") + + mb = generate_minibatch(images_train,50,lparas) + + img0 = torch.squeeze(mb[0,:,:]).clone().detach().to("cpu") + img1 = torch.squeeze(mb[1,:,:]).clone().detach().to("cpu") + img2 = torch.squeeze(mb[2,:,:]).clone().detach().to("cpu") + img3 = torch.squeeze(mb[30,:,:]).clone().detach().to("cpu") + + + plt.figure() + plt.subplot(2,2,1) + plt.imshow(img0,cmap='gray') + plt.subplot(2,2,2) + plt.imshow(img1,cmap='gray') + plt.subplot(2,2,3) + plt.imshow(img2,cmap='gray') + plt.subplot(2,2,4) + plt.imshow(img3,cmap='gray') + + plt.show() + + return + + +########## +## Main ## +########## + +if(__name__=="__main__"): + + #test1() + test_gaussian_timeselect() diff --git a/attempt3/scratchpad.ipynb b/attempt3/scratchpad.ipynb new file mode 100644 index 0000000..8968d3c --- /dev/null +++ b/attempt3/scratchpad.ipynb @@ -0,0 +1,85 @@ +{ + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9-final" + }, + "orig_nbformat": 2, + "kernelspec": { + "name": "python3", + "display_name": "Python 3.6.9 64-bit", + "metadata": { + "interpreter": { + "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 2, + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os,sys,math\n", + "import numpy as np\n", + "\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "import mnist_dataloader\n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "torch.Size([])\ntorch.Size([])\n[1, 2, 3]\n" + ] + } + ], + "source": [ + "a = torch.randn(10,14)\n", + "b = a.shape[1:1]\n", + "print(b)\n", + "b.numel()\n", + "\n", + "print(b)\n", + "\n", + "b = torch.Size([1])\n", + "c = torch.Size([2,3])\n", + "d = torch.Size(torch.cat([torch.tensor(b),torch.tensor(c)]))\n", + "\n", + "d = [*b,*c]\n", + "\n", + "print(d)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ] +} \ No newline at end of file diff --git a/mnistdiffusion_attempt1_fig1.png b/mnistdiffusion_attempt1_fig1.png new file mode 100644 index 0000000..355789d Binary files /dev/null and b/mnistdiffusion_attempt1_fig1.png differ diff --git a/training_data/MNISTDataset.docx b/training_data/MNISTDataset.docx new file mode 100644 index 0000000..4ff4cb2 Binary files /dev/null and b/training_data/MNISTDataset.docx differ diff --git a/training_data/ams_MNIST_load.py b/training_data/ams_MNIST_load.py new file mode 100644 index 0000000..177cd9c --- /dev/null +++ b/training_data/ams_MNIST_load.py @@ -0,0 +1,138 @@ +#!/usr/bin/python3 + +import sys +sys.path.append('/home/aschinde/workspace/projects_python/library') + +import os,sys,math +import numpy as np +import cv2; + +import gzip #May need to use gzip.open instead of open + +import struct +#struct unpack allows some interpretation of python binary data +#Example +##import struct +## +##data = open("from_fortran.bin", "rb").read() +## +##(eight, N) = struct.unpack("@II", data) +## +##This unpacks the first two fields, assuming they start at the very +##beginning of the file (no padding or extraneous data), and also assuming +##native byte-order (the @ symbol). The Is in the formatting string mean +##"unsigned integer, 32 bits". + +#for integers +#a = int +#a.from_bytes(b'\xaf\xc2R',byteorder='little') +#a.to_bytes(nbytes,byteorder='big') +#analagous operation doens't seem to exist for floats +#what about numpy? + + +#https://www.devdungeon.com/content/working-binary-data-python + +#print("{:02d}".format(2)) +#b = b.fromhex('010203040506') +#b.hex() +#c = b.decode(encoding='utf-8' or 'latin-1' or 'ascii'...) +#print(c) + +#numpy arrays have tobytes +#numpy arrays have frombuffer (converts to dtypes) +# +#q = np.array([15],dtype=np.uint8); +#q.tobytes(); +#q.tobytes(order='C') (options are 'C' and 'F' +#q2 = np.buffer(q.tobytes(),dtype=np.uint8) +#np.frombuffer(buffer,dtype=float,count=-1,offset=0) + +##You could also use the < and > endianess format codes in the struct +##module to achieve the same result: +## +##>>> struct.pack('<2h', *struct.unpack('>2h', original)) +##'\xde\xad\xc0\xde' + +def bytereverse(bts): +## bts2 = bytes(len(bts)); +## for I in range(0,len(bts)): +## bts2[len(bts)-I-1] = bts[I]; + N = len(bts); +## print(N); +## print(bts); +## bts2 = struct.pack('<{}h'.format(N), *struct.unpack('>{}h'.format(N), bts)) + bts2 = bts; + return bts2; + +#Read Labels +def read_MNIST_label_file(fname): + #fp = gzip.open('./train-labels-idx1-ubyte.gz','rb'); + fp = gzip.open(fname,'rb'); + magic = fp.read(4); + #nitems = np.frombuffer(fp.read(4),dtype=np.int32)[0]; #some sort of endiannes problem + bts = fp.read(4); + #bts = bytereverse(bts); + #nitems = np.frombuffer(bts,dtype=np.int32); + nitems = np.int32(struct.unpack('>I',bts)[0]); #it was a non-native endianness in teh integer encoding + #> < @ - endianness + + bts = fp.read(nitems); + N = len(bts); + labels = np.zeros((N),dtype=np.uint8); + labels = np.frombuffer(bts,dtype=np.uint8,count=N); + #for i in range(0,10): + # bt = fp.read(1); + # labels[i] = np.frombuffer(bt,dtype=np.uint8); + fp.close(); + return labels; + +def read_MNIST_image_file(fname): + fp = gzip.open(fname,'rb'); + magic = fp.read(4); + bts = fp.read(4); + nitems = np.int32(struct.unpack('>I',bts)[0]); + bts = fp.read(4); + nrows = np.int32(struct.unpack('>I',bts)[0]); + bts = fp.read(4); + ncols = np.int32(struct.unpack('>I',bts)[0]); + + images = np.zeros((nitems,nrows,ncols),dtype=np.uint8); + for I in range(0,nitems): + bts = fp.read(nrows*ncols); + img1 = np.frombuffer(bts,dtype=np.uint8,count=nrows*ncols); + img1 = img1.reshape((nrows,ncols)); + images[I,:,:] = img1; + + fp.close(); + + return images; + +def read_training_data(): + rootdir = '/home/aschinde/workspace/machinelearning/datasets/MNIST'; + fname1 = 'train-labels-idx1-ubyte.gz'; + fname2 = 'train-images-idx3-ubyte.gz'; + + labels = read_MNIST_label_file(os.path.join(rootdir,fname1)); + images = read_MNIST_image_file(os.path.join(rootdir,fname2)); + + return [labels,images]; + +def read_test_data(): + rootdir = '/home/aschinde/workspace/machinelearning/datasets/MNIST'; + + fname1 = 't10k-labels-idx1-ubyte.gz'; + fname2 = 't10k-images-idx3-ubyte.gz'; + + labels = read_MNIST_label_file(os.path.join(rootdir,fname1)); + images = read_MNIST_image_file(os.path.join(rootdir,fname2)); + + return [labels,images]; + +def show_MNIST_image(img): + import matplotlib.pyplot as plt; + plt.figure(); + plt.imshow(255-img,cmap='gray'); + plt.show(); + return; + diff --git a/training_data/an_mnist_loader.py b/training_data/an_mnist_loader.py new file mode 100644 index 0000000..91c15d4 --- /dev/null +++ b/training_data/an_mnist_loader.py @@ -0,0 +1,92 @@ +#!/usr/bin/python3 + +""" +mnist_loader +~~~~~~~~~~~~ + +A library to load the MNIST image data. For details of the data +structures that are returned, see the doc strings for ``load_data`` +and ``load_data_wrapper``. In practice, ``load_data_wrapper`` is the +function usually called by our neural network code. +""" + +##sigh: If you want it to run today, write it in Python. +##If you want it to run tomorrow, write it in ANYTHING ELSE + +#### Libraries +# Standard library +##import cPickle +import pickle as cPickle +import gzip + +# Third-party libraries +import numpy as np + +def load_data(): + """Return the MNIST data as a tuple containing the training data, + the validation data, and the test data. + + The ``training_data`` is returned as a tuple with two entries. + The first entry contains the actual training images. This is a + numpy ndarray with 50,000 entries. Each entry is, in turn, a + numpy ndarray with 784 values, representing the 28 * 28 = 784 + pixels in a single MNIST image. + + The second entry in the ``training_data`` tuple is a numpy ndarray + containing 50,000 entries. Those entries are just the digit + values (0...9) for the corresponding images contained in the first + entry of the tuple. + + The ``validation_data`` and ``test_data`` are similar, except + each contains only 10,000 images. + + This is a nice data format, but for use in neural networks it's + helpful to modify the format of the ``training_data`` a little. + That's done in the wrapper function ``load_data_wrapper()``, see + below. + """ + #f = gzip.open('../data/mnist.pkl.gz', 'rb') + f = gzip.open('./t10k-images-idx3-ubyte.gz','rb'); + training_data, validation_data, test_data = cPickle.load(f) + f.close() + return (training_data, validation_data, test_data) + +def load_data_wrapper(): + """Return a tuple containing ``(training_data, validation_data, + test_data)``. Based on ``load_data``, but the format is more + convenient for use in our implementation of neural networks. + + In particular, ``training_data`` is a list containing 50,000 + 2-tuples ``(x, y)``. ``x`` is a 784-dimensional numpy.ndarray + containing the input image. ``y`` is a 10-dimensional + numpy.ndarray representing the unit vector corresponding to the + correct digit for ``x``. + + ``validation_data`` and ``test_data`` are lists containing 10,000 + 2-tuples ``(x, y)``. In each case, ``x`` is a 784-dimensional + numpy.ndarry containing the input image, and ``y`` is the + corresponding classification, i.e., the digit values (integers) + corresponding to ``x``. + + Obviously, this means we're using slightly different formats for + the training data and the validation / test data. These formats + turn out to be the most convenient for use in our neural network + code.""" + tr_d, va_d, te_d = load_data() + training_inputs = [np.reshape(x, (784, 1)) for x in tr_d[0]] + training_results = [vectorized_result(y) for y in tr_d[1]] + training_data = zip(training_inputs, training_results) + validation_inputs = [np.reshape(x, (784, 1)) for x in va_d[0]] + validation_data = zip(validation_inputs, va_d[1]) + test_inputs = [np.reshape(x, (784, 1)) for x in te_d[0]] + test_data = zip(test_inputs, te_d[1]) + return (training_data, validation_data, test_data) + +def vectorized_result(j): + """Return a 10-dimensional unit vector with a 1.0 in the jth + position and zeroes elsewhere. This is used to convert a digit + (0...9) into a corresponding desired output from the neural + network.""" + e = np.zeros((10, 1)) + e[j] = 1.0 + return e diff --git a/training_data/t10k-images-idx3-ubyte.gz b/training_data/t10k-images-idx3-ubyte.gz new file mode 100644 index 0000000..5ace8ea Binary files /dev/null and b/training_data/t10k-images-idx3-ubyte.gz differ diff --git a/training_data/t10k-labels-idx1-ubyte.gz b/training_data/t10k-labels-idx1-ubyte.gz new file mode 100644 index 0000000..a7e1415 Binary files /dev/null and b/training_data/t10k-labels-idx1-ubyte.gz differ diff --git a/training_data/train-images-idx3-ubyte.gz b/training_data/train-images-idx3-ubyte.gz new file mode 100644 index 0000000..b50e4b6 Binary files /dev/null and b/training_data/train-images-idx3-ubyte.gz differ diff --git a/training_data/train-labels-idx1-ubyte.gz b/training_data/train-labels-idx1-ubyte.gz new file mode 100644 index 0000000..707a576 Binary files /dev/null and b/training_data/train-labels-idx1-ubyte.gz differ