Snowflake and Dask

July 15, 2020

Snowflake and Dask

July 15, 2020

Snowflake is the most popular data warehouse amongst our Saturn users. This article will cover efficient ways to load Snowflake data into Dask so you can do non-SQL operations (think machine learning) at scale.

The Basics

First, some basics, the standard way to load Snowflake data into Pandas:

import snowflake.connector
import pandas as pd

ctx = snowflake.connector.connect(
    user='YOUR_USER',
    password='YOUR_PASSWORD',
    account='YOUR_ACCOUNT'
)
query = "SELECT * FROM SNOWFLAKE_SAMPLE_DATA.TPCH_SF1.CUSTOMER"
pd.read_sql(query, ctx)

Snowflake recently introduced a much faster method for this operation, fetch_pandas_all(), and fetch_pandas_batches() which leverages Arrow.

cur = ctx.cursor()
cur.execute(query)
df = cur.fetch_pandas_all()

fetch_pandas_batches() returns an iterator, but since we’re going to focus on loading this into a distributed dataframe (pulling from multiple machines), we’re going to setup our query to shard the data, and use fetch_pandas_all() on our workers.

When to use Snowflake or Dask?

It can be very tempting to rip all the data out of Snowflake so that you can work with it in Dask. That definitely works, however Snowflake is going to be much faster at applying SQL-like operations to the data. Snowflake stores the data and has highly optimized routines to get every ounce of performance out of your query. These examples will pretend like we’re loading the entire data into Dask – in your case, you will probably have some SQL query, which performs the SQL-like transformations you care about, and you’ll be loading the result set into Dask, for the things that Dask is good at (possibly some types of feature engineering, and machine learning). Saturn Cloud has a native integration with Snowflake and Dask, and you can try out loading some data with our free trial here.

How does Dask load data?

You can think about a Dask dataframe as a giant Pandas dataframe, that has been chopped up an scattered across a bunch of computers. When we are loading data from Snowflake (assuming that the data is large), it’s not efficient to load all the data on one machine, and then scatter that out to your cluster. We are going to focus on having all machines in your Dask cluster load a partition (a small slice) of your data.

Data Clustering

We need a way to split the data into little partitions so that we can load it into the cluster. Data in SQL doesn’t necessarily have any natural ordering. You can’t just say that you’re going to throw the first 10k rows into one partition, and the second 10k rows into another partition. That partitioning has to be based on a column of the data. For example you can partition the data by a date field. Or you can create a row number by adding an identity column into your Snowflake table.

Once you’ve decided what column you want to partition your data on, it’s important to setup data clustering on the snowflake side. Every single worker is going to ask for a small slice of the data. Something like:

select * from table where id < 20000 and id >= 10000

If you don’t setup data clustering, every single query is going to trigger a full table scan on the resulting database (I’m probably overstating the problem, but without data clustering, performance here can be quite bad).

Load it!

We aren’t going to use read_sql_table() from the dask library here. I prefer to have more control over how we load the data from Snowflake, and we want to call fetch_pandas_all(), which is a Snowflake specific function, and therefore not supported with read_sql_table().

We need to set up a query template containing a binding that will result in Dask issuing multiple queries that each extract a slice of the data based on a partitioning column. These slices will become the partitions in a Dask DataFrame.

query = """
SELECT *
FROM customer
WHERE c_custkey BETWEEN %s AND %s
"""

It’s important to pick a column that evenly divides your data, like a row ID or a uniformly distributed timestamp. Otherwise one query may take much longer to execute than the others. We then use a dask.delayed function to execute this query multiple times in parallel for each partition. Note that we put our Snowflake connection information in a dict called conn_info to be able to reference it multiple times.

import snowflake.connector
import dask

conn_info = {
    "account": 'YOUR_ACCOUNT',
    "user": 'YOUR_USER',
    "password": 'YOUR_PASSWORD',
    "database": 'SNOWFLAKE_SAMPLE_DATA',
    "schema": 'TPCH_SF1',
}

@dask.delayed
def load(conn_info, query, start, end):
    with snowflake.connector.connect(**conn_info) as conn:
        cur = conn.cursor().execute(query, (start, end))
        return cur.fetch_pandas_all()

@dask.delayed is a decorator that turns a Python function into a function suitable for running on the Dask cluster. When you call it, instead of executing, it returns a Delayed result that represents what the return value of the function will be. The from_delayed() function takes a list of these Delayed objects, and concatenates them into a giant dataframe.

We can now call this load() function multiple times and convert the results into a Dask DataFrame using dask.dataframe.from_delayed().

import dask.dataframe as dd

results = [
    load(conn_info, query, 0, 10000),
    load(conn_info, query, 10001, 20000),
    load(conn_info, query, 20001, 30000),
]
ddf = dd.from_delayed(results)

The start and end values were hard-coded for the above example, but you would normally write a query to determine what the partitions will look like based on your data. For example, with our customer table, we know that the c_custkey coumn is an auto-incrementing, non-null ID column (the cardinality of the column is equal to the number of rows in the table). We can write a function that will determine the appropriate start and end values given a desired number of partitions, then use those results to create the Dask DataFrame:

def get_partitions(table, id_col, num_partitions=100):
    with snowflake.connector.connect(**conn_info) as conn:
        part_query = f"SELECT MAX({id_col}) from {table}"
        part_max = conn.cursor().execute(part_query).fetchall()[0][0]

        inc = part_max // num_partitions
        parts = [(i, i + inc - 1) for i in range(0, part_max, inc)]
        return parts

parts = get_partitions('customer', 'c_custkey')

ddf = dd.from_delayed(
    [load(conn_info, query, part[0], part[1]) for part in parts]
)

As long as the full dataset fits into the memory of your cluster, you can persist the DataFrame to ensure the Snowflake queries only execute once. Otherwise, they will execute each time you trigger computation on the DataFrame.

ddf = ddf.persist()

Putting it all together

This is what the code would look like for this example table. You will likely want to do many more transformations in the Snowflake query as you can leverage the power of Snowflake’s data warehouse there. The partition column you use will also be different depending on how your data is organized.

import snowflake.connector
import dask
import dask.dataframe as dd

conn_info = {
    "account": 'YOUR_ACCOUNT',
    "user": 'YOUR_USER',
    "password": 'YOUR_PASSWORD',
    "database": 'SNOWFLAKE_SAMPLE_DATA',
    "schema": 'TPCH_SF1',
}

query = """
SELECT *
FROM customer
WHERE c_custkey BETWEEN %s AND %s
"""

def get_partitions(table, id_col, num_partitions=100):
    with snowflake.connector.connect(**conn_info) as conn:
        part_query = f"SELECT MAX({id_col}) from {table}"
        part_max = conn.cursor().execute(part_query).fetchall()[0][0]

        inc = part_max // num_partitions
        parts = [(i, i + inc - 1) for i in range(0, part_max, inc)]
        return parts

parts = get_partitions('customer', 'c_custkey')

ddf = dd.from_delayed(
    [load(conn_info, query, part[0], part[1]) for part in parts]
)
ddf = ddf.persist()

Thanks for reading. If you’re interested in using Dask with Snowflake, then I recommend you check out Saturn Cloud and read our Connecting to Snowflake docs page.

By Hugo Shi
Posted in Blog | July 15, 2020