132 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			Lua
		
	
	
	
			
		
		
	
	
			132 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			Lua
		
	
	
	
local helpers = require('personal.luasnip-helper-funcs')
 | 
						|
local get_visual = helpers.get_visual
 | 
						|
 | 
						|
local line_begin = require("luasnip.extras.expand_conditions").line_begin
 | 
						|
 | 
						|
return
 | 
						|
  {
 | 
						|
    -- COMMON IMPORTS
 | 
						|
    s({trig="itorch"},
 | 
						|
      fmt(
 | 
						|
        [[
 | 
						|
          import torch
 | 
						|
          from torch import nn
 | 
						|
          from torch.utils.data import Dataset, DataLoader
 | 
						|
        ]],
 | 
						|
        {
 | 
						|
        }
 | 
						|
      ),
 | 
						|
      {condition = line_begin}
 | 
						|
    ),
 | 
						|
    -- NETWORK MODEL TEMPLATE
 | 
						|
    s({trig="model"},
 | 
						|
      fmta(
 | 
						|
        [[
 | 
						|
          class FooNet(nn.Module):
 | 
						|
              def __init__(self):
 | 
						|
                  super(FooNet, self).__init__()
 | 
						|
                  <>
 | 
						|
 | 
						|
              def forward(self, x):
 | 
						|
                  <>
 | 
						|
        ]],
 | 
						|
        {
 | 
						|
          i(1),
 | 
						|
          i(2)
 | 
						|
        }
 | 
						|
      ),
 | 
						|
      {condition = line_begin}
 | 
						|
    ),
 | 
						|
    -- CUSTOM DATASET TEMPLATE
 | 
						|
    s({trig="dataset"},
 | 
						|
      fmta(
 | 
						|
        [[
 | 
						|
          class FooDataset(Dataset):
 | 
						|
              def __init__(self, ...):
 | 
						|
                  <>
 | 
						|
                  
 | 
						|
              def __getitem__(self, index):
 | 
						|
                  # Returns the (feature vector, label) tuple at index `index`
 | 
						|
                  <>
 | 
						|
 | 
						|
              def __len__(self):
 | 
						|
                  # Return number of instances in dataset
 | 
						|
                  <>
 | 
						|
        ]],
 | 
						|
        {
 | 
						|
          i(1),
 | 
						|
          i(2),
 | 
						|
          i(3)
 | 
						|
        }
 | 
						|
      ),
 | 
						|
      {condition = line_begin}
 | 
						|
    ),
 | 
						|
    -- SGD OPTIMIZER
 | 
						|
    s({trig="optim"},
 | 
						|
      fmta(
 | 
						|
        [[
 | 
						|
          optim = torch.optim.SGD(model.parameters(), lr=<>)
 | 
						|
        ]],
 | 
						|
        {
 | 
						|
          i(1),
 | 
						|
        }
 | 
						|
      ),
 | 
						|
      {condition = line_begin}
 | 
						|
    ),
 | 
						|
    -- TRAINING LOOP TEMPLATE
 | 
						|
    s({trig="train"},
 | 
						|
      fmta(
 | 
						|
        [[
 | 
						|
          def train_loop(dataloader, model, loss_fn, optim):
 | 
						|
              N = len(dataloader.dataset)
 | 
						|
 | 
						|
              # Loop over all minibatches in dataset
 | 
						|
              for mb, (X, y) in enumerate(dataloader):
 | 
						|
                  # Compute prediction and loss
 | 
						|
                  pred = model(X)
 | 
						|
                  loss = loss_fn(pred, y)
 | 
						|
 | 
						|
                  # Backpropagation
 | 
						|
                  optimizer.zero_grad()
 | 
						|
                  loss.backward()
 | 
						|
                  optimizer.step()
 | 
						|
 | 
						|
                  # Log loss and number of instances trained
 | 
						|
                  if mb % <> == 0:
 | 
						|
                      loss, n = loss.item(), mb * len(X)
 | 
						|
                      print("loss: {:.7f}  [{:5d}/{:5d}]".format(loss, n, N))
 | 
						|
              
 | 
						|
        ]],
 | 
						|
        {
 | 
						|
          i(1, "100"),
 | 
						|
        }
 | 
						|
      ),
 | 
						|
      {condition = line_begin}
 | 
						|
    ),
 | 
						|
    -- TEST LOOP TEMPLATE
 | 
						|
    s({trig="test"},
 | 
						|
      fmta(
 | 
						|
        [[
 | 
						|
          def test_loop(dataloader, model, loss_fn):
 | 
						|
              N = len(dataloader.dataset)
 | 
						|
              num_batches = len(dataloader)
 | 
						|
              test_loss = 0
 | 
						|
              correct_preds = 0
 | 
						|
 | 
						|
              with torch.no_grad():
 | 
						|
                  for X, y in dataloader:
 | 
						|
                      pred = model(X)
 | 
						|
                      test_loss += loss_fn(pred, y).item()
 | 
						|
                      correct_preds += (pred.argmax(1) == y).type(torch.float).sum().item()
 | 
						|
 | 
						|
              test_loss /= num_batches
 | 
						|
              print("Test Error: \n  Accuracy: {:.1f}%\n  Avg loss per minibatch: {:8f} \n".format((100*correct_preds/N), test_loss))
 | 
						|
        ]],
 | 
						|
        { }
 | 
						|
      ),
 | 
						|
      {condition = line_begin}
 | 
						|
    ),
 | 
						|
  }
 | 
						|
 | 
						|
 |