본문 바로가기
오류 해결

[Pytorch] RuntimeError: Expected object of scalar type Long but found Float

by First Adventure 2024. 9. 21.
반응형

에러 메시지 설명

  이 오류는 PyTorch에서 데이터의 자료형이 예상과 일치하지 않을 때 발생합니다. 특정 연산, 특히 손실 함수에서 입력 데이터 또는 타겟 레이블의 자료형이 Float일 때, Long 형식을 기대하면 이 오류가 발생합니다.

 

발생 원인

  • 손실 함수가 Long 형식을 요구: nn.CrossEntropyLoss와 같은 손실 함수는 타겟 레이블이 Long 형식이어야 합니다. 그러나 Float 형식의 레이블이 전달되면 이 오류가 발생할 수 있습니다.
  • 데이터 타입 불일치: 모델의 출력은 보통 Float 형식이지만, 타겟 레이블은 Long 형식이어야 합니다. 만약 모델 입력이나 레이블의 형식이 일치하지 않으면 오류가 발생합니다​.

 

해결 방법

  • 타겟 레이블을 Long 타입으로 변환: 타겟 레이블이 Float 형식인 경우, Long 형식으로 변환해야 합니다. 이를 위해 long() 메서드를 사용하여 자료형을 변환할 수 있습니다.
target = target.long()  # 레이블을 Long 형식으로 변환

 

  • 손실 함수에 맞는 입력 형식 확인: 손실 함수가 요구하는 입력 형식에 맞게 데이터를 변환하세요. 예를 들어, nn.CrossEntropyLoss는 모델의 출력이 Float 형식이어야 하고, 타겟 레이블이 Long 형식이어야 합니다.
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(output, target.long())  # 타겟을 Long으로 변환

 

  • 입력 및 타겟 자료형 확인: 코드 실행 전, 모델의 출력과 타겟 레이블의 자료형이 손실 함수와 일치하는지 확인합니다. 자료형을 출력하여 문제가 있는 부분을 찾을 수 있습니다.
print(output.dtype, target.dtype)  # 자료형 확인

 

관련 내용 및 추가 팁

  • 이 오류는 주로 다중 클래스 분류 작업에서 발생하며, 모델 출력과 타겟 레이블 간의 자료형 불일치가 원인입니다. 일반적으로 PyTorch에서 손실 함수는 Float 형식의 예측 값과 Long 형식의 타겟 값을 요구하므로, 이를 명확히 구분하여 처리하는 것이 중요합니다​.
  • 데이터 전처리 단계에서 항상 레이블의 자료형을 확인하고, 필요한 경우 long() 메서드를 사용하여 형식을 맞추세요.
  • 손실 함수의 요구 사항에 맞게 모델 출력과 레이블 형식을 유지하는 것이 중요합니다.
반응형