The expand_dims() function in NumPy allows us to expand the dimensions of a given input array. In simple terms, the function enables you to expand the shape of a given array.
It works by adding a new axis into the array that will appear at the axis position resulting in an expanded shape.
Let us explore how this function works.
Function Syntax
The function syntax is as shown below:
1 | numpy.expand_dims(a, axis) |
The function has a relatively simple syntax. It accepts the parameters as shown:
- a – refers to the input array.
- axis – specifies the position in the output array where the axis is positioned.
Return Value
The function returns a view of the input array with the dimensions expanded according to the specified parameters.
Example
Consider the example code shown below:
1 2 3 4 5 6 | # import numpy import numpy as np arr = np.array([1,2,3,4]) print(f"before: {arr.shape}") arr = np.expand_dims(arr, axis=1) print(f"after: {arr.shape}") |
The example above uses the expand_dims() function to expand the shape of a one-dimensional array.
An example output is as shown:
1 2 | before: (4,) after: (4, 1) |
Example 2
We can also perform the same operation on a two-dimensional array. An example is as shown:
1 2 3 4 | arr = np.array([[1,2,3], [4,5,6]]) print(f"before: {arr.shape}") arr = np.expand_dims(arr, axis=0) print(f"after: {arr.shape}") |
The resulting output are as shown:
1 2 | before: (2, 3) after: (1, 2, 3) |
Conclusion
This tutorial illustrates how to use the expand_dims() function in NumPy to alter the shape of an input array.
Happy coding, my friends!!