Computer Vision at Scale With Dask and PyTorch

This tutorial walks through how to use PyTorch and Dask to train an image recognition model across a GPU cluster.

Applying deep learning strategies to computer vision problems has opened up a world of possibilities for data scientists. However, to use these techniques at scale to create business value, substantial computing resources need to be available – and this is just the kind of challenge Saturn Cloud is built to solve!

In this tutorial, you’ll see the steps to conducting image classification inference using the popular Resnet50 deep learning model at scale using NVIDIA GPU clusters on Saturn Cloud. Using the resources Saturn Cloud makes available, we can run the task 40x faster than a non-parallelized approach!

Cute and inspiring dogs

We’ll be classifying dog images today!

What you’ll learn here:

  • How to set up and manage a GPU cluster on Saturn Cloud for deep learning inference tasks
  • How to run inference tasks with Pytorch on the GPU cluster
  • How to use batch processing to accelerate your inference tasks with Pytorch on the GPU cluster

Table of Contents

  1. Setup
    a. GPU
  2. Inference
    a. Preprocessing
    b. Run the Model
    c. Putting It All Together
    c. On the Cluster
  3. Evaluate Results
  4. Compare Runtime Performance

Setup

To begin, we need to ensure that our image dataset is available and that our GPU cluster is running.

In our case, we have stored the data on S3 and use the s3fs library to work with it, as you’ll see below. If you would like to use this same dataset, it is the Stanford Dogs dataset, available here: http://vision.stanford.edu/aditya86/ImageNetDogs/

To set up our Saturn GPU cluster, the process is very straightforward.

import dask_saturn
from dask_saturn import SaturnCluster

cluster = SaturnCluster(n_workers=4, worker_size='g4dn8xlarge')
client = Client(cluster)
client

#> [2020-10-15 18:52:56] INFO -- dask-saturn | Cluster is ready

We are not explicitly stating it, but we are using 32 threads each on our cluster nodes, making 128 total threads.

Tip: Individual users may find that you want to adjust the number of threads, reducing it down if your files are very large – too many threads running large tasks simultaneously might require more memory than your workers have available at one time.

This step may take a moment to complete because all the AWS instances that we are requesting need to be spun up. Calling client at the end, there will monitor the spin-up process and let you know when things are ready to rock!

GPU Capability

At this point, we can confirm that our cluster has GPU capabilities, and make sure we have set everything up correctly.

First, check that the Jupyter server has GPU capability.

torch.cuda.is_available()

#> True

Awesome- now let’s also check each of our four workers.

client.run(lambda: torch.cuda.is_available())

#> {'tcp://10.0.24.217:45281': True,
#> 'tcp://10.0.28.232:36099': True,
#> 'tcp://10.0.3.136:40143': True,
#> 'tcp://10.0.3.239:40585': True}

Here then we’ll set the “device” to always be cuda, so we can use those GPUs.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Inference

Now, we’re ready to start doing some classification! We’re going to use some custom-written functions to do this efficiently and make sure our jobs can take full advantage of the parallelization of the GPU cluster.

Preprocessing

Single Image Processing

@dask.delayed
def preprocess(path, fs=__builtins__):
    '''Ingest images directly from S3, apply transformations,
    and extract the ground truth and image identifier. Accepts
    a filepath. '''

    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(250),
        transforms.ToTensor()])

    with fs.open(path, 'rb') as f:
        img = Image.open(f).convert("RGB")
        nvis = transform(img)

    truth = re.search('dogs/Images/n[0-9]+-([^/]+)/n[0-9]+_[0-9]+.jpg', path).group(1)
    name = re.search('dogs/Images/n[0-9]+-[a-zA-Z-_]+/(n[0-9]+_[0-9]+).jpg', path).group(1)

    return [name, nvis, truth]

This function allows us to process one image, but of course, we have a lot of images to work with here! We’re going to use some list comprehension strategies to create our batches and get them ready for our inference.

First, we break the list of images we have from our S3 file path into chunks that will define the batches.

s3fpath = 's3://dask-datasets/dogs/Images/*/*.jpg'

batch_breaks = [list(batch) for batch in toolz.partition_all(60, s3.glob(s3fpath))]

Then we’ll process each file into nested lists. Then we’ll reformat this list setup slightly and we’re ready to go!

image_batches = [[preprocess(x, fs=s3) for x in y] for y in batch_breaks]

Notice that we have used the Dask delayed decorator on all of this- we don’t want it to actually run yet, but to wait until we are doing work in parallel on the GPU cluster!

Format Batches

This little step just makes sure that the batches of images are organized in the way that the model will expect them.

@dask.delayed
def reformat(batch):
    flat_list = [item for item in batch]
    tensors = [x[1] for x in flat_list]
    names = [x[0] for x in flat_list]
    labels = [x[2] for x in flat_list]
    return [names, tensors, labels]

image_batches = [reformat(result) for result in image_batches]

Run the Model

Now we are ready to do the inference task! This is going to have a few steps, all of which are contained in functions described below, but we’ll talk through them so everything is clear.

Our unit of work at this point is batches of 60 images at a time, which we created in the section above. They are all neatly arranged in lists so that we can work with them effectively.

One thing we need to do with the lists is to “stack” the tensors. We could do this earlier in our process, but because we are using the Dask delayed decorator on the preprocessing, our functions actually do not know that they are receiving tensors until later in the process. Therefore, we’re delaying the “stacking” as well by putting it inside this function that comes after the preprocessing.

@dask.delayed
def run_batch_to_s3(iteritem):
    ''' Accepts iterable result of preprocessing,
    generates inferences and evaluates. '''

    with s3.open('s3://dask-datasets/dogs/imagenet1000_clsidx_to_labels.txt') as f:
        classes = [line.strip() for line in f.readlines()]

    names, images, truelabels = iteritem

    images = torch.stack(images)
...

So now we have our tensors stacked so that batches can be passed to the model. We are going to retrieve our model using pretty simple syntax:

...
    resnet = models.resnet50(pretrained=True)
    resnet = resnet.to(device)
    resnet.eval()
...

Conveniently, we load the library torchvision which contains several useful pretrained models and datasets. That’s where we are grabbing Resnet50 from. Calling the method .to(device) allows us to assign the model object to GPU resources on our cluster.

Now we are ready to run inference! It is inside the same function, styled this way:

...
    images = images.to(device)
    pred_batch = resnet(images)
...

We assign our image stack (just the batch we are working on) to the GPU resources and then run the inference, returning predictions for that batch.

Result Evaluation

The predictions and truth we have so far, however, are not really human-readable or comparable, so we’ll use the functions that follow to fix them up and get us interpretable results.

def evaluate_pred_batch(batch, gtruth, classes):
    ''' Accepts batch of images, returns human readable predictions. '''
    _, indices = torch.sort(batch, descending=True)
    percentage = torch.nn.functional.softmax(batch, dim=1)[0] * 100

    preds = []
    labslist = []
    for i in range(len(batch)):
        pred = [(classes[idx], percentage[idx].item()) for idx in indices[i][:1]]
        preds.append(pred)

        labs = gtruth[i]
        labslist.append(labs)

    return(preds, labslist)

This takes our results from the model, and a few other elements, to return nice readable predictions and the probabilities the model assigned.

preds, labslist = evaluate_pred_batch(pred_batch, truelabels, classes)

From here, we’re nearly done! We want to pass our results back to S3 in a tidy, human-readable way, so the rest of the function handles that. It will iterate over each image because these functionalities are not batch handling. is_match is one of our custom functions, which you can check out below.

...
    for j in range(0, len(images)):
        predicted = preds[j]
        groundtruth = labslist[j]
        name = names[j]
        match = is_match(groundtruth, predicted)

        outcome = {'name': name, 'ground_truth': groundtruth, 'prediction': predicted, 'evaluation': match}

        # Write each result to S3 directly
        with s3.open(f"s3://dask-datasets/dogs/preds/{name}.pkl", "wb") as f:
            pickle.dump(outcome, f)
...

Put It All Together

Now, we aren’t going to patch together all these functions by hand, instead, we have assembled them in one single delayed function that will do the work for us. Importantly, we can then map this across all our batches of images across the cluster!

def evaluate_pred_batch(batch, gtruth, classes):
    ''' Accepts batch of images, returns human readable predictions. '''
    _, indices = torch.sort(batch, descending=True)
    percentage = torch.nn.functional.softmax(batch, dim=1)[0] * 100

    preds = []
    labslist = []
    for i in range(len(batch)):
        pred = [(classes[idx], percentage[idx].item()) for idx in indices[i][:1]]
        preds.append(pred)

        labs = gtruth[i]
        labslist.append(labs)

    return(preds, labslist)

def is_match(la, ev):
    ''' Evaluate human readable prediction against ground truth.
    (Used in both methods)'''
    if re.search(la.replace('_', ' '), str(ev).replace('_', ' ')):
        match = True
    else:
        match = False
    return(match)


@[dask](https://saturncloud.io/glossary/dask).delayed
def run_batch_to_s3(iteritem):
    ''' Accepts iterable result of preprocessing,
    generates inferences and evaluates. '''

    with s3.open('s3://dask-datasets/dogs/imagenet1000_clsidx_to_labels.txt') as f:
        classes = [line.strip() for line in f.readlines()]

    names, images, truelabels = iteritem

    images = torch.stack(images)

    with torch.no_grad():
        # Set up model
        resnet = models.resnet50(pretrained=True)
        resnet = resnet.to(device)
        resnet.eval()

        # run model on batch
        images = images.to(device)
        pred_batch = resnet(images)

        #Evaluate batch
        preds, labslist = evaluate_pred_batch(pred_batch, truelabels, classes)

        #Organize prediction results
        for j in range(0, len(images)):
            predicted = preds[j]
            groundtruth = labslist[j]
            name = names[j]
            match = is_match(groundtruth, predicted)

            outcome = {'name': name, 'ground_truth': groundtruth, 'prediction': predicted, 'evaluation': match}

            # Write each result to S3 directly
            with s3.open(f"s3://dask-datasets/dogs/preds/{name}.pkl", "wb") as f:
                pickle.dump(outcome, f)

        return(names)

On the Cluster

We have really done all the hard work already and can let our functions take it from here. We’ll be using the .map method to distribute our tasks efficiently.

futures = client.map(run_batch_to_s3, image_batches)
futures_gathered = client.gather(futures)
futures_computed = client.compute(futures_gathered, sync=False)

With map we ensure all our batches will get the function applied to them. With gather, we can collect all the results simultaneously rather than one by one. With compute(sync=False) we return all the futures, ready to be calculated when we want them. This may seem arduous, but these steps are required to allow us to iterate over the future.

Now we actually run the tasks, and we also have a simple error handling system just in case any of our files are messed up or anything goes haywire.

import logging

results = []
errors = []
for fut in futures_computed:
    try:
        result = fut.result()
    except Exception as e:
        errors.append(e)
        logging.error(e)
    else:
        results.extend(result)

Evaluate

We want to make sure we have high-quality results coming out of this model, of course! First, we can peek at a single result.

with s3.open('s3://dask-datasets/dogs/preds/n02086240_1082.pkl', 'rb') as data:
    old_list = pickle.load(data)
    old_list

#> {'name': 'n02086240_1082',\
#> 'ground_truth': 'Shih-Tzu',\
#> ['prediction': \[(b"203: 'West Highland white terrier',",
#> 3.0289587812148966e-05)\],]\
#> 'evaluation': False}

While we have a wrong prediction here, we have the sort of results we expect! To do a more thorough review, we would download all the results files, then just check to see how many have evaluation:True.

Number of dog photos examined: 20580
Number of dogs classified correctly: 13806
The percent of dogs classified correctly: 67.085%

Not perfect, but good looking results overall!

Comparing Performance

So, we have managed to classify over 20,000 images in about 5 minutes. That sounds good, but what is the alternative?

alt text ](/images/blog/upload_bebbb7bba606ec8c9fb2ef41a2fb66e0.webp “blog-image”)

TechniqueRuntime
No Cluster with Batching3 hours, 21 minutes, 13 sec
GPU Cluster with Batching5 minutes, 15 sec

Adding a GPU cluster makes a HUGE difference! If you’d like to see this work for yourself, sign up for the free version of Saturn Cloud today!


About Saturn Cloud

Saturn Cloud is your all-in-one solution for data science & ML development, deployment, and data pipelines in the cloud. Spin up a notebook with 4TB of RAM, add a GPU, connect to a distributed cluster of workers, and more. Join today and get 150 hours of free compute per month.