Python

NumPy Np.Argwhere()

The argwhere() function in NumPy is used to find the indices of non-zero elements in an array, grouped by component (indices).

Sounds confusing? Stay tuned along this tutorial to explore further.

Function Syntax

The function has an incredibly simple syntax, as shown in the definition below:

1
numpy.argwhere(a)

As shown, the function takes only one parameter:

  1. a – refers to the input array or array_like object.

The function will then return the indices of the non-zero elements in the array grouped by the element.

The resulting array contains the shape (N, a.ndim), where N is the number of non-zero elements and a.ndim is the number of array dimensions of the input array.

Example #1

Take the example code shown below:

1
2
3
4
5
6
7
# import numpy
import numpy as np
# generate array:
arr = np.arange(10).reshape(2,5)
print(arr)
# return index of non-zero elements
print(np.argwhere(arr))

In the example code above, we start by importing NumPy.

We then create an array using the arrange function and reshape it into the shape of (2,5).

Finally, we get the indices of the non-zero elements using the argwhere function.

1
2
3
4
5
6
7
8
9
[[0 1]
[0 2]
[0 3]
[0 4]
[1 0]
[1 1]
[1 2]
[1 3]
[1 4]]

Example #2

You can also tell the function to return the indices that match a specific condition. For example, to get the index that is greater than or equal to 3, we can do the following:

1
print(np.argwhere(arr>=3))

The resulting output:

1
2
3
4
5
6
7
[[0 3]
[0 4]
[1 0]
[1 1]
[1 2]
[1 3]
[1 4]]

Conclusion

This article gives fundamental knowledge of using the argwhere() function in NumPy. Feel free to explore the docs for more.

Happy coding!!

About the author

John Otieno

My name is John and am a fellow geek like you. I am passionate about all things computers from Hardware, Operating systems to Programming. My dream is to share my knowledge with the world and help out fellow geeks. Follow my content by subscribing to LinuxHint mailing list