Source code for slayerSNN.spikeClassifier
import numpy as np
import torch
[docs]class spikeClassifier:
'''
It provides classification modules for SNNs.
All the functions it supplies are static and can be called without making an instance of the class.
'''
[docs] @staticmethod
def getClass(spike):
'''
Returns the predicted class label.
It assignes single class for the SNN output for the whole simulation runtime.
Usage:
>>> predictedClass = spikeClassifier.getClass(spikeOut)
'''
numSpikes = torch.sum(spike, 4, keepdim=True).cpu()
return torch.max(numSpikes.reshape((numSpikes.shape[0], -1)), 1)[1]
# numSpikes = torch.sum(spike, 4, keepdim=True).cpu().data.numpy()
# return np.argmax(numSpikes.reshape((numSpikes.shape[0], -1)), 1)