Making some of my models public.
This commit is contained in:
		
							
								
								
									
										2
									
								
								LICENSE
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								LICENSE
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,2 @@
 | 
			
		||||
Copyright 2023, Aaron M. Schinder
 | 
			
		||||
Released under the MIT/BSD License
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								attempt1/__pycache__/mnist_dataloader.cpython-36.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								attempt1/__pycache__/mnist_dataloader.cpython-36.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								attempt1/__pycache__/mnistdiffusion_model.cpython-36.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								attempt1/__pycache__/mnistdiffusion_model.cpython-36.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								attempt1/__pycache__/mnistdiffusion_utils.cpython-36.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								attempt1/__pycache__/mnistdiffusion_utils.cpython-36.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										91
									
								
								attempt1/mnist_dataloader.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										91
									
								
								attempt1/mnist_dataloader.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,91 @@
 | 
			
		||||
#!/usr/bin/python3
 | 
			
		||||
import os,sys,math
 | 
			
		||||
import numpy as np
 | 
			
		||||
import cv2
 | 
			
		||||
import gzip #need to use gzip.open instead of open
 | 
			
		||||
import struct
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
def read_MNIST_label_file(fname):
 | 
			
		||||
    #fp = gzip.open('./train-labels-idx1-ubyte.gz','rb');
 | 
			
		||||
    fp = gzip.open(fname,'rb');
 | 
			
		||||
    magic = fp.read(4);
 | 
			
		||||
    #nitems = np.frombuffer(fp.read(4),dtype=np.int32)[0]; #some sort of endiannes problem
 | 
			
		||||
    bts = fp.read(4);
 | 
			
		||||
    #bts = bytereverse(bts);
 | 
			
		||||
    #nitems = np.frombuffer(bts,dtype=np.int32);
 | 
			
		||||
    nitems = np.int32(struct.unpack('>I',bts)[0]); #it was a non-native endianness in teh integer encoding
 | 
			
		||||
    #> < @ - endianness
 | 
			
		||||
 | 
			
		||||
    bts = fp.read(nitems);
 | 
			
		||||
    N = len(bts);
 | 
			
		||||
    labels = np.zeros((N),dtype=np.uint8);
 | 
			
		||||
    labels = np.frombuffer(bts,dtype=np.uint8,count=N);
 | 
			
		||||
    #for i in range(0,10):
 | 
			
		||||
    #    bt = fp.read(1);
 | 
			
		||||
    #    labels[i] = np.frombuffer(bt,dtype=np.uint8);
 | 
			
		||||
    fp.close();
 | 
			
		||||
    return labels;
 | 
			
		||||
 | 
			
		||||
def read_MNIST_image_file(fname):
 | 
			
		||||
    fp = gzip.open(fname,'rb');
 | 
			
		||||
    magic = fp.read(4);
 | 
			
		||||
    bts = fp.read(4);
 | 
			
		||||
    nitems = np.int32(struct.unpack('>I',bts)[0]);
 | 
			
		||||
    bts = fp.read(4);
 | 
			
		||||
    nrows = np.int32(struct.unpack('>I',bts)[0]);
 | 
			
		||||
    bts = fp.read(4);
 | 
			
		||||
    ncols = np.int32(struct.unpack('>I',bts)[0]);
 | 
			
		||||
 | 
			
		||||
    images = np.zeros((nitems,nrows,ncols),dtype=np.uint8);
 | 
			
		||||
    for I in range(0,nitems):
 | 
			
		||||
        bts = fp.read(nrows*ncols);
 | 
			
		||||
        img1 = np.frombuffer(bts,dtype=np.uint8,count=nrows*ncols);
 | 
			
		||||
        img1 = img1.reshape((nrows,ncols));
 | 
			
		||||
        images[I,:,:] = img1;
 | 
			
		||||
 | 
			
		||||
    fp.close();
 | 
			
		||||
 | 
			
		||||
    return images;
 | 
			
		||||
 | 
			
		||||
#The mnist dataset is small enough to fit entirely in memory
 | 
			
		||||
def mnist_load():
 | 
			
		||||
    baseloc = "../training_data"
 | 
			
		||||
 | 
			
		||||
    traindatafile = "train-images-idx3-ubyte.gz"
 | 
			
		||||
    trainlabelfile = "train-labels-idx1-ubyte.gz"
 | 
			
		||||
    testdatafile = "t10k-images-idx3-ubyte.gz"
 | 
			
		||||
    testlabelfile = "t10k-labels-idx1-ubyte.gz"
 | 
			
		||||
 | 
			
		||||
    traindatafile = os.path.join(baseloc,traindatafile)
 | 
			
		||||
    trainlabelfile = os.path.join(baseloc,trainlabelfile)
 | 
			
		||||
    testdatafile = os.path.join(baseloc,testdatafile)
 | 
			
		||||
    testlabelfile = os.path.join(baseloc,testlabelfile)
 | 
			
		||||
    
 | 
			
		||||
    labels_train = read_MNIST_label_file(trainlabelfile)
 | 
			
		||||
    labels_test = read_MNIST_label_file(testlabelfile)
 | 
			
		||||
    images_train = read_MNIST_image_file(traindatafile)
 | 
			
		||||
    images_test = read_MNIST_image_file(testdatafile)
 | 
			
		||||
 | 
			
		||||
    labels_train = torch.tensor(labels_train,dtype=torch.float32,requires_grad=False)
 | 
			
		||||
    labels_test = torch.tensor(labels_test,dtype=torch.float32,requires_grad=False)
 | 
			
		||||
    images_train = torch.tensor(images_train,dtype=torch.float32,requires_grad=False)
 | 
			
		||||
    images_test = torch.tensor(images_test,dtype=torch.float32,requires_grad=False)
 | 
			
		||||
    
 | 
			
		||||
    # #debug
 | 
			
		||||
    # print(labels_train.shape)
 | 
			
		||||
    # print(labels_test.shape)
 | 
			
		||||
    # print(images_train.shape)
 | 
			
		||||
    # print(images_test.shape)
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    return [labels_train, labels_test, images_train, images_test]
 | 
			
		||||
 | 
			
		||||
if(__name__ == "__main__"):
 | 
			
		||||
    [labels_train, labels_test, images_train, images_test] = mnist_load()
 | 
			
		||||
    print("Loaded MNIST Data")
 | 
			
		||||
    print(labels_train.shape)
 | 
			
		||||
    print(labels_test.shape)
 | 
			
		||||
    print(images_train.shape)
 | 
			
		||||
    print(images_test.shape)
 | 
			
		||||
							
								
								
									
										401
									
								
								attempt1/mnistdiffusion.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										401
									
								
								attempt1/mnistdiffusion.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,401 @@
 | 
			
		||||
#!/usr/bin/python3
 | 
			
		||||
 | 
			
		||||
import os,sys,math
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
import matplotlib.pyplot as plt
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
 | 
			
		||||
import mnist_dataloader
 | 
			
		||||
 | 
			
		||||
from mnistdiffusion_model import mndiff_rev
 | 
			
		||||
from mnistdiffusion_utils import *
 | 
			
		||||
 | 
			
		||||
import time
 | 
			
		||||
 | 
			
		||||
################
 | 
			
		||||
## Dataloader ##
 | 
			
		||||
################
 | 
			
		||||
 | 
			
		||||
def mnist_load():
 | 
			
		||||
 | 
			
		||||
    [labels_train, labels_test, images_train, images_test] = mnist_dataloader.mnist_load()
 | 
			
		||||
 | 
			
		||||
    #60000 x 28 x 28
 | 
			
		||||
    #60000 x 1
 | 
			
		||||
 | 
			
		||||
    #Rescale data to (-1,1)
 | 
			
		||||
    images_train = images_train.clone().detach().float()
 | 
			
		||||
    images_test = images_test.clone().detach().float()
 | 
			
		||||
 | 
			
		||||
    images_train = 2.0*(images_train/255.0)-1.0
 | 
			
		||||
    images_test = 2.0*(images_test/255.0)-1.0
 | 
			
		||||
    
 | 
			
		||||
    images_train.requires_grad = False
 | 
			
		||||
    images_test.requires_grad = False
 | 
			
		||||
 | 
			
		||||
    return [images_train, images_test]
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
################
 | 
			
		||||
## Operations ##
 | 
			
		||||
################
 | 
			
		||||
 | 
			
		||||
# Rather than killing myself figuring out variable variance schedules,
 | 
			
		||||
# just use a constant variance schedule to start with
 | 
			
		||||
 | 
			
		||||
def training_parameters():
 | 
			
		||||
    lparas = dict()
 | 
			
		||||
 | 
			
		||||
    nts = 200
 | 
			
		||||
    Beta = 0.0050
 | 
			
		||||
 | 
			
		||||
    lr = 1E-4
 | 
			
		||||
    batchsize = 80
 | 
			
		||||
    save_every = 200
 | 
			
		||||
    record_every = 25
 | 
			
		||||
 | 
			
		||||
    B_t = np.ones((nts))*Beta
 | 
			
		||||
    a_t = np.zeros((nts))
 | 
			
		||||
    a_t[0] = 1
 | 
			
		||||
    for I in range(1,nts):
 | 
			
		||||
        a_t[I] = a_t[I-1]*(1-B_t[I-1])
 | 
			
		||||
 | 
			
		||||
    t = np.linspace(0,1,nts) #normalized time (t/T)
 | 
			
		||||
 | 
			
		||||
    lparas["nts"] = nts
 | 
			
		||||
    lparas["Beta"] = Beta
 | 
			
		||||
    lparas["t"] = torch.tensor(t,dtype=torch.float32)
 | 
			
		||||
    lparas["a_t"] = torch.tensor(a_t,dtype=torch.float32)
 | 
			
		||||
    lparas["B_t"] = torch.tensor(B_t,dtype=torch.float32)
 | 
			
		||||
    lparas["lr"] = lr
 | 
			
		||||
    lparas["batchsize"] = batchsize
 | 
			
		||||
    lparas["save_every"] = save_every
 | 
			
		||||
    lparas["record_every"] = record_every
 | 
			
		||||
    return lparas
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_minibatch(imgs, Ndraw, lparas):
 | 
			
		||||
 | 
			
		||||
    #I can probably speed this up with compiled CUDA code
 | 
			
		||||
 | 
			
		||||
    nts = lparas["nts"]
 | 
			
		||||
    Beta = lparas["Beta"]
 | 
			
		||||
 | 
			
		||||
    dev = imgs.device
 | 
			
		||||
 | 
			
		||||
    imbatch = torch.zeros([Ndraw,nts,28,28],dtype=torch.float32).to(dev)
 | 
			
		||||
    draws = torch.randint(0,imgs.shape[0],(Ndraw,))
 | 
			
		||||
 | 
			
		||||
    times = lparas["t"].to(dev)
 | 
			
		||||
    times = times.reshape([1,nts]).expand([Ndraw,nts])
 | 
			
		||||
 | 
			
		||||
    for I in range(0,Ndraw):
 | 
			
		||||
        imgseq = torch.zeros([nts,28,28],dtype=torch.float32).to(dev)
 | 
			
		||||
        imgseq[0,:,:] = imgs[draws[I],:,:]
 | 
			
		||||
 | 
			
		||||
        beta = torch.kron(lparas["B_t"],torch.ones(28,28)).to(dev)
 | 
			
		||||
        beta = beta.reshape([nts,28,28])
 | 
			
		||||
        
 | 
			
		||||
        sig = torch.sqrt(beta)
 | 
			
		||||
        noise = torch.randn((nts,28,28)).to(dev)*sig
 | 
			
		||||
 | 
			
		||||
        for J in range(1,nts):
 | 
			
		||||
            imgseq[J,:,:] = imgseq[J-1,:,:]*torch.sqrt(1.0-beta[J-1,:,:]) + noise[J-1,:,:]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        imbatch[I,:,:,:] = imgseq[:,:,:]
 | 
			
		||||
 | 
			
		||||
    return [imbatch,times]
 | 
			
		||||
 | 
			
		||||
def train_batch(revnet,batchimg,batchtime,lparas):
 | 
			
		||||
 | 
			
		||||
    L = 0.0
 | 
			
		||||
    lr = lparas["lr"]
 | 
			
		||||
    nts = batchtime.shape[1]
 | 
			
		||||
 | 
			
		||||
    dev = revnet.device
 | 
			
		||||
    #beta = torch.as_tensor(lparas["Beta"]).float().to(dev)
 | 
			
		||||
    beta = float(lparas["Beta"])
 | 
			
		||||
 | 
			
		||||
    optim = torch.optim.SGD(revnet.parameters(),lr=lr)
 | 
			
		||||
    optim = torch.optim.Adam(revnet.parameters(),lr=lr)
 | 
			
		||||
    
 | 
			
		||||
    loss = torch.nn.MSELoss().to(dev)
 | 
			
		||||
 | 
			
		||||
    d_est = revnet(batchimg,batchtime)
 | 
			
		||||
    d_est = d_est[:,1:nts,:,:]
 | 
			
		||||
    targ = batchimg[:,1:nts,:,:]-batchimg[:,0:nts-1,:,:]
 | 
			
		||||
 | 
			
		||||
    L2 = loss(targ,beta*d_est)
 | 
			
		||||
 | 
			
		||||
    L2.backward()
 | 
			
		||||
 | 
			
		||||
    optim.step()
 | 
			
		||||
    L = L2.detach().cpu().numpy()
 | 
			
		||||
    L = L/beta**2
 | 
			
		||||
 | 
			
		||||
    loss.zero_grad()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    return L
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def train(directory,saveseries):
 | 
			
		||||
 | 
			
		||||
    print("Loading dataset...")
 | 
			
		||||
    [images_train, images_test] = mnist_load()
 | 
			
		||||
    lparas = training_parameters()
 | 
			
		||||
 | 
			
		||||
    batchsize = lparas["batchsize"]
 | 
			
		||||
    save_every = lparas["save_every"]
 | 
			
		||||
    record_every = lparas["record_every"]
 | 
			
		||||
    device = torch.device("cuda")
 | 
			
		||||
 | 
			
		||||
    revnet = mndiff_rev()
 | 
			
		||||
    fname = get_current_save(directory,saveseries)
 | 
			
		||||
    if(fname is None):
 | 
			
		||||
        fname = get_next_save(directory,saveseries)
 | 
			
		||||
    else:
 | 
			
		||||
        print("Loading {}".format(fname))
 | 
			
		||||
        revnet.load(fname)
 | 
			
		||||
 | 
			
		||||
    images_train = images_train.to(device)
 | 
			
		||||
    revnet = revnet.to(device)
 | 
			
		||||
    
 | 
			
		||||
    ## Training Loop
 | 
			
		||||
    I = 0
 | 
			
		||||
    J = 0
 | 
			
		||||
    while(True):
 | 
			
		||||
 | 
			
		||||
        if(I%record_every==0):
 | 
			
		||||
            print("Minibatch {}".format(I))
 | 
			
		||||
 | 
			
		||||
        #New minibatch
 | 
			
		||||
        losses = np.zeros((record_every))
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        #draw minibatch
 | 
			
		||||
        [mb,mbt] = generate_minibatch(images_train,batchsize,lparas)
 | 
			
		||||
 | 
			
		||||
        #train minibatch
 | 
			
		||||
        L = train_batch(revnet,mb,mbt,lparas)
 | 
			
		||||
        losses[J] = L
 | 
			
		||||
        J = J + 1
 | 
			
		||||
 | 
			
		||||
        #record results
 | 
			
		||||
        if(J%record_every==0):
 | 
			
		||||
            J = 0
 | 
			
		||||
            loss = np.mean(losses)
 | 
			
		||||
            revnet.recordloss(loss)
 | 
			
		||||
            print("\tloss={:1.5f}".format(loss))
 | 
			
		||||
 | 
			
		||||
        #save
 | 
			
		||||
        if(I%save_every==0):
 | 
			
		||||
            fnext = get_next_save(directory,saveseries)
 | 
			
		||||
            revnet.save(fnext)
 | 
			
		||||
 | 
			
		||||
        I = I + 1
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
def infer(directory,saveseries):
 | 
			
		||||
 | 
			
		||||
    ttt = int(time.time())
 | 
			
		||||
    np.random.seed(ttt%50000)
 | 
			
		||||
    torch.manual_seed(ttt%50000)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    lparas = training_parameters()
 | 
			
		||||
    revnet = mndiff_rev()
 | 
			
		||||
    fname = get_current_save(directory,saveseries)
 | 
			
		||||
    if(fname is None):
 | 
			
		||||
        print("Warning, no state loaded.")
 | 
			
		||||
    else:
 | 
			
		||||
        print("Loading {}".format(fname))
 | 
			
		||||
        revnet.load(fname)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    nts = lparas["nts"]
 | 
			
		||||
    beta = lparas["Beta"]
 | 
			
		||||
 | 
			
		||||
    nts = nts*15
 | 
			
		||||
 | 
			
		||||
    imgseq = torch.zeros([nts,28,28])
 | 
			
		||||
    imgseq[nts-1,:,:] = torch.randn(28,28)
 | 
			
		||||
 | 
			
		||||
    #imgseq[nts-1,:,17:20] = 3.0
 | 
			
		||||
 | 
			
		||||
    tl = torch.linspace(0,1,nts)
 | 
			
		||||
 | 
			
		||||
    bb = torch.as_tensor(beta)
 | 
			
		||||
 | 
			
		||||
    for I in range(nts-2,-1,-1):
 | 
			
		||||
        img = imgseq[I+1,:,:]
 | 
			
		||||
        t = tl[I]
 | 
			
		||||
        eps = revnet(img,t)
 | 
			
		||||
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        img2 = img*torch.sqrt(1-bb) - bb*eps + 0.5*bb*torch.randn(28,28)
 | 
			
		||||
 | 
			
		||||
        imgseq[I,:,:] = img2
 | 
			
		||||
 | 
			
		||||
    im0 = torch.squeeze(imgseq[0,:,:])
 | 
			
		||||
    im1 = torch.squeeze(imgseq[int(nts*0.25),:,:])
 | 
			
		||||
    im2 = torch.squeeze(imgseq[int(nts*0.75),:,:])
 | 
			
		||||
    im3 = torch.squeeze(imgseq[int(nts*0.90),:,:])
 | 
			
		||||
      
 | 
			
		||||
    im0 = im0.detach().numpy()
 | 
			
		||||
    im1 = im1.detach().numpy()
 | 
			
		||||
    im2 = im2.detach().numpy()
 | 
			
		||||
    im3 = im3.detach().numpy()
 | 
			
		||||
    
 | 
			
		||||
    plt.figure()
 | 
			
		||||
    plt.subplot(2,2,1)
 | 
			
		||||
    plt.imshow(im3,cmap='gray')
 | 
			
		||||
    plt.subplot(2,2,2)
 | 
			
		||||
    plt.imshow(im2,cmap='gray')
 | 
			
		||||
    plt.subplot(2,2,3)
 | 
			
		||||
    plt.imshow(im1,cmap='gray')
 | 
			
		||||
    plt.subplot(2,2,4)
 | 
			
		||||
    plt.imshow(im0,cmap='gray')
 | 
			
		||||
    plt.show()
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
def plot_history(directory, saveseries):
 | 
			
		||||
 | 
			
		||||
    revnet = mndiff_rev()
 | 
			
		||||
    fname = get_current_save(directory,saveseries)
 | 
			
		||||
    if(fname is None):
 | 
			
		||||
        fname = get_next_save(directory,saveseries)
 | 
			
		||||
    else:
 | 
			
		||||
        print("Loading {}".format(fname))
 | 
			
		||||
        revnet.load(fname)
 | 
			
		||||
 | 
			
		||||
    revnet.plothistory()
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
def test1():
 | 
			
		||||
 | 
			
		||||
    [images_train, images_test] = mnist_load()
 | 
			
		||||
    lparas = training_parameters()
 | 
			
		||||
    
 | 
			
		||||
    #plt.figure()
 | 
			
		||||
    #plt.plot(lparas["t"],lparas["a_t"])
 | 
			
		||||
    #plt.show()
 | 
			
		||||
 | 
			
		||||
    #plt.figure()
 | 
			
		||||
    #plt.plot(lparas["t"],lparas["B_t"])
 | 
			
		||||
    #plt.show()
 | 
			
		||||
 | 
			
		||||
    images_train = images_train.to("cuda")
 | 
			
		||||
 | 
			
		||||
    [mb,mbtime] = generate_minibatch(images_train,50,lparas)
 | 
			
		||||
 | 
			
		||||
    img0 = torch.squeeze(mb[0,0,:,:]).clone().detach().to("cpu")
 | 
			
		||||
    img1 = torch.squeeze(mb[0,1,:,:]).clone().detach().to("cpu")
 | 
			
		||||
    img2 = torch.squeeze(mb[0,2,:,:]).clone().detach().to("cpu")
 | 
			
		||||
    img3 = torch.squeeze(mb[0,99,:,:]).clone().detach().to("cpu")
 | 
			
		||||
    
 | 
			
		||||
    print(mb.shape)
 | 
			
		||||
    print(mbtime.shape)
 | 
			
		||||
 | 
			
		||||
    plt.figure()
 | 
			
		||||
    plt.subplot(2,2,1)
 | 
			
		||||
    plt.imshow(img0,cmap='gray')
 | 
			
		||||
    plt.subplot(2,2,2)
 | 
			
		||||
    plt.imshow(img1,cmap='gray')
 | 
			
		||||
    plt.subplot(2,2,3)
 | 
			
		||||
    plt.imshow(img2,cmap='gray')
 | 
			
		||||
    plt.subplot(2,2,4)
 | 
			
		||||
    plt.imshow(img3,cmap='gray')
 | 
			
		||||
 | 
			
		||||
    plt.show()
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
def test2(directory, saveseries):
 | 
			
		||||
 | 
			
		||||
    lparas = training_parameters()
 | 
			
		||||
    nts = lparas["nts"]
 | 
			
		||||
    beta = lparas["Beta"]
 | 
			
		||||
 | 
			
		||||
    revnet = mndiff_rev()
 | 
			
		||||
    fname = get_current_save(directory,saveseries)
 | 
			
		||||
    if(fname is None):
 | 
			
		||||
        fname = get_next_save(directory,saveseries)
 | 
			
		||||
    else:
 | 
			
		||||
        print("Loading {}".format(fname))
 | 
			
		||||
        revnet.load(fname)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    [images_train, images_test] = mnist_load()
 | 
			
		||||
    
 | 
			
		||||
    x = torch.squeeze(images_train[100,:,:])
 | 
			
		||||
    #x = beta*torch.randn(28,28)
 | 
			
		||||
    y = revnet(x,0.5)
 | 
			
		||||
 | 
			
		||||
    x = x.detach().numpy()
 | 
			
		||||
    y = y.detach().numpy()
 | 
			
		||||
 | 
			
		||||
    print("y mn={:1.3f} std={:1.3f} std2*beta={}".format(np.mean(y),np.std(y),np.std(y)**2*beta))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    plt.figure()
 | 
			
		||||
    plt.subplot(2,1,1)
 | 
			
		||||
    plt.imshow(x,cmap='gray')
 | 
			
		||||
    plt.subplot(2,1,2)
 | 
			
		||||
    plt.imshow(y,cmap='gray')
 | 
			
		||||
    plt.show()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
##########
 | 
			
		||||
## Main ##
 | 
			
		||||
##########
 | 
			
		||||
 | 
			
		||||
def mainswitch():
 | 
			
		||||
    args = sys.argv
 | 
			
		||||
    if(len(args)<2):
 | 
			
		||||
        print("mnistdiffusion [operation]")
 | 
			
		||||
        print("operations:")
 | 
			
		||||
        print("\ttrain")
 | 
			
		||||
        print("\tplot_history")
 | 
			
		||||
        print("\tinfer")
 | 
			
		||||
 | 
			
		||||
        exit(0)
 | 
			
		||||
    
 | 
			
		||||
    if(sys.argv[1]=="train"):
 | 
			
		||||
        train("./saves","test_f")
 | 
			
		||||
    if(sys.argv[1]=="plot_history"):
 | 
			
		||||
        plot_history("./saves","test_f")
 | 
			
		||||
    if(sys.argv[1]=="infer"):
 | 
			
		||||
        infer("./saves","test_f")
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if(__name__=="__main__"):
 | 
			
		||||
 | 
			
		||||
    ttt = int(time.time())
 | 
			
		||||
    np.random.seed(ttt%50000)
 | 
			
		||||
    torch.manual_seed(ttt%50000)
 | 
			
		||||
 | 
			
		||||
    #test1()
 | 
			
		||||
    #test2("./saves","test_d")
 | 
			
		||||
    mainswitch()
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										289
									
								
								attempt1/mnistdiffusion_model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										289
									
								
								attempt1/mnistdiffusion_model.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,289 @@
 | 
			
		||||
#!/usr/bin/python3
 | 
			
		||||
 | 
			
		||||
import os,sys,math
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
import matplotlib.pyplot as plt
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
 | 
			
		||||
import mnist_dataloader
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
## Sohl-Dickstein 2015 describes a gaussian "bump" function to 
 | 
			
		||||
#   implement learnable time-dependence
 | 
			
		||||
#
 | 
			
		||||
#network input vector and normalized time (tn e (0,1)) 
 | 
			
		||||
# input: (dimx, dimy, ..., features), (scalar) --> output: (dimx, dimy, ...)
 | 
			
		||||
# input: (ntimes, ..., features), (ntimes) --> output: (ntimes , ...)
 | 
			
		||||
class gaussian_timeselect(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, nfeatures=5, width = 0.25):
 | 
			
		||||
        super(gaussian_timeselect,self).__init__()
 | 
			
		||||
 | 
			
		||||
        tau_j = torch.linspace(0,1,nfeatures).float()
 | 
			
		||||
        self.tau_j = torch.nn.parameter.Parameter(tau_j, requires_grad=False)
 | 
			
		||||
 | 
			
		||||
        width = torch.tensor(width,dtype=torch.float32)
 | 
			
		||||
        self.width = torch.nn.parameter.Parameter(width,requires_grad = False)
 | 
			
		||||
 | 
			
		||||
    #network input vector and normalized time (tn e (0,1)) 
 | 
			
		||||
    # input: (dimx, dimy, ..., features), (scalar) --> output: (dimx, dimy, ...)
 | 
			
		||||
    # input: (ntimes, ..., features), (ntimes) --> output: (ntimes , ...)
 | 
			
		||||
    def forward(self, x, tn):
 | 
			
		||||
 | 
			
		||||
        ndim = len(x.shape)
 | 
			
		||||
 | 
			
		||||
        tn = torch.as_tensor(tn)
 | 
			
		||||
        tn.requires_grad = False
 | 
			
		||||
        
 | 
			
		||||
        
 | 
			
		||||
        if(len(tn.shape)==0):
 | 
			
		||||
            #case 0 - if tn is a scalar
 | 
			
		||||
            g = torch.exp(-1.0/(2.0*self.width)*(tn-self.tau_j)**2)
 | 
			
		||||
            g = g/torch.clip(torch.sum(g),min=0.001)
 | 
			
		||||
            x = torch.tensordot(x,g,dims=[[ndim-1],[0]])
 | 
			
		||||
        else:
 | 
			
		||||
            #Case 1 - tn is a tensor
 | 
			
		||||
            nf = self.tau_j.shape[0]
 | 
			
		||||
            nt = tn.numel()
 | 
			
		||||
            nt2 = x.shape[0]
 | 
			
		||||
            ndim = len(x.shape)
 | 
			
		||||
            nf2 = x.shape[ndim-1]
 | 
			
		||||
            
 | 
			
		||||
            if(nt!=nt2):
 | 
			
		||||
                raise Exception("gaussian_timeselect time mismatch")
 | 
			
		||||
            if(nf!=nf2):
 | 
			
		||||
                raise Exception("gaussian_timeselect feature mismatch")
 | 
			
		||||
        
 | 
			
		||||
            shp2 = x.shape[1:ndim-1]
 | 
			
		||||
            nshp2 = shp2.numel()
 | 
			
		||||
 | 
			
		||||
            x = x.reshape([nt,nshp2,nf])
 | 
			
		||||
            tau = self.tau_j.reshape([1,1,nf])
 | 
			
		||||
            tau = tau.expand([nt,nshp2,nf])
 | 
			
		||||
            tn = tn.reshape([nt,1,1])
 | 
			
		||||
            tn = tn.expand([nt,nshp2,nf])
 | 
			
		||||
 | 
			
		||||
            g = torch.exp(-1.0/(2.0*self.width)*(tn-tau)**2)
 | 
			
		||||
            gs = torch.clip(torch.sum(g,axis=2),min=0.0001)
 | 
			
		||||
            gs = gs.reshape([nt,nshp2,1]).expand([nt,nshp2,nf])
 | 
			
		||||
            g = g/gs
 | 
			
		||||
 | 
			
		||||
            x = torch.sum(x*g,axis=2)
 | 
			
		||||
            #shp3 = torch.Size(torch.cat(torch.tensor([nt]),torch.tensor(shp2)))
 | 
			
		||||
                # contact [nt, *shp2]
 | 
			
		||||
            shp3 = torch.Size([nt,*shp2])
 | 
			
		||||
            x = x.reshape(shp3)
 | 
			
		||||
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
#The reverse network should output a mean image and a covariance image
 | 
			
		||||
#given an input image
 | 
			
		||||
 | 
			
		||||
class mndiff_rev(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        super(mndiff_rev,self).__init__()
 | 
			
		||||
 | 
			
		||||
        self.L1 = nn.Linear(28*28,1000,bias=True)
 | 
			
		||||
        self.L2 = nn.Linear(1000,1000,bias=True)
 | 
			
		||||
        self.L3 = nn.Linear(1000,28*28*5,bias=True)
 | 
			
		||||
        
 | 
			
		||||
        #self.L4 = nn.Conv2d(10,10,kernel_size=1)
 | 
			
		||||
            #channels must be 2nd dimension
 | 
			
		||||
        self.L4a = nn.Linear(5,5) 
 | 
			
		||||
            #equivalent, channels must be last dimension
 | 
			
		||||
        self.L4b = nn.Linear(5,5)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        self.L5 = gaussian_timeselect(5,0.1)
 | 
			
		||||
        
 | 
			
		||||
        #self.L5alt = nn.Linear(5,1)
 | 
			
		||||
 | 
			
		||||
        #self.RL = nn.LeakyReLU(0.01)
 | 
			
		||||
        self.tanh = nn.Tanh()
 | 
			
		||||
 | 
			
		||||
        self.device = torch.device("cpu")
 | 
			
		||||
 | 
			
		||||
        self.losstime = []
 | 
			
		||||
        self.lossrecord = []
 | 
			
		||||
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
    def forward(self,x,t):
 | 
			
		||||
 | 
			
		||||
        #handle data-shape cases
 | 
			
		||||
        #case 0: x (pix_x, pix_y), t is scalar
 | 
			
		||||
        #case 1: x (ntimes, pix_x, pix_y), t is (ntimes)
 | 
			
		||||
        #case 2: x (nbatch, ntimes, pix_x, pix_y), t is (nbatch, ntimes)
 | 
			
		||||
 | 
			
		||||
        if(len(x.shape)==2):
 | 
			
		||||
            case = 0
 | 
			
		||||
            t = torch.as_tensor(t).to(self.device)
 | 
			
		||||
            t = t.float()
 | 
			
		||||
            x = x.reshape([1,28*28])
 | 
			
		||||
            ntimes = 1
 | 
			
		||||
            nbatch = 1
 | 
			
		||||
        elif(len(x.shape)==3):
 | 
			
		||||
            case = 1
 | 
			
		||||
            ntimes = x.shape[0]
 | 
			
		||||
            nbatch = 1
 | 
			
		||||
            x = x.reshape([ntimes,28*28])
 | 
			
		||||
        elif(len(x.shape)==4):
 | 
			
		||||
            case = 2
 | 
			
		||||
            nbatch = x.shape[0]
 | 
			
		||||
            ntimes = x.shape[1]
 | 
			
		||||
            #x = x.reshape([nbatch,ntimes,28*28])
 | 
			
		||||
            x = x.reshape([nbatch*ntimes,28*28])
 | 
			
		||||
            t = t.reshape([nbatch*ntimes])
 | 
			
		||||
 | 
			
		||||
            y = torch.randn(28*28,device=self.device)
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            print("Error: Expect input to be dimension 2,3,4")
 | 
			
		||||
            raise Exception("Error: Expect input to be dimension 2,3,4")
 | 
			
		||||
 | 
			
		||||
        s1 = self.L1(x)
 | 
			
		||||
        s1 = self.tanh(s1)
 | 
			
		||||
        s1 = self.L2(s1)
 | 
			
		||||
        s1 = self.tanh(s1)
 | 
			
		||||
        s1 = self.L3(s1)
 | 
			
		||||
        s1 = self.tanh(s1)
 | 
			
		||||
        s1 = s1.reshape([ntimes*nbatch,28*28,5])
 | 
			
		||||
        s1 = self.L4a(s1)
 | 
			
		||||
        s1 = self.tanh(s1)
 | 
			
		||||
        s1 = self.L4b(s1)
 | 
			
		||||
        #s1 = self.tanh(s1)
 | 
			
		||||
 | 
			
		||||
        s2 = self.L5(s1,t)
 | 
			
		||||
        #s2 = self.L5alt(s1)
 | 
			
		||||
 | 
			
		||||
        eps = s2
 | 
			
		||||
        #eps = self.tanh(s2)
 | 
			
		||||
 | 
			
		||||
        if(case==0):
 | 
			
		||||
            eps = eps.reshape([28,28])
 | 
			
		||||
        elif(case==1):
 | 
			
		||||
            eps = eps.reshape([ntimes,28,28])
 | 
			
		||||
        elif(case==2):
 | 
			
		||||
            eps = eps.reshape([nbatch,ntimes,28,28])
 | 
			
		||||
 | 
			
		||||
        return eps
 | 
			
		||||
 | 
			
		||||
    def to(self,dev):
 | 
			
		||||
        self.device = torch.device(dev)
 | 
			
		||||
        ret = super(mndiff_rev,self).to(self.device)
 | 
			
		||||
        return ret
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def save(self,fname):
 | 
			
		||||
        try:
 | 
			
		||||
            dev = self.device
 | 
			
		||||
            self.to("cpu")
 | 
			
		||||
            q = [self.state_dict(), self.losstime, self.lossrecord]
 | 
			
		||||
            torch.save(q,fname)
 | 
			
		||||
            self.to(dev)
 | 
			
		||||
        except:
 | 
			
		||||
            print("model save: problem saving {}".format(fname))
 | 
			
		||||
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    def load(self,fname):
 | 
			
		||||
        try:
 | 
			
		||||
            [modelsd,losstime,lossrecord] = torch.load(fname)
 | 
			
		||||
            self.load_state_dict(modelsd)
 | 
			
		||||
            self.losstime = losstime
 | 
			
		||||
            self.lossrecord = lossrecord
 | 
			
		||||
        except:
 | 
			
		||||
            print("model load: problem loading {}".format(fname))
 | 
			
		||||
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
    def plothistory(self):
 | 
			
		||||
        x = self.losstime
 | 
			
		||||
        y = self.lossrecord
 | 
			
		||||
        plt.figure()
 | 
			
		||||
        plt.plot(x,y,'k.',markersize=1)
 | 
			
		||||
        plt.title('Loss History')
 | 
			
		||||
        plt.show()
 | 
			
		||||
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    def recordloss(self,loss):
 | 
			
		||||
 | 
			
		||||
        t = 0
 | 
			
		||||
        L = len(self.losstime)
 | 
			
		||||
        if(L>0):
 | 
			
		||||
            t = self.losstime[L-1]
 | 
			
		||||
        self.losstime.append(t+1)
 | 
			
		||||
        self.lossrecord.append(loss)
 | 
			
		||||
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
###########
 | 
			
		||||
## Tests ##
 | 
			
		||||
###########
 | 
			
		||||
 | 
			
		||||
def test_gaussian_timeselect():
 | 
			
		||||
    gts = gaussian_timeselect(5,0.025)
 | 
			
		||||
    gts = gts.to("cuda")
 | 
			
		||||
 | 
			
		||||
    x = torch.randn((8,8,5)).to("cuda")
 | 
			
		||||
    #t = torch.tensor([0,0.5,1]).to("cuda")
 | 
			
		||||
    t = torch.linspace(0,1,8).to("cuda")
 | 
			
		||||
 | 
			
		||||
    y = gts(x,0)
 | 
			
		||||
    y2 = gts(x,t)
 | 
			
		||||
    y3 = gts(x,0.111)
 | 
			
		||||
 | 
			
		||||
    print(y)
 | 
			
		||||
    print(x[:,:,0])
 | 
			
		||||
 | 
			
		||||
    print(x.shape)
 | 
			
		||||
    print(y.shape)
 | 
			
		||||
    print(y2.shape)
 | 
			
		||||
 | 
			
		||||
    print(torch.abs(y2[0,:]-y[0,:])>1E-6)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
def test_mndiff_rev():
 | 
			
		||||
 | 
			
		||||
    mnd = mndiff_rev()
 | 
			
		||||
 | 
			
		||||
    x0 = torch.randn(28,28)
 | 
			
		||||
    t0 = 0
 | 
			
		||||
    x1 = torch.randn(5,28,28)
 | 
			
		||||
    t1 = torch.linspace(0,1,5)
 | 
			
		||||
 | 
			
		||||
    x2 = torch.randn(8,5,28,28)
 | 
			
		||||
    t2 = torch.linspace(0,1,5)
 | 
			
		||||
    t2 = t2.reshape([1,5]).expand([8,5])
 | 
			
		||||
 | 
			
		||||
    y0 = mnd(x0,t0)
 | 
			
		||||
    y1 = mnd(x1,t1)
 | 
			
		||||
    y2 = mnd(x2,t2)
 | 
			
		||||
 | 
			
		||||
    print(y0.shape)
 | 
			
		||||
    print(y1.shape)
 | 
			
		||||
    print(y2.shape)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
if(__name__=="__main__"):
 | 
			
		||||
 | 
			
		||||
    test_gaussian_timeselect()
 | 
			
		||||
    #test_mndiff_rev()
 | 
			
		||||
 | 
			
		||||
    pass
 | 
			
		||||
							
								
								
									
										78
									
								
								attempt1/mnistdiffusion_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										78
									
								
								attempt1/mnistdiffusion_utils.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,78 @@
 | 
			
		||||
#!/usr/bin/python3
 | 
			
		||||
 | 
			
		||||
import os,sys,math
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_savenumber(fname):
 | 
			
		||||
    num = 0
 | 
			
		||||
 | 
			
		||||
    fn2 = os.path.split(fname)[1]
 | 
			
		||||
    fn2 = os.path.splitext(fn2)[0]
 | 
			
		||||
 | 
			
		||||
    L = len(fn2)
 | 
			
		||||
    L2 = L
 | 
			
		||||
    for I in range(L-1,0,-1):
 | 
			
		||||
        c = fn2[I]
 | 
			
		||||
        if(not c.isnumeric()):
 | 
			
		||||
            L2 = I+1
 | 
			
		||||
            break
 | 
			
		||||
    nm = fn2[L2:L]
 | 
			
		||||
    
 | 
			
		||||
    try:
 | 
			
		||||
        num = int(nm)
 | 
			
		||||
    except:
 | 
			
		||||
        num = 0
 | 
			
		||||
 | 
			
		||||
    return num
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def sorttwopythonlists(list1, list2):
 | 
			
		||||
 
 | 
			
		||||
    zipped_pairs = zip(list2, list1)
 | 
			
		||||
 
 | 
			
		||||
    z = [x for _, x in sorted(zipped_pairs)]
 | 
			
		||||
    z2 = sorted(list2)
 | 
			
		||||
 
 | 
			
		||||
    return [z,z2]
 | 
			
		||||
 | 
			
		||||
def get_current_save(directory,saveseries):
 | 
			
		||||
    fname = None
 | 
			
		||||
 | 
			
		||||
    fl = os.listdir(directory)
 | 
			
		||||
    fl2 = []
 | 
			
		||||
    nums = []
 | 
			
		||||
 | 
			
		||||
    for f in fl:
 | 
			
		||||
        if(f.find(saveseries)>=0):
 | 
			
		||||
            fl2.append(os.path.join(directory,f))
 | 
			
		||||
    
 | 
			
		||||
    for f in fl2:
 | 
			
		||||
        n = get_savenumber(f)
 | 
			
		||||
        nums.append(n)
 | 
			
		||||
 | 
			
		||||
    [fl2,nums] = sorttwopythonlists(fl2,nums)
 | 
			
		||||
 | 
			
		||||
    if(len(fl2)>0):
 | 
			
		||||
        fname = fl2[len(fl2)-1]
 | 
			
		||||
    else:
 | 
			
		||||
        fname = None
 | 
			
		||||
    
 | 
			
		||||
    return fname
 | 
			
		||||
 | 
			
		||||
def get_next_save(directory,saveseries):
 | 
			
		||||
    fname = None
 | 
			
		||||
    fncurrent = get_current_save(directory,saveseries)
 | 
			
		||||
    if(fncurrent is None):
 | 
			
		||||
        fname = os.path.join(directory,"{}{:02d}.pyt".format(saveseries,0))
 | 
			
		||||
    else:
 | 
			
		||||
        N = get_savenumber(fncurrent)
 | 
			
		||||
        N = N + 1
 | 
			
		||||
        fname = os.path.join(directory,"{}{:02d}.pyt".format(saveseries,N))
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
    return fname
 | 
			
		||||
 | 
			
		||||
if(__name__=="__main__"):
 | 
			
		||||
 | 
			
		||||
    #print(get_savenumber("./saves/helloworld05202.py"))
 | 
			
		||||
    pass
 | 
			
		||||
							
								
								
									
										161
									
								
								attempt1/old/mnistdiffusion.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										161
									
								
								attempt1/old/mnistdiffusion.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,161 @@
 | 
			
		||||
#!/usr/bin/python3
 | 
			
		||||
 | 
			
		||||
import os,sys,math
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
import matplotlib.pyplot as plt
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
 | 
			
		||||
import mnist_dataloader
 | 
			
		||||
 | 
			
		||||
from mnistdiffusion_model import mndiff_rev
 | 
			
		||||
 | 
			
		||||
################
 | 
			
		||||
## Dataloader ##
 | 
			
		||||
################
 | 
			
		||||
 | 
			
		||||
def mnist_load():
 | 
			
		||||
 | 
			
		||||
    [labels_train, labels_test, images_train, images_test] = mnist_dataloader.mnist_load()
 | 
			
		||||
 | 
			
		||||
    #60000 x 28 x 28
 | 
			
		||||
    #60000 x 1
 | 
			
		||||
 | 
			
		||||
    #Rescale data to (-1,1)
 | 
			
		||||
    images_train = images_train.clone().detach().float()
 | 
			
		||||
    images_test = images_test.clone().detach().float()
 | 
			
		||||
 | 
			
		||||
    images_train = 2.0*(images_train/255.0)-1.0
 | 
			
		||||
    images_test = 2.0*(images_test/255.0)-1.0
 | 
			
		||||
    
 | 
			
		||||
    images_train.requires_grad = False
 | 
			
		||||
    images_test.requires_grad = False
 | 
			
		||||
 | 
			
		||||
    return [images_train, images_test]
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
################
 | 
			
		||||
## Operations ##
 | 
			
		||||
################
 | 
			
		||||
 | 
			
		||||
def training_parameters():
 | 
			
		||||
    lparas = dict()
 | 
			
		||||
 | 
			
		||||
    nts = 200
 | 
			
		||||
    t = np.arange(0,nts)
 | 
			
		||||
    s = 0.1
 | 
			
		||||
    f_t = np.cos((t/nts+s)/(1+s)*np.pi/2.0)**2
 | 
			
		||||
    a_t = f_t/f_t[0]
 | 
			
		||||
    B_t = np.zeros(t.shape[0])
 | 
			
		||||
    B_t[0:nts-1] = np.clip(1-a_t[1:nts]/a_t[0:nts-1],0,0.999)
 | 
			
		||||
    B_t[nts-1] = B_t[nts-2]
 | 
			
		||||
 | 
			
		||||
    lparas["nts"] = nts
 | 
			
		||||
    lparas["t"] = torch.tensor(t,dtype=torch.float32)
 | 
			
		||||
    lparas["a_t"] = torch.tensor(a_t,dtype=torch.float32)
 | 
			
		||||
    lparas["B_t"] = torch.tensor(B_t,dtype=torch.float32)
 | 
			
		||||
 | 
			
		||||
    #Traiing parameters for a cosine variance schedule
 | 
			
		||||
    #ref: Nichol and Dharawal 2021
 | 
			
		||||
 | 
			
		||||
    return lparas
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_minibatch(imgs, Ndraw, lparas):
 | 
			
		||||
 | 
			
		||||
    #I can probably speed this up with compiled CUDA code
 | 
			
		||||
 | 
			
		||||
    nts = lparas["nts"]
 | 
			
		||||
    dev = imgs.device
 | 
			
		||||
 | 
			
		||||
    imbatch = torch.zeros([Ndraw*nts,28,28],dtype=torch.float32).to(dev)
 | 
			
		||||
    draws = torch.randint(0,imgs.shape[0],(Ndraw,))
 | 
			
		||||
 | 
			
		||||
    for I in range(0,Ndraw):
 | 
			
		||||
        imgseq = torch.zeros([nts,28,28],dtype=torch.float32).to(dev)
 | 
			
		||||
        imgseq[0,:,:] = imgs[draws[I],:,:]
 | 
			
		||||
 | 
			
		||||
        beta = torch.kron(lparas["B_t"],torch.ones(28,28)).to(dev)
 | 
			
		||||
        beta = beta.reshape([nts,28,28])
 | 
			
		||||
        
 | 
			
		||||
        sig = torch.sqrt(beta)
 | 
			
		||||
        noise = torch.randn((nts,28,28)).to(dev)*sig
 | 
			
		||||
 | 
			
		||||
        for J in range(1,nts):
 | 
			
		||||
            imgseq[J,:,:] = imgseq[J-1,:,:] + noise[J-1,:,:]
 | 
			
		||||
 | 
			
		||||
        I1 = I*nts
 | 
			
		||||
        I2 = (I+1)*nts
 | 
			
		||||
        imbatch[I1:I2,:,:] = imgseq[:,:,:]
 | 
			
		||||
 | 
			
		||||
    return imbatch
 | 
			
		||||
 | 
			
		||||
def train_batch(imgbatch,lr,lparas):
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def train():
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
def infer():
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test1():
 | 
			
		||||
 | 
			
		||||
    [images_train, images_test] = mnist_load()
 | 
			
		||||
    lparas = training_parameters()
 | 
			
		||||
    
 | 
			
		||||
    #plt.figure()
 | 
			
		||||
    #plt.plot(lparas["t"],lparas["a_t"])
 | 
			
		||||
    #plt.show()
 | 
			
		||||
 | 
			
		||||
    #plt.figure()
 | 
			
		||||
    #plt.plot(lparas["t"],lparas["B_t"])
 | 
			
		||||
    #plt.show()
 | 
			
		||||
 | 
			
		||||
    images_train = images_train.to("cuda")
 | 
			
		||||
 | 
			
		||||
    mb = generate_minibatch(images_train,50,lparas)
 | 
			
		||||
 | 
			
		||||
    img0 = torch.squeeze(mb[0,:,:]).clone().detach().to("cpu")
 | 
			
		||||
    img1 = torch.squeeze(mb[1,:,:]).clone().detach().to("cpu")
 | 
			
		||||
    img2 = torch.squeeze(mb[2,:,:]).clone().detach().to("cpu")
 | 
			
		||||
    img3 = torch.squeeze(mb[30,:,:]).clone().detach().to("cpu")
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    plt.figure()
 | 
			
		||||
    plt.subplot(2,2,1)
 | 
			
		||||
    plt.imshow(img0,cmap='gray')
 | 
			
		||||
    plt.subplot(2,2,2)
 | 
			
		||||
    plt.imshow(img1,cmap='gray')
 | 
			
		||||
    plt.subplot(2,2,3)
 | 
			
		||||
    plt.imshow(img2,cmap='gray')
 | 
			
		||||
    plt.subplot(2,2,4)
 | 
			
		||||
    plt.imshow(img3,cmap='gray')
 | 
			
		||||
 | 
			
		||||
    plt.show()
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
##########
 | 
			
		||||
## Main ##
 | 
			
		||||
##########
 | 
			
		||||
 | 
			
		||||
if(__name__=="__main__"):
 | 
			
		||||
 | 
			
		||||
    #test1()
 | 
			
		||||
    test_gaussian_timeselect()
 | 
			
		||||
							
								
								
									
										85
									
								
								attempt1/scratchpad.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										85
									
								
								attempt1/scratchpad.ipynb
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,85 @@
 | 
			
		||||
{
 | 
			
		||||
 "metadata": {
 | 
			
		||||
  "language_info": {
 | 
			
		||||
   "codemirror_mode": {
 | 
			
		||||
    "name": "ipython",
 | 
			
		||||
    "version": 3
 | 
			
		||||
   },
 | 
			
		||||
   "file_extension": ".py",
 | 
			
		||||
   "mimetype": "text/x-python",
 | 
			
		||||
   "name": "python",
 | 
			
		||||
   "nbconvert_exporter": "python",
 | 
			
		||||
   "pygments_lexer": "ipython3",
 | 
			
		||||
   "version": "3.6.9-final"
 | 
			
		||||
  },
 | 
			
		||||
  "orig_nbformat": 2,
 | 
			
		||||
  "kernelspec": {
 | 
			
		||||
   "name": "python3",
 | 
			
		||||
   "display_name": "Python 3.6.9 64-bit",
 | 
			
		||||
   "metadata": {
 | 
			
		||||
    "interpreter": {
 | 
			
		||||
     "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
 | 
			
		||||
    }
 | 
			
		||||
   }
 | 
			
		||||
  }
 | 
			
		||||
 },
 | 
			
		||||
 "nbformat": 4,
 | 
			
		||||
 "nbformat_minor": 2,
 | 
			
		||||
 "cells": [
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 1,
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "import os,sys,math\n",
 | 
			
		||||
    "import numpy as np\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "import matplotlib.pyplot as plt\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "import torch\n",
 | 
			
		||||
    "import torch.nn as nn\n",
 | 
			
		||||
    "import torch.nn.functional as F\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "import mnist_dataloader\n"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 14,
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [
 | 
			
		||||
    {
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "torch.Size([])\ntorch.Size([])\n[1, 2, 3]\n"
 | 
			
		||||
     ]
 | 
			
		||||
    }
 | 
			
		||||
   ],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "a = torch.randn(10,14)\n",
 | 
			
		||||
    "b = a.shape[1:1]\n",
 | 
			
		||||
    "print(b)\n",
 | 
			
		||||
    "b.numel()\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "print(b)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "b = torch.Size([1])\n",
 | 
			
		||||
    "c = torch.Size([2,3])\n",
 | 
			
		||||
    "d = torch.Size(torch.cat([torch.tensor(b),torch.tensor(c)]))\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "d = [*b,*c]\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "print(d)\n"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": null,
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
   "source": []
 | 
			
		||||
  }
 | 
			
		||||
 ]
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								attempt2/__pycache__/mnist_dataloader.cpython-36.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								attempt2/__pycache__/mnist_dataloader.cpython-36.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								attempt2/__pycache__/mnistdiffusion_model.cpython-36.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								attempt2/__pycache__/mnistdiffusion_model.cpython-36.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								attempt2/__pycache__/mnistdiffusion_utils.cpython-36.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								attempt2/__pycache__/mnistdiffusion_utils.cpython-36.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										91
									
								
								attempt2/mnist_dataloader.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										91
									
								
								attempt2/mnist_dataloader.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,91 @@
 | 
			
		||||
#!/usr/bin/python3
 | 
			
		||||
import os,sys,math
 | 
			
		||||
import numpy as np
 | 
			
		||||
import cv2
 | 
			
		||||
import gzip #need to use gzip.open instead of open
 | 
			
		||||
import struct
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
def read_MNIST_label_file(fname):
 | 
			
		||||
    #fp = gzip.open('./train-labels-idx1-ubyte.gz','rb');
 | 
			
		||||
    fp = gzip.open(fname,'rb');
 | 
			
		||||
    magic = fp.read(4);
 | 
			
		||||
    #nitems = np.frombuffer(fp.read(4),dtype=np.int32)[0]; #some sort of endiannes problem
 | 
			
		||||
    bts = fp.read(4);
 | 
			
		||||
    #bts = bytereverse(bts);
 | 
			
		||||
    #nitems = np.frombuffer(bts,dtype=np.int32);
 | 
			
		||||
    nitems = np.int32(struct.unpack('>I',bts)[0]); #it was a non-native endianness in teh integer encoding
 | 
			
		||||
    #> < @ - endianness
 | 
			
		||||
 | 
			
		||||
    bts = fp.read(nitems);
 | 
			
		||||
    N = len(bts);
 | 
			
		||||
    labels = np.zeros((N),dtype=np.uint8);
 | 
			
		||||
    labels = np.frombuffer(bts,dtype=np.uint8,count=N);
 | 
			
		||||
    #for i in range(0,10):
 | 
			
		||||
    #    bt = fp.read(1);
 | 
			
		||||
    #    labels[i] = np.frombuffer(bt,dtype=np.uint8);
 | 
			
		||||
    fp.close();
 | 
			
		||||
    return labels;
 | 
			
		||||
 | 
			
		||||
def read_MNIST_image_file(fname):
 | 
			
		||||
    fp = gzip.open(fname,'rb');
 | 
			
		||||
    magic = fp.read(4);
 | 
			
		||||
    bts = fp.read(4);
 | 
			
		||||
    nitems = np.int32(struct.unpack('>I',bts)[0]);
 | 
			
		||||
    bts = fp.read(4);
 | 
			
		||||
    nrows = np.int32(struct.unpack('>I',bts)[0]);
 | 
			
		||||
    bts = fp.read(4);
 | 
			
		||||
    ncols = np.int32(struct.unpack('>I',bts)[0]);
 | 
			
		||||
 | 
			
		||||
    images = np.zeros((nitems,nrows,ncols),dtype=np.uint8);
 | 
			
		||||
    for I in range(0,nitems):
 | 
			
		||||
        bts = fp.read(nrows*ncols);
 | 
			
		||||
        img1 = np.frombuffer(bts,dtype=np.uint8,count=nrows*ncols);
 | 
			
		||||
        img1 = img1.reshape((nrows,ncols));
 | 
			
		||||
        images[I,:,:] = img1;
 | 
			
		||||
 | 
			
		||||
    fp.close();
 | 
			
		||||
 | 
			
		||||
    return images;
 | 
			
		||||
 | 
			
		||||
#The mnist dataset is small enough to fit entirely in memory
 | 
			
		||||
def mnist_load():
 | 
			
		||||
    baseloc = "../training_data"
 | 
			
		||||
 | 
			
		||||
    traindatafile = "train-images-idx3-ubyte.gz"
 | 
			
		||||
    trainlabelfile = "train-labels-idx1-ubyte.gz"
 | 
			
		||||
    testdatafile = "t10k-images-idx3-ubyte.gz"
 | 
			
		||||
    testlabelfile = "t10k-labels-idx1-ubyte.gz"
 | 
			
		||||
 | 
			
		||||
    traindatafile = os.path.join(baseloc,traindatafile)
 | 
			
		||||
    trainlabelfile = os.path.join(baseloc,trainlabelfile)
 | 
			
		||||
    testdatafile = os.path.join(baseloc,testdatafile)
 | 
			
		||||
    testlabelfile = os.path.join(baseloc,testlabelfile)
 | 
			
		||||
    
 | 
			
		||||
    labels_train = read_MNIST_label_file(trainlabelfile)
 | 
			
		||||
    labels_test = read_MNIST_label_file(testlabelfile)
 | 
			
		||||
    images_train = read_MNIST_image_file(traindatafile)
 | 
			
		||||
    images_test = read_MNIST_image_file(testdatafile)
 | 
			
		||||
 | 
			
		||||
    labels_train = torch.tensor(labels_train,dtype=torch.float32,requires_grad=False)
 | 
			
		||||
    labels_test = torch.tensor(labels_test,dtype=torch.float32,requires_grad=False)
 | 
			
		||||
    images_train = torch.tensor(images_train,dtype=torch.float32,requires_grad=False)
 | 
			
		||||
    images_test = torch.tensor(images_test,dtype=torch.float32,requires_grad=False)
 | 
			
		||||
    
 | 
			
		||||
    # #debug
 | 
			
		||||
    # print(labels_train.shape)
 | 
			
		||||
    # print(labels_test.shape)
 | 
			
		||||
    # print(images_train.shape)
 | 
			
		||||
    # print(images_test.shape)
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    return [labels_train, labels_test, images_train, images_test]
 | 
			
		||||
 | 
			
		||||
if(__name__ == "__main__"):
 | 
			
		||||
    [labels_train, labels_test, images_train, images_test] = mnist_load()
 | 
			
		||||
    print("Loaded MNIST Data")
 | 
			
		||||
    print(labels_train.shape)
 | 
			
		||||
    print(labels_test.shape)
 | 
			
		||||
    print(images_train.shape)
 | 
			
		||||
    print(images_test.shape)
 | 
			
		||||
							
								
								
									
										401
									
								
								attempt2/mnistdiffusion.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										401
									
								
								attempt2/mnistdiffusion.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,401 @@
 | 
			
		||||
#!/usr/bin/python3
 | 
			
		||||
 | 
			
		||||
import os,sys,math
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
import matplotlib.pyplot as plt
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
 | 
			
		||||
import mnist_dataloader
 | 
			
		||||
 | 
			
		||||
from mnistdiffusion_model import mndiff_rev
 | 
			
		||||
from mnistdiffusion_utils import *
 | 
			
		||||
 | 
			
		||||
import time
 | 
			
		||||
 | 
			
		||||
################
 | 
			
		||||
## Dataloader ##
 | 
			
		||||
################
 | 
			
		||||
 | 
			
		||||
def mnist_load():
 | 
			
		||||
 | 
			
		||||
    [labels_train, labels_test, images_train, images_test] = mnist_dataloader.mnist_load()
 | 
			
		||||
 | 
			
		||||
    #60000 x 28 x 28
 | 
			
		||||
    #60000 x 1
 | 
			
		||||
 | 
			
		||||
    #Rescale data to (-1,1)
 | 
			
		||||
    images_train = images_train.clone().detach().float()
 | 
			
		||||
    images_test = images_test.clone().detach().float()
 | 
			
		||||
 | 
			
		||||
    images_train = 2.0*(images_train/255.0)-1.0
 | 
			
		||||
    images_test = 2.0*(images_test/255.0)-1.0
 | 
			
		||||
    
 | 
			
		||||
    images_train.requires_grad = False
 | 
			
		||||
    images_test.requires_grad = False
 | 
			
		||||
 | 
			
		||||
    return [images_train, images_test]
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
################
 | 
			
		||||
## Operations ##
 | 
			
		||||
################
 | 
			
		||||
 | 
			
		||||
# Rather than killing myself figuring out variable variance schedules,
 | 
			
		||||
# just use a constant variance schedule to start with
 | 
			
		||||
 | 
			
		||||
def training_parameters():
 | 
			
		||||
    lparas = dict()
 | 
			
		||||
 | 
			
		||||
    nts = 200
 | 
			
		||||
    Beta = 0.0050
 | 
			
		||||
 | 
			
		||||
    lr = 1E-4
 | 
			
		||||
    batchsize = 40
 | 
			
		||||
    save_every = 200
 | 
			
		||||
    record_every = 25
 | 
			
		||||
 | 
			
		||||
    B_t = np.ones((nts))*Beta
 | 
			
		||||
    a_t = np.zeros((nts))
 | 
			
		||||
    a_t[0] = 1
 | 
			
		||||
    for I in range(1,nts):
 | 
			
		||||
        a_t[I] = a_t[I-1]*(1-B_t[I-1])
 | 
			
		||||
 | 
			
		||||
    t = np.linspace(0,1,nts) #normalized time (t/T)
 | 
			
		||||
 | 
			
		||||
    lparas["nts"] = nts
 | 
			
		||||
    lparas["Beta"] = Beta
 | 
			
		||||
    lparas["t"] = torch.tensor(t,dtype=torch.float32)
 | 
			
		||||
    lparas["a_t"] = torch.tensor(a_t,dtype=torch.float32)
 | 
			
		||||
    lparas["B_t"] = torch.tensor(B_t,dtype=torch.float32)
 | 
			
		||||
    lparas["lr"] = lr
 | 
			
		||||
    lparas["batchsize"] = batchsize
 | 
			
		||||
    lparas["save_every"] = save_every
 | 
			
		||||
    lparas["record_every"] = record_every
 | 
			
		||||
    return lparas
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_minibatch(imgs, Ndraw, lparas):
 | 
			
		||||
 | 
			
		||||
    #I can probably speed this up with compiled CUDA code
 | 
			
		||||
 | 
			
		||||
    nts = lparas["nts"]
 | 
			
		||||
    Beta = lparas["Beta"]
 | 
			
		||||
 | 
			
		||||
    dev = imgs.device
 | 
			
		||||
 | 
			
		||||
    imbatch = torch.zeros([Ndraw,nts,28,28],dtype=torch.float32).to(dev)
 | 
			
		||||
    draws = torch.randint(0,imgs.shape[0],(Ndraw,))
 | 
			
		||||
 | 
			
		||||
    times = lparas["t"].to(dev)
 | 
			
		||||
    times = times.reshape([1,nts]).expand([Ndraw,nts])
 | 
			
		||||
 | 
			
		||||
    for I in range(0,Ndraw):
 | 
			
		||||
        imgseq = torch.zeros([nts,28,28],dtype=torch.float32).to(dev)
 | 
			
		||||
        imgseq[0,:,:] = imgs[draws[I],:,:]
 | 
			
		||||
 | 
			
		||||
        beta = torch.kron(lparas["B_t"],torch.ones(28,28)).to(dev)
 | 
			
		||||
        beta = beta.reshape([nts,28,28])
 | 
			
		||||
        
 | 
			
		||||
        sig = torch.sqrt(beta)
 | 
			
		||||
        noise = torch.randn((nts,28,28)).to(dev)*sig
 | 
			
		||||
 | 
			
		||||
        for J in range(1,nts):
 | 
			
		||||
            imgseq[J,:,:] = imgseq[J-1,:,:]*torch.sqrt(1.0-beta[J-1,:,:]) + noise[J-1,:,:]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        imbatch[I,:,:,:] = imgseq[:,:,:]
 | 
			
		||||
 | 
			
		||||
    return [imbatch,times]
 | 
			
		||||
 | 
			
		||||
def train_batch(revnet,batchimg,batchtime,lparas):
 | 
			
		||||
 | 
			
		||||
    L = 0.0
 | 
			
		||||
    lr = lparas["lr"]
 | 
			
		||||
    nts = batchtime.shape[1]
 | 
			
		||||
 | 
			
		||||
    dev = revnet.device
 | 
			
		||||
    #beta = torch.as_tensor(lparas["Beta"]).float().to(dev)
 | 
			
		||||
    beta = float(lparas["Beta"])
 | 
			
		||||
 | 
			
		||||
    optim = torch.optim.SGD(revnet.parameters(),lr=lr)
 | 
			
		||||
    optim = torch.optim.Adam(revnet.parameters(),lr=lr)
 | 
			
		||||
    
 | 
			
		||||
    loss = torch.nn.MSELoss().to(dev)
 | 
			
		||||
 | 
			
		||||
    d_est = revnet(batchimg,batchtime)
 | 
			
		||||
    d_est = d_est[:,1:nts,:,:]
 | 
			
		||||
    targ = batchimg[:,1:nts,:,:]-batchimg[:,0:nts-1,:,:]
 | 
			
		||||
 | 
			
		||||
    L2 = loss(targ,beta*d_est)
 | 
			
		||||
 | 
			
		||||
    L2.backward()
 | 
			
		||||
 | 
			
		||||
    optim.step()
 | 
			
		||||
    L = L2.detach().cpu().numpy()
 | 
			
		||||
    L = L/beta**2
 | 
			
		||||
 | 
			
		||||
    loss.zero_grad()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    return L
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def train(directory,saveseries):
 | 
			
		||||
 | 
			
		||||
    print("Loading dataset...")
 | 
			
		||||
    [images_train, images_test] = mnist_load()
 | 
			
		||||
    lparas = training_parameters()
 | 
			
		||||
 | 
			
		||||
    batchsize = lparas["batchsize"]
 | 
			
		||||
    save_every = lparas["save_every"]
 | 
			
		||||
    record_every = lparas["record_every"]
 | 
			
		||||
    device = torch.device("cuda")
 | 
			
		||||
 | 
			
		||||
    revnet = mndiff_rev()
 | 
			
		||||
    fname = get_current_save(directory,saveseries)
 | 
			
		||||
    if(fname is None):
 | 
			
		||||
        fname = get_next_save(directory,saveseries)
 | 
			
		||||
    else:
 | 
			
		||||
        print("Loading {}".format(fname))
 | 
			
		||||
        revnet.load(fname)
 | 
			
		||||
 | 
			
		||||
    images_train = images_train.to(device)
 | 
			
		||||
    revnet = revnet.to(device)
 | 
			
		||||
    
 | 
			
		||||
    ## Training Loop
 | 
			
		||||
    I = 0
 | 
			
		||||
    J = 0
 | 
			
		||||
    while(True):
 | 
			
		||||
 | 
			
		||||
        if(I%record_every==0):
 | 
			
		||||
            print("Minibatch {}".format(I))
 | 
			
		||||
 | 
			
		||||
        #New minibatch
 | 
			
		||||
        losses = np.zeros((record_every))
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        #draw minibatch
 | 
			
		||||
        [mb,mbt] = generate_minibatch(images_train,batchsize,lparas)
 | 
			
		||||
 | 
			
		||||
        #train minibatch
 | 
			
		||||
        L = train_batch(revnet,mb,mbt,lparas)
 | 
			
		||||
        losses[J] = L
 | 
			
		||||
        J = J + 1
 | 
			
		||||
 | 
			
		||||
        #record results
 | 
			
		||||
        if(J%record_every==0):
 | 
			
		||||
            J = 0
 | 
			
		||||
            loss = np.mean(losses)
 | 
			
		||||
            revnet.recordloss(loss)
 | 
			
		||||
            print("\tloss={:1.5f}".format(loss))
 | 
			
		||||
 | 
			
		||||
        #save
 | 
			
		||||
        if(I%save_every==0):
 | 
			
		||||
            fnext = get_next_save(directory,saveseries)
 | 
			
		||||
            revnet.save(fnext)
 | 
			
		||||
 | 
			
		||||
        I = I + 1
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
def infer(directory,saveseries):
 | 
			
		||||
 | 
			
		||||
    ttt = int(time.time())
 | 
			
		||||
    np.random.seed(ttt%50000)
 | 
			
		||||
    torch.manual_seed(ttt%50000)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    lparas = training_parameters()
 | 
			
		||||
    revnet = mndiff_rev()
 | 
			
		||||
    fname = get_current_save(directory,saveseries)
 | 
			
		||||
    if(fname is None):
 | 
			
		||||
        print("Warning, no state loaded.")
 | 
			
		||||
    else:
 | 
			
		||||
        print("Loading {}".format(fname))
 | 
			
		||||
        revnet.load(fname)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    nts = lparas["nts"]
 | 
			
		||||
    beta = lparas["Beta"]
 | 
			
		||||
 | 
			
		||||
    nts = nts*15
 | 
			
		||||
 | 
			
		||||
    imgseq = torch.zeros([nts,28,28])
 | 
			
		||||
    imgseq[nts-1,:,:] = torch.randn(28,28)
 | 
			
		||||
 | 
			
		||||
    #imgseq[nts-1,:,17:20] = 3.0
 | 
			
		||||
 | 
			
		||||
    tl = torch.linspace(0,1,nts)
 | 
			
		||||
 | 
			
		||||
    bb = torch.as_tensor(beta)
 | 
			
		||||
 | 
			
		||||
    for I in range(nts-2,-1,-1):
 | 
			
		||||
        img = imgseq[I+1,:,:]
 | 
			
		||||
        t = tl[I]
 | 
			
		||||
        eps = revnet(img,t)
 | 
			
		||||
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        img2 = img*torch.sqrt(1-bb) - bb*eps + 1.0*bb*torch.randn(28,28)
 | 
			
		||||
 | 
			
		||||
        imgseq[I,:,:] = img2
 | 
			
		||||
 | 
			
		||||
    im0 = torch.squeeze(imgseq[0,:,:])
 | 
			
		||||
    im1 = torch.squeeze(imgseq[int(nts*0.25),:,:])
 | 
			
		||||
    im2 = torch.squeeze(imgseq[int(nts*0.75),:,:])
 | 
			
		||||
    im3 = torch.squeeze(imgseq[int(nts*0.90),:,:])
 | 
			
		||||
      
 | 
			
		||||
    im0 = im0.detach().numpy()
 | 
			
		||||
    im1 = im1.detach().numpy()
 | 
			
		||||
    im2 = im2.detach().numpy()
 | 
			
		||||
    im3 = im3.detach().numpy()
 | 
			
		||||
    
 | 
			
		||||
    plt.figure()
 | 
			
		||||
    plt.subplot(2,2,1)
 | 
			
		||||
    plt.imshow(im3,cmap='gray')
 | 
			
		||||
    plt.subplot(2,2,2)
 | 
			
		||||
    plt.imshow(im2,cmap='gray')
 | 
			
		||||
    plt.subplot(2,2,3)
 | 
			
		||||
    plt.imshow(im1,cmap='gray')
 | 
			
		||||
    plt.subplot(2,2,4)
 | 
			
		||||
    plt.imshow(im0,cmap='gray')
 | 
			
		||||
    plt.show()
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
def plot_history(directory, saveseries):
 | 
			
		||||
 | 
			
		||||
    revnet = mndiff_rev()
 | 
			
		||||
    fname = get_current_save(directory,saveseries)
 | 
			
		||||
    if(fname is None):
 | 
			
		||||
        fname = get_next_save(directory,saveseries)
 | 
			
		||||
    else:
 | 
			
		||||
        print("Loading {}".format(fname))
 | 
			
		||||
        revnet.load(fname)
 | 
			
		||||
 | 
			
		||||
    revnet.plothistory()
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
def test1():
 | 
			
		||||
 | 
			
		||||
    [images_train, images_test] = mnist_load()
 | 
			
		||||
    lparas = training_parameters()
 | 
			
		||||
    
 | 
			
		||||
    #plt.figure()
 | 
			
		||||
    #plt.plot(lparas["t"],lparas["a_t"])
 | 
			
		||||
    #plt.show()
 | 
			
		||||
 | 
			
		||||
    #plt.figure()
 | 
			
		||||
    #plt.plot(lparas["t"],lparas["B_t"])
 | 
			
		||||
    #plt.show()
 | 
			
		||||
 | 
			
		||||
    images_train = images_train.to("cuda")
 | 
			
		||||
 | 
			
		||||
    [mb,mbtime] = generate_minibatch(images_train,50,lparas)
 | 
			
		||||
 | 
			
		||||
    img0 = torch.squeeze(mb[0,0,:,:]).clone().detach().to("cpu")
 | 
			
		||||
    img1 = torch.squeeze(mb[0,1,:,:]).clone().detach().to("cpu")
 | 
			
		||||
    img2 = torch.squeeze(mb[0,2,:,:]).clone().detach().to("cpu")
 | 
			
		||||
    img3 = torch.squeeze(mb[0,99,:,:]).clone().detach().to("cpu")
 | 
			
		||||
    
 | 
			
		||||
    print(mb.shape)
 | 
			
		||||
    print(mbtime.shape)
 | 
			
		||||
 | 
			
		||||
    plt.figure()
 | 
			
		||||
    plt.subplot(2,2,1)
 | 
			
		||||
    plt.imshow(img0,cmap='gray')
 | 
			
		||||
    plt.subplot(2,2,2)
 | 
			
		||||
    plt.imshow(img1,cmap='gray')
 | 
			
		||||
    plt.subplot(2,2,3)
 | 
			
		||||
    plt.imshow(img2,cmap='gray')
 | 
			
		||||
    plt.subplot(2,2,4)
 | 
			
		||||
    plt.imshow(img3,cmap='gray')
 | 
			
		||||
 | 
			
		||||
    plt.show()
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
def test2(directory, saveseries):
 | 
			
		||||
 | 
			
		||||
    lparas = training_parameters()
 | 
			
		||||
    nts = lparas["nts"]
 | 
			
		||||
    beta = lparas["Beta"]
 | 
			
		||||
 | 
			
		||||
    revnet = mndiff_rev()
 | 
			
		||||
    fname = get_current_save(directory,saveseries)
 | 
			
		||||
    if(fname is None):
 | 
			
		||||
        fname = get_next_save(directory,saveseries)
 | 
			
		||||
    else:
 | 
			
		||||
        print("Loading {}".format(fname))
 | 
			
		||||
        revnet.load(fname)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    [images_train, images_test] = mnist_load()
 | 
			
		||||
    
 | 
			
		||||
    x = torch.squeeze(images_train[100,:,:])
 | 
			
		||||
    #x = beta*torch.randn(28,28)
 | 
			
		||||
    y = revnet(x,0.5)
 | 
			
		||||
 | 
			
		||||
    x = x.detach().numpy()
 | 
			
		||||
    y = y.detach().numpy()
 | 
			
		||||
 | 
			
		||||
    print("y mn={:1.3f} std={:1.3f} std2*beta={}".format(np.mean(y),np.std(y),np.std(y)**2*beta))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    plt.figure()
 | 
			
		||||
    plt.subplot(2,1,1)
 | 
			
		||||
    plt.imshow(x,cmap='gray')
 | 
			
		||||
    plt.subplot(2,1,2)
 | 
			
		||||
    plt.imshow(y,cmap='gray')
 | 
			
		||||
    plt.show()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
##########
 | 
			
		||||
## Main ##
 | 
			
		||||
##########
 | 
			
		||||
 | 
			
		||||
def mainswitch():
 | 
			
		||||
    args = sys.argv
 | 
			
		||||
    if(len(args)<2):
 | 
			
		||||
        print("mnistdiffusion [operation]")
 | 
			
		||||
        print("operations:")
 | 
			
		||||
        print("\ttrain")
 | 
			
		||||
        print("\tplot_history")
 | 
			
		||||
        print("\tinfer")
 | 
			
		||||
 | 
			
		||||
        exit(0)
 | 
			
		||||
    
 | 
			
		||||
    if(sys.argv[1]=="train"):
 | 
			
		||||
        train("./saves","test_2a")
 | 
			
		||||
    if(sys.argv[1]=="plot_history"):
 | 
			
		||||
        plot_history("./saves","test_2a")
 | 
			
		||||
    if(sys.argv[1]=="infer"):
 | 
			
		||||
        infer("./saves","test_2a")
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if(__name__=="__main__"):
 | 
			
		||||
 | 
			
		||||
    ttt = int(time.time())
 | 
			
		||||
    np.random.seed(ttt%50000)
 | 
			
		||||
    torch.manual_seed(ttt%50000)
 | 
			
		||||
 | 
			
		||||
    #test1()
 | 
			
		||||
    #test2("./saves","test_2a")
 | 
			
		||||
    mainswitch()
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										298
									
								
								attempt2/mnistdiffusion_model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										298
									
								
								attempt2/mnistdiffusion_model.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,298 @@
 | 
			
		||||
#!/usr/bin/python3
 | 
			
		||||
 | 
			
		||||
import os,sys,math
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
import matplotlib.pyplot as plt
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
 | 
			
		||||
import mnist_dataloader
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
## Sohl-Dickstein 2015 describes a gaussian "bump" function to 
 | 
			
		||||
#   implement learnable time-dependence
 | 
			
		||||
#
 | 
			
		||||
#network input vector and normalized time (tn e (0,1)) 
 | 
			
		||||
# input: (dimx, dimy, ..., features), (scalar) --> output: (dimx, dimy, ...)
 | 
			
		||||
# input: (ntimes, ..., features), (ntimes) --> output: (ntimes , ...)
 | 
			
		||||
class gaussian_timeselect(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, nfeatures=5, width = 0.25):
 | 
			
		||||
        super(gaussian_timeselect,self).__init__()
 | 
			
		||||
 | 
			
		||||
        tau_j = torch.linspace(0,1,nfeatures).float()
 | 
			
		||||
        self.tau_j = torch.nn.parameter.Parameter(tau_j, requires_grad=False)
 | 
			
		||||
 | 
			
		||||
        width = torch.tensor(width,dtype=torch.float32)
 | 
			
		||||
        self.width = torch.nn.parameter.Parameter(width,requires_grad = False)
 | 
			
		||||
 | 
			
		||||
    #network input vector and normalized time (tn e (0,1)) 
 | 
			
		||||
    # input: (dimx, dimy, ..., features), (scalar) --> output: (dimx, dimy, ...)
 | 
			
		||||
    # input: (ntimes, ..., features), (ntimes) --> output: (ntimes , ...)
 | 
			
		||||
    def forward(self, x, tn):
 | 
			
		||||
 | 
			
		||||
        ndim = len(x.shape)
 | 
			
		||||
 | 
			
		||||
        tn = torch.as_tensor(tn)
 | 
			
		||||
        tn.requires_grad = False
 | 
			
		||||
        
 | 
			
		||||
        
 | 
			
		||||
        if(len(tn.shape)==0):
 | 
			
		||||
            #case 0 - if tn is a scalar
 | 
			
		||||
            g = torch.exp(-1.0/(2.0*self.width)*(tn-self.tau_j)**2)
 | 
			
		||||
            g = g/torch.clip(torch.sum(g),min=0.001)
 | 
			
		||||
            x = torch.tensordot(x,g,dims=[[ndim-1],[0]])
 | 
			
		||||
        else:
 | 
			
		||||
            #Case 1 - tn is a tensor
 | 
			
		||||
            nf = self.tau_j.shape[0]
 | 
			
		||||
            nt = tn.numel()
 | 
			
		||||
            nt2 = x.shape[0]
 | 
			
		||||
            ndim = len(x.shape)
 | 
			
		||||
            nf2 = x.shape[ndim-1]
 | 
			
		||||
            
 | 
			
		||||
            if(nt!=nt2):
 | 
			
		||||
                raise Exception("gaussian_timeselect time mismatch")
 | 
			
		||||
            if(nf!=nf2):
 | 
			
		||||
                raise Exception("gaussian_timeselect feature mismatch")
 | 
			
		||||
        
 | 
			
		||||
            shp2 = x.shape[1:ndim-1]
 | 
			
		||||
            nshp2 = shp2.numel()
 | 
			
		||||
 | 
			
		||||
            x = x.reshape([nt,nshp2,nf])
 | 
			
		||||
            tau = self.tau_j.reshape([1,1,nf])
 | 
			
		||||
            tau = tau.expand([nt,nshp2,nf])
 | 
			
		||||
            tn = tn.reshape([nt,1,1])
 | 
			
		||||
            tn = tn.expand([nt,nshp2,nf])
 | 
			
		||||
 | 
			
		||||
            g = torch.exp(-1.0/(2.0*self.width)*(tn-tau)**2)
 | 
			
		||||
            gs = torch.clip(torch.sum(g,axis=2),min=0.0001)
 | 
			
		||||
            gs = gs.reshape([nt,nshp2,1]).expand([nt,nshp2,nf])
 | 
			
		||||
            g = g/gs
 | 
			
		||||
 | 
			
		||||
            x = torch.sum(x*g,axis=2)
 | 
			
		||||
            #shp3 = torch.Size(torch.cat(torch.tensor([nt]),torch.tensor(shp2)))
 | 
			
		||||
                # contact [nt, *shp2]
 | 
			
		||||
            shp3 = torch.Size([nt,*shp2])
 | 
			
		||||
            x = x.reshape(shp3)
 | 
			
		||||
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
#The reverse network should output a mean image and a covariance image
 | 
			
		||||
#given an input image
 | 
			
		||||
 | 
			
		||||
class mndiff_rev(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        super(mndiff_rev,self).__init__()
 | 
			
		||||
 | 
			
		||||
        self.L1 = torch.nn.Conv2d(1,10,3)
 | 
			
		||||
        self.L2 = torch.nn.Conv2d(10,10,5)
 | 
			
		||||
        self.L3 = torch.nn.ConvTranspose2d(10,10,5)
 | 
			
		||||
        self.L4 = torch.nn.ConvTranspose2d(10,10,3)
 | 
			
		||||
 | 
			
		||||
        #bypass branch
 | 
			
		||||
        #self.Lb1 = torch.nn.Linear(1,10)
 | 
			
		||||
        #self.Lb2 = torch.nn.Linear(10,2)
 | 
			
		||||
        #self.Lb4f = torch.nn.Linear(12,12)
 | 
			
		||||
 | 
			
		||||
        self.L5 = gaussian_timeselect(10,0.1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        #self.L4 = nn.Conv2d(10,10,kernel_size=1)
 | 
			
		||||
            #channels must be 2nd dimension
 | 
			
		||||
        # self.L4a = nn.Linear(5,5) 
 | 
			
		||||
        #     #equivalent, channels must be last dimension
 | 
			
		||||
        # self.L4b = nn.Linear(5,5)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        # self.L5 = gaussian_timeselect(5,0.1)
 | 
			
		||||
        
 | 
			
		||||
        #self.L5alt = nn.Linear(5,1)
 | 
			
		||||
 | 
			
		||||
        #self.RL = nn.LeakyReLU(0.01)
 | 
			
		||||
        self.tanh = nn.Tanh()
 | 
			
		||||
 | 
			
		||||
        self.device = torch.device("cpu")
 | 
			
		||||
 | 
			
		||||
        self.losstime = []
 | 
			
		||||
        self.lossrecord = []
 | 
			
		||||
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
    def forward(self,x,t):
 | 
			
		||||
 | 
			
		||||
        #handle data-shape cases
 | 
			
		||||
        #case 0: x (pix_x, pix_y), t is scalar
 | 
			
		||||
        #case 1: x (ntimes, pix_x, pix_y), t is (ntimes)
 | 
			
		||||
        #case 2: x (nbatch, ntimes, pix_x, pix_y), t is (nbatch, ntimes)
 | 
			
		||||
 | 
			
		||||
        if(len(x.shape)==2):
 | 
			
		||||
            case = 0
 | 
			
		||||
            t = torch.as_tensor(t).to(self.device)
 | 
			
		||||
            t = t.float()
 | 
			
		||||
            x = x.reshape([1,1,28,28])
 | 
			
		||||
            ntimes = 1
 | 
			
		||||
            nbatch = 1
 | 
			
		||||
        elif(len(x.shape)==3):
 | 
			
		||||
            case = 1
 | 
			
		||||
            ntimes = x.shape[0]
 | 
			
		||||
            nbatch = 1
 | 
			
		||||
            x = x.reshape([ntimes,1,28,28])
 | 
			
		||||
        elif(len(x.shape)==4):
 | 
			
		||||
            case = 2
 | 
			
		||||
            nbatch = x.shape[0]
 | 
			
		||||
            ntimes = x.shape[1]
 | 
			
		||||
            #x = x.reshape([nbatch,ntimes,28*28])
 | 
			
		||||
            x = x.reshape([nbatch*ntimes,1,28,28])
 | 
			
		||||
            t = t.reshape([nbatch*ntimes])
 | 
			
		||||
 | 
			
		||||
            y = torch.randn(28*28,device=self.device)
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            print("Error: Expect input to be dimension 2,3,4")
 | 
			
		||||
            raise Exception("Error: Expect input to be dimension 2,3,4")
 | 
			
		||||
 | 
			
		||||
        s1 = self.L1(x)
 | 
			
		||||
        s1 = self.tanh(s1)
 | 
			
		||||
        #print(s1.shape)
 | 
			
		||||
        s1 = self.L2(s1)
 | 
			
		||||
        s1 = self.tanh(s1)
 | 
			
		||||
        #print(s1.shape)
 | 
			
		||||
        s1 = self.L3(s1)
 | 
			
		||||
        s1 = self.tanh(s1)
 | 
			
		||||
        #print(s1.shape)
 | 
			
		||||
        s1 = self.L4(s1)
 | 
			
		||||
        s1 = self.tanh(s1)
 | 
			
		||||
        #print(s1.shape)
 | 
			
		||||
        s1 = s1.permute([0,2,3,1])
 | 
			
		||||
        s2 = self.L5(s1,t)
 | 
			
		||||
        
 | 
			
		||||
        eps = s2
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        if(case==0):
 | 
			
		||||
            eps = eps.reshape([28,28])
 | 
			
		||||
        elif(case==1):
 | 
			
		||||
            eps = eps.reshape([ntimes,28,28])
 | 
			
		||||
        elif(case==2):
 | 
			
		||||
            eps = eps.reshape([nbatch,ntimes,28,28])
 | 
			
		||||
 | 
			
		||||
        return eps
 | 
			
		||||
 | 
			
		||||
    def to(self,dev):
 | 
			
		||||
        self.device = torch.device(dev)
 | 
			
		||||
        ret = super(mndiff_rev,self).to(self.device)
 | 
			
		||||
        return ret
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def save(self,fname):
 | 
			
		||||
        try:
 | 
			
		||||
            dev = self.device
 | 
			
		||||
            self.to("cpu")
 | 
			
		||||
            q = [self.state_dict(), self.losstime, self.lossrecord]
 | 
			
		||||
            torch.save(q,fname)
 | 
			
		||||
            self.to(dev)
 | 
			
		||||
        except:
 | 
			
		||||
            print("model save: problem saving {}".format(fname))
 | 
			
		||||
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    def load(self,fname):
 | 
			
		||||
        try:
 | 
			
		||||
            [modelsd,losstime,lossrecord] = torch.load(fname)
 | 
			
		||||
            self.load_state_dict(modelsd)
 | 
			
		||||
            self.losstime = losstime
 | 
			
		||||
            self.lossrecord = lossrecord
 | 
			
		||||
        except:
 | 
			
		||||
            print("model load: problem loading {}".format(fname))
 | 
			
		||||
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
    def plothistory(self):
 | 
			
		||||
        x = self.losstime
 | 
			
		||||
        y = self.lossrecord
 | 
			
		||||
        plt.figure()
 | 
			
		||||
        plt.plot(x,y,'k.',markersize=1)
 | 
			
		||||
        plt.title('Loss History')
 | 
			
		||||
        plt.show()
 | 
			
		||||
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    def recordloss(self,loss):
 | 
			
		||||
 | 
			
		||||
        t = 0
 | 
			
		||||
        L = len(self.losstime)
 | 
			
		||||
        if(L>0):
 | 
			
		||||
            t = self.losstime[L-1]
 | 
			
		||||
        self.losstime.append(t+1)
 | 
			
		||||
        self.lossrecord.append(loss)
 | 
			
		||||
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
###########
 | 
			
		||||
## Tests ##
 | 
			
		||||
###########
 | 
			
		||||
 | 
			
		||||
def test_gaussian_timeselect():
 | 
			
		||||
    gts = gaussian_timeselect(5,0.025)
 | 
			
		||||
    gts = gts.to("cuda")
 | 
			
		||||
 | 
			
		||||
    x = torch.randn((8,8,5)).to("cuda")
 | 
			
		||||
    #t = torch.tensor([0,0.5,1]).to("cuda")
 | 
			
		||||
    t = torch.linspace(0,1,8).to("cuda")
 | 
			
		||||
 | 
			
		||||
    y = gts(x,0)
 | 
			
		||||
    y2 = gts(x,t)
 | 
			
		||||
    y3 = gts(x,0.111)
 | 
			
		||||
 | 
			
		||||
    print(y)
 | 
			
		||||
    print(x[:,:,0])
 | 
			
		||||
 | 
			
		||||
    print(x.shape)
 | 
			
		||||
    print(y.shape)
 | 
			
		||||
    print(y2.shape)
 | 
			
		||||
 | 
			
		||||
    print(torch.abs(y2[0,:]-y[0,:])>1E-6)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
def test_mndiff_rev():
 | 
			
		||||
 | 
			
		||||
    mnd = mndiff_rev()
 | 
			
		||||
 | 
			
		||||
    x0 = torch.randn(28,28)
 | 
			
		||||
    t0 = 0
 | 
			
		||||
    x1 = torch.randn(5,28,28)
 | 
			
		||||
    t1 = torch.linspace(0,1,5)
 | 
			
		||||
 | 
			
		||||
    x2 = torch.randn(8,5,28,28)
 | 
			
		||||
    t2 = torch.linspace(0,1,5)
 | 
			
		||||
    t2 = t2.reshape([1,5]).expand([8,5])
 | 
			
		||||
 | 
			
		||||
    y0 = mnd(x0,t0)
 | 
			
		||||
    y1 = mnd(x1,t1)
 | 
			
		||||
    y2 = mnd(x2,t2)
 | 
			
		||||
 | 
			
		||||
    print(y0.shape)
 | 
			
		||||
    print(y1.shape)
 | 
			
		||||
    print(y2.shape)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
if(__name__=="__main__"):
 | 
			
		||||
 | 
			
		||||
    #test_gaussian_timeselect()
 | 
			
		||||
    test_mndiff_rev()
 | 
			
		||||
 | 
			
		||||
    pass
 | 
			
		||||
							
								
								
									
										78
									
								
								attempt2/mnistdiffusion_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										78
									
								
								attempt2/mnistdiffusion_utils.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,78 @@
 | 
			
		||||
#!/usr/bin/python3
 | 
			
		||||
 | 
			
		||||
import os,sys,math
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_savenumber(fname):
 | 
			
		||||
    num = 0
 | 
			
		||||
 | 
			
		||||
    fn2 = os.path.split(fname)[1]
 | 
			
		||||
    fn2 = os.path.splitext(fn2)[0]
 | 
			
		||||
 | 
			
		||||
    L = len(fn2)
 | 
			
		||||
    L2 = L
 | 
			
		||||
    for I in range(L-1,0,-1):
 | 
			
		||||
        c = fn2[I]
 | 
			
		||||
        if(not c.isnumeric()):
 | 
			
		||||
            L2 = I+1
 | 
			
		||||
            break
 | 
			
		||||
    nm = fn2[L2:L]
 | 
			
		||||
    
 | 
			
		||||
    try:
 | 
			
		||||
        num = int(nm)
 | 
			
		||||
    except:
 | 
			
		||||
        num = 0
 | 
			
		||||
 | 
			
		||||
    return num
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def sorttwopythonlists(list1, list2):
 | 
			
		||||
 
 | 
			
		||||
    zipped_pairs = zip(list2, list1)
 | 
			
		||||
 
 | 
			
		||||
    z = [x for _, x in sorted(zipped_pairs)]
 | 
			
		||||
    z2 = sorted(list2)
 | 
			
		||||
 
 | 
			
		||||
    return [z,z2]
 | 
			
		||||
 | 
			
		||||
def get_current_save(directory,saveseries):
 | 
			
		||||
    fname = None
 | 
			
		||||
 | 
			
		||||
    fl = os.listdir(directory)
 | 
			
		||||
    fl2 = []
 | 
			
		||||
    nums = []
 | 
			
		||||
 | 
			
		||||
    for f in fl:
 | 
			
		||||
        if(f.find(saveseries)>=0):
 | 
			
		||||
            fl2.append(os.path.join(directory,f))
 | 
			
		||||
    
 | 
			
		||||
    for f in fl2:
 | 
			
		||||
        n = get_savenumber(f)
 | 
			
		||||
        nums.append(n)
 | 
			
		||||
 | 
			
		||||
    [fl2,nums] = sorttwopythonlists(fl2,nums)
 | 
			
		||||
 | 
			
		||||
    if(len(fl2)>0):
 | 
			
		||||
        fname = fl2[len(fl2)-1]
 | 
			
		||||
    else:
 | 
			
		||||
        fname = None
 | 
			
		||||
    
 | 
			
		||||
    return fname
 | 
			
		||||
 | 
			
		||||
def get_next_save(directory,saveseries):
 | 
			
		||||
    fname = None
 | 
			
		||||
    fncurrent = get_current_save(directory,saveseries)
 | 
			
		||||
    if(fncurrent is None):
 | 
			
		||||
        fname = os.path.join(directory,"{}{:02d}.pyt".format(saveseries,0))
 | 
			
		||||
    else:
 | 
			
		||||
        N = get_savenumber(fncurrent)
 | 
			
		||||
        N = N + 1
 | 
			
		||||
        fname = os.path.join(directory,"{}{:02d}.pyt".format(saveseries,N))
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
    return fname
 | 
			
		||||
 | 
			
		||||
if(__name__=="__main__"):
 | 
			
		||||
 | 
			
		||||
    #print(get_savenumber("./saves/helloworld05202.py"))
 | 
			
		||||
    pass
 | 
			
		||||
							
								
								
									
										161
									
								
								attempt2/old/mnistdiffusion.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										161
									
								
								attempt2/old/mnistdiffusion.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,161 @@
 | 
			
		||||
#!/usr/bin/python3
 | 
			
		||||
 | 
			
		||||
import os,sys,math
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
import matplotlib.pyplot as plt
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
 | 
			
		||||
import mnist_dataloader
 | 
			
		||||
 | 
			
		||||
from mnistdiffusion_model import mndiff_rev
 | 
			
		||||
 | 
			
		||||
################
 | 
			
		||||
## Dataloader ##
 | 
			
		||||
################
 | 
			
		||||
 | 
			
		||||
def mnist_load():
 | 
			
		||||
 | 
			
		||||
    [labels_train, labels_test, images_train, images_test] = mnist_dataloader.mnist_load()
 | 
			
		||||
 | 
			
		||||
    #60000 x 28 x 28
 | 
			
		||||
    #60000 x 1
 | 
			
		||||
 | 
			
		||||
    #Rescale data to (-1,1)
 | 
			
		||||
    images_train = images_train.clone().detach().float()
 | 
			
		||||
    images_test = images_test.clone().detach().float()
 | 
			
		||||
 | 
			
		||||
    images_train = 2.0*(images_train/255.0)-1.0
 | 
			
		||||
    images_test = 2.0*(images_test/255.0)-1.0
 | 
			
		||||
    
 | 
			
		||||
    images_train.requires_grad = False
 | 
			
		||||
    images_test.requires_grad = False
 | 
			
		||||
 | 
			
		||||
    return [images_train, images_test]
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
################
 | 
			
		||||
## Operations ##
 | 
			
		||||
################
 | 
			
		||||
 | 
			
		||||
def training_parameters():
 | 
			
		||||
    lparas = dict()
 | 
			
		||||
 | 
			
		||||
    nts = 200
 | 
			
		||||
    t = np.arange(0,nts)
 | 
			
		||||
    s = 0.1
 | 
			
		||||
    f_t = np.cos((t/nts+s)/(1+s)*np.pi/2.0)**2
 | 
			
		||||
    a_t = f_t/f_t[0]
 | 
			
		||||
    B_t = np.zeros(t.shape[0])
 | 
			
		||||
    B_t[0:nts-1] = np.clip(1-a_t[1:nts]/a_t[0:nts-1],0,0.999)
 | 
			
		||||
    B_t[nts-1] = B_t[nts-2]
 | 
			
		||||
 | 
			
		||||
    lparas["nts"] = nts
 | 
			
		||||
    lparas["t"] = torch.tensor(t,dtype=torch.float32)
 | 
			
		||||
    lparas["a_t"] = torch.tensor(a_t,dtype=torch.float32)
 | 
			
		||||
    lparas["B_t"] = torch.tensor(B_t,dtype=torch.float32)
 | 
			
		||||
 | 
			
		||||
    #Traiing parameters for a cosine variance schedule
 | 
			
		||||
    #ref: Nichol and Dharawal 2021
 | 
			
		||||
 | 
			
		||||
    return lparas
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_minibatch(imgs, Ndraw, lparas):
 | 
			
		||||
 | 
			
		||||
    #I can probably speed this up with compiled CUDA code
 | 
			
		||||
 | 
			
		||||
    nts = lparas["nts"]
 | 
			
		||||
    dev = imgs.device
 | 
			
		||||
 | 
			
		||||
    imbatch = torch.zeros([Ndraw*nts,28,28],dtype=torch.float32).to(dev)
 | 
			
		||||
    draws = torch.randint(0,imgs.shape[0],(Ndraw,))
 | 
			
		||||
 | 
			
		||||
    for I in range(0,Ndraw):
 | 
			
		||||
        imgseq = torch.zeros([nts,28,28],dtype=torch.float32).to(dev)
 | 
			
		||||
        imgseq[0,:,:] = imgs[draws[I],:,:]
 | 
			
		||||
 | 
			
		||||
        beta = torch.kron(lparas["B_t"],torch.ones(28,28)).to(dev)
 | 
			
		||||
        beta = beta.reshape([nts,28,28])
 | 
			
		||||
        
 | 
			
		||||
        sig = torch.sqrt(beta)
 | 
			
		||||
        noise = torch.randn((nts,28,28)).to(dev)*sig
 | 
			
		||||
 | 
			
		||||
        for J in range(1,nts):
 | 
			
		||||
            imgseq[J,:,:] = imgseq[J-1,:,:] + noise[J-1,:,:]
 | 
			
		||||
 | 
			
		||||
        I1 = I*nts
 | 
			
		||||
        I2 = (I+1)*nts
 | 
			
		||||
        imbatch[I1:I2,:,:] = imgseq[:,:,:]
 | 
			
		||||
 | 
			
		||||
    return imbatch
 | 
			
		||||
 | 
			
		||||
def train_batch(imgbatch,lr,lparas):
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def train():
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
def infer():
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test1():
 | 
			
		||||
 | 
			
		||||
    [images_train, images_test] = mnist_load()
 | 
			
		||||
    lparas = training_parameters()
 | 
			
		||||
    
 | 
			
		||||
    #plt.figure()
 | 
			
		||||
    #plt.plot(lparas["t"],lparas["a_t"])
 | 
			
		||||
    #plt.show()
 | 
			
		||||
 | 
			
		||||
    #plt.figure()
 | 
			
		||||
    #plt.plot(lparas["t"],lparas["B_t"])
 | 
			
		||||
    #plt.show()
 | 
			
		||||
 | 
			
		||||
    images_train = images_train.to("cuda")
 | 
			
		||||
 | 
			
		||||
    mb = generate_minibatch(images_train,50,lparas)
 | 
			
		||||
 | 
			
		||||
    img0 = torch.squeeze(mb[0,:,:]).clone().detach().to("cpu")
 | 
			
		||||
    img1 = torch.squeeze(mb[1,:,:]).clone().detach().to("cpu")
 | 
			
		||||
    img2 = torch.squeeze(mb[2,:,:]).clone().detach().to("cpu")
 | 
			
		||||
    img3 = torch.squeeze(mb[30,:,:]).clone().detach().to("cpu")
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    plt.figure()
 | 
			
		||||
    plt.subplot(2,2,1)
 | 
			
		||||
    plt.imshow(img0,cmap='gray')
 | 
			
		||||
    plt.subplot(2,2,2)
 | 
			
		||||
    plt.imshow(img1,cmap='gray')
 | 
			
		||||
    plt.subplot(2,2,3)
 | 
			
		||||
    plt.imshow(img2,cmap='gray')
 | 
			
		||||
    plt.subplot(2,2,4)
 | 
			
		||||
    plt.imshow(img3,cmap='gray')
 | 
			
		||||
 | 
			
		||||
    plt.show()
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
##########
 | 
			
		||||
## Main ##
 | 
			
		||||
##########
 | 
			
		||||
 | 
			
		||||
if(__name__=="__main__"):
 | 
			
		||||
 | 
			
		||||
    #test1()
 | 
			
		||||
    test_gaussian_timeselect()
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								attempt2/saves/test_2a00.pyt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								attempt2/saves/test_2a00.pyt
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								attempt2/saves/test_2a01.pyt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								attempt2/saves/test_2a01.pyt
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								attempt2/saves/test_2a02.pyt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								attempt2/saves/test_2a02.pyt
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										85
									
								
								attempt2/scratchpad.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										85
									
								
								attempt2/scratchpad.ipynb
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,85 @@
 | 
			
		||||
{
 | 
			
		||||
 "metadata": {
 | 
			
		||||
  "language_info": {
 | 
			
		||||
   "codemirror_mode": {
 | 
			
		||||
    "name": "ipython",
 | 
			
		||||
    "version": 3
 | 
			
		||||
   },
 | 
			
		||||
   "file_extension": ".py",
 | 
			
		||||
   "mimetype": "text/x-python",
 | 
			
		||||
   "name": "python",
 | 
			
		||||
   "nbconvert_exporter": "python",
 | 
			
		||||
   "pygments_lexer": "ipython3",
 | 
			
		||||
   "version": "3.6.9-final"
 | 
			
		||||
  },
 | 
			
		||||
  "orig_nbformat": 2,
 | 
			
		||||
  "kernelspec": {
 | 
			
		||||
   "name": "python3",
 | 
			
		||||
   "display_name": "Python 3.6.9 64-bit",
 | 
			
		||||
   "metadata": {
 | 
			
		||||
    "interpreter": {
 | 
			
		||||
     "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
 | 
			
		||||
    }
 | 
			
		||||
   }
 | 
			
		||||
  }
 | 
			
		||||
 },
 | 
			
		||||
 "nbformat": 4,
 | 
			
		||||
 "nbformat_minor": 2,
 | 
			
		||||
 "cells": [
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 1,
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "import os,sys,math\n",
 | 
			
		||||
    "import numpy as np\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "import matplotlib.pyplot as plt\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "import torch\n",
 | 
			
		||||
    "import torch.nn as nn\n",
 | 
			
		||||
    "import torch.nn.functional as F\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "import mnist_dataloader\n"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 14,
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [
 | 
			
		||||
    {
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "torch.Size([])\ntorch.Size([])\n[1, 2, 3]\n"
 | 
			
		||||
     ]
 | 
			
		||||
    }
 | 
			
		||||
   ],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "a = torch.randn(10,14)\n",
 | 
			
		||||
    "b = a.shape[1:1]\n",
 | 
			
		||||
    "print(b)\n",
 | 
			
		||||
    "b.numel()\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "print(b)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "b = torch.Size([1])\n",
 | 
			
		||||
    "c = torch.Size([2,3])\n",
 | 
			
		||||
    "d = torch.Size(torch.cat([torch.tensor(b),torch.tensor(c)]))\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "d = [*b,*c]\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "print(d)\n"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": null,
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
   "source": []
 | 
			
		||||
  }
 | 
			
		||||
 ]
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								attempt3/__pycache__/mnist_dataloader.cpython-36.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								attempt3/__pycache__/mnist_dataloader.cpython-36.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								attempt3/__pycache__/mnistdiffusion_model.cpython-36.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								attempt3/__pycache__/mnistdiffusion_model.cpython-36.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								attempt3/__pycache__/mnistdiffusion_utils.cpython-36.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								attempt3/__pycache__/mnistdiffusion_utils.cpython-36.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										91
									
								
								attempt3/mnist_dataloader.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										91
									
								
								attempt3/mnist_dataloader.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,91 @@
 | 
			
		||||
#!/usr/bin/python3
 | 
			
		||||
import os,sys,math
 | 
			
		||||
import numpy as np
 | 
			
		||||
import cv2
 | 
			
		||||
import gzip #need to use gzip.open instead of open
 | 
			
		||||
import struct
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
def read_MNIST_label_file(fname):
 | 
			
		||||
    #fp = gzip.open('./train-labels-idx1-ubyte.gz','rb');
 | 
			
		||||
    fp = gzip.open(fname,'rb');
 | 
			
		||||
    magic = fp.read(4);
 | 
			
		||||
    #nitems = np.frombuffer(fp.read(4),dtype=np.int32)[0]; #some sort of endiannes problem
 | 
			
		||||
    bts = fp.read(4);
 | 
			
		||||
    #bts = bytereverse(bts);
 | 
			
		||||
    #nitems = np.frombuffer(bts,dtype=np.int32);
 | 
			
		||||
    nitems = np.int32(struct.unpack('>I',bts)[0]); #it was a non-native endianness in teh integer encoding
 | 
			
		||||
    #> < @ - endianness
 | 
			
		||||
 | 
			
		||||
    bts = fp.read(nitems);
 | 
			
		||||
    N = len(bts);
 | 
			
		||||
    labels = np.zeros((N),dtype=np.uint8);
 | 
			
		||||
    labels = np.frombuffer(bts,dtype=np.uint8,count=N);
 | 
			
		||||
    #for i in range(0,10):
 | 
			
		||||
    #    bt = fp.read(1);
 | 
			
		||||
    #    labels[i] = np.frombuffer(bt,dtype=np.uint8);
 | 
			
		||||
    fp.close();
 | 
			
		||||
    return labels;
 | 
			
		||||
 | 
			
		||||
def read_MNIST_image_file(fname):
 | 
			
		||||
    fp = gzip.open(fname,'rb');
 | 
			
		||||
    magic = fp.read(4);
 | 
			
		||||
    bts = fp.read(4);
 | 
			
		||||
    nitems = np.int32(struct.unpack('>I',bts)[0]);
 | 
			
		||||
    bts = fp.read(4);
 | 
			
		||||
    nrows = np.int32(struct.unpack('>I',bts)[0]);
 | 
			
		||||
    bts = fp.read(4);
 | 
			
		||||
    ncols = np.int32(struct.unpack('>I',bts)[0]);
 | 
			
		||||
 | 
			
		||||
    images = np.zeros((nitems,nrows,ncols),dtype=np.uint8);
 | 
			
		||||
    for I in range(0,nitems):
 | 
			
		||||
        bts = fp.read(nrows*ncols);
 | 
			
		||||
        img1 = np.frombuffer(bts,dtype=np.uint8,count=nrows*ncols);
 | 
			
		||||
        img1 = img1.reshape((nrows,ncols));
 | 
			
		||||
        images[I,:,:] = img1;
 | 
			
		||||
 | 
			
		||||
    fp.close();
 | 
			
		||||
 | 
			
		||||
    return images;
 | 
			
		||||
 | 
			
		||||
#The mnist dataset is small enough to fit entirely in memory
 | 
			
		||||
def mnist_load():
 | 
			
		||||
    baseloc = "../training_data"
 | 
			
		||||
 | 
			
		||||
    traindatafile = "train-images-idx3-ubyte.gz"
 | 
			
		||||
    trainlabelfile = "train-labels-idx1-ubyte.gz"
 | 
			
		||||
    testdatafile = "t10k-images-idx3-ubyte.gz"
 | 
			
		||||
    testlabelfile = "t10k-labels-idx1-ubyte.gz"
 | 
			
		||||
 | 
			
		||||
    traindatafile = os.path.join(baseloc,traindatafile)
 | 
			
		||||
    trainlabelfile = os.path.join(baseloc,trainlabelfile)
 | 
			
		||||
    testdatafile = os.path.join(baseloc,testdatafile)
 | 
			
		||||
    testlabelfile = os.path.join(baseloc,testlabelfile)
 | 
			
		||||
    
 | 
			
		||||
    labels_train = read_MNIST_label_file(trainlabelfile)
 | 
			
		||||
    labels_test = read_MNIST_label_file(testlabelfile)
 | 
			
		||||
    images_train = read_MNIST_image_file(traindatafile)
 | 
			
		||||
    images_test = read_MNIST_image_file(testdatafile)
 | 
			
		||||
 | 
			
		||||
    labels_train = torch.tensor(labels_train,dtype=torch.float32,requires_grad=False)
 | 
			
		||||
    labels_test = torch.tensor(labels_test,dtype=torch.float32,requires_grad=False)
 | 
			
		||||
    images_train = torch.tensor(images_train,dtype=torch.float32,requires_grad=False)
 | 
			
		||||
    images_test = torch.tensor(images_test,dtype=torch.float32,requires_grad=False)
 | 
			
		||||
    
 | 
			
		||||
    # #debug
 | 
			
		||||
    # print(labels_train.shape)
 | 
			
		||||
    # print(labels_test.shape)
 | 
			
		||||
    # print(images_train.shape)
 | 
			
		||||
    # print(images_test.shape)
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    return [labels_train, labels_test, images_train, images_test]
 | 
			
		||||
 | 
			
		||||
if(__name__ == "__main__"):
 | 
			
		||||
    [labels_train, labels_test, images_train, images_test] = mnist_load()
 | 
			
		||||
    print("Loaded MNIST Data")
 | 
			
		||||
    print(labels_train.shape)
 | 
			
		||||
    print(labels_test.shape)
 | 
			
		||||
    print(images_train.shape)
 | 
			
		||||
    print(images_test.shape)
 | 
			
		||||
							
								
								
									
										398
									
								
								attempt3/mnistdiffusion.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										398
									
								
								attempt3/mnistdiffusion.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,398 @@
 | 
			
		||||
#!/usr/bin/python3
 | 
			
		||||
 | 
			
		||||
import os,sys,math
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
import matplotlib.pyplot as plt
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
 | 
			
		||||
import mnist_dataloader
 | 
			
		||||
 | 
			
		||||
from mnistdiffusion_model import mndiff_rev
 | 
			
		||||
from mnistdiffusion_utils import *
 | 
			
		||||
 | 
			
		||||
import time
 | 
			
		||||
 | 
			
		||||
################
 | 
			
		||||
## Dataloader ##
 | 
			
		||||
################
 | 
			
		||||
 | 
			
		||||
def mnist_load():
 | 
			
		||||
 | 
			
		||||
    [labels_train, labels_test, images_train, images_test] = mnist_dataloader.mnist_load()
 | 
			
		||||
 | 
			
		||||
    #60000 x 28 x 28
 | 
			
		||||
    #60000 x 1
 | 
			
		||||
 | 
			
		||||
    #Rescale data to (-1,1)
 | 
			
		||||
    images_train = images_train.clone().detach().float()
 | 
			
		||||
    images_test = images_test.clone().detach().float()
 | 
			
		||||
 | 
			
		||||
    images_train = 2.0*(images_train/255.0)-1.0
 | 
			
		||||
    images_test = 2.0*(images_test/255.0)-1.0
 | 
			
		||||
    
 | 
			
		||||
    images_train.requires_grad = False
 | 
			
		||||
    images_test.requires_grad = False
 | 
			
		||||
 | 
			
		||||
    return [images_train, images_test]
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
################
 | 
			
		||||
## Operations ##
 | 
			
		||||
################
 | 
			
		||||
 | 
			
		||||
# Rather than killing myself figuring out variable variance schedules,
 | 
			
		||||
# just use a constant variance schedule to start with
 | 
			
		||||
 | 
			
		||||
def training_parameters():
 | 
			
		||||
    lparas = dict()
 | 
			
		||||
 | 
			
		||||
    nts = 200
 | 
			
		||||
    Beta = 0.0050
 | 
			
		||||
 | 
			
		||||
    lr = 1E-4
 | 
			
		||||
    batchsize = 30
 | 
			
		||||
    save_every = 200
 | 
			
		||||
    record_every = 25
 | 
			
		||||
 | 
			
		||||
    B_t = np.ones((nts))*Beta
 | 
			
		||||
    a_t = np.zeros((nts))
 | 
			
		||||
    a_t[0] = 1
 | 
			
		||||
    for I in range(1,nts):
 | 
			
		||||
        a_t[I] = a_t[I-1]*(1-B_t[I-1])
 | 
			
		||||
 | 
			
		||||
    t = np.linspace(0,1,nts) #normalized time (t/T)
 | 
			
		||||
 | 
			
		||||
    lparas["nts"] = nts
 | 
			
		||||
    lparas["Beta"] = Beta
 | 
			
		||||
    lparas["t"] = torch.tensor(t,dtype=torch.float32)
 | 
			
		||||
    lparas["a_t"] = torch.tensor(a_t,dtype=torch.float32)
 | 
			
		||||
    lparas["B_t"] = torch.tensor(B_t,dtype=torch.float32)
 | 
			
		||||
    lparas["lr"] = lr
 | 
			
		||||
    lparas["batchsize"] = batchsize
 | 
			
		||||
    lparas["save_every"] = save_every
 | 
			
		||||
    lparas["record_every"] = record_every
 | 
			
		||||
    return lparas
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_minibatch(imgs, Ndraw, lparas):
 | 
			
		||||
 | 
			
		||||
    #I can probably speed this up with compiled CUDA code
 | 
			
		||||
 | 
			
		||||
    nts = lparas["nts"]
 | 
			
		||||
    Beta = lparas["Beta"]
 | 
			
		||||
 | 
			
		||||
    dev = imgs.device
 | 
			
		||||
 | 
			
		||||
    imbatch = torch.zeros([Ndraw,nts,28,28],dtype=torch.float32).to(dev)
 | 
			
		||||
    draws = torch.randint(0,imgs.shape[0],(Ndraw,))
 | 
			
		||||
 | 
			
		||||
    times = lparas["t"].to(dev)
 | 
			
		||||
    times = times.reshape([1,nts]).expand([Ndraw,nts])
 | 
			
		||||
 | 
			
		||||
    for I in range(0,Ndraw):
 | 
			
		||||
        imgseq = torch.zeros([nts,28,28],dtype=torch.float32).to(dev)
 | 
			
		||||
        imgseq[0,:,:] = imgs[draws[I],:,:]
 | 
			
		||||
 | 
			
		||||
        beta = torch.kron(lparas["B_t"],torch.ones(28,28)).to(dev)
 | 
			
		||||
        beta = beta.reshape([nts,28,28])
 | 
			
		||||
        
 | 
			
		||||
        sig = torch.sqrt(beta)
 | 
			
		||||
        noise = torch.randn((nts,28,28)).to(dev)*sig
 | 
			
		||||
 | 
			
		||||
        for J in range(1,nts):
 | 
			
		||||
            imgseq[J,:,:] = imgseq[J-1,:,:]*torch.sqrt(1.0-beta[J-1,:,:]) + noise[J-1,:,:]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        imbatch[I,:,:,:] = imgseq[:,:,:]
 | 
			
		||||
 | 
			
		||||
    return [imbatch,times]
 | 
			
		||||
 | 
			
		||||
def train_batch(revnet,batchimg,batchtime,lparas):
 | 
			
		||||
 | 
			
		||||
    L = 0.0
 | 
			
		||||
    lr = lparas["lr"]
 | 
			
		||||
    nts = batchtime.shape[1]
 | 
			
		||||
 | 
			
		||||
    dev = revnet.device
 | 
			
		||||
    #beta = torch.as_tensor(lparas["Beta"]).float().to(dev)
 | 
			
		||||
    beta = float(lparas["Beta"])
 | 
			
		||||
 | 
			
		||||
    optim = torch.optim.SGD(revnet.parameters(),lr=lr)
 | 
			
		||||
    optim = torch.optim.Adam(revnet.parameters(),lr=lr)
 | 
			
		||||
    
 | 
			
		||||
    loss = torch.nn.MSELoss().to(dev)
 | 
			
		||||
 | 
			
		||||
    d_est = revnet(batchimg,batchtime)
 | 
			
		||||
    d_est = d_est[:,1:nts,:,:]
 | 
			
		||||
    targ = batchimg[:,1:nts,:,:]-batchimg[:,0:nts-1,:,:]
 | 
			
		||||
 | 
			
		||||
    L2 = loss(targ,beta*d_est)
 | 
			
		||||
 | 
			
		||||
    L2.backward()
 | 
			
		||||
 | 
			
		||||
    optim.step()
 | 
			
		||||
    L = L2.detach().cpu().numpy()
 | 
			
		||||
    L = L/beta**2
 | 
			
		||||
 | 
			
		||||
    loss.zero_grad()
 | 
			
		||||
 | 
			
		||||
    return L
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def train(directory,saveseries):
 | 
			
		||||
 | 
			
		||||
    print("Loading dataset...")
 | 
			
		||||
    [images_train, images_test] = mnist_load()
 | 
			
		||||
    lparas = training_parameters()
 | 
			
		||||
 | 
			
		||||
    batchsize = lparas["batchsize"]
 | 
			
		||||
    save_every = lparas["save_every"]
 | 
			
		||||
    record_every = lparas["record_every"]
 | 
			
		||||
    device = torch.device("cuda")
 | 
			
		||||
 | 
			
		||||
    revnet = mndiff_rev()
 | 
			
		||||
    fname = get_current_save(directory,saveseries)
 | 
			
		||||
    if(fname is None):
 | 
			
		||||
        fname = get_next_save(directory,saveseries)
 | 
			
		||||
    else:
 | 
			
		||||
        print("Loading {}".format(fname))
 | 
			
		||||
        revnet.load(fname)
 | 
			
		||||
 | 
			
		||||
    images_train = images_train.to(device)
 | 
			
		||||
    revnet = revnet.to(device)
 | 
			
		||||
    
 | 
			
		||||
    ## Training Loop
 | 
			
		||||
    I = 0
 | 
			
		||||
    J = 0
 | 
			
		||||
    while(True):
 | 
			
		||||
 | 
			
		||||
        if(I%record_every==0):
 | 
			
		||||
            print("Minibatch {}".format(I))
 | 
			
		||||
 | 
			
		||||
        #New minibatch
 | 
			
		||||
        losses = np.zeros((record_every))
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        #draw minibatch
 | 
			
		||||
        [mb,mbt] = generate_minibatch(images_train,batchsize,lparas)
 | 
			
		||||
 | 
			
		||||
        #train minibatch
 | 
			
		||||
        L = train_batch(revnet,mb,mbt,lparas)
 | 
			
		||||
        losses[J] = L
 | 
			
		||||
        J = J + 1
 | 
			
		||||
 | 
			
		||||
        #record results
 | 
			
		||||
        if(J%record_every==0):
 | 
			
		||||
            J = 0
 | 
			
		||||
            loss = np.mean(losses)
 | 
			
		||||
            revnet.recordloss(loss)
 | 
			
		||||
            print("\tloss={:1.5f}".format(loss))
 | 
			
		||||
 | 
			
		||||
        #save
 | 
			
		||||
        if(I%save_every==0):
 | 
			
		||||
            fnext = get_next_save(directory,saveseries)
 | 
			
		||||
            revnet.save(fnext)
 | 
			
		||||
 | 
			
		||||
        I = I + 1
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
def infer(directory,saveseries):
 | 
			
		||||
 | 
			
		||||
    ttt = int(time.time())
 | 
			
		||||
    np.random.seed(ttt%50000)
 | 
			
		||||
    torch.manual_seed(ttt%50000)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    lparas = training_parameters()
 | 
			
		||||
    revnet = mndiff_rev()
 | 
			
		||||
    fname = get_current_save(directory,saveseries)
 | 
			
		||||
    if(fname is None):
 | 
			
		||||
        print("Warning, no state loaded.")
 | 
			
		||||
    else:
 | 
			
		||||
        print("Loading {}".format(fname))
 | 
			
		||||
        revnet.load(fname)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    nts = lparas["nts"]
 | 
			
		||||
    beta = lparas["Beta"]
 | 
			
		||||
 | 
			
		||||
    nts = nts*15
 | 
			
		||||
 | 
			
		||||
    imgseq = torch.zeros([nts,28,28])
 | 
			
		||||
    imgseq[nts-1,:,:] = torch.randn(28,28)
 | 
			
		||||
 | 
			
		||||
    #imgseq[nts-1,:,17:20] = 3.0
 | 
			
		||||
 | 
			
		||||
    tl = torch.linspace(0,1,nts)
 | 
			
		||||
 | 
			
		||||
    bb = torch.as_tensor(beta)
 | 
			
		||||
 | 
			
		||||
    for I in range(nts-2,-1,-1):
 | 
			
		||||
        img = imgseq[I+1,:,:]
 | 
			
		||||
        t = tl[I]
 | 
			
		||||
        eps = revnet(img,t)
 | 
			
		||||
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        img2 = img*torch.sqrt(1-bb) - bb*eps + 1.0*bb*torch.randn(28,28)
 | 
			
		||||
 | 
			
		||||
        imgseq[I,:,:] = img2
 | 
			
		||||
 | 
			
		||||
    im0 = torch.squeeze(imgseq[0,:,:])
 | 
			
		||||
    im1 = torch.squeeze(imgseq[int(nts*0.25),:,:])
 | 
			
		||||
    im2 = torch.squeeze(imgseq[int(nts*0.75),:,:])
 | 
			
		||||
    im3 = torch.squeeze(imgseq[int(nts*0.90),:,:])
 | 
			
		||||
      
 | 
			
		||||
    im0 = im0.detach().numpy()
 | 
			
		||||
    im1 = im1.detach().numpy()
 | 
			
		||||
    im2 = im2.detach().numpy()
 | 
			
		||||
    im3 = im3.detach().numpy()
 | 
			
		||||
    
 | 
			
		||||
    plt.figure()
 | 
			
		||||
    plt.subplot(2,2,1)
 | 
			
		||||
    plt.imshow(im3,cmap='gray')
 | 
			
		||||
    plt.subplot(2,2,2)
 | 
			
		||||
    plt.imshow(im2,cmap='gray')
 | 
			
		||||
    plt.subplot(2,2,3)
 | 
			
		||||
    plt.imshow(im1,cmap='gray')
 | 
			
		||||
    plt.subplot(2,2,4)
 | 
			
		||||
    plt.imshow(im0,cmap='gray')
 | 
			
		||||
    plt.show()
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
def plot_history(directory, saveseries):
 | 
			
		||||
 | 
			
		||||
    revnet = mndiff_rev()
 | 
			
		||||
    fname = get_current_save(directory,saveseries)
 | 
			
		||||
    if(fname is None):
 | 
			
		||||
        fname = get_next_save(directory,saveseries)
 | 
			
		||||
    else:
 | 
			
		||||
        print("Loading {}".format(fname))
 | 
			
		||||
        revnet.load(fname)
 | 
			
		||||
 | 
			
		||||
    revnet.plothistory()
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
def test1():
 | 
			
		||||
 | 
			
		||||
    [images_train, images_test] = mnist_load()
 | 
			
		||||
    lparas = training_parameters()
 | 
			
		||||
    
 | 
			
		||||
    #plt.figure()
 | 
			
		||||
    #plt.plot(lparas["t"],lparas["a_t"])
 | 
			
		||||
    #plt.show()
 | 
			
		||||
 | 
			
		||||
    #plt.figure()
 | 
			
		||||
    #plt.plot(lparas["t"],lparas["B_t"])
 | 
			
		||||
    #plt.show()
 | 
			
		||||
 | 
			
		||||
    images_train = images_train.to("cuda")
 | 
			
		||||
 | 
			
		||||
    [mb,mbtime] = generate_minibatch(images_train,50,lparas)
 | 
			
		||||
 | 
			
		||||
    img0 = torch.squeeze(mb[0,0,:,:]).clone().detach().to("cpu")
 | 
			
		||||
    img1 = torch.squeeze(mb[0,1,:,:]).clone().detach().to("cpu")
 | 
			
		||||
    img2 = torch.squeeze(mb[0,2,:,:]).clone().detach().to("cpu")
 | 
			
		||||
    img3 = torch.squeeze(mb[0,99,:,:]).clone().detach().to("cpu")
 | 
			
		||||
    
 | 
			
		||||
    print(mb.shape)
 | 
			
		||||
    print(mbtime.shape)
 | 
			
		||||
 | 
			
		||||
    plt.figure()
 | 
			
		||||
    plt.subplot(2,2,1)
 | 
			
		||||
    plt.imshow(img0,cmap='gray')
 | 
			
		||||
    plt.subplot(2,2,2)
 | 
			
		||||
    plt.imshow(img1,cmap='gray')
 | 
			
		||||
    plt.subplot(2,2,3)
 | 
			
		||||
    plt.imshow(img2,cmap='gray')
 | 
			
		||||
    plt.subplot(2,2,4)
 | 
			
		||||
    plt.imshow(img3,cmap='gray')
 | 
			
		||||
 | 
			
		||||
    plt.show()
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
def test2(directory, saveseries):
 | 
			
		||||
 | 
			
		||||
    lparas = training_parameters()
 | 
			
		||||
    nts = lparas["nts"]
 | 
			
		||||
    beta = lparas["Beta"]
 | 
			
		||||
 | 
			
		||||
    revnet = mndiff_rev()
 | 
			
		||||
    fname = get_current_save(directory,saveseries)
 | 
			
		||||
    if(fname is None):
 | 
			
		||||
        fname = get_next_save(directory,saveseries)
 | 
			
		||||
    else:
 | 
			
		||||
        print("Loading {}".format(fname))
 | 
			
		||||
        revnet.load(fname)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    [images_train, images_test] = mnist_load()
 | 
			
		||||
    
 | 
			
		||||
    x = torch.squeeze(images_train[100,:,:])
 | 
			
		||||
    #x = beta*torch.randn(28,28)
 | 
			
		||||
    y = revnet(x,0.5)
 | 
			
		||||
 | 
			
		||||
    x = x.detach().numpy()
 | 
			
		||||
    y = y.detach().numpy()
 | 
			
		||||
 | 
			
		||||
    print("y mn={:1.3f} std={:1.3f} std2*beta={}".format(np.mean(y),np.std(y),np.std(y)**2*beta))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    plt.figure()
 | 
			
		||||
    plt.subplot(2,1,1)
 | 
			
		||||
    plt.imshow(x,cmap='gray')
 | 
			
		||||
    plt.subplot(2,1,2)
 | 
			
		||||
    plt.imshow(y,cmap='gray')
 | 
			
		||||
    plt.show()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
##########
 | 
			
		||||
## Main ##
 | 
			
		||||
##########
 | 
			
		||||
 | 
			
		||||
def mainswitch():
 | 
			
		||||
    args = sys.argv
 | 
			
		||||
    if(len(args)<2):
 | 
			
		||||
        print("mnistdiffusion [operation]")
 | 
			
		||||
        print("operations:")
 | 
			
		||||
        print("\ttrain")
 | 
			
		||||
        print("\tplot_history")
 | 
			
		||||
        print("\tinfer")
 | 
			
		||||
 | 
			
		||||
        exit(0)
 | 
			
		||||
    
 | 
			
		||||
    if(sys.argv[1]=="train"):
 | 
			
		||||
        train("./saves","test_3a")
 | 
			
		||||
    if(sys.argv[1]=="plot_history"):
 | 
			
		||||
        plot_history("./saves","test_3a")
 | 
			
		||||
    if(sys.argv[1]=="infer"):
 | 
			
		||||
        infer("./saves","test_3a")
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if(__name__=="__main__"):
 | 
			
		||||
 | 
			
		||||
    ttt = int(time.time())
 | 
			
		||||
    np.random.seed(ttt%50000)
 | 
			
		||||
    torch.manual_seed(ttt%50000)
 | 
			
		||||
 | 
			
		||||
    #test1()
 | 
			
		||||
    #test2("./saves","test_2a")
 | 
			
		||||
    mainswitch()
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										300
									
								
								attempt3/mnistdiffusion_model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										300
									
								
								attempt3/mnistdiffusion_model.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,300 @@
 | 
			
		||||
#!/usr/bin/python3
 | 
			
		||||
 | 
			
		||||
import os,sys,math
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
import matplotlib.pyplot as plt
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
 | 
			
		||||
import mnist_dataloader
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
## Sohl-Dickstein 2015 describes a gaussian "bump" function to 
 | 
			
		||||
#   implement learnable time-dependence
 | 
			
		||||
#
 | 
			
		||||
#network input vector and normalized time (tn e (0,1)) 
 | 
			
		||||
# input: (dimx, dimy, ..., features), (scalar) --> output: (dimx, dimy, ...)
 | 
			
		||||
# input: (ntimes, ..., features), (ntimes) --> output: (ntimes , ...)
 | 
			
		||||
class gaussian_timeselect(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, nfeatures=5, width = 0.25):
 | 
			
		||||
        super(gaussian_timeselect,self).__init__()
 | 
			
		||||
 | 
			
		||||
        tau_j = torch.linspace(0,1,nfeatures).float()
 | 
			
		||||
        self.tau_j = torch.nn.parameter.Parameter(tau_j, requires_grad=False)
 | 
			
		||||
 | 
			
		||||
        width = torch.tensor(width,dtype=torch.float32)
 | 
			
		||||
        self.width = torch.nn.parameter.Parameter(width,requires_grad = False)
 | 
			
		||||
 | 
			
		||||
    #network input vector and normalized time (tn e (0,1)) 
 | 
			
		||||
    # input: (dimx, dimy, ..., features), (scalar) --> output: (dimx, dimy, ...)
 | 
			
		||||
    # input: (ntimes, ..., features), (ntimes) --> output: (ntimes , ...)
 | 
			
		||||
    def forward(self, x, tn):
 | 
			
		||||
 | 
			
		||||
        ndim = len(x.shape)
 | 
			
		||||
 | 
			
		||||
        tn = torch.as_tensor(tn)
 | 
			
		||||
        tn.requires_grad = False
 | 
			
		||||
        
 | 
			
		||||
        
 | 
			
		||||
        if(len(tn.shape)==0):
 | 
			
		||||
            #case 0 - if tn is a scalar
 | 
			
		||||
            g = torch.exp(-1.0/(2.0*self.width)*(tn-self.tau_j)**2)
 | 
			
		||||
            g = g/torch.clip(torch.sum(g),min=0.001)
 | 
			
		||||
            x = torch.tensordot(x,g,dims=[[ndim-1],[0]])
 | 
			
		||||
        else:
 | 
			
		||||
            #Case 1 - tn is a tensor
 | 
			
		||||
            nf = self.tau_j.shape[0]
 | 
			
		||||
            nt = tn.numel()
 | 
			
		||||
            nt2 = x.shape[0]
 | 
			
		||||
            ndim = len(x.shape)
 | 
			
		||||
            nf2 = x.shape[ndim-1]
 | 
			
		||||
            
 | 
			
		||||
            if(nt!=nt2):
 | 
			
		||||
                raise Exception("gaussian_timeselect time mismatch")
 | 
			
		||||
            if(nf!=nf2):
 | 
			
		||||
                raise Exception("gaussian_timeselect feature mismatch")
 | 
			
		||||
        
 | 
			
		||||
            shp2 = x.shape[1:ndim-1]
 | 
			
		||||
            nshp2 = shp2.numel()
 | 
			
		||||
 | 
			
		||||
            x = x.reshape([nt,nshp2,nf])
 | 
			
		||||
            tau = self.tau_j.reshape([1,1,nf])
 | 
			
		||||
            tau = tau.expand([nt,nshp2,nf])
 | 
			
		||||
            tn = tn.reshape([nt,1,1])
 | 
			
		||||
            tn = tn.expand([nt,nshp2,nf])
 | 
			
		||||
 | 
			
		||||
            g = torch.exp(-1.0/(2.0*self.width)*(tn-tau)**2)
 | 
			
		||||
            gs = torch.clip(torch.sum(g,axis=2),min=0.0001)
 | 
			
		||||
            gs = gs.reshape([nt,nshp2,1]).expand([nt,nshp2,nf])
 | 
			
		||||
            g = g/gs
 | 
			
		||||
 | 
			
		||||
            x = torch.sum(x*g,axis=2)
 | 
			
		||||
            #shp3 = torch.Size(torch.cat(torch.tensor([nt]),torch.tensor(shp2)))
 | 
			
		||||
                # contact [nt, *shp2]
 | 
			
		||||
            shp3 = torch.Size([nt,*shp2])
 | 
			
		||||
            x = x.reshape(shp3)
 | 
			
		||||
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
#The reverse network should output a mean image and a covariance image
 | 
			
		||||
#given an input image
 | 
			
		||||
 | 
			
		||||
class mndiff_rev(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        super(mndiff_rev,self).__init__()
 | 
			
		||||
 | 
			
		||||
        self.nchan = 10
 | 
			
		||||
 | 
			
		||||
        self.L1a = torch.nn.Conv2d(1,self.nchan,3,padding=1)
 | 
			
		||||
        self.L1b = torch.nn.Conv2d(self.nchan,self.nchan,3)
 | 
			
		||||
        self.L1c = torch.nn.MaxPool2d(kernel_size=2)
 | 
			
		||||
 | 
			
		||||
        self.L2a = torch.nn.Conv2d(self.nchan,self.nchan,3)
 | 
			
		||||
        self.L2b = torch.nn.Conv2d(self.nchan,self.nchan,3)
 | 
			
		||||
 | 
			
		||||
        self.L3a = torch.nn.ConvTranspose2d(self.nchan,self.nchan,3)
 | 
			
		||||
        self.L3b = torch.nn.ConvTranspose2d(self.nchan,self.nchan,3)
 | 
			
		||||
 | 
			
		||||
        #self.L4a = torch.nn.Upsample(scale_factor=2)
 | 
			
		||||
        self.L4b = torch.nn.ConvTranspose2d(2*self.nchan,self.nchan,3)
 | 
			
		||||
        self.L4c = torch.nn.ConvTranspose2d(self.nchan,1,3,padding=1)
 | 
			
		||||
        
 | 
			
		||||
        self.nl = nn.Tanh()
 | 
			
		||||
 | 
			
		||||
        self.device = torch.device("cpu")
 | 
			
		||||
 | 
			
		||||
        self.losstime = []
 | 
			
		||||
        self.lossrecord = []
 | 
			
		||||
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
    def forward(self,x,t):
 | 
			
		||||
 | 
			
		||||
        #handle data-shape cases
 | 
			
		||||
        #case 0: x (pix_x, pix_y), t is scalar
 | 
			
		||||
        #case 1: x (ntimes, pix_x, pix_y), t is (ntimes)
 | 
			
		||||
        #case 2: x (nbatch, ntimes, pix_x, pix_y), t is (nbatch, ntimes)
 | 
			
		||||
 | 
			
		||||
        if(len(x.shape)==2):
 | 
			
		||||
            case = 0
 | 
			
		||||
            t = torch.as_tensor(t).to(self.device)
 | 
			
		||||
            t = t.float()
 | 
			
		||||
            x = x.reshape([1,1,28,28])
 | 
			
		||||
            ntimes = 1
 | 
			
		||||
            nbatch = 1
 | 
			
		||||
        elif(len(x.shape)==3):
 | 
			
		||||
            case = 1
 | 
			
		||||
            ntimes = x.shape[0]
 | 
			
		||||
            nbatch = 1
 | 
			
		||||
            x = x.reshape([ntimes,1,28,28])
 | 
			
		||||
        elif(len(x.shape)==4):
 | 
			
		||||
            case = 2
 | 
			
		||||
            nbatch = x.shape[0]
 | 
			
		||||
            ntimes = x.shape[1]
 | 
			
		||||
            #x = x.reshape([nbatch,ntimes,28*28])
 | 
			
		||||
            x = x.reshape([nbatch*ntimes,1,28,28])
 | 
			
		||||
            t = t.reshape([nbatch*ntimes])
 | 
			
		||||
        else:
 | 
			
		||||
            print("Error: Expect input to be dimension 2,3,4")
 | 
			
		||||
            raise Exception("Error: Expect input to be dimension 2,3,4")
 | 
			
		||||
 | 
			
		||||
        #block 1
 | 
			
		||||
        s1 = self.L1a(x)
 | 
			
		||||
        s1 = self.nl(s1)
 | 
			
		||||
        s1 = self.L1b(s1) #skip connection
 | 
			
		||||
        s1 = self.nl(s1)
 | 
			
		||||
        s2 = self.L1c(s1)
 | 
			
		||||
 | 
			
		||||
        #block 2
 | 
			
		||||
        s2 = self.L2a(s2)
 | 
			
		||||
        s2 = self.nl(s2)
 | 
			
		||||
        s2 = self.L2b(s2)
 | 
			
		||||
        s2 = self.nl(s2)
 | 
			
		||||
 | 
			
		||||
        #block 3
 | 
			
		||||
        s2 = self.L3a(s2)
 | 
			
		||||
        s2 = self.nl(s2)
 | 
			
		||||
        s2 = self.L3b(s2)
 | 
			
		||||
        s2 = self.nl(s2)
 | 
			
		||||
 | 
			
		||||
        #block 4
 | 
			
		||||
        #s2 = self.L4a(s2)
 | 
			
		||||
        s2 = torch.nn.functional.interpolate(s2,scale_factor=[2,2])
 | 
			
		||||
        #concat skip connection
 | 
			
		||||
        s3 = torch.cat([s1,s2],axis=1)
 | 
			
		||||
 | 
			
		||||
        s3 = self.L4b(s3)
 | 
			
		||||
        s3 = self.nl(s3)
 | 
			
		||||
        s3 = self.L4c(s3)
 | 
			
		||||
 | 
			
		||||
        #output
 | 
			
		||||
        eps = s3
 | 
			
		||||
 | 
			
		||||
        if(case==0):
 | 
			
		||||
            eps = eps.reshape([28,28])
 | 
			
		||||
        elif(case==1):
 | 
			
		||||
            eps = eps.reshape([ntimes,28,28])
 | 
			
		||||
        elif(case==2):
 | 
			
		||||
            eps = eps.reshape([nbatch,ntimes,28,28])
 | 
			
		||||
 | 
			
		||||
        return eps
 | 
			
		||||
 | 
			
		||||
    def to(self,dev):
 | 
			
		||||
        self.device = torch.device(dev)
 | 
			
		||||
        ret = super(mndiff_rev,self).to(self.device)
 | 
			
		||||
        return ret
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def save(self,fname):
 | 
			
		||||
        try:
 | 
			
		||||
            dev = self.device
 | 
			
		||||
            self.to("cpu")
 | 
			
		||||
            q = [self.state_dict(), self.losstime, self.lossrecord]
 | 
			
		||||
            torch.save(q,fname)
 | 
			
		||||
            self.to(dev)
 | 
			
		||||
        except:
 | 
			
		||||
            print("model save: problem saving {}".format(fname))
 | 
			
		||||
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    def load(self,fname):
 | 
			
		||||
        try:
 | 
			
		||||
            [modelsd,losstime,lossrecord] = torch.load(fname)
 | 
			
		||||
            self.load_state_dict(modelsd)
 | 
			
		||||
            self.losstime = losstime
 | 
			
		||||
            self.lossrecord = lossrecord
 | 
			
		||||
        except:
 | 
			
		||||
            print("model load: problem loading {}".format(fname))
 | 
			
		||||
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
    def plothistory(self):
 | 
			
		||||
        x = self.losstime
 | 
			
		||||
        y = self.lossrecord
 | 
			
		||||
        plt.figure()
 | 
			
		||||
        plt.plot(x,y,'k.',markersize=1)
 | 
			
		||||
        plt.title('Loss History')
 | 
			
		||||
        plt.show()
 | 
			
		||||
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    def recordloss(self,loss):
 | 
			
		||||
 | 
			
		||||
        t = 0
 | 
			
		||||
        L = len(self.losstime)
 | 
			
		||||
        if(L>0):
 | 
			
		||||
            t = self.losstime[L-1]
 | 
			
		||||
        self.losstime.append(t+1)
 | 
			
		||||
        self.lossrecord.append(loss)
 | 
			
		||||
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
###########
 | 
			
		||||
## Tests ##
 | 
			
		||||
###########
 | 
			
		||||
 | 
			
		||||
def test_gaussian_timeselect():
 | 
			
		||||
    gts = gaussian_timeselect(5,0.025)
 | 
			
		||||
    gts = gts.to("cuda")
 | 
			
		||||
 | 
			
		||||
    x = torch.randn((8,8,5)).to("cuda")
 | 
			
		||||
    #t = torch.tensor([0,0.5,1]).to("cuda")
 | 
			
		||||
    t = torch.linspace(0,1,8).to("cuda")
 | 
			
		||||
 | 
			
		||||
    y = gts(x,0)
 | 
			
		||||
    y2 = gts(x,t)
 | 
			
		||||
    y3 = gts(x,0.111)
 | 
			
		||||
 | 
			
		||||
    print(y)
 | 
			
		||||
    print(x[:,:,0])
 | 
			
		||||
 | 
			
		||||
    print(x.shape)
 | 
			
		||||
    print(y.shape)
 | 
			
		||||
    print(y2.shape)
 | 
			
		||||
 | 
			
		||||
    print(torch.abs(y2[0,:]-y[0,:])>1E-6)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
def test_mndiff_rev():
 | 
			
		||||
 | 
			
		||||
    mnd = mndiff_rev()
 | 
			
		||||
 | 
			
		||||
    x0 = torch.randn(28,28)
 | 
			
		||||
    t0 = 0
 | 
			
		||||
    x1 = torch.randn(5,28,28)
 | 
			
		||||
    t1 = torch.linspace(0,1,5)
 | 
			
		||||
 | 
			
		||||
    x2 = torch.randn(8,5,28,28)
 | 
			
		||||
    t2 = torch.linspace(0,1,5)
 | 
			
		||||
    t2 = t2.reshape([1,5]).expand([8,5])
 | 
			
		||||
 | 
			
		||||
    y0 = mnd(x0,t0)
 | 
			
		||||
    y1 = mnd(x1,t1)
 | 
			
		||||
    y2 = mnd(x2,t2)
 | 
			
		||||
 | 
			
		||||
    print(y0.shape)
 | 
			
		||||
    print(y1.shape)
 | 
			
		||||
    print(y2.shape)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
if(__name__=="__main__"):
 | 
			
		||||
 | 
			
		||||
    #test_gaussian_timeselect()
 | 
			
		||||
    test_mndiff_rev()
 | 
			
		||||
 | 
			
		||||
    pass
 | 
			
		||||
							
								
								
									
										78
									
								
								attempt3/mnistdiffusion_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										78
									
								
								attempt3/mnistdiffusion_utils.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,78 @@
 | 
			
		||||
#!/usr/bin/python3
 | 
			
		||||
 | 
			
		||||
import os,sys,math
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_savenumber(fname):
 | 
			
		||||
    num = 0
 | 
			
		||||
 | 
			
		||||
    fn2 = os.path.split(fname)[1]
 | 
			
		||||
    fn2 = os.path.splitext(fn2)[0]
 | 
			
		||||
 | 
			
		||||
    L = len(fn2)
 | 
			
		||||
    L2 = L
 | 
			
		||||
    for I in range(L-1,0,-1):
 | 
			
		||||
        c = fn2[I]
 | 
			
		||||
        if(not c.isnumeric()):
 | 
			
		||||
            L2 = I+1
 | 
			
		||||
            break
 | 
			
		||||
    nm = fn2[L2:L]
 | 
			
		||||
    
 | 
			
		||||
    try:
 | 
			
		||||
        num = int(nm)
 | 
			
		||||
    except:
 | 
			
		||||
        num = 0
 | 
			
		||||
 | 
			
		||||
    return num
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def sorttwopythonlists(list1, list2):
 | 
			
		||||
 
 | 
			
		||||
    zipped_pairs = zip(list2, list1)
 | 
			
		||||
 
 | 
			
		||||
    z = [x for _, x in sorted(zipped_pairs)]
 | 
			
		||||
    z2 = sorted(list2)
 | 
			
		||||
 
 | 
			
		||||
    return [z,z2]
 | 
			
		||||
 | 
			
		||||
def get_current_save(directory,saveseries):
 | 
			
		||||
    fname = None
 | 
			
		||||
 | 
			
		||||
    fl = os.listdir(directory)
 | 
			
		||||
    fl2 = []
 | 
			
		||||
    nums = []
 | 
			
		||||
 | 
			
		||||
    for f in fl:
 | 
			
		||||
        if(f.find(saveseries)>=0):
 | 
			
		||||
            fl2.append(os.path.join(directory,f))
 | 
			
		||||
    
 | 
			
		||||
    for f in fl2:
 | 
			
		||||
        n = get_savenumber(f)
 | 
			
		||||
        nums.append(n)
 | 
			
		||||
 | 
			
		||||
    [fl2,nums] = sorttwopythonlists(fl2,nums)
 | 
			
		||||
 | 
			
		||||
    if(len(fl2)>0):
 | 
			
		||||
        fname = fl2[len(fl2)-1]
 | 
			
		||||
    else:
 | 
			
		||||
        fname = None
 | 
			
		||||
    
 | 
			
		||||
    return fname
 | 
			
		||||
 | 
			
		||||
def get_next_save(directory,saveseries):
 | 
			
		||||
    fname = None
 | 
			
		||||
    fncurrent = get_current_save(directory,saveseries)
 | 
			
		||||
    if(fncurrent is None):
 | 
			
		||||
        fname = os.path.join(directory,"{}{:02d}.pyt".format(saveseries,0))
 | 
			
		||||
    else:
 | 
			
		||||
        N = get_savenumber(fncurrent)
 | 
			
		||||
        N = N + 1
 | 
			
		||||
        fname = os.path.join(directory,"{}{:02d}.pyt".format(saveseries,N))
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
    return fname
 | 
			
		||||
 | 
			
		||||
if(__name__=="__main__"):
 | 
			
		||||
 | 
			
		||||
    #print(get_savenumber("./saves/helloworld05202.py"))
 | 
			
		||||
    pass
 | 
			
		||||
							
								
								
									
										161
									
								
								attempt3/old/mnistdiffusion.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										161
									
								
								attempt3/old/mnistdiffusion.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,161 @@
 | 
			
		||||
#!/usr/bin/python3
 | 
			
		||||
 | 
			
		||||
import os,sys,math
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
import matplotlib.pyplot as plt
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
 | 
			
		||||
import mnist_dataloader
 | 
			
		||||
 | 
			
		||||
from mnistdiffusion_model import mndiff_rev
 | 
			
		||||
 | 
			
		||||
################
 | 
			
		||||
## Dataloader ##
 | 
			
		||||
################
 | 
			
		||||
 | 
			
		||||
def mnist_load():
 | 
			
		||||
 | 
			
		||||
    [labels_train, labels_test, images_train, images_test] = mnist_dataloader.mnist_load()
 | 
			
		||||
 | 
			
		||||
    #60000 x 28 x 28
 | 
			
		||||
    #60000 x 1
 | 
			
		||||
 | 
			
		||||
    #Rescale data to (-1,1)
 | 
			
		||||
    images_train = images_train.clone().detach().float()
 | 
			
		||||
    images_test = images_test.clone().detach().float()
 | 
			
		||||
 | 
			
		||||
    images_train = 2.0*(images_train/255.0)-1.0
 | 
			
		||||
    images_test = 2.0*(images_test/255.0)-1.0
 | 
			
		||||
    
 | 
			
		||||
    images_train.requires_grad = False
 | 
			
		||||
    images_test.requires_grad = False
 | 
			
		||||
 | 
			
		||||
    return [images_train, images_test]
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
################
 | 
			
		||||
## Operations ##
 | 
			
		||||
################
 | 
			
		||||
 | 
			
		||||
def training_parameters():
 | 
			
		||||
    lparas = dict()
 | 
			
		||||
 | 
			
		||||
    nts = 200
 | 
			
		||||
    t = np.arange(0,nts)
 | 
			
		||||
    s = 0.1
 | 
			
		||||
    f_t = np.cos((t/nts+s)/(1+s)*np.pi/2.0)**2
 | 
			
		||||
    a_t = f_t/f_t[0]
 | 
			
		||||
    B_t = np.zeros(t.shape[0])
 | 
			
		||||
    B_t[0:nts-1] = np.clip(1-a_t[1:nts]/a_t[0:nts-1],0,0.999)
 | 
			
		||||
    B_t[nts-1] = B_t[nts-2]
 | 
			
		||||
 | 
			
		||||
    lparas["nts"] = nts
 | 
			
		||||
    lparas["t"] = torch.tensor(t,dtype=torch.float32)
 | 
			
		||||
    lparas["a_t"] = torch.tensor(a_t,dtype=torch.float32)
 | 
			
		||||
    lparas["B_t"] = torch.tensor(B_t,dtype=torch.float32)
 | 
			
		||||
 | 
			
		||||
    #Traiing parameters for a cosine variance schedule
 | 
			
		||||
    #ref: Nichol and Dharawal 2021
 | 
			
		||||
 | 
			
		||||
    return lparas
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_minibatch(imgs, Ndraw, lparas):
 | 
			
		||||
 | 
			
		||||
    #I can probably speed this up with compiled CUDA code
 | 
			
		||||
 | 
			
		||||
    nts = lparas["nts"]
 | 
			
		||||
    dev = imgs.device
 | 
			
		||||
 | 
			
		||||
    imbatch = torch.zeros([Ndraw*nts,28,28],dtype=torch.float32).to(dev)
 | 
			
		||||
    draws = torch.randint(0,imgs.shape[0],(Ndraw,))
 | 
			
		||||
 | 
			
		||||
    for I in range(0,Ndraw):
 | 
			
		||||
        imgseq = torch.zeros([nts,28,28],dtype=torch.float32).to(dev)
 | 
			
		||||
        imgseq[0,:,:] = imgs[draws[I],:,:]
 | 
			
		||||
 | 
			
		||||
        beta = torch.kron(lparas["B_t"],torch.ones(28,28)).to(dev)
 | 
			
		||||
        beta = beta.reshape([nts,28,28])
 | 
			
		||||
        
 | 
			
		||||
        sig = torch.sqrt(beta)
 | 
			
		||||
        noise = torch.randn((nts,28,28)).to(dev)*sig
 | 
			
		||||
 | 
			
		||||
        for J in range(1,nts):
 | 
			
		||||
            imgseq[J,:,:] = imgseq[J-1,:,:] + noise[J-1,:,:]
 | 
			
		||||
 | 
			
		||||
        I1 = I*nts
 | 
			
		||||
        I2 = (I+1)*nts
 | 
			
		||||
        imbatch[I1:I2,:,:] = imgseq[:,:,:]
 | 
			
		||||
 | 
			
		||||
    return imbatch
 | 
			
		||||
 | 
			
		||||
def train_batch(imgbatch,lr,lparas):
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def train():
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
def infer():
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test1():
 | 
			
		||||
 | 
			
		||||
    [images_train, images_test] = mnist_load()
 | 
			
		||||
    lparas = training_parameters()
 | 
			
		||||
    
 | 
			
		||||
    #plt.figure()
 | 
			
		||||
    #plt.plot(lparas["t"],lparas["a_t"])
 | 
			
		||||
    #plt.show()
 | 
			
		||||
 | 
			
		||||
    #plt.figure()
 | 
			
		||||
    #plt.plot(lparas["t"],lparas["B_t"])
 | 
			
		||||
    #plt.show()
 | 
			
		||||
 | 
			
		||||
    images_train = images_train.to("cuda")
 | 
			
		||||
 | 
			
		||||
    mb = generate_minibatch(images_train,50,lparas)
 | 
			
		||||
 | 
			
		||||
    img0 = torch.squeeze(mb[0,:,:]).clone().detach().to("cpu")
 | 
			
		||||
    img1 = torch.squeeze(mb[1,:,:]).clone().detach().to("cpu")
 | 
			
		||||
    img2 = torch.squeeze(mb[2,:,:]).clone().detach().to("cpu")
 | 
			
		||||
    img3 = torch.squeeze(mb[30,:,:]).clone().detach().to("cpu")
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    plt.figure()
 | 
			
		||||
    plt.subplot(2,2,1)
 | 
			
		||||
    plt.imshow(img0,cmap='gray')
 | 
			
		||||
    plt.subplot(2,2,2)
 | 
			
		||||
    plt.imshow(img1,cmap='gray')
 | 
			
		||||
    plt.subplot(2,2,3)
 | 
			
		||||
    plt.imshow(img2,cmap='gray')
 | 
			
		||||
    plt.subplot(2,2,4)
 | 
			
		||||
    plt.imshow(img3,cmap='gray')
 | 
			
		||||
 | 
			
		||||
    plt.show()
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
##########
 | 
			
		||||
## Main ##
 | 
			
		||||
##########
 | 
			
		||||
 | 
			
		||||
if(__name__=="__main__"):
 | 
			
		||||
 | 
			
		||||
    #test1()
 | 
			
		||||
    test_gaussian_timeselect()
 | 
			
		||||
							
								
								
									
										85
									
								
								attempt3/scratchpad.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										85
									
								
								attempt3/scratchpad.ipynb
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,85 @@
 | 
			
		||||
{
 | 
			
		||||
 "metadata": {
 | 
			
		||||
  "language_info": {
 | 
			
		||||
   "codemirror_mode": {
 | 
			
		||||
    "name": "ipython",
 | 
			
		||||
    "version": 3
 | 
			
		||||
   },
 | 
			
		||||
   "file_extension": ".py",
 | 
			
		||||
   "mimetype": "text/x-python",
 | 
			
		||||
   "name": "python",
 | 
			
		||||
   "nbconvert_exporter": "python",
 | 
			
		||||
   "pygments_lexer": "ipython3",
 | 
			
		||||
   "version": "3.6.9-final"
 | 
			
		||||
  },
 | 
			
		||||
  "orig_nbformat": 2,
 | 
			
		||||
  "kernelspec": {
 | 
			
		||||
   "name": "python3",
 | 
			
		||||
   "display_name": "Python 3.6.9 64-bit",
 | 
			
		||||
   "metadata": {
 | 
			
		||||
    "interpreter": {
 | 
			
		||||
     "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
 | 
			
		||||
    }
 | 
			
		||||
   }
 | 
			
		||||
  }
 | 
			
		||||
 },
 | 
			
		||||
 "nbformat": 4,
 | 
			
		||||
 "nbformat_minor": 2,
 | 
			
		||||
 "cells": [
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 1,
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "import os,sys,math\n",
 | 
			
		||||
    "import numpy as np\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "import matplotlib.pyplot as plt\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "import torch\n",
 | 
			
		||||
    "import torch.nn as nn\n",
 | 
			
		||||
    "import torch.nn.functional as F\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "import mnist_dataloader\n"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 14,
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [
 | 
			
		||||
    {
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "torch.Size([])\ntorch.Size([])\n[1, 2, 3]\n"
 | 
			
		||||
     ]
 | 
			
		||||
    }
 | 
			
		||||
   ],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "a = torch.randn(10,14)\n",
 | 
			
		||||
    "b = a.shape[1:1]\n",
 | 
			
		||||
    "print(b)\n",
 | 
			
		||||
    "b.numel()\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "print(b)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "b = torch.Size([1])\n",
 | 
			
		||||
    "c = torch.Size([2,3])\n",
 | 
			
		||||
    "d = torch.Size(torch.cat([torch.tensor(b),torch.tensor(c)]))\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "d = [*b,*c]\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "print(d)\n"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": null,
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
   "source": []
 | 
			
		||||
  }
 | 
			
		||||
 ]
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								mnistdiffusion_attempt1_fig1.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								mnistdiffusion_attempt1_fig1.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 21 KiB  | 
							
								
								
									
										
											BIN
										
									
								
								training_data/MNISTDataset.docx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								training_data/MNISTDataset.docx
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										138
									
								
								training_data/ams_MNIST_load.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										138
									
								
								training_data/ams_MNIST_load.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,138 @@
 | 
			
		||||
#!/usr/bin/python3
 | 
			
		||||
 | 
			
		||||
import sys
 | 
			
		||||
sys.path.append('/home/aschinde/workspace/projects_python/library')
 | 
			
		||||
 | 
			
		||||
import os,sys,math
 | 
			
		||||
import numpy as np
 | 
			
		||||
import cv2;
 | 
			
		||||
 | 
			
		||||
import gzip #May need to use gzip.open instead of open
 | 
			
		||||
 | 
			
		||||
import struct
 | 
			
		||||
#struct unpack allows some interpretation of python binary data
 | 
			
		||||
#Example
 | 
			
		||||
##import struct
 | 
			
		||||
##
 | 
			
		||||
##data = open("from_fortran.bin", "rb").read()
 | 
			
		||||
##
 | 
			
		||||
##(eight, N) = struct.unpack("@II", data)
 | 
			
		||||
##
 | 
			
		||||
##This unpacks the first two fields, assuming they start at the very
 | 
			
		||||
##beginning of the file (no padding or extraneous data), and also assuming
 | 
			
		||||
##native byte-order (the @ symbol). The Is in the formatting string mean
 | 
			
		||||
##"unsigned integer, 32 bits".
 | 
			
		||||
 | 
			
		||||
#for integers
 | 
			
		||||
#a = int
 | 
			
		||||
#a.from_bytes(b'\xaf\xc2R',byteorder='little')
 | 
			
		||||
#a.to_bytes(nbytes,byteorder='big')
 | 
			
		||||
#analagous operation doens't seem to exist for floats
 | 
			
		||||
#what about numpy?
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#https://www.devdungeon.com/content/working-binary-data-python
 | 
			
		||||
 | 
			
		||||
#print("{:02d}".format(2))
 | 
			
		||||
#b = b.fromhex('010203040506')
 | 
			
		||||
#b.hex()
 | 
			
		||||
#c = b.decode(encoding='utf-8' or 'latin-1' or 'ascii'...)
 | 
			
		||||
#print(c)
 | 
			
		||||
 | 
			
		||||
#numpy arrays have tobytes
 | 
			
		||||
#numpy arrays have frombuffer (converts to dtypes)
 | 
			
		||||
#
 | 
			
		||||
#q = np.array([15],dtype=np.uint8);
 | 
			
		||||
#q.tobytes();
 | 
			
		||||
#q.tobytes(order='C') (options are 'C' and 'F'
 | 
			
		||||
#q2 = np.buffer(q.tobytes(),dtype=np.uint8)
 | 
			
		||||
#np.frombuffer(buffer,dtype=float,count=-1,offset=0)
 | 
			
		||||
 | 
			
		||||
##You could also use the < and > endianess format codes in the struct
 | 
			
		||||
##module to achieve the same result:
 | 
			
		||||
##
 | 
			
		||||
##>>> struct.pack('<2h', *struct.unpack('>2h', original))
 | 
			
		||||
##'\xde\xad\xc0\xde'
 | 
			
		||||
 | 
			
		||||
def bytereverse(bts):
 | 
			
		||||
##    bts2 = bytes(len(bts));
 | 
			
		||||
##    for I in range(0,len(bts)):
 | 
			
		||||
##        bts2[len(bts)-I-1] = bts[I];
 | 
			
		||||
    N = len(bts);
 | 
			
		||||
##    print(N);
 | 
			
		||||
##    print(bts);
 | 
			
		||||
##    bts2 = struct.pack('<{}h'.format(N), *struct.unpack('>{}h'.format(N), bts))
 | 
			
		||||
    bts2 = bts;
 | 
			
		||||
    return bts2;
 | 
			
		||||
 | 
			
		||||
#Read Labels
 | 
			
		||||
def read_MNIST_label_file(fname):
 | 
			
		||||
    #fp = gzip.open('./train-labels-idx1-ubyte.gz','rb');
 | 
			
		||||
    fp = gzip.open(fname,'rb');
 | 
			
		||||
    magic = fp.read(4);
 | 
			
		||||
    #nitems = np.frombuffer(fp.read(4),dtype=np.int32)[0]; #some sort of endiannes problem
 | 
			
		||||
    bts = fp.read(4);
 | 
			
		||||
    #bts = bytereverse(bts);
 | 
			
		||||
    #nitems = np.frombuffer(bts,dtype=np.int32);
 | 
			
		||||
    nitems = np.int32(struct.unpack('>I',bts)[0]); #it was a non-native endianness in teh integer encoding
 | 
			
		||||
    #> < @ - endianness
 | 
			
		||||
 | 
			
		||||
    bts = fp.read(nitems);
 | 
			
		||||
    N = len(bts);
 | 
			
		||||
    labels = np.zeros((N),dtype=np.uint8);
 | 
			
		||||
    labels = np.frombuffer(bts,dtype=np.uint8,count=N);
 | 
			
		||||
    #for i in range(0,10):
 | 
			
		||||
    #    bt = fp.read(1);
 | 
			
		||||
    #    labels[i] = np.frombuffer(bt,dtype=np.uint8);
 | 
			
		||||
    fp.close();
 | 
			
		||||
    return labels;
 | 
			
		||||
 | 
			
		||||
def read_MNIST_image_file(fname):
 | 
			
		||||
    fp = gzip.open(fname,'rb');
 | 
			
		||||
    magic = fp.read(4);
 | 
			
		||||
    bts = fp.read(4);
 | 
			
		||||
    nitems = np.int32(struct.unpack('>I',bts)[0]);
 | 
			
		||||
    bts = fp.read(4);
 | 
			
		||||
    nrows = np.int32(struct.unpack('>I',bts)[0]);
 | 
			
		||||
    bts = fp.read(4);
 | 
			
		||||
    ncols = np.int32(struct.unpack('>I',bts)[0]);
 | 
			
		||||
 | 
			
		||||
    images = np.zeros((nitems,nrows,ncols),dtype=np.uint8);
 | 
			
		||||
    for I in range(0,nitems):
 | 
			
		||||
        bts = fp.read(nrows*ncols);
 | 
			
		||||
        img1 = np.frombuffer(bts,dtype=np.uint8,count=nrows*ncols);
 | 
			
		||||
        img1 = img1.reshape((nrows,ncols));
 | 
			
		||||
        images[I,:,:] = img1;
 | 
			
		||||
 | 
			
		||||
    fp.close();
 | 
			
		||||
 | 
			
		||||
    return images;
 | 
			
		||||
 | 
			
		||||
def read_training_data():
 | 
			
		||||
    rootdir = '/home/aschinde/workspace/machinelearning/datasets/MNIST';
 | 
			
		||||
    fname1 = 'train-labels-idx1-ubyte.gz';
 | 
			
		||||
    fname2 = 'train-images-idx3-ubyte.gz';    
 | 
			
		||||
 | 
			
		||||
    labels = read_MNIST_label_file(os.path.join(rootdir,fname1));
 | 
			
		||||
    images = read_MNIST_image_file(os.path.join(rootdir,fname2));
 | 
			
		||||
 | 
			
		||||
    return [labels,images];
 | 
			
		||||
 | 
			
		||||
def read_test_data():
 | 
			
		||||
    rootdir = '/home/aschinde/workspace/machinelearning/datasets/MNIST';
 | 
			
		||||
 | 
			
		||||
    fname1 = 't10k-labels-idx1-ubyte.gz';
 | 
			
		||||
    fname2 = 't10k-images-idx3-ubyte.gz';
 | 
			
		||||
 | 
			
		||||
    labels = read_MNIST_label_file(os.path.join(rootdir,fname1));
 | 
			
		||||
    images = read_MNIST_image_file(os.path.join(rootdir,fname2));    
 | 
			
		||||
 | 
			
		||||
    return [labels,images];
 | 
			
		||||
 | 
			
		||||
def show_MNIST_image(img):
 | 
			
		||||
    import matplotlib.pyplot as plt;
 | 
			
		||||
    plt.figure();
 | 
			
		||||
    plt.imshow(255-img,cmap='gray');
 | 
			
		||||
    plt.show();
 | 
			
		||||
    return;
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										92
									
								
								training_data/an_mnist_loader.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										92
									
								
								training_data/an_mnist_loader.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,92 @@
 | 
			
		||||
#!/usr/bin/python3
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
mnist_loader
 | 
			
		||||
~~~~~~~~~~~~
 | 
			
		||||
 | 
			
		||||
A library to load the MNIST image data.  For details of the data
 | 
			
		||||
structures that are returned, see the doc strings for ``load_data``
 | 
			
		||||
and ``load_data_wrapper``.  In practice, ``load_data_wrapper`` is the
 | 
			
		||||
function usually called by our neural network code.
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
##sigh: If you want it to run today, write it in Python.
 | 
			
		||||
##If you want it to run tomorrow, write it in ANYTHING ELSE
 | 
			
		||||
 | 
			
		||||
#### Libraries
 | 
			
		||||
# Standard library
 | 
			
		||||
##import cPickle
 | 
			
		||||
import pickle as cPickle
 | 
			
		||||
import gzip
 | 
			
		||||
 | 
			
		||||
# Third-party libraries
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
def load_data():
 | 
			
		||||
    """Return the MNIST data as a tuple containing the training data,
 | 
			
		||||
    the validation data, and the test data.
 | 
			
		||||
 | 
			
		||||
    The ``training_data`` is returned as a tuple with two entries.
 | 
			
		||||
    The first entry contains the actual training images.  This is a
 | 
			
		||||
    numpy ndarray with 50,000 entries.  Each entry is, in turn, a
 | 
			
		||||
    numpy ndarray with 784 values, representing the 28 * 28 = 784
 | 
			
		||||
    pixels in a single MNIST image.
 | 
			
		||||
 | 
			
		||||
    The second entry in the ``training_data`` tuple is a numpy ndarray
 | 
			
		||||
    containing 50,000 entries.  Those entries are just the digit
 | 
			
		||||
    values (0...9) for the corresponding images contained in the first
 | 
			
		||||
    entry of the tuple.
 | 
			
		||||
 | 
			
		||||
    The ``validation_data`` and ``test_data`` are similar, except
 | 
			
		||||
    each contains only 10,000 images.
 | 
			
		||||
 | 
			
		||||
    This is a nice data format, but for use in neural networks it's
 | 
			
		||||
    helpful to modify the format of the ``training_data`` a little.
 | 
			
		||||
    That's done in the wrapper function ``load_data_wrapper()``, see
 | 
			
		||||
    below.
 | 
			
		||||
    """
 | 
			
		||||
    #f = gzip.open('../data/mnist.pkl.gz', 'rb')
 | 
			
		||||
    f = gzip.open('./t10k-images-idx3-ubyte.gz','rb');
 | 
			
		||||
    training_data, validation_data, test_data = cPickle.load(f)
 | 
			
		||||
    f.close()
 | 
			
		||||
    return (training_data, validation_data, test_data)
 | 
			
		||||
 | 
			
		||||
def load_data_wrapper():
 | 
			
		||||
    """Return a tuple containing ``(training_data, validation_data,
 | 
			
		||||
    test_data)``. Based on ``load_data``, but the format is more
 | 
			
		||||
    convenient for use in our implementation of neural networks.
 | 
			
		||||
 | 
			
		||||
    In particular, ``training_data`` is a list containing 50,000
 | 
			
		||||
    2-tuples ``(x, y)``.  ``x`` is a 784-dimensional numpy.ndarray
 | 
			
		||||
    containing the input image.  ``y`` is a 10-dimensional
 | 
			
		||||
    numpy.ndarray representing the unit vector corresponding to the
 | 
			
		||||
    correct digit for ``x``.
 | 
			
		||||
 | 
			
		||||
    ``validation_data`` and ``test_data`` are lists containing 10,000
 | 
			
		||||
    2-tuples ``(x, y)``.  In each case, ``x`` is a 784-dimensional
 | 
			
		||||
    numpy.ndarry containing the input image, and ``y`` is the
 | 
			
		||||
    corresponding classification, i.e., the digit values (integers)
 | 
			
		||||
    corresponding to ``x``.
 | 
			
		||||
 | 
			
		||||
    Obviously, this means we're using slightly different formats for
 | 
			
		||||
    the training data and the validation / test data.  These formats
 | 
			
		||||
    turn out to be the most convenient for use in our neural network
 | 
			
		||||
    code."""
 | 
			
		||||
    tr_d, va_d, te_d = load_data()
 | 
			
		||||
    training_inputs = [np.reshape(x, (784, 1)) for x in tr_d[0]]
 | 
			
		||||
    training_results = [vectorized_result(y) for y in tr_d[1]]
 | 
			
		||||
    training_data = zip(training_inputs, training_results)
 | 
			
		||||
    validation_inputs = [np.reshape(x, (784, 1)) for x in va_d[0]]
 | 
			
		||||
    validation_data = zip(validation_inputs, va_d[1])
 | 
			
		||||
    test_inputs = [np.reshape(x, (784, 1)) for x in te_d[0]]
 | 
			
		||||
    test_data = zip(test_inputs, te_d[1])
 | 
			
		||||
    return (training_data, validation_data, test_data)
 | 
			
		||||
 | 
			
		||||
def vectorized_result(j):
 | 
			
		||||
    """Return a 10-dimensional unit vector with a 1.0 in the jth
 | 
			
		||||
    position and zeroes elsewhere.  This is used to convert a digit
 | 
			
		||||
    (0...9) into a corresponding desired output from the neural
 | 
			
		||||
    network."""
 | 
			
		||||
    e = np.zeros((10, 1))
 | 
			
		||||
    e[j] = 1.0
 | 
			
		||||
    return e
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								training_data/t10k-images-idx3-ubyte.gz
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								training_data/t10k-images-idx3-ubyte.gz
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								training_data/t10k-labels-idx1-ubyte.gz
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								training_data/t10k-labels-idx1-ubyte.gz
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								training_data/train-images-idx3-ubyte.gz
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								training_data/train-images-idx3-ubyte.gz
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								training_data/train-labels-idx1-ubyte.gz
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								training_data/train-labels-idx1-ubyte.gz
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
		Reference in New Issue
	
	Block a user