These are some quick revision notes on PyTorch

A simple PyTorch model could look like this

from torch import nn

class SimpleNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

Chaining modules together

# Using nn.Sequential
model = nn.Sequential(
    nn.Linear(input_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, output_size)
)

Model and optimizer

from torch.optim import AdamW

model = SimpleNN()
optimizer = AdamW(model.parameters(), lr=learning_rate)

Data

from torchvision import datasets, transforms

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True,
                                  transform=transforms.ToTensor())
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

Training loop

model.train()
for inputs, labels in train_loader:
    predictions = model(inputs)
    loss = cross_entropy(predictions, labels)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

What to look out for in training

  • Loss
  • Grad norm torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)
  • Validation metrics

Tools to use in PyTorch

Profiler - https://github.com/tonghuikang/optimizer-memory-profiles

Metrics reporting - Weights and Biases, though I prefer reporting locally