In this tutorial, you’ll learn how to use the NumPy argmax() function to find the index of the largest value in an array. The np.argmax()
function can be used to find the maximum value across an array, as well as across axes of different dimensions. The function can often be confusing, though this tutorial should clear up any confusion you have about how the function works.
By the end of this tutorial, you’ll have learned:
- How the np.argmax() function works
- How to use the np.argmax() to find the index of the maximum value
- How to use the argmax function to filter a Pandas DataFrame
Let’s get started!
Table of Contents
What is the NumPy argmax Function?
The NumPy argmax() function is used to return the index of the maximum value (or values) of an array, along a particular axis. Before diving much further in, let’s take a look at the what the function looks like and what parameters it has:
# Understanding the np.argmax() Function
np.argmax(
a,
axis=None,
out=None,
keepdims=<no value>
)
The table below breaks down the various parameters of the function and their default values, and provides a brief description of what they do:
Parameter | Default Value | Accepted Values | Description |
---|---|---|---|
a= | None | array-like | The input array to use to find maximum values in. |
axis= | None | integer | By default, the array is flattened. By providing an axis, the maximum values are found along specified axes. |
out= | None | array | If provided, the results will be inserted into an array. |
keepdims= | <no value> | boolean | If True , the axes which are reduced are left as dimensions with size one. |
np.argmax()
functionIn the next section, you’ll learn how the np.argmax() function works with some practical examples.
How does the NumPy argmax() Function Work?
In this section, you’ll learn how to use the NumPy argmax() function to find the index of the maximum value (or values) across an array. Let’s start by looking at a one-dimensional array to illustrate how the function works:
# A Simple Example Using np.argmax()
import numpy as np
arr = np.array([1,3,5,2,4])
max_idx = np.argmax(arr)
print(max_idx)
# Returns: 2
Let’s break down what we did above:
- We imported NumPy using the alias
np
- We created an array,
arr
, which contains five values - We created a variable,
max_idx
, which was the result of passing the array into thenp.argmax()
function
When we printed our resulting value, we saw that it returned 2
. At this point, this should make quite a bit of sense: we can see that 5
is the largest value and that it exists at index position 2
.
Now let’s take a look at a more complicated example: we’ll pass in a two-dimensional array with the default parameters:
# Passing a 2-Dimensional Array into np.argmax()
import numpy as np
arr = np.array([[1,3,5,2,4], [5,2,3,4,6]])
max_idx = np.argmax(arr)
print(max_idx)
# Returns: 9
The result above may surprise you: we passed in a two-dimensional array, yet only a single value was returned. Similarly, neither of the arrays has a length of 9. To many users of the function, they would likely expect an array with two values to be returned, each providing the index of the maximum value.
However, because the axis
parameter is set to None
, NumPy flatterns the arrays into a single dimension and returns the index of the maximum value of the resulting array. In the following section, you’ll learn how to modify the behavior of the np.argmax()
function to find the indices of max values across different axes.
How to Change the Axis of the NumPy argmax() Function
The np.argmax()
function takes an optional parameter of axis=
. By default, this is set to None
, meaning that the passed in array is flattened to a single dimension. We can pass in either 0
for column-wise sorting or 1
for row-wise sorting.
Let’s take a look at earlier example again and see what the results look like when manipulating the axis=
parameter:
# Using axis=0 to Modify the np.argmax() Function
import numpy as np
arr = np.array([[1,3,5,2,4], [5,2,3,4,6]])
max_idx = np.argmax(arr, axis=0)
print(max_idx)
# Returns: [1 0 0 1 1]
When we pass in the axis=0
argument, the function returns the index of each max value from the column perspective. Because of this, the length of the array is as wide as the passed in array is.
If we wanted to return the max values across each row, we could pass in axis=1
. The code below shows how this is done:
# Using axis=1 to Modify the np.argmax() Function
import numpy as np
arr = np.array([[1,3,5,2,4], [5,2,3,4,6]])
max_idx = np.argmax(arr, axis=1)
print(max_idx)
# Returns: [2 4]
In the example above, the function returns an array of length 2, where each value corresponds to the index position of the max value in each “row”.
How to Use argmax to Filter a Pandas DataFrame
In this section, you’ll learn how to use the np.argmax()
function to filter a Pandas DataFrame. If we want to return the row that contains the max value in a given column, we can use the np.argmax()
function to find the index of that row. Let’s take a look at how this would work:
# Filtering a Pandas DataFrame with np.argmax()
import numpy as np
import pandas as pd
df = pd.DataFrame.from_dict({
'Name': ['John', 'Jane', 'Mary', 'Bob'],
'Age': [23, 14, 19, 25],
})
max_row = df.iloc[np.argmax(df['Age'])]
print(max_row)
# Returns:
# Name Bob
# Age 25
# Name: 3, dtype: object
In the code above, we used the .loc
accessor to filter our DataFrame based on the index position returned by the np.argmax()
function.
Conclusion
In this tutorial, you learned how to use the np.argmax()
function to find the index position of the max value (or values) in a NumPy array. You first learned how to understand the parameters of the function. Then, you learned how to use the function on a single-dimensional array as well as multi-dimensional arrays along different axes. Finally, you learned how to use the function to filter a Pandas DataFrame to find the index of the row containing the max value.
Additional Resources
To learn more about about related topics, check out the tutorials below: