Python

NumPy Np.Expand_Dims()

The expand_dims() function in NumPy allows us to expand the dimensions of a given input array. In simple terms, the function enables you to expand the shape of a given array.

It works by adding a new axis into the array that will appear at the axis position resulting in an expanded shape.

Let us explore how this function works.

Function Syntax

The function syntax is as shown below:

1
numpy.expand_dims(a, axis)

The function has a relatively simple syntax. It accepts the parameters as shown:

  1. a – refers to the input array.
  2. axis – specifies the position in the output array where the axis is positioned.

Return Value

The function returns a view of the input array with the dimensions expanded according to the specified parameters.

Example

Consider the example code shown below:

1
2
3
4
5
6
# import numpy
import numpy as np
arr = np.array([1,2,3,4])
print(f"before: {arr.shape}")
arr = np.expand_dims(arr, axis=1)
print(f"after: {arr.shape}")

The example above uses the expand_dims() function to expand the shape of a one-dimensional array.

An example output is as shown:

1
2
before: (4,)
after: (4, 1)

Example 2

We can also perform the same operation on a two-dimensional array. An example is as shown:

1
2
3
4
arr = np.array([[1,2,3], [4,5,6]])
print(f"before: {arr.shape}")
arr = np.expand_dims(arr, axis=0)
print(f"after: {arr.shape}")

The resulting output are as shown:

1
2
before: (2, 3)
after: (1, 2, 3)

Conclusion

This tutorial illustrates how to use the expand_dims() function in NumPy to alter the shape of an input array.

Happy coding, my friends!!

About the author

John Otieno

My name is John and am a fellow geek like you. I am passionate about all things computers from Hardware, Operating systems to Programming. My dream is to share my knowledge with the world and help out fellow geeks. Follow my content by subscribing to LinuxHint mailing list