You may be familiar with pivot tables in Excel to generate easy insights into your data. In this post, you’ll learn how to create pivot tables in Python and Pandas using the `.pivot_table()`

method. This post will give you a complete overview of how to use the `.pivot_table()`

function!

Being able to quickly summarize data is an important skill to be able to get a sense of what your data looks like. The function is quite similar to the `.groupby()`

method also available in Pandas, but offers significantly more customization, as we’ll see later on in this post.

By the end of this tutorial, you’ll have learned:

- How to use the
`pivot_table()`

function and what its parameters represent - How to group data using an index or a multi-index
- How to pivot table even further using indices and columns
- How to specify and create your own aggregation methods
- How to calculate totals and deal with missing data

Table of Contents

## Python Pivot Tables Video Tutorial

## How to Build a Pivot Table in Python

A pivot table is a table of statistics that helps summarize the data of a larger table by “pivoting” that data. Microsoft Excel popularized the pivot table, where they’re known as PivotTables. Pandas gives access to creating pivot tables in Python using the `.pivot_table()`

function. The function has the following default parameters:

```
# The syntax of the .pivot_table() function
import pandas as pd
pd.pivot_table(
data=,
values=None,
index=None,
columns=None,
aggfunc='mean',
fill_value=None,
margins=False,
dropna=True,
margins_name='All',
observed=False,
sort=True
)
```

The method takes a DataFrame and then also returns a DataFrame. The table below provides an overview of the different parameters available in the function:

Parameter | Default Value | Description |
---|---|---|

`data=` | The DataFrame to pivot | |

`values=` | The column to aggregate (if blank, will aggregate all numerical values) | |

`index=` | The column or columns to group data by. A single column can be a string, while multiple columns should be a list of strings | |

`columns=` | The column or columns to group data by. A single column can be a string, while multiple columns should be a list of strings | |

`aggfunc=` | ‘mean’ | A function or list of functions to aggregate data by |

`fill_value=` | Value to replace missing values with | |

`margins=` | False | Add a row and column for totals |

`dropna=` | True | To choose to not include columns where all entries are NaN |

`margins_name=` | ‘All’ | Name of total row/column |

`observed=` | False | Only for categorical data – if `True` will only show observed values for categorical groups |

`sort=` | True | Whether to sort the resulting values |

Now that you have an understanding of the different parameters available in the function, let’s load in our data set and begin exploring our data.

## Loading a Sample Pandas DataFrame

To follow along with this tutorial, let’s load a sample Pandas DataFrame. We can load the DataFrame from the file hosted on my GitHub page, using the `pd.read_excel()`

function. Then we can print out the first five records of the dataset using the `.head()`

method.

```
# Loading a Sample Pandas DataFrame
import pandas as pd
df = pd.read_excel('https://github.com/datagy/mediumdata/raw/master/sample_pivot.xlsx', parse_dates=['Date'])
print(df.head())
# Returns:
# Date Region Type Units Sales
# 0 2020-07-11 East Children's Clothing 18.0 306
# 1 2020-09-23 North Children's Clothing 14.0 448
# 2 2020-04-02 South Women's Clothing 17.0 425
# 3 2020-02-28 East Children's Clothing 26.0 832
# 4 2020-03-19 West Women's Clothing 3.0 33
```

Based on the output of the first five rows shown above, we can see that we have five columns to work with:

Column Name | Description |
---|---|

Date | Date of transaction |

Region | The region of the transaction |

Type | The type of clothing sold |

Units | The number of units sold |

Sales | The cost of the sale |

Now that we have a bit more context around the data, let’s explore creating our first pivot table in Pandas.

## Creating a Pivot Table in Pandas

Let’s create your first Pandas pivot table. At a minimum, we have to pass in some form of a group key, either using the `index=`

or `columns=`

parameters. In the examples below, we’re using the Pandas function, rather than the DataFrame function. Because of this, we need to pass in the `data=`

argument. If we applied the method to the DataFrame directly, this would be implied.

```
# Creating your first Pandas pivot table
pivot = pd.pivot_table(
data=df,
index='Region'
)
print(pivot)
# Returns:
# Sales Units
# Region
# East 408.182482 19.732360
# North 438.924051 19.202643
# South 432.956204 20.423358
# West 452.029412 19.29411
```

Let’s break down what happened here:

- We created a new DataFrame called
`sales_by_region`

, which was created using the`pd.pivot_table()`

function - We passed in our DataFrame,
`df`

, and set the`index='region'`

, meaning data would be grouped by the region column

Because all other parameters were left to their defaults, Pandas made the following assumption:

- Data should be aggregated by the average of each column (
`aggfunc='mean'`

) - The values should be any numeric columns

## Aggregating Only Certain Columns in a Pandas Pivot Table

In the example above, you didn’t modify the `values=`

parameter. Because of this, all numeric columns were aggregated. This may not always be ideal. Because of this, Pandas allows us to pass in either a single string representing one column or a list of strings representing multiple columns.

Let’s now modify our code to only calculate the mean for a single column, `Sales`

:

```
# Aggreating Only A Single Column
pivot = pd.pivot_table(
data=df,
index='Region',
values='Sales'
)
print(pivot)
# Returns:
# Sales
# Region
# East 408.182482
# North 438.924051
# South 432.956204
# West 452.029412
```

We can see that instead of aggregating all numeric columns, only the one specified was aggregated.

## Working with Aggregation Methods in a Pandas Pivot Table

Now that you’ve created your first pivot table in Pandas, let’s work on changing the aggregation methods. This allows you to specify *how* you want your data aggregated. This is where the power of Pandas really comes through, allowing you to calculate complex analyses with ease.

### Specifying Aggregation Method in a Pandas Pivot Table

You can use the `aggfunc=`

(aggregation function) parameter to change *how* data are aggregated in a pivot table. By default, Pandas will use the `.mean()`

method to aggregate data. You can pass a named function, such as `'mean'`

, `'sum'`

, or `'max'`

, or a function callable such as `np.mean`

.

Let’s now try to change our behavior to produce the sum of our sales across all regions:

```
# Specifying the Aggregation Function
pivot = pd.pivot_table(
data=df,
index='Region',
aggfunc='sum'
)
print(pivot)
# Returns:
# Sales Units
# Region
# East 167763 8110.0
# North 138700 4359.0
# South 59315 2798.0
# West 61476 2624.0
```

### Multiple Aggregation Method in a Pandas DataFrame

Similarly, we can specify *multiple* aggregation methods to a Pandas pivot table. This is quite easy and only requires you to pass in a list of functions and the function will be applied to all values columns. Let’s produce aggregations for both the mean and the sum:

```
pivot = pd.pivot_table(
data=df,
index='Region',
aggfunc=['mean', 'sum']
)
print(pivot)
# Returns:
# mean sum
# Sales Units Sales Units
# Region
# East 408.182482 19.732360 167763 8110.0
# North 438.924051 19.202643 138700 4359.0
# South 432.956204 20.423358 59315 2798.0
# West 452.029412 19.294118 61476 2624.0
```

We can see how easy that was and how much more data it provides! For each column containing numeric data, both the mean and the sum are created.

### Specifying Different Aggregations Per Column

Now, imagine you wanted to calculate different aggregations per column. In order to do this, you can pass in a dictionary containing the following key-value pair format: `'column': function`

. Let’s say we wanted to calculate the sum of units and the average number of sales:

```
pivot = pd.pivot_table(
data=df,
index='Region',
aggfunc={'Sales': 'mean', 'Units': 'sum'}
)
print(pivot)
# Returns:
# Sales Units
# Region
# East 408.182482 8110.0
# North 438.924051 4359.0
# South 432.956204 2798.0
# West 452.029412 2624.0
```

This allows you to easily see data compared across different key performance indicators easily, in the same DataFrame.

### Custom Aggregations in Pandas Pivot Tables

Pandas also allows us to pass in a custom function into the `.pivot_table()`

function. This greatly extends our ability to work with analyses specifically tailored to your needs! Let’s see how we can pass in a function that calculates the mean of a column without any outliers.

Pandas comes with a method, `.quantiles()`

, that lets us define a range of values we want to select based on their percentage in the range of data. Let’s say we wanted to calculate the average of a column, removing the top and bottom 10% of the data. We could define the following function:

```
# Defining a custom function
import numpy as np
def mean_no_outliers(values):
no_outliers = values.quantile([0.1, 0.9])
mean = np.mean(no_outliers)
return mean
```

This function accepts a single parameter, `values`

, which will be the values passed in by the `.pivot_table()`

function. The values are then filtered using the `.quantile()`

method. Finally, the mean of these values is calculated. Let’s see how we can use this (and the normal `mean`

aggregation) in our pivot table, applied to our Sales column.

```
# Specifying custom functions in a Pandas pivot table
pivot = pd.pivot_table(
data=df,
index='Region',
aggfunc=['mean', mean_no_outliers],
values='Sales'
)
print(pivot)
# Returns:
# mean mean_no_outliers
# Sales Sales
# Region
# East 408.182482 436.0
# North 438.924051 484.5
# South 432.956204 434.1
# West 452.029412 497.0
```

## More Complex Pandas Pivot Tables

Now that you have an understanding of how the `.pivot_table()`

function works in Pandas, let’s take a look at how we can expand our understanding. In this section, you’ll learn how to add columns and multiple indices to our Pandas pivot tables.

### Add Columns to a Pandas Pivot Table

When we add columns to a Pandas pivot table, we add another dimension to the data. While the `index=`

parameter splits the data vertically, the `columns=`

parameter groups and splits the data horizontally. This allows us to create an easy-to-read table. Let’s see how we can use the `columns=`

parameter to split the data by the Type column.

```
# Adding Columns to Our Pandas Pivot Table
pivot = pd.pivot_table(
data=df,
index='Region',
columns='Type',
values='Sales'
)
print(pivot)
# Returns:
# Type Children's Clothing Men's Clothing Women's Clothing
# Region
# East 405.743363 423.647541 399.028409
# North 438.894118 449.157303 432.528169
# South 412.666667 475.435897 418.924528
# West 480.523810 465.292683 419.188679
```

We can see how easy it was to add a whole other dimension of data. This allows us to spot differences between groupings in a format that’s easy to read.

### Add Multiple Indices to Pandas Pivot Tables

While columns added a horizontal dimension, we can also specify multiple indices when there is a logical hierarchy in our data. For example, we can add a date dimension to our pivot table. Let’s use Pandas’ built-in date accessors to group our data by quarters. This allows us, then, to visualize our data over a period of time. Let’s see how this works:

```
pivot = pd.pivot_table(
data=df,
index=['Region',df['Date'].dt.quarter],
columns='Type',
values='Sales'
)
print(pivot.head())
# Returns:
# Type Children's Clothing Men's Clothing Women's Clothing
# Region Date
# East 1 423.241379 369.250000 428.948718
# 2 274.800000 445.425000 456.816327
# 3 425.382353 506.421053 342.386364
# 4 453.866667 405.666667 364.795455
# North 1 394.727273 450.869565 489.944444
```

This returns a multi-index Pandas DataFrame. While it may look more complex, accessing data in a multi-index Pandas DataFrame works quite similarly to accessing data in any other DataFrame. However, since now we have *two* index columns, rather than 1, we can pass in a tuple of indices. Let’s sat we wanted to access only the intersection of East Region, Quarter 1, and Men’s clothing, we could use the following code:

```
# Accessing data in a multi-index pivot table
print(pivot.loc[('East', 1), "Men's Clothing"])
# Returns: 369.25
```

## Customizing Pandas Pivot Tables

In this section, you’ll learn how to customize your Pandas pivot tables. This allows you to add even further customizations, such as adding totals and working with missing data. You’ll also learn how to fill missing data in a resulting pivot table with a specific value.

### Adding Totals to Pandas Pivot Tables

Let’s start off by learning how to add totals to a Pandas pivot table. This is controlled by the `margins=`

parameter, which accepts a boolean value. By default, this is set to `False`

, but toggling it to `True`

adds the totals to rows and columns. Let’s see what this looks like:

```
# Adding totals to rows and columns
pivot = pd.pivot_table(
data=df,
index='Region',
columns='Type',
values='Sales',
margins=True
)
print(pivot)
# Returns:
# Type Children's Clothing Men's Clothing Women's Clothing All
# Region
# East 405.743363 423.647541 399.028409 408.182482
# North 438.894118 449.157303 432.528169 438.924051
# South 412.666667 475.435897 418.924528 432.956204
# West 480.523810 465.292683 419.188679 452.029412
# All 427.743860 444.257732 415.254717 427.254000
```

By default, Pandas will name the totals `'All'`

. If you wanted to rename these labels, you can use the `margins_name=`

parameter to pass in a string to relabel the values.

```
# Renaming totals in a Pandas pivot table
pivot = pd.pivot_table(
data=df,
index='Region',
columns='Type',
values='Sales',
margins=True,
margins_name='Total'
)
print(pivot)
# Returns:
# Type Children's Clothing Men's Clothing Women's Clothing Total
# Region
# East 405.743363 423.647541 399.028409 408.182482
# North 438.894118 449.157303 432.528169 438.924051
# South 412.666667 475.435897 418.924528 432.956204
# West 480.523810 465.292683 419.188679 452.029412
# Total 427.743860 444.257732 415.254717 427.254000
```

### Dealing with Missing Data in a Pandas Pivot Table

When Pandas encounters a cross-section where no data exists, it’ll include a `NaN`

value in the resulting pivot table. Let’s modify our DataFrame to include some missing data and calculate a pivot table to see what this looks like:

```
# Adding and seeing missing data in a Pandas pivot table
import numpy as np
df.loc[(df['Region'] == 'East') & (df['Type'] == "Children's Clothing"), 'Sales'] = np.NaN
pivot = pd.pivot_table(
data=df,
index='Region',
columns='Type',
values='Sales',
)
print(pivot)
# Returns:
# Type Children's Clothing Men's Clothing Women's Clothing
# Region
# East NaN 423.647541 399.028409
# North 438.894118 449.157303 432.528169
# South 412.666667 475.435897 418.924528
# West 480.523810 465.292683 419.188679
```

It may not always be ideal to see a `NaN`

value, especially for non-technical audiences. Because of this, Pandas provides a parameter, `fill_value=`

, which enables you to pass in a value to fill these missing data points. For example, if we wanted to fill all these values with a 0, we can simply pass in this argument:

```
# Filling Missing Values in a Pandas Pivot Table
import numpy as np
df.loc[(df['Region'] == 'East') & (df['Type'] == "Children's Clothing"), 'Sales'] = np.NaN
pivot = pd.pivot_table(
data=df,
index='Region',
columns='Type',
values='Sales',
fill_value=0
)
print(pivot)
# Returns:
# Type Children's Clothing Men's Clothing Women's Clothing
# Region
# East 0.000000 423.647541 399.028409
# North 438.894118 449.157303 432.528169
# South 412.666667 475.435897 418.924528
# West 480.523810 465.292683 419.188679
```

### Sorting Pandas Pivot Table Data

Beginning in Pandas version 1.3.0, a new parameter was added which enables you to sort the resulting DataFrame. Previously, you’d need to first generate the DataFrame and then pass in a method to sort the data. Now, you can simply pass in the `sort=True`

argument to help sort your resulting DataFrame.

```
# Sorting a Pandas Pivot Table
pivot = pd.pivot_table(
data=df,
index='Region',
values='Sales',
sort=True
)
print(pivot)
# Returns:
# Sales
# Region
# East 408.182482
# North 438.924051
# South 432.956204
# West 452.029412
```

By default, Pandas will sort the pivot table in ascending order. Unfortunately, for more complex sorting (such as across different columns), you would still need to chain the `.sort_values()`

method.

## Filtering Python Pivot Tables

In this section, you’ll learn how to filter a Pandas pivot table. Because pivot tables can often be quite large, filtering a pivot table can focus the results quite a bit. Because the function returns a DataFrame, you can simply filter the DataFrame as you would any other. Let’s recreate our pivot table adding up values over quarters and regions.

```
# Generating a long pivot table
pivot = pd.pivot_table(
data=df,
index=['Region', df['Date'].dt.quarter],
values='Sales'
)
print(pivot.head())
# Returns:
# Sales
# Region Date
# East 1 406.692308
# 2 419.238532
# 3 403.608247
# 4 402.178218
# North 1 462.142857
```

What we can do now is either filter by a scalar value or by a dynamic value. For example, we could simply filter based on a hard coded value. But, say, for example we wanted to filter to only show records where the Sales average was larger than the overall average, we could write the following filter:

```
print(pivot[pivot['Sales'] > pivot['Sales'].mean()])
# Returns:
# Sales
# Region Date
# North 1 462.142857
# 2 442.034884
# 3 447.200000
# South 1 465.263158
# 2 440.628571
# West 1 475.000000
# 3 444.884615
# 4 466.209302
```

This allows us to see exactly what we want to see!

## Exercises

It’s time to check your learning! Try to solve the exercises below based on what you learned. If you need help or want to verify your solution, toggle the section to see a sample solution. Use the same DataFrame as you did throughout the tutorial.

## Conclusion and Recap

In this tutorial, you learned how to use the Pandas `.pivot_table()`

function to generate Excel-style pivot tables, directly off of a Pandas DataFrame. The function provides significant flexibility through a large assortment of parameters. The section below provides a summary of what you’ve learned:

- The Pandas
`pivot_table()`

function provides a familiar interface to create Excel-style pivot tables - The function requires at a minimum either the
`index=`

or`columns=`

parameters to specify how to split data - The function can calculate one or multiple aggregation methods, including using custom functions
- The function returns a DataFrame which can be filtered or queried as any other DataFrame

## Additional Resources

To learn more about related topics, check out the tutorials below:

Pingback: Python: Reverse a String (6 Easy Ways) • datagy

Pingback: Pandas Dataframe to CSV File - Export Using .to_csv() • datagy

Pingback: Python: Shuffle a List (Randomize Python List Elements) • datagy

Pingback: Python: Get Filename From Path (Windows, Mac & Linux) • datagy

Pingback: Python: Check if a Key (or Value) Exists in a Dictionary (5 Easy Ways) • datagy

Pingback: Python: Find an Index (or all) of a Substring in a String • datagy

Pingback: Pandas: Add Days to a Date Column • datagy

Pingback: Length of List in Python: Calculate how many Items a List has • datagy

Pingback: Python Zip Lists - Zip Two or More Lists in Python • datagy

Pingback: Pandas: Number of Rows in a Dataframe (6 Ways) • datagy

Pingback: Pandas Replace: Replace Values in Pandas Dataframe • datagy

Pingback: Matplotlib Scatter Charts – Learn all you need to know • datagy

Pingback: Reorder Pandas Columns: Pandas Reindex and Pandas insert • datagy

Pingback: How to Drop Duplicates in Pandas - Subset and Keep • datagy

Pingback: Python Dictionary Comprehensions (With Examples) • datagy

Pingback: Pandas get dummies (One-Hot Encoding) Explained • datagy

Pingback: VLOOKUP in Python and Pandas using .map() or .merge() • datagy

Pingback: Python: Remove Duplicates From a List (7 Ways) • datagy

Pingback: Matplotlib Line Charts - Learn all you need to know • datagy

Pingback: Using Pandas for Descriptive Statistics in Python • datagy

Pingback: Python: Check If a String is a Palindrome (5 Easy Ways!) • datagy

where is your sample file ‘sample_pivot.xlsx’ ,

it gives HTTPError: HTTP Error 404: Not Found

Hi there! Thanks so much for flagging that. I have fixed the URL :).