You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
289 lines
7.2 KiB
Python
289 lines
7.2 KiB
Python
#!/usr/bin/python3
|
|
|
|
import os,sys,math
|
|
import numpy as np
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
import mnist_dataloader
|
|
|
|
|
|
## Sohl-Dickstein 2015 describes a gaussian "bump" function to
|
|
# implement learnable time-dependence
|
|
#
|
|
#network input vector and normalized time (tn e (0,1))
|
|
# input: (dimx, dimy, ..., features), (scalar) --> output: (dimx, dimy, ...)
|
|
# input: (ntimes, ..., features), (ntimes) --> output: (ntimes , ...)
|
|
class gaussian_timeselect(nn.Module):
|
|
|
|
def __init__(self, nfeatures=5, width = 0.25):
|
|
super(gaussian_timeselect,self).__init__()
|
|
|
|
tau_j = torch.linspace(0,1,nfeatures).float()
|
|
self.tau_j = torch.nn.parameter.Parameter(tau_j, requires_grad=False)
|
|
|
|
width = torch.tensor(width,dtype=torch.float32)
|
|
self.width = torch.nn.parameter.Parameter(width,requires_grad = False)
|
|
|
|
#network input vector and normalized time (tn e (0,1))
|
|
# input: (dimx, dimy, ..., features), (scalar) --> output: (dimx, dimy, ...)
|
|
# input: (ntimes, ..., features), (ntimes) --> output: (ntimes , ...)
|
|
def forward(self, x, tn):
|
|
|
|
ndim = len(x.shape)
|
|
|
|
tn = torch.as_tensor(tn)
|
|
tn.requires_grad = False
|
|
|
|
|
|
if(len(tn.shape)==0):
|
|
#case 0 - if tn is a scalar
|
|
g = torch.exp(-1.0/(2.0*self.width)*(tn-self.tau_j)**2)
|
|
g = g/torch.clip(torch.sum(g),min=0.001)
|
|
x = torch.tensordot(x,g,dims=[[ndim-1],[0]])
|
|
else:
|
|
#Case 1 - tn is a tensor
|
|
nf = self.tau_j.shape[0]
|
|
nt = tn.numel()
|
|
nt2 = x.shape[0]
|
|
ndim = len(x.shape)
|
|
nf2 = x.shape[ndim-1]
|
|
|
|
if(nt!=nt2):
|
|
raise Exception("gaussian_timeselect time mismatch")
|
|
if(nf!=nf2):
|
|
raise Exception("gaussian_timeselect feature mismatch")
|
|
|
|
shp2 = x.shape[1:ndim-1]
|
|
nshp2 = shp2.numel()
|
|
|
|
x = x.reshape([nt,nshp2,nf])
|
|
tau = self.tau_j.reshape([1,1,nf])
|
|
tau = tau.expand([nt,nshp2,nf])
|
|
tn = tn.reshape([nt,1,1])
|
|
tn = tn.expand([nt,nshp2,nf])
|
|
|
|
g = torch.exp(-1.0/(2.0*self.width)*(tn-tau)**2)
|
|
gs = torch.clip(torch.sum(g,axis=2),min=0.0001)
|
|
gs = gs.reshape([nt,nshp2,1]).expand([nt,nshp2,nf])
|
|
g = g/gs
|
|
|
|
x = torch.sum(x*g,axis=2)
|
|
#shp3 = torch.Size(torch.cat(torch.tensor([nt]),torch.tensor(shp2)))
|
|
# contact [nt, *shp2]
|
|
shp3 = torch.Size([nt,*shp2])
|
|
x = x.reshape(shp3)
|
|
|
|
return x
|
|
|
|
#The reverse network should output a mean image and a covariance image
|
|
#given an input image
|
|
|
|
class mndiff_rev(nn.Module):
|
|
|
|
def __init__(self):
|
|
super(mndiff_rev,self).__init__()
|
|
|
|
self.L1 = nn.Linear(28*28,1000,bias=True)
|
|
self.L2 = nn.Linear(1000,1000,bias=True)
|
|
self.L3 = nn.Linear(1000,28*28*5,bias=True)
|
|
|
|
#self.L4 = nn.Conv2d(10,10,kernel_size=1)
|
|
#channels must be 2nd dimension
|
|
self.L4a = nn.Linear(5,5)
|
|
#equivalent, channels must be last dimension
|
|
self.L4b = nn.Linear(5,5)
|
|
|
|
|
|
self.L5 = gaussian_timeselect(5,0.1)
|
|
|
|
#self.L5alt = nn.Linear(5,1)
|
|
|
|
#self.RL = nn.LeakyReLU(0.01)
|
|
self.tanh = nn.Tanh()
|
|
|
|
self.device = torch.device("cpu")
|
|
|
|
self.losstime = []
|
|
self.lossrecord = []
|
|
|
|
return
|
|
|
|
|
|
def forward(self,x,t):
|
|
|
|
#handle data-shape cases
|
|
#case 0: x (pix_x, pix_y), t is scalar
|
|
#case 1: x (ntimes, pix_x, pix_y), t is (ntimes)
|
|
#case 2: x (nbatch, ntimes, pix_x, pix_y), t is (nbatch, ntimes)
|
|
|
|
if(len(x.shape)==2):
|
|
case = 0
|
|
t = torch.as_tensor(t).to(self.device)
|
|
t = t.float()
|
|
x = x.reshape([1,28*28])
|
|
ntimes = 1
|
|
nbatch = 1
|
|
elif(len(x.shape)==3):
|
|
case = 1
|
|
ntimes = x.shape[0]
|
|
nbatch = 1
|
|
x = x.reshape([ntimes,28*28])
|
|
elif(len(x.shape)==4):
|
|
case = 2
|
|
nbatch = x.shape[0]
|
|
ntimes = x.shape[1]
|
|
#x = x.reshape([nbatch,ntimes,28*28])
|
|
x = x.reshape([nbatch*ntimes,28*28])
|
|
t = t.reshape([nbatch*ntimes])
|
|
|
|
y = torch.randn(28*28,device=self.device)
|
|
|
|
else:
|
|
print("Error: Expect input to be dimension 2,3,4")
|
|
raise Exception("Error: Expect input to be dimension 2,3,4")
|
|
|
|
s1 = self.L1(x)
|
|
s1 = self.tanh(s1)
|
|
s1 = self.L2(s1)
|
|
s1 = self.tanh(s1)
|
|
s1 = self.L3(s1)
|
|
s1 = self.tanh(s1)
|
|
s1 = s1.reshape([ntimes*nbatch,28*28,5])
|
|
s1 = self.L4a(s1)
|
|
s1 = self.tanh(s1)
|
|
s1 = self.L4b(s1)
|
|
#s1 = self.tanh(s1)
|
|
|
|
s2 = self.L5(s1,t)
|
|
#s2 = self.L5alt(s1)
|
|
|
|
eps = s2
|
|
#eps = self.tanh(s2)
|
|
|
|
if(case==0):
|
|
eps = eps.reshape([28,28])
|
|
elif(case==1):
|
|
eps = eps.reshape([ntimes,28,28])
|
|
elif(case==2):
|
|
eps = eps.reshape([nbatch,ntimes,28,28])
|
|
|
|
return eps
|
|
|
|
def to(self,dev):
|
|
self.device = torch.device(dev)
|
|
ret = super(mndiff_rev,self).to(self.device)
|
|
return ret
|
|
|
|
|
|
def save(self,fname):
|
|
try:
|
|
dev = self.device
|
|
self.to("cpu")
|
|
q = [self.state_dict(), self.losstime, self.lossrecord]
|
|
torch.save(q,fname)
|
|
self.to(dev)
|
|
except:
|
|
print("model save: problem saving {}".format(fname))
|
|
|
|
return
|
|
|
|
def load(self,fname):
|
|
try:
|
|
[modelsd,losstime,lossrecord] = torch.load(fname)
|
|
self.load_state_dict(modelsd)
|
|
self.losstime = losstime
|
|
self.lossrecord = lossrecord
|
|
except:
|
|
print("model load: problem loading {}".format(fname))
|
|
|
|
return
|
|
|
|
|
|
def plothistory(self):
|
|
x = self.losstime
|
|
y = self.lossrecord
|
|
plt.figure()
|
|
plt.plot(x,y,'k.',markersize=1)
|
|
plt.title('Loss History')
|
|
plt.show()
|
|
|
|
return
|
|
|
|
def recordloss(self,loss):
|
|
|
|
t = 0
|
|
L = len(self.losstime)
|
|
if(L>0):
|
|
t = self.losstime[L-1]
|
|
self.losstime.append(t+1)
|
|
self.lossrecord.append(loss)
|
|
|
|
return
|
|
|
|
|
|
|
|
###########
|
|
## Tests ##
|
|
###########
|
|
|
|
def test_gaussian_timeselect():
|
|
gts = gaussian_timeselect(5,0.025)
|
|
gts = gts.to("cuda")
|
|
|
|
x = torch.randn((8,8,5)).to("cuda")
|
|
#t = torch.tensor([0,0.5,1]).to("cuda")
|
|
t = torch.linspace(0,1,8).to("cuda")
|
|
|
|
y = gts(x,0)
|
|
y2 = gts(x,t)
|
|
y3 = gts(x,0.111)
|
|
|
|
print(y)
|
|
print(x[:,:,0])
|
|
|
|
print(x.shape)
|
|
print(y.shape)
|
|
print(y2.shape)
|
|
|
|
print(torch.abs(y2[0,:]-y[0,:])>1E-6)
|
|
|
|
|
|
|
|
return
|
|
|
|
def test_mndiff_rev():
|
|
|
|
mnd = mndiff_rev()
|
|
|
|
x0 = torch.randn(28,28)
|
|
t0 = 0
|
|
x1 = torch.randn(5,28,28)
|
|
t1 = torch.linspace(0,1,5)
|
|
|
|
x2 = torch.randn(8,5,28,28)
|
|
t2 = torch.linspace(0,1,5)
|
|
t2 = t2.reshape([1,5]).expand([8,5])
|
|
|
|
y0 = mnd(x0,t0)
|
|
y1 = mnd(x1,t1)
|
|
y2 = mnd(x2,t2)
|
|
|
|
print(y0.shape)
|
|
print(y1.shape)
|
|
print(y2.shape)
|
|
|
|
|
|
|
|
return
|
|
|
|
if(__name__=="__main__"):
|
|
|
|
test_gaussian_timeselect()
|
|
#test_mndiff_rev()
|
|
|
|
pass |