You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

289 lines
7.2 KiB
Python

#!/usr/bin/python3
import os,sys,math
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import mnist_dataloader
## Sohl-Dickstein 2015 describes a gaussian "bump" function to
# implement learnable time-dependence
#
#network input vector and normalized time (tn e (0,1))
# input: (dimx, dimy, ..., features), (scalar) --> output: (dimx, dimy, ...)
# input: (ntimes, ..., features), (ntimes) --> output: (ntimes , ...)
class gaussian_timeselect(nn.Module):
def __init__(self, nfeatures=5, width = 0.25):
super(gaussian_timeselect,self).__init__()
tau_j = torch.linspace(0,1,nfeatures).float()
self.tau_j = torch.nn.parameter.Parameter(tau_j, requires_grad=False)
width = torch.tensor(width,dtype=torch.float32)
self.width = torch.nn.parameter.Parameter(width,requires_grad = False)
#network input vector and normalized time (tn e (0,1))
# input: (dimx, dimy, ..., features), (scalar) --> output: (dimx, dimy, ...)
# input: (ntimes, ..., features), (ntimes) --> output: (ntimes , ...)
def forward(self, x, tn):
ndim = len(x.shape)
tn = torch.as_tensor(tn)
tn.requires_grad = False
if(len(tn.shape)==0):
#case 0 - if tn is a scalar
g = torch.exp(-1.0/(2.0*self.width)*(tn-self.tau_j)**2)
g = g/torch.clip(torch.sum(g),min=0.001)
x = torch.tensordot(x,g,dims=[[ndim-1],[0]])
else:
#Case 1 - tn is a tensor
nf = self.tau_j.shape[0]
nt = tn.numel()
nt2 = x.shape[0]
ndim = len(x.shape)
nf2 = x.shape[ndim-1]
if(nt!=nt2):
raise Exception("gaussian_timeselect time mismatch")
if(nf!=nf2):
raise Exception("gaussian_timeselect feature mismatch")
shp2 = x.shape[1:ndim-1]
nshp2 = shp2.numel()
x = x.reshape([nt,nshp2,nf])
tau = self.tau_j.reshape([1,1,nf])
tau = tau.expand([nt,nshp2,nf])
tn = tn.reshape([nt,1,1])
tn = tn.expand([nt,nshp2,nf])
g = torch.exp(-1.0/(2.0*self.width)*(tn-tau)**2)
gs = torch.clip(torch.sum(g,axis=2),min=0.0001)
gs = gs.reshape([nt,nshp2,1]).expand([nt,nshp2,nf])
g = g/gs
x = torch.sum(x*g,axis=2)
#shp3 = torch.Size(torch.cat(torch.tensor([nt]),torch.tensor(shp2)))
# contact [nt, *shp2]
shp3 = torch.Size([nt,*shp2])
x = x.reshape(shp3)
return x
#The reverse network should output a mean image and a covariance image
#given an input image
class mndiff_rev(nn.Module):
def __init__(self):
super(mndiff_rev,self).__init__()
self.L1 = nn.Linear(28*28,1000,bias=True)
self.L2 = nn.Linear(1000,1000,bias=True)
self.L3 = nn.Linear(1000,28*28*5,bias=True)
#self.L4 = nn.Conv2d(10,10,kernel_size=1)
#channels must be 2nd dimension
self.L4a = nn.Linear(5,5)
#equivalent, channels must be last dimension
self.L4b = nn.Linear(5,5)
self.L5 = gaussian_timeselect(5,0.1)
#self.L5alt = nn.Linear(5,1)
#self.RL = nn.LeakyReLU(0.01)
self.tanh = nn.Tanh()
self.device = torch.device("cpu")
self.losstime = []
self.lossrecord = []
return
def forward(self,x,t):
#handle data-shape cases
#case 0: x (pix_x, pix_y), t is scalar
#case 1: x (ntimes, pix_x, pix_y), t is (ntimes)
#case 2: x (nbatch, ntimes, pix_x, pix_y), t is (nbatch, ntimes)
if(len(x.shape)==2):
case = 0
t = torch.as_tensor(t).to(self.device)
t = t.float()
x = x.reshape([1,28*28])
ntimes = 1
nbatch = 1
elif(len(x.shape)==3):
case = 1
ntimes = x.shape[0]
nbatch = 1
x = x.reshape([ntimes,28*28])
elif(len(x.shape)==4):
case = 2
nbatch = x.shape[0]
ntimes = x.shape[1]
#x = x.reshape([nbatch,ntimes,28*28])
x = x.reshape([nbatch*ntimes,28*28])
t = t.reshape([nbatch*ntimes])
y = torch.randn(28*28,device=self.device)
else:
print("Error: Expect input to be dimension 2,3,4")
raise Exception("Error: Expect input to be dimension 2,3,4")
s1 = self.L1(x)
s1 = self.tanh(s1)
s1 = self.L2(s1)
s1 = self.tanh(s1)
s1 = self.L3(s1)
s1 = self.tanh(s1)
s1 = s1.reshape([ntimes*nbatch,28*28,5])
s1 = self.L4a(s1)
s1 = self.tanh(s1)
s1 = self.L4b(s1)
#s1 = self.tanh(s1)
s2 = self.L5(s1,t)
#s2 = self.L5alt(s1)
eps = s2
#eps = self.tanh(s2)
if(case==0):
eps = eps.reshape([28,28])
elif(case==1):
eps = eps.reshape([ntimes,28,28])
elif(case==2):
eps = eps.reshape([nbatch,ntimes,28,28])
return eps
def to(self,dev):
self.device = torch.device(dev)
ret = super(mndiff_rev,self).to(self.device)
return ret
def save(self,fname):
try:
dev = self.device
self.to("cpu")
q = [self.state_dict(), self.losstime, self.lossrecord]
torch.save(q,fname)
self.to(dev)
except:
print("model save: problem saving {}".format(fname))
return
def load(self,fname):
try:
[modelsd,losstime,lossrecord] = torch.load(fname)
self.load_state_dict(modelsd)
self.losstime = losstime
self.lossrecord = lossrecord
except:
print("model load: problem loading {}".format(fname))
return
def plothistory(self):
x = self.losstime
y = self.lossrecord
plt.figure()
plt.plot(x,y,'k.',markersize=1)
plt.title('Loss History')
plt.show()
return
def recordloss(self,loss):
t = 0
L = len(self.losstime)
if(L>0):
t = self.losstime[L-1]
self.losstime.append(t+1)
self.lossrecord.append(loss)
return
###########
## Tests ##
###########
def test_gaussian_timeselect():
gts = gaussian_timeselect(5,0.025)
gts = gts.to("cuda")
x = torch.randn((8,8,5)).to("cuda")
#t = torch.tensor([0,0.5,1]).to("cuda")
t = torch.linspace(0,1,8).to("cuda")
y = gts(x,0)
y2 = gts(x,t)
y3 = gts(x,0.111)
print(y)
print(x[:,:,0])
print(x.shape)
print(y.shape)
print(y2.shape)
print(torch.abs(y2[0,:]-y[0,:])>1E-6)
return
def test_mndiff_rev():
mnd = mndiff_rev()
x0 = torch.randn(28,28)
t0 = 0
x1 = torch.randn(5,28,28)
t1 = torch.linspace(0,1,5)
x2 = torch.randn(8,5,28,28)
t2 = torch.linspace(0,1,5)
t2 = t2.reshape([1,5]).expand([8,5])
y0 = mnd(x0,t0)
y1 = mnd(x1,t1)
y2 = mnd(x2,t2)
print(y0.shape)
print(y1.shape)
print(y2.shape)
return
if(__name__=="__main__"):
test_gaussian_timeselect()
#test_mndiff_rev()
pass