132 lines
3.2 KiB
Lua
132 lines
3.2 KiB
Lua
local helpers = require('personal.luasnip-helper-funcs')
|
|
local get_visual = helpers.get_visual
|
|
|
|
local line_begin = require("luasnip.extras.expand_conditions").line_begin
|
|
|
|
return
|
|
{
|
|
-- COMMON IMPORTS
|
|
s({trig="itorch"},
|
|
fmt(
|
|
[[
|
|
import torch
|
|
from torch import nn
|
|
from torch.utils.data import Dataset, DataLoader
|
|
]],
|
|
{
|
|
}
|
|
),
|
|
{condition = line_begin}
|
|
),
|
|
-- NETWORK MODEL TEMPLATE
|
|
s({trig="model"},
|
|
fmta(
|
|
[[
|
|
class FooNet(nn.Module):
|
|
def __init__(self):
|
|
super(FooNet, self).__init__()
|
|
<>
|
|
|
|
def forward(self, x):
|
|
<>
|
|
]],
|
|
{
|
|
i(1),
|
|
i(2)
|
|
}
|
|
),
|
|
{condition = line_begin}
|
|
),
|
|
-- CUSTOM DATASET TEMPLATE
|
|
s({trig="dataset"},
|
|
fmta(
|
|
[[
|
|
class FooDataset(Dataset):
|
|
def __init__(self, ...):
|
|
<>
|
|
|
|
def __getitem__(self, index):
|
|
# Returns the (feature vector, label) tuple at index `index`
|
|
<>
|
|
|
|
def __len__(self):
|
|
# Return number of instances in dataset
|
|
<>
|
|
]],
|
|
{
|
|
i(1),
|
|
i(2),
|
|
i(3)
|
|
}
|
|
),
|
|
{condition = line_begin}
|
|
),
|
|
-- SGD OPTIMIZER
|
|
s({trig="optim"},
|
|
fmta(
|
|
[[
|
|
optim = torch.optim.SGD(model.parameters(), lr=<>)
|
|
]],
|
|
{
|
|
i(1),
|
|
}
|
|
),
|
|
{condition = line_begin}
|
|
),
|
|
-- TRAINING LOOP TEMPLATE
|
|
s({trig="train"},
|
|
fmta(
|
|
[[
|
|
def train_loop(dataloader, model, loss_fn, optim):
|
|
N = len(dataloader.dataset)
|
|
|
|
# Loop over all minibatches in dataset
|
|
for mb, (X, y) in enumerate(dataloader):
|
|
# Compute prediction and loss
|
|
pred = model(X)
|
|
loss = loss_fn(pred, y)
|
|
|
|
# Backpropagation
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
# Log loss and number of instances trained
|
|
if mb % <> == 0:
|
|
loss, n = loss.item(), mb * len(X)
|
|
print("loss: {:.7f} [{:5d}/{:5d}]".format(loss, n, N))
|
|
|
|
]],
|
|
{
|
|
i(1, "100"),
|
|
}
|
|
),
|
|
{condition = line_begin}
|
|
),
|
|
-- TEST LOOP TEMPLATE
|
|
s({trig="test"},
|
|
fmta(
|
|
[[
|
|
def test_loop(dataloader, model, loss_fn):
|
|
N = len(dataloader.dataset)
|
|
num_batches = len(dataloader)
|
|
test_loss = 0
|
|
correct_preds = 0
|
|
|
|
with torch.no_grad():
|
|
for X, y in dataloader:
|
|
pred = model(X)
|
|
test_loss += loss_fn(pred, y).item()
|
|
correct_preds += (pred.argmax(1) == y).type(torch.float).sum().item()
|
|
|
|
test_loss /= num_batches
|
|
print("Test Error: \n Accuracy: {:.1f}%\n Avg loss per minibatch: {:8f} \n".format((100*correct_preds/N), test_loss))
|
|
]],
|
|
{ }
|
|
),
|
|
{condition = line_begin}
|
|
),
|
|
}
|
|
|
|
|