상세 컨텐츠

본문 제목

[PyTorch] CIFAR-10 데이터셋 써보기

PROGRAMMING/AI

by koharin 2023. 8. 7. 11:17

본문

728x90
반응형

학습 위한 장치 얻기

  • GPU 또는 MPS 같은 하드웨어 가속기에서 모델 학습
  • `torch.cuda` 또는 `torch.backends.mps` 사용 가능한지 확인하고, 그렇지 않으면 CPU 사용
if torch.cuda.is_available():
  DEVICE = torch.device('cuda')
elif torch.backends.mps.is_available():
  DEVICE = torch.device('mps')
else:
  DEVICE = torch.device('cpu')

print('PyTorch version: ', torch.__version__, ', Device: ', DEVICE)
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

batch_size = 4
  • torchvision 데이터셋의 출력(output)은 [0, 1] 범위를 갖는 PILImage 이미지로, 이를 [-1, 1]의 범위로 정규화된 Tensor로 변환
training_data = datasets.CIFAR10(
    root="Dataset/CIFAR-10",
    train=True,
    download=True,
    transform=transform
)

test_data = datasets.CIFAR10(
    root="Dataset/CIFAR-10",
    train=False,
    download=True,
    transform=transform
)
```

```python
train_loader = torch.utils.data.DataLoader(
    training_data,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2
)

test_loader = torch.utils.data.DataLoader(
    test_data,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2
)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
def imshow(img, title):
  img = img / 2 + 0.5 # unnormalize
  npimg = img.numpy()
  plt.imshow(np.transpose(npimg, (1,2,0)))
  plt.title(title)
  plt.show()

data_iter = iter(train_loader)
images, labels = next(data_iter)

imshow(torchvision.utils.make_grid(images), title=[classes[labels[x]] for x in range(batch_size)])
#print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))
  • `plt.axis("off")` : 그림 보여줄 때 x축, y축에 수치 보여주는데 이거 안 보여주게 하는거
  • `torchvision.utils.make_grid` : make grid of image

728x90
반응형

관련글 더보기