Managing Dask Resources for Machine Learning Training

Machine Learning training workflows on data stored in Dask Collections can be sped up by changing the size and settings of the Dask cluster, and by changing the way that data are stored in the cluster.


This article contains some tips for resizing changing Dask settings to speed up machine learning training workloads. These tips should generally be useful for any interfaces that allow machine learning training on data stored in Dask collections, such as:

The examples below assume that you’re using a Saturn Cloud Dask cluster created with dask-saturn. It references objects client and cluster from a code block like this.

from dask.distributed import Client
from dask_saturn import SaturnCluster

n_workers = 3
cluster = SaturnCluster(

client = Client(cluster)

Restart the Workers

As you use a Dask cluster, some objects may not be cleared from memory when your code stops running. These memory leaks can from Dask itself or from any libraries (like pandas) that your Dask code uses, and can be very difficult to track down.

As a result, your second model training run might have less memory available to it than the first one, the third run might have less available to it than the second run, etc.

To guarantee a clean environment on all workers, you can run the code below.


This function will restart each of the worker processes, clearing out anything they’re holding in memory. This function does NOT restart the worker nodes themselves in a distributed setting, so it usually runs very quickly.

Increase the Worker Size

When the work you ask a cluster to do requires most of the cluster’s memory, Dask will start spilling objects to disk. This happens automatically. This is great because it means that jobs that require a little more memory than you expected won’t be killed by out-of-memory errors, but once you start spilling to disk you should expect your code to get significantly slower.

From the official Dask documentation:

  • At 60% of memory load (as estimated by sizeof), spill least recently used data to disk.
  • At 80% of memory load, stop accepting new work on local thread pool
  • At 95% of memory load, terminate and restart the worker

Watch the Dask dashboard closely. If the worker memory gets close to 60%, that might mean that Dask has started to spill to disk. In this situation, the easiest solution is to resize the cluster by increasing the worker size.

# check available sizes
from dask_saturn import describe_sizes

# resize the cluster from 4xlarge to 8xlarge
cluster = cluster.reset(
client = Client(cluster)

Add More Workers

If training is hitting memory issues (see “Increase the Worker Size”) or you just want to try throwing more CPU at the problem, increase the number of workers in the cluster. All else equal, you should expect that for many distributed training algorithms, doubling the number of workers cuts the training time in slightly-less-than-half. This effect has limits, since adding more workers adds more communication overhead, but it’s a good thing to try.

# scale up from 3 workers to 6

# wait until all of the workers are up

After doing this, you should repartition your data. Otherwise, the new workers won’t have any pieces of the dataset and won’t be able to contribute to training. Libraries like XGBoost takes special care to avoid moving data between workers as much as possible, and that includes only sending training tasks to workers that already have pieces of the dataset.

Use All Available Cores on Each Dask Worker

If you use dask_saturn.SaturnCluster and exclude the nthreads argument, Dask will automatically use a number of threads per worker equal to the number of logical CPUs on each worker. This is the ideal setting for machine learning training, because it allows the training process to use all of the resources in the cluster.

However, it’s possible to find yourself in a situation where nthreads on your Dask workers is less than the number of logical CPUs. For example, if you scale up to a larger worker size but forget to also increase nthreads in your call to cluster.reset().

To check for this situation, compare the nthreads for each worker to the number of CPUs reported by dask_saturn.describe_sizes().

    worker_id: w["nthreads"]
    for worker_id, w
    in client.scheduler_info()["workers"].items()

import dask_saturn

If you find that nthreads in your cluster is lower than the number of available CPUs per worker, change it with cluster.reset().

cluster = cluster.reset(
client = Client(cluster)

Repartition your Training Data

Dask collections look like one data structure to your code, but are actually composed of multiple smaller pieces. For Dask Array, these pieces are called “chunks” and each chunk is one numpy array. For Dask DataFrame, these pieces are called “partitions” and each partition is one pandas data frame.

Machine learning training libraries built to work with Dask usually parallelize work across pieces. Depending on the current number of pieces, you might be able to get a small speedup by re-partitioning your data into smaller pieces.

For Dask Arrays, use .rechunk() to change how your array is divided. For many machine learning workloads, note that you have to create only row-wise chunks. In other words, every chunk must have the same number of columns.

import dask.array as da
from dask.distributed import wait

# original dataset: 10 chunks, each with 100,000 rows
data = da.random.random(
    size=(1e6, 150),
    chunks=(1e5, 150)
print(f"number of partitions: {data.npartitions}")

# resized dataset: 100 chunks, each with 10,000 rows
data = data.rechunk(
    size=(1e4, 150)
print(f"number of partitions: {data.npartitions}")

# re-persist to materialize changes before training
data = data.persist()
_ = wait(data)

For Dask DataFrames, use .repartition() to change how your data frame is divided.

import dask.array as da
import dask.dataframe as dd
from dask.distributed import wait

# original dataset: 10 partitions, each with 100,000 rows
array = da.random.random(
    size=(1e6, 150),
    chunks=(1e5, 150)
ddf = dd.from_dask_array(array)
print(f"number of partitions: {ddf.npartitions}")

# resized dataset: 100 chunks, each with 10,000 rows
ddf = ddf.repartition(npartitions=100)
print(f"number of partitions: {ddf.npartitions}")

# re-persist to materialize changes before training
ddf = data.persist()
_ = wait(ddf)