5,323
edits
Line 12: | Line 12: | ||
* [https://pytorch.org/tutorials/ PyTorch Tutorials] | * [https://pytorch.org/tutorials/ PyTorch Tutorials] | ||
{{hidden | Example | | |||
<syntaxhighlight lang="python"> | <syntaxhighlight lang="python"> | ||
import torch | import torch | ||
import torch.nn as nn | import torch.nn as nn | ||
model = nn.Sequential(nn.Linear(5, 5),nn.ReLU(),nn.Linear(5, 1)) | |||
criterion = nn.MSELoss() | |||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) | |||
# Training | # Training | ||
for epoch in range(epochs): | for epoch in range(epochs): | ||
for i, data in enumerate(trainloader): | |||
for i, data in enumerate(trainloader | # get the inputs; e.g. data is a list of [inputs, labels] | ||
# get the inputs; data is a list of [inputs, labels] | |||
inputs, labels = data | inputs, labels = data | ||
Line 26: | Line 30: | ||
optimizer.zero_grad() | optimizer.zero_grad() | ||
# forward | # forward | ||
outputs = | outputs = model(inputs) | ||
loss = criterion(outputs, labels) | loss = criterion(outputs, labels) | ||
# backward | |||
loss.backward() | loss.backward() | ||
optimizer.step() | optimizer.step() | ||
</syntaxhighlight> | </syntaxhighlight> | ||
}} | |||
==Importing Data== | ==Importing Data== |