본문 바로가기
함수 설명/인공지능 (Pytorch)

[PyTorch] 효율적인 데이터 배치: torch.utils.data.DataLoader() 사용 가이드

by First Adventure 2024. 8. 16.
반응형

소개

  torch.utils.data.DataLoader는 PyTorch에서 데이터를 효율적으로 로드하고 배치(batch)로 처리하기 위한 도구입니다. 이 클래스는 데이터셋을 미니배치(mini-batch)로 나누고, 학습 시 데이터가 잘 섞이도록 셔플(shuffle)하며, 병렬로 데이터를 로드할 수 있게 해줍니다. 딥러닝 모델을 학습할 때 데이터셋을 손쉽게 다룰 수 있도록 돕습니다.

 

기본 사용법

상세 설명

  • DataLoader의 주요 파라미터
    • dataset: 로드할 데이터셋을 지정합니다. TensorDataset, Dataset 등을 사용 가능합니다.
    • batch_size: 배치당 로드할 샘플의 수를 지정합니다.
    • shuffle: True로 설정하면, 각 에포크(epoch)마다 데이터가 무작위로 섞입니다.
    • num_workers: 데이터를 로드할 때 사용할 서브 프로세스의 수를 지정합니다. 병렬 처리가 가능하여 데이터 로드 속도를 높일 수 있습니다.
    • drop_last: True로 설정하면, 데이터셋의 마지막 배치가 전체 배치 크기보다 작을 때 해당 배치를 버립니다.
  • DataLoader의 역할
    • 데이터셋을 쉽게 순회할 수 있도록 배치로 나누어 로드합니다.
    • 데이터 로딩 중 데이터가 섞이도록 셔플링하여 모델 학습의 일반화 성능을 향상시킵니다.
    • 병렬 데이터 로드를 통해 CPU를 활용하여 GPU의 대기 시간을 줄이고 학습 속도를 최적화할 수 있습니다.

예시 설명

  • DataLoader(dataset, batch_size=2, shuffle=True)는 데이터셋을 배치 크기 2로 나누고, 데이터가 무작위로 섞인 상태에서 로드합니다.
  • num_workers=4는 4개의 서브 프로세스를 사용하여 병렬로 데이터를 로드합니다.
  • drop_last=True로 설정하면, 데이터셋의 마지막 배치 크기가 지정된 배치 크기보다 작을 경우 해당 배치를 제외합니다.
import torch
from torch.utils.data import TensorDataset, DataLoader

# 예제 데이터셋 생성
data = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
labels = torch.tensor([0, 1, 0, 1])

# TensorDataset을 사용하여 데이터셋 생성
dataset = TensorDataset(data, labels)

# DataLoader 생성
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# 데이터 로딩 예시
for batch_data, batch_labels in dataloader:
    print(batch_data, batch_labels)

 

import torch
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]
        return sample, label

# 예제 데이터셋
data = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
labels = torch.tensor([0, 1, 0, 1])

# 커스텀 데이터셋 인스턴스 생성
dataset = CustomDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=4)

# 데이터셋 순회
for batch_data, batch_labels in dataloader:
    # 훈련 코드 작성
    pass

 

라이센스

  PyTorch의 표준 라이브러리와 내장 함수들은 BSD-style license 하에 배포됩니다. 이 라이센스는 자유 소프트웨어 라이센스로, 상업적 사용을 포함한 거의 모든 용도로 사용이 가능합니다. 라이센스와 저작권 정보는 PyTorch의 공식 GitHub 리포지토리에서 확인할 수 있습니다.

 

관련 내용

  [PyTorch] 맞춤형 데이터셋 만들기: torch.utils.data.Dataset() 사용 가이드

  [PyTorch] 효율적인 데이터 배치: torch.utils.data.DataLoader() 사용 가이드

 

반응형