import torch
from torch.utils.data import Dataset
from ..spikeClassifier import spikeClassifier as predict
import slayerCuda
from datetime import datetime
[docs]class Assistant:
'''
This class provides standard assistant functionalities for traiing and testing workflow.
If you want a different workflow than what is available, you should inherit this module and
overload the particular module to your need.
Arguments:
* ``net``: the SLAYER network to be run.
* ``trainLoader``: training dataloader.
* ``testLoader``: testing dataloader.
* ``error``: a function object or a lamda function that takes (output, target, label) as its input and returns
a scalar error value.
* ``optimizer``: the learning optimizer.
* ``scheduler``: the learning scheduler. Default: ``None`` meaning no scheduler will be used.
* ``stats``: the SLAYER learning stats logger: ``slayerSNN.stats``. Default: ``None`` meaning no stats will be used.
* ``dataParallel``: flag if dataParallel execution needs to be handled. Default: ``False``.
* ``showTimeSteps``: flag to print timesteps of the sample or not. Default: ``False``.
* ``lossScale``: a scale factor to be used while printing the loss. Default: ``None`` meaning no scaling is done.
* ``printInterval``: number of epochs to print the lerning output once. Default: 1.
Usage:
.. code-block:: python
assist = assistant(net, trainLoader, testLoader, lambda o, t, l: error.numSpikes(o, t), optimizer, stats)
for epoch in range(maxEpoch):
assist.train(epoch)
assist.test(epoch)
'''
def __init__(self, net, trainLoader, testLoader, error, optimizer, scheduler=None, stats=None,
dataParallel=False, showTimeSteps=False, lossScale=None, printInterval=1):
self.net = net
self.module = net.module if dataParallel is True else net
self.error = error
self.device = net.slayer.srmKernel.device
self.optimizer = optimizer
self.scheduler = scheduler
self.stats = stats
self.showTimeSteps = showTimeSteps
self.lossScale = lossScale
self.printInterval = printInterval
self.trainLoader = trainLoader
self.testLoader = testLoader
[docs] def train(self, epoch=0, breakIter = None):
'''
Training assistant fucntion.
Arguments:
* ``epoch``: training epoch number.
* ``breakIter``: number of samples to wait before breaking out of the training loop.
``None`` means go over the complete training samples. Default: ``None``.
'''
tSt = datetime.now()
for i, (input, target, label) in enumerate(self.trainLoader, 0):
self.net.train()
input = input.to(self.device)
target = target.to(self.device)
count = 0
if self.module.countLog is True:
output, count = self.net.forward(input)
else:
output = self.net.forward(input)
if self.stats is not None:
self.stats.training.correctSamples += torch.sum( predict.getClass(output) == label ).data.item()
self.stats.training.numSamples += len(label)
loss = self.error(output, target, label)
if self.stats is not None:
self.stats.training.lossSum += loss.cpu().data.item() * (1 if self.lossScale is None else self.lossScale)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.module.clamp()
if self.stats is not None and epoch%self.printInterval == 0:
headerList = ['[{}/{} ({:.0f}%)]'.format(i*len(input), len(self.trainLoader.dataset), 100.0*i/len(self.trainLoader))]
if self.module.countLog is True:
headerList.append('Spike count: ' + ', '.join(['{}'.format(int(c)) for c in torch.sum(count, dim=0).tolist()]))
if self.showTimeSteps is True:
headerList.append('nTimeBins: {}'.format(input.shape[-1]))
self.stats.print(
epoch, i,
(datetime.now() - tSt).total_seconds() / (i+1) / input.shape[0],
header= headerList,
)
if breakIter is not None and i >= breakIter:
break
if self.scheduler is not None:
self.scheduler.step()
[docs] def test(self, epoch=0, evalLoss=True, slidingWindow=None, breakIter = None):
'''
Testing assistant fucntion.
Arguments:
* ``epoch``: training epoch number.
* ``evalLoss``: a flag to enable or disable loss evalutaion. Default: ``True``.
* ``slidingWindow``: the length of sliding window to use for continuous output prediction over time.
``None`` means total spike count is used to produce one output per sample. If it is not
``None``, ``evalLoss`` is overwritten to ``False``. Default: ``None``.
* ``breakIter``: number of samples to wait before breaking out of the testing loop.
``None`` means go over the complete training samples. Default: ``None``.
'''
if slidingWindow is not None:
filter = torch.ones((slidingWindow)).to(self.device)
evalLoss = False
tSt = datetime.now()
for i, (input, target, label) in enumerate(self.testLoader, 0):
self.net.eval()
with torch.no_grad():
input = input.to(self.device)
target = target.to(self.device)
count = 0
if self.module.countLog is True:
output, count = self.net.forward(input)
else:
output = self.net.forward(input)
if slidingWindow is None:
if self.stats is not None:
self.stats.testing.correctSamples += torch.sum( predict.getClass(output) == label ).data.item()
self.stats.testing.numSamples += len(label)
else:
filteredOutput = slayerCuda.conv(output.contiguous(), filter, 1)[..., slidingWindow:]
predictions = torch.argmax(filteredOutput.reshape(-1, filteredOutput.shape[-1]), dim=0)
# print(output.shape, predictions.shape)
# print(predictions[:100])
# print(label)
# print(torch.sum(predictions == label).item())
# print(torch.sum(predictions == label).item() / predictions.shape[0])
# assert False, 'Just braking'
if self.stats is not None:
self.stats.testing.correctSamples += torch.sum(predictions == label.to(self.device)).item()
self.stats.testing.numSamples += predictions.shape[0]
if evalLoss is True:
loss = self.error(output, target, label)
if self.stats is not None:
self.stats.testing.lossSum += loss.cpu().data.item() * (1 if self.lossScale is None else self.lossScale)
else:
if self.stats is not None:
if slidingWindow is None:
self.stats.testing.lossSum += (1 if self.lossScale is None else self.lossScale)
else:
self.stats.testing.lossSum += predictions.shape[0] * (1 if self.lossScale is None else self.lossScale)
if self.stats is not None and epoch%self.printInterval == 0:
headerList = ['[{}/{} ({:.0f}%)]'.format(i*len(input), len(self.testLoader.dataset), 100.0*i/len(self.testLoader))]
if self.module.countLog is True:
headerList.append('Spike count: ' + ', '.join(['{}'.format(int(c)) for c in torch.sum(count, dim=0).tolist()]))
if self.showTimeSteps is True:
headerList.append('nTimeBins: {}'.format(input.shape[-1]))
self.stats.print(
epoch, i,
(datetime.now() - tSt).total_seconds() / (i+1) / input.shape[0],
header= headerList,
)
if breakIter is not None and i >= breakIter:
break