Making some of my models public.
commit
eb2c910e29
@ -0,0 +1,2 @@
|
||||
Copyright 2023, Aaron M. Schinder
|
||||
Released under the MIT/BSD License
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,91 @@
|
||||
#!/usr/bin/python3
|
||||
import os,sys,math
|
||||
import numpy as np
|
||||
import cv2
|
||||
import gzip #need to use gzip.open instead of open
|
||||
import struct
|
||||
|
||||
import torch
|
||||
|
||||
def read_MNIST_label_file(fname):
|
||||
#fp = gzip.open('./train-labels-idx1-ubyte.gz','rb');
|
||||
fp = gzip.open(fname,'rb');
|
||||
magic = fp.read(4);
|
||||
#nitems = np.frombuffer(fp.read(4),dtype=np.int32)[0]; #some sort of endiannes problem
|
||||
bts = fp.read(4);
|
||||
#bts = bytereverse(bts);
|
||||
#nitems = np.frombuffer(bts,dtype=np.int32);
|
||||
nitems = np.int32(struct.unpack('>I',bts)[0]); #it was a non-native endianness in teh integer encoding
|
||||
#> < @ - endianness
|
||||
|
||||
bts = fp.read(nitems);
|
||||
N = len(bts);
|
||||
labels = np.zeros((N),dtype=np.uint8);
|
||||
labels = np.frombuffer(bts,dtype=np.uint8,count=N);
|
||||
#for i in range(0,10):
|
||||
# bt = fp.read(1);
|
||||
# labels[i] = np.frombuffer(bt,dtype=np.uint8);
|
||||
fp.close();
|
||||
return labels;
|
||||
|
||||
def read_MNIST_image_file(fname):
|
||||
fp = gzip.open(fname,'rb');
|
||||
magic = fp.read(4);
|
||||
bts = fp.read(4);
|
||||
nitems = np.int32(struct.unpack('>I',bts)[0]);
|
||||
bts = fp.read(4);
|
||||
nrows = np.int32(struct.unpack('>I',bts)[0]);
|
||||
bts = fp.read(4);
|
||||
ncols = np.int32(struct.unpack('>I',bts)[0]);
|
||||
|
||||
images = np.zeros((nitems,nrows,ncols),dtype=np.uint8);
|
||||
for I in range(0,nitems):
|
||||
bts = fp.read(nrows*ncols);
|
||||
img1 = np.frombuffer(bts,dtype=np.uint8,count=nrows*ncols);
|
||||
img1 = img1.reshape((nrows,ncols));
|
||||
images[I,:,:] = img1;
|
||||
|
||||
fp.close();
|
||||
|
||||
return images;
|
||||
|
||||
#The mnist dataset is small enough to fit entirely in memory
|
||||
def mnist_load():
|
||||
baseloc = "../training_data"
|
||||
|
||||
traindatafile = "train-images-idx3-ubyte.gz"
|
||||
trainlabelfile = "train-labels-idx1-ubyte.gz"
|
||||
testdatafile = "t10k-images-idx3-ubyte.gz"
|
||||
testlabelfile = "t10k-labels-idx1-ubyte.gz"
|
||||
|
||||
traindatafile = os.path.join(baseloc,traindatafile)
|
||||
trainlabelfile = os.path.join(baseloc,trainlabelfile)
|
||||
testdatafile = os.path.join(baseloc,testdatafile)
|
||||
testlabelfile = os.path.join(baseloc,testlabelfile)
|
||||
|
||||
labels_train = read_MNIST_label_file(trainlabelfile)
|
||||
labels_test = read_MNIST_label_file(testlabelfile)
|
||||
images_train = read_MNIST_image_file(traindatafile)
|
||||
images_test = read_MNIST_image_file(testdatafile)
|
||||
|
||||
labels_train = torch.tensor(labels_train,dtype=torch.float32,requires_grad=False)
|
||||
labels_test = torch.tensor(labels_test,dtype=torch.float32,requires_grad=False)
|
||||
images_train = torch.tensor(images_train,dtype=torch.float32,requires_grad=False)
|
||||
images_test = torch.tensor(images_test,dtype=torch.float32,requires_grad=False)
|
||||
|
||||
# #debug
|
||||
# print(labels_train.shape)
|
||||
# print(labels_test.shape)
|
||||
# print(images_train.shape)
|
||||
# print(images_test.shape)
|
||||
|
||||
|
||||
return [labels_train, labels_test, images_train, images_test]
|
||||
|
||||
if(__name__ == "__main__"):
|
||||
[labels_train, labels_test, images_train, images_test] = mnist_load()
|
||||
print("Loaded MNIST Data")
|
||||
print(labels_train.shape)
|
||||
print(labels_test.shape)
|
||||
print(images_train.shape)
|
||||
print(images_test.shape)
|
@ -0,0 +1,401 @@
|
||||
#!/usr/bin/python3
|
||||
|
||||
import os,sys,math
|
||||
import numpy as np
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import mnist_dataloader
|
||||
|
||||
from mnistdiffusion_model import mndiff_rev
|
||||
from mnistdiffusion_utils import *
|
||||
|
||||
import time
|
||||
|
||||
################
|
||||
## Dataloader ##
|
||||
################
|
||||
|
||||
def mnist_load():
|
||||
|
||||
[labels_train, labels_test, images_train, images_test] = mnist_dataloader.mnist_load()
|
||||
|
||||
#60000 x 28 x 28
|
||||
#60000 x 1
|
||||
|
||||
#Rescale data to (-1,1)
|
||||
images_train = images_train.clone().detach().float()
|
||||
images_test = images_test.clone().detach().float()
|
||||
|
||||
images_train = 2.0*(images_train/255.0)-1.0
|
||||
images_test = 2.0*(images_test/255.0)-1.0
|
||||
|
||||
images_train.requires_grad = False
|
||||
images_test.requires_grad = False
|
||||
|
||||
return [images_train, images_test]
|
||||
|
||||
|
||||
|
||||
|
||||
################
|
||||
## Operations ##
|
||||
################
|
||||
|
||||
# Rather than killing myself figuring out variable variance schedules,
|
||||
# just use a constant variance schedule to start with
|
||||
|
||||
def training_parameters():
|
||||
lparas = dict()
|
||||
|
||||
nts = 200
|
||||
Beta = 0.0050
|
||||
|
||||
lr = 1E-4
|
||||
batchsize = 80
|
||||
save_every = 200
|
||||
record_every = 25
|
||||
|
||||
B_t = np.ones((nts))*Beta
|
||||
a_t = np.zeros((nts))
|
||||
a_t[0] = 1
|
||||
for I in range(1,nts):
|
||||
a_t[I] = a_t[I-1]*(1-B_t[I-1])
|
||||
|
||||
t = np.linspace(0,1,nts) #normalized time (t/T)
|
||||
|
||||
lparas["nts"] = nts
|
||||
lparas["Beta"] = Beta
|
||||
lparas["t"] = torch.tensor(t,dtype=torch.float32)
|
||||
lparas["a_t"] = torch.tensor(a_t,dtype=torch.float32)
|
||||
lparas["B_t"] = torch.tensor(B_t,dtype=torch.float32)
|
||||
lparas["lr"] = lr
|
||||
lparas["batchsize"] = batchsize
|
||||
lparas["save_every"] = save_every
|
||||
lparas["record_every"] = record_every
|
||||
return lparas
|
||||
|
||||
|
||||
def generate_minibatch(imgs, Ndraw, lparas):
|
||||
|
||||
#I can probably speed this up with compiled CUDA code
|
||||
|
||||
nts = lparas["nts"]
|
||||
Beta = lparas["Beta"]
|
||||
|
||||
dev = imgs.device
|
||||
|
||||
imbatch = torch.zeros([Ndraw,nts,28,28],dtype=torch.float32).to(dev)
|
||||
draws = torch.randint(0,imgs.shape[0],(Ndraw,))
|
||||
|
||||
times = lparas["t"].to(dev)
|
||||
times = times.reshape([1,nts]).expand([Ndraw,nts])
|
||||
|
||||
for I in range(0,Ndraw):
|
||||
imgseq = torch.zeros([nts,28,28],dtype=torch.float32).to(dev)
|
||||
imgseq[0,:,:] = imgs[draws[I],:,:]
|
||||
|
||||
beta = torch.kron(lparas["B_t"],torch.ones(28,28)).to(dev)
|
||||
beta = beta.reshape([nts,28,28])
|
||||
|
||||
sig = torch.sqrt(beta)
|
||||
noise = torch.randn((nts,28,28)).to(dev)*sig
|
||||
|
||||
for J in range(1,nts):
|
||||
imgseq[J,:,:] = imgseq[J-1,:,:]*torch.sqrt(1.0-beta[J-1,:,:]) + noise[J-1,:,:]
|
||||
|
||||
|
||||
imbatch[I,:,:,:] = imgseq[:,:,:]
|
||||
|
||||
return [imbatch,times]
|
||||
|
||||
def train_batch(revnet,batchimg,batchtime,lparas):
|
||||
|
||||
L = 0.0
|
||||
lr = lparas["lr"]
|
||||
nts = batchtime.shape[1]
|
||||
|
||||
dev = revnet.device
|
||||
#beta = torch.as_tensor(lparas["Beta"]).float().to(dev)
|
||||
beta = float(lparas["Beta"])
|
||||
|
||||
optim = torch.optim.SGD(revnet.parameters(),lr=lr)
|
||||
optim = torch.optim.Adam(revnet.parameters(),lr=lr)
|
||||
|
||||
loss = torch.nn.MSELoss().to(dev)
|
||||
|
||||
d_est = revnet(batchimg,batchtime)
|
||||
d_est = d_est[:,1:nts,:,:]
|
||||
targ = batchimg[:,1:nts,:,:]-batchimg[:,0:nts-1,:,:]
|
||||
|
||||
L2 = loss(targ,beta*d_est)
|
||||
|
||||
L2.backward()
|
||||
|
||||
optim.step()
|
||||
L = L2.detach().cpu().numpy()
|
||||
L = L/beta**2
|
||||
|
||||
loss.zero_grad()
|
||||
|
||||
|
||||
return L
|
||||
|
||||
|
||||
|
||||
|
||||
def train(directory,saveseries):
|
||||
|
||||
print("Loading dataset...")
|
||||
[images_train, images_test] = mnist_load()
|
||||
lparas = training_parameters()
|
||||
|
||||
batchsize = lparas["batchsize"]
|
||||
save_every = lparas["save_every"]
|
||||
record_every = lparas["record_every"]
|
||||
device = torch.device("cuda")
|
||||
|
||||
revnet = mndiff_rev()
|
||||
fname = get_current_save(directory,saveseries)
|
||||
if(fname is None):
|
||||
fname = get_next_save(directory,saveseries)
|
||||
else:
|
||||
print("Loading {}".format(fname))
|
||||
revnet.load(fname)
|
||||
|
||||
images_train = images_train.to(device)
|
||||
revnet = revnet.to(device)
|
||||
|
||||
## Training Loop
|
||||
I = 0
|
||||
J = 0
|
||||
while(True):
|
||||
|
||||
if(I%record_every==0):
|
||||
print("Minibatch {}".format(I))
|
||||
|
||||
#New minibatch
|
||||
losses = np.zeros((record_every))
|
||||
|
||||
|
||||
#draw minibatch
|
||||
[mb,mbt] = generate_minibatch(images_train,batchsize,lparas)
|
||||
|
||||
#train minibatch
|
||||
L = train_batch(revnet,mb,mbt,lparas)
|
||||
losses[J] = L
|
||||
J = J + 1
|
||||
|
||||
#record results
|
||||
if(J%record_every==0):
|
||||
J = 0
|
||||
loss = np.mean(losses)
|
||||
revnet.recordloss(loss)
|
||||
print("\tloss={:1.5f}".format(loss))
|
||||
|
||||
#save
|
||||
if(I%save_every==0):
|
||||
fnext = get_next_save(directory,saveseries)
|
||||
revnet.save(fnext)
|
||||
|
||||
I = I + 1
|
||||
|
||||
return
|
||||
|
||||
def infer(directory,saveseries):
|
||||
|
||||
ttt = int(time.time())
|
||||
np.random.seed(ttt%50000)
|
||||
torch.manual_seed(ttt%50000)
|
||||
|
||||
|
||||
|
||||
lparas = training_parameters()
|
||||
revnet = mndiff_rev()
|
||||
fname = get_current_save(directory,saveseries)
|
||||
if(fname is None):
|
||||
print("Warning, no state loaded.")
|
||||
else:
|
||||
print("Loading {}".format(fname))
|
||||
revnet.load(fname)
|
||||
|
||||
|
||||
nts = lparas["nts"]
|
||||
beta = lparas["Beta"]
|
||||
|
||||
nts = nts*15
|
||||
|
||||
imgseq = torch.zeros([nts,28,28])
|
||||
imgseq[nts-1,:,:] = torch.randn(28,28)
|
||||
|
||||
#imgseq[nts-1,:,17:20] = 3.0
|
||||
|
||||
tl = torch.linspace(0,1,nts)
|
||||
|
||||
bb = torch.as_tensor(beta)
|
||||
|
||||
for I in range(nts-2,-1,-1):
|
||||
img = imgseq[I+1,:,:]
|
||||
t = tl[I]
|
||||
eps = revnet(img,t)
|
||||
|
||||
|
||||
|
||||
img2 = img*torch.sqrt(1-bb) - bb*eps + 0.5*bb*torch.randn(28,28)
|
||||
|
||||
imgseq[I,:,:] = img2
|
||||
|
||||
im0 = torch.squeeze(imgseq[0,:,:])
|
||||
im1 = torch.squeeze(imgseq[int(nts*0.25),:,:])
|
||||
im2 = torch.squeeze(imgseq[int(nts*0.75),:,:])
|
||||
im3 = torch.squeeze(imgseq[int(nts*0.90),:,:])
|
||||
|
||||
im0 = im0.detach().numpy()
|
||||
im1 = im1.detach().numpy()
|
||||
im2 = im2.detach().numpy()
|
||||
im3 = im3.detach().numpy()
|
||||
|
||||
plt.figure()
|
||||
plt.subplot(2,2,1)
|
||||
plt.imshow(im3,cmap='gray')
|
||||
plt.subplot(2,2,2)
|
||||
plt.imshow(im2,cmap='gray')
|
||||
plt.subplot(2,2,3)
|
||||
plt.imshow(im1,cmap='gray')
|
||||
plt.subplot(2,2,4)
|
||||
plt.imshow(im0,cmap='gray')
|
||||
plt.show()
|
||||
|
||||
|
||||
return
|
||||
|
||||
def plot_history(directory, saveseries):
|
||||
|
||||
revnet = mndiff_rev()
|
||||
fname = get_current_save(directory,saveseries)
|
||||
if(fname is None):
|
||||
fname = get_next_save(directory,saveseries)
|
||||
else:
|
||||
print("Loading {}".format(fname))
|
||||
revnet.load(fname)
|
||||
|
||||
revnet.plothistory()
|
||||
|
||||
return
|
||||
|
||||
def test1():
|
||||
|
||||
[images_train, images_test] = mnist_load()
|
||||
lparas = training_parameters()
|
||||
|
||||
#plt.figure()
|
||||
#plt.plot(lparas["t"],lparas["a_t"])
|
||||
#plt.show()
|
||||
|
||||
#plt.figure()
|
||||
#plt.plot(lparas["t"],lparas["B_t"])
|
||||
#plt.show()
|
||||
|
||||
images_train = images_train.to("cuda")
|
||||
|
||||
[mb,mbtime] = generate_minibatch(images_train,50,lparas)
|
||||
|
||||
img0 = torch.squeeze(mb[0,0,:,:]).clone().detach().to("cpu")
|
||||
img1 = torch.squeeze(mb[0,1,:,:]).clone().detach().to("cpu")
|
||||
img2 = torch.squeeze(mb[0,2,:,:]).clone().detach().to("cpu")
|
||||
img3 = torch.squeeze(mb[0,99,:,:]).clone().detach().to("cpu")
|
||||
|
||||
print(mb.shape)
|
||||
print(mbtime.shape)
|
||||
|
||||
plt.figure()
|
||||
plt.subplot(2,2,1)
|
||||
plt.imshow(img0,cmap='gray')
|
||||
plt.subplot(2,2,2)
|
||||
plt.imshow(img1,cmap='gray')
|
||||
plt.subplot(2,2,3)
|
||||
plt.imshow(img2,cmap='gray')
|
||||
plt.subplot(2,2,4)
|
||||
plt.imshow(img3,cmap='gray')
|
||||
|
||||
plt.show()
|
||||
|
||||
return
|
||||
|
||||
def test2(directory, saveseries):
|
||||
|
||||
lparas = training_parameters()
|
||||
nts = lparas["nts"]
|
||||
beta = lparas["Beta"]
|
||||
|
||||
revnet = mndiff_rev()
|
||||
fname = get_current_save(directory,saveseries)
|
||||
if(fname is None):
|
||||
fname = get_next_save(directory,saveseries)
|
||||
else:
|
||||
print("Loading {}".format(fname))
|
||||
revnet.load(fname)
|
||||
|
||||
|
||||
[images_train, images_test] = mnist_load()
|
||||
|
||||
x = torch.squeeze(images_train[100,:,:])
|
||||
#x = beta*torch.randn(28,28)
|
||||
y = revnet(x,0.5)
|
||||
|
||||
x = x.detach().numpy()
|
||||
y = y.detach().numpy()
|
||||
|
||||
print("y mn={:1.3f} std={:1.3f} std2*beta={}".format(np.mean(y),np.std(y),np.std(y)**2*beta))
|
||||
|
||||
|
||||
plt.figure()
|
||||
plt.subplot(2,1,1)
|
||||
plt.imshow(x,cmap='gray')
|
||||
plt.subplot(2,1,2)
|
||||
plt.imshow(y,cmap='gray')
|
||||
plt.show()
|
||||
|
||||
|
||||
return
|
||||
|
||||
##########
|
||||
## Main ##
|
||||
##########
|
||||
|
||||
def mainswitch():
|
||||
args = sys.argv
|
||||
if(len(args)<2):
|
||||
print("mnistdiffusion [operation]")
|
||||
print("operations:")
|
||||
print("\ttrain")
|
||||
print("\tplot_history")
|
||||
print("\tinfer")
|
||||
|
||||
exit(0)
|
||||
|
||||
if(sys.argv[1]=="train"):
|
||||
train("./saves","test_f")
|
||||
if(sys.argv[1]=="plot_history"):
|
||||
plot_history("./saves","test_f")
|
||||
if(sys.argv[1]=="infer"):
|
||||
infer("./saves","test_f")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if(__name__=="__main__"):
|
||||
|
||||
ttt = int(time.time())
|
||||
np.random.seed(ttt%50000)
|
||||
torch.manual_seed(ttt%50000)
|
||||
|
||||
#test1()
|
||||
#test2("./saves","test_d")
|
||||
mainswitch()
|
||||
|
@ -0,0 +1,289 @@
|
||||
#!/usr/bin/python3
|
||||
|
||||
import os,sys,math
|
||||
import numpy as np
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import mnist_dataloader
|
||||
|
||||
|
||||
## Sohl-Dickstein 2015 describes a gaussian "bump" function to
|
||||
# implement learnable time-dependence
|
||||
#
|
||||
#network input vector and normalized time (tn e (0,1))
|
||||
# input: (dimx, dimy, ..., features), (scalar) --> output: (dimx, dimy, ...)
|
||||
# input: (ntimes, ..., features), (ntimes) --> output: (ntimes , ...)
|
||||
class gaussian_timeselect(nn.Module):
|
||||
|
||||
def __init__(self, nfeatures=5, width = 0.25):
|
||||
super(gaussian_timeselect,self).__init__()
|
||||
|
||||
tau_j = torch.linspace(0,1,nfeatures).float()
|
||||
self.tau_j = torch.nn.parameter.Parameter(tau_j, requires_grad=False)
|
||||
|
||||
width = torch.tensor(width,dtype=torch.float32)
|
||||
self.width = torch.nn.parameter.Parameter(width,requires_grad = False)
|
||||
|
||||
#network input vector and normalized time (tn e (0,1))
|
||||
# input: (dimx, dimy, ..., features), (scalar) --> output: (dimx, dimy, ...)
|
||||
# input: (ntimes, ..., features), (ntimes) --> output: (ntimes , ...)
|
||||
def forward(self, x, tn):
|
||||
|
||||
ndim = len(x.shape)
|
||||
|
||||
tn = torch.as_tensor(tn)
|
||||
tn.requires_grad = False
|
||||
|
||||
|
||||
if(len(tn.shape)==0):
|
||||
#case 0 - if tn is a scalar
|
||||
g = torch.exp(-1.0/(2.0*self.width)*(tn-self.tau_j)**2)
|
||||
g = g/torch.clip(torch.sum(g),min=0.001)
|
||||
x = torch.tensordot(x,g,dims=[[ndim-1],[0]])
|
||||
else:
|
||||
#Case 1 - tn is a tensor
|
||||
nf = self.tau_j.shape[0]
|
||||
nt = tn.numel()
|
||||
nt2 = x.shape[0]
|
||||
ndim = len(x.shape)
|
||||
nf2 = x.shape[ndim-1]
|
||||
|
||||
if(nt!=nt2):
|
||||
raise Exception("gaussian_timeselect time mismatch")
|
||||
if(nf!=nf2):
|
||||
raise Exception("gaussian_timeselect feature mismatch")
|
||||
|
||||
shp2 = x.shape[1:ndim-1]
|
||||
nshp2 = shp2.numel()
|
||||
|
||||
x = x.reshape([nt,nshp2,nf])
|
||||
tau = self.tau_j.reshape([1,1,nf])
|
||||
tau = tau.expand([nt,nshp2,nf])
|
||||
tn = tn.reshape([nt,1,1])
|
||||
tn = tn.expand([nt,nshp2,nf])
|
||||
|
||||
g = torch.exp(-1.0/(2.0*self.width)*(tn-tau)**2)
|
||||
gs = torch.clip(torch.sum(g,axis=2),min=0.0001)
|
||||
gs = gs.reshape([nt,nshp2,1]).expand([nt,nshp2,nf])
|
||||
g = g/gs
|
||||
|
||||
x = torch.sum(x*g,axis=2)
|
||||
#shp3 = torch.Size(torch.cat(torch.tensor([nt]),torch.tensor(shp2)))
|
||||
# contact [nt, *shp2]
|
||||
shp3 = torch.Size([nt,*shp2])
|
||||
x = x.reshape(shp3)
|
||||
|
||||
return x
|
||||
|
||||
#The reverse network should output a mean image and a covariance image
|
||||
#given an input image
|
||||
|
||||
class mndiff_rev(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(mndiff_rev,self).__init__()
|
||||
|
||||
self.L1 = nn.Linear(28*28,1000,bias=True)
|
||||
self.L2 = nn.Linear(1000,1000,bias=True)
|
||||
self.L3 = nn.Linear(1000,28*28*5,bias=True)
|
||||
|
||||
#self.L4 = nn.Conv2d(10,10,kernel_size=1)
|
||||
#channels must be 2nd dimension
|
||||
self.L4a = nn.Linear(5,5)
|
||||
#equivalent, channels must be last dimension
|
||||
self.L4b = nn.Linear(5,5)
|
||||
|
||||
|
||||
self.L5 = gaussian_timeselect(5,0.1)
|
||||
|
||||
#self.L5alt = nn.Linear(5,1)
|
||||
|
||||
#self.RL = nn.LeakyReLU(0.01)
|
||||
self.tanh = nn.Tanh()
|
||||
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
self.losstime = []
|
||||
self.lossrecord = []
|
||||
|
||||
return
|
||||
|
||||
|
||||
def forward(self,x,t):
|
||||
|
||||
#handle data-shape cases
|
||||
#case 0: x (pix_x, pix_y), t is scalar
|
||||
#case 1: x (ntimes, pix_x, pix_y), t is (ntimes)
|
||||
#case 2: x (nbatch, ntimes, pix_x, pix_y), t is (nbatch, ntimes)
|
||||
|
||||
if(len(x.shape)==2):
|
||||
case = 0
|
||||
t = torch.as_tensor(t).to(self.device)
|
||||
t = t.float()
|
||||
x = x.reshape([1,28*28])
|
||||
ntimes = 1
|
||||
nbatch = 1
|
||||
elif(len(x.shape)==3):
|
||||
case = 1
|
||||
ntimes = x.shape[0]
|
||||
nbatch = 1
|
||||
x = x.reshape([ntimes,28*28])
|
||||
elif(len(x.shape)==4):
|
||||
case = 2
|
||||
nbatch = x.shape[0]
|
||||
ntimes = x.shape[1]
|
||||
#x = x.reshape([nbatch,ntimes,28*28])
|
||||
x = x.reshape([nbatch*ntimes,28*28])
|
||||
t = t.reshape([nbatch*ntimes])
|
||||
|
||||
y = torch.randn(28*28,device=self.device)
|
||||
|
||||
else:
|
||||
print("Error: Expect input to be dimension 2,3,4")
|
||||
raise Exception("Error: Expect input to be dimension 2,3,4")
|
||||
|
||||
s1 = self.L1(x)
|
||||
s1 = self.tanh(s1)
|
||||
s1 = self.L2(s1)
|
||||
s1 = self.tanh(s1)
|
||||
s1 = self.L3(s1)
|
||||
s1 = self.tanh(s1)
|
||||
s1 = s1.reshape([ntimes*nbatch,28*28,5])
|
||||
s1 = self.L4a(s1)
|
||||
s1 = self.tanh(s1)
|
||||
s1 = self.L4b(s1)
|
||||
#s1 = self.tanh(s1)
|
||||
|
||||
s2 = self.L5(s1,t)
|
||||
#s2 = self.L5alt(s1)
|
||||
|
||||
eps = s2
|
||||
#eps = self.tanh(s2)
|
||||
|
||||
if(case==0):
|
||||
eps = eps.reshape([28,28])
|
||||
elif(case==1):
|
||||
eps = eps.reshape([ntimes,28,28])
|
||||
elif(case==2):
|
||||
eps = eps.reshape([nbatch,ntimes,28,28])
|
||||
|
||||
return eps
|
||||
|
||||
def to(self,dev):
|
||||
self.device = torch.device(dev)
|
||||
ret = super(mndiff_rev,self).to(self.device)
|
||||
return ret
|
||||
|
||||
|
||||
def save(self,fname):
|
||||
try:
|
||||
dev = self.device
|
||||
self.to("cpu")
|
||||
q = [self.state_dict(), self.losstime, self.lossrecord]
|
||||
torch.save(q,fname)
|
||||
self.to(dev)
|
||||
except:
|
||||
print("model save: problem saving {}".format(fname))
|
||||
|
||||
return
|
||||
|
||||
def load(self,fname):
|
||||
try:
|
||||
[modelsd,losstime,lossrecord] = torch.load(fname)
|
||||
self.load_state_dict(modelsd)
|
||||
self.losstime = losstime
|
||||
self.lossrecord = lossrecord
|
||||
except:
|
||||
print("model load: problem loading {}".format(fname))
|
||||
|
||||
return
|
||||
|
||||
|
||||
def plothistory(self):
|
||||
x = self.losstime
|
||||
y = self.lossrecord
|
||||
plt.figure()
|
||||
plt.plot(x,y,'k.',markersize=1)
|
||||
plt.title('Loss History')
|
||||
plt.show()
|
||||
|
||||
return
|
||||
|
||||
def recordloss(self,loss):
|
||||
|
||||
t = 0
|
||||
L = len(self.losstime)
|
||||
if(L>0):
|
||||
t = self.losstime[L-1]
|
||||
self.losstime.append(t+1)
|
||||
self.lossrecord.append(loss)
|
||||
|
||||
return
|
||||
|
||||
|
||||
|
||||
###########
|
||||
## Tests ##
|
||||
###########
|
||||
|
||||
def test_gaussian_timeselect():
|
||||
gts = gaussian_timeselect(5,0.025)
|
||||
gts = gts.to("cuda")
|
||||
|
||||
x = torch.randn((8,8,5)).to("cuda")
|
||||
#t = torch.tensor([0,0.5,1]).to("cuda")
|
||||
t = torch.linspace(0,1,8).to("cuda")
|
||||
|
||||
y = gts(x,0)
|
||||
y2 = gts(x,t)
|
||||
y3 = gts(x,0.111)
|
||||
|
||||
print(y)
|
||||
print(x[:,:,0])
|
||||
|
||||
print(x.shape)
|
||||
print(y.shape)
|
||||
print(y2.shape)
|
||||
|
||||
print(torch.abs(y2[0,:]-y[0,:])>1E-6)
|
||||
|
||||
|
||||
|
||||
return
|
||||
|
||||
def test_mndiff_rev():
|
||||
|
||||
mnd = mndiff_rev()
|
||||
|
||||
x0 = torch.randn(28,28)
|
||||
t0 = 0
|
||||
x1 = torch.randn(5,28,28)
|
||||
t1 = torch.linspace(0,1,5)
|
||||
|
||||
x2 = torch.randn(8,5,28,28)
|
||||
t2 = torch.linspace(0,1,5)
|
||||
t2 = t2.reshape([1,5]).expand([8,5])
|
||||
|
||||
y0 = mnd(x0,t0)
|
||||
y1 = mnd(x1,t1)
|
||||
y2 = mnd(x2,t2)
|
||||
|
||||
print(y0.shape)
|
||||
print(y1.shape)
|
||||
print(y2.shape)
|
||||
|
||||
|
||||
|
||||
return
|
||||
|
||||
if(__name__=="__main__"):
|
||||
|
||||
test_gaussian_timeselect()
|
||||
#test_mndiff_rev()
|
||||
|
||||
pass
|
@ -0,0 +1,78 @@
|
||||
#!/usr/bin/python3
|
||||
|
||||
import os,sys,math
|
||||
|
||||
|
||||
def get_savenumber(fname):
|
||||
num = 0
|
||||
|
||||
fn2 = os.path.split(fname)[1]
|
||||
fn2 = os.path.splitext(fn2)[0]
|
||||
|
||||
L = len(fn2)
|
||||
L2 = L
|
||||
for I in range(L-1,0,-1):
|
||||
c = fn2[I]
|
||||
if(not c.isnumeric()):
|
||||
L2 = I+1
|
||||
break
|
||||
nm = fn2[L2:L]
|
||||
|
||||
try:
|
||||
num = int(nm)
|
||||
except:
|
||||
num = 0
|
||||
|
||||
return num
|
||||
|
||||
|
||||
def sorttwopythonlists(list1, list2):
|
||||
|
||||
zipped_pairs = zip(list2, list1)
|
||||
|
||||
z = [x for _, x in sorted(zipped_pairs)]
|
||||
z2 = sorted(list2)
|
||||
|
||||
return [z,z2]
|
||||
|
||||
def get_current_save(directory,saveseries):
|
||||
fname = None
|
||||
|
||||
fl = os.listdir(directory)
|
||||
fl2 = []
|
||||
nums = []
|
||||
|
||||
for f in fl:
|
||||
if(f.find(saveseries)>=0):
|
||||
fl2.append(os.path.join(directory,f))
|
||||
|
||||
for f in fl2:
|
||||
n = get_savenumber(f)
|
||||
nums.append(n)
|
||||
|
||||
[fl2,nums] = sorttwopythonlists(fl2,nums)
|
||||
|
||||
if(len(fl2)>0):
|
||||
fname = fl2[len(fl2)-1]
|
||||
else:
|
||||
fname = None
|
||||
|
||||
return fname
|
||||
|
||||
def get_next_save(directory,saveseries):
|
||||
fname = None
|
||||
fncurrent = get_current_save(directory,saveseries)
|
||||
if(fncurrent is None):
|
||||
fname = os.path.join(directory,"{}{:02d}.pyt".format(saveseries,0))
|
||||
else:
|
||||
N = get_savenumber(fncurrent)
|
||||
N = N + 1
|
||||
fname = os.path.join(directory,"{}{:02d}.pyt".format(saveseries,N))
|
||||
|
||||
|
||||
return fname
|
||||
|
||||
if(__name__=="__main__"):
|
||||
|
||||
#print(get_savenumber("./saves/helloworld05202.py"))
|
||||
pass
|
@ -0,0 +1,161 @@
|
||||
#!/usr/bin/python3
|
||||
|
||||
import os,sys,math
|
||||
import numpy as np
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import mnist_dataloader
|
||||
|
||||
from mnistdiffusion_model import mndiff_rev
|
||||
|
||||
################
|
||||
## Dataloader ##
|
||||
################
|
||||
|
||||
def mnist_load():
|
||||
|
||||
[labels_train, labels_test, images_train, images_test] = mnist_dataloader.mnist_load()
|
||||
|
||||
#60000 x 28 x 28
|
||||
#60000 x 1
|
||||
|
||||
#Rescale data to (-1,1)
|
||||
images_train = images_train.clone().detach().float()
|
||||
images_test = images_test.clone().detach().float()
|
||||
|
||||
images_train = 2.0*(images_train/255.0)-1.0
|
||||
images_test = 2.0*(images_test/255.0)-1.0
|
||||
|
||||
images_train.requires_grad = False
|
||||
images_test.requires_grad = False
|
||||
|
||||
return [images_train, images_test]
|
||||
|
||||
|
||||
|
||||
|
||||
################
|
||||
## Operations ##
|
||||
################
|
||||
|
||||
def training_parameters():
|
||||
lparas = dict()
|
||||
|
||||
nts = 200
|
||||
t = np.arange(0,nts)
|
||||
s = 0.1
|
||||
f_t = np.cos((t/nts+s)/(1+s)*np.pi/2.0)**2
|
||||
a_t = f_t/f_t[0]
|
||||
B_t = np.zeros(t.shape[0])
|
||||
B_t[0:nts-1] = np.clip(1-a_t[1:nts]/a_t[0:nts-1],0,0.999)
|
||||
B_t[nts-1] = B_t[nts-2]
|
||||
|
||||
lparas["nts"] = nts
|
||||
lparas["t"] = torch.tensor(t,dtype=torch.float32)
|
||||
lparas["a_t"] = torch.tensor(a_t,dtype=torch.float32)
|
||||
lparas["B_t"] = torch.tensor(B_t,dtype=torch.float32)
|
||||
|
||||
#Traiing parameters for a cosine variance schedule
|
||||
#ref: Nichol and Dharawal 2021
|
||||
|
||||
return lparas
|
||||
|
||||
|
||||
def generate_minibatch(imgs, Ndraw, lparas):
|
||||
|
||||
#I can probably speed this up with compiled CUDA code
|
||||
|
||||
nts = lparas["nts"]
|
||||
dev = imgs.device
|
||||
|
||||
imbatch = torch.zeros([Ndraw*nts,28,28],dtype=torch.float32).to(dev)
|
||||
draws = torch.randint(0,imgs.shape[0],(Ndraw,))
|
||||
|
||||
for I in range(0,Ndraw):
|
||||
imgseq = torch.zeros([nts,28,28],dtype=torch.float32).to(dev)
|
||||
imgseq[0,:,:] = imgs[draws[I],:,:]
|
||||
|
||||
beta = torch.kron(lparas["B_t"],torch.ones(28,28)).to(dev)
|
||||
beta = beta.reshape([nts,28,28])
|
||||
|
||||
sig = torch.sqrt(beta)
|
||||
noise = torch.randn((nts,28,28)).to(dev)*sig
|
||||
|
||||
for J in range(1,nts):
|
||||
imgseq[J,:,:] = imgseq[J-1,:,:] + noise[J-1,:,:]
|
||||
|
||||
I1 = I*nts
|
||||
I2 = (I+1)*nts
|
||||
imbatch[I1:I2,:,:] = imgseq[:,:,:]
|
||||
|
||||
return imbatch
|
||||
|
||||
def train_batch(imgbatch,lr,lparas):
|
||||
|
||||
|
||||
|
||||
|
||||
return
|
||||
|
||||
|
||||
def train():
|
||||
|
||||
return
|
||||
|
||||
def infer():
|
||||
|
||||
return
|
||||
|
||||
|
||||
|
||||
def test1():
|
||||
|
||||
[images_train, images_test] = mnist_load()
|
||||
lparas = training_parameters()
|
||||
|
||||
#plt.figure()
|
||||
#plt.plot(lparas["t"],lparas["a_t"])
|
||||
#plt.show()
|
||||
|
||||
#plt.figure()
|
||||
#plt.plot(lparas["t"],lparas["B_t"])
|
||||
#plt.show()
|
||||
|
||||
images_train = images_train.to("cuda")
|
||||
|
||||
mb = generate_minibatch(images_train,50,lparas)
|
||||
|
||||
img0 = torch.squeeze(mb[0,:,:]).clone().detach().to("cpu")
|
||||
img1 = torch.squeeze(mb[1,:,:]).clone().detach().to("cpu")
|
||||
img2 = torch.squeeze(mb[2,:,:]).clone().detach().to("cpu")
|
||||
img3 = torch.squeeze(mb[30,:,:]).clone().detach().to("cpu")
|
||||
|
||||
|
||||
plt.figure()
|
||||
plt.subplot(2,2,1)
|
||||
plt.imshow(img0,cmap='gray')
|
||||
plt.subplot(2,2,2)
|
||||
plt.imshow(img1,cmap='gray')
|
||||
plt.subplot(2,2,3)
|
||||
plt.imshow(img2,cmap='gray')
|
||||
plt.subplot(2,2,4)
|
||||
plt.imshow(img3,cmap='gray')
|
||||
|
||||
plt.show()
|
||||
|
||||
return
|
||||
|
||||
|
||||
##########
|
||||
## Main ##
|
||||
##########
|
||||
|
||||
if(__name__=="__main__"):
|
||||
|
||||
#test1()
|
||||
test_gaussian_timeselect()
|
@ -0,0 +1,85 @@
|
||||
{
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.9-final"
|
||||
},
|
||||
"orig_nbformat": 2,
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3.6.9 64-bit",
|
||||
"metadata": {
|
||||
"interpreter": {
|
||||
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2,
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os,sys,math\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"\n",
|
||||
"import mnist_dataloader\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"torch.Size([])\ntorch.Size([])\n[1, 2, 3]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"a = torch.randn(10,14)\n",
|
||||
"b = a.shape[1:1]\n",
|
||||
"print(b)\n",
|
||||
"b.numel()\n",
|
||||
"\n",
|
||||
"print(b)\n",
|
||||
"\n",
|
||||
"b = torch.Size([1])\n",
|
||||
"c = torch.Size([2,3])\n",
|
||||
"d = torch.Size(torch.cat([torch.tensor(b),torch.tensor(c)]))\n",
|
||||
"\n",
|
||||
"d = [*b,*c]\n",
|
||||
"\n",
|
||||
"print(d)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
]
|
||||
}
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,91 @@
|
||||
#!/usr/bin/python3
|
||||
import os,sys,math
|
||||
import numpy as np
|
||||
import cv2
|
||||
import gzip #need to use gzip.open instead of open
|
||||
import struct
|
||||
|
||||
import torch
|
||||
|
||||
def read_MNIST_label_file(fname):
|
||||
#fp = gzip.open('./train-labels-idx1-ubyte.gz','rb');
|
||||
fp = gzip.open(fname,'rb');
|
||||
magic = fp.read(4);
|
||||
#nitems = np.frombuffer(fp.read(4),dtype=np.int32)[0]; #some sort of endiannes problem
|
||||
bts = fp.read(4);
|
||||
#bts = bytereverse(bts);
|
||||
#nitems = np.frombuffer(bts,dtype=np.int32);
|
||||
nitems = np.int32(struct.unpack('>I',bts)[0]); #it was a non-native endianness in teh integer encoding
|
||||
#> < @ - endianness
|
||||
|
||||
bts = fp.read(nitems);
|
||||
N = len(bts);
|
||||
labels = np.zeros((N),dtype=np.uint8);
|
||||
labels = np.frombuffer(bts,dtype=np.uint8,count=N);
|
||||
#for i in range(0,10):
|
||||
# bt = fp.read(1);
|
||||
# labels[i] = np.frombuffer(bt,dtype=np.uint8);
|
||||
fp.close();
|
||||
return labels;
|
||||
|
||||
def read_MNIST_image_file(fname):
|
||||
fp = gzip.open(fname,'rb');
|
||||
magic = fp.read(4);
|
||||
bts = fp.read(4);
|
||||
nitems = np.int32(struct.unpack('>I',bts)[0]);
|
||||
bts = fp.read(4);
|
||||
nrows = np.int32(struct.unpack('>I',bts)[0]);
|
||||
bts = fp.read(4);
|
||||
ncols = np.int32(struct.unpack('>I',bts)[0]);
|
||||
|
||||
images = np.zeros((nitems,nrows,ncols),dtype=np.uint8);
|
||||
for I in range(0,nitems):
|
||||
bts = fp.read(nrows*ncols);
|
||||
img1 = np.frombuffer(bts,dtype=np.uint8,count=nrows*ncols);
|
||||
img1 = img1.reshape((nrows,ncols));
|
||||
images[I,:,:] = img1;
|
||||
|
||||
fp.close();
|
||||
|
||||
return images;
|
||||
|
||||
#The mnist dataset is small enough to fit entirely in memory
|
||||
def mnist_load():
|
||||
baseloc = "../training_data"
|
||||
|
||||
traindatafile = "train-images-idx3-ubyte.gz"
|
||||
trainlabelfile = "train-labels-idx1-ubyte.gz"
|
||||
testdatafile = "t10k-images-idx3-ubyte.gz"
|
||||
testlabelfile = "t10k-labels-idx1-ubyte.gz"
|
||||
|
||||
traindatafile = os.path.join(baseloc,traindatafile)
|
||||
trainlabelfile = os.path.join(baseloc,trainlabelfile)
|
||||
testdatafile = os.path.join(baseloc,testdatafile)
|
||||
testlabelfile = os.path.join(baseloc,testlabelfile)
|
||||
|
||||
labels_train = read_MNIST_label_file(trainlabelfile)
|
||||
labels_test = read_MNIST_label_file(testlabelfile)
|
||||
images_train = read_MNIST_image_file(traindatafile)
|
||||
images_test = read_MNIST_image_file(testdatafile)
|
||||
|
||||
labels_train = torch.tensor(labels_train,dtype=torch.float32,requires_grad=False)
|
||||
labels_test = torch.tensor(labels_test,dtype=torch.float32,requires_grad=False)
|
||||
images_train = torch.tensor(images_train,dtype=torch.float32,requires_grad=False)
|
||||
images_test = torch.tensor(images_test,dtype=torch.float32,requires_grad=False)
|
||||
|
||||
# #debug
|
||||
# print(labels_train.shape)
|
||||
# print(labels_test.shape)
|
||||
# print(images_train.shape)
|
||||
# print(images_test.shape)
|
||||
|
||||
|
||||
return [labels_train, labels_test, images_train, images_test]
|
||||
|
||||
if(__name__ == "__main__"):
|
||||
[labels_train, labels_test, images_train, images_test] = mnist_load()
|
||||
print("Loaded MNIST Data")
|
||||
print(labels_train.shape)
|
||||
print(labels_test.shape)
|
||||
print(images_train.shape)
|
||||
print(images_test.shape)
|
@ -0,0 +1,401 @@
|
||||
#!/usr/bin/python3
|
||||
|
||||
import os,sys,math
|
||||
import numpy as np
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import mnist_dataloader
|
||||
|
||||
from mnistdiffusion_model import mndiff_rev
|
||||
from mnistdiffusion_utils import *
|
||||
|
||||
import time
|
||||
|
||||
################
|
||||
## Dataloader ##
|
||||
################
|
||||
|
||||
def mnist_load():
|
||||
|
||||
[labels_train, labels_test, images_train, images_test] = mnist_dataloader.mnist_load()
|
||||
|
||||
#60000 x 28 x 28
|
||||
#60000 x 1
|
||||
|
||||
#Rescale data to (-1,1)
|
||||
images_train = images_train.clone().detach().float()
|
||||
images_test = images_test.clone().detach().float()
|
||||
|
||||
images_train = 2.0*(images_train/255.0)-1.0
|
||||
images_test = 2.0*(images_test/255.0)-1.0
|
||||
|
||||
images_train.requires_grad = False
|
||||
images_test.requires_grad = False
|
||||
|
||||
return [images_train, images_test]
|
||||
|
||||
|
||||
|
||||
|
||||
################
|
||||
## Operations ##
|
||||
################
|
||||
|
||||
# Rather than killing myself figuring out variable variance schedules,
|
||||
# just use a constant variance schedule to start with
|
||||
|
||||
def training_parameters():
|
||||
lparas = dict()
|
||||
|
||||
nts = 200
|
||||
Beta = 0.0050
|
||||
|
||||
lr = 1E-4
|
||||
batchsize = 40
|
||||
save_every = 200
|
||||
record_every = 25
|
||||
|
||||
B_t = np.ones((nts))*Beta
|
||||
a_t = np.zeros((nts))
|
||||
a_t[0] = 1
|
||||
for I in range(1,nts):
|
||||
a_t[I] = a_t[I-1]*(1-B_t[I-1])
|
||||
|
||||
t = np.linspace(0,1,nts) #normalized time (t/T)
|
||||
|
||||
lparas["nts"] = nts
|
||||
lparas["Beta"] = Beta
|
||||
lparas["t"] = torch.tensor(t,dtype=torch.float32)
|
||||
lparas["a_t"] = torch.tensor(a_t,dtype=torch.float32)
|
||||
lparas["B_t"] = torch.tensor(B_t,dtype=torch.float32)
|
||||
lparas["lr"] = lr
|
||||
lparas["batchsize"] = batchsize
|
||||
lparas["save_every"] = save_every
|
||||
lparas["record_every"] = record_every
|
||||
return lparas
|
||||
|
||||
|
||||
def generate_minibatch(imgs, Ndraw, lparas):
|
||||
|
||||
#I can probably speed this up with compiled CUDA code
|
||||
|
||||
nts = lparas["nts"]
|
||||
Beta = lparas["Beta"]
|
||||
|
||||
dev = imgs.device
|
||||
|
||||
imbatch = torch.zeros([Ndraw,nts,28,28],dtype=torch.float32).to(dev)
|
||||
draws = torch.randint(0,imgs.shape[0],(Ndraw,))
|
||||
|
||||
times = lparas["t"].to(dev)
|
||||
times = times.reshape([1,nts]).expand([Ndraw,nts])
|
||||
|
||||
for I in range(0,Ndraw):
|
||||
imgseq = torch.zeros([nts,28,28],dtype=torch.float32).to(dev)
|
||||
imgseq[0,:,:] = imgs[draws[I],:,:]
|
||||
|
||||
beta = torch.kron(lparas["B_t"],torch.ones(28,28)).to(dev)
|
||||
beta = beta.reshape([nts,28,28])
|
||||
|
||||
sig = torch.sqrt(beta)
|
||||
noise = torch.randn((nts,28,28)).to(dev)*sig
|
||||
|
||||
for J in range(1,nts):
|
||||
imgseq[J,:,:] = imgseq[J-1,:,:]*torch.sqrt(1.0-beta[J-1,:,:]) + noise[J-1,:,:]
|
||||
|
||||
|
||||
imbatch[I,:,:,:] = imgseq[:,:,:]
|
||||
|
||||
return [imbatch,times]
|
||||
|
||||
def train_batch(revnet,batchimg,batchtime,lparas):
|
||||
|
||||
L = 0.0
|
||||
lr = lparas["lr"]
|
||||
nts = batchtime.shape[1]
|
||||
|
||||
dev = revnet.device
|
||||
#beta = torch.as_tensor(lparas["Beta"]).float().to(dev)
|
||||
beta = float(lparas["Beta"])
|
||||
|
||||
optim = torch.optim.SGD(revnet.parameters(),lr=lr)
|
||||
optim = torch.optim.Adam(revnet.parameters(),lr=lr)
|
||||
|
||||
loss = torch.nn.MSELoss().to(dev)
|
||||
|
||||
d_est = revnet(batchimg,batchtime)
|
||||
d_est = d_est[:,1:nts,:,:]
|
||||
targ = batchimg[:,1:nts,:,:]-batchimg[:,0:nts-1,:,:]
|
||||
|
||||
L2 = loss(targ,beta*d_est)
|
||||
|
||||
L2.backward()
|
||||
|
||||
optim.step()
|
||||
L = L2.detach().cpu().numpy()
|
||||
L = L/beta**2
|
||||
|
||||
loss.zero_grad()
|
||||
|
||||
|
||||
return L
|
||||
|
||||
|
||||
|
||||
|
||||
def train(directory,saveseries):
|
||||
|
||||
print("Loading dataset...")
|
||||
[images_train, images_test] = mnist_load()
|
||||
lparas = training_parameters()
|
||||
|
||||
batchsize = lparas["batchsize"]
|
||||
save_every = lparas["save_every"]
|
||||
record_every = lparas["record_every"]
|
||||
device = torch.device("cuda")
|
||||
|
||||
revnet = mndiff_rev()
|
||||
fname = get_current_save(directory,saveseries)
|
||||
if(fname is None):
|
||||
fname = get_next_save(directory,saveseries)
|
||||
else:
|
||||
print("Loading {}".format(fname))
|
||||
revnet.load(fname)
|
||||
|
||||
images_train = images_train.to(device)
|
||||
revnet = revnet.to(device)
|
||||
|
||||
## Training Loop
|
||||
I = 0
|
||||
J = 0
|
||||
while(True):
|
||||
|
||||
if(I%record_every==0):
|
||||
print("Minibatch {}".format(I))
|
||||
|
||||
#New minibatch
|
||||
losses = np.zeros((record_every))
|
||||
|
||||
|
||||
#draw minibatch
|
||||
[mb,mbt] = generate_minibatch(images_train,batchsize,lparas)
|
||||
|
||||
#train minibatch
|
||||
L = train_batch(revnet,mb,mbt,lparas)
|
||||
losses[J] = L
|
||||
J = J + 1
|
||||
|
||||
#record results
|
||||
if(J%record_every==0):
|
||||
J = 0
|
||||
loss = np.mean(losses)
|
||||
revnet.recordloss(loss)
|
||||
print("\tloss={:1.5f}".format(loss))
|
||||
|
||||
#save
|
||||
if(I%save_every==0):
|
||||
fnext = get_next_save(directory,saveseries)
|
||||
revnet.save(fnext)
|
||||
|
||||
I = I + 1
|
||||
|
||||
return
|
||||
|
||||
def infer(directory,saveseries):
|
||||
|
||||
ttt = int(time.time())
|
||||
np.random.seed(ttt%50000)
|
||||
torch.manual_seed(ttt%50000)
|
||||
|
||||
|
||||
|
||||
lparas = training_parameters()
|
||||
revnet = mndiff_rev()
|
||||
fname = get_current_save(directory,saveseries)
|
||||
if(fname is None):
|
||||
print("Warning, no state loaded.")
|
||||
else:
|
||||
print("Loading {}".format(fname))
|
||||
revnet.load(fname)
|
||||
|
||||
|
||||
nts = lparas["nts"]
|
||||
beta = lparas["Beta"]
|
||||
|
||||
nts = nts*15
|
||||
|
||||
imgseq = torch.zeros([nts,28,28])
|
||||
imgseq[nts-1,:,:] = torch.randn(28,28)
|
||||
|
||||
#imgseq[nts-1,:,17:20] = 3.0
|
||||
|
||||
tl = torch.linspace(0,1,nts)
|
||||
|
||||
bb = torch.as_tensor(beta)
|
||||
|
||||
for I in range(nts-2,-1,-1):
|
||||
img = imgseq[I+1,:,:]
|
||||
t = tl[I]
|
||||
eps = revnet(img,t)
|
||||
|
||||
|
||||
|
||||
img2 = img*torch.sqrt(1-bb) - bb*eps + 1.0*bb*torch.randn(28,28)
|
||||
|
||||
imgseq[I,:,:] = img2
|
||||
|
||||
im0 = torch.squeeze(imgseq[0,:,:])
|
||||
im1 = torch.squeeze(imgseq[int(nts*0.25),:,:])
|
||||
im2 = torch.squeeze(imgseq[int(nts*0.75),:,:])
|
||||
im3 = torch.squeeze(imgseq[int(nts*0.90),:,:])
|
||||
|
||||
im0 = im0.detach().numpy()
|
||||
im1 = im1.detach().numpy()
|
||||
im2 = im2.detach().numpy()
|
||||
im3 = im3.detach().numpy()
|
||||
|
||||
plt.figure()
|
||||
plt.subplot(2,2,1)
|
||||
plt.imshow(im3,cmap='gray')
|
||||
plt.subplot(2,2,2)
|
||||
plt.imshow(im2,cmap='gray')
|
||||
plt.subplot(2,2,3)
|
||||
plt.imshow(im1,cmap='gray')
|
||||
plt.subplot(2,2,4)
|
||||
plt.imshow(im0,cmap='gray')
|
||||
plt.show()
|
||||
|
||||
|
||||
return
|
||||
|
||||
def plot_history(directory, saveseries):
|
||||
|
||||
revnet = mndiff_rev()
|
||||
fname = get_current_save(directory,saveseries)
|
||||
if(fname is None):
|
||||
fname = get_next_save(directory,saveseries)
|
||||
else:
|
||||
print("Loading {}".format(fname))
|
||||
revnet.load(fname)
|
||||
|
||||
revnet.plothistory()
|
||||
|
||||
return
|
||||
|
||||
def test1():
|
||||
|
||||
[images_train, images_test] = mnist_load()
|
||||
lparas = training_parameters()
|
||||
|
||||
#plt.figure()
|
||||
#plt.plot(lparas["t"],lparas["a_t"])
|
||||
#plt.show()
|
||||
|
||||
#plt.figure()
|
||||
#plt.plot(lparas["t"],lparas["B_t"])
|
||||
#plt.show()
|
||||
|
||||
images_train = images_train.to("cuda")
|
||||
|
||||
[mb,mbtime] = generate_minibatch(images_train,50,lparas)
|
||||
|
||||
img0 = torch.squeeze(mb[0,0,:,:]).clone().detach().to("cpu")
|
||||
img1 = torch.squeeze(mb[0,1,:,:]).clone().detach().to("cpu")
|
||||
img2 = torch.squeeze(mb[0,2,:,:]).clone().detach().to("cpu")
|
||||
img3 = torch.squeeze(mb[0,99,:,:]).clone().detach().to("cpu")
|
||||
|
||||
print(mb.shape)
|
||||
print(mbtime.shape)
|
||||
|
||||
plt.figure()
|
||||
plt.subplot(2,2,1)
|
||||
plt.imshow(img0,cmap='gray')
|
||||
plt.subplot(2,2,2)
|
||||
plt.imshow(img1,cmap='gray')
|
||||
plt.subplot(2,2,3)
|
||||
plt.imshow(img2,cmap='gray')
|
||||
plt.subplot(2,2,4)
|
||||
plt.imshow(img3,cmap='gray')
|
||||
|
||||
plt.show()
|
||||
|
||||
return
|
||||
|
||||
def test2(directory, saveseries):
|
||||
|
||||
lparas = training_parameters()
|
||||
nts = lparas["nts"]
|
||||
beta = lparas["Beta"]
|
||||
|
||||
revnet = mndiff_rev()
|
||||
fname = get_current_save(directory,saveseries)
|
||||
if(fname is None):
|
||||
fname = get_next_save(directory,saveseries)
|
||||
else:
|
||||
print("Loading {}".format(fname))
|
||||
revnet.load(fname)
|
||||
|
||||
|
||||
[images_train, images_test] = mnist_load()
|
||||
|
||||
x = torch.squeeze(images_train[100,:,:])
|
||||
#x = beta*torch.randn(28,28)
|
||||
y = revnet(x,0.5)
|
||||
|
||||
x = x.detach().numpy()
|
||||
y = y.detach().numpy()
|
||||
|
||||
print("y mn={:1.3f} std={:1.3f} std2*beta={}".format(np.mean(y),np.std(y),np.std(y)**2*beta))
|
||||
|
||||
|
||||
plt.figure()
|
||||
plt.subplot(2,1,1)
|
||||
plt.imshow(x,cmap='gray')
|
||||
plt.subplot(2,1,2)
|
||||
plt.imshow(y,cmap='gray')
|
||||
plt.show()
|
||||
|
||||
|
||||
return
|
||||
|
||||
##########
|
||||
## Main ##
|
||||
##########
|
||||
|
||||
def mainswitch():
|
||||
args = sys.argv
|
||||
if(len(args)<2):
|
||||
print("mnistdiffusion [operation]")
|
||||
print("operations:")
|
||||
print("\ttrain")
|
||||
print("\tplot_history")
|
||||
print("\tinfer")
|
||||
|
||||
exit(0)
|
||||
|
||||
if(sys.argv[1]=="train"):
|
||||
train("./saves","test_2a")
|
||||
if(sys.argv[1]=="plot_history"):
|
||||
plot_history("./saves","test_2a")
|
||||
if(sys.argv[1]=="infer"):
|
||||
infer("./saves","test_2a")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if(__name__=="__main__"):
|
||||
|
||||
ttt = int(time.time())
|
||||
np.random.seed(ttt%50000)
|
||||
torch.manual_seed(ttt%50000)
|
||||
|
||||
#test1()
|
||||
#test2("./saves","test_2a")
|
||||
mainswitch()
|
||||
|
@ -0,0 +1,298 @@
|
||||
#!/usr/bin/python3
|
||||
|
||||
import os,sys,math
|
||||
import numpy as np
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import mnist_dataloader
|
||||
|
||||
|
||||
## Sohl-Dickstein 2015 describes a gaussian "bump" function to
|
||||
# implement learnable time-dependence
|
||||
#
|
||||
#network input vector and normalized time (tn e (0,1))
|
||||
# input: (dimx, dimy, ..., features), (scalar) --> output: (dimx, dimy, ...)
|
||||
# input: (ntimes, ..., features), (ntimes) --> output: (ntimes , ...)
|
||||
class gaussian_timeselect(nn.Module):
|
||||
|
||||
def __init__(self, nfeatures=5, width = 0.25):
|
||||
super(gaussian_timeselect,self).__init__()
|
||||
|
||||
tau_j = torch.linspace(0,1,nfeatures).float()
|
||||
self.tau_j = torch.nn.parameter.Parameter(tau_j, requires_grad=False)
|
||||
|
||||
width = torch.tensor(width,dtype=torch.float32)
|
||||
self.width = torch.nn.parameter.Parameter(width,requires_grad = False)
|
||||
|
||||
#network input vector and normalized time (tn e (0,1))
|
||||
# input: (dimx, dimy, ..., features), (scalar) --> output: (dimx, dimy, ...)
|
||||
# input: (ntimes, ..., features), (ntimes) --> output: (ntimes , ...)
|
||||
def forward(self, x, tn):
|
||||
|
||||
ndim = len(x.shape)
|
||||
|
||||
tn = torch.as_tensor(tn)
|
||||
tn.requires_grad = False
|
||||
|
||||
|
||||
if(len(tn.shape)==0):
|
||||
#case 0 - if tn is a scalar
|
||||
g = torch.exp(-1.0/(2.0*self.width)*(tn-self.tau_j)**2)
|
||||
g = g/torch.clip(torch.sum(g),min=0.001)
|
||||
x = torch.tensordot(x,g,dims=[[ndim-1],[0]])
|
||||
else:
|
||||
#Case 1 - tn is a tensor
|
||||
nf = self.tau_j.shape[0]
|
||||
nt = tn.numel()
|
||||
nt2 = x.shape[0]
|
||||
ndim = len(x.shape)
|
||||
nf2 = x.shape[ndim-1]
|
||||
|
||||
if(nt!=nt2):
|
||||
raise Exception("gaussian_timeselect time mismatch")
|
||||
if(nf!=nf2):
|
||||
raise Exception("gaussian_timeselect feature mismatch")
|
||||
|
||||
shp2 = x.shape[1:ndim-1]
|
||||
nshp2 = shp2.numel()
|
||||
|
||||
x = x.reshape([nt,nshp2,nf])
|
||||
tau = self.tau_j.reshape([1,1,nf])
|
||||
tau = tau.expand([nt,nshp2,nf])
|
||||
tn = tn.reshape([nt,1,1])
|
||||
tn = tn.expand([nt,nshp2,nf])
|
||||
|
||||
g = torch.exp(-1.0/(2.0*self.width)*(tn-tau)**2)
|
||||
gs = torch.clip(torch.sum(g,axis=2),min=0.0001)
|
||||
gs = gs.reshape([nt,nshp2,1]).expand([nt,nshp2,nf])
|
||||
g = g/gs
|
||||
|
||||
x = torch.sum(x*g,axis=2)
|
||||
#shp3 = torch.Size(torch.cat(torch.tensor([nt]),torch.tensor(shp2)))
|
||||
# contact [nt, *shp2]
|
||||
shp3 = torch.Size([nt,*shp2])
|
||||
x = x.reshape(shp3)
|
||||
|
||||
return x
|
||||
|
||||
#The reverse network should output a mean image and a covariance image
|
||||
#given an input image
|
||||
|
||||
class mndiff_rev(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(mndiff_rev,self).__init__()
|
||||
|
||||
self.L1 = torch.nn.Conv2d(1,10,3)
|
||||
self.L2 = torch.nn.Conv2d(10,10,5)
|
||||
self.L3 = torch.nn.ConvTranspose2d(10,10,5)
|
||||
self.L4 = torch.nn.ConvTranspose2d(10,10,3)
|
||||
|
||||
#bypass branch
|
||||
#self.Lb1 = torch.nn.Linear(1,10)
|
||||
#self.Lb2 = torch.nn.Linear(10,2)
|
||||
#self.Lb4f = torch.nn.Linear(12,12)
|
||||
|
||||
self.L5 = gaussian_timeselect(10,0.1)
|
||||
|
||||
|
||||
#self.L4 = nn.Conv2d(10,10,kernel_size=1)
|
||||
#channels must be 2nd dimension
|
||||
# self.L4a = nn.Linear(5,5)
|
||||
# #equivalent, channels must be last dimension
|
||||
# self.L4b = nn.Linear(5,5)
|
||||
|
||||
|
||||
# self.L5 = gaussian_timeselect(5,0.1)
|
||||
|
||||
#self.L5alt = nn.Linear(5,1)
|
||||
|
||||
#self.RL = nn.LeakyReLU(0.01)
|
||||
self.tanh = nn.Tanh()
|
||||
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
self.losstime = []
|
||||
self.lossrecord = []
|
||||
|
||||
return
|
||||
|
||||
|
||||
def forward(self,x,t):
|
||||
|
||||
#handle data-shape cases
|
||||
#case 0: x (pix_x, pix_y), t is scalar
|
||||
#case 1: x (ntimes, pix_x, pix_y), t is (ntimes)
|
||||
#case 2: x (nbatch, ntimes, pix_x, pix_y), t is (nbatch, ntimes)
|
||||
|
||||
if(len(x.shape)==2):
|
||||
case = 0
|
||||
t = torch.as_tensor(t).to(self.device)
|
||||
t = t.float()
|
||||
x = x.reshape([1,1,28,28])
|
||||
ntimes = 1
|
||||
nbatch = 1
|
||||
elif(len(x.shape)==3):
|
||||
case = 1
|
||||
ntimes = x.shape[0]
|
||||
nbatch = 1
|
||||
x = x.reshape([ntimes,1,28,28])
|
||||
elif(len(x.shape)==4):
|
||||
case = 2
|
||||
nbatch = x.shape[0]
|
||||
ntimes = x.shape[1]
|
||||
#x = x.reshape([nbatch,ntimes,28*28])
|
||||
x = x.reshape([nbatch*ntimes,1,28,28])
|
||||
t = t.reshape([nbatch*ntimes])
|
||||
|
||||
y = torch.randn(28*28,device=self.device)
|
||||
|
||||
else:
|
||||
print("Error: Expect input to be dimension 2,3,4")
|
||||
raise Exception("Error: Expect input to be dimension 2,3,4")
|
||||
|
||||
s1 = self.L1(x)
|
||||
s1 = self.tanh(s1)
|
||||
#print(s1.shape)
|
||||
s1 = self.L2(s1)
|
||||
s1 = self.tanh(s1)
|
||||
#print(s1.shape)
|
||||
s1 = self.L3(s1)
|
||||
s1 = self.tanh(s1)
|
||||
#print(s1.shape)
|
||||
s1 = self.L4(s1)
|
||||
s1 = self.tanh(s1)
|
||||
#print(s1.shape)
|
||||
s1 = s1.permute([0,2,3,1])
|
||||
s2 = self.L5(s1,t)
|
||||
|
||||
eps = s2
|
||||
|
||||
|
||||
if(case==0):
|
||||
eps = eps.reshape([28,28])
|
||||
elif(case==1):
|
||||
eps = eps.reshape([ntimes,28,28])
|
||||
elif(case==2):
|
||||
eps = eps.reshape([nbatch,ntimes,28,28])
|
||||
|
||||
return eps
|
||||
|
||||
def to(self,dev):
|
||||
self.device = torch.device(dev)
|
||||
ret = super(mndiff_rev,self).to(self.device)
|
||||
return ret
|
||||
|
||||
|
||||
def save(self,fname):
|
||||
try:
|
||||
dev = self.device
|
||||
self.to("cpu")
|
||||
q = [self.state_dict(), self.losstime, self.lossrecord]
|
||||
torch.save(q,fname)
|
||||
self.to(dev)
|
||||
except:
|
||||
print("model save: problem saving {}".format(fname))
|
||||
|
||||
return
|
||||
|
||||
def load(self,fname):
|
||||
try:
|
||||
[modelsd,losstime,lossrecord] = torch.load(fname)
|
||||
self.load_state_dict(modelsd)
|
||||
self.losstime = losstime
|
||||
self.lossrecord = lossrecord
|
||||
except:
|
||||
print("model load: problem loading {}".format(fname))
|
||||
|
||||
return
|
||||
|
||||
|
||||
def plothistory(self):
|
||||
x = self.losstime
|
||||
y = self.lossrecord
|
||||
plt.figure()
|
||||
plt.plot(x,y,'k.',markersize=1)
|
||||
plt.title('Loss History')
|
||||
plt.show()
|
||||
|
||||
return
|
||||
|
||||
def recordloss(self,loss):
|
||||
|
||||
t = 0
|
||||
L = len(self.losstime)
|
||||
if(L>0):
|
||||
t = self.losstime[L-1]
|
||||
self.losstime.append(t+1)
|
||||
self.lossrecord.append(loss)
|
||||
|
||||
return
|
||||
|
||||
|
||||
|
||||
###########
|
||||
## Tests ##
|
||||
###########
|
||||
|
||||
def test_gaussian_timeselect():
|
||||
gts = gaussian_timeselect(5,0.025)
|
||||
gts = gts.to("cuda")
|
||||
|
||||
x = torch.randn((8,8,5)).to("cuda")
|
||||
#t = torch.tensor([0,0.5,1]).to("cuda")
|
||||
t = torch.linspace(0,1,8).to("cuda")
|
||||
|
||||
y = gts(x,0)
|
||||
y2 = gts(x,t)
|
||||
y3 = gts(x,0.111)
|
||||
|
||||
print(y)
|
||||
print(x[:,:,0])
|
||||
|
||||
print(x.shape)
|
||||
print(y.shape)
|
||||
print(y2.shape)
|
||||
|
||||
print(torch.abs(y2[0,:]-y[0,:])>1E-6)
|
||||
|
||||
|
||||
|
||||
return
|
||||
|
||||
def test_mndiff_rev():
|
||||
|
||||
mnd = mndiff_rev()
|
||||
|
||||
x0 = torch.randn(28,28)
|
||||
t0 = 0
|
||||
x1 = torch.randn(5,28,28)
|
||||
t1 = torch.linspace(0,1,5)
|
||||
|
||||
x2 = torch.randn(8,5,28,28)
|
||||
t2 = torch.linspace(0,1,5)
|
||||
t2 = t2.reshape([1,5]).expand([8,5])
|
||||
|
||||
y0 = mnd(x0,t0)
|
||||
y1 = mnd(x1,t1)
|
||||
y2 = mnd(x2,t2)
|
||||
|
||||
print(y0.shape)
|
||||
print(y1.shape)
|
||||
print(y2.shape)
|
||||
|
||||
|
||||
|
||||
return
|
||||
|
||||
if(__name__=="__main__"):
|
||||
|
||||
#test_gaussian_timeselect()
|
||||
test_mndiff_rev()
|
||||
|
||||
pass
|
@ -0,0 +1,78 @@
|
||||
#!/usr/bin/python3
|
||||
|
||||
import os,sys,math
|
||||
|
||||
|
||||
def get_savenumber(fname):
|
||||
num = 0
|
||||
|
||||
fn2 = os.path.split(fname)[1]
|
||||
fn2 = os.path.splitext(fn2)[0]
|
||||
|
||||
L = len(fn2)
|
||||
L2 = L
|
||||
for I in range(L-1,0,-1):
|
||||
c = fn2[I]
|
||||
if(not c.isnumeric()):
|
||||
L2 = I+1
|
||||
break
|
||||
nm = fn2[L2:L]
|
||||
|
||||
try:
|
||||
num = int(nm)
|
||||
except:
|
||||
num = 0
|
||||
|
||||
return num
|
||||
|
||||
|
||||
def sorttwopythonlists(list1, list2):
|
||||
|
||||
zipped_pairs = zip(list2, list1)
|
||||
|
||||
z = [x for _, x in sorted(zipped_pairs)]
|
||||
z2 = sorted(list2)
|
||||
|
||||
return [z,z2]
|
||||
|
||||
def get_current_save(directory,saveseries):
|
||||
fname = None
|
||||
|
||||
fl = os.listdir(directory)
|
||||
fl2 = []
|
||||
nums = []
|
||||
|
||||
for f in fl:
|
||||
if(f.find(saveseries)>=0):
|
||||
fl2.append(os.path.join(directory,f))
|
||||
|
||||
for f in fl2:
|
||||
n = get_savenumber(f)
|
||||
nums.append(n)
|
||||
|
||||
[fl2,nums] = sorttwopythonlists(fl2,nums)
|
||||
|
||||
if(len(fl2)>0):
|
||||
fname = fl2[len(fl2)-1]
|
||||
else:
|
||||
fname = None
|
||||
|
||||
return fname
|
||||
|
||||
def get_next_save(directory,saveseries):
|
||||
fname = None
|
||||
fncurrent = get_current_save(directory,saveseries)
|
||||
if(fncurrent is None):
|
||||
fname = os.path.join(directory,"{}{:02d}.pyt".format(saveseries,0))
|
||||
else:
|
||||
N = get_savenumber(fncurrent)
|
||||
N = N + 1
|
||||
fname = os.path.join(directory,"{}{:02d}.pyt".format(saveseries,N))
|
||||
|
||||
|
||||
return fname
|
||||
|
||||
if(__name__=="__main__"):
|
||||
|
||||
#print(get_savenumber("./saves/helloworld05202.py"))
|
||||
pass
|
@ -0,0 +1,161 @@
|
||||
#!/usr/bin/python3
|
||||
|
||||
import os,sys,math
|
||||
import numpy as np
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import mnist_dataloader
|
||||
|
||||
from mnistdiffusion_model import mndiff_rev
|
||||
|
||||
################
|
||||
## Dataloader ##
|
||||
################
|
||||
|
||||
def mnist_load():
|
||||
|
||||
[labels_train, labels_test, images_train, images_test] = mnist_dataloader.mnist_load()
|
||||
|
||||
#60000 x 28 x 28
|
||||
#60000 x 1
|
||||
|
||||
#Rescale data to (-1,1)
|
||||
images_train = images_train.clone().detach().float()
|
||||
images_test = images_test.clone().detach().float()
|
||||
|
||||
images_train = 2.0*(images_train/255.0)-1.0
|
||||
images_test = 2.0*(images_test/255.0)-1.0
|
||||
|
||||
images_train.requires_grad = False
|
||||
images_test.requires_grad = False
|
||||
|
||||
return [images_train, images_test]
|
||||
|
||||
|
||||
|
||||
|
||||
################
|
||||
## Operations ##
|
||||
################
|
||||
|
||||
def training_parameters():
|
||||
lparas = dict()
|
||||
|
||||
nts = 200
|
||||
t = np.arange(0,nts)
|
||||
s = 0.1
|
||||
f_t = np.cos((t/nts+s)/(1+s)*np.pi/2.0)**2
|
||||
a_t = f_t/f_t[0]
|
||||
B_t = np.zeros(t.shape[0])
|
||||
B_t[0:nts-1] = np.clip(1-a_t[1:nts]/a_t[0:nts-1],0,0.999)
|
||||
B_t[nts-1] = B_t[nts-2]
|
||||
|
||||
lparas["nts"] = nts
|
||||
lparas["t"] = torch.tensor(t,dtype=torch.float32)
|
||||
lparas["a_t"] = torch.tensor(a_t,dtype=torch.float32)
|
||||
lparas["B_t"] = torch.tensor(B_t,dtype=torch.float32)
|
||||
|
||||
#Traiing parameters for a cosine variance schedule
|
||||
#ref: Nichol and Dharawal 2021
|
||||
|
||||
return lparas
|
||||
|
||||
|
||||
def generate_minibatch(imgs, Ndraw, lparas):
|
||||
|
||||
#I can probably speed this up with compiled CUDA code
|
||||
|
||||
nts = lparas["nts"]
|
||||
dev = imgs.device
|
||||
|
||||
imbatch = torch.zeros([Ndraw*nts,28,28],dtype=torch.float32).to(dev)
|
||||
draws = torch.randint(0,imgs.shape[0],(Ndraw,))
|
||||
|
||||
for I in range(0,Ndraw):
|
||||
imgseq = torch.zeros([nts,28,28],dtype=torch.float32).to(dev)
|
||||
imgseq[0,:,:] = imgs[draws[I],:,:]
|
||||
|
||||
beta = torch.kron(lparas["B_t"],torch.ones(28,28)).to(dev)
|
||||
beta = beta.reshape([nts,28,28])
|
||||
|
||||
sig = torch.sqrt(beta)
|
||||
noise = torch.randn((nts,28,28)).to(dev)*sig
|
||||
|
||||
for J in range(1,nts):
|
||||
imgseq[J,:,:] = imgseq[J-1,:,:] + noise[J-1,:,:]
|
||||
|
||||
I1 = I*nts
|
||||
I2 = (I+1)*nts
|
||||
imbatch[I1:I2,:,:] = imgseq[:,:,:]
|
||||
|
||||
return imbatch
|
||||
|
||||
def train_batch(imgbatch,lr,lparas):
|
||||
|
||||
|
||||
|
||||
|
||||
return
|
||||
|
||||
|
||||
def train():
|
||||
|
||||
return
|
||||
|
||||
def infer():
|
||||
|
||||
return
|
||||
|
||||
|
||||
|
||||
def test1():
|
||||
|
||||
[images_train, images_test] = mnist_load()
|
||||
lparas = training_parameters()
|
||||
|
||||
#plt.figure()
|
||||
#plt.plot(lparas["t"],lparas["a_t"])
|
||||
#plt.show()
|
||||
|
||||
#plt.figure()
|
||||
#plt.plot(lparas["t"],lparas["B_t"])
|
||||
#plt.show()
|
||||
|
||||
images_train = images_train.to("cuda")
|
||||
|
||||
mb = generate_minibatch(images_train,50,lparas)
|
||||
|
||||
img0 = torch.squeeze(mb[0,:,:]).clone().detach().to("cpu")
|
||||
img1 = torch.squeeze(mb[1,:,:]).clone().detach().to("cpu")
|
||||
img2 = torch.squeeze(mb[2,:,:]).clone().detach().to("cpu")
|
||||
img3 = torch.squeeze(mb[30,:,:]).clone().detach().to("cpu")
|
||||
|
||||
|
||||
plt.figure()
|
||||
plt.subplot(2,2,1)
|
||||
plt.imshow(img0,cmap='gray')
|
||||
plt.subplot(2,2,2)
|
||||
plt.imshow(img1,cmap='gray')
|
||||
plt.subplot(2,2,3)
|
||||
plt.imshow(img2,cmap='gray')
|
||||
plt.subplot(2,2,4)
|
||||
plt.imshow(img3,cmap='gray')
|
||||
|
||||
plt.show()
|
||||
|
||||
return
|
||||
|
||||
|
||||
##########
|
||||
## Main ##
|
||||
##########
|
||||
|
||||
if(__name__=="__main__"):
|
||||
|
||||
#test1()
|
||||
test_gaussian_timeselect()
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,85 @@
|
||||
{
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.9-final"
|
||||
},
|
||||
"orig_nbformat": 2,
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3.6.9 64-bit",
|
||||
"metadata": {
|
||||
"interpreter": {
|
||||
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2,
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os,sys,math\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"\n",
|
||||
"import mnist_dataloader\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"torch.Size([])\ntorch.Size([])\n[1, 2, 3]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"a = torch.randn(10,14)\n",
|
||||
"b = a.shape[1:1]\n",
|
||||
"print(b)\n",
|
||||
"b.numel()\n",
|
||||
"\n",
|
||||
"print(b)\n",
|
||||
"\n",
|
||||
"b = torch.Size([1])\n",
|
||||
"c = torch.Size([2,3])\n",
|
||||
"d = torch.Size(torch.cat([torch.tensor(b),torch.tensor(c)]))\n",
|
||||
"\n",
|
||||
"d = [*b,*c]\n",
|
||||
"\n",
|
||||
"print(d)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
]
|
||||
}
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,91 @@
|
||||
#!/usr/bin/python3
|
||||
import os,sys,math
|
||||
import numpy as np
|
||||
import cv2
|
||||
import gzip #need to use gzip.open instead of open
|
||||
import struct
|
||||
|
||||
import torch
|
||||
|
||||
def read_MNIST_label_file(fname):
|
||||
#fp = gzip.open('./train-labels-idx1-ubyte.gz','rb');
|
||||
fp = gzip.open(fname,'rb');
|
||||
magic = fp.read(4);
|
||||
#nitems = np.frombuffer(fp.read(4),dtype=np.int32)[0]; #some sort of endiannes problem
|
||||
bts = fp.read(4);
|
||||
#bts = bytereverse(bts);
|
||||
#nitems = np.frombuffer(bts,dtype=np.int32);
|
||||
nitems = np.int32(struct.unpack('>I',bts)[0]); #it was a non-native endianness in teh integer encoding
|
||||
#> < @ - endianness
|
||||
|
||||
bts = fp.read(nitems);
|
||||
N = len(bts);
|
||||
labels = np.zeros((N),dtype=np.uint8);
|
||||
labels = np.frombuffer(bts,dtype=np.uint8,count=N);
|
||||
#for i in range(0,10):
|
||||
# bt = fp.read(1);
|
||||
# labels[i] = np.frombuffer(bt,dtype=np.uint8);
|
||||
fp.close();
|
||||
return labels;
|
||||
|
||||
def read_MNIST_image_file(fname):
|
||||
fp = gzip.open(fname,'rb');
|
||||
magic = fp.read(4);
|
||||
bts = fp.read(4);
|
||||
nitems = np.int32(struct.unpack('>I',bts)[0]);
|
||||
bts = fp.read(4);
|
||||
nrows = np.int32(struct.unpack('>I',bts)[0]);
|
||||
bts = fp.read(4);
|
||||
ncols = np.int32(struct.unpack('>I',bts)[0]);
|
||||
|
||||
images = np.zeros((nitems,nrows,ncols),dtype=np.uint8);
|
||||
for I in range(0,nitems):
|
||||
bts = fp.read(nrows*ncols);
|
||||
img1 = np.frombuffer(bts,dtype=np.uint8,count=nrows*ncols);
|
||||
img1 = img1.reshape((nrows,ncols));
|
||||
images[I,:,:] = img1;
|
||||
|
||||
fp.close();
|
||||
|
||||
return images;
|
||||
|
||||
#The mnist dataset is small enough to fit entirely in memory
|
||||
def mnist_load():
|
||||
baseloc = "../training_data"
|
||||
|
||||
traindatafile = "train-images-idx3-ubyte.gz"
|
||||
trainlabelfile = "train-labels-idx1-ubyte.gz"
|
||||
testdatafile = "t10k-images-idx3-ubyte.gz"
|
||||
testlabelfile = "t10k-labels-idx1-ubyte.gz"
|
||||
|
||||
traindatafile = os.path.join(baseloc,traindatafile)
|
||||
trainlabelfile = os.path.join(baseloc,trainlabelfile)
|
||||
testdatafile = os.path.join(baseloc,testdatafile)
|
||||
testlabelfile = os.path.join(baseloc,testlabelfile)
|
||||
|
||||
labels_train = read_MNIST_label_file(trainlabelfile)
|
||||
labels_test = read_MNIST_label_file(testlabelfile)
|
||||
images_train = read_MNIST_image_file(traindatafile)
|
||||
images_test = read_MNIST_image_file(testdatafile)
|
||||
|
||||
labels_train = torch.tensor(labels_train,dtype=torch.float32,requires_grad=False)
|
||||
labels_test = torch.tensor(labels_test,dtype=torch.float32,requires_grad=False)
|
||||
images_train = torch.tensor(images_train,dtype=torch.float32,requires_grad=False)
|
||||
images_test = torch.tensor(images_test,dtype=torch.float32,requires_grad=False)
|
||||
|
||||
# #debug
|
||||
# print(labels_train.shape)
|
||||
# print(labels_test.shape)
|
||||
# print(images_train.shape)
|
||||
# print(images_test.shape)
|
||||
|
||||
|
||||
return [labels_train, labels_test, images_train, images_test]
|
||||
|
||||
if(__name__ == "__main__"):
|
||||
[labels_train, labels_test, images_train, images_test] = mnist_load()
|
||||
print("Loaded MNIST Data")
|
||||
print(labels_train.shape)
|
||||
print(labels_test.shape)
|
||||
print(images_train.shape)
|
||||
print(images_test.shape)
|
@ -0,0 +1,398 @@
|
||||
#!/usr/bin/python3
|
||||
|
||||
import os,sys,math
|
||||
import numpy as np
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import mnist_dataloader
|
||||
|
||||
from mnistdiffusion_model import mndiff_rev
|
||||
from mnistdiffusion_utils import *
|
||||
|
||||
import time
|
||||
|
||||
################
|
||||
## Dataloader ##
|
||||
################
|
||||
|
||||
def mnist_load():
|
||||
|
||||
[labels_train, labels_test, images_train, images_test] = mnist_dataloader.mnist_load()
|
||||
|
||||
#60000 x 28 x 28
|
||||
#60000 x 1
|
||||
|
||||
#Rescale data to (-1,1)
|
||||
images_train = images_train.clone().detach().float()
|
||||
images_test = images_test.clone().detach().float()
|
||||
|
||||
images_train = 2.0*(images_train/255.0)-1.0
|
||||
images_test = 2.0*(images_test/255.0)-1.0
|
||||
|
||||
images_train.requires_grad = False
|
||||
images_test.requires_grad = False
|
||||
|
||||
return [images_train, images_test]
|
||||
|
||||
|
||||
|
||||
|
||||
################
|
||||
## Operations ##
|
||||
################
|
||||
|
||||
# Rather than killing myself figuring out variable variance schedules,
|
||||
# just use a constant variance schedule to start with
|
||||
|
||||
def training_parameters():
|
||||
lparas = dict()
|
||||
|
||||
nts = 200
|
||||
Beta = 0.0050
|
||||
|
||||
lr = 1E-4
|
||||
batchsize = 30
|
||||
save_every = 200
|
||||
record_every = 25
|
||||
|
||||
B_t = np.ones((nts))*Beta
|
||||
a_t = np.zeros((nts))
|
||||
a_t[0] = 1
|
||||
for I in range(1,nts):
|
||||
a_t[I] = a_t[I-1]*(1-B_t[I-1])
|
||||
|
||||
t = np.linspace(0,1,nts) #normalized time (t/T)
|
||||
|
||||
lparas["nts"] = nts
|
||||
lparas["Beta"] = Beta
|
||||
lparas["t"] = torch.tensor(t,dtype=torch.float32)
|
||||
lparas["a_t"] = torch.tensor(a_t,dtype=torch.float32)
|
||||
lparas["B_t"] = torch.tensor(B_t,dtype=torch.float32)
|
||||
lparas["lr"] = lr
|
||||
lparas["batchsize"] = batchsize
|
||||
lparas["save_every"] = save_every
|
||||
lparas["record_every"] = record_every
|
||||
return lparas
|
||||
|
||||
|
||||
def generate_minibatch(imgs, Ndraw, lparas):
|
||||
|
||||
#I can probably speed this up with compiled CUDA code
|
||||
|
||||
nts = lparas["nts"]
|
||||
Beta = lparas["Beta"]
|
||||
|
||||
dev = imgs.device
|
||||
|
||||
imbatch = torch.zeros([Ndraw,nts,28,28],dtype=torch.float32).to(dev)
|
||||
draws = torch.randint(0,imgs.shape[0],(Ndraw,))
|
||||
|
||||
times = lparas["t"].to(dev)
|
||||
times = times.reshape([1,nts]).expand([Ndraw,nts])
|
||||
|
||||
for I in range(0,Ndraw):
|
||||
imgseq = torch.zeros([nts,28,28],dtype=torch.float32).to(dev)
|
||||
imgseq[0,:,:] = imgs[draws[I],:,:]
|
||||
|
||||
beta = torch.kron(lparas["B_t"],torch.ones(28,28)).to(dev)
|
||||
beta = beta.reshape([nts,28,28])
|
||||
|
||||
sig = torch.sqrt(beta)
|
||||
noise = torch.randn((nts,28,28)).to(dev)*sig
|
||||
|
||||
for J in range(1,nts):
|
||||
imgseq[J,:,:] = imgseq[J-1,:,:]*torch.sqrt(1.0-beta[J-1,:,:]) + noise[J-1,:,:]
|
||||
|
||||
|
||||
imbatch[I,:,:,:] = imgseq[:,:,:]
|
||||
|
||||
return [imbatch,times]
|
||||
|
||||
def train_batch(revnet,batchimg,batchtime,lparas):
|
||||
|
||||
L = 0.0
|
||||
lr = lparas["lr"]
|
||||
nts = batchtime.shape[1]
|
||||
|
||||
dev = revnet.device
|
||||
#beta = torch.as_tensor(lparas["Beta"]).float().to(dev)
|
||||
beta = float(lparas["Beta"])
|
||||
|
||||
optim = torch.optim.SGD(revnet.parameters(),lr=lr)
|
||||
optim = torch.optim.Adam(revnet.parameters(),lr=lr)
|
||||
|
||||
loss = torch.nn.MSELoss().to(dev)
|
||||
|
||||
d_est = revnet(batchimg,batchtime)
|
||||
d_est = d_est[:,1:nts,:,:]
|
||||
targ = batchimg[:,1:nts,:,:]-batchimg[:,0:nts-1,:,:]
|
||||
|
||||
L2 = loss(targ,beta*d_est)
|
||||
|
||||
L2.backward()
|
||||
|
||||
optim.step()
|
||||
L = L2.detach().cpu().numpy()
|
||||
L = L/beta**2
|
||||
|
||||
loss.zero_grad()
|
||||
|
||||
return L
|
||||
|
||||
|
||||
def train(directory,saveseries):
|
||||
|
||||
print("Loading dataset...")
|
||||
[images_train, images_test] = mnist_load()
|
||||
lparas = training_parameters()
|
||||
|
||||
batchsize = lparas["batchsize"]
|
||||
save_every = lparas["save_every"]
|
||||
record_every = lparas["record_every"]
|
||||
device = torch.device("cuda")
|
||||
|
||||
revnet = mndiff_rev()
|
||||
fname = get_current_save(directory,saveseries)
|
||||
if(fname is None):
|
||||
fname = get_next_save(directory,saveseries)
|
||||
else:
|
||||
print("Loading {}".format(fname))
|
||||
revnet.load(fname)
|
||||
|
||||
images_train = images_train.to(device)
|
||||
revnet = revnet.to(device)
|
||||
|
||||
## Training Loop
|
||||
I = 0
|
||||
J = 0
|
||||
while(True):
|
||||
|
||||
if(I%record_every==0):
|
||||
print("Minibatch {}".format(I))
|
||||
|
||||
#New minibatch
|
||||
losses = np.zeros((record_every))
|
||||
|
||||
|
||||
#draw minibatch
|
||||
[mb,mbt] = generate_minibatch(images_train,batchsize,lparas)
|
||||
|
||||
#train minibatch
|
||||
L = train_batch(revnet,mb,mbt,lparas)
|
||||
losses[J] = L
|
||||
J = J + 1
|
||||
|
||||
#record results
|
||||
if(J%record_every==0):
|
||||
J = 0
|
||||
loss = np.mean(losses)
|
||||
revnet.recordloss(loss)
|
||||
print("\tloss={:1.5f}".format(loss))
|
||||
|
||||
#save
|
||||
if(I%save_every==0):
|
||||
fnext = get_next_save(directory,saveseries)
|
||||
revnet.save(fnext)
|
||||
|
||||
I = I + 1
|
||||
|
||||
return
|
||||
|
||||
def infer(directory,saveseries):
|
||||
|
||||
ttt = int(time.time())
|
||||
np.random.seed(ttt%50000)
|
||||
torch.manual_seed(ttt%50000)
|
||||
|
||||
|
||||
|
||||
lparas = training_parameters()
|
||||
revnet = mndiff_rev()
|
||||
fname = get_current_save(directory,saveseries)
|
||||
if(fname is None):
|
||||
print("Warning, no state loaded.")
|
||||
else:
|
||||
print("Loading {}".format(fname))
|
||||
revnet.load(fname)
|
||||
|
||||
|
||||
nts = lparas["nts"]
|
||||
beta = lparas["Beta"]
|
||||
|
||||
nts = nts*15
|
||||
|
||||
imgseq = torch.zeros([nts,28,28])
|
||||
imgseq[nts-1,:,:] = torch.randn(28,28)
|
||||
|
||||
#imgseq[nts-1,:,17:20] = 3.0
|
||||
|
||||
tl = torch.linspace(0,1,nts)
|
||||
|
||||
bb = torch.as_tensor(beta)
|
||||
|
||||
for I in range(nts-2,-1,-1):
|
||||
img = imgseq[I+1,:,:]
|
||||
t = tl[I]
|
||||
eps = revnet(img,t)
|
||||
|
||||
|
||||
|
||||
img2 = img*torch.sqrt(1-bb) - bb*eps + 1.0*bb*torch.randn(28,28)
|
||||
|
||||
imgseq[I,:,:] = img2
|
||||
|
||||
im0 = torch.squeeze(imgseq[0,:,:])
|
||||
im1 = torch.squeeze(imgseq[int(nts*0.25),:,:])
|
||||
im2 = torch.squeeze(imgseq[int(nts*0.75),:,:])
|
||||
im3 = torch.squeeze(imgseq[int(nts*0.90),:,:])
|
||||
|
||||
im0 = im0.detach().numpy()
|
||||
im1 = im1.detach().numpy()
|
||||
im2 = im2.detach().numpy()
|
||||
im3 = im3.detach().numpy()
|
||||
|
||||
plt.figure()
|
||||
plt.subplot(2,2,1)
|
||||
plt.imshow(im3,cmap='gray')
|
||||
plt.subplot(2,2,2)
|
||||
plt.imshow(im2,cmap='gray')
|
||||
plt.subplot(2,2,3)
|
||||
plt.imshow(im1,cmap='gray')
|
||||
plt.subplot(2,2,4)
|
||||
plt.imshow(im0,cmap='gray')
|
||||
plt.show()
|
||||
|
||||
|
||||
return
|
||||
|
||||
def plot_history(directory, saveseries):
|
||||
|
||||
revnet = mndiff_rev()
|
||||
fname = get_current_save(directory,saveseries)
|
||||
if(fname is None):
|
||||
fname = get_next_save(directory,saveseries)
|
||||
else:
|
||||
print("Loading {}".format(fname))
|
||||
revnet.load(fname)
|
||||
|
||||
revnet.plothistory()
|
||||
|
||||
return
|
||||
|
||||
def test1():
|
||||
|
||||
[images_train, images_test] = mnist_load()
|
||||
lparas = training_parameters()
|
||||
|
||||
#plt.figure()
|
||||
#plt.plot(lparas["t"],lparas["a_t"])
|
||||
#plt.show()
|
||||
|
||||
#plt.figure()
|
||||
#plt.plot(lparas["t"],lparas["B_t"])
|
||||
#plt.show()
|
||||
|
||||
images_train = images_train.to("cuda")
|
||||
|
||||
[mb,mbtime] = generate_minibatch(images_train,50,lparas)
|
||||
|
||||
img0 = torch.squeeze(mb[0,0,:,:]).clone().detach().to("cpu")
|
||||
img1 = torch.squeeze(mb[0,1,:,:]).clone().detach().to("cpu")
|
||||
img2 = torch.squeeze(mb[0,2,:,:]).clone().detach().to("cpu")
|
||||
img3 = torch.squeeze(mb[0,99,:,:]).clone().detach().to("cpu")
|
||||
|
||||
print(mb.shape)
|
||||
print(mbtime.shape)
|
||||
|
||||
plt.figure()
|
||||
plt.subplot(2,2,1)
|
||||
plt.imshow(img0,cmap='gray')
|
||||
plt.subplot(2,2,2)
|
||||
plt.imshow(img1,cmap='gray')
|
||||
plt.subplot(2,2,3)
|
||||
plt.imshow(img2,cmap='gray')
|
||||
plt.subplot(2,2,4)
|
||||
plt.imshow(img3,cmap='gray')
|
||||
|
||||
plt.show()
|
||||
|
||||
return
|
||||
|
||||
def test2(directory, saveseries):
|
||||
|
||||
lparas = training_parameters()
|
||||
nts = lparas["nts"]
|
||||
beta = lparas["Beta"]
|
||||
|
||||
revnet = mndiff_rev()
|
||||
fname = get_current_save(directory,saveseries)
|
||||
if(fname is None):
|
||||
fname = get_next_save(directory,saveseries)
|
||||
else:
|
||||
print("Loading {}".format(fname))
|
||||
revnet.load(fname)
|
||||
|
||||
|
||||
[images_train, images_test] = mnist_load()
|
||||
|
||||
x = torch.squeeze(images_train[100,:,:])
|
||||
#x = beta*torch.randn(28,28)
|
||||
y = revnet(x,0.5)
|
||||
|
||||
x = x.detach().numpy()
|
||||
y = y.detach().numpy()
|
||||
|
||||
print("y mn={:1.3f} std={:1.3f} std2*beta={}".format(np.mean(y),np.std(y),np.std(y)**2*beta))
|
||||
|
||||
|
||||
plt.figure()
|
||||
plt.subplot(2,1,1)
|
||||
plt.imshow(x,cmap='gray')
|
||||
plt.subplot(2,1,2)
|
||||
plt.imshow(y,cmap='gray')
|
||||
plt.show()
|
||||
|
||||
|
||||
return
|
||||
|
||||
##########
|
||||
## Main ##
|
||||
##########
|
||||
|
||||
def mainswitch():
|
||||
args = sys.argv
|
||||
if(len(args)<2):
|
||||
print("mnistdiffusion [operation]")
|
||||
print("operations:")
|
||||
print("\ttrain")
|
||||
print("\tplot_history")
|
||||
print("\tinfer")
|
||||
|
||||
exit(0)
|
||||
|
||||
if(sys.argv[1]=="train"):
|
||||
train("./saves","test_3a")
|
||||
if(sys.argv[1]=="plot_history"):
|
||||
plot_history("./saves","test_3a")
|
||||
if(sys.argv[1]=="infer"):
|
||||
infer("./saves","test_3a")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if(__name__=="__main__"):
|
||||
|
||||
ttt = int(time.time())
|
||||
np.random.seed(ttt%50000)
|
||||
torch.manual_seed(ttt%50000)
|
||||
|
||||
#test1()
|
||||
#test2("./saves","test_2a")
|
||||
mainswitch()
|
||||
|
@ -0,0 +1,300 @@
|
||||
#!/usr/bin/python3
|
||||
|
||||
import os,sys,math
|
||||
import numpy as np
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import mnist_dataloader
|
||||
|
||||
|
||||
## Sohl-Dickstein 2015 describes a gaussian "bump" function to
|
||||
# implement learnable time-dependence
|
||||
#
|
||||
#network input vector and normalized time (tn e (0,1))
|
||||
# input: (dimx, dimy, ..., features), (scalar) --> output: (dimx, dimy, ...)
|
||||
# input: (ntimes, ..., features), (ntimes) --> output: (ntimes , ...)
|
||||
class gaussian_timeselect(nn.Module):
|
||||
|
||||
def __init__(self, nfeatures=5, width = 0.25):
|
||||
super(gaussian_timeselect,self).__init__()
|
||||
|
||||
tau_j = torch.linspace(0,1,nfeatures).float()
|
||||
self.tau_j = torch.nn.parameter.Parameter(tau_j, requires_grad=False)
|
||||
|
||||
width = torch.tensor(width,dtype=torch.float32)
|
||||
self.width = torch.nn.parameter.Parameter(width,requires_grad = False)
|
||||
|
||||
#network input vector and normalized time (tn e (0,1))
|
||||
# input: (dimx, dimy, ..., features), (scalar) --> output: (dimx, dimy, ...)
|
||||
# input: (ntimes, ..., features), (ntimes) --> output: (ntimes , ...)
|
||||
def forward(self, x, tn):
|
||||
|
||||
ndim = len(x.shape)
|
||||
|
||||
tn = torch.as_tensor(tn)
|
||||
tn.requires_grad = False
|
||||
|
||||
|
||||
if(len(tn.shape)==0):
|
||||
#case 0 - if tn is a scalar
|
||||
g = torch.exp(-1.0/(2.0*self.width)*(tn-self.tau_j)**2)
|
||||
g = g/torch.clip(torch.sum(g),min=0.001)
|
||||
x = torch.tensordot(x,g,dims=[[ndim-1],[0]])
|
||||
else:
|
||||
#Case 1 - tn is a tensor
|
||||
nf = self.tau_j.shape[0]
|
||||
nt = tn.numel()
|
||||
nt2 = x.shape[0]
|
||||
ndim = len(x.shape)
|
||||
nf2 = x.shape[ndim-1]
|
||||
|
||||
if(nt!=nt2):
|
||||
raise Exception("gaussian_timeselect time mismatch")
|
||||
if(nf!=nf2):
|
||||
raise Exception("gaussian_timeselect feature mismatch")
|
||||
|
||||
shp2 = x.shape[1:ndim-1]
|
||||
nshp2 = shp2.numel()
|
||||
|
||||
x = x.reshape([nt,nshp2,nf])
|
||||
tau = self.tau_j.reshape([1,1,nf])
|
||||
tau = tau.expand([nt,nshp2,nf])
|
||||
tn = tn.reshape([nt,1,1])
|
||||
tn = tn.expand([nt,nshp2,nf])
|
||||
|
||||
g = torch.exp(-1.0/(2.0*self.width)*(tn-tau)**2)
|
||||
gs = torch.clip(torch.sum(g,axis=2),min=0.0001)
|
||||
gs = gs.reshape([nt,nshp2,1]).expand([nt,nshp2,nf])
|
||||
g = g/gs
|
||||
|
||||
x = torch.sum(x*g,axis=2)
|
||||
#shp3 = torch.Size(torch.cat(torch.tensor([nt]),torch.tensor(shp2)))
|
||||
# contact [nt, *shp2]
|
||||
shp3 = torch.Size([nt,*shp2])
|
||||
x = x.reshape(shp3)
|
||||
|
||||
return x
|
||||
|
||||
#The reverse network should output a mean image and a covariance image
|
||||
#given an input image
|
||||
|
||||
class mndiff_rev(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(mndiff_rev,self).__init__()
|
||||
|
||||
self.nchan = 10
|
||||
|
||||
self.L1a = torch.nn.Conv2d(1,self.nchan,3,padding=1)
|
||||
self.L1b = torch.nn.Conv2d(self.nchan,self.nchan,3)
|
||||
self.L1c = torch.nn.MaxPool2d(kernel_size=2)
|
||||
|
||||
self.L2a = torch.nn.Conv2d(self.nchan,self.nchan,3)
|
||||
self.L2b = torch.nn.Conv2d(self.nchan,self.nchan,3)
|
||||
|
||||
self.L3a = torch.nn.ConvTranspose2d(self.nchan,self.nchan,3)
|
||||
self.L3b = torch.nn.ConvTranspose2d(self.nchan,self.nchan,3)
|
||||
|
||||
#self.L4a = torch.nn.Upsample(scale_factor=2)
|
||||
self.L4b = torch.nn.ConvTranspose2d(2*self.nchan,self.nchan,3)
|
||||
self.L4c = torch.nn.ConvTranspose2d(self.nchan,1,3,padding=1)
|
||||
|
||||
self.nl = nn.Tanh()
|
||||
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
self.losstime = []
|
||||
self.lossrecord = []
|
||||
|
||||
return
|
||||
|
||||
|
||||
def forward(self,x,t):
|
||||
|
||||
#handle data-shape cases
|
||||
#case 0: x (pix_x, pix_y), t is scalar
|
||||
#case 1: x (ntimes, pix_x, pix_y), t is (ntimes)
|
||||
#case 2: x (nbatch, ntimes, pix_x, pix_y), t is (nbatch, ntimes)
|
||||
|
||||
if(len(x.shape)==2):
|
||||
case = 0
|
||||
t = torch.as_tensor(t).to(self.device)
|
||||
t = t.float()
|
||||
x = x.reshape([1,1,28,28])
|
||||
ntimes = 1
|
||||
nbatch = 1
|
||||
elif(len(x.shape)==3):
|
||||
case = 1
|
||||
ntimes = x.shape[0]
|
||||
nbatch = 1
|
||||
x = x.reshape([ntimes,1,28,28])
|
||||
elif(len(x.shape)==4):
|
||||
case = 2
|
||||
nbatch = x.shape[0]
|
||||
ntimes = x.shape[1]
|
||||
#x = x.reshape([nbatch,ntimes,28*28])
|
||||
x = x.reshape([nbatch*ntimes,1,28,28])
|
||||
t = t.reshape([nbatch*ntimes])
|
||||
else:
|
||||
print("Error: Expect input to be dimension 2,3,4")
|
||||
raise Exception("Error: Expect input to be dimension 2,3,4")
|
||||
|
||||
#block 1
|
||||
s1 = self.L1a(x)
|
||||
s1 = self.nl(s1)
|
||||
s1 = self.L1b(s1) #skip connection
|
||||
s1 = self.nl(s1)
|
||||
s2 = self.L1c(s1)
|
||||
|
||||
#block 2
|
||||
s2 = self.L2a(s2)
|
||||
s2 = self.nl(s2)
|
||||
s2 = self.L2b(s2)
|
||||
s2 = self.nl(s2)
|
||||
|
||||
#block 3
|
||||
s2 = self.L3a(s2)
|
||||
s2 = self.nl(s2)
|
||||
s2 = self.L3b(s2)
|
||||
s2 = self.nl(s2)
|
||||
|
||||
#block 4
|
||||
#s2 = self.L4a(s2)
|
||||
s2 = torch.nn.functional.interpolate(s2,scale_factor=[2,2])
|
||||
#concat skip connection
|
||||
s3 = torch.cat([s1,s2],axis=1)
|
||||
|
||||
s3 = self.L4b(s3)
|
||||
s3 = self.nl(s3)
|
||||
s3 = self.L4c(s3)
|
||||
|
||||
#output
|
||||
eps = s3
|
||||
|
||||
if(case==0):
|
||||
eps = eps.reshape([28,28])
|
||||
elif(case==1):
|
||||
eps = eps.reshape([ntimes,28,28])
|
||||
elif(case==2):
|
||||
eps = eps.reshape([nbatch,ntimes,28,28])
|
||||
|
||||
return eps
|
||||
|
||||
def to(self,dev):
|
||||
self.device = torch.device(dev)
|
||||
ret = super(mndiff_rev,self).to(self.device)
|
||||
return ret
|
||||
|
||||
|
||||
def save(self,fname):
|
||||
try:
|
||||
dev = self.device
|
||||
self.to("cpu")
|
||||
q = [self.state_dict(), self.losstime, self.lossrecord]
|
||||
torch.save(q,fname)
|
||||
self.to(dev)
|
||||
except:
|
||||
print("model save: problem saving {}".format(fname))
|
||||
|
||||
return
|
||||
|
||||
def load(self,fname):
|
||||
try:
|
||||
[modelsd,losstime,lossrecord] = torch.load(fname)
|
||||
self.load_state_dict(modelsd)
|
||||
self.losstime = losstime
|
||||
self.lossrecord = lossrecord
|
||||
except:
|
||||
print("model load: problem loading {}".format(fname))
|
||||
|
||||
return
|
||||
|
||||
|
||||
def plothistory(self):
|
||||
x = self.losstime
|
||||
y = self.lossrecord
|
||||
plt.figure()
|
||||
plt.plot(x,y,'k.',markersize=1)
|
||||
plt.title('Loss History')
|
||||
plt.show()
|
||||
|
||||
return
|
||||
|
||||
def recordloss(self,loss):
|
||||
|
||||
t = 0
|
||||
L = len(self.losstime)
|
||||
if(L>0):
|
||||
t = self.losstime[L-1]
|
||||
self.losstime.append(t+1)
|
||||
self.lossrecord.append(loss)
|
||||
|
||||
return
|
||||
|
||||
|
||||
|
||||
###########
|
||||
## Tests ##
|
||||
###########
|
||||
|
||||
def test_gaussian_timeselect():
|
||||
gts = gaussian_timeselect(5,0.025)
|
||||
gts = gts.to("cuda")
|
||||
|
||||
x = torch.randn((8,8,5)).to("cuda")
|
||||
#t = torch.tensor([0,0.5,1]).to("cuda")
|
||||
t = torch.linspace(0,1,8).to("cuda")
|
||||
|
||||
y = gts(x,0)
|
||||
y2 = gts(x,t)
|
||||
y3 = gts(x,0.111)
|
||||
|
||||
print(y)
|
||||
print(x[:,:,0])
|
||||
|
||||
print(x.shape)
|
||||
print(y.shape)
|
||||
print(y2.shape)
|
||||
|
||||
print(torch.abs(y2[0,:]-y[0,:])>1E-6)
|
||||
|
||||
|
||||
|
||||
return
|
||||
|
||||
def test_mndiff_rev():
|
||||
|
||||
mnd = mndiff_rev()
|
||||
|
||||
x0 = torch.randn(28,28)
|
||||
t0 = 0
|
||||
x1 = torch.randn(5,28,28)
|
||||
t1 = torch.linspace(0,1,5)
|
||||
|
||||
x2 = torch.randn(8,5,28,28)
|
||||
t2 = torch.linspace(0,1,5)
|
||||
t2 = t2.reshape([1,5]).expand([8,5])
|
||||
|
||||
y0 = mnd(x0,t0)
|
||||
y1 = mnd(x1,t1)
|
||||
y2 = mnd(x2,t2)
|
||||
|
||||
print(y0.shape)
|
||||
print(y1.shape)
|
||||
print(y2.shape)
|
||||
|
||||
|
||||
|
||||
return
|
||||
|
||||
if(__name__=="__main__"):
|
||||
|
||||
#test_gaussian_timeselect()
|
||||
test_mndiff_rev()
|
||||
|
||||
pass
|
@ -0,0 +1,78 @@
|
||||
#!/usr/bin/python3
|
||||
|
||||
import os,sys,math
|
||||
|
||||
|
||||
def get_savenumber(fname):
|
||||
num = 0
|
||||
|
||||
fn2 = os.path.split(fname)[1]
|
||||
fn2 = os.path.splitext(fn2)[0]
|
||||
|
||||
L = len(fn2)
|
||||
L2 = L
|
||||
for I in range(L-1,0,-1):
|
||||
c = fn2[I]
|
||||
if(not c.isnumeric()):
|
||||
L2 = I+1
|
||||
break
|
||||
nm = fn2[L2:L]
|
||||
|
||||
try:
|
||||
num = int(nm)
|
||||
except:
|
||||
num = 0
|
||||
|
||||
return num
|
||||
|
||||
|
||||
def sorttwopythonlists(list1, list2):
|
||||
|
||||
zipped_pairs = zip(list2, list1)
|
||||
|
||||
z = [x for _, x in sorted(zipped_pairs)]
|
||||
z2 = sorted(list2)
|
||||
|
||||
return [z,z2]
|
||||
|
||||
def get_current_save(directory,saveseries):
|
||||
fname = None
|
||||
|
||||
fl = os.listdir(directory)
|
||||
fl2 = []
|
||||
nums = []
|
||||
|
||||
for f in fl:
|
||||
if(f.find(saveseries)>=0):
|
||||
fl2.append(os.path.join(directory,f))
|
||||
|
||||
for f in fl2:
|
||||
n = get_savenumber(f)
|
||||
nums.append(n)
|
||||
|
||||
[fl2,nums] = sorttwopythonlists(fl2,nums)
|
||||
|
||||
if(len(fl2)>0):
|
||||
fname = fl2[len(fl2)-1]
|
||||
else:
|
||||
fname = None
|
||||
|
||||
return fname
|
||||
|
||||
def get_next_save(directory,saveseries):
|
||||
fname = None
|
||||
fncurrent = get_current_save(directory,saveseries)
|
||||
if(fncurrent is None):
|
||||
fname = os.path.join(directory,"{}{:02d}.pyt".format(saveseries,0))
|
||||
else:
|
||||
N = get_savenumber(fncurrent)
|
||||
N = N + 1
|
||||
fname = os.path.join(directory,"{}{:02d}.pyt".format(saveseries,N))
|
||||
|
||||
|
||||
return fname
|
||||
|
||||
if(__name__=="__main__"):
|
||||
|
||||
#print(get_savenumber("./saves/helloworld05202.py"))
|
||||
pass
|
@ -0,0 +1,161 @@
|
||||
#!/usr/bin/python3
|
||||
|
||||
import os,sys,math
|
||||
import numpy as np
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import mnist_dataloader
|
||||
|
||||
from mnistdiffusion_model import mndiff_rev
|
||||
|
||||
################
|
||||
## Dataloader ##
|
||||
################
|
||||
|
||||
def mnist_load():
|
||||
|
||||
[labels_train, labels_test, images_train, images_test] = mnist_dataloader.mnist_load()
|
||||
|
||||
#60000 x 28 x 28
|
||||
#60000 x 1
|
||||
|
||||
#Rescale data to (-1,1)
|
||||
images_train = images_train.clone().detach().float()
|
||||
images_test = images_test.clone().detach().float()
|
||||
|
||||
images_train = 2.0*(images_train/255.0)-1.0
|
||||
images_test = 2.0*(images_test/255.0)-1.0
|
||||
|
||||
images_train.requires_grad = False
|
||||
images_test.requires_grad = False
|
||||
|
||||
return [images_train, images_test]
|
||||
|
||||
|
||||
|
||||
|
||||
################
|
||||
## Operations ##
|
||||
################
|
||||
|
||||
def training_parameters():
|
||||
lparas = dict()
|
||||
|
||||
nts = 200
|
||||
t = np.arange(0,nts)
|
||||
s = 0.1
|
||||
f_t = np.cos((t/nts+s)/(1+s)*np.pi/2.0)**2
|
||||
a_t = f_t/f_t[0]
|
||||
B_t = np.zeros(t.shape[0])
|
||||
B_t[0:nts-1] = np.clip(1-a_t[1:nts]/a_t[0:nts-1],0,0.999)
|
||||
B_t[nts-1] = B_t[nts-2]
|
||||
|
||||
lparas["nts"] = nts
|
||||
lparas["t"] = torch.tensor(t,dtype=torch.float32)
|
||||
lparas["a_t"] = torch.tensor(a_t,dtype=torch.float32)
|
||||
lparas["B_t"] = torch.tensor(B_t,dtype=torch.float32)
|
||||
|
||||
#Traiing parameters for a cosine variance schedule
|
||||
#ref: Nichol and Dharawal 2021
|
||||
|
||||
return lparas
|
||||
|
||||
|
||||
def generate_minibatch(imgs, Ndraw, lparas):
|
||||
|
||||
#I can probably speed this up with compiled CUDA code
|
||||
|
||||
nts = lparas["nts"]
|
||||
dev = imgs.device
|
||||
|
||||
imbatch = torch.zeros([Ndraw*nts,28,28],dtype=torch.float32).to(dev)
|
||||
draws = torch.randint(0,imgs.shape[0],(Ndraw,))
|
||||
|
||||
for I in range(0,Ndraw):
|
||||
imgseq = torch.zeros([nts,28,28],dtype=torch.float32).to(dev)
|
||||
imgseq[0,:,:] = imgs[draws[I],:,:]
|
||||
|
||||
beta = torch.kron(lparas["B_t"],torch.ones(28,28)).to(dev)
|
||||
beta = beta.reshape([nts,28,28])
|
||||
|
||||
sig = torch.sqrt(beta)
|
||||
noise = torch.randn((nts,28,28)).to(dev)*sig
|
||||
|
||||
for J in range(1,nts):
|
||||
imgseq[J,:,:] = imgseq[J-1,:,:] + noise[J-1,:,:]
|
||||
|
||||
I1 = I*nts
|
||||
I2 = (I+1)*nts
|
||||
imbatch[I1:I2,:,:] = imgseq[:,:,:]
|
||||
|
||||
return imbatch
|
||||
|
||||
def train_batch(imgbatch,lr,lparas):
|
||||
|
||||
|
||||
|
||||
|
||||
return
|
||||
|
||||
|
||||
def train():
|
||||
|
||||
return
|
||||
|
||||
def infer():
|
||||
|
||||
return
|
||||
|
||||
|
||||
|
||||
def test1():
|
||||
|
||||
[images_train, images_test] = mnist_load()
|
||||
lparas = training_parameters()
|
||||
|
||||
#plt.figure()
|
||||
#plt.plot(lparas["t"],lparas["a_t"])
|
||||
#plt.show()
|
||||
|
||||
#plt.figure()
|
||||
#plt.plot(lparas["t"],lparas["B_t"])
|
||||
#plt.show()
|
||||
|
||||
images_train = images_train.to("cuda")
|
||||
|
||||
mb = generate_minibatch(images_train,50,lparas)
|
||||
|
||||
img0 = torch.squeeze(mb[0,:,:]).clone().detach().to("cpu")
|
||||
img1 = torch.squeeze(mb[1,:,:]).clone().detach().to("cpu")
|
||||
img2 = torch.squeeze(mb[2,:,:]).clone().detach().to("cpu")
|
||||
img3 = torch.squeeze(mb[30,:,:]).clone().detach().to("cpu")
|
||||
|
||||
|
||||
plt.figure()
|
||||
plt.subplot(2,2,1)
|
||||
plt.imshow(img0,cmap='gray')
|
||||
plt.subplot(2,2,2)
|
||||
plt.imshow(img1,cmap='gray')
|
||||
plt.subplot(2,2,3)
|
||||
plt.imshow(img2,cmap='gray')
|
||||
plt.subplot(2,2,4)
|
||||
plt.imshow(img3,cmap='gray')
|
||||
|
||||
plt.show()
|
||||
|
||||
return
|
||||
|
||||
|
||||
##########
|
||||
## Main ##
|
||||
##########
|
||||
|
||||
if(__name__=="__main__"):
|
||||
|
||||
#test1()
|
||||
test_gaussian_timeselect()
|
@ -0,0 +1,85 @@
|
||||
{
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.9-final"
|
||||
},
|
||||
"orig_nbformat": 2,
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3.6.9 64-bit",
|
||||
"metadata": {
|
||||
"interpreter": {
|
||||
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2,
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os,sys,math\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"\n",
|
||||
"import mnist_dataloader\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"torch.Size([])\ntorch.Size([])\n[1, 2, 3]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"a = torch.randn(10,14)\n",
|
||||
"b = a.shape[1:1]\n",
|
||||
"print(b)\n",
|
||||
"b.numel()\n",
|
||||
"\n",
|
||||
"print(b)\n",
|
||||
"\n",
|
||||
"b = torch.Size([1])\n",
|
||||
"c = torch.Size([2,3])\n",
|
||||
"d = torch.Size(torch.cat([torch.tensor(b),torch.tensor(c)]))\n",
|
||||
"\n",
|
||||
"d = [*b,*c]\n",
|
||||
"\n",
|
||||
"print(d)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
]
|
||||
}
|
Binary file not shown.
After Width: | Height: | Size: 21 KiB |
Binary file not shown.
@ -0,0 +1,138 @@
|
||||
#!/usr/bin/python3
|
||||
|
||||
import sys
|
||||
sys.path.append('/home/aschinde/workspace/projects_python/library')
|
||||
|
||||
import os,sys,math
|
||||
import numpy as np
|
||||
import cv2;
|
||||
|
||||
import gzip #May need to use gzip.open instead of open
|
||||
|
||||
import struct
|
||||
#struct unpack allows some interpretation of python binary data
|
||||
#Example
|
||||
##import struct
|
||||
##
|
||||
##data = open("from_fortran.bin", "rb").read()
|
||||
##
|
||||
##(eight, N) = struct.unpack("@II", data)
|
||||
##
|
||||
##This unpacks the first two fields, assuming they start at the very
|
||||
##beginning of the file (no padding or extraneous data), and also assuming
|
||||
##native byte-order (the @ symbol). The Is in the formatting string mean
|
||||
##"unsigned integer, 32 bits".
|
||||
|
||||
#for integers
|
||||
#a = int
|
||||
#a.from_bytes(b'\xaf\xc2R',byteorder='little')
|
||||
#a.to_bytes(nbytes,byteorder='big')
|
||||
#analagous operation doens't seem to exist for floats
|
||||
#what about numpy?
|
||||
|
||||
|
||||
#https://www.devdungeon.com/content/working-binary-data-python
|
||||
|
||||
#print("{:02d}".format(2))
|
||||
#b = b.fromhex('010203040506')
|
||||
#b.hex()
|
||||
#c = b.decode(encoding='utf-8' or 'latin-1' or 'ascii'...)
|
||||
#print(c)
|
||||
|
||||
#numpy arrays have tobytes
|
||||
#numpy arrays have frombuffer (converts to dtypes)
|
||||
#
|
||||
#q = np.array([15],dtype=np.uint8);
|
||||
#q.tobytes();
|
||||
#q.tobytes(order='C') (options are 'C' and 'F'
|
||||
#q2 = np.buffer(q.tobytes(),dtype=np.uint8)
|
||||
#np.frombuffer(buffer,dtype=float,count=-1,offset=0)
|
||||
|
||||
##You could also use the < and > endianess format codes in the struct
|
||||
##module to achieve the same result:
|
||||
##
|
||||
##>>> struct.pack('<2h', *struct.unpack('>2h', original))
|
||||
##'\xde\xad\xc0\xde'
|
||||
|
||||
def bytereverse(bts):
|
||||
## bts2 = bytes(len(bts));
|
||||
## for I in range(0,len(bts)):
|
||||
## bts2[len(bts)-I-1] = bts[I];
|
||||
N = len(bts);
|
||||
## print(N);
|
||||
## print(bts);
|
||||
## bts2 = struct.pack('<{}h'.format(N), *struct.unpack('>{}h'.format(N), bts))
|
||||
bts2 = bts;
|
||||
return bts2;
|
||||
|
||||
#Read Labels
|
||||
def read_MNIST_label_file(fname):
|
||||
#fp = gzip.open('./train-labels-idx1-ubyte.gz','rb');
|
||||
fp = gzip.open(fname,'rb');
|
||||
magic = fp.read(4);
|
||||
#nitems = np.frombuffer(fp.read(4),dtype=np.int32)[0]; #some sort of endiannes problem
|
||||
bts = fp.read(4);
|
||||
#bts = bytereverse(bts);
|
||||
#nitems = np.frombuffer(bts,dtype=np.int32);
|
||||
nitems = np.int32(struct.unpack('>I',bts)[0]); #it was a non-native endianness in teh integer encoding
|
||||
#> < @ - endianness
|
||||
|
||||
bts = fp.read(nitems);
|
||||
N = len(bts);
|
||||
labels = np.zeros((N),dtype=np.uint8);
|
||||
labels = np.frombuffer(bts,dtype=np.uint8,count=N);
|
||||
#for i in range(0,10):
|
||||
# bt = fp.read(1);
|
||||
# labels[i] = np.frombuffer(bt,dtype=np.uint8);
|
||||
fp.close();
|
||||
return labels;
|
||||
|
||||
def read_MNIST_image_file(fname):
|
||||
fp = gzip.open(fname,'rb');
|
||||
magic = fp.read(4);
|
||||
bts = fp.read(4);
|
||||
nitems = np.int32(struct.unpack('>I',bts)[0]);
|
||||
bts = fp.read(4);
|
||||
nrows = np.int32(struct.unpack('>I',bts)[0]);
|
||||
bts = fp.read(4);
|
||||
ncols = np.int32(struct.unpack('>I',bts)[0]);
|
||||
|
||||
images = np.zeros((nitems,nrows,ncols),dtype=np.uint8);
|
||||
for I in range(0,nitems):
|
||||
bts = fp.read(nrows*ncols);
|
||||
img1 = np.frombuffer(bts,dtype=np.uint8,count=nrows*ncols);
|
||||
img1 = img1.reshape((nrows,ncols));
|
||||
images[I,:,:] = img1;
|
||||
|
||||
fp.close();
|
||||
|
||||
return images;
|
||||
|
||||
def read_training_data():
|
||||
rootdir = '/home/aschinde/workspace/machinelearning/datasets/MNIST';
|
||||
fname1 = 'train-labels-idx1-ubyte.gz';
|
||||
fname2 = 'train-images-idx3-ubyte.gz';
|
||||
|
||||
labels = read_MNIST_label_file(os.path.join(rootdir,fname1));
|
||||
images = read_MNIST_image_file(os.path.join(rootdir,fname2));
|
||||
|
||||
return [labels,images];
|
||||
|
||||
def read_test_data():
|
||||
rootdir = '/home/aschinde/workspace/machinelearning/datasets/MNIST';
|
||||
|
||||
fname1 = 't10k-labels-idx1-ubyte.gz';
|
||||
fname2 = 't10k-images-idx3-ubyte.gz';
|
||||
|
||||
labels = read_MNIST_label_file(os.path.join(rootdir,fname1));
|
||||
images = read_MNIST_image_file(os.path.join(rootdir,fname2));
|
||||
|
||||
return [labels,images];
|
||||
|
||||
def show_MNIST_image(img):
|
||||
import matplotlib.pyplot as plt;
|
||||
plt.figure();
|
||||
plt.imshow(255-img,cmap='gray');
|
||||
plt.show();
|
||||
return;
|
||||
|
@ -0,0 +1,92 @@
|
||||
#!/usr/bin/python3
|
||||
|
||||
"""
|
||||
mnist_loader
|
||||
~~~~~~~~~~~~
|
||||
|
||||
A library to load the MNIST image data. For details of the data
|
||||
structures that are returned, see the doc strings for ``load_data``
|
||||
and ``load_data_wrapper``. In practice, ``load_data_wrapper`` is the
|
||||
function usually called by our neural network code.
|
||||
"""
|
||||
|
||||
##sigh: If you want it to run today, write it in Python.
|
||||
##If you want it to run tomorrow, write it in ANYTHING ELSE
|
||||
|
||||
#### Libraries
|
||||
# Standard library
|
||||
##import cPickle
|
||||
import pickle as cPickle
|
||||
import gzip
|
||||
|
||||
# Third-party libraries
|
||||
import numpy as np
|
||||
|
||||
def load_data():
|
||||
"""Return the MNIST data as a tuple containing the training data,
|
||||
the validation data, and the test data.
|
||||
|
||||
The ``training_data`` is returned as a tuple with two entries.
|
||||
The first entry contains the actual training images. This is a
|
||||
numpy ndarray with 50,000 entries. Each entry is, in turn, a
|
||||
numpy ndarray with 784 values, representing the 28 * 28 = 784
|
||||
pixels in a single MNIST image.
|
||||
|
||||
The second entry in the ``training_data`` tuple is a numpy ndarray
|
||||
containing 50,000 entries. Those entries are just the digit
|
||||
values (0...9) for the corresponding images contained in the first
|
||||
entry of the tuple.
|
||||
|
||||
The ``validation_data`` and ``test_data`` are similar, except
|
||||
each contains only 10,000 images.
|
||||
|
||||
This is a nice data format, but for use in neural networks it's
|
||||
helpful to modify the format of the ``training_data`` a little.
|
||||
That's done in the wrapper function ``load_data_wrapper()``, see
|
||||
below.
|
||||
"""
|
||||
#f = gzip.open('../data/mnist.pkl.gz', 'rb')
|
||||
f = gzip.open('./t10k-images-idx3-ubyte.gz','rb');
|
||||
training_data, validation_data, test_data = cPickle.load(f)
|
||||
f.close()
|
||||
return (training_data, validation_data, test_data)
|
||||
|
||||
def load_data_wrapper():
|
||||
"""Return a tuple containing ``(training_data, validation_data,
|
||||
test_data)``. Based on ``load_data``, but the format is more
|
||||
convenient for use in our implementation of neural networks.
|
||||
|
||||
In particular, ``training_data`` is a list containing 50,000
|
||||
2-tuples ``(x, y)``. ``x`` is a 784-dimensional numpy.ndarray
|
||||
containing the input image. ``y`` is a 10-dimensional
|
||||
numpy.ndarray representing the unit vector corresponding to the
|
||||
correct digit for ``x``.
|
||||
|
||||
``validation_data`` and ``test_data`` are lists containing 10,000
|
||||
2-tuples ``(x, y)``. In each case, ``x`` is a 784-dimensional
|
||||
numpy.ndarry containing the input image, and ``y`` is the
|
||||
corresponding classification, i.e., the digit values (integers)
|
||||
corresponding to ``x``.
|
||||
|
||||
Obviously, this means we're using slightly different formats for
|
||||
the training data and the validation / test data. These formats
|
||||
turn out to be the most convenient for use in our neural network
|
||||
code."""
|
||||
tr_d, va_d, te_d = load_data()
|
||||
training_inputs = [np.reshape(x, (784, 1)) for x in tr_d[0]]
|
||||
training_results = [vectorized_result(y) for y in tr_d[1]]
|
||||
training_data = zip(training_inputs, training_results)
|
||||
validation_inputs = [np.reshape(x, (784, 1)) for x in va_d[0]]
|
||||
validation_data = zip(validation_inputs, va_d[1])
|
||||
test_inputs = [np.reshape(x, (784, 1)) for x in te_d[0]]
|
||||
test_data = zip(test_inputs, te_d[1])
|
||||
return (training_data, validation_data, test_data)
|
||||
|
||||
def vectorized_result(j):
|
||||
"""Return a 10-dimensional unit vector with a 1.0 in the jth
|
||||
position and zeroes elsewhere. This is used to convert a digit
|
||||
(0...9) into a corresponding desired output from the neural
|
||||
network."""
|
||||
e = np.zeros((10, 1))
|
||||
e[j] = 1.0
|
||||
return e
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue