Train a Model with PyTorch - CNN for Image Classification

Train a convolutional neural network for image classification in PyTorch on GPU, using Dask to greatly accelerate the work!

In this example, we will train a model using the Resnet50 architecture to classify images of dogs by breed.

Dataset: Stanford Dogs
Model: Resnet50

import dask
from dask_saturn import SaturnCluster
from dask.distributed import Client
from dask_pytorch_ddp import data, dispatch

import re
import s3fs
import toolz
import math
import numpy as np
import multiprocessing as mp

import torch
from torchvision import datasets, transforms, models
import torch.distributed as dist
from torch import nn, optim
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import lr_scheduler
from torch.utils.data.sampler import SubsetRandomSampler

Load label targets

s3a = s3fs.S3FileSystem()
with s3a.open('s3://saturn-public-data/dogs/imagenet1000_clsidx_to_labels.txt') as f:
    imagenetclasses = [line.strip() for line in f.readlines()]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Start and connect cluster

cluster = SaturnCluster()
client = Client(cluster)
client.wait_for_workers(3)
client.restart()

Define functions

def prepro_batches(bucket, prefix):
    '''Initialize the custom Dataset class defined in dask-pytorch-ddp, apply transformations.'''
    transform = transforms.Compose([
    transforms.Resize(256), 
    transforms.CenterCrop(250), 
    transforms.ToTensor()])
    whole_dataset = data.S3ImageFolder(bucket, prefix, transform=transform, anon = True)
    return whole_dataset

def get_splits_parallel(train_pct, data, batch_size, subset = False, workers = 1):
    '''Select two samples of data for training and evaluation'''
    classes = data.classes
    train_size = math.floor(len(data) * train_pct)
    indices = list(range(len(data)))
    np.random.shuffle(indices)
    train_idx = indices[:train_size]
    test_idx = indices[train_size:len(data)]

    if subset:
        train_idx = np.random.choice(train_idx, size = int(np.floor(len(train_idx)*(1/workers))), replace=False)
        test_idx = np.random.choice(test_idx, size = int(np.floor(len(test_idx)*(1/workers))), replace=False)

    train_sampler = SubsetRandomSampler(train_idx)
    test_sampler = SubsetRandomSampler(test_idx)
    
    train_loader = torch.utils.data.DataLoader(
        data, 
        sampler=train_sampler, 
        batch_size=batch_size, 
        num_workers=num_workers, 
        multiprocessing_context=mp.get_context('fork')
    )
    
    test_loader = torch.utils.data.DataLoader(
        data, 
        sampler=train_sampler, 
        batch_size=batch_size, 
        num_workers=num_workers, 
        multiprocessing_context=mp.get_context('fork')
    )
    
    return train_loader, test_loader

def replace_label(dataset_label, model_labels):
    '''Reindex dataset labels to match imagenet label set'''

    label_string = re.search('n[0-9]+-([^/]+)', dataset_label).group(1)
    
    for i in model_labels:
        i = str(i).replace('{', '').replace('}', '')
        model_label_str = re.search('''b["'][0-9]+: ["']([^\/]+)["'],["']''', str(i))
        model_label_idx = re.search('''b["']([0-9]+):''', str(i)).group(1)
        
        if re.search(str(label_string).replace('_', ' '), str(model_label_str).replace('_', ' ')):
            return i, model_label_idx
            break

Model Parameters

Aside from the Special Elements noted below, we can write this section essentially the same way we write any other PyTorch training loop.

  • Cross Entropy Loss for our loss function
  • SGD (Stochastic Gradient Descent) for our optimizer

We have two stages in this process, as well - training and evaluation. We run the training set completely using batches of 100 before we move to the evaluation step, where we run the eval set completely also using batches of 100.

Most of the training workflow function shown will be very familiar for users of PyTorch. However, there are a couple of elements that are different.

Special Elements

  1. Model to GPU Resources
device = torch.device(0)
net = models.resnet50(pretrained=True)
model = net.to(device)

We need to make sure our model is assigned to a GPU resource- here we do it one time before the training loops begin. We will also assign each image and its label to a GPU resource within the training and evaluation loops.

  1. DDP Wrapper

model = DDP(model)

And finally, we need to enable the DistributedDataParallel framework. To do this, we are using the DDP() wrapper around the model, which is short for the PyTorch function torch.nn.parallel.DistributedDataParallel. There is a lot to know about this, but for our purposes the important thing is to understand that this allows the model training to run in parallel on our cluster. (DDP Documentation)

Discussing DDP
It may be interesting for you to know what DDP is really doing under the hood: for a detailed discussion and more tips about this same workflow, you can visit our blog to read more! https://www.saturncloud.io/s/combining-dask-and-pytorch-for-better-faster-transfer-learning/

Model Training Workflow Function

def run_training(bucket, prefix, train_pct, batch_size, 
                          n_epochs, base_lr, imagenetclasses, 
                          n_workers = 1, subset = False):
    '''Load basic Resnet50 architecture untrained, run training over given epochs.
    Uses dataset from the path given as the pool from which to take the 
    training and evaluation samples.'''

    s3 = s3fs.S3FileSystem()
    worker_rank = int(dist.get_rank())

    # --------- Format model and params --------- #
    device = torch.device("cuda")
    net = models.resnet50(pretrained=False)   
    model = net.to(device)
    model = DDP(model)

    criterion = nn.CrossEntropyLoss().cuda()    
    lr = base_lr
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode = 'min', patience = 2)
    
    # --------- Retrieve and format data for training and eval --------- #
    whole_dataset = prepro_batches(bucket, prefix)
    new_class_to_idx = {x: int(replace_label(x, imagenetclasses)[1]) for x in whole_dataset.classes}
    whole_dataset.class_to_idx = new_class_to_idx
    
    train, val = get_splits_parallel(train_pct, whole_dataset, batch_size=batch_size, subset = subset, workers = n_workers)
    dataloaders = {'train' : train, 'val': val}
    
    # --------- Start iterations --------- #
    count = 0
    t_count = 0
    
    for epoch in range(n_epochs):
    # --------- Training section --------- #   
        model.train()  # Set model to training mode
        for inputs, labels in dataloaders["train"]:
            count += 1
           
            # Pass items to GPU
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Run model iteration
            outputs = model(inputs)

            # Format results
            _, preds = torch.max(outputs, 1)
            perct = [torch.nn.functional.softmax(el, dim=0)[i].item() for i, el in zip(preds, outputs)]

            loss = criterion(outputs, labels)
            correct = (preds == labels).sum().item()
            
            # zero the parameter gradients
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    # --------- Evaluation section --------- #   
        with torch.no_grad():
            model.eval()  # Set model to evaluation mode
            for inputs_t, labels_t in dataloaders["val"]:
                t_count += 1

                # Pass items to GPU
                inputs_t = inputs_t.to(device)
                labels_t = labels_t.to(device)

                # Run model iteration
                outputs_t = model(inputs_t)

                # Format results
                _, pred_t = torch.max(outputs_t, 1)
                perct_t = [torch.nn.functional.softmax(el, dim=0)[i].item() for i, el in zip(pred_t, outputs_t)]

                loss_t = criterion(outputs_t, labels_t)
                correct_t = (pred_t == labels_t).sum().item()

        scheduler.step(loss)

Assign Parameters

num_workers = 64
client.restart() # Optional, but recommended - clears cluster memory

startparams = {'n_epochs': 100, 
                'batch_size': 100,
                'train_pct': .8,
                'base_lr': 0.01,
                'imagenetclasses':imagenetclasses,
                'subset': False,
                'n_workers': 1}

Run training workflow on cluster

Now we’ve done all the hard work, and just need to run our function! Using dispatch.run from dask-pytorch-ddp, we pass in the learning function so that it gets distributed correctly across our cluster. This creates futures and starts computing them.

Inside the dispatch.run() function in dask-pytorch-ddp, we are actually using the client.submit() method to pass tasks to our workers, and collecting these as futures in a list. We can prove this by looking at the results, here named “futures”, where we can see they are in fact all pending futures, one for each of the workers in our cluster.

Why don’t we use .map() in this function?
Recall that .map allows the Cluster to decide where the tasks are completed - it has the ability to choose which worker is assigned any task. That means that we don’t have the control we need to ensure that we have one and only one job per GPU. This could be a problem for our methodology because of the use of DDP. Instead we use .submit and manually assign it to the workers by number. This way, each worker is attacking the same problem - our training problem - and pursuing a solution simultaneously. We’ll have one and only one job per worker.

futures = dispatch.run(
    client, 
    run_training, 
    bucket = "saturn-public-data", 
    prefix = "dogs/Images", 
    **startparams
)
futures

You may want to add steps in the workflow that checkpoint the model and/or save performance statistics. There are several ways to do this, including saving those metrics to the workers for later retrieval, or writing them to an external data store such as S3. If you need help with this, contact our support!




Need help, or have more questions? Contact us at: We'll be happy to help you and answer your questions!