학습 위한 장치 얻기
- GPU 또는 MPS 같은 하드웨어 가속기에서 모델 학습
- `torch.cuda` 또는 `torch.backends.mps` 사용 가능한지 확인하고, 그렇지 않으면 CPU 사용
if torch.cuda.is_available():
DEVICE = torch.device('cuda')
elif torch.backends.mps.is_available():
DEVICE = torch.device('mps')
else:
DEVICE = torch.device('cpu')
print('PyTorch version: ', torch.__version__, ', Device: ', DEVICE)
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
batch_size = 4
- torchvision 데이터셋의 출력(output)은 [0, 1] 범위를 갖는 PILImage 이미지로, 이를 [-1, 1]의 범위로 정규화된 Tensor로 변환
training_data = datasets.CIFAR10(
root="Dataset/CIFAR-10",
train=True,
download=True,
transform=transform
)
test_data = datasets.CIFAR10(
root="Dataset/CIFAR-10",
train=False,
download=True,
transform=transform
)
```
```python
train_loader = torch.utils.data.DataLoader(
training_data,
batch_size=batch_size,
shuffle=True,
num_workers=2
)
test_loader = torch.utils.data.DataLoader(
test_data,
batch_size=batch_size,
shuffle=False,
num_workers=2
)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
def imshow(img, title):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1,2,0)))
plt.title(title)
plt.show()
data_iter = iter(train_loader)
images, labels = next(data_iter)
imshow(torchvision.utils.make_grid(images), title=[classes[labels[x]] for x in range(batch_size)])
#print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))
- `plt.axis("off")` : 그림 보여줄 때 x축, y축에 수치 보여주는데 이거 안 보여주게 하는거
- `torchvision.utils.make_grid` : make grid of image