In this tutorial, you’ll learn how to use the NumPy squeeze() function. The np.squeeze()
function allows you to remove single-dimensional entries from an array’s shape. This allows you to better transform arrays that aren’t shaped in the way that makes sense for the work that you’re doing.
In short, the function returns the input array with any subset of the array where the dimension length is equal to 1 is removed.
By the end of this tutorial, you’ll have learned:
- The syntax of the NumPy squeeze() function
- How to use the NumPy squeeze() function to reduce the dimensionality
- How to squeeze only certain axes in NumPy
Table of Contents
Understanding the NumPy squeeze() Function Syntax
Before diving into how to use the np.squeeze()
function, let’s take a look at the different parameters of the function. What’s great about the np.squeeze()
function is that it can also be applied as a NumPy array method. In that case, the array – of course – doesn’t need to be passed in.
Let’s take a look at the parameters of the function:
# Understanding the np.squeeze() Function
np.squeeze(
a=, # The array to squeeze
axis = None # Selecting a subset of entries
)
We can see from the code block above that the function takes two parameters:
- The array to squeeze, which must simply be array-like (such as a Python list)
- The axis to squeeze, which can be
None
, and integer, or a tuple of integers
Now that you know the syntax of the np.squeeze()
function, let’s start looking at a simple example of the function being put to use.
How to Use NumPy squeeze() On a NumPy Array
In this section, you’ll learn how to use the NumPy squeeze() function with an example. We’ll create an array that has some additional dimensions with a length of 1. Let’s create this array first, using the np.linspace()
function and reshape it.
# Creating an Array in NumPy
import numpy as np
arr = np.linspace(0, 5, 6).reshape(1,2,3)
print(arr)
# Returns:
# [[[0. 1. 2.]
# [3. 4. 5.]]]
We can check the shape of the array by using the .shape
attribute, as shown below:
# Checking the Shape of Our NumPy Array
import numpy as np
arr = np.linspace(0, 5, 6).reshape(1,2,3)
print(arr.shape)
# Returns: (1, 2, 3)
We can see that there is a potentially superfluous dimension. We can use the np.squeeze()
function to remove this dimension from our array. Let’s see how this is done:
# Removing a Dimension of Length 1 Using np.squeeze()
import numpy as np
arr = np.linspace(0, 5, 6).reshape(1,2,3)
print(f'Original array shape: {arr.shape}')
arr = np.squeeze(arr)
print(f'Modified array shape: {arr.shape}')
# Returns:
# Original array shape: (1, 2, 3)
# Modified array shape: (2, 3)
We can see we originally had an array with the shape of (1, 2, 3)
. When we squeezed the array, it removed the original dimension of length 1.
How to Use NumPy squeeze() with Only Some Axis
In the example above, we used the NumPy squeeze() function to indiscriminately remove dimensions of length 1. By default, NumPy will pass in an axis=None
parameter which will do just this. If we didn’t want this behavior, we can modify the axis=
parameter in order to only squeeze out some dimensions.
Let’s create a more complex array to work with:
# Creating an Array with Multiple Dimensions of Length 1
import numpy as np
arr = np.linspace(0, 5, 6).reshape(1,2,1,3)
print(arr)
# Returns:
# [[[[0. 1. 2.]]
#
# [[3. 4. 5.]]]]
We can see that the dimensions of the array have two axes with lengths of exactly 1. Let’s see how we can remove only one of these axes.
Let’s see how we can use the np.squeeze()
function to remove the third axis. Because Python is 0-based indexed, we need to pass in the axis of 2:
# Squeezing Only Particular Axes in NumPy
import numpy as np
arr = np.linspace(0, 5, 6).reshape(1,2,1,3)
print(f'Original array shape: {arr.shape}')
arr = np.squeeze(arr, 2)
print(f'Modified array shape: {arr.shape}')
# Returns:
# Original array shape: (1, 2, 1, 3)
# Modified array shape: (1, 2, 3)
We can see that the function only removed the third axis but kept the first one. It’s important to note, however, that if we’d tried to remove an axis where the length wasn’t 1, an error would’ve been raised.
Conclusion
In this tutorial, you learned how to use the NumPy squeeze() function to reduce the dimensionality of an array. The function removes axes of length 1, which can get in the way of how you run your analysis. You first learned how to understand the syntax of the function and its two parameters. Then, you learned how to use the function with an example. Finally, you learned how to use the function to remove only specific axes.
Additional Resources
To learn more about related topics, check out the tutorials below: