Training a new model¶
In this example we’ll look at one of the “hello world” example problems of training deep learning networks - handwritten digit recognition. We’ll use the MNIST dataset, consisting of 70,000 labelled handwritten digits between 0 and 9 to train a convolutional neural network.
The dataset¶
In this case, we’ve created a dtool DataSet from the MNIST data. We can use the dtool CLI to see what we know about this DataSet:
$ dtool readme show http://bit.ly/2uqXxrk
---
dataset_name: MNIST handwritten digits
project: dtoolAI demonstration datasets
authors:
- Yann LeCun
- Corinna Cortes
- Christopher J.C. Burges
origin: http://yann.lecun.com/exdb/mnist/
usetype: train
This tells us some information about what the data are, who created them, and where we can go to find out more.
Training a network¶
We’ll start by using one of the helper scripts from dtoolAI to train a CNN. Later, we’ll look at what the script is doing.
mkdir example
python scripts/train_cnn_classifier_from_tensor_dataset.py http://bit.ly/2uqXxrk example mnistcnn
This will produce information about the training process, and then report where the dataset with the trained model weights have been written, e.g.:
Wrote trained model (simpleScalingCNN) weights to file://N108176/Users/hartleym/projects/ai/dtoolai-p/example/mnistcnn
dtoolAI and URIs¶
In the example above, when we specified where the trained model should be
written, we provided two parameters to the script with values example
and
mnistcnn
. The second of these, mnistcnn
gives the name of the output
model, the first example
is a base URI. This concept is explained in
more detail in the dtool documentation, we’ll
give a short summary here.
In general when we create model training datasets and trained models, we want to store these in permanant HTTP accessible object storage with persistent URIs. However, since this requires setting up Amazon S3 or Microsoft Azure storage credentials, for simplicity we’ll work with filesystem URIs in these examples. URIs on filesystem disk are something of a special case. Properly qualified file URIs have a form like the example above:
file://N108176/Users/hartleym/projects/ai/dtoolai-p/example/mnistcnn
For convenience’s sake, we allow file URIs to be expressed as filesystem paths.
As such the URI above can be simplified to ./example/mnistcnn
, and dtool
will internally convert this into a full URI.
Applying the trained model to test data¶
The simplest way to test our model is on another preprepared dataset - this allows us to quickly apply the model to many ready-labelled images and calculate its accuracy.
We have provided the MNIST test data as a separate dtool DataSet for this purpose, and we can apply our new model to this dataset like this:
$ python scripts/apply_model_to_tensor_dataset.py \
./example/mnistcnn http://bit.ly/2NVFGQd
7929/10000 correct
If we want to improve the model’s accuracy, we could try training it for longer. For example, to train it for 5 epochs (loops through the training dataset) rather than one, we can run our script again:
$ python scripts/train_cnn_classifier_from_tensor_dataset.py \
http://bit.ly/2uqXxrk example mnistcnn_epochs_5 --params n_epochs=5
This will train the model for longer.
Viewing the trained model metadata¶
One of the core features of dtoolAI is capture of references to training data and metadata about the training process. Let’s look at how we access those captured data for our newly trained model.
dtoolai provides a helper script, dtoolai-provenance
for this purpose. This
will show a model’s training metadata, the references to its training data, then
the metadata for those training data.
$ dtoolai-provenance example/mnistcnn/
Network architecture name: dtoolai.simpleScalingCNN
Model training parameters: {'batch_size': 128,
'init_params': {'input_channels': 1, 'input_dim': 28},
'input_channels': 1,
'input_dim': 28,
'learning_rate': 0.01,
'n_epochs': 1,
'optimiser_name': 'SGD'}
Source dataset URI: http://bit.ly/2uqXxrk
Source dataset name: mnist.train
Source dataset readme:
---
dataset_name: MNIST handwritten digits
project: dtoolAI demonstration datasets
authors:
- Yann LeCun
- Corinna Cortes
- Christopher J.C. Burges
origin: http://yann.lecun.com/exdb/mnist/
usetype: train
We can see that the model dataset contains both information about how the model was trained (learning_rate, n_epochs and so on) as well as the reference to the training data, which we can follow to show its provenance.
What the code is doing¶
We provide the Jupyter notebook TrainingExplained.ipynb to show how the training
script uses dtoolAI’s library functions and classes to make capturing training
metadata and parameters easier. This notebook’s available
here,
or if you have a local copy of the dtoolAI repository, in the notebooks
directory.