Fashion MNIST – Importing and Plotting in Python

Filed Under: Python Advanced
Fashion MNIST

Fashion MNIST dataset is a more challenging replacement for the old MNIST dataset. The MNIST dataset is a very popular dataset in the world of Machine Learning. It is often used in benchmarking of machine learning algorithms.

The MNIST contains a collection of 70,000, 28 x 28 images of handwritten digits from 0 to 9. It is often used to solve the problem of handwriting recognition.

This dataset contains 70,000 small square 28×28 pixel grayscale images of items of 10 types of clothing, such as shoes, t-shirts, dresses, and more.

The different class labels in this dataset are as follows :

  • 0: T-shirt/top
  • 1: Trouser
  • 2: Pullover
  • 3: Dress
  • 4: Coat
  • 5: Sandal
  • 6: Shirt
  • 7: Sneaker
  • 8: Bag
  • 9: Ankle boot

In this tutorial we will use Keras to load the Fashion MNIST dataset and then plot it using matplotlib.

Importing the Fashion MNIST dataset from Keras

Let’s start by importing the dataset from Keras. Use the following lines of code to do that:

from keras.datasets import fashion_mnist
# get training and testing vectors 
(trainX, trainy), (testX, testy) = fashion_mnist.load_data()

After loading the dataset, we can print the shape of the training and testing vectors.

print('X_train: ' + str(train_X.shape))
print('Y_train: ' + str(train_y.shape))
print('X_test:  '  + str(test_X.shape))
print('Y_test:  '  + str(test_y.shape))

Output :

X_train: (60000, 28, 28)
Y_train: (60000,)
X_test:  (10000, 28, 28)
Y_test:  (10000,)

We can see that out of the total 70,000 images, 60,000 are part of the training set and the remaining 10,000 are a part of the testing set.

Now let’s learn how to plot the fashion MNIST dataset.

Plotting the Fashion MNIST dataset

To plot the dataset we are going to use matplotlib.

We will first import the library and then use it for plotting 9 images from the training set.

from matplotlib import pyplot
for i in range(4):  
  pyplot.subplot(330 + 1 + i)
  pyplot.imshow(train_X[i+100], cmap=pyplot.get_cmap('gray'))
  pyplot.show()

Complete Code

The complete code for importing and plotting the Fashion MNIST dataset is given below :

from keras.datasets import fashion_mnist
from matplotlib import pyplot
# get training and testing vectors 
(trainX, trainy), (testX, testy) = fashion_mnist.load_data()

#printing the shapes of vectors 
print('X_train: ' + str(train_X.shape))
print('Y_train: ' + str(train_y.shape))
print('X_test:  '  + str(test_X.shape))
print('Y_test:  '  + str(test_y.shape))

#plotting 
for i in range(4):  
  pyplot.subplot(330 + 1 + i)
  pyplot.imshow(train_X[i+100], cmap=pyplot.get_cmap('gray'))
pyplot.show()


Output :

Plotting
Plotting

Conclusion

This tutorial was about importing and plotting the Fashion MNIST dataset. This dataset is a more challenging version of the existing MNIST dataset.

After importing the dataset you can build a Convolution Neural Networks and train the network on this dataset for recognizing these 10 items of clothing in an image. To learn how to import and plot the MNIST dataset, refer to this tutorial.

Leave a Reply

Your email address will not be published. Required fields are marked *

close
Generic selectors
Exact matches only
Search in title
Search in content
Search in posts
Search in pages