Purpose Statement
We shall code a PyTorch Dataset to load data from the IDX-formatted files that store the MNIST hand-written numeral data and labels. This will include a study of the format documentation found on the source webpage. We shall then use the Dataset to train a simple feed-forward network to classify hand-written digits. The code written is not written to be optimal but for ease of understanding.
Introduction
The MNIST numeral images are perhaps the most famous intro-to-machine-learning dataset. The set contains numerous 28-by-28 grey-scale images that each contain one hand-drawn numerical digit. The set also has a label for each image that identifies what numeral the image contains. The "Hello, World!" of machine learning is then to train a model to identify what numeral is contained in such a 28-by-28 image.
What makes this dataset optimal as an introduction to machine learning is that it is clean, meaning that the data are exactly what we want for the problem, that it is large, allowing a basic model to obtain high performance, and that it can be accessed easily, empowering a programmer to download and use the data in a couple lines of code. There exist many efficient methods for accessing this MNIST numeral dataset, such as the importable datasets in Tensorflow and PyTorch.
For this article, though, we will not be using one of these pre-built pipelines to access the MNIST numeral dataset. Instead, we shall write code to parse the IDX-format files found on the source webpage that contain these data. This will permit that we learn the basics of parsing a custom file format and of creating a custom PyTorch Dataset.
I am running Python 3.10.6 with PyTorch 2.0.0.
Finding & Organizing the Data
The MNIST data can be downloaded from yann.lecun.com. The data that you need to download are the the four files with the .gz extension:
- train-images-idx3-ubyte.gz
- train-labels-idx1-ubyte.gz
- t10k-images-idx3-ubyte.gz
- t10k-labels-idx1-ubyte.gz
The first two files contain the training data and the second two contain the testing data. The reason for which the training and testing data are kept separate is because we want to ensure that any system we design to identify the numerals in an image has not simply memorized the images that we used to build the system; thus do we test our system on data that are not used in training the model.
All file paths given will be relative to the directory containing my Python script, say main.py. Let us now place these four data files with the extension removed in the folder path data/mnist-digits/. Therefore, the path of the first file will be data/mnist-digits/train-images-idx3-ubyte.
Parsing the IDX format
Now that we have the files, we need to know how to read them. Fortunately, there is documentation at the bottom of the same webpage describing the file format. The format is a binary format, meaning that we will need to interpret the files as such. The format is called the IDX format and one file stores an array with a specified number of axes.
Meta Data
Let us begin our parsing code with us opening the training images file, from which we shall then go onward to read. The Python open
function, with the flag "rb" set, specifies that we want to read (r) from the file and that we want to interpret the file as bytes (b). The function hence returns a file object that allows us to read the byte values stored in the file.
# Open training images file
train_im = open("data/mnist-digits/train-images.idx3-ubyte", "rb")
The format is described by telling us what values we should expect at each byte in the file. The first thing that we read is
The magic number is an integer (MSB first). The first 2 bytes are always 0.
A magic number is some consistent value with which all files of a specific type begin. Here, we see that all IDX files will begin with two bytes of all zeros, that is 16 bits with value 0. If we open a file and find the first 16 bits not all to be zero, then we may conclude that the file is not intended to be an IDX file.
To read the first two bytes of the file, getting the magic number, we run the following code:
# Read magic number
magic_number = train_im.read(2)
We read the first two bytes of the file with train_im.read(2)
. We could now use this magic_number value to verify that the file we are reading could be an IDX file. However, I am simply going to assume that the file is IDX.
The .read
method of our file object does not just read the first bytes of our file but also advances an internal pointer value that determines the current position in the file. Therefore, the next time that we use .read
, the first byte read will be the third byte of the file.
The next byte of the format is described thus in the documentation, where 0x## corresponds to the hex value ##:
The third byte codes the type of the data:
0x08: unsigned byte
0x09: signed byte
0x0B: short (2 bytes)
0x0C: int (4 bytes)
0x0D: float (4 bytes)
0x0E: double (8 bytes)
This third byte will tell us what data type is used to store each value in the array inside this file. Elsewhere in the documentation, it states that the MNIST numeral data are stored with unsigned bytes. Therefore, I will only explicitly include the code for handling unsigned bytes:
# Read data type
data_type = train_im.read(1)
# Set data sizes
if data_type == b'\x08':
# set data type
data_type = int
signed = False
# size in bytes of one element
datum_size = 1
# Code to check for other types
elif data_type == b'\x09':
# ...
For each data type, we store how many bytes are used per element because we must know later how many bytes to read for getting an element. Also, we save whether the data type is signed or not to differentiate between signed and unsigned integers.
Again from the documentation,
The 4-th byte codes the number of dimensions of the vector/matrix: 1 for vectors, 2 for matrices....
The fourth byte will tell us the number of axes that the array contains. We need to read in one more byte and convert it to an integer to get the number of axes, then:
# Number of axes
n_axes = int.from_bytes(train_im.read(1), byteorder = 'big')
The function int.from_bytes
converts the one byte that was read with .read
into an integer. The byteorder
argument defines whether we interpret the byte as being big-endian or little-endian formatted. Big endian means that the bits earlier in the file are the more significant digits and little endian means the opposite. For example, with big endian but with little endian.
This number, stored in n_axes
, is the number of axis dimensions that we will need to read.
From the basic outline below, we gather that the axis dimensions appear in order from the least to greatest axis starting from the fifth byte (where "magic number" necessarily means the actual magic number alongside the data type and number of axes):
magic number
size in dimension 0
size in dimension 1
size in dimension 2
.....
size in dimension N
data
We then see how each dimension is formatted:
The sizes in each dimension are 4-byte integers (MSB first, high endian, like in most non-Intel processors).
We now need to read in n_axes
4-byte integers to determine the size of each axis in the array, saving the dimensions into a tuple to store the shape of our data nicely:
shape = []
for _ in range(n_axes):
# Append each axis dimension to the shape
shape.append(int.from_bytes(train_im.read(4), byteorder = 'big'))
# Convert the array to a tuple
shape = tuple(shape)
If we inspect the contents of shape, we would see that it contains (60000,28,28)
, which correctly is sixty-thousand 28-by-28 images.
Accessing Elements
Now, we have the actual data to parse.
The data is stored like in a C array, i.e. the index in the last dimension changes the fastest.
When the documentation says that the data are stored as in C, this implies that the entire array is stored in one long, contiguous chunk starting from the first element and going to the last. Even for multi-dimensional arrays like matrices, the data must be stored in one straight line because memory is linear. The position of each element in the array is as follows. Let be the axis's dimension in an array with dimensions and be the index that we are accessing. The position of an element with indices is given by
To figure why this equation is set up as it is, consider that, for a two-dimensional array, each one-dimensional array (row) is a contiguous chunk of memory. Similarly, a three-dimensional array has each two-dimensional array being internally contiguous et cetera.
Now, we can implement the equation given above to get the position of one element in the file:
def get_position(indices, shape):
'''Given the indices for an element and
the shape of the array, return the position
in memory of the element, ranging from 0
to the number of elements in the array minus 1
:param indices: an array of indices [i1,...,in]
:param shape: an array of axis dimensions [d1,...,dn]
'''
# start with last index
pos = indices[-1]
# Handle all other indices
for i in range(0, len(shape) - 1):
# calculate dimension product
prod = 1
for j in range(i + 1, len(shape)):
prod *= shape[j]
pos += indices[i] * prod
return pos
Finally, we may work to get the starting byte of an index in the array.
Since we have already read all of the meta data, the rest of the file is now the actual data for the array stored in the IDX file. The current file pointer thus is pointed to the first element in the array. For this reason, let us save this position in the file for later reference:
# Get start point of data in file
data_start = train_im.tell()
To get the byte offset from the beginning of the data for an element with indices (i, j, k), we get the position of the element in the data with get_position
and then multiply this by the number of bytes that each element consumes.
i, j, k = 6, 4, 13
# Access element i,j,k in file
pos = get_position([i,j,k], shape)
# Multiply by size of each element
byte_offset = pos * datum_size
# Get element address in file
address = data_start + byte_offset
# Move file pointer to address
train_im.seek(address)
# Get element at address
element = data_type.from_bytes(
train_im.read(datum_size),
byteorder = "big",
signed = signed)
# Inspect element value
print(element) # Out: 255
The method .seek
moves the file pointer to a specified byte in the file. After it is called, the next .read
will begin at that address. If your code prints out the value 255 for element
, then you probably coded everything correctly.
We say that we are loading data from the file dynamically because we are not loading the entire array at once from the file. Instead, we only load the parts of the file relevant to the data that we want to use. This technique, while perhaps not needed for a dataset of this size, must be used with massive datasets used in some real-world scenarios.
Building PyTorch Dataset
While heavily based on all code hereto writte, what follows does not depend on the code that we wrote above.
PyTorch is a machine learning library with functions and packages that aid in model creation, training, and deployment. Matplotlib will allow us to visualize our data. Let us import the libraries:
# We shall use this to display images
from matplotlib import pyplot as plt
# PyTorch
import torch
To allow for easier data-loading to a PyTorch model, we now shall construct a PyTorch Dataset. Creating a PyTorch Dataset involves the creation of a class that inherits from torch.utils.data.Dataset
and that implements three methods. We need to define the constructor, the length getter, and the procedure for accessing a datum (in this case a grey-scale image) given an index.
A lot of the code from before will be reused below to read the IDX format. However, we do not want to index individual elements as before; instead, we want to index each datum, which is a 28-by-28 array. For this reason, we can treat each 28-by-28 array, which is contiguous positions in memory, as a single element that we read all at once as a PyTorch tensor.
We thus treat our N-dimensional data as though they were one-dimensional with entries that are (N-1)-dimensional. If we run dataset[4]
we want the fifth element, whether that be the value 7 for one dataset or a multidimensional array for another.
The code for the dataset class is below. As before, I only include code to handle unsigned integer data because that is all that we need for this MNIST dataset.
class IdxDataset(torch.utils.data.Dataset):
'''A PyTorch map-style Dataset for reading a
IDX format file as a dataset. This assumes that
the first axis is the index for each datum.'''
def __init__(self, filename):
# Open IDX file
self.f = open(filename, 'rb')
# Read magic number
magic_number = self.f.read(2)
# get data type
data_type = self.f.read(1)
# Number of axes
n_axes = int.from_bytes(self.f.read(1), byteorder = 'big')
# Get shape of data
shape = []
for _ in range(n_axes):
shape.append(int.from_bytes(self.f.read(4), byteorder = 'big'))
self.shape = tuple(shape)
# Set data sizes
if data_type == b'\x08':
self.data_type = torch.uint8
# size of one element
datum_size = 1
## ...
## Not included here: check for other data types
## ...
# Get the size of one entry in elements
self.entry_size = 1
for dim in self.shape[1:]:
self.entry_size *= dim
# Get size of one entry in bytes
self.entry_size_b = self.entry_size * datum_size
# Get start point of data in file
self.data_start = self.f.tell()
def __len__(self):
return self.shape[0]
def __getitem__(self, idx):
address = self.data_start + idx * self.entry_size_b
self.f.seek(address)
# Get an entry from memory
out = torch.frombuffer(
bytearray(self.f.read(self.entry_size)),
dtype = self.data_type,
count = self.entry_size
)
# Reshape datum
return out.reshape(self.shape[1:])
The class field self.entry_size
stores the number of positions that comprise each datum. For the MNIST numeral dataset, this will be . We then multiply this value by the number of bytes in each number to get the number of bytes used per image.
Inside __getitem__
, we add to the starting address of the data the index times the number of bytes per image, getting us to the starting byte of the desired image. We then use torch.frombuffer
to read the bytes as a PyTorch tensor. By default, .read
returns an immutable object, which should not be passed into .frombuffer
. For this reason, we cast the bytes returned to a bytearray, which is mutable. The tensor that we read in will be a one-dimensional array, necessitating that we reshape the tensor into the expected shape of each datum.
We now can test this Dataset by loading arbitrary indices along with their labels from the appropriate files. In the code here, we shall display an arbitrary image along with printing its label.
# Create two datasets for images and for labels
ds_im = IdxDataset("data/mnist-digits/train-images.idx3-ubyte")
ds_labels = IdxDataset("data/mnist-digits/train-labels.idx1-ubyte")
print(ds_labels[5385])
plt.imshow(ds_im[5385], cmap = 'gray')
plt.show() # See image below for output
To train a neural network with the MNIST data, we do not want a Dataset with only images or only labels, though; but we want a PyTorch Dataset that will load an image with its corresponding label at the same time because a training algorithm must receive the input image, make a prediction, and then compare its prediction to the label to correct itself and to learn. We therefore write the following simple Dataset that combines two Dataset objects into one:
class CombinationDataset(torch.utils.data.Dataset):
'''Dataset for combining two datasets,
one being the inputs and one being the labels'''
def __init__(self, dataset, labelset):
self.data = dataset
self.labels = labelset
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
Using this new class, we can create a Dataset that stores both the data and the labels from the MNIST numeral set as follows:
# Create dataset that contains both
# data (images) and labels (numerals)
ds = CombinationDataset(ds_im, ds_labels)
This Dataset, stored in the variable ds
, is what we shall use to train our digit classifier.
Building Our DataLoader
As a final step before training, we need to prepare our DataLoader. Whereas a PyTorch Dataset provides streamlined access to data, a PyTorch DataLoader provides streamlined loading of that data for use in a machine-learning model. This loading will batch, randomize, and preprocess our data for the training algorithm.
We ultimately are going to be using a standard feed-forward, densely connected network. These dense layers that we shall use will not support two-dimensional image input, meaning that we shall need to flatten our images into long arrays. While flattening an image, the following formatting function will also take the image pixel values from the range [0, 255], the MNIST dataset's default, to the range [-1,1] because the model will train better:
def format_image(im):
'''Given a 2-D MNIST image, format it
to be a flattened 1-D array on range [-1, 1]'''
# Get flat, floating point array
new_im = im.flatten().float().unsqueeze(0)
# Rescale to be on [-1, 1]
new_im = (new_im - 127.5) / 127.5
return new_im
A PyTorch Dataloader, by default, will group our data into tuple pairs (image, label). We shall write the following custom collation function, which is a preprocessing function, for our DataLoader:
def collate_fn(data):
# Build arrays of X & y data
X = []
y = []
for datum in data:
# Normalization & flattening
X.append(format_image(datum[0]))
# unsqueeze adds an extra dimension
# we need this to concat later
y.append(datum[1].unsqueeze(0))
# Concatenate arrays into PyTorch Tensor
X = torch.concat(X)
y = torch.concat(y)
return X, y
With this collation function and the CombinationDataset
that we wrote earlier, we can now create our DataLoader for our model training:
loader = torch.utils.data.DataLoader(ds, shuffle = True,
batch_size = 64,
collate_fn = collate_fn)
This DataLoader has shuffling turned on, meaning that, each time we create an iterable from the DataLoader, the data will be loaded in a different order, thus reducing the overfitting that our model may do. We set the batch size to 64, meaning that 64 images will be loaded for each training iteration. Finally, we use our collation function to preprocess the data. Before the DataLoader returns its batch of data, the batch is passed into collate_fn
and the output thereof becomes the DataLoader's output.
Model Creation & Training
The PyTorch model that we shall code is a no-bells, no-whistles, feed-forward, densely-connected network with ReLU activation functions. The only methods that we need to define for a working PyTorch model are the initialization function and forward prediction function:
class MnistModel(torch.nn.Module):
'''A neural network for the MNIST dataset'''
def __init__(self):
super(MnistModel, self).__init__()
self.linear1 = torch.nn.Linear(28*28, 16)
self.activation1 = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(16, 16)
self.activation2 = torch.nn.ReLU()
self.linear3 = torch.nn.Linear(16, 16)
self.activation3 = torch.nn.ReLU()
self.linear4 = torch.nn.Linear(16, 10)
self.activation4 = torch.nn.Softmax(1)
def forward(self, x):
x = self.linear1(x)
x = self.activation1(x)
x = self.linear2(x)
x = self.activation2(x)
x = self.linear3(x)
x = self.activation3(x)
x = self.linear4(x)
x = self.activation4(x)
return x
Note that the first layer accepts one-axis data with size , the shape of our flattened MNIST images, and that the last layer has ten nodes for its output, the number of numerals represented in the images among which we want to classify.
Next, we initialize this model and train it using our MNIST DataLoader.
# Initialize model
model = MnistModel()
# Set hyperparameters
N_EPOCHS = 5
LR = 0.00005
# CrossEntropyLoss expects
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = LR)
# Run N_EPOCHS iterations of entire dataset
for i in range(N_EPOCHS):
avg_loss = 0
for data, labels in iter(loader):
# make predictions
outputs = model(data)
# calculate error
loss = loss_fn(outputs, labels)
# calculate gradient
loss.backward()
# update model weights/parameters
optimizer.step()
# track average loss for epoch
avg_loss += loss
avg_loss /= len(loader)
print(f"EPOCH {i}: loss {avg_loss}")
Running the above training algorithm for five epochs should be sufficient for this problem.
Finally, we see how our model performs by loading in the test data with the classes we wrote earlier:
# Load IDX data
ds_im_test = IdxDataset("data/mnist-digits/t10k-images.idx3-ubyte")
ds_labels_test = IdxDataset("data/mnist-digits/t10k-labels.idx1-ubyte")
# Combine images and labels to single dataset
test_ds = CombinationDataset(ds_im_test, ds_labels_test)
And then we test our model using these data:
idx = 21 # Change this value
plt.imshow(test_ds[idx][0], cmap='gray')
print(model(format_image(test_ds[idx][0])).argmax())
plt.show()
Running those four lines of code, altering the idx
variable to inspect different testing images, we can see that the model performs well and as we should expect.
Conclusion
After learning how to parse the IDX format with Python code, we wrote a custom PyTorch Dataset to read the MNIST numeral data. We then customized the PyTorch DataLoader to preprocess the data for our model, which we then trained to classify images from the MNIST dataset according to their contained numerals.