You may be familiar with pivot tables in Excel to generate easy insights into your data. In this post, you’ll learn how to create pivot tables in Python and Pandas using the .pivot_table()
method. This post will give you a complete overview of how to use the .pivot_table()
function!
Being able to quickly summarize data is an important skill to be able to get a sense of what your data looks like. The function is quite similar to the .groupby()
method also available in Pandas, but offers significantly more customization, as we’ll see later on in this post.
By the end of this tutorial, you’ll have learned:
- How to use the
pivot_table()
function and what its parameters represent - How to group data using an index or a multi-index
- How to pivot table even further using indices and columns
- How to specify and create your own aggregation methods
- How to calculate totals and deal with missing data
Table of Contents
Python Pivot Tables Video Tutorial
How to Build a Pivot Table in Python
A pivot table is a table of statistics that helps summarize the data of a larger table by “pivoting” that data. Microsoft Excel popularized the pivot table, where they’re known as PivotTables. Pandas gives access to creating pivot tables in Python using the .pivot_table()
function. The function has the following default parameters:
# The syntax of the .pivot_table() function
import pandas as pd
pd.pivot_table(
data=,
values=None,
index=None,
columns=None,
aggfunc='mean',
fill_value=None,
margins=False,
dropna=True,
margins_name='All',
observed=False,
sort=True
)
The method takes a DataFrame and then also returns a DataFrame. The table below provides an overview of the different parameters available in the function:
Parameter | Default Value | Description |
---|---|---|
data= | The DataFrame to pivot | |
values= | The column to aggregate (if blank, will aggregate all numerical values) | |
index= | The column or columns to group data by. A single column can be a string, while multiple columns should be a list of strings | |
columns= | The column or columns to group data by. A single column can be a string, while multiple columns should be a list of strings | |
aggfunc= | ‘mean’ | A function or list of functions to aggregate data by |
fill_value= | Value to replace missing values with | |
margins= | False | Add a row and column for totals |
dropna= | True | To choose to not include columns where all entries are NaN |
margins_name= | ‘All’ | Name of total row/column |
observed= | False | Only for categorical data – if True will only show observed values for categorical groups |
sort= | True | Whether to sort the resulting values |
The function, in many ways, works to turn a long dataset into a wide dataset but also provides aggregations. In order to do the opposite, you can use the Pandas melt() function to convert a wide DataFrame into a long one.
Now that you have an understanding of the different parameters available in the function, let’s load in our data set and begin exploring our data.
Loading a Sample Pandas DataFrame
To follow along with this tutorial, let’s load a sample Pandas DataFrame. We can load the DataFrame from the file hosted on my GitHub page, using the pd.read_excel()
function. Then we can print out the first five records of the dataset using the .head()
method.
# Loading a Sample Pandas DataFrame
import pandas as pd
df = pd.read_excel('https://github.com/datagy/mediumdata/raw/master/sample_pivot.xlsx', parse_dates=['Date'])
print(df.head())
# Returns:
# Date Region Type Units Sales
# 0 2020-07-11 East Children's Clothing 18.0 306
# 1 2020-09-23 North Children's Clothing 14.0 448
# 2 2020-04-02 South Women's Clothing 17.0 425
# 3 2020-02-28 East Children's Clothing 26.0 832
# 4 2020-03-19 West Women's Clothing 3.0 33
Based on the output of the first five rows shown above, we can see that we have five columns to work with:
Column Name | Description |
---|---|
Date | Date of transaction |
Region | The region of the transaction |
Type | The type of clothing sold |
Units | The number of units sold |
Sales | The cost of the sale |
Now that we have a bit more context around the data, let’s explore creating our first pivot table in Pandas.
Creating a Pivot Table in Pandas
Let’s create your first Pandas pivot table. At a minimum, we have to pass in some form of a group key, either using the index=
or columns=
parameters. In the examples below, we’re using the Pandas function, rather than the DataFrame function. Because of this, we need to pass in the data=
argument. If we applied the method to the DataFrame directly, this would be implied.
# Creating your first Pandas pivot table
pivot = pd.pivot_table(
data=df,
index='Region'
)
print(pivot)
# Returns:
# Sales Units
# Region
# East 408.182482 19.732360
# North 438.924051 19.202643
# South 432.956204 20.423358
# West 452.029412 19.29411
Let’s break down what happened here:
- We created a new DataFrame called
sales_by_region
, which was created using thepd.pivot_table()
function - We passed in our DataFrame,
df
, and set theindex='region'
, meaning data would be grouped by the region column
Because all other parameters were left to their defaults, Pandas made the following assumption:
- Data should be aggregated by the average of each column (
aggfunc='mean'
) - The values should be any numeric columns
Aggregating Only Certain Columns in a Pandas Pivot Table
In the example above, you didn’t modify the values=
parameter. Because of this, all numeric columns were aggregated. This may not always be ideal. Because of this, Pandas allows us to pass in either a single string representing one column or a list of strings representing multiple columns.
Let’s now modify our code to only calculate the mean for a single column, Sales
:
# Aggreating Only A Single Column
pivot = pd.pivot_table(
data=df,
index='Region',
values='Sales'
)
print(pivot)
# Returns:
# Sales
# Region
# East 408.182482
# North 438.924051
# South 432.956204
# West 452.029412
We can see that instead of aggregating all numeric columns, only the one specified was aggregated.
Working with Aggregation Methods in a Pandas Pivot Table
Now that you’ve created your first pivot table in Pandas, let’s work on changing the aggregation methods. This allows you to specify how you want your data aggregated. This is where the power of Pandas really comes through, allowing you to calculate complex analyses with ease.
Specifying Aggregation Method in a Pandas Pivot Table
You can use the aggfunc=
(aggregation function) parameter to change how data are aggregated in a pivot table. By default, Pandas will use the .mean()
method to aggregate data. You can pass a named function, such as 'mean'
, 'sum'
, or 'max'
, or a function callable such as np.mean
.
Let’s now try to change our behavior to produce the sum of our sales across all regions:
# Specifying the Aggregation Function
pivot = pd.pivot_table(
data=df,
index='Region',
aggfunc='sum'
)
print(pivot)
# Returns:
# Sales Units
# Region
# East 167763 8110.0
# North 138700 4359.0
# South 59315 2798.0
# West 61476 2624.0
Multiple Aggregation Method in a Pandas DataFrame
Similarly, we can specify multiple aggregation methods to a Pandas pivot table. This is quite easy and only requires you to pass in a list of functions and the function will be applied to all values columns. Let’s produce aggregations for both the mean and the sum:
pivot = pd.pivot_table(
data=df,
index='Region',
aggfunc=['mean', 'sum']
)
print(pivot)
# Returns:
# mean sum
# Sales Units Sales Units
# Region
# East 408.182482 19.732360 167763 8110.0
# North 438.924051 19.202643 138700 4359.0
# South 432.956204 20.423358 59315 2798.0
# West 452.029412 19.294118 61476 2624.0
We can see how easy that was and how much more data it provides! For each column containing numeric data, both the mean and the sum are created.
Specifying Different Aggregations Per Column
Now, imagine you wanted to calculate different aggregations per column. In order to do this, you can pass in a dictionary containing the following key-value pair format: 'column': function
. Let’s say we wanted to calculate the sum of units and the average number of sales:
pivot = pd.pivot_table(
data=df,
index='Region',
aggfunc={'Sales': 'mean', 'Units': 'sum'}
)
print(pivot)
# Returns:
# Sales Units
# Region
# East 408.182482 8110.0
# North 438.924051 4359.0
# South 432.956204 2798.0
# West 452.029412 2624.0
This allows you to easily see data compared across different key performance indicators easily, in the same DataFrame.
Custom Aggregations in Pandas Pivot Tables
Pandas also allows us to pass in a custom function into the .pivot_table()
function. This greatly extends our ability to work with analyses specifically tailored to your needs! Let’s see how we can pass in a function that calculates the mean of a column without any outliers.
Pandas comes with a method, .quantiles()
, that lets us define a range of values we want to select based on their percentage in the range of data. Let’s say we wanted to calculate the average of a column, removing the top and bottom 10% of the data. We could define the following function:
# Defining a custom function
import numpy as np
def mean_no_outliers(values):
no_outliers = values.quantile([0.1, 0.9])
mean = np.mean(no_outliers)
return mean
This function accepts a single parameter, values
, which will be the values passed in by the .pivot_table()
function. The values are then filtered using the .quantile()
method. Finally, the mean of these values is calculated. Let’s see how we can use this (and the normal mean
aggregation) in our pivot table, applied to our Sales column.
# Specifying custom functions in a Pandas pivot table
pivot = pd.pivot_table(
data=df,
index='Region',
aggfunc=['mean', mean_no_outliers],
values='Sales'
)
print(pivot)
# Returns:
# mean mean_no_outliers
# Sales Sales
# Region
# East 408.182482 436.0
# North 438.924051 484.5
# South 432.956204 434.1
# West 452.029412 497.0
More Complex Pandas Pivot Tables
Now that you have an understanding of how the .pivot_table()
function works in Pandas, let’s take a look at how we can expand our understanding. In this section, you’ll learn how to add columns and multiple indices to our Pandas pivot tables.
Add Columns to a Pandas Pivot Table
When we add columns to a Pandas pivot table, we add another dimension to the data. While the index=
parameter splits the data vertically, the columns=
parameter groups and splits the data horizontally. This allows us to create an easy-to-read table. Let’s see how we can use the columns=
parameter to split the data by the Type column.
# Adding Columns to Our Pandas Pivot Table
pivot = pd.pivot_table(
data=df,
index='Region',
columns='Type',
values='Sales'
)
print(pivot)
# Returns:
# Type Children's Clothing Men's Clothing Women's Clothing
# Region
# East 405.743363 423.647541 399.028409
# North 438.894118 449.157303 432.528169
# South 412.666667 475.435897 418.924528
# West 480.523810 465.292683 419.188679
We can see how easy it was to add a whole other dimension of data. This allows us to spot differences between groupings in a format that’s easy to read.
Add Multiple Indices to Pandas Pivot Tables
While columns added a horizontal dimension, we can also specify multiple indices when there is a logical hierarchy in our data. For example, we can add a date dimension to our pivot table. Let’s use Pandas’ built-in date accessors to group our data by quarters. This allows us, then, to visualize our data over a period of time. Let’s see how this works:
pivot = pd.pivot_table(
data=df,
index=['Region',df['Date'].dt.quarter],
columns='Type',
values='Sales'
)
print(pivot.head())
# Returns:
# Type Children's Clothing Men's Clothing Women's Clothing
# Region Date
# East 1 423.241379 369.250000 428.948718
# 2 274.800000 445.425000 456.816327
# 3 425.382353 506.421053 342.386364
# 4 453.866667 405.666667 364.795455
# North 1 394.727273 450.869565 489.944444
This returns a multi-index Pandas DataFrame. While it may look more complex, accessing data in a multi-index Pandas DataFrame works quite similarly to accessing data in any other DataFrame. However, since now we have two index columns, rather than 1, we can pass in a tuple of indices. Let’s sat we wanted to access only the intersection of East Region, Quarter 1, and Men’s clothing, we could use the following code:
# Accessing data in a multi-index pivot table
print(pivot.loc[('East', 1), "Men's Clothing"])
# Returns: 369.25
Customizing Pandas Pivot Tables
In this section, you’ll learn how to customize your Pandas pivot tables. This allows you to add even further customizations, such as adding totals and working with missing data. You’ll also learn how to fill missing data in a resulting pivot table with a specific value.
Adding Totals to Pandas Pivot Tables
Let’s start off by learning how to add totals to a Pandas pivot table. This is controlled by the margins=
parameter, which accepts a boolean value. By default, this is set to False
, but toggling it to True
adds the totals to rows and columns. Let’s see what this looks like:
# Adding totals to rows and columns
pivot = pd.pivot_table(
data=df,
index='Region',
columns='Type',
values='Sales',
margins=True
)
print(pivot)
# Returns:
# Type Children's Clothing Men's Clothing Women's Clothing All
# Region
# East 405.743363 423.647541 399.028409 408.182482
# North 438.894118 449.157303 432.528169 438.924051
# South 412.666667 475.435897 418.924528 432.956204
# West 480.523810 465.292683 419.188679 452.029412
# All 427.743860 444.257732 415.254717 427.254000
By default, Pandas will name the totals 'All'
. If you wanted to rename these labels, you can use the margins_name=
parameter to pass in a string to relabel the values.
# Renaming totals in a Pandas pivot table
pivot = pd.pivot_table(
data=df,
index='Region',
columns='Type',
values='Sales',
margins=True,
margins_name='Total'
)
print(pivot)
# Returns:
# Type Children's Clothing Men's Clothing Women's Clothing Total
# Region
# East 405.743363 423.647541 399.028409 408.182482
# North 438.894118 449.157303 432.528169 438.924051
# South 412.666667 475.435897 418.924528 432.956204
# West 480.523810 465.292683 419.188679 452.029412
# Total 427.743860 444.257732 415.254717 427.254000
Dealing with Missing Data in a Pandas Pivot Table
When Pandas encounters a cross-section where no data exists, it’ll include a NaN
value in the resulting pivot table. Let’s modify our DataFrame to include some missing data and calculate a pivot table to see what this looks like:
# Adding and seeing missing data in a Pandas pivot table
import numpy as np
df.loc[(df['Region'] == 'East') & (df['Type'] == "Children's Clothing"), 'Sales'] = np.NaN
pivot = pd.pivot_table(
data=df,
index='Region',
columns='Type',
values='Sales',
)
print(pivot)
# Returns:
# Type Children's Clothing Men's Clothing Women's Clothing
# Region
# East NaN 423.647541 399.028409
# North 438.894118 449.157303 432.528169
# South 412.666667 475.435897 418.924528
# West 480.523810 465.292683 419.188679
It may not always be ideal to see a NaN
value, especially for non-technical audiences. Because of this, Pandas provides a parameter, fill_value=
, which enables you to pass in a value to fill these missing data points. For example, if we wanted to fill all these values with a 0, we can simply pass in this argument:
# Filling Missing Values in a Pandas Pivot Table
import numpy as np
df.loc[(df['Region'] == 'East') & (df['Type'] == "Children's Clothing"), 'Sales'] = np.NaN
pivot = pd.pivot_table(
data=df,
index='Region',
columns='Type',
values='Sales',
fill_value=0
)
print(pivot)
# Returns:
# Type Children's Clothing Men's Clothing Women's Clothing
# Region
# East 0.000000 423.647541 399.028409
# North 438.894118 449.157303 432.528169
# South 412.666667 475.435897 418.924528
# West 480.523810 465.292683 419.188679
Sorting Pandas Pivot Table Data
Beginning in Pandas version 1.3.0, a new parameter was added which enables you to sort the resulting DataFrame. Previously, you’d need to first generate the DataFrame and then pass in a method to sort the data. Now, you can simply pass in the sort=True
argument to help sort your resulting DataFrame.
# Sorting a Pandas Pivot Table
pivot = pd.pivot_table(
data=df,
index='Region',
values='Sales',
sort=True
)
print(pivot)
# Returns:
# Sales
# Region
# East 408.182482
# North 438.924051
# South 432.956204
# West 452.029412
By default, Pandas will sort the pivot table in ascending order. Unfortunately, for more complex sorting (such as across different columns), you would still need to chain the .sort_values()
method.
Filtering Python Pivot Tables
In this section, you’ll learn how to filter a Pandas pivot table. Because pivot tables can often be quite large, filtering a pivot table can focus the results quite a bit. Because the function returns a DataFrame, you can simply filter the DataFrame as you would any other. Let’s recreate our pivot table adding up values over quarters and regions.
# Generating a long pivot table
pivot = pd.pivot_table(
data=df,
index=['Region', df['Date'].dt.quarter],
values='Sales'
)
print(pivot.head())
# Returns:
# Sales
# Region Date
# East 1 406.692308
# 2 419.238532
# 3 403.608247
# 4 402.178218
# North 1 462.142857
What we can do now is either filter by a scalar value or by a dynamic value. For example, we could simply filter based on a hard coded value. But, say, for example we wanted to filter to only show records where the Sales average was larger than the overall average, we could write the following filter:
print(pivot[pivot['Sales'] > pivot['Sales'].mean()])
# Returns:
# Sales
# Region Date
# North 1 462.142857
# 2 442.034884
# 3 447.200000
# South 1 465.263158
# 2 440.628571
# West 1 475.000000
# 3 444.884615
# 4 466.209302
This allows us to see exactly what we want to see!
Exercises
It’s time to check your learning! Try to solve the exercises below based on what you learned. If you need help or want to verify your solution, toggle the section to see a sample solution. Use the same DataFrame as you did throughout the tutorial.
How would you remove the totals (all) row from a pivot table?
Sometimes you may just want to have the column totals in your resulting DataFrame. Because of this, you can simply filter out the last row, using the negative index:
pivot.loc[-1:,]
What type of value would you maybe not want to use as a fill_value parameter and why?
It’s recommended to keep to numeric data types (such as integers and floats) in order to prevent columns from being converted to columns that can’t have mathematical operations applied to them. Because columns in Pandas are homogeneous, it’s important to keep in mind what might happen.
What region had the highest sales in the DataFrame? How would you get the region’s name programmatically?
You can sort your pivot table, then use the .index
accessor to access the last value (since data are sorted in ascending order).
pivot = pd.pivot_table(
data=df,
index='Region',
values='Sales',
sort=True
)
print(pivot.index[-1])
# Returns: West
Conclusion and Recap
In this tutorial, you learned how to use the Pandas .pivot_table()
function to generate Excel-style pivot tables, directly off of a Pandas DataFrame. The function provides significant flexibility through a large assortment of parameters. The section below provides a summary of what you’ve learned:
- The Pandas
pivot_table()
function provides a familiar interface to create Excel-style pivot tables - The function requires at a minimum either the
index=
orcolumns=
parameters to specify how to split data - The function can calculate one or multiple aggregation methods, including using custom functions
- The function returns a DataFrame which can be filtered or queried as any other DataFrame
Additional Resources
To learn more about related topics, check out the tutorials below:
Pingback: Python: Reverse a String (6 Easy Ways) • datagy
Pingback: Pandas Dataframe to CSV File - Export Using .to_csv() • datagy
Pingback: Python: Shuffle a List (Randomize Python List Elements) • datagy
Pingback: Python: Get Filename From Path (Windows, Mac & Linux) • datagy
Pingback: Python: Check if a Key (or Value) Exists in a Dictionary (5 Easy Ways) • datagy
Pingback: Python: Find an Index (or all) of a Substring in a String • datagy
Pingback: Pandas: Add Days to a Date Column • datagy
Pingback: Python Zip Lists - Zip Two or More Lists in Python • datagy
Pingback: Pandas: Number of Rows in a Dataframe (6 Ways) • datagy
Pingback: Pandas Replace: Replace Values in Pandas Dataframe • datagy
Pingback: Matplotlib Scatter Charts – Learn all you need to know • datagy
Pingback: Reorder Pandas Columns: Pandas Reindex and Pandas insert • datagy
Pingback: How to Drop Duplicates in Pandas - Subset and Keep • datagy
Pingback: Python Dictionary Comprehensions (With Examples) • datagy
Pingback: Pandas get dummies (One-Hot Encoding) Explained • datagy
Pingback: VLOOKUP in Python and Pandas using .map() or .merge() • datagy
Pingback: Length of List in Python: Calculate how many Items a List has • datagy
Pingback: Python: Remove Duplicates From a List (7 Ways) • datagy
Pingback: Matplotlib Line Charts - Learn all you need to know • datagy
Pingback: Using Pandas for Descriptive Statistics in Python • datagy
Pingback: Python: Check If a String is a Palindrome (5 Easy Ways!) • datagy
where is your sample file ‘sample_pivot.xlsx’ ,
it gives HTTPError: HTTP Error 404: Not Found
Hi there! Thanks so much for flagging that. I have fixed the URL :).
Hello, the totals in the adding totals part are not correct.
Hi Dup, thanks for your comment! In which part?
Hello, in the chapter ‘Adding Totals to Pandas Pivot Tables’.
For example the East region gives 405.743363, 423.647541 and 399.028409. The all column gives 408.182482 instead of 1228.419313.
Hi Dup, thanks! That helps. The reason this happens is that we’re calculating the mean, by default. If, instead, we wanted to calculate the sum across the rows, we’d need to change the aggfunc= argument :).
Sorry, my mistake. I over read the default value for aggfunc in the syntax of the .pivot_table() function.
No problem! Thanks for mentioning it! I agree, it’s not the most intuitive!