반응형
소개
torch.argmax는 PyTorch에서 텐서의 모든 요소 중에서 최댓값을 가지는 인덱스를 반환하는 함수입니다. 이 함수는 텐서 전체에서 최댓값의 인덱스를 찾거나, 특정 차원(axis)에서 최댓값의 인덱스를 계산할 수 있습니다. torch.argmax는 데이터 분석과 딥러닝 모델에서 가장 큰 값의 위치를 추적하거나 확인하는 데 매우 유용합니다. 특히, 분류 문제에서 모델의 예측값 중 가장 높은 확률을 가진 클래스를 찾을 때 자주 사용됩니다.
기본 사용법
상세 설명
- 전체 요소 최댓값 인덱스
- torch.argmax(tensor)는 텐서의 모든 요소를 비교하여 최댓값을 가지는 요소의 인덱스를 반환합니다.
- 이 기능은 데이터에서 가장 큰 값이 어디에 위치하는지를 확인할 때 유용합니다.
- 차원별 최댓값 인덱스
- torch.argmax(tensor, dim=0)와 같이 dim 인수를 사용하면 특정 차원(axis)에서 최댓값을 가지는 요소의 인덱스를 계산할 수 있습니다.
- 이 기능은 텐서의 각 차원에서 최댓값의 위치를 추적하거나 데이터를 축소할 때 유용합니다.
예시 설명
- 첫 번째 예시에서 torch.argmax(tensor)는 1D 텐서 [3.0, 1.0, 4.0, 1.5, 2.0]의 모든 요소 중에서 최댓값 4.0의 인덱스 2를 반환합니다.
- 두 번째 예시에서는 2D 텐서의 전체 요소 중에서 최댓값 5.0의 인덱스 4이며, dim=0과 dim=1을 사용하여 각각 열과 행의 최댓값 인덱스를 구할 수 있습니다.
import torch
# 1D 텐서 생성
tensor = torch.tensor([3.0, 1.0, 4.0, 1.5, 2.0])
# 모든 요소 중에서 최댓값의 인덱스 계산
argmax_result = torch.argmax(tensor)
print(argmax_result)
# 출력: tensor(2)
# 2D 텐서 생성
tensor = torch.tensor([[3.0, 2.0, 1.0], [4.0, 5.0, 0.0]])
# 모든 요소 중에서 최댓값의 인덱스 계산
overall_argmax = torch.argmax(tensor)
print(overall_argmax)
# 출력: tensor(4)
# 특정 차원에서의 최댓값 인덱스 계산 (dim=0: 각 열의 최댓값 인덱스)
column_argmax = torch.argmax(tensor, dim=0)
print(column_argmax)
# 출력: tensor([1, 1, 0])
# 특정 차원에서의 최댓값 인덱스 계산 (dim=1: 각 행의 최댓값 인덱스)
row_argmax = torch.argmax(tensor, dim=1)
print(row_argmax)
# 출력: tensor([0, 1])
라이센스
PyTorch의 표준 라이브러리와 내장 함수들은 BSD-style license 하에 배포됩니다. 이 라이센스는 자유 소프트웨어 라이센스로, 상업적 사용을 포함한 거의 모든 용도로 사용이 가능합니다. 라이센스와 저작권 정보는 PyTorch의 공식 GitHub 리포지토리에서 확인할 수 있습니다.
관련 내용
반응형
'함수 설명 > 인공지능 (Pytorch)' 카테고리의 다른 글
[PyTorch] 텐서 요소의 최댓값 계산: torch.max() 설명 (0) | 2024.08.25 |
---|---|
[PyTorch] 텐서의 최소값 인덱스 찾기: torch.argmin() 설명 (0) | 2024.08.25 |
[PyTorch] 텐서의 지수 계산: torch.exp() 설명 (0) | 2024.08.25 |
[PyTorch] 텐서의 자연 로그 계산: torch.log() 설명 (0) | 2024.08.25 |
[PyTorch] 배치(batch) 단위의 행렬 곱셈: torch.bmm() 설명 (0) | 2024.08.24 |