In this article, we’ll take a look at using the PyTorch torch.max() function.
As you may expect, this is a very simple function, but interestingly, it has more than you imagine.
Let’s take a look at using this function, using some simple examples.
NOTE: At the time of writing, the PyTorch version used is PyTorch 1.5.0
Table of Contents
PyTorch torch.max() – Basic Syntax
To use PyTorch
torch.max(), first import
Now, this function returns the maximum among the elements in the Tensor.
Default Behavior of PyTorch torch.max()
The default behavior is to return a single element and an index, corresponding to the global maximum element.
max_element = torch.max(input_tensor)
Here is an example:
p = torch.randn([2, 3]) print(p) max_element = torch.max(p) print(max_element)
tensor([[-0.0665, 2.7976, 0.9753], [ 0.0688, -1.0376, 1.4443]]) tensor(2.7976)
Indeed, this gives us the global maximum element in the Tensor!
Use torch.max() along a dimension
However, you may wish to get the maximum along a particular dimension, as a Tensor, instead of a single element.
To specify the dimension (axis – in
numpy), there is another optional keyword argument, called
This represents the direction that we take for the maximum.
This returns a tuple,
max_elements-> All the maximum elements of the Tensor.
max_indices-> Indices corresponding to the maximum elements.
max_elements, max_indices = torch.max(input_tensor, dim)
This will return a Tensor, which has the maximum elements along the dimension
Let’s now look at some examples.
p = torch.randn([2, 3]) print(p) # Get the maximum along dim = 0 (axis = 0) max_elements, max_idxs = torch.max(p, dim=0) print(max_elements) print(max_idxs)
tensor([[-0.0665, 2.7976, 0.9753], [ 0.0688, -1.0376, 1.4443]]) tensor([0.0688, 2.7976, 1.4443]) tensor([1, 0, 1])
As you can see, we find the maximum along the dimension 0 (maximum along columns).
Also, we get the indices corresponding to the elements. For example,
0.0688 has the index
1 along column 0
Similarly, if you want to find the maximum along the rows, use
# Get the maximum along dim = 1 (axis = 1) max_elements, max_idxs = torch.max(p, dim=1) print(max_elements) print(max_idxs)
tensor([2.7976, 1.4443]) tensor([1, 2])
Indeed, we get the maximum elements along the row, and the corresponding index (along the row).
Using torch.max() for comparison
We can also use
torch.max() to get the maximum values between two Tensors.
output_tensor = torch.max(a, b)
b must have the same dimensions, or must be “broadcastable” Tensors.
Here is a simple example to compare two Tensors having the same dimensions.
p = torch.randn([2, 3]) q = torch.randn([2, 3]) print("p =", p) print("q =",q) # Compare elements of p and q and get the maximum max_elements = torch.max(p, q) print(max_elements)
p = tensor([[-0.0665, 2.7976, 0.9753], [ 0.0688, -1.0376, 1.4443]]) q = tensor([[-0.0678, 0.2042, 0.8254], [-0.1530, 0.0581, -0.3694]]) tensor([[-0.0665, 2.7976, 0.9753], [ 0.0688, 0.0581, 1.4443]])
Indeed, we get the output tensor having maximum elements between
In this article, we learned about using the torch.max() function, to find out the maximum element of a Tensor.
We also used this function to compare two tensors and get the maximum among them.
For similar articles, do go through our content on our PyTorch tutorials! Stay tuned for more!
- PyTorch Official Documentation on torch.max()