kickstart.nvim/LuaSnip/python/torch.lua

132 lines
3.2 KiB
Lua

local helpers = require('personal.luasnip-helper-funcs')
local get_visual = helpers.get_visual
local line_begin = require("luasnip.extras.expand_conditions").line_begin
return
{
-- COMMON IMPORTS
s({trig="itorch"},
fmt(
[[
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
]],
{
}
),
{condition = line_begin}
),
-- NETWORK MODEL TEMPLATE
s({trig="model"},
fmta(
[[
class FooNet(nn.Module):
def __init__(self):
super(FooNet, self).__init__()
<>
def forward(self, x):
<>
]],
{
i(1),
i(2)
}
),
{condition = line_begin}
),
-- CUSTOM DATASET TEMPLATE
s({trig="dataset"},
fmta(
[[
class FooDataset(Dataset):
def __init__(self, ...):
<>
def __getitem__(self, index):
# Returns the (feature vector, label) tuple at index `index`
<>
def __len__(self):
# Return number of instances in dataset
<>
]],
{
i(1),
i(2),
i(3)
}
),
{condition = line_begin}
),
-- SGD OPTIMIZER
s({trig="optim"},
fmta(
[[
optim = torch.optim.SGD(model.parameters(), lr=<>)
]],
{
i(1),
}
),
{condition = line_begin}
),
-- TRAINING LOOP TEMPLATE
s({trig="train"},
fmta(
[[
def train_loop(dataloader, model, loss_fn, optim):
N = len(dataloader.dataset)
# Loop over all minibatches in dataset
for mb, (X, y) in enumerate(dataloader):
# Compute prediction and loss
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Log loss and number of instances trained
if mb % <> == 0:
loss, n = loss.item(), mb * len(X)
print("loss: {:.7f} [{:5d}/{:5d}]".format(loss, n, N))
]],
{
i(1, "100"),
}
),
{condition = line_begin}
),
-- TEST LOOP TEMPLATE
s({trig="test"},
fmta(
[[
def test_loop(dataloader, model, loss_fn):
N = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss = 0
correct_preds = 0
with torch.no_grad():
for X, y in dataloader:
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct_preds += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
print("Test Error: \n Accuracy: {:.1f}%\n Avg loss per minibatch: {:8f} \n".format((100*correct_preds/N), test_loss))
]],
{ }
),
{condition = line_begin}
),
}