Numpy np.argmax

Data science is all about organizing, sorting, counting, and interpreting data to generate meaningful insight and predictions.

Numpy provides you with the argmax() function that allows you to fetch the index of the max element within an array on a specific axis.

This tutorial attempts to explain how the numpy.argmax() function works and how to use it.

Function Syntax

The syntax of the function is as shown below:

numpy.argmax(array, axis=None, out=None, *, keepdims=<no value>

The function accepts the following parameters:

  • array – refers to the input array on which the function is applied.
  • Axis -specifies along which axis of the array the function is applied—this is an optional integer type parameter. The function will flatten the array if the axis value is not set.
  • Out – specifies an array into which the result is inserted. The specified array should be of the suitable type and shape.
  • Keepdims – this is a boolean parameter that takes either True or False. If set to true, the reduced axis is left as dimensions with size one.

Function Return Value

The function returns an array of indices into the array with the same shape as the array.shape.


Let us illustrate how to use the numpy.argmax function with an example.

Start by importing numpy as:

# importy numpy
import numpy as np

Next, create a 2D array as shown:

# 2d array
array = np.arange(6).reshape((2,3))

The above should create a 2D array with the elements as shown:

[[0 1 2]
 [3 4 5]]

To use the argmax() function on the entire array (flattened), we can run the code as shown:

# max element
print(f"max value: {np.argmax(array, axis=None)}");

The code above returns the max index of the max value as shown:

You can also use the argmax function along a specific axis as shown:

# max element along 0 axis
print(np.argmax(array, axis=0));

NOTE: Since we are working with a 2D array, the array has two possible axes. The first go vertically across the rows as axis 0, and the other runs horizontally across columns as axis 1.
Hence, the code above should return:

 [1 1 1]

The same case applies along axis 1. An example is as shown:

# max element along axis 1
print(np.argmax(array, axis=1))

This should return:

[2 2]

Example 2

We can also use the argmax function with an N-dimension array. For example, the code below illustrates how to use the function with a 3-D array.

# 3d array
array = np.arange(24).reshape(2, 3, 4)

This should generate an array as:

[[[ 0  1  2  3]
  [ 4  5  6  7]
  [ 8  9 10 11]]

 [[12 13 14 15]
  [16 17 18 19]
  [20 21 22 23]]]

Along the axes, we can get the argmax as shown:

# max element
print(f" Max: {np.argmax(array, axis=None)}")
# max element along axis 0
print(f" Max indices: {np.argmax(array, axis=0)}")
# max element along axis 1
print(f" Max Indices: {np.argmax(array, axis=1)}")
# max element along axis 2
print(f" Max Indices: {np.argmax(array, axis=2)}")

The code above should return:

Max: 23
Max indices: [[1 1 1 1]
 [1 1 1 1]
 [1 1 1 1]]
 Max Indices: [[2 2 2 2]
 [2 2 2 2]]
 Max Indices: [[3 3 3]
 [3 3 3]]


This article explores how to use the numpy argmax() function to fetch the indices of the max values along with a specific index.

Thank you for reading, and stay tuned!!

About the author

John Otieno

My name is John and am a fellow geek like you. I am passionate about all things computers from Hardware, Operating systems to Programming. My dream is to share my knowledge with the world and help out fellow geeks. Follow my content by subscribing to LinuxHint mailing list