Python numpy.argmax(): Beginners Reference

Filed Under: NumPy
Python numpy.Argmax

In this tutorial, we will learn about the numpy.argmax() function in Python. This function returns the indices of the maximum elements in a matrix. The function lets you pass an argument mentioning the axis along which you need to find the indices of the maximum elements.

We get three options while using the argmax function.

  • Find the maximum element for the entire matrix. (default)
  • Get the maximum element for each row.
  • Get the maximum element for each column.

Let’s see how to use this function.

Finding the maximum element from a matrix with Python numpy.argmax()

Let’s start with importing numpy and creating a sample matrix..

import numpy as np
a = np.arange(12).reshape(4,3) + 10
print(a)

Output:

[[10 11 12]
 [13 14 15]
 [16 17 18]
 [19 20 21]]

Now, let’s find the index of the maximum element in the array.

print(np.argmax(a))

Output : 11

We get 11 as the output. This is because when no axis is mentioned to the numpy.argmax() function, the index is into the flattened array. Once that’s done, it returns the index of the last element in the array.

Or basically, without the axis specified, the Python numpy.argmax() function returns the count of elements within the array.

We can use the np.unravel_index function for getting an index corresponding to a 2D array from the numpy.argmax output.

Note : In case of multiple occurrences of the maximum values, the function returns the indices corresponding to the first.

Using np.unravel_index on argmax output

To use np.unravel_index function on the argmax output, let’s run the following code snippet:

index = np.unravel_index(np.argmax(a), a.shape)
print(index)
print(a[index])

This gives the following output :

(3, 2)
21

We can combine the code from these two sections to directly print the maximum element.

Complete code to print the maximum element for the matrix

Here’s the complete code:

import numpy as np
a = np.arange(12).reshape(4,3) + 10
print(a)
index = np.unravel_index(np.argmax(a), a.shape)
print(index)
print(a[index])

Finding Maximum Elements along columns using Python numpy.argmax()

To find the maximum elements for each column use:

import numpy as np
a = np.arange(12).reshape(4,3) + 10
print(np.argmax(a, axis=0))

Output :

[3 3 3]

This gives the index value of the maximum elements along each column.

Similarly, if we mention the axis as 1 then we can get the indices of the maximum elements along the rows.

Finding Maximum Elements along rows

To find the maximum elements for each row use:

import numpy as np
a = np.arange(12).reshape(4,3) + 10
print(np.argmax(a, axis=1))

Output :

[2 2 2 2]

Conclusion

This tutorial was about numpy.argmax() function in Python. We learned how this function is used for finding maximum elements along different axes of the matrix.

Leave a Reply

Your email address will not be published. Required fields are marked *

close
Generic selectors
Exact matches only
Search in title
Search in content
Search in posts
Search in pages