Extending dtoolAI¶
dtoolAI provides everything needed to train image classification networks “out of the box”. Different types of Deep Learning network will require both new models and possibly classes for training data.
New forms of training data¶
dtoolAI provides two classes for managing training data - TensorDataSet
and
ImageDataSet
. Our examples use these to train models and capture provenance.
The class should:
- Inherit from
dtoolai.data.WrappedDataSet
. This ensures that it provides both the methods required by Pytorch (to feed into the model) and dtoolAI (to capture metadata). - Implement
__len__
which should return how many items are in the dataset. - Implement
__getitem__
, which should return eithertorch.Tensor
objects or numpy arrays that Pytorch is capable of converting to tensors.
Instances of this class can then be passed to
dtoolai.training.train_model_with_metadata_capture
.