Source code for slayerSNN.auto.assistant

import torch
from torch.utils.data import Dataset
from ..spikeClassifier import spikeClassifier as predict
import slayerCuda
from datetime import datetime

[docs]class Assistant: ''' This class provides standard assistant functionalities for traiing and testing workflow. If you want a different workflow than what is available, you should inherit this module and overload the particular module to your need. Arguments: * ``net``: the SLAYER network to be run. * ``trainLoader``: training dataloader. * ``testLoader``: testing dataloader. * ``error``: a function object or a lamda function that takes (output, target, label) as its input and returns a scalar error value. * ``optimizer``: the learning optimizer. * ``scheduler``: the learning scheduler. Default: ``None`` meaning no scheduler will be used. * ``stats``: the SLAYER learning stats logger: ``slayerSNN.stats``. Default: ``None`` meaning no stats will be used. * ``dataParallel``: flag if dataParallel execution needs to be handled. Default: ``False``. * ``showTimeSteps``: flag to print timesteps of the sample or not. Default: ``False``. * ``lossScale``: a scale factor to be used while printing the loss. Default: ``None`` meaning no scaling is done. * ``printInterval``: number of epochs to print the lerning output once. Default: 1. Usage: .. code-block:: python assist = assistant(net, trainLoader, testLoader, lambda o, t, l: error.numSpikes(o, t), optimizer, stats) for epoch in range(maxEpoch): assist.train(epoch) assist.test(epoch) ''' def __init__(self, net, trainLoader, testLoader, error, optimizer, scheduler=None, stats=None, dataParallel=False, showTimeSteps=False, lossScale=None, printInterval=1): self.net = net self.module = net.module if dataParallel is True else net self.error = error self.device = net.slayer.srmKernel.device self.optimizer = optimizer self.scheduler = scheduler self.stats = stats self.showTimeSteps = showTimeSteps self.lossScale = lossScale self.printInterval = printInterval self.trainLoader = trainLoader self.testLoader = testLoader
[docs] def train(self, epoch=0, breakIter = None): ''' Training assistant fucntion. Arguments: * ``epoch``: training epoch number. * ``breakIter``: number of samples to wait before breaking out of the training loop. ``None`` means go over the complete training samples. Default: ``None``. ''' tSt = datetime.now() for i, (input, target, label) in enumerate(self.trainLoader, 0): self.net.train() input = input.to(self.device) target = target.to(self.device) count = 0 if self.module.countLog is True: output, count = self.net.forward(input) else: output = self.net.forward(input) if self.stats is not None: self.stats.training.correctSamples += torch.sum( predict.getClass(output) == label ).data.item() self.stats.training.numSamples += len(label) loss = self.error(output, target, label) if self.stats is not None: self.stats.training.lossSum += loss.cpu().data.item() * (1 if self.lossScale is None else self.lossScale) self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.module.clamp() if self.stats is not None and epoch%self.printInterval == 0: headerList = ['[{}/{} ({:.0f}%)]'.format(i*len(input), len(self.trainLoader.dataset), 100.0*i/len(self.trainLoader))] if self.module.countLog is True: headerList.append('Spike count: ' + ', '.join(['{}'.format(int(c)) for c in torch.sum(count, dim=0).tolist()])) if self.showTimeSteps is True: headerList.append('nTimeBins: {}'.format(input.shape[-1])) self.stats.print( epoch, i, (datetime.now() - tSt).total_seconds() / (i+1) / input.shape[0], header= headerList, ) if breakIter is not None and i >= breakIter: break if self.scheduler is not None: self.scheduler.step()
[docs] def test(self, epoch=0, evalLoss=True, slidingWindow=None, breakIter = None): ''' Testing assistant fucntion. Arguments: * ``epoch``: training epoch number. * ``evalLoss``: a flag to enable or disable loss evalutaion. Default: ``True``. * ``slidingWindow``: the length of sliding window to use for continuous output prediction over time. ``None`` means total spike count is used to produce one output per sample. If it is not ``None``, ``evalLoss`` is overwritten to ``False``. Default: ``None``. * ``breakIter``: number of samples to wait before breaking out of the testing loop. ``None`` means go over the complete training samples. Default: ``None``. ''' if slidingWindow is not None: filter = torch.ones((slidingWindow)).to(self.device) evalLoss = False tSt = datetime.now() for i, (input, target, label) in enumerate(self.testLoader, 0): self.net.eval() with torch.no_grad(): input = input.to(self.device) target = target.to(self.device) count = 0 if self.module.countLog is True: output, count = self.net.forward(input) else: output = self.net.forward(input) if slidingWindow is None: if self.stats is not None: self.stats.testing.correctSamples += torch.sum( predict.getClass(output) == label ).data.item() self.stats.testing.numSamples += len(label) else: filteredOutput = slayerCuda.conv(output.contiguous(), filter, 1)[..., slidingWindow:] predictions = torch.argmax(filteredOutput.reshape(-1, filteredOutput.shape[-1]), dim=0) # print(output.shape, predictions.shape) # print(predictions[:100]) # print(label) # print(torch.sum(predictions == label).item()) # print(torch.sum(predictions == label).item() / predictions.shape[0]) # assert False, 'Just braking' if self.stats is not None: self.stats.testing.correctSamples += torch.sum(predictions == label.to(self.device)).item() self.stats.testing.numSamples += predictions.shape[0] if evalLoss is True: loss = self.error(output, target, label) if self.stats is not None: self.stats.testing.lossSum += loss.cpu().data.item() * (1 if self.lossScale is None else self.lossScale) else: if self.stats is not None: if slidingWindow is None: self.stats.testing.lossSum += (1 if self.lossScale is None else self.lossScale) else: self.stats.testing.lossSum += predictions.shape[0] * (1 if self.lossScale is None else self.lossScale) if self.stats is not None and epoch%self.printInterval == 0: headerList = ['[{}/{} ({:.0f}%)]'.format(i*len(input), len(self.testLoader.dataset), 100.0*i/len(self.testLoader))] if self.module.countLog is True: headerList.append('Spike count: ' + ', '.join(['{}'.format(int(c)) for c in torch.sum(count, dim=0).tolist()])) if self.showTimeSteps is True: headerList.append('nTimeBins: {}'.format(input.shape[-1])) self.stats.print( epoch, i, (datetime.now() - tSt).total_seconds() / (i+1) / input.shape[0], header= headerList, ) if breakIter is not None and i >= breakIter: break