Skip to content

Pivot Tables in Pandas with Python

Pivot Tables in Pandas with Python Cover Image

You may be familiar with pivot tables in Excel to generate easy insights into your data. In this post, you’ll learn how to create pivot tables in Python and Pandas using the .pivot_table() method. This post will give you a complete overview of how to use the .pivot_table() function!

Being able to quickly summarize data is an important skill to be able to get a sense of what your data looks like. The function is quite similar to the .groupby() method also available in Pandas, but offers significantly more customization, as we’ll see later on in this post.

By the end of this tutorial, you’ll have learned:

  • How to use the pivot_table() function and what its parameters represent
  • How to group data using an index or a multi-index
  • How to pivot table even further using indices and columns
  • How to specify and create your own aggregation methods
  • How to calculate totals and deal with missing data

Python Pivot Tables Video Tutorial

How to Build a Pivot Table in Python

A pivot table is a table of statistics that helps summarize the data of a larger table by “pivoting” that data. Microsoft Excel popularized the pivot table, where they’re known as PivotTables. Pandas gives access to creating pivot tables in Python using the .pivot_table() function. The function has the following default parameters:

# The syntax of the .pivot_table() function
import pandas as pd
pd.pivot_table(
    data=,
    values=None, 
    index=None, 
    columns=None, 
    aggfunc='mean', 
    fill_value=None, 
    margins=False, 
    dropna=True, 
    margins_name='All', 
    observed=False,
    sort=True
) 

The method takes a DataFrame and then also returns a DataFrame. The table below provides an overview of the different parameters available in the function:

ParameterDefault ValueDescription
data=The DataFrame to pivot
values=The column to aggregate (if blank, will aggregate all numerical values)
index=The column or columns to group data by. A single column can be a string, while multiple columns should be a list of strings
columns=The column or columns to group data by. A single column can be a string, while multiple columns should be a list of strings
aggfunc=‘mean’A function or list of functions to aggregate data by
fill_value=Value to replace missing values with
margins=FalseAdd a row and column for totals
dropna=TrueTo choose to not include columns where all entries are NaN
margins_name=‘All’Name of total row/column
observed=FalseOnly for categorical data – if True will only show observed values for categorical groups
sort=TrueWhether to sort the resulting values
The parameters of the pivot_table function in Pandas.

Now that you have an understanding of the different parameters available in the function, let’s load in our data set and begin exploring our data.

Loading a Sample Pandas DataFrame

To follow along with this tutorial, let’s load a sample Pandas DataFrame. We can load the DataFrame from the file hosted on my GitHub page, using the pd.read_excel() function. Then we can print out the first five records of the dataset using the .head() method.

# Loading a Sample Pandas DataFrame
import pandas as pd
df = pd.read_excel('https://github.com/datagy/mediumdata/raw/master/sample_pivot.xlsx', parse_dates=['Date'])
print(df.head())

# Returns:
#         Date Region                 Type  Units  Sales
# 0 2020-07-11   East  Children's Clothing   18.0    306
# 1 2020-09-23  North  Children's Clothing   14.0    448
# 2 2020-04-02  South     Women's Clothing   17.0    425
# 3 2020-02-28   East  Children's Clothing   26.0    832
# 4 2020-03-19   West     Women's Clothing    3.0     33

Based on the output of the first five rows shown above, we can see that we have five columns to work with:

Column NameDescription
DateDate of transaction
RegionThe region of the transaction
TypeThe type of clothing sold
UnitsThe number of units sold
SalesThe cost of the sale
Description of columns of our dataset

Now that we have a bit more context around the data, let’s explore creating our first pivot table in Pandas.

Creating a Pivot Table in Pandas

Let’s create your first Pandas pivot table. At a minimum, we have to pass in some form of a group key, either using the index= or columns= parameters. In the examples below, we’re using the Pandas function, rather than the DataFrame function. Because of this, we need to pass in the data= argument. If we applied the method to the DataFrame directly, this would be implied.

# Creating your first Pandas pivot table
pivot = pd.pivot_table(
    data=df,
    index='Region'
)
print(pivot)

# Returns:
#              Sales      Units
# Region                       
# East    408.182482  19.732360
# North   438.924051  19.202643
# South   432.956204  20.423358
# West    452.029412  19.29411

Let’s break down what happened here:

  1. We created a new DataFrame called sales_by_region, which was created using the pd.pivot_table() function
  2. We passed in our DataFrame, df, and set the index='region', meaning data would be grouped by the region column

Because all other parameters were left to their defaults, Pandas made the following assumption:

  • Data should be aggregated by the average of each column (aggfunc='mean')
  • The values should be any numeric columns

Aggregating Only Certain Columns in a Pandas Pivot Table

In the example above, you didn’t modify the values= parameter. Because of this, all numeric columns were aggregated. This may not always be ideal. Because of this, Pandas allows us to pass in either a single string representing one column or a list of strings representing multiple columns.

Let’s now modify our code to only calculate the mean for a single column, Sales:

# Aggreating Only A Single Column
pivot = pd.pivot_table(
    data=df,
    index='Region',
    values='Sales'
)

print(pivot)

# Returns:
#              Sales
# Region            
# East    408.182482
# North   438.924051
# South   432.956204
# West    452.029412

We can see that instead of aggregating all numeric columns, only the one specified was aggregated.

Working with Aggregation Methods in a Pandas Pivot Table

Now that you’ve created your first pivot table in Pandas, let’s work on changing the aggregation methods. This allows you to specify how you want your data aggregated. This is where the power of Pandas really comes through, allowing you to calculate complex analyses with ease.

Specifying Aggregation Method in a Pandas Pivot Table

You can use the aggfunc= (aggregation function) parameter to change how data are aggregated in a pivot table. By default, Pandas will use the .mean() method to aggregate data. You can pass a named function, such as 'mean', 'sum', or 'max', or a function callable such as np.mean.

Let’s now try to change our behavior to produce the sum of our sales across all regions:

# Specifying the Aggregation Function
pivot = pd.pivot_table(
    data=df,
    index='Region',
    aggfunc='sum'
)

print(pivot)

# Returns:
#          Sales   Units
# Region                
# East    167763  8110.0
# North   138700  4359.0
# South    59315  2798.0
# West     61476  2624.0

Multiple Aggregation Method in a Pandas DataFrame

Similarly, we can specify multiple aggregation methods to a Pandas pivot table. This is quite easy and only requires you to pass in a list of functions and the function will be applied to all values columns. Let’s produce aggregations for both the mean and the sum:

pivot = pd.pivot_table(
    data=df,
    index='Region',
    aggfunc=['mean', 'sum']
)

print(pivot)

# Returns:
#               mean                sum        
#              Sales      Units   Sales   Units
# Region                                       
# East    408.182482  19.732360  167763  8110.0
# North   438.924051  19.202643  138700  4359.0
# South   432.956204  20.423358   59315  2798.0
# West    452.029412  19.294118   61476  2624.0

We can see how easy that was and how much more data it provides! For each column containing numeric data, both the mean and the sum are created.

Specifying Different Aggregations Per Column

Now, imagine you wanted to calculate different aggregations per column. In order to do this, you can pass in a dictionary containing the following key-value pair format: 'column': function. Let’s say we wanted to calculate the sum of units and the average number of sales:

pivot = pd.pivot_table(
    data=df,
    index='Region',
    aggfunc={'Sales': 'mean', 'Units': 'sum'}
)

print(pivot)

# Returns:
#              Sales   Units
# Region                    
# East    408.182482  8110.0
# North   438.924051  4359.0
# South   432.956204  2798.0
# West    452.029412  2624.0

This allows you to easily see data compared across different key performance indicators easily, in the same DataFrame.

Custom Aggregations in Pandas Pivot Tables

Pandas also allows us to pass in a custom function into the .pivot_table() function. This greatly extends our ability to work with analyses specifically tailored to your needs! Let’s see how we can pass in a function that calculates the mean of a column without any outliers.

Pandas comes with a method, .quantiles(), that lets us define a range of values we want to select based on their percentage in the range of data. Let’s say we wanted to calculate the average of a column, removing the top and bottom 10% of the data. We could define the following function:

# Defining a custom function
import numpy as np
def mean_no_outliers(values):
    no_outliers = values.quantile([0.1, 0.9])
    mean = np.mean(no_outliers)
    return mean

This function accepts a single parameter, values, which will be the values passed in by the .pivot_table() function. The values are then filtered using the .quantile() method. Finally, the mean of these values is calculated. Let’s see how we can use this (and the normal mean aggregation) in our pivot table, applied to our Sales column.

# Specifying custom functions in a Pandas pivot table
pivot = pd.pivot_table(
    data=df,
    index='Region',
    aggfunc=['mean', mean_no_outliers],
    values='Sales'
)

print(pivot)

# Returns:
#               mean mean_no_outliers
#              Sales            Sales
# Region                             
# East    408.182482            436.0
# North   438.924051            484.5
# South   432.956204            434.1
# West    452.029412            497.0

More Complex Pandas Pivot Tables

Now that you have an understanding of how the .pivot_table() function works in Pandas, let’s take a look at how we can expand our understanding. In this section, you’ll learn how to add columns and multiple indices to our Pandas pivot tables.

Add Columns to a Pandas Pivot Table

When we add columns to a Pandas pivot table, we add another dimension to the data. While the index= parameter splits the data vertically, the columns= parameter groups and splits the data horizontally. This allows us to create an easy-to-read table. Let’s see how we can use the columns= parameter to split the data by the Type column.

# Adding Columns to Our Pandas Pivot Table
pivot = pd.pivot_table(
    data=df,
    index='Region',
    columns='Type',
    values='Sales'
)

print(pivot)

# Returns:
# Type    Children's Clothing  Men's Clothing  Women's Clothing
# Region                                                       
# East             405.743363      423.647541        399.028409
# North            438.894118      449.157303        432.528169
# South            412.666667      475.435897        418.924528
# West             480.523810      465.292683        419.188679

We can see how easy it was to add a whole other dimension of data. This allows us to spot differences between groupings in a format that’s easy to read.

Add Multiple Indices to Pandas Pivot Tables

While columns added a horizontal dimension, we can also specify multiple indices when there is a logical hierarchy in our data. For example, we can add a date dimension to our pivot table. Let’s use Pandas’ built-in date accessors to group our data by quarters. This allows us, then, to visualize our data over a period of time. Let’s see how this works:

pivot = pd.pivot_table(
    data=df,
    index=['Region',df['Date'].dt.quarter],
    columns='Type',
    values='Sales'
)

print(pivot.head())

# Returns:
# Type         Children's Clothing  Men's Clothing  Women's Clothing
# Region Date                                                       
# East   1              423.241379      369.250000        428.948718
#        2              274.800000      445.425000        456.816327
#        3              425.382353      506.421053        342.386364
#        4              453.866667      405.666667        364.795455
# North  1              394.727273      450.869565        489.944444

This returns a multi-index Pandas DataFrame. While it may look more complex, accessing data in a multi-index Pandas DataFrame works quite similarly to accessing data in any other DataFrame. However, since now we have two index columns, rather than 1, we can pass in a tuple of indices. Let’s sat we wanted to access only the intersection of East Region, Quarter 1, and Men’s clothing, we could use the following code:

# Accessing data in a multi-index pivot table
print(pivot.loc[('East', 1), "Men's Clothing"])

# Returns: 369.25

Customizing Pandas Pivot Tables

In this section, you’ll learn how to customize your Pandas pivot tables. This allows you to add even further customizations, such as adding totals and working with missing data. You’ll also learn how to fill missing data in a resulting pivot table with a specific value.

Adding Totals to Pandas Pivot Tables

Let’s start off by learning how to add totals to a Pandas pivot table. This is controlled by the margins= parameter, which accepts a boolean value. By default, this is set to False, but toggling it to True adds the totals to rows and columns. Let’s see what this looks like:

# Adding totals to rows and columns
pivot = pd.pivot_table(
    data=df,
    index='Region',
    columns='Type',
    values='Sales',
    margins=True
)

print(pivot)

# Returns:
# Type    Children's Clothing  Men's Clothing  Women's Clothing         All
# Region                                                                   
# East             405.743363      423.647541        399.028409  408.182482
# North            438.894118      449.157303        432.528169  438.924051
# South            412.666667      475.435897        418.924528  432.956204
# West             480.523810      465.292683        419.188679  452.029412
# All              427.743860      444.257732        415.254717  427.254000

By default, Pandas will name the totals 'All'. If you wanted to rename these labels, you can use the margins_name= parameter to pass in a string to relabel the values.

# Renaming totals in a Pandas pivot table
pivot = pd.pivot_table(
    data=df,
    index='Region',
    columns='Type',
    values='Sales',
    margins=True,
    margins_name='Total'
)

print(pivot)

# Returns:
# Type    Children's Clothing  Men's Clothing  Women's Clothing       Total
# Region                                                                   
# East             405.743363      423.647541        399.028409  408.182482
# North            438.894118      449.157303        432.528169  438.924051
# South            412.666667      475.435897        418.924528  432.956204
# West             480.523810      465.292683        419.188679  452.029412
# Total            427.743860      444.257732        415.254717  427.254000

Dealing with Missing Data in a Pandas Pivot Table

When Pandas encounters a cross-section where no data exists, it’ll include a NaN value in the resulting pivot table. Let’s modify our DataFrame to include some missing data and calculate a pivot table to see what this looks like:

# Adding and seeing missing data in a Pandas pivot table
import numpy as np
df.loc[(df['Region'] == 'East') & (df['Type'] == "Children's Clothing"), 'Sales'] = np.NaN

pivot = pd.pivot_table(
    data=df,
    index='Region',
    columns='Type',
    values='Sales',
)

print(pivot)

# Returns:
# Type    Children's Clothing  Men's Clothing  Women's Clothing
# Region                                                       
# East                    NaN      423.647541        399.028409
# North            438.894118      449.157303        432.528169
# South            412.666667      475.435897        418.924528
# West             480.523810      465.292683        419.188679

It may not always be ideal to see a NaN value, especially for non-technical audiences. Because of this, Pandas provides a parameter, fill_value=, which enables you to pass in a value to fill these missing data points. For example, if we wanted to fill all these values with a 0, we can simply pass in this argument:

# Filling Missing Values in a Pandas Pivot Table
import numpy as np
df.loc[(df['Region'] == 'East') & (df['Type'] == "Children's Clothing"), 'Sales'] = np.NaN

pivot = pd.pivot_table(
    data=df,
    index='Region',
    columns='Type',
    values='Sales',
    fill_value=0
)

print(pivot)

# Returns:
# Type    Children's Clothing  Men's Clothing  Women's Clothing
# Region                                                       
# East               0.000000      423.647541        399.028409
# North            438.894118      449.157303        432.528169
# South            412.666667      475.435897        418.924528
# West             480.523810      465.292683        419.188679

Sorting Pandas Pivot Table Data

Beginning in Pandas version 1.3.0, a new parameter was added which enables you to sort the resulting DataFrame. Previously, you’d need to first generate the DataFrame and then pass in a method to sort the data. Now, you can simply pass in the sort=True argument to help sort your resulting DataFrame.

# Sorting a Pandas Pivot Table
pivot = pd.pivot_table(
    data=df,
    index='Region',
    values='Sales',
    sort=True
)

print(pivot)

# Returns:
#              Sales
# Region            
# East    408.182482
# North   438.924051
# South   432.956204
# West    452.029412

By default, Pandas will sort the pivot table in ascending order. Unfortunately, for more complex sorting (such as across different columns), you would still need to chain the .sort_values() method.

Filtering Python Pivot Tables

In this section, you’ll learn how to filter a Pandas pivot table. Because pivot tables can often be quite large, filtering a pivot table can focus the results quite a bit. Because the function returns a DataFrame, you can simply filter the DataFrame as you would any other. Let’s recreate our pivot table adding up values over quarters and regions.

# Generating a long pivot table
pivot = pd.pivot_table(
    data=df,
    index=['Region', df['Date'].dt.quarter],
    values='Sales'
)

print(pivot.head())

# Returns:
#                   Sales
# Region Date            
# East   1     406.692308
#        2     419.238532
#        3     403.608247
#        4     402.178218
# North  1     462.142857

What we can do now is either filter by a scalar value or by a dynamic value. For example, we could simply filter based on a hard coded value. But, say, for example we wanted to filter to only show records where the Sales average was larger than the overall average, we could write the following filter:

print(pivot[pivot['Sales'] > pivot['Sales'].mean()])

# Returns:
#                   Sales
# Region Date            
# North  1     462.142857
#        2     442.034884
#        3     447.200000
# South  1     465.263158
#        2     440.628571
# West   1     475.000000
#        3     444.884615
#        4     466.209302

This allows us to see exactly what we want to see!

Exercises

It’s time to check your learning! Try to solve the exercises below based on what you learned. If you need help or want to verify your solution, toggle the section to see a sample solution. Use the same DataFrame as you did throughout the tutorial.

You can sort your pivot table, then use the .index accessor to access the last value (since data are sorted in ascending order).

pivot = pd.pivot_table(
    data=df,
    index='Region',
    values='Sales',
    sort=True
)

print(pivot.index[-1])

# Returns: West

It’s recommended to keep to numeric data types (such as integers and floats) in order to prevent columns from being converted to columns that can’t have mathematical operations applied to them. Because columns in Pandas are homogeneous, it’s important to keep in mind what might happen.

Sometimes you may just want to have the column totals in your resulting DataFrame. Because of this, you can simply filter out the last row, using the negative index:

pivot.loc[-1:,]

Conclusion and Recap

In this tutorial, you learned how to use the Pandas .pivot_table() function to generate Excel-style pivot tables, directly off of a Pandas DataFrame. The function provides significant flexibility through a large assortment of parameters. The section below provides a summary of what you’ve learned:

  • The Pandas pivot_table() function provides a familiar interface to create Excel-style pivot tables
  • The function requires at a minimum either the index= or columns= parameters to specify how to split data
  • The function can calculate one or multiple aggregation methods, including using custom functions
  • The function returns a DataFrame which can be filtered or queried as any other DataFrame

Additional Resources

To learn more about related topics, check out the tutorials below:

23 thoughts on “Pivot Tables in Pandas with Python”

  1. Pingback: Python: Reverse a String (6 Easy Ways) • datagy

  2. Pingback: Pandas Dataframe to CSV File - Export Using .to_csv() • datagy

  3. Pingback: Python: Shuffle a List (Randomize Python List Elements) • datagy

  4. Pingback: Python: Get Filename From Path (Windows, Mac & Linux) • datagy

  5. Pingback: Python: Check if a Key (or Value) Exists in a Dictionary (5 Easy Ways) • datagy

  6. Pingback: Python: Find an Index (or all) of a Substring in a String • datagy

  7. Pingback: Pandas: Add Days to a Date Column • datagy

  8. Pingback: Length of List in Python: Calculate how many Items a List has • datagy

  9. Pingback: Python Zip Lists - Zip Two or More Lists in Python • datagy

  10. Pingback: Pandas: Number of Rows in a Dataframe (6 Ways) • datagy

  11. Pingback: Pandas Replace: Replace Values in Pandas Dataframe • datagy

  12. Pingback: Matplotlib Scatter Charts – Learn all you need to know • datagy

  13. Pingback: Reorder Pandas Columns: Pandas Reindex and Pandas insert • datagy

  14. Pingback: How to Drop Duplicates in Pandas - Subset and Keep • datagy

  15. Pingback: Python Dictionary Comprehensions (With Examples) • datagy

  16. Pingback: Pandas get dummies (One-Hot Encoding) Explained • datagy

  17. Pingback: VLOOKUP in Python and Pandas using .map() or .merge() • datagy

  18. Pingback: Python: Remove Duplicates From a List (7 Ways) • datagy

  19. Pingback: Matplotlib Line Charts - Learn all you need to know • datagy

  20. Pingback: Using Pandas for Descriptive Statistics in Python • datagy

  21. Pingback: Python: Check If a String is a Palindrome (5 Easy Ways!) • datagy

Leave a Reply

Your email address will not be published.