You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

162 lines
3.3 KiB
Python

#!/usr/bin/python3
import os,sys,math
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import mnist_dataloader
from mnistdiffusion_model import mndiff_rev
################
## Dataloader ##
################
def mnist_load():
[labels_train, labels_test, images_train, images_test] = mnist_dataloader.mnist_load()
#60000 x 28 x 28
#60000 x 1
#Rescale data to (-1,1)
images_train = images_train.clone().detach().float()
images_test = images_test.clone().detach().float()
images_train = 2.0*(images_train/255.0)-1.0
images_test = 2.0*(images_test/255.0)-1.0
images_train.requires_grad = False
images_test.requires_grad = False
return [images_train, images_test]
################
## Operations ##
################
def training_parameters():
lparas = dict()
nts = 200
t = np.arange(0,nts)
s = 0.1
f_t = np.cos((t/nts+s)/(1+s)*np.pi/2.0)**2
a_t = f_t/f_t[0]
B_t = np.zeros(t.shape[0])
B_t[0:nts-1] = np.clip(1-a_t[1:nts]/a_t[0:nts-1],0,0.999)
B_t[nts-1] = B_t[nts-2]
lparas["nts"] = nts
lparas["t"] = torch.tensor(t,dtype=torch.float32)
lparas["a_t"] = torch.tensor(a_t,dtype=torch.float32)
lparas["B_t"] = torch.tensor(B_t,dtype=torch.float32)
#Traiing parameters for a cosine variance schedule
#ref: Nichol and Dharawal 2021
return lparas
def generate_minibatch(imgs, Ndraw, lparas):
#I can probably speed this up with compiled CUDA code
nts = lparas["nts"]
dev = imgs.device
imbatch = torch.zeros([Ndraw*nts,28,28],dtype=torch.float32).to(dev)
draws = torch.randint(0,imgs.shape[0],(Ndraw,))
for I in range(0,Ndraw):
imgseq = torch.zeros([nts,28,28],dtype=torch.float32).to(dev)
imgseq[0,:,:] = imgs[draws[I],:,:]
beta = torch.kron(lparas["B_t"],torch.ones(28,28)).to(dev)
beta = beta.reshape([nts,28,28])
sig = torch.sqrt(beta)
noise = torch.randn((nts,28,28)).to(dev)*sig
for J in range(1,nts):
imgseq[J,:,:] = imgseq[J-1,:,:] + noise[J-1,:,:]
I1 = I*nts
I2 = (I+1)*nts
imbatch[I1:I2,:,:] = imgseq[:,:,:]
return imbatch
def train_batch(imgbatch,lr,lparas):
return
def train():
return
def infer():
return
def test1():
[images_train, images_test] = mnist_load()
lparas = training_parameters()
#plt.figure()
#plt.plot(lparas["t"],lparas["a_t"])
#plt.show()
#plt.figure()
#plt.plot(lparas["t"],lparas["B_t"])
#plt.show()
images_train = images_train.to("cuda")
mb = generate_minibatch(images_train,50,lparas)
img0 = torch.squeeze(mb[0,:,:]).clone().detach().to("cpu")
img1 = torch.squeeze(mb[1,:,:]).clone().detach().to("cpu")
img2 = torch.squeeze(mb[2,:,:]).clone().detach().to("cpu")
img3 = torch.squeeze(mb[30,:,:]).clone().detach().to("cpu")
plt.figure()
plt.subplot(2,2,1)
plt.imshow(img0,cmap='gray')
plt.subplot(2,2,2)
plt.imshow(img1,cmap='gray')
plt.subplot(2,2,3)
plt.imshow(img2,cmap='gray')
plt.subplot(2,2,4)
plt.imshow(img3,cmap='gray')
plt.show()
return
##########
## Main ##
##########
if(__name__=="__main__"):
#test1()
test_gaussian_timeselect()