---
title: "Creating a Petastorm dataset from MNIST example"
date: 2021-05-03
type: technical_note
draft: false
---

## Creating a Petastorm MNIST dataset
In this notebook we are going to create a Petastorm dataset from the famous MNIST dataset. Compared to ImageNette it has the advantage of being easily available through PyTorch. It is also considerably smaller which makes it easier to experiment with.

In [1]:
from hops import hdfs
import numpy as np
from torchvision.datasets import MNIST

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log
181,application_1617699042861_0008,pyspark,idle,Link,Link


SparkSession available as 'spark'.


In [2]:
path = hdfs.project_path() + "Resources/Petastorm"

### Downloading the dataset with torchvision
Torchvision provides a simple interface to download the MNIST dataset. Note that the download prior to version 0.9.1 is broken! If you have issues with this, please upgrade your installation to the latest version. For other workarounds, see [here](https://stackoverflow.com/questions/66577151/http-error-when-trying-to-download-mnist-data).

In [3]:
path = hdfs.project_path() + "DataSets/MNIST"
train_dataset = MNIST(path, download=True)
test_dataset = MNIST(path, download=True, train=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to hdfs://rpc.namenode.service.consul:8020/Projects/PyTorch_spark_minimal/DataSets/MNIST/MNIST/raw/train-images-idx3-ubyte.gz
Extracting hdfs://rpc.namenode.service.consul:8020/Projects/PyTorch_spark_minimal/DataSets/MNIST/MNIST/raw/train-images-idx3-ubyte.gz to hdfs://rpc.namenode.service.consul:8020/Projects/PyTorch_spark_minimal/DataSets/MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to hdfs://rpc.

### Setting up the petastorm dataset generation
Now that we have our dataset, creating the petastorm dataset is exactly the same as with ImageNette. Note that for distributed training you need an even dataset. If your dataset is not even (meaning that each node sees the same amount of examples) you can increase the number of parquet files in order to allow for a more fine grained distribution among nodes.

In [14]:
from petastorm.codecs import CompressedImageCodec, NdarrayCodec, ScalarCodec
from petastorm.etl.dataset_metadata import materialize_dataset
from petastorm.unischema import Unischema, UnischemaField, dict_to_spark_row
from pyspark.sql import SparkSession
from pyspark.sql.types import IntegerType


MNISTSchema = Unischema('ScalarSchema', [
   UnischemaField('image', np.uint8, (1,28,28), NdarrayCodec(), False),
   UnischemaField('label', np.int8, (), ScalarCodec(IntegerType()), False)])

def row_generator(idx, dataset):
    img, label = dataset[idx]
    return {'image': np.expand_dims(np.array(img, dtype=np.uint8), axis=0), 'label': label}


def generate_MNIST_dataset(output_url, dataset):
    rowgroup_size_mb = 1
    rows_count = len(dataset)
    parquet_files_count = 100
    
    sc = spark.sparkContext
    # Wrap dataset materialization portion. Will take care of setting up spark environment variables as
    # well as save petastorm specific metadata
    with materialize_dataset(spark, output_url, MNISTSchema, rowgroup_size_mb):
        rows_rdd = sc.parallelize(range(rows_count))\
            .map(lambda x: row_generator(x, dataset))\
            .map(lambda x: dict_to_spark_row(MNISTSchema, x))

        spark.createDataFrame(rows_rdd, MNISTSchema.as_spark_schema()) \
            .repartition(parquet_files_count) \
            .write \
            .mode('overwrite') \
            .parquet(output_url)

### Generating the dataset
Now that everything is set up, we can define our output paths and generate the datasets.

In [15]:
train_path = hdfs.project_path() + "DataSets/MNIST/PetastormMNIST/train_set"
test_path = hdfs.project_path() + "DataSets/MNIST/PetastormMNIST/test_set"

In [16]:
generate_MNIST_dataset(train_path, train_dataset)
generate_MNIST_dataset(test_path, test_dataset)