You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

92 lines
3.0 KiB
Python

#!/usr/bin/python3
import os,sys,math
import numpy as np
import cv2
import gzip #need to use gzip.open instead of open
import struct
import torch
def read_MNIST_label_file(fname):
#fp = gzip.open('./train-labels-idx1-ubyte.gz','rb');
fp = gzip.open(fname,'rb');
magic = fp.read(4);
#nitems = np.frombuffer(fp.read(4),dtype=np.int32)[0]; #some sort of endiannes problem
bts = fp.read(4);
#bts = bytereverse(bts);
#nitems = np.frombuffer(bts,dtype=np.int32);
nitems = np.int32(struct.unpack('>I',bts)[0]); #it was a non-native endianness in teh integer encoding
#> < @ - endianness
bts = fp.read(nitems);
N = len(bts);
labels = np.zeros((N),dtype=np.uint8);
labels = np.frombuffer(bts,dtype=np.uint8,count=N);
#for i in range(0,10):
# bt = fp.read(1);
# labels[i] = np.frombuffer(bt,dtype=np.uint8);
fp.close();
return labels;
def read_MNIST_image_file(fname):
fp = gzip.open(fname,'rb');
magic = fp.read(4);
bts = fp.read(4);
nitems = np.int32(struct.unpack('>I',bts)[0]);
bts = fp.read(4);
nrows = np.int32(struct.unpack('>I',bts)[0]);
bts = fp.read(4);
ncols = np.int32(struct.unpack('>I',bts)[0]);
images = np.zeros((nitems,nrows,ncols),dtype=np.uint8);
for I in range(0,nitems):
bts = fp.read(nrows*ncols);
img1 = np.frombuffer(bts,dtype=np.uint8,count=nrows*ncols);
img1 = img1.reshape((nrows,ncols));
images[I,:,:] = img1;
fp.close();
return images;
#The mnist dataset is small enough to fit entirely in memory
def mnist_load():
baseloc = "../training_data"
traindatafile = "train-images-idx3-ubyte.gz"
trainlabelfile = "train-labels-idx1-ubyte.gz"
testdatafile = "t10k-images-idx3-ubyte.gz"
testlabelfile = "t10k-labels-idx1-ubyte.gz"
traindatafile = os.path.join(baseloc,traindatafile)
trainlabelfile = os.path.join(baseloc,trainlabelfile)
testdatafile = os.path.join(baseloc,testdatafile)
testlabelfile = os.path.join(baseloc,testlabelfile)
labels_train = read_MNIST_label_file(trainlabelfile)
labels_test = read_MNIST_label_file(testlabelfile)
images_train = read_MNIST_image_file(traindatafile)
images_test = read_MNIST_image_file(testdatafile)
labels_train = torch.tensor(labels_train,dtype=torch.float32,requires_grad=False)
labels_test = torch.tensor(labels_test,dtype=torch.float32,requires_grad=False)
images_train = torch.tensor(images_train,dtype=torch.float32,requires_grad=False)
images_test = torch.tensor(images_test,dtype=torch.float32,requires_grad=False)
# #debug
# print(labels_train.shape)
# print(labels_test.shape)
# print(images_train.shape)
# print(images_test.shape)
return [labels_train, labels_test, images_train, images_test]
if(__name__ == "__main__"):
[labels_train, labels_test, images_train, images_test] = mnist_load()
print("Loaded MNIST Data")
print(labels_train.shape)
print(labels_test.shape)
print(images_train.shape)
print(images_test.shape)