You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
92 lines
3.0 KiB
Python
92 lines
3.0 KiB
Python
3 weeks ago
|
#!/usr/bin/python3
|
||
|
import os,sys,math
|
||
|
import numpy as np
|
||
|
import cv2
|
||
|
import gzip #need to use gzip.open instead of open
|
||
|
import struct
|
||
|
|
||
|
import torch
|
||
|
|
||
|
def read_MNIST_label_file(fname):
|
||
|
#fp = gzip.open('./train-labels-idx1-ubyte.gz','rb');
|
||
|
fp = gzip.open(fname,'rb');
|
||
|
magic = fp.read(4);
|
||
|
#nitems = np.frombuffer(fp.read(4),dtype=np.int32)[0]; #some sort of endiannes problem
|
||
|
bts = fp.read(4);
|
||
|
#bts = bytereverse(bts);
|
||
|
#nitems = np.frombuffer(bts,dtype=np.int32);
|
||
|
nitems = np.int32(struct.unpack('>I',bts)[0]); #it was a non-native endianness in teh integer encoding
|
||
|
#> < @ - endianness
|
||
|
|
||
|
bts = fp.read(nitems);
|
||
|
N = len(bts);
|
||
|
labels = np.zeros((N),dtype=np.uint8);
|
||
|
labels = np.frombuffer(bts,dtype=np.uint8,count=N);
|
||
|
#for i in range(0,10):
|
||
|
# bt = fp.read(1);
|
||
|
# labels[i] = np.frombuffer(bt,dtype=np.uint8);
|
||
|
fp.close();
|
||
|
return labels;
|
||
|
|
||
|
def read_MNIST_image_file(fname):
|
||
|
fp = gzip.open(fname,'rb');
|
||
|
magic = fp.read(4);
|
||
|
bts = fp.read(4);
|
||
|
nitems = np.int32(struct.unpack('>I',bts)[0]);
|
||
|
bts = fp.read(4);
|
||
|
nrows = np.int32(struct.unpack('>I',bts)[0]);
|
||
|
bts = fp.read(4);
|
||
|
ncols = np.int32(struct.unpack('>I',bts)[0]);
|
||
|
|
||
|
images = np.zeros((nitems,nrows,ncols),dtype=np.uint8);
|
||
|
for I in range(0,nitems):
|
||
|
bts = fp.read(nrows*ncols);
|
||
|
img1 = np.frombuffer(bts,dtype=np.uint8,count=nrows*ncols);
|
||
|
img1 = img1.reshape((nrows,ncols));
|
||
|
images[I,:,:] = img1;
|
||
|
|
||
|
fp.close();
|
||
|
|
||
|
return images;
|
||
|
|
||
|
#The mnist dataset is small enough to fit entirely in memory
|
||
|
def mnist_load():
|
||
|
baseloc = "../training_data"
|
||
|
|
||
|
traindatafile = "train-images-idx3-ubyte.gz"
|
||
|
trainlabelfile = "train-labels-idx1-ubyte.gz"
|
||
|
testdatafile = "t10k-images-idx3-ubyte.gz"
|
||
|
testlabelfile = "t10k-labels-idx1-ubyte.gz"
|
||
|
|
||
|
traindatafile = os.path.join(baseloc,traindatafile)
|
||
|
trainlabelfile = os.path.join(baseloc,trainlabelfile)
|
||
|
testdatafile = os.path.join(baseloc,testdatafile)
|
||
|
testlabelfile = os.path.join(baseloc,testlabelfile)
|
||
|
|
||
|
labels_train = read_MNIST_label_file(trainlabelfile)
|
||
|
labels_test = read_MNIST_label_file(testlabelfile)
|
||
|
images_train = read_MNIST_image_file(traindatafile)
|
||
|
images_test = read_MNIST_image_file(testdatafile)
|
||
|
|
||
|
labels_train = torch.tensor(labels_train,dtype=torch.float32,requires_grad=False)
|
||
|
labels_test = torch.tensor(labels_test,dtype=torch.float32,requires_grad=False)
|
||
|
images_train = torch.tensor(images_train,dtype=torch.float32,requires_grad=False)
|
||
|
images_test = torch.tensor(images_test,dtype=torch.float32,requires_grad=False)
|
||
|
|
||
|
# #debug
|
||
|
# print(labels_train.shape)
|
||
|
# print(labels_test.shape)
|
||
|
# print(images_train.shape)
|
||
|
# print(images_test.shape)
|
||
|
|
||
|
|
||
|
return [labels_train, labels_test, images_train, images_test]
|
||
|
|
||
|
if(__name__ == "__main__"):
|
||
|
[labels_train, labels_test, images_train, images_test] = mnist_load()
|
||
|
print("Loaded MNIST Data")
|
||
|
print(labels_train.shape)
|
||
|
print(labels_test.shape)
|
||
|
print(images_train.shape)
|
||
|
print(images_test.shape)
|