Making some of my models public.

master
Aaron 3 weeks ago
commit eb2c910e29

@ -0,0 +1,2 @@
Copyright 2023, Aaron M. Schinder
Released under the MIT/BSD License

@ -0,0 +1,91 @@
#!/usr/bin/python3
import os,sys,math
import numpy as np
import cv2
import gzip #need to use gzip.open instead of open
import struct
import torch
def read_MNIST_label_file(fname):
#fp = gzip.open('./train-labels-idx1-ubyte.gz','rb');
fp = gzip.open(fname,'rb');
magic = fp.read(4);
#nitems = np.frombuffer(fp.read(4),dtype=np.int32)[0]; #some sort of endiannes problem
bts = fp.read(4);
#bts = bytereverse(bts);
#nitems = np.frombuffer(bts,dtype=np.int32);
nitems = np.int32(struct.unpack('>I',bts)[0]); #it was a non-native endianness in teh integer encoding
#> < @ - endianness
bts = fp.read(nitems);
N = len(bts);
labels = np.zeros((N),dtype=np.uint8);
labels = np.frombuffer(bts,dtype=np.uint8,count=N);
#for i in range(0,10):
# bt = fp.read(1);
# labels[i] = np.frombuffer(bt,dtype=np.uint8);
fp.close();
return labels;
def read_MNIST_image_file(fname):
fp = gzip.open(fname,'rb');
magic = fp.read(4);
bts = fp.read(4);
nitems = np.int32(struct.unpack('>I',bts)[0]);
bts = fp.read(4);
nrows = np.int32(struct.unpack('>I',bts)[0]);
bts = fp.read(4);
ncols = np.int32(struct.unpack('>I',bts)[0]);
images = np.zeros((nitems,nrows,ncols),dtype=np.uint8);
for I in range(0,nitems):
bts = fp.read(nrows*ncols);
img1 = np.frombuffer(bts,dtype=np.uint8,count=nrows*ncols);
img1 = img1.reshape((nrows,ncols));
images[I,:,:] = img1;
fp.close();
return images;
#The mnist dataset is small enough to fit entirely in memory
def mnist_load():
baseloc = "../training_data"
traindatafile = "train-images-idx3-ubyte.gz"
trainlabelfile = "train-labels-idx1-ubyte.gz"
testdatafile = "t10k-images-idx3-ubyte.gz"
testlabelfile = "t10k-labels-idx1-ubyte.gz"
traindatafile = os.path.join(baseloc,traindatafile)
trainlabelfile = os.path.join(baseloc,trainlabelfile)
testdatafile = os.path.join(baseloc,testdatafile)
testlabelfile = os.path.join(baseloc,testlabelfile)
labels_train = read_MNIST_label_file(trainlabelfile)
labels_test = read_MNIST_label_file(testlabelfile)
images_train = read_MNIST_image_file(traindatafile)
images_test = read_MNIST_image_file(testdatafile)
labels_train = torch.tensor(labels_train,dtype=torch.float32,requires_grad=False)
labels_test = torch.tensor(labels_test,dtype=torch.float32,requires_grad=False)
images_train = torch.tensor(images_train,dtype=torch.float32,requires_grad=False)
images_test = torch.tensor(images_test,dtype=torch.float32,requires_grad=False)
# #debug
# print(labels_train.shape)
# print(labels_test.shape)
# print(images_train.shape)
# print(images_test.shape)
return [labels_train, labels_test, images_train, images_test]
if(__name__ == "__main__"):
[labels_train, labels_test, images_train, images_test] = mnist_load()
print("Loaded MNIST Data")
print(labels_train.shape)
print(labels_test.shape)
print(images_train.shape)
print(images_test.shape)

@ -0,0 +1,401 @@
#!/usr/bin/python3
import os,sys,math
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import mnist_dataloader
from mnistdiffusion_model import mndiff_rev
from mnistdiffusion_utils import *
import time
################
## Dataloader ##
################
def mnist_load():
[labels_train, labels_test, images_train, images_test] = mnist_dataloader.mnist_load()
#60000 x 28 x 28
#60000 x 1
#Rescale data to (-1,1)
images_train = images_train.clone().detach().float()
images_test = images_test.clone().detach().float()
images_train = 2.0*(images_train/255.0)-1.0
images_test = 2.0*(images_test/255.0)-1.0
images_train.requires_grad = False
images_test.requires_grad = False
return [images_train, images_test]
################
## Operations ##
################
# Rather than killing myself figuring out variable variance schedules,
# just use a constant variance schedule to start with
def training_parameters():
lparas = dict()
nts = 200
Beta = 0.0050
lr = 1E-4
batchsize = 80
save_every = 200
record_every = 25
B_t = np.ones((nts))*Beta
a_t = np.zeros((nts))
a_t[0] = 1
for I in range(1,nts):
a_t[I] = a_t[I-1]*(1-B_t[I-1])
t = np.linspace(0,1,nts) #normalized time (t/T)
lparas["nts"] = nts
lparas["Beta"] = Beta
lparas["t"] = torch.tensor(t,dtype=torch.float32)
lparas["a_t"] = torch.tensor(a_t,dtype=torch.float32)
lparas["B_t"] = torch.tensor(B_t,dtype=torch.float32)
lparas["lr"] = lr
lparas["batchsize"] = batchsize
lparas["save_every"] = save_every
lparas["record_every"] = record_every
return lparas
def generate_minibatch(imgs, Ndraw, lparas):
#I can probably speed this up with compiled CUDA code
nts = lparas["nts"]
Beta = lparas["Beta"]
dev = imgs.device
imbatch = torch.zeros([Ndraw,nts,28,28],dtype=torch.float32).to(dev)
draws = torch.randint(0,imgs.shape[0],(Ndraw,))
times = lparas["t"].to(dev)
times = times.reshape([1,nts]).expand([Ndraw,nts])
for I in range(0,Ndraw):
imgseq = torch.zeros([nts,28,28],dtype=torch.float32).to(dev)
imgseq[0,:,:] = imgs[draws[I],:,:]
beta = torch.kron(lparas["B_t"],torch.ones(28,28)).to(dev)
beta = beta.reshape([nts,28,28])
sig = torch.sqrt(beta)
noise = torch.randn((nts,28,28)).to(dev)*sig
for J in range(1,nts):
imgseq[J,:,:] = imgseq[J-1,:,:]*torch.sqrt(1.0-beta[J-1,:,:]) + noise[J-1,:,:]
imbatch[I,:,:,:] = imgseq[:,:,:]
return [imbatch,times]
def train_batch(revnet,batchimg,batchtime,lparas):
L = 0.0
lr = lparas["lr"]
nts = batchtime.shape[1]
dev = revnet.device
#beta = torch.as_tensor(lparas["Beta"]).float().to(dev)
beta = float(lparas["Beta"])
optim = torch.optim.SGD(revnet.parameters(),lr=lr)
optim = torch.optim.Adam(revnet.parameters(),lr=lr)
loss = torch.nn.MSELoss().to(dev)
d_est = revnet(batchimg,batchtime)
d_est = d_est[:,1:nts,:,:]
targ = batchimg[:,1:nts,:,:]-batchimg[:,0:nts-1,:,:]
L2 = loss(targ,beta*d_est)
L2.backward()
optim.step()
L = L2.detach().cpu().numpy()
L = L/beta**2
loss.zero_grad()
return L
def train(directory,saveseries):
print("Loading dataset...")
[images_train, images_test] = mnist_load()
lparas = training_parameters()
batchsize = lparas["batchsize"]
save_every = lparas["save_every"]
record_every = lparas["record_every"]
device = torch.device("cuda")
revnet = mndiff_rev()
fname = get_current_save(directory,saveseries)
if(fname is None):
fname = get_next_save(directory,saveseries)
else:
print("Loading {}".format(fname))
revnet.load(fname)
images_train = images_train.to(device)
revnet = revnet.to(device)
## Training Loop
I = 0
J = 0
while(True):
if(I%record_every==0):
print("Minibatch {}".format(I))
#New minibatch
losses = np.zeros((record_every))
#draw minibatch
[mb,mbt] = generate_minibatch(images_train,batchsize,lparas)
#train minibatch
L = train_batch(revnet,mb,mbt,lparas)
losses[J] = L
J = J + 1
#record results
if(J%record_every==0):
J = 0
loss = np.mean(losses)
revnet.recordloss(loss)
print("\tloss={:1.5f}".format(loss))
#save
if(I%save_every==0):
fnext = get_next_save(directory,saveseries)
revnet.save(fnext)
I = I + 1
return
def infer(directory,saveseries):
ttt = int(time.time())
np.random.seed(ttt%50000)
torch.manual_seed(ttt%50000)
lparas = training_parameters()
revnet = mndiff_rev()
fname = get_current_save(directory,saveseries)
if(fname is None):
print("Warning, no state loaded.")
else:
print("Loading {}".format(fname))
revnet.load(fname)
nts = lparas["nts"]
beta = lparas["Beta"]
nts = nts*15
imgseq = torch.zeros([nts,28,28])
imgseq[nts-1,:,:] = torch.randn(28,28)
#imgseq[nts-1,:,17:20] = 3.0
tl = torch.linspace(0,1,nts)
bb = torch.as_tensor(beta)
for I in range(nts-2,-1,-1):
img = imgseq[I+1,:,:]
t = tl[I]
eps = revnet(img,t)
img2 = img*torch.sqrt(1-bb) - bb*eps + 0.5*bb*torch.randn(28,28)
imgseq[I,:,:] = img2
im0 = torch.squeeze(imgseq[0,:,:])
im1 = torch.squeeze(imgseq[int(nts*0.25),:,:])
im2 = torch.squeeze(imgseq[int(nts*0.75),:,:])
im3 = torch.squeeze(imgseq[int(nts*0.90),:,:])
im0 = im0.detach().numpy()
im1 = im1.detach().numpy()
im2 = im2.detach().numpy()
im3 = im3.detach().numpy()
plt.figure()
plt.subplot(2,2,1)
plt.imshow(im3,cmap='gray')
plt.subplot(2,2,2)
plt.imshow(im2,cmap='gray')
plt.subplot(2,2,3)
plt.imshow(im1,cmap='gray')
plt.subplot(2,2,4)
plt.imshow(im0,cmap='gray')
plt.show()
return
def plot_history(directory, saveseries):
revnet = mndiff_rev()
fname = get_current_save(directory,saveseries)
if(fname is None):
fname = get_next_save(directory,saveseries)
else:
print("Loading {}".format(fname))
revnet.load(fname)
revnet.plothistory()
return
def test1():
[images_train, images_test] = mnist_load()
lparas = training_parameters()
#plt.figure()
#plt.plot(lparas["t"],lparas["a_t"])
#plt.show()
#plt.figure()
#plt.plot(lparas["t"],lparas["B_t"])
#plt.show()
images_train = images_train.to("cuda")
[mb,mbtime] = generate_minibatch(images_train,50,lparas)
img0 = torch.squeeze(mb[0,0,:,:]).clone().detach().to("cpu")
img1 = torch.squeeze(mb[0,1,:,:]).clone().detach().to("cpu")
img2 = torch.squeeze(mb[0,2,:,:]).clone().detach().to("cpu")
img3 = torch.squeeze(mb[0,99,:,:]).clone().detach().to("cpu")
print(mb.shape)
print(mbtime.shape)
plt.figure()
plt.subplot(2,2,1)
plt.imshow(img0,cmap='gray')
plt.subplot(2,2,2)
plt.imshow(img1,cmap='gray')
plt.subplot(2,2,3)
plt.imshow(img2,cmap='gray')
plt.subplot(2,2,4)
plt.imshow(img3,cmap='gray')
plt.show()
return
def test2(directory, saveseries):
lparas = training_parameters()
nts = lparas["nts"]
beta = lparas["Beta"]
revnet = mndiff_rev()
fname = get_current_save(directory,saveseries)
if(fname is None):
fname = get_next_save(directory,saveseries)
else:
print("Loading {}".format(fname))
revnet.load(fname)
[images_train, images_test] = mnist_load()
x = torch.squeeze(images_train[100,:,:])
#x = beta*torch.randn(28,28)
y = revnet(x,0.5)
x = x.detach().numpy()
y = y.detach().numpy()
print("y mn={:1.3f} std={:1.3f} std2*beta={}".format(np.mean(y),np.std(y),np.std(y)**2*beta))
plt.figure()
plt.subplot(2,1,1)
plt.imshow(x,cmap='gray')
plt.subplot(2,1,2)
plt.imshow(y,cmap='gray')
plt.show()
return
##########
## Main ##
##########
def mainswitch():
args = sys.argv
if(len(args)<2):
print("mnistdiffusion [operation]")
print("operations:")
print("\ttrain")
print("\tplot_history")
print("\tinfer")
exit(0)
if(sys.argv[1]=="train"):
train("./saves","test_f")
if(sys.argv[1]=="plot_history"):
plot_history("./saves","test_f")
if(sys.argv[1]=="infer"):
infer("./saves","test_f")
if(__name__=="__main__"):
ttt = int(time.time())
np.random.seed(ttt%50000)
torch.manual_seed(ttt%50000)
#test1()
#test2("./saves","test_d")
mainswitch()

@ -0,0 +1,289 @@
#!/usr/bin/python3
import os,sys,math
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import mnist_dataloader
## Sohl-Dickstein 2015 describes a gaussian "bump" function to
# implement learnable time-dependence
#
#network input vector and normalized time (tn e (0,1))
# input: (dimx, dimy, ..., features), (scalar) --> output: (dimx, dimy, ...)
# input: (ntimes, ..., features), (ntimes) --> output: (ntimes , ...)
class gaussian_timeselect(nn.Module):
def __init__(self, nfeatures=5, width = 0.25):
super(gaussian_timeselect,self).__init__()
tau_j = torch.linspace(0,1,nfeatures).float()
self.tau_j = torch.nn.parameter.Parameter(tau_j, requires_grad=False)
width = torch.tensor(width,dtype=torch.float32)
self.width = torch.nn.parameter.Parameter(width,requires_grad = False)
#network input vector and normalized time (tn e (0,1))
# input: (dimx, dimy, ..., features), (scalar) --> output: (dimx, dimy, ...)
# input: (ntimes, ..., features), (ntimes) --> output: (ntimes , ...)
def forward(self, x, tn):
ndim = len(x.shape)
tn = torch.as_tensor(tn)
tn.requires_grad = False
if(len(tn.shape)==0):
#case 0 - if tn is a scalar
g = torch.exp(-1.0/(2.0*self.width)*(tn-self.tau_j)**2)
g = g/torch.clip(torch.sum(g),min=0.001)
x = torch.tensordot(x,g,dims=[[ndim-1],[0]])
else:
#Case 1 - tn is a tensor
nf = self.tau_j.shape[0]
nt = tn.numel()
nt2 = x.shape[0]
ndim = len(x.shape)
nf2 = x.shape[ndim-1]
if(nt!=nt2):
raise Exception("gaussian_timeselect time mismatch")
if(nf!=nf2):
raise Exception("gaussian_timeselect feature mismatch")
shp2 = x.shape[1:ndim-1]
nshp2 = shp2.numel()
x = x.reshape([nt,nshp2,nf])
tau = self.tau_j.reshape([1,1,nf])
tau = tau.expand([nt,nshp2,nf])
tn = tn.reshape([nt,1,1])
tn = tn.expand([nt,nshp2,nf])
g = torch.exp(-1.0/(2.0*self.width)*(tn-tau)**2)
gs = torch.clip(torch.sum(g,axis=2),min=0.0001)
gs = gs.reshape([nt,nshp2,1]).expand([nt,nshp2,nf])
g = g/gs
x = torch.sum(x*g,axis=2)
#shp3 = torch.Size(torch.cat(torch.tensor([nt]),torch.tensor(shp2)))
# contact [nt, *shp2]
shp3 = torch.Size([nt,*shp2])
x = x.reshape(shp3)
return x
#The reverse network should output a mean image and a covariance image
#given an input image
class mndiff_rev(nn.Module):
def __init__(self):
super(mndiff_rev,self).__init__()
self.L1 = nn.Linear(28*28,1000,bias=True)
self.L2 = nn.Linear(1000,1000,bias=True)
self.L3 = nn.Linear(1000,28*28*5,bias=True)
#self.L4 = nn.Conv2d(10,10,kernel_size=1)
#channels must be 2nd dimension
self.L4a = nn.Linear(5,5)
#equivalent, channels must be last dimension
self.L4b = nn.Linear(5,5)
self.L5 = gaussian_timeselect(5,0.1)
#self.L5alt = nn.Linear(5,1)
#self.RL = nn.LeakyReLU(0.01)
self.tanh = nn.Tanh()
self.device = torch.device("cpu")
self.losstime = []
self.lossrecord = []
return
def forward(self,x,t):
#handle data-shape cases
#case 0: x (pix_x, pix_y), t is scalar
#case 1: x (ntimes, pix_x, pix_y), t is (ntimes)
#case 2: x (nbatch, ntimes, pix_x, pix_y), t is (nbatch, ntimes)
if(len(x.shape)==2):
case = 0
t = torch.as_tensor(t).to(self.device)
t = t.float()
x = x.reshape([1,28*28])
ntimes = 1
nbatch = 1
elif(len(x.shape)==3):
case = 1
ntimes = x.shape[0]
nbatch = 1
x = x.reshape([ntimes,28*28])
elif(len(x.shape)==4):
case = 2
nbatch = x.shape[0]
ntimes = x.shape[1]
#x = x.reshape([nbatch,ntimes,28*28])
x = x.reshape([nbatch*ntimes,28*28])
t = t.reshape([nbatch*ntimes])
y = torch.randn(28*28,device=self.device)
else:
print("Error: Expect input to be dimension 2,3,4")
raise Exception("Error: Expect input to be dimension 2,3,4")
s1 = self.L1(x)
s1 = self.tanh(s1)
s1 = self.L2(s1)
s1 = self.tanh(s1)
s1 = self.L3(s1)
s1 = self.tanh(s1)
s1 = s1.reshape([ntimes*nbatch,28*28,5])
s1 = self.L4a(s1)
s1 = self.tanh(s1)
s1 = self.L4b(s1)
#s1 = self.tanh(s1)
s2 = self.L5(s1,t)
#s2 = self.L5alt(s1)
eps = s2
#eps = self.tanh(s2)
if(case==0):
eps = eps.reshape([28,28])
elif(case==1):
eps = eps.reshape([ntimes,28,28])
elif(case==2):
eps = eps.reshape([nbatch,ntimes,28,28])
return eps
def to(self,dev):
self.device = torch.device(dev)
ret = super(mndiff_rev,self).to(self.device)
return ret
def save(self,fname):
try:
dev = self.device
self.to("cpu")
q = [self.state_dict(), self.losstime, self.lossrecord]
torch.save(q,fname)
self.to(dev)
except:
print("model save: problem saving {}".format(fname))
return
def load(self,fname):
try:
[modelsd,losstime,lossrecord] = torch.load(fname)
self.load_state_dict(modelsd)
self.losstime = losstime
self.lossrecord = lossrecord
except:
print("model load: problem loading {}".format(fname))
return
def plothistory(self):
x = self.losstime
y = self.lossrecord
plt.figure()
plt.plot(x,y,'k.',markersize=1)
plt.title('Loss History')
plt.show()
return
def recordloss(self,loss):
t = 0
L = len(self.losstime)
if(L>0):
t = self.losstime[L-1]
self.losstime.append(t+1)
self.lossrecord.append(loss)
return
###########
## Tests ##
###########
def test_gaussian_timeselect():
gts = gaussian_timeselect(5,0.025)
gts = gts.to("cuda")
x = torch.randn((8,8,5)).to("cuda")
#t = torch.tensor([0,0.5,1]).to("cuda")
t = torch.linspace(0,1,8).to("cuda")
y = gts(x,0)
y2 = gts(x,t)
y3 = gts(x,0.111)
print(y)
print(x[:,:,0])
print(x.shape)
print(y.shape)
print(y2.shape)
print(torch.abs(y2[0,:]-y[0,:])>1E-6)
return
def test_mndiff_rev():
mnd = mndiff_rev()
x0 = torch.randn(28,28)
t0 = 0
x1 = torch.randn(5,28,28)
t1 = torch.linspace(0,1,5)
x2 = torch.randn(8,5,28,28)
t2 = torch.linspace(0,1,5)
t2 = t2.reshape([1,5]).expand([8,5])
y0 = mnd(x0,t0)
y1 = mnd(x1,t1)
y2 = mnd(x2,t2)
print(y0.shape)
print(y1.shape)
print(y2.shape)
return
if(__name__=="__main__"):
test_gaussian_timeselect()
#test_mndiff_rev()
pass

@ -0,0 +1,78 @@
#!/usr/bin/python3
import os,sys,math
def get_savenumber(fname):
num = 0
fn2 = os.path.split(fname)[1]
fn2 = os.path.splitext(fn2)[0]
L = len(fn2)
L2 = L
for I in range(L-1,0,-1):
c = fn2[I]
if(not c.isnumeric()):
L2 = I+1
break
nm = fn2[L2:L]
try:
num = int(nm)
except:
num = 0
return num
def sorttwopythonlists(list1, list2):
zipped_pairs = zip(list2, list1)
z = [x for _, x in sorted(zipped_pairs)]
z2 = sorted(list2)
return [z,z2]
def get_current_save(directory,saveseries):
fname = None
fl = os.listdir(directory)
fl2 = []
nums = []
for f in fl:
if(f.find(saveseries)>=0):
fl2.append(os.path.join(directory,f))
for f in fl2:
n = get_savenumber(f)
nums.append(n)
[fl2,nums] = sorttwopythonlists(fl2,nums)
if(len(fl2)>0):
fname = fl2[len(fl2)-1]
else:
fname = None
return fname
def get_next_save(directory,saveseries):
fname = None
fncurrent = get_current_save(directory,saveseries)
if(fncurrent is None):
fname = os.path.join(directory,"{}{:02d}.pyt".format(saveseries,0))
else:
N = get_savenumber(fncurrent)
N = N + 1
fname = os.path.join(directory,"{}{:02d}.pyt".format(saveseries,N))
return fname
if(__name__=="__main__"):
#print(get_savenumber("./saves/helloworld05202.py"))
pass

@ -0,0 +1,161 @@
#!/usr/bin/python3
import os,sys,math
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import mnist_dataloader
from mnistdiffusion_model import mndiff_rev
################
## Dataloader ##
################
def mnist_load():
[labels_train, labels_test, images_train, images_test] = mnist_dataloader.mnist_load()
#60000 x 28 x 28
#60000 x 1
#Rescale data to (-1,1)
images_train = images_train.clone().detach().float()
images_test = images_test.clone().detach().float()
images_train = 2.0*(images_train/255.0)-1.0
images_test = 2.0*(images_test/255.0)-1.0
images_train.requires_grad = False
images_test.requires_grad = False
return [images_train, images_test]
################
## Operations ##
################
def training_parameters():
lparas = dict()
nts = 200
t = np.arange(0,nts)
s = 0.1
f_t = np.cos((t/nts+s)/(1+s)*np.pi/2.0)**2
a_t = f_t/f_t[0]
B_t = np.zeros(t.shape[0])
B_t[0:nts-1] = np.clip(1-a_t[1:nts]/a_t[0:nts-1],0,0.999)
B_t[nts-1] = B_t[nts-2]
lparas["nts"] = nts
lparas["t"] = torch.tensor(t,dtype=torch.float32)
lparas["a_t"] = torch.tensor(a_t,dtype=torch.float32)
lparas["B_t"] = torch.tensor(B_t,dtype=torch.float32)
#Traiing parameters for a cosine variance schedule
#ref: Nichol and Dharawal 2021
return lparas
def generate_minibatch(imgs, Ndraw, lparas):
#I can probably speed this up with compiled CUDA code
nts = lparas["nts"]
dev = imgs.device
imbatch = torch.zeros([Ndraw*nts,28,28],dtype=torch.float32).to(dev)
draws = torch.randint(0,imgs.shape[0],(Ndraw,))
for I in range(0,Ndraw):
imgseq = torch.zeros([nts,28,28],dtype=torch.float32).to(dev)
imgseq[0,:,:] = imgs[draws[I],:,:]
beta = torch.kron(lparas["B_t"],torch.ones(28,28)).to(dev)
beta = beta.reshape([nts,28,28])
sig = torch.sqrt(beta)
noise = torch.randn((nts,28,28)).to(dev)*sig
for J in range(1,nts):
imgseq[J,:,:] = imgseq[J-1,:,:] + noise[J-1,:,:]
I1 = I*nts
I2 = (I+1)*nts
imbatch[I1:I2,:,:] = imgseq[:,:,:]
return imbatch
def train_batch(imgbatch,lr,lparas):
return
def train():
return
def infer():
return
def test1():
[images_train, images_test] = mnist_load()
lparas = training_parameters()
#plt.figure()
#plt.plot(lparas["t"],lparas["a_t"])
#plt.show()
#plt.figure()
#plt.plot(lparas["t"],lparas["B_t"])
#plt.show()
images_train = images_train.to("cuda")
mb = generate_minibatch(images_train,50,lparas)
img0 = torch.squeeze(mb[0,:,:]).clone().detach().to("cpu")
img1 = torch.squeeze(mb[1,:,:]).clone().detach().to("cpu")
img2 = torch.squeeze(mb[2,:,:]).clone().detach().to("cpu")
img3 = torch.squeeze(mb[30,:,:]).clone().detach().to("cpu")
plt.figure()
plt.subplot(2,2,1)
plt.imshow(img0,cmap='gray')
plt.subplot(2,2,2)
plt.imshow(img1,cmap='gray')
plt.subplot(2,2,3)
plt.imshow(img2,cmap='gray')
plt.subplot(2,2,4)
plt.imshow(img3,cmap='gray')
plt.show()
return
##########
## Main ##
##########
if(__name__=="__main__"):
#test1()
test_gaussian_timeselect()

@ -0,0 +1,85 @@
{
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9-final"
},
"orig_nbformat": 2,
"kernelspec": {
"name": "python3",
"display_name": "Python 3.6.9 64-bit",
"metadata": {
"interpreter": {
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
}
}
}
},
"nbformat": 4,
"nbformat_minor": 2,
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os,sys,math\n",
"import numpy as np\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"import mnist_dataloader\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"torch.Size([])\ntorch.Size([])\n[1, 2, 3]\n"
]
}
],
"source": [
"a = torch.randn(10,14)\n",
"b = a.shape[1:1]\n",
"print(b)\n",
"b.numel()\n",
"\n",
"print(b)\n",
"\n",
"b = torch.Size([1])\n",
"c = torch.Size([2,3])\n",
"d = torch.Size(torch.cat([torch.tensor(b),torch.tensor(c)]))\n",
"\n",
"d = [*b,*c]\n",
"\n",
"print(d)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
]
}

@ -0,0 +1,91 @@
#!/usr/bin/python3
import os,sys,math
import numpy as np
import cv2
import gzip #need to use gzip.open instead of open
import struct
import torch
def read_MNIST_label_file(fname):
#fp = gzip.open('./train-labels-idx1-ubyte.gz','rb');
fp = gzip.open(fname,'rb');
magic = fp.read(4);
#nitems = np.frombuffer(fp.read(4),dtype=np.int32)[0]; #some sort of endiannes problem
bts = fp.read(4);
#bts = bytereverse(bts);
#nitems = np.frombuffer(bts,dtype=np.int32);
nitems = np.int32(struct.unpack('>I',bts)[0]); #it was a non-native endianness in teh integer encoding
#> < @ - endianness
bts = fp.read(nitems);
N = len(bts);
labels = np.zeros((N),dtype=np.uint8);
labels = np.frombuffer(bts,dtype=np.uint8,count=N);
#for i in range(0,10):
# bt = fp.read(1);
# labels[i] = np.frombuffer(bt,dtype=np.uint8);
fp.close();
return labels;
def read_MNIST_image_file(fname):
fp = gzip.open(fname,'rb');
magic = fp.read(4);
bts = fp.read(4);
nitems = np.int32(struct.unpack('>I',bts)[0]);
bts = fp.read(4);
nrows = np.int32(struct.unpack('>I',bts)[0]);
bts = fp.read(4);
ncols = np.int32(struct.unpack('>I',bts)[0]);
images = np.zeros((nitems,nrows,ncols),dtype=np.uint8);
for I in range(0,nitems):
bts = fp.read(nrows*ncols);
img1 = np.frombuffer(bts,dtype=np.uint8,count=nrows*ncols);
img1 = img1.reshape((nrows,ncols));
images[I,:,:] = img1;
fp.close();
return images;
#The mnist dataset is small enough to fit entirely in memory
def mnist_load():
baseloc = "../training_data"
traindatafile = "train-images-idx3-ubyte.gz"
trainlabelfile = "train-labels-idx1-ubyte.gz"
testdatafile = "t10k-images-idx3-ubyte.gz"
testlabelfile = "t10k-labels-idx1-ubyte.gz"
traindatafile = os.path.join(baseloc,traindatafile)
trainlabelfile = os.path.join(baseloc,trainlabelfile)
testdatafile = os.path.join(baseloc,testdatafile)
testlabelfile = os.path.join(baseloc,testlabelfile)
labels_train = read_MNIST_label_file(trainlabelfile)
labels_test = read_MNIST_label_file(testlabelfile)
images_train = read_MNIST_image_file(traindatafile)
images_test = read_MNIST_image_file(testdatafile)
labels_train = torch.tensor(labels_train,dtype=torch.float32,requires_grad=False)
labels_test = torch.tensor(labels_test,dtype=torch.float32,requires_grad=False)
images_train = torch.tensor(images_train,dtype=torch.float32,requires_grad=False)
images_test = torch.tensor(images_test,dtype=torch.float32,requires_grad=False)
# #debug
# print(labels_train.shape)
# print(labels_test.shape)
# print(images_train.shape)
# print(images_test.shape)
return [labels_train, labels_test, images_train, images_test]
if(__name__ == "__main__"):
[labels_train, labels_test, images_train, images_test] = mnist_load()
print("Loaded MNIST Data")
print(labels_train.shape)
print(labels_test.shape)
print(images_train.shape)
print(images_test.shape)

@ -0,0 +1,401 @@
#!/usr/bin/python3
import os,sys,math
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import mnist_dataloader
from mnistdiffusion_model import mndiff_rev
from mnistdiffusion_utils import *
import time
################
## Dataloader ##
################
def mnist_load():
[labels_train, labels_test, images_train, images_test] = mnist_dataloader.mnist_load()
#60000 x 28 x 28
#60000 x 1
#Rescale data to (-1,1)
images_train = images_train.clone().detach().float()
images_test = images_test.clone().detach().float()
images_train = 2.0*(images_train/255.0)-1.0
images_test = 2.0*(images_test/255.0)-1.0
images_train.requires_grad = False
images_test.requires_grad = False
return [images_train, images_test]
################
## Operations ##
################
# Rather than killing myself figuring out variable variance schedules,
# just use a constant variance schedule to start with
def training_parameters():
lparas = dict()
nts = 200
Beta = 0.0050
lr = 1E-4
batchsize = 40
save_every = 200
record_every = 25
B_t = np.ones((nts))*Beta
a_t = np.zeros((nts))
a_t[0] = 1
for I in range(1,nts):
a_t[I] = a_t[I-1]*(1-B_t[I-1])
t = np.linspace(0,1,nts) #normalized time (t/T)
lparas["nts"] = nts
lparas["Beta"] = Beta
lparas["t"] = torch.tensor(t,dtype=torch.float32)
lparas["a_t"] = torch.tensor(a_t,dtype=torch.float32)
lparas["B_t"] = torch.tensor(B_t,dtype=torch.float32)
lparas["lr"] = lr
lparas["batchsize"] = batchsize
lparas["save_every"] = save_every
lparas["record_every"] = record_every
return lparas
def generate_minibatch(imgs, Ndraw, lparas):
#I can probably speed this up with compiled CUDA code
nts = lparas["nts"]
Beta = lparas["Beta"]
dev = imgs.device
imbatch = torch.zeros([Ndraw,nts,28,28],dtype=torch.float32).to(dev)
draws = torch.randint(0,imgs.shape[0],(Ndraw,))
times = lparas["t"].to(dev)
times = times.reshape([1,nts]).expand([Ndraw,nts])
for I in range(0,Ndraw):
imgseq = torch.zeros([nts,28,28],dtype=torch.float32).to(dev)
imgseq[0,:,:] = imgs[draws[I],:,:]
beta = torch.kron(lparas["B_t"],torch.ones(28,28)).to(dev)
beta = beta.reshape([nts,28,28])
sig = torch.sqrt(beta)
noise = torch.randn((nts,28,28)).to(dev)*sig
for J in range(1,nts):
imgseq[J,:,:] = imgseq[J-1,:,:]*torch.sqrt(1.0-beta[J-1,:,:]) + noise[J-1,:,:]
imbatch[I,:,:,:] = imgseq[:,:,:]
return [imbatch,times]
def train_batch(revnet,batchimg,batchtime,lparas):
L = 0.0
lr = lparas["lr"]
nts = batchtime.shape[1]
dev = revnet.device
#beta = torch.as_tensor(lparas["Beta"]).float().to(dev)
beta = float(lparas["Beta"])
optim = torch.optim.SGD(revnet.parameters(),lr=lr)
optim = torch.optim.Adam(revnet.parameters(),lr=lr)
loss = torch.nn.MSELoss().to(dev)
d_est = revnet(batchimg,batchtime)
d_est = d_est[:,1:nts,:,:]
targ = batchimg[:,1:nts,:,:]-batchimg[:,0:nts-1,:,:]
L2 = loss(targ,beta*d_est)
L2.backward()
optim.step()
L = L2.detach().cpu().numpy()
L = L/beta**2
loss.zero_grad()
return L
def train(directory,saveseries):
print("Loading dataset...")
[images_train, images_test] = mnist_load()
lparas = training_parameters()
batchsize = lparas["batchsize"]
save_every = lparas["save_every"]
record_every = lparas["record_every"]
device = torch.device("cuda")
revnet = mndiff_rev()
fname = get_current_save(directory,saveseries)
if(fname is None):
fname = get_next_save(directory,saveseries)
else:
print("Loading {}".format(fname))
revnet.load(fname)
images_train = images_train.to(device)
revnet = revnet.to(device)
## Training Loop
I = 0
J = 0
while(True):
if(I%record_every==0):
print("Minibatch {}".format(I))
#New minibatch
losses = np.zeros((record_every))
#draw minibatch
[mb,mbt] = generate_minibatch(images_train,batchsize,lparas)
#train minibatch
L = train_batch(revnet,mb,mbt,lparas)
losses[J] = L
J = J + 1
#record results
if(J%record_every==0):
J = 0
loss = np.mean(losses)
revnet.recordloss(loss)
print("\tloss={:1.5f}".format(loss))
#save
if(I%save_every==0):
fnext = get_next_save(directory,saveseries)
revnet.save(fnext)
I = I + 1
return
def infer(directory,saveseries):
ttt = int(time.time())
np.random.seed(ttt%50000)
torch.manual_seed(ttt%50000)
lparas = training_parameters()
revnet = mndiff_rev()
fname = get_current_save(directory,saveseries)
if(fname is None):
print("Warning, no state loaded.")
else:
print("Loading {}".format(fname))
revnet.load(fname)
nts = lparas["nts"]
beta = lparas["Beta"]
nts = nts*15
imgseq = torch.zeros([nts,28,28])
imgseq[nts-1,:,:] = torch.randn(28,28)
#imgseq[nts-1,:,17:20] = 3.0
tl = torch.linspace(0,1,nts)
bb = torch.as_tensor(beta)
for I in range(nts-2,-1,-1):
img = imgseq[I+1,:,:]
t = tl[I]
eps = revnet(img,t)
img2 = img*torch.sqrt(1-bb) - bb*eps + 1.0*bb*torch.randn(28,28)
imgseq[I,:,:] = img2
im0 = torch.squeeze(imgseq[0,:,:])
im1 = torch.squeeze(imgseq[int(nts*0.25),:,:])
im2 = torch.squeeze(imgseq[int(nts*0.75),:,:])
im3 = torch.squeeze(imgseq[int(nts*0.90),:,:])
im0 = im0.detach().numpy()
im1 = im1.detach().numpy()
im2 = im2.detach().numpy()
im3 = im3.detach().numpy()
plt.figure()
plt.subplot(2,2,1)
plt.imshow(im3,cmap='gray')
plt.subplot(2,2,2)
plt.imshow(im2,cmap='gray')
plt.subplot(2,2,3)
plt.imshow(im1,cmap='gray')
plt.subplot(2,2,4)
plt.imshow(im0,cmap='gray')
plt.show()
return
def plot_history(directory, saveseries):
revnet = mndiff_rev()
fname = get_current_save(directory,saveseries)
if(fname is None):
fname = get_next_save(directory,saveseries)
else:
print("Loading {}".format(fname))
revnet.load(fname)
revnet.plothistory()
return
def test1():
[images_train, images_test] = mnist_load()
lparas = training_parameters()
#plt.figure()
#plt.plot(lparas["t"],lparas["a_t"])
#plt.show()
#plt.figure()
#plt.plot(lparas["t"],lparas["B_t"])
#plt.show()
images_train = images_train.to("cuda")
[mb,mbtime] = generate_minibatch(images_train,50,lparas)
img0 = torch.squeeze(mb[0,0,:,:]).clone().detach().to("cpu")
img1 = torch.squeeze(mb[0,1,:,:]).clone().detach().to("cpu")
img2 = torch.squeeze(mb[0,2,:,:]).clone().detach().to("cpu")
img3 = torch.squeeze(mb[0,99,:,:]).clone().detach().to("cpu")
print(mb.shape)
print(mbtime.shape)
plt.figure()
plt.subplot(2,2,1)
plt.imshow(img0,cmap='gray')
plt.subplot(2,2,2)
plt.imshow(img1,cmap='gray')
plt.subplot(2,2,3)
plt.imshow(img2,cmap='gray')
plt.subplot(2,2,4)
plt.imshow(img3,cmap='gray')
plt.show()
return
def test2(directory, saveseries):
lparas = training_parameters()
nts = lparas["nts"]
beta = lparas["Beta"]
revnet = mndiff_rev()
fname = get_current_save(directory,saveseries)
if(fname is None):
fname = get_next_save(directory,saveseries)
else:
print("Loading {}".format(fname))
revnet.load(fname)
[images_train, images_test] = mnist_load()
x = torch.squeeze(images_train[100,:,:])
#x = beta*torch.randn(28,28)
y = revnet(x,0.5)
x = x.detach().numpy()
y = y.detach().numpy()
print("y mn={:1.3f} std={:1.3f} std2*beta={}".format(np.mean(y),np.std(y),np.std(y)**2*beta))
plt.figure()
plt.subplot(2,1,1)
plt.imshow(x,cmap='gray')
plt.subplot(2,1,2)
plt.imshow(y,cmap='gray')
plt.show()
return
##########
## Main ##
##########
def mainswitch():
args = sys.argv
if(len(args)<2):
print("mnistdiffusion [operation]")
print("operations:")
print("\ttrain")
print("\tplot_history")
print("\tinfer")
exit(0)
if(sys.argv[1]=="train"):
train("./saves","test_2a")
if(sys.argv[1]=="plot_history"):
plot_history("./saves","test_2a")
if(sys.argv[1]=="infer"):
infer("./saves","test_2a")
if(__name__=="__main__"):
ttt = int(time.time())
np.random.seed(ttt%50000)
torch.manual_seed(ttt%50000)
#test1()
#test2("./saves","test_2a")
mainswitch()

@ -0,0 +1,298 @@
#!/usr/bin/python3
import os,sys,math
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import mnist_dataloader
## Sohl-Dickstein 2015 describes a gaussian "bump" function to
# implement learnable time-dependence
#
#network input vector and normalized time (tn e (0,1))
# input: (dimx, dimy, ..., features), (scalar) --> output: (dimx, dimy, ...)
# input: (ntimes, ..., features), (ntimes) --> output: (ntimes , ...)
class gaussian_timeselect(nn.Module):
def __init__(self, nfeatures=5, width = 0.25):
super(gaussian_timeselect,self).__init__()
tau_j = torch.linspace(0,1,nfeatures).float()
self.tau_j = torch.nn.parameter.Parameter(tau_j, requires_grad=False)
width = torch.tensor(width,dtype=torch.float32)
self.width = torch.nn.parameter.Parameter(width,requires_grad = False)
#network input vector and normalized time (tn e (0,1))
# input: (dimx, dimy, ..., features), (scalar) --> output: (dimx, dimy, ...)
# input: (ntimes, ..., features), (ntimes) --> output: (ntimes , ...)
def forward(self, x, tn):
ndim = len(x.shape)
tn = torch.as_tensor(tn)
tn.requires_grad = False
if(len(tn.shape)==0):
#case 0 - if tn is a scalar
g = torch.exp(-1.0/(2.0*self.width)*(tn-self.tau_j)**2)
g = g/torch.clip(torch.sum(g),min=0.001)
x = torch.tensordot(x,g,dims=[[ndim-1],[0]])
else:
#Case 1 - tn is a tensor
nf = self.tau_j.shape[0]
nt = tn.numel()
nt2 = x.shape[0]
ndim = len(x.shape)
nf2 = x.shape[ndim-1]
if(nt!=nt2):
raise Exception("gaussian_timeselect time mismatch")
if(nf!=nf2):
raise Exception("gaussian_timeselect feature mismatch")
shp2 = x.shape[1:ndim-1]
nshp2 = shp2.numel()
x = x.reshape([nt,nshp2,nf])
tau = self.tau_j.reshape([1,1,nf])
tau = tau.expand([nt,nshp2,nf])
tn = tn.reshape([nt,1,1])
tn = tn.expand([nt,nshp2,nf])
g = torch.exp(-1.0/(2.0*self.width)*(tn-tau)**2)
gs = torch.clip(torch.sum(g,axis=2),min=0.0001)
gs = gs.reshape([nt,nshp2,1]).expand([nt,nshp2,nf])
g = g/gs
x = torch.sum(x*g,axis=2)
#shp3 = torch.Size(torch.cat(torch.tensor([nt]),torch.tensor(shp2)))
# contact [nt, *shp2]
shp3 = torch.Size([nt,*shp2])
x = x.reshape(shp3)
return x
#The reverse network should output a mean image and a covariance image
#given an input image
class mndiff_rev(nn.Module):
def __init__(self):
super(mndiff_rev,self).__init__()
self.L1 = torch.nn.Conv2d(1,10,3)
self.L2 = torch.nn.Conv2d(10,10,5)
self.L3 = torch.nn.ConvTranspose2d(10,10,5)
self.L4 = torch.nn.ConvTranspose2d(10,10,3)
#bypass branch
#self.Lb1 = torch.nn.Linear(1,10)
#self.Lb2 = torch.nn.Linear(10,2)
#self.Lb4f = torch.nn.Linear(12,12)
self.L5 = gaussian_timeselect(10,0.1)
#self.L4 = nn.Conv2d(10,10,kernel_size=1)
#channels must be 2nd dimension
# self.L4a = nn.Linear(5,5)
# #equivalent, channels must be last dimension
# self.L4b = nn.Linear(5,5)
# self.L5 = gaussian_timeselect(5,0.1)
#self.L5alt = nn.Linear(5,1)
#self.RL = nn.LeakyReLU(0.01)
self.tanh = nn.Tanh()
self.device = torch.device("cpu")
self.losstime = []
self.lossrecord = []
return
def forward(self,x,t):
#handle data-shape cases
#case 0: x (pix_x, pix_y), t is scalar
#case 1: x (ntimes, pix_x, pix_y), t is (ntimes)
#case 2: x (nbatch, ntimes, pix_x, pix_y), t is (nbatch, ntimes)
if(len(x.shape)==2):
case = 0
t = torch.as_tensor(t).to(self.device)
t = t.float()
x = x.reshape([1,1,28,28])
ntimes = 1
nbatch = 1
elif(len(x.shape)==3):
case = 1
ntimes = x.shape[0]
nbatch = 1
x = x.reshape([ntimes,1,28,28])
elif(len(x.shape)==4):
case = 2
nbatch = x.shape[0]
ntimes = x.shape[1]
#x = x.reshape([nbatch,ntimes,28*28])
x = x.reshape([nbatch*ntimes,1,28,28])
t = t.reshape([nbatch*ntimes])
y = torch.randn(28*28,device=self.device)
else:
print("Error: Expect input to be dimension 2,3,4")
raise Exception("Error: Expect input to be dimension 2,3,4")
s1 = self.L1(x)
s1 = self.tanh(s1)
#print(s1.shape)
s1 = self.L2(s1)
s1 = self.tanh(s1)
#print(s1.shape)
s1 = self.L3(s1)
s1 = self.tanh(s1)
#print(s1.shape)
s1 = self.L4(s1)
s1 = self.tanh(s1)
#print(s1.shape)
s1 = s1.permute([0,2,3,1])
s2 = self.L5(s1,t)
eps = s2
if(case==0):
eps = eps.reshape([28,28])
elif(case==1):
eps = eps.reshape([ntimes,28,28])
elif(case==2):
eps = eps.reshape([nbatch,ntimes,28,28])
return eps
def to(self,dev):
self.device = torch.device(dev)
ret = super(mndiff_rev,self).to(self.device)
return ret
def save(self,fname):
try:
dev = self.device
self.to("cpu")
q = [self.state_dict(), self.losstime, self.lossrecord]
torch.save(q,fname)
self.to(dev)
except:
print("model save: problem saving {}".format(fname))
return
def load(self,fname):
try:
[modelsd,losstime,lossrecord] = torch.load(fname)
self.load_state_dict(modelsd)
self.losstime = losstime
self.lossrecord = lossrecord
except:
print("model load: problem loading {}".format(fname))
return
def plothistory(self):
x = self.losstime
y = self.lossrecord
plt.figure()
plt.plot(x,y,'k.',markersize=1)
plt.title('Loss History')
plt.show()
return
def recordloss(self,loss):
t = 0
L = len(self.losstime)
if(L>0):
t = self.losstime[L-1]
self.losstime.append(t+1)
self.lossrecord.append(loss)
return
###########
## Tests ##
###########
def test_gaussian_timeselect():
gts = gaussian_timeselect(5,0.025)
gts = gts.to("cuda")
x = torch.randn((8,8,5)).to("cuda")
#t = torch.tensor([0,0.5,1]).to("cuda")
t = torch.linspace(0,1,8).to("cuda")
y = gts(x,0)
y2 = gts(x,t)
y3 = gts(x,0.111)
print(y)
print(x[:,:,0])
print(x.shape)
print(y.shape)
print(y2.shape)
print(torch.abs(y2[0,:]-y[0,:])>1E-6)
return
def test_mndiff_rev():
mnd = mndiff_rev()
x0 = torch.randn(28,28)
t0 = 0
x1 = torch.randn(5,28,28)
t1 = torch.linspace(0,1,5)
x2 = torch.randn(8,5,28,28)
t2 = torch.linspace(0,1,5)
t2 = t2.reshape([1,5]).expand([8,5])
y0 = mnd(x0,t0)
y1 = mnd(x1,t1)
y2 = mnd(x2,t2)
print(y0.shape)
print(y1.shape)
print(y2.shape)
return
if(__name__=="__main__"):
#test_gaussian_timeselect()
test_mndiff_rev()
pass

@ -0,0 +1,78 @@
#!/usr/bin/python3
import os,sys,math
def get_savenumber(fname):
num = 0
fn2 = os.path.split(fname)[1]
fn2 = os.path.splitext(fn2)[0]
L = len(fn2)
L2 = L
for I in range(L-1,0,-1):
c = fn2[I]
if(not c.isnumeric()):
L2 = I+1
break
nm = fn2[L2:L]
try:
num = int(nm)
except:
num = 0
return num
def sorttwopythonlists(list1, list2):
zipped_pairs = zip(list2, list1)
z = [x for _, x in sorted(zipped_pairs)]
z2 = sorted(list2)
return [z,z2]
def get_current_save(directory,saveseries):
fname = None
fl = os.listdir(directory)
fl2 = []
nums = []
for f in fl:
if(f.find(saveseries)>=0):
fl2.append(os.path.join(directory,f))
for f in fl2:
n = get_savenumber(f)
nums.append(n)
[fl2,nums] = sorttwopythonlists(fl2,nums)
if(len(fl2)>0):
fname = fl2[len(fl2)-1]
else:
fname = None
return fname
def get_next_save(directory,saveseries):
fname = None
fncurrent = get_current_save(directory,saveseries)
if(fncurrent is None):
fname = os.path.join(directory,"{}{:02d}.pyt".format(saveseries,0))
else:
N = get_savenumber(fncurrent)
N = N + 1
fname = os.path.join(directory,"{}{:02d}.pyt".format(saveseries,N))
return fname
if(__name__=="__main__"):
#print(get_savenumber("./saves/helloworld05202.py"))
pass

@ -0,0 +1,161 @@
#!/usr/bin/python3
import os,sys,math
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import mnist_dataloader
from mnistdiffusion_model import mndiff_rev
################
## Dataloader ##
################
def mnist_load():
[labels_train, labels_test, images_train, images_test] = mnist_dataloader.mnist_load()
#60000 x 28 x 28
#60000 x 1
#Rescale data to (-1,1)
images_train = images_train.clone().detach().float()
images_test = images_test.clone().detach().float()
images_train = 2.0*(images_train/255.0)-1.0
images_test = 2.0*(images_test/255.0)-1.0
images_train.requires_grad = False
images_test.requires_grad = False
return [images_train, images_test]
################
## Operations ##
################
def training_parameters():
lparas = dict()
nts = 200
t = np.arange(0,nts)
s = 0.1
f_t = np.cos((t/nts+s)/(1+s)*np.pi/2.0)**2
a_t = f_t/f_t[0]
B_t = np.zeros(t.shape[0])
B_t[0:nts-1] = np.clip(1-a_t[1:nts]/a_t[0:nts-1],0,0.999)
B_t[nts-1] = B_t[nts-2]
lparas["nts"] = nts
lparas["t"] = torch.tensor(t,dtype=torch.float32)
lparas["a_t"] = torch.tensor(a_t,dtype=torch.float32)
lparas["B_t"] = torch.tensor(B_t,dtype=torch.float32)
#Traiing parameters for a cosine variance schedule
#ref: Nichol and Dharawal 2021
return lparas
def generate_minibatch(imgs, Ndraw, lparas):
#I can probably speed this up with compiled CUDA code
nts = lparas["nts"]
dev = imgs.device
imbatch = torch.zeros([Ndraw*nts,28,28],dtype=torch.float32).to(dev)
draws = torch.randint(0,imgs.shape[0],(Ndraw,))
for I in range(0,Ndraw):
imgseq = torch.zeros([nts,28,28],dtype=torch.float32).to(dev)
imgseq[0,:,:] = imgs[draws[I],:,:]
beta = torch.kron(lparas["B_t"],torch.ones(28,28)).to(dev)
beta = beta.reshape([nts,28,28])
sig = torch.sqrt(beta)
noise = torch.randn((nts,28,28)).to(dev)*sig
for J in range(1,nts):
imgseq[J,:,:] = imgseq[J-1,:,:] + noise[J-1,:,:]
I1 = I*nts
I2 = (I+1)*nts
imbatch[I1:I2,:,:] = imgseq[:,:,:]
return imbatch
def train_batch(imgbatch,lr,lparas):
return
def train():
return
def infer():
return
def test1():
[images_train, images_test] = mnist_load()
lparas = training_parameters()
#plt.figure()
#plt.plot(lparas["t"],lparas["a_t"])
#plt.show()
#plt.figure()
#plt.plot(lparas["t"],lparas["B_t"])
#plt.show()
images_train = images_train.to("cuda")
mb = generate_minibatch(images_train,50,lparas)
img0 = torch.squeeze(mb[0,:,:]).clone().detach().to("cpu")
img1 = torch.squeeze(mb[1,:,:]).clone().detach().to("cpu")
img2 = torch.squeeze(mb[2,:,:]).clone().detach().to("cpu")
img3 = torch.squeeze(mb[30,:,:]).clone().detach().to("cpu")
plt.figure()
plt.subplot(2,2,1)
plt.imshow(img0,cmap='gray')
plt.subplot(2,2,2)
plt.imshow(img1,cmap='gray')
plt.subplot(2,2,3)
plt.imshow(img2,cmap='gray')
plt.subplot(2,2,4)
plt.imshow(img3,cmap='gray')
plt.show()
return
##########
## Main ##
##########
if(__name__=="__main__"):
#test1()
test_gaussian_timeselect()

Binary file not shown.

Binary file not shown.

Binary file not shown.

@ -0,0 +1,85 @@
{
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9-final"
},
"orig_nbformat": 2,
"kernelspec": {
"name": "python3",
"display_name": "Python 3.6.9 64-bit",
"metadata": {
"interpreter": {
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
}
}
}
},
"nbformat": 4,
"nbformat_minor": 2,
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os,sys,math\n",
"import numpy as np\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"import mnist_dataloader\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"torch.Size([])\ntorch.Size([])\n[1, 2, 3]\n"
]
}
],
"source": [
"a = torch.randn(10,14)\n",
"b = a.shape[1:1]\n",
"print(b)\n",
"b.numel()\n",
"\n",
"print(b)\n",
"\n",
"b = torch.Size([1])\n",
"c = torch.Size([2,3])\n",
"d = torch.Size(torch.cat([torch.tensor(b),torch.tensor(c)]))\n",
"\n",
"d = [*b,*c]\n",
"\n",
"print(d)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
]
}

@ -0,0 +1,91 @@
#!/usr/bin/python3
import os,sys,math
import numpy as np
import cv2
import gzip #need to use gzip.open instead of open
import struct
import torch
def read_MNIST_label_file(fname):
#fp = gzip.open('./train-labels-idx1-ubyte.gz','rb');
fp = gzip.open(fname,'rb');
magic = fp.read(4);
#nitems = np.frombuffer(fp.read(4),dtype=np.int32)[0]; #some sort of endiannes problem
bts = fp.read(4);
#bts = bytereverse(bts);
#nitems = np.frombuffer(bts,dtype=np.int32);
nitems = np.int32(struct.unpack('>I',bts)[0]); #it was a non-native endianness in teh integer encoding
#> < @ - endianness
bts = fp.read(nitems);
N = len(bts);
labels = np.zeros((N),dtype=np.uint8);
labels = np.frombuffer(bts,dtype=np.uint8,count=N);
#for i in range(0,10):
# bt = fp.read(1);
# labels[i] = np.frombuffer(bt,dtype=np.uint8);
fp.close();
return labels;
def read_MNIST_image_file(fname):
fp = gzip.open(fname,'rb');
magic = fp.read(4);
bts = fp.read(4);
nitems = np.int32(struct.unpack('>I',bts)[0]);
bts = fp.read(4);
nrows = np.int32(struct.unpack('>I',bts)[0]);
bts = fp.read(4);
ncols = np.int32(struct.unpack('>I',bts)[0]);
images = np.zeros((nitems,nrows,ncols),dtype=np.uint8);
for I in range(0,nitems):
bts = fp.read(nrows*ncols);
img1 = np.frombuffer(bts,dtype=np.uint8,count=nrows*ncols);
img1 = img1.reshape((nrows,ncols));
images[I,:,:] = img1;
fp.close();
return images;
#The mnist dataset is small enough to fit entirely in memory
def mnist_load():
baseloc = "../training_data"
traindatafile = "train-images-idx3-ubyte.gz"
trainlabelfile = "train-labels-idx1-ubyte.gz"
testdatafile = "t10k-images-idx3-ubyte.gz"
testlabelfile = "t10k-labels-idx1-ubyte.gz"
traindatafile = os.path.join(baseloc,traindatafile)
trainlabelfile = os.path.join(baseloc,trainlabelfile)
testdatafile = os.path.join(baseloc,testdatafile)
testlabelfile = os.path.join(baseloc,testlabelfile)
labels_train = read_MNIST_label_file(trainlabelfile)
labels_test = read_MNIST_label_file(testlabelfile)
images_train = read_MNIST_image_file(traindatafile)
images_test = read_MNIST_image_file(testdatafile)
labels_train = torch.tensor(labels_train,dtype=torch.float32,requires_grad=False)
labels_test = torch.tensor(labels_test,dtype=torch.float32,requires_grad=False)
images_train = torch.tensor(images_train,dtype=torch.float32,requires_grad=False)
images_test = torch.tensor(images_test,dtype=torch.float32,requires_grad=False)
# #debug
# print(labels_train.shape)
# print(labels_test.shape)
# print(images_train.shape)
# print(images_test.shape)
return [labels_train, labels_test, images_train, images_test]
if(__name__ == "__main__"):
[labels_train, labels_test, images_train, images_test] = mnist_load()
print("Loaded MNIST Data")
print(labels_train.shape)
print(labels_test.shape)
print(images_train.shape)
print(images_test.shape)

@ -0,0 +1,398 @@
#!/usr/bin/python3
import os,sys,math
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import mnist_dataloader
from mnistdiffusion_model import mndiff_rev
from mnistdiffusion_utils import *
import time
################
## Dataloader ##
################
def mnist_load():
[labels_train, labels_test, images_train, images_test] = mnist_dataloader.mnist_load()
#60000 x 28 x 28
#60000 x 1
#Rescale data to (-1,1)
images_train = images_train.clone().detach().float()
images_test = images_test.clone().detach().float()
images_train = 2.0*(images_train/255.0)-1.0
images_test = 2.0*(images_test/255.0)-1.0
images_train.requires_grad = False
images_test.requires_grad = False
return [images_train, images_test]
################
## Operations ##
################
# Rather than killing myself figuring out variable variance schedules,
# just use a constant variance schedule to start with
def training_parameters():
lparas = dict()
nts = 200
Beta = 0.0050
lr = 1E-4
batchsize = 30
save_every = 200
record_every = 25
B_t = np.ones((nts))*Beta
a_t = np.zeros((nts))
a_t[0] = 1
for I in range(1,nts):
a_t[I] = a_t[I-1]*(1-B_t[I-1])
t = np.linspace(0,1,nts) #normalized time (t/T)
lparas["nts"] = nts
lparas["Beta"] = Beta
lparas["t"] = torch.tensor(t,dtype=torch.float32)
lparas["a_t"] = torch.tensor(a_t,dtype=torch.float32)
lparas["B_t"] = torch.tensor(B_t,dtype=torch.float32)
lparas["lr"] = lr
lparas["batchsize"] = batchsize
lparas["save_every"] = save_every
lparas["record_every"] = record_every
return lparas
def generate_minibatch(imgs, Ndraw, lparas):
#I can probably speed this up with compiled CUDA code
nts = lparas["nts"]
Beta = lparas["Beta"]
dev = imgs.device
imbatch = torch.zeros([Ndraw,nts,28,28],dtype=torch.float32).to(dev)
draws = torch.randint(0,imgs.shape[0],(Ndraw,))
times = lparas["t"].to(dev)
times = times.reshape([1,nts]).expand([Ndraw,nts])
for I in range(0,Ndraw):
imgseq = torch.zeros([nts,28,28],dtype=torch.float32).to(dev)
imgseq[0,:,:] = imgs[draws[I],:,:]
beta = torch.kron(lparas["B_t"],torch.ones(28,28)).to(dev)
beta = beta.reshape([nts,28,28])
sig = torch.sqrt(beta)
noise = torch.randn((nts,28,28)).to(dev)*sig
for J in range(1,nts):
imgseq[J,:,:] = imgseq[J-1,:,:]*torch.sqrt(1.0-beta[J-1,:,:]) + noise[J-1,:,:]
imbatch[I,:,:,:] = imgseq[:,:,:]
return [imbatch,times]
def train_batch(revnet,batchimg,batchtime,lparas):
L = 0.0
lr = lparas["lr"]
nts = batchtime.shape[1]
dev = revnet.device
#beta = torch.as_tensor(lparas["Beta"]).float().to(dev)
beta = float(lparas["Beta"])
optim = torch.optim.SGD(revnet.parameters(),lr=lr)
optim = torch.optim.Adam(revnet.parameters(),lr=lr)
loss = torch.nn.MSELoss().to(dev)
d_est = revnet(batchimg,batchtime)
d_est = d_est[:,1:nts,:,:]
targ = batchimg[:,1:nts,:,:]-batchimg[:,0:nts-1,:,:]
L2 = loss(targ,beta*d_est)
L2.backward()
optim.step()
L = L2.detach().cpu().numpy()
L = L/beta**2
loss.zero_grad()
return L
def train(directory,saveseries):
print("Loading dataset...")
[images_train, images_test] = mnist_load()
lparas = training_parameters()
batchsize = lparas["batchsize"]
save_every = lparas["save_every"]
record_every = lparas["record_every"]
device = torch.device("cuda")
revnet = mndiff_rev()
fname = get_current_save(directory,saveseries)
if(fname is None):
fname = get_next_save(directory,saveseries)
else:
print("Loading {}".format(fname))
revnet.load(fname)
images_train = images_train.to(device)
revnet = revnet.to(device)
## Training Loop
I = 0
J = 0
while(True):
if(I%record_every==0):
print("Minibatch {}".format(I))
#New minibatch
losses = np.zeros((record_every))
#draw minibatch
[mb,mbt] = generate_minibatch(images_train,batchsize,lparas)
#train minibatch
L = train_batch(revnet,mb,mbt,lparas)
losses[J] = L
J = J + 1
#record results
if(J%record_every==0):
J = 0
loss = np.mean(losses)
revnet.recordloss(loss)
print("\tloss={:1.5f}".format(loss))
#save
if(I%save_every==0):
fnext = get_next_save(directory,saveseries)
revnet.save(fnext)
I = I + 1
return
def infer(directory,saveseries):
ttt = int(time.time())
np.random.seed(ttt%50000)
torch.manual_seed(ttt%50000)
lparas = training_parameters()
revnet = mndiff_rev()
fname = get_current_save(directory,saveseries)
if(fname is None):
print("Warning, no state loaded.")
else:
print("Loading {}".format(fname))
revnet.load(fname)
nts = lparas["nts"]
beta = lparas["Beta"]
nts = nts*15
imgseq = torch.zeros([nts,28,28])
imgseq[nts-1,:,:] = torch.randn(28,28)
#imgseq[nts-1,:,17:20] = 3.0
tl = torch.linspace(0,1,nts)
bb = torch.as_tensor(beta)
for I in range(nts-2,-1,-1):
img = imgseq[I+1,:,:]
t = tl[I]
eps = revnet(img,t)
img2 = img*torch.sqrt(1-bb) - bb*eps + 1.0*bb*torch.randn(28,28)
imgseq[I,:,:] = img2
im0 = torch.squeeze(imgseq[0,:,:])
im1 = torch.squeeze(imgseq[int(nts*0.25),:,:])
im2 = torch.squeeze(imgseq[int(nts*0.75),:,:])
im3 = torch.squeeze(imgseq[int(nts*0.90),:,:])
im0 = im0.detach().numpy()
im1 = im1.detach().numpy()
im2 = im2.detach().numpy()
im3 = im3.detach().numpy()
plt.figure()
plt.subplot(2,2,1)
plt.imshow(im3,cmap='gray')
plt.subplot(2,2,2)
plt.imshow(im2,cmap='gray')
plt.subplot(2,2,3)
plt.imshow(im1,cmap='gray')
plt.subplot(2,2,4)
plt.imshow(im0,cmap='gray')
plt.show()
return
def plot_history(directory, saveseries):
revnet = mndiff_rev()
fname = get_current_save(directory,saveseries)
if(fname is None):
fname = get_next_save(directory,saveseries)
else:
print("Loading {}".format(fname))
revnet.load(fname)
revnet.plothistory()
return
def test1():
[images_train, images_test] = mnist_load()
lparas = training_parameters()
#plt.figure()
#plt.plot(lparas["t"],lparas["a_t"])
#plt.show()
#plt.figure()
#plt.plot(lparas["t"],lparas["B_t"])
#plt.show()
images_train = images_train.to("cuda")
[mb,mbtime] = generate_minibatch(images_train,50,lparas)
img0 = torch.squeeze(mb[0,0,:,:]).clone().detach().to("cpu")
img1 = torch.squeeze(mb[0,1,:,:]).clone().detach().to("cpu")
img2 = torch.squeeze(mb[0,2,:,:]).clone().detach().to("cpu")
img3 = torch.squeeze(mb[0,99,:,:]).clone().detach().to("cpu")
print(mb.shape)
print(mbtime.shape)
plt.figure()
plt.subplot(2,2,1)
plt.imshow(img0,cmap='gray')
plt.subplot(2,2,2)
plt.imshow(img1,cmap='gray')
plt.subplot(2,2,3)
plt.imshow(img2,cmap='gray')
plt.subplot(2,2,4)
plt.imshow(img3,cmap='gray')
plt.show()
return
def test2(directory, saveseries):
lparas = training_parameters()
nts = lparas["nts"]
beta = lparas["Beta"]
revnet = mndiff_rev()
fname = get_current_save(directory,saveseries)
if(fname is None):
fname = get_next_save(directory,saveseries)
else:
print("Loading {}".format(fname))
revnet.load(fname)
[images_train, images_test] = mnist_load()
x = torch.squeeze(images_train[100,:,:])
#x = beta*torch.randn(28,28)
y = revnet(x,0.5)
x = x.detach().numpy()
y = y.detach().numpy()
print("y mn={:1.3f} std={:1.3f} std2*beta={}".format(np.mean(y),np.std(y),np.std(y)**2*beta))
plt.figure()
plt.subplot(2,1,1)
plt.imshow(x,cmap='gray')
plt.subplot(2,1,2)
plt.imshow(y,cmap='gray')
plt.show()
return
##########
## Main ##
##########
def mainswitch():
args = sys.argv
if(len(args)<2):
print("mnistdiffusion [operation]")
print("operations:")
print("\ttrain")
print("\tplot_history")
print("\tinfer")
exit(0)
if(sys.argv[1]=="train"):
train("./saves","test_3a")
if(sys.argv[1]=="plot_history"):
plot_history("./saves","test_3a")
if(sys.argv[1]=="infer"):
infer("./saves","test_3a")
if(__name__=="__main__"):
ttt = int(time.time())
np.random.seed(ttt%50000)
torch.manual_seed(ttt%50000)
#test1()
#test2("./saves","test_2a")
mainswitch()

@ -0,0 +1,300 @@
#!/usr/bin/python3
import os,sys,math
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import mnist_dataloader
## Sohl-Dickstein 2015 describes a gaussian "bump" function to
# implement learnable time-dependence
#
#network input vector and normalized time (tn e (0,1))
# input: (dimx, dimy, ..., features), (scalar) --> output: (dimx, dimy, ...)
# input: (ntimes, ..., features), (ntimes) --> output: (ntimes , ...)
class gaussian_timeselect(nn.Module):
def __init__(self, nfeatures=5, width = 0.25):
super(gaussian_timeselect,self).__init__()
tau_j = torch.linspace(0,1,nfeatures).float()
self.tau_j = torch.nn.parameter.Parameter(tau_j, requires_grad=False)
width = torch.tensor(width,dtype=torch.float32)
self.width = torch.nn.parameter.Parameter(width,requires_grad = False)
#network input vector and normalized time (tn e (0,1))
# input: (dimx, dimy, ..., features), (scalar) --> output: (dimx, dimy, ...)
# input: (ntimes, ..., features), (ntimes) --> output: (ntimes , ...)
def forward(self, x, tn):
ndim = len(x.shape)
tn = torch.as_tensor(tn)
tn.requires_grad = False
if(len(tn.shape)==0):
#case 0 - if tn is a scalar
g = torch.exp(-1.0/(2.0*self.width)*(tn-self.tau_j)**2)
g = g/torch.clip(torch.sum(g),min=0.001)
x = torch.tensordot(x,g,dims=[[ndim-1],[0]])
else:
#Case 1 - tn is a tensor
nf = self.tau_j.shape[0]
nt = tn.numel()
nt2 = x.shape[0]
ndim = len(x.shape)
nf2 = x.shape[ndim-1]
if(nt!=nt2):
raise Exception("gaussian_timeselect time mismatch")
if(nf!=nf2):
raise Exception("gaussian_timeselect feature mismatch")
shp2 = x.shape[1:ndim-1]
nshp2 = shp2.numel()
x = x.reshape([nt,nshp2,nf])
tau = self.tau_j.reshape([1,1,nf])
tau = tau.expand([nt,nshp2,nf])
tn = tn.reshape([nt,1,1])
tn = tn.expand([nt,nshp2,nf])
g = torch.exp(-1.0/(2.0*self.width)*(tn-tau)**2)
gs = torch.clip(torch.sum(g,axis=2),min=0.0001)
gs = gs.reshape([nt,nshp2,1]).expand([nt,nshp2,nf])
g = g/gs
x = torch.sum(x*g,axis=2)
#shp3 = torch.Size(torch.cat(torch.tensor([nt]),torch.tensor(shp2)))
# contact [nt, *shp2]
shp3 = torch.Size([nt,*shp2])
x = x.reshape(shp3)
return x
#The reverse network should output a mean image and a covariance image
#given an input image
class mndiff_rev(nn.Module):
def __init__(self):
super(mndiff_rev,self).__init__()
self.nchan = 10
self.L1a = torch.nn.Conv2d(1,self.nchan,3,padding=1)
self.L1b = torch.nn.Conv2d(self.nchan,self.nchan,3)
self.L1c = torch.nn.MaxPool2d(kernel_size=2)
self.L2a = torch.nn.Conv2d(self.nchan,self.nchan,3)
self.L2b = torch.nn.Conv2d(self.nchan,self.nchan,3)
self.L3a = torch.nn.ConvTranspose2d(self.nchan,self.nchan,3)
self.L3b = torch.nn.ConvTranspose2d(self.nchan,self.nchan,3)
#self.L4a = torch.nn.Upsample(scale_factor=2)
self.L4b = torch.nn.ConvTranspose2d(2*self.nchan,self.nchan,3)
self.L4c = torch.nn.ConvTranspose2d(self.nchan,1,3,padding=1)
self.nl = nn.Tanh()
self.device = torch.device("cpu")
self.losstime = []
self.lossrecord = []
return
def forward(self,x,t):
#handle data-shape cases
#case 0: x (pix_x, pix_y), t is scalar
#case 1: x (ntimes, pix_x, pix_y), t is (ntimes)
#case 2: x (nbatch, ntimes, pix_x, pix_y), t is (nbatch, ntimes)
if(len(x.shape)==2):
case = 0
t = torch.as_tensor(t).to(self.device)
t = t.float()
x = x.reshape([1,1,28,28])
ntimes = 1
nbatch = 1
elif(len(x.shape)==3):
case = 1
ntimes = x.shape[0]
nbatch = 1
x = x.reshape([ntimes,1,28,28])
elif(len(x.shape)==4):
case = 2
nbatch = x.shape[0]
ntimes = x.shape[1]
#x = x.reshape([nbatch,ntimes,28*28])
x = x.reshape([nbatch*ntimes,1,28,28])
t = t.reshape([nbatch*ntimes])
else:
print("Error: Expect input to be dimension 2,3,4")
raise Exception("Error: Expect input to be dimension 2,3,4")
#block 1
s1 = self.L1a(x)
s1 = self.nl(s1)
s1 = self.L1b(s1) #skip connection
s1 = self.nl(s1)
s2 = self.L1c(s1)
#block 2
s2 = self.L2a(s2)
s2 = self.nl(s2)
s2 = self.L2b(s2)
s2 = self.nl(s2)
#block 3
s2 = self.L3a(s2)
s2 = self.nl(s2)
s2 = self.L3b(s2)
s2 = self.nl(s2)
#block 4
#s2 = self.L4a(s2)
s2 = torch.nn.functional.interpolate(s2,scale_factor=[2,2])
#concat skip connection
s3 = torch.cat([s1,s2],axis=1)
s3 = self.L4b(s3)
s3 = self.nl(s3)
s3 = self.L4c(s3)
#output
eps = s3
if(case==0):
eps = eps.reshape([28,28])
elif(case==1):
eps = eps.reshape([ntimes,28,28])
elif(case==2):
eps = eps.reshape([nbatch,ntimes,28,28])
return eps
def to(self,dev):
self.device = torch.device(dev)
ret = super(mndiff_rev,self).to(self.device)
return ret
def save(self,fname):
try:
dev = self.device
self.to("cpu")
q = [self.state_dict(), self.losstime, self.lossrecord]
torch.save(q,fname)
self.to(dev)
except:
print("model save: problem saving {}".format(fname))
return
def load(self,fname):
try:
[modelsd,losstime,lossrecord] = torch.load(fname)
self.load_state_dict(modelsd)
self.losstime = losstime
self.lossrecord = lossrecord
except:
print("model load: problem loading {}".format(fname))
return
def plothistory(self):
x = self.losstime
y = self.lossrecord
plt.figure()
plt.plot(x,y,'k.',markersize=1)
plt.title('Loss History')
plt.show()
return
def recordloss(self,loss):
t = 0
L = len(self.losstime)
if(L>0):
t = self.losstime[L-1]
self.losstime.append(t+1)
self.lossrecord.append(loss)
return
###########
## Tests ##
###########
def test_gaussian_timeselect():
gts = gaussian_timeselect(5,0.025)
gts = gts.to("cuda")
x = torch.randn((8,8,5)).to("cuda")
#t = torch.tensor([0,0.5,1]).to("cuda")
t = torch.linspace(0,1,8).to("cuda")
y = gts(x,0)
y2 = gts(x,t)
y3 = gts(x,0.111)
print(y)
print(x[:,:,0])
print(x.shape)
print(y.shape)
print(y2.shape)
print(torch.abs(y2[0,:]-y[0,:])>1E-6)
return
def test_mndiff_rev():
mnd = mndiff_rev()
x0 = torch.randn(28,28)
t0 = 0
x1 = torch.randn(5,28,28)
t1 = torch.linspace(0,1,5)
x2 = torch.randn(8,5,28,28)
t2 = torch.linspace(0,1,5)
t2 = t2.reshape([1,5]).expand([8,5])
y0 = mnd(x0,t0)
y1 = mnd(x1,t1)
y2 = mnd(x2,t2)
print(y0.shape)
print(y1.shape)
print(y2.shape)
return
if(__name__=="__main__"):
#test_gaussian_timeselect()
test_mndiff_rev()
pass

@ -0,0 +1,78 @@
#!/usr/bin/python3
import os,sys,math
def get_savenumber(fname):
num = 0
fn2 = os.path.split(fname)[1]
fn2 = os.path.splitext(fn2)[0]
L = len(fn2)
L2 = L
for I in range(L-1,0,-1):
c = fn2[I]
if(not c.isnumeric()):
L2 = I+1
break
nm = fn2[L2:L]
try:
num = int(nm)
except:
num = 0
return num
def sorttwopythonlists(list1, list2):
zipped_pairs = zip(list2, list1)
z = [x for _, x in sorted(zipped_pairs)]
z2 = sorted(list2)
return [z,z2]
def get_current_save(directory,saveseries):
fname = None
fl = os.listdir(directory)
fl2 = []
nums = []
for f in fl:
if(f.find(saveseries)>=0):
fl2.append(os.path.join(directory,f))
for f in fl2:
n = get_savenumber(f)
nums.append(n)
[fl2,nums] = sorttwopythonlists(fl2,nums)
if(len(fl2)>0):
fname = fl2[len(fl2)-1]
else:
fname = None
return fname
def get_next_save(directory,saveseries):
fname = None
fncurrent = get_current_save(directory,saveseries)
if(fncurrent is None):
fname = os.path.join(directory,"{}{:02d}.pyt".format(saveseries,0))
else:
N = get_savenumber(fncurrent)
N = N + 1
fname = os.path.join(directory,"{}{:02d}.pyt".format(saveseries,N))
return fname
if(__name__=="__main__"):
#print(get_savenumber("./saves/helloworld05202.py"))
pass

@ -0,0 +1,161 @@
#!/usr/bin/python3
import os,sys,math
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import mnist_dataloader
from mnistdiffusion_model import mndiff_rev
################
## Dataloader ##
################
def mnist_load():
[labels_train, labels_test, images_train, images_test] = mnist_dataloader.mnist_load()
#60000 x 28 x 28
#60000 x 1
#Rescale data to (-1,1)
images_train = images_train.clone().detach().float()
images_test = images_test.clone().detach().float()
images_train = 2.0*(images_train/255.0)-1.0
images_test = 2.0*(images_test/255.0)-1.0
images_train.requires_grad = False
images_test.requires_grad = False
return [images_train, images_test]
################
## Operations ##
################
def training_parameters():
lparas = dict()
nts = 200
t = np.arange(0,nts)
s = 0.1
f_t = np.cos((t/nts+s)/(1+s)*np.pi/2.0)**2
a_t = f_t/f_t[0]
B_t = np.zeros(t.shape[0])
B_t[0:nts-1] = np.clip(1-a_t[1:nts]/a_t[0:nts-1],0,0.999)
B_t[nts-1] = B_t[nts-2]
lparas["nts"] = nts
lparas["t"] = torch.tensor(t,dtype=torch.float32)
lparas["a_t"] = torch.tensor(a_t,dtype=torch.float32)
lparas["B_t"] = torch.tensor(B_t,dtype=torch.float32)
#Traiing parameters for a cosine variance schedule
#ref: Nichol and Dharawal 2021
return lparas
def generate_minibatch(imgs, Ndraw, lparas):
#I can probably speed this up with compiled CUDA code
nts = lparas["nts"]
dev = imgs.device
imbatch = torch.zeros([Ndraw*nts,28,28],dtype=torch.float32).to(dev)
draws = torch.randint(0,imgs.shape[0],(Ndraw,))
for I in range(0,Ndraw):
imgseq = torch.zeros([nts,28,28],dtype=torch.float32).to(dev)
imgseq[0,:,:] = imgs[draws[I],:,:]
beta = torch.kron(lparas["B_t"],torch.ones(28,28)).to(dev)
beta = beta.reshape([nts,28,28])
sig = torch.sqrt(beta)
noise = torch.randn((nts,28,28)).to(dev)*sig
for J in range(1,nts):
imgseq[J,:,:] = imgseq[J-1,:,:] + noise[J-1,:,:]
I1 = I*nts
I2 = (I+1)*nts
imbatch[I1:I2,:,:] = imgseq[:,:,:]
return imbatch
def train_batch(imgbatch,lr,lparas):
return
def train():
return
def infer():
return
def test1():
[images_train, images_test] = mnist_load()
lparas = training_parameters()
#plt.figure()
#plt.plot(lparas["t"],lparas["a_t"])
#plt.show()
#plt.figure()
#plt.plot(lparas["t"],lparas["B_t"])
#plt.show()
images_train = images_train.to("cuda")
mb = generate_minibatch(images_train,50,lparas)
img0 = torch.squeeze(mb[0,:,:]).clone().detach().to("cpu")
img1 = torch.squeeze(mb[1,:,:]).clone().detach().to("cpu")
img2 = torch.squeeze(mb[2,:,:]).clone().detach().to("cpu")
img3 = torch.squeeze(mb[30,:,:]).clone().detach().to("cpu")
plt.figure()
plt.subplot(2,2,1)
plt.imshow(img0,cmap='gray')
plt.subplot(2,2,2)
plt.imshow(img1,cmap='gray')
plt.subplot(2,2,3)
plt.imshow(img2,cmap='gray')
plt.subplot(2,2,4)
plt.imshow(img3,cmap='gray')
plt.show()
return
##########
## Main ##
##########
if(__name__=="__main__"):
#test1()
test_gaussian_timeselect()

@ -0,0 +1,85 @@
{
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9-final"
},
"orig_nbformat": 2,
"kernelspec": {
"name": "python3",
"display_name": "Python 3.6.9 64-bit",
"metadata": {
"interpreter": {
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
}
}
}
},
"nbformat": 4,
"nbformat_minor": 2,
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os,sys,math\n",
"import numpy as np\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"import mnist_dataloader\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"torch.Size([])\ntorch.Size([])\n[1, 2, 3]\n"
]
}
],
"source": [
"a = torch.randn(10,14)\n",
"b = a.shape[1:1]\n",
"print(b)\n",
"b.numel()\n",
"\n",
"print(b)\n",
"\n",
"b = torch.Size([1])\n",
"c = torch.Size([2,3])\n",
"d = torch.Size(torch.cat([torch.tensor(b),torch.tensor(c)]))\n",
"\n",
"d = [*b,*c]\n",
"\n",
"print(d)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
]
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 KiB

Binary file not shown.

@ -0,0 +1,138 @@
#!/usr/bin/python3
import sys
sys.path.append('/home/aschinde/workspace/projects_python/library')
import os,sys,math
import numpy as np
import cv2;
import gzip #May need to use gzip.open instead of open
import struct
#struct unpack allows some interpretation of python binary data
#Example
##import struct
##
##data = open("from_fortran.bin", "rb").read()
##
##(eight, N) = struct.unpack("@II", data)
##
##This unpacks the first two fields, assuming they start at the very
##beginning of the file (no padding or extraneous data), and also assuming
##native byte-order (the @ symbol). The Is in the formatting string mean
##"unsigned integer, 32 bits".
#for integers
#a = int
#a.from_bytes(b'\xaf\xc2R',byteorder='little')
#a.to_bytes(nbytes,byteorder='big')
#analagous operation doens't seem to exist for floats
#what about numpy?
#https://www.devdungeon.com/content/working-binary-data-python
#print("{:02d}".format(2))
#b = b.fromhex('010203040506')
#b.hex()
#c = b.decode(encoding='utf-8' or 'latin-1' or 'ascii'...)
#print(c)
#numpy arrays have tobytes
#numpy arrays have frombuffer (converts to dtypes)
#
#q = np.array([15],dtype=np.uint8);
#q.tobytes();
#q.tobytes(order='C') (options are 'C' and 'F'
#q2 = np.buffer(q.tobytes(),dtype=np.uint8)
#np.frombuffer(buffer,dtype=float,count=-1,offset=0)
##You could also use the < and > endianess format codes in the struct
##module to achieve the same result:
##
##>>> struct.pack('<2h', *struct.unpack('>2h', original))
##'\xde\xad\xc0\xde'
def bytereverse(bts):
## bts2 = bytes(len(bts));
## for I in range(0,len(bts)):
## bts2[len(bts)-I-1] = bts[I];
N = len(bts);
## print(N);
## print(bts);
## bts2 = struct.pack('<{}h'.format(N), *struct.unpack('>{}h'.format(N), bts))
bts2 = bts;
return bts2;
#Read Labels
def read_MNIST_label_file(fname):
#fp = gzip.open('./train-labels-idx1-ubyte.gz','rb');
fp = gzip.open(fname,'rb');
magic = fp.read(4);
#nitems = np.frombuffer(fp.read(4),dtype=np.int32)[0]; #some sort of endiannes problem
bts = fp.read(4);
#bts = bytereverse(bts);
#nitems = np.frombuffer(bts,dtype=np.int32);
nitems = np.int32(struct.unpack('>I',bts)[0]); #it was a non-native endianness in teh integer encoding
#> < @ - endianness
bts = fp.read(nitems);
N = len(bts);
labels = np.zeros((N),dtype=np.uint8);
labels = np.frombuffer(bts,dtype=np.uint8,count=N);
#for i in range(0,10):
# bt = fp.read(1);
# labels[i] = np.frombuffer(bt,dtype=np.uint8);
fp.close();
return labels;
def read_MNIST_image_file(fname):
fp = gzip.open(fname,'rb');
magic = fp.read(4);
bts = fp.read(4);
nitems = np.int32(struct.unpack('>I',bts)[0]);
bts = fp.read(4);
nrows = np.int32(struct.unpack('>I',bts)[0]);
bts = fp.read(4);
ncols = np.int32(struct.unpack('>I',bts)[0]);
images = np.zeros((nitems,nrows,ncols),dtype=np.uint8);
for I in range(0,nitems):
bts = fp.read(nrows*ncols);
img1 = np.frombuffer(bts,dtype=np.uint8,count=nrows*ncols);
img1 = img1.reshape((nrows,ncols));
images[I,:,:] = img1;
fp.close();
return images;
def read_training_data():
rootdir = '/home/aschinde/workspace/machinelearning/datasets/MNIST';
fname1 = 'train-labels-idx1-ubyte.gz';
fname2 = 'train-images-idx3-ubyte.gz';
labels = read_MNIST_label_file(os.path.join(rootdir,fname1));
images = read_MNIST_image_file(os.path.join(rootdir,fname2));
return [labels,images];
def read_test_data():
rootdir = '/home/aschinde/workspace/machinelearning/datasets/MNIST';
fname1 = 't10k-labels-idx1-ubyte.gz';
fname2 = 't10k-images-idx3-ubyte.gz';
labels = read_MNIST_label_file(os.path.join(rootdir,fname1));
images = read_MNIST_image_file(os.path.join(rootdir,fname2));
return [labels,images];
def show_MNIST_image(img):
import matplotlib.pyplot as plt;
plt.figure();
plt.imshow(255-img,cmap='gray');
plt.show();
return;

@ -0,0 +1,92 @@
#!/usr/bin/python3
"""
mnist_loader
~~~~~~~~~~~~
A library to load the MNIST image data. For details of the data
structures that are returned, see the doc strings for ``load_data``
and ``load_data_wrapper``. In practice, ``load_data_wrapper`` is the
function usually called by our neural network code.
"""
##sigh: If you want it to run today, write it in Python.
##If you want it to run tomorrow, write it in ANYTHING ELSE
#### Libraries
# Standard library
##import cPickle
import pickle as cPickle
import gzip
# Third-party libraries
import numpy as np
def load_data():
"""Return the MNIST data as a tuple containing the training data,
the validation data, and the test data.
The ``training_data`` is returned as a tuple with two entries.
The first entry contains the actual training images. This is a
numpy ndarray with 50,000 entries. Each entry is, in turn, a
numpy ndarray with 784 values, representing the 28 * 28 = 784
pixels in a single MNIST image.
The second entry in the ``training_data`` tuple is a numpy ndarray
containing 50,000 entries. Those entries are just the digit
values (0...9) for the corresponding images contained in the first
entry of the tuple.
The ``validation_data`` and ``test_data`` are similar, except
each contains only 10,000 images.
This is a nice data format, but for use in neural networks it's
helpful to modify the format of the ``training_data`` a little.
That's done in the wrapper function ``load_data_wrapper()``, see
below.
"""
#f = gzip.open('../data/mnist.pkl.gz', 'rb')
f = gzip.open('./t10k-images-idx3-ubyte.gz','rb');
training_data, validation_data, test_data = cPickle.load(f)
f.close()
return (training_data, validation_data, test_data)
def load_data_wrapper():
"""Return a tuple containing ``(training_data, validation_data,
test_data)``. Based on ``load_data``, but the format is more
convenient for use in our implementation of neural networks.
In particular, ``training_data`` is a list containing 50,000
2-tuples ``(x, y)``. ``x`` is a 784-dimensional numpy.ndarray
containing the input image. ``y`` is a 10-dimensional
numpy.ndarray representing the unit vector corresponding to the
correct digit for ``x``.
``validation_data`` and ``test_data`` are lists containing 10,000
2-tuples ``(x, y)``. In each case, ``x`` is a 784-dimensional
numpy.ndarry containing the input image, and ``y`` is the
corresponding classification, i.e., the digit values (integers)
corresponding to ``x``.
Obviously, this means we're using slightly different formats for
the training data and the validation / test data. These formats
turn out to be the most convenient for use in our neural network
code."""
tr_d, va_d, te_d = load_data()
training_inputs = [np.reshape(x, (784, 1)) for x in tr_d[0]]
training_results = [vectorized_result(y) for y in tr_d[1]]
training_data = zip(training_inputs, training_results)
validation_inputs = [np.reshape(x, (784, 1)) for x in va_d[0]]
validation_data = zip(validation_inputs, va_d[1])
test_inputs = [np.reshape(x, (784, 1)) for x in te_d[0]]
test_data = zip(test_inputs, te_d[1])
return (training_data, validation_data, test_data)
def vectorized_result(j):
"""Return a 10-dimensional unit vector with a 1.0 in the jth
position and zeroes elsewhere. This is used to convert a digit
(0...9) into a corresponding desired output from the neural
network."""
e = np.zeros((10, 1))
e[j] = 1.0
return e
Loading…
Cancel
Save