#!/usr/bin/python3 import os,sys,math import numpy as np import matplotlib.pyplot as plt import torch import torch.nn as nn import torch.nn.functional as F import mnist_dataloader ## Sohl-Dickstein 2015 describes a gaussian "bump" function to # implement learnable time-dependence # #network input vector and normalized time (tn e (0,1)) # input: (dimx, dimy, ..., features), (scalar) --> output: (dimx, dimy, ...) # input: (ntimes, ..., features), (ntimes) --> output: (ntimes , ...) class gaussian_timeselect(nn.Module): def __init__(self, nfeatures=5, width = 0.25): super(gaussian_timeselect,self).__init__() tau_j = torch.linspace(0,1,nfeatures).float() self.tau_j = torch.nn.parameter.Parameter(tau_j, requires_grad=False) width = torch.tensor(width,dtype=torch.float32) self.width = torch.nn.parameter.Parameter(width,requires_grad = False) #network input vector and normalized time (tn e (0,1)) # input: (dimx, dimy, ..., features), (scalar) --> output: (dimx, dimy, ...) # input: (ntimes, ..., features), (ntimes) --> output: (ntimes , ...) def forward(self, x, tn): ndim = len(x.shape) tn = torch.as_tensor(tn) tn.requires_grad = False if(len(tn.shape)==0): #case 0 - if tn is a scalar g = torch.exp(-1.0/(2.0*self.width)*(tn-self.tau_j)**2) g = g/torch.clip(torch.sum(g),min=0.001) x = torch.tensordot(x,g,dims=[[ndim-1],[0]]) else: #Case 1 - tn is a tensor nf = self.tau_j.shape[0] nt = tn.numel() nt2 = x.shape[0] ndim = len(x.shape) nf2 = x.shape[ndim-1] if(nt!=nt2): raise Exception("gaussian_timeselect time mismatch") if(nf!=nf2): raise Exception("gaussian_timeselect feature mismatch") shp2 = x.shape[1:ndim-1] nshp2 = shp2.numel() x = x.reshape([nt,nshp2,nf]) tau = self.tau_j.reshape([1,1,nf]) tau = tau.expand([nt,nshp2,nf]) tn = tn.reshape([nt,1,1]) tn = tn.expand([nt,nshp2,nf]) g = torch.exp(-1.0/(2.0*self.width)*(tn-tau)**2) gs = torch.clip(torch.sum(g,axis=2),min=0.0001) gs = gs.reshape([nt,nshp2,1]).expand([nt,nshp2,nf]) g = g/gs x = torch.sum(x*g,axis=2) #shp3 = torch.Size(torch.cat(torch.tensor([nt]),torch.tensor(shp2))) # contact [nt, *shp2] shp3 = torch.Size([nt,*shp2]) x = x.reshape(shp3) return x #The reverse network should output a mean image and a covariance image #given an input image class mndiff_rev(nn.Module): def __init__(self): super(mndiff_rev,self).__init__() self.L1 = nn.Linear(28*28,1000,bias=True) self.L2 = nn.Linear(1000,1000,bias=True) self.L3 = nn.Linear(1000,28*28*5,bias=True) #self.L4 = nn.Conv2d(10,10,kernel_size=1) #channels must be 2nd dimension self.L4a = nn.Linear(5,5) #equivalent, channels must be last dimension self.L4b = nn.Linear(5,5) self.L5 = gaussian_timeselect(5,0.1) #self.L5alt = nn.Linear(5,1) #self.RL = nn.LeakyReLU(0.01) self.tanh = nn.Tanh() self.device = torch.device("cpu") self.losstime = [] self.lossrecord = [] return def forward(self,x,t): #handle data-shape cases #case 0: x (pix_x, pix_y), t is scalar #case 1: x (ntimes, pix_x, pix_y), t is (ntimes) #case 2: x (nbatch, ntimes, pix_x, pix_y), t is (nbatch, ntimes) if(len(x.shape)==2): case = 0 t = torch.as_tensor(t).to(self.device) t = t.float() x = x.reshape([1,28*28]) ntimes = 1 nbatch = 1 elif(len(x.shape)==3): case = 1 ntimes = x.shape[0] nbatch = 1 x = x.reshape([ntimes,28*28]) elif(len(x.shape)==4): case = 2 nbatch = x.shape[0] ntimes = x.shape[1] #x = x.reshape([nbatch,ntimes,28*28]) x = x.reshape([nbatch*ntimes,28*28]) t = t.reshape([nbatch*ntimes]) y = torch.randn(28*28,device=self.device) else: print("Error: Expect input to be dimension 2,3,4") raise Exception("Error: Expect input to be dimension 2,3,4") s1 = self.L1(x) s1 = self.tanh(s1) s1 = self.L2(s1) s1 = self.tanh(s1) s1 = self.L3(s1) s1 = self.tanh(s1) s1 = s1.reshape([ntimes*nbatch,28*28,5]) s1 = self.L4a(s1) s1 = self.tanh(s1) s1 = self.L4b(s1) #s1 = self.tanh(s1) s2 = self.L5(s1,t) #s2 = self.L5alt(s1) eps = s2 #eps = self.tanh(s2) if(case==0): eps = eps.reshape([28,28]) elif(case==1): eps = eps.reshape([ntimes,28,28]) elif(case==2): eps = eps.reshape([nbatch,ntimes,28,28]) return eps def to(self,dev): self.device = torch.device(dev) ret = super(mndiff_rev,self).to(self.device) return ret def save(self,fname): try: dev = self.device self.to("cpu") q = [self.state_dict(), self.losstime, self.lossrecord] torch.save(q,fname) self.to(dev) except: print("model save: problem saving {}".format(fname)) return def load(self,fname): try: [modelsd,losstime,lossrecord] = torch.load(fname) self.load_state_dict(modelsd) self.losstime = losstime self.lossrecord = lossrecord except: print("model load: problem loading {}".format(fname)) return def plothistory(self): x = self.losstime y = self.lossrecord plt.figure() plt.plot(x,y,'k.',markersize=1) plt.title('Loss History') plt.show() return def recordloss(self,loss): t = 0 L = len(self.losstime) if(L>0): t = self.losstime[L-1] self.losstime.append(t+1) self.lossrecord.append(loss) return ########### ## Tests ## ########### def test_gaussian_timeselect(): gts = gaussian_timeselect(5,0.025) gts = gts.to("cuda") x = torch.randn((8,8,5)).to("cuda") #t = torch.tensor([0,0.5,1]).to("cuda") t = torch.linspace(0,1,8).to("cuda") y = gts(x,0) y2 = gts(x,t) y3 = gts(x,0.111) print(y) print(x[:,:,0]) print(x.shape) print(y.shape) print(y2.shape) print(torch.abs(y2[0,:]-y[0,:])>1E-6) return def test_mndiff_rev(): mnd = mndiff_rev() x0 = torch.randn(28,28) t0 = 0 x1 = torch.randn(5,28,28) t1 = torch.linspace(0,1,5) x2 = torch.randn(8,5,28,28) t2 = torch.linspace(0,1,5) t2 = t2.reshape([1,5]).expand([8,5]) y0 = mnd(x0,t0) y1 = mnd(x1,t1) y2 = mnd(x2,t2) print(y0.shape) print(y1.shape) print(y2.shape) return if(__name__=="__main__"): test_gaussian_timeselect() #test_mndiff_rev() pass