Source code for slayerSNN.slayer

import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import yaml
# import slayer_cuda
import slayerCuda
# import matplotlib.pyplot as plt

# # Consider dictionary for easier iteration and better scalability
# class yamlParams(object):
#   '''
#   This class reads yaml parameter file and allows dictionary like access to the members.
    
#   Usage:

#   .. code-block:: python
        
#       import slayerSNN as snn
#       netParams = snn.params('path_to_yaml_file') # OR
#       netParams = slayer.yamlParams('path_to_yaml_file')

#       netParams['training']['learning']['etaW'] = 0.01
#       print('Simulation step size        ', netParams['simulation']['Ts'])
#       print('Spiking neuron time constant', netParams['neuron']['tauSr'])
#       print('Spiking neuron threshold    ', netParams['neuron']['theta'])

#       netParams.save('filename.yaml')
#   '''
#   def __init__(self, parameter_file_path):
#       with open(parameter_file_path, 'r') as param_file:
#           self.parameters = yaml.safe_load(param_file)

#   # Allow dictionary like access
#   def __getitem__(self, key):
#       return self.parameters[key]

#   def __setitem__(self, key, value):
#       self.parameters[key] = value

#   def save(self, filename):
#       with open(filename, 'w') as f:
#           yaml.dump(self.parameters, f)

# class spikeLayer():
[docs]class spikeLayer(torch.nn.Module): ''' This class defines the main engine of SLAYER. It provides necessary functions for describing a SNN layer. The input to output connection can be fully-connected, convolutional, or aggregation (pool) It also defines the psp operation and spiking mechanism of a spiking neuron in the layer. **Important:** It assumes all the tensors that are being processed are 5 dimensional. (Batch, Channels, Height, Width, Time) or ``NCHWT`` format. The user must make sure that an input of correct dimension is supplied. *If the layer does not have spatial dimension, the neurons can be distributed along either Channel, Height or Width dimension where Channel * Height * Width is equal to number of neurons. It is recommended (for speed reasons) to define the neuons in Channels dimension and make Height and Width dimension one.* Arguments: * ``neuronDesc`` (``slayerParams.yamlParams``): spiking neuron descriptor. .. code-block:: python neuron: type: SRMALPHA # neuron type theta: 10 # neuron threshold tauSr: 10.0 # neuron time constant tauRef: 1.0 # neuron refractory time constant scaleRef: 2 # neuron refractory response scaling (relative to theta) tauRho: 1 # spike function derivative time constant (relative to theta) scaleRho: 1 # spike function derivative scale factor * ``simulationDesc`` (``slayerParams.yamlParams``): simulation descriptor .. code-block:: python simulation: Ts: 1.0 # sampling time (ms) tSample: 300 # time length of sample (ms) * ``fullRefKernel`` (``bool``, optional): high resolution refractory kernel (the user shall not use it in practice) Usage: >>> snnLayer = slayer.spikeLayer(neuronDesc, simulationDesc) ''' def __init__(self, neuronDesc, simulationDesc, fullRefKernel = False): super(spikeLayer, self).__init__() self.neuron = neuronDesc self.simulation = simulationDesc self.fullRefKernel = fullRefKernel # self.srmKernel = self.calculateSrmKernel() # self.refKernel = self.calculateRefKernel() self.register_buffer('srmKernel', self.calculateSrmKernel()) self.register_buffer('refKernel', self.calculateRefKernel()) def calculateSrmKernel(self): srmKernel = self._calculateAlphaKernel(self.neuron['tauSr']) # TODO implement for different types of kernels return torch.FloatTensor(srmKernel) # return torch.FloatTensor( self._zeroPadAndFlip(srmKernel)) # to be removed later when custom cuda code is implemented def calculateRefKernel(self): if self.fullRefKernel: refKernel = self._calculateAlphaKernel(tau=self.neuron['tauRef'], mult = -self.neuron['scaleRef'] * self.neuron['theta'], EPSILON = 0.0001) # This gives the high precision refractory kernel as MATLAB implementation, however, it is expensive else: refKernel = self._calculateAlphaKernel(tau=self.neuron['tauRef'], mult = -self.neuron['scaleRef'] * self.neuron['theta']) # TODO implement for different types of kernels return torch.FloatTensor(refKernel) def _calculateAlphaKernel(self, tau, mult = 1, EPSILON = 0.01): # could be made faster... NOT A PRIORITY NOW eps = [] # tauSr = self.neuron['tauSr'] for t in np.arange(0, self.simulation['tSample'], self.simulation['Ts']): epsVal = mult * t / tau * math.exp(1 - t / tau) if abs(epsVal) < EPSILON and t > tau: break eps.append(epsVal) return eps def _zeroPadAndFlip(self, kernel): if (len(kernel)%2) == 0: kernel.append(0) prependedZeros = np.zeros((len(kernel) - 1)) return np.flip( np.concatenate( (prependedZeros, kernel) ) ).tolist()
[docs] def psp(self, spike): ''' Applies psp filtering to spikes. The output tensor dimension is same as input. Arguments: * ``spike``: input spike tensor. Usage: >>> filteredSpike = snnLayer.psp(spike) ''' return _pspFunction.apply(spike, self.srmKernel, self.simulation['Ts'])
[docs] def pspLayer(self): ''' Returns a function that can be called to apply psp filtering to spikes. The output tensor dimension is same as input. The initial psp filter corresponds to the neuron psp filter. The psp filter is learnable. NOTE: the learned psp filter must be reversed because PyTorch performs conrrelation operation. Usage: >>> pspLayer = snnLayer.pspLayer() >>> filteredSpike = pspLayer(spike) ''' return _pspLayer(self.srmKernel, self.simulation['Ts'])
[docs] def pspFilter(self, nFilter, filterLength, filterScale=1): ''' Returns a function that can be called to apply a bank of temporal filters. The output tensor is of same dimension as input except the channel dimension is scaled by number of filters. The initial filters are initialized using default PyTorch initializaion for conv layer. The filter banks are learnable. NOTE: the learned psp filter must be reversed because PyTorch performs conrrelation operation. Arguments: * ``nFilter``: number of filters in the filterbank. * ``filterLength``: length of filter in number of time bins. * ``filterScale``: initial scaling factor for filter banks. Default: 1. Usage: >>> pspFilter = snnLayer.pspFilter() >>> filteredSpike = pspFilter(spike) ''' return _pspFilter(nFilter, filterLength, self.simulation['Ts'], filterScale)
def replicateInTime(self, input, mode='nearest'): Ns = int(self.simulation['tSample'] / self.simulation['Ts']) N, C, H, W = input.shape # output = F.pad(input.reshape(N, C, H, W, 1), pad=(Ns-1, 0, 0, 0, 0, 0), mode='replicate') if mode == 'nearest': output = F.interpolate(input.reshape(N, C, H, W, 1), size=(H, W, Ns), mode='nearest') return output
[docs] def dense(self, inFeatures, outFeatures, weightScale=10, preHookFx=None): # default weight scaling of 10 ''' Returns a function that can be called to apply dense layer mapping to input tensor per time instance. It behaves similar to ``torch.nn.Linear`` applied for each time instance. Arguments: * ``inFeatures`` (``int``, tuple of two ints, tuple of three ints): dimension of input features (Width, Height, Channel) that represents the number of input neurons. * ``outFeatures`` (``int``): number of output neurons. * ``weightScale``: sale factor of default initialized weights. Default: 10 * ``preHookFx``: a function that operates on weight before applying it. Could be used for quantization etc. Usage: >>> fcl = snnLayer.dense(2048, 512) # takes (N, 2048, 1, 1, T) tensor >>> fcl = snnLayer.dense((128, 128, 2), 512) # takes (N, 2, 128, 128, T) tensor >>> output = fcl(input) # output will be (N, 512, 1, 1, T) tensor ''' return _denseLayer(inFeatures, outFeatures, weightScale, preHookFx)
[docs] def conv(self, inChannels, outChannels, kernelSize, stride=1, padding=0, dilation=1, groups=1, weightScale=100, preHookFx=None): # default weight scaling of 100 ''' Returns a function that can be called to apply conv layer mapping to input tensor per time instance. It behaves same as ``torch.nn.conv2d`` applied for each time instance. Arguments: * ``inChannels`` (``int``): number of channels in input * ``outChannels`` (``int``): number of channls produced by convoluion * ``kernelSize`` (``int`` or tuple of two ints): size of the convolving kernel * ``stride`` (``int`` or tuple of two ints): stride of the convolution. Default: 1 * ``padding`` (``int`` or tuple of two ints): zero-padding added to both sides of the input. Default: 0 * ``dilation`` (``int`` or tuple of two ints): spacing between kernel elements. Default: 1 * ``groups`` (``int`` or tuple of two ints): number of blocked connections from input channels to output channels. Default: 1 * ``weightScale``: sale factor of default initialized weights. Default: 100 * ``preHookFx``: a function that operates on weight before applying it. Could be used for quantization etc. The parameters ``kernelSize``, ``stride``, ``padding``, ``dilation`` can either be: - a single ``int`` -- in which case the same value is used for the height and width dimension - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, and the second `int` for the width dimension Usage: >>> conv = snnLayer.conv(2, 32, 5) # 32C5 flter >>> output = conv(input) # must have 2 channels ''' return _convLayer(inChannels, outChannels, kernelSize, stride, padding, dilation, groups, weightScale, preHookFx)
[docs] def pool(self, kernelSize, stride=None, padding=0, dilation=1, preHookFx=None): ''' Returns a function that can be called to apply pool layer mapping to input tensor per time instance. It behaves same as ``torch.nn.``:sum pooling applied for each time instance. Arguments: * ``kernelSize`` (``int`` or tuple of two ints): the size of the window to pool over * ``stride`` (``int`` or tuple of two ints): stride of the window. Default: `kernelSize` * ``padding`` (``int`` or tuple of two ints): implicit zero padding to be added on both sides. Default: 0 * ``dilation`` (``int`` or tuple of two ints): a parameter that controls the stride of elements in the window. Default: 1 * ``preHookFx``: a function that operates on weight before applying it. Could be used for quantization etc. The parameters ``kernelSize``, ``stride``, ``padding``, ``dilation`` can either be: - a single ``int`` -- in which case the same value is used for the height and width dimension - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, and the second `int` for the width dimension Usage: >>> pool = snnLayer.pool(4) # 4x4 pooling >>> output = pool(input) ''' return _poolLayer(self.neuron['theta'], kernelSize, stride, padding, dilation, preHookFx)
[docs] def convTranspose(self, inChannels, outChannels, kernelSize, stride=1, padding=0, dilation=1, groups=1, weightScale=100, preHookFx=None): ''' Returns a function that can be called to apply conv layer mapping to input tensor per time instance. It behaves the same as ``torch.nn.ConvTranspose3d`` applied for each time instance. Arguments: * ``inChannels`` (``int``): number of channels in input * ``outChannels`` (``int``): number of channels produced by transposed convolution * ``kernelSize`` (``int`` or tuple of two ints): size of ransposed convolution kernel * ``stride`` (``int`` or tuple of two ints): stride of the transposed convolution. Default: 1 * ``padding`` (``int`` or tuple of two ints): amount of implicit zero-padding added to both sides of the input. Default: 0 * ``dilation`` (``int`` or tuple of two ints): spacing between kernel elements. Default: 1 * ``groups`` (``int`` or tuple of two ints): number of blocked connections from input channels to output channels. Default: 1 * ``weightScale`` : scale factor of default initialized weights. Default: 100 * ``preHookFx``: a function that operates on weights before applying it. Could be used for quantization etc. The parameters kernelSize, stride, padding, dilation can either be: - a single ``int`` -- in which case the same value is used for the height and width dimension - a `tuple` of two ints -- in which case, the first `int` is used for the height dimension, and the second is used for the width dimension Usage: >>> convT = snnLayer.convTranspose(32, 2, 5) # 2T5 flter, the opposite of 32C5 filter >>> output = convT(input) ''' return _convTransposeLayer(inChannels, outChannels, kernelSize, stride, padding, dilation, groups, weightScale, preHookFx)
[docs] def unpool(self, kernelSize, stride=None, padding=0, dilation=1, preHookFx=None): ''' Returns a function that can be called to apply unpool layer mapping to input tensor per time instance. It behaves same as ``torch.nn.`` unpool layers. Arguments: * ``kernelSize`` (``int`` or tuple of two ints): the size of the window to unpool over * ``stride`` (``int`` or tuple of two ints): stride of the window. Default: `kernelSize` * ``padding`` (``int`` or tuple of two ints): implicit zero padding to be added on both sides. Default: 0 * ``dilation`` (``int`` or tuple of two ints): a parameter that controls the stride of elements in the window. Default: 1 * ``preHookFx``: a function that operates on weight before applying it. Could be used for quantization etc. The parameters ``kernelSize``, ``stride``, ``padding``, ``dialtion`` can either be: - a single ``int`` -- in which case the same value is used for the height and width dimension - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, and the second `int` for the width dimension Usage: >>> unpool = snnLayer.unpool(2) # 2x2 unpooling >>> output = unpool(input) ''' return _unpoolLayer(self.neuron['theta'], kernelSize, stride, padding, dilation, preHookFx)
[docs] def dropout(self, p=0.5, inplace=False): ''' Returns a function that can be called to apply dropout layer to the input tensor. It behaves similar to ``torch.nn.Dropout``. However, dropout over time dimension is preserved, i.e. if a neuron is dropped, it remains dropped for entire time duration. Arguments: * ``p``: dropout probability. * ``inplace`` (``bool``): inplace opeartion flag. Usage: >>> drop = snnLayer.dropout(0.2) >>> output = drop(input) ''' return _dropoutLayer(p, inplace)
[docs] def delayShift(self, input, delay, Ts=1): ''' Applies delay in time dimension (assumed to be the last dimension of the tensor) of the input tensor. The autograd backward link is established as well. Arguments: * ``input``: input Torch tensor. * ``delay`` (``float`` or Torch tensor): amount of delay to apply. Same delay is applied to all the inputs if ``delay`` is ``float`` or Torch tensor of size 1. If the Torch tensor has size more than 1, its dimension must match the dimension of input tensor except the last dimension. * ``Ts``: sampling time of the delay. Default is 1. Usage: >>> delayedInput = slayer.delayShift(input, 5) ''' return _delayFunctionNoGradient.apply(input, delay, Ts)
[docs] def delay(self, inputSize): ''' Returns a function that can be called to apply delay opeartion in time dimension of the input tensor. The delay parameter is available as ``delay.delay`` and is initialized uniformly between 0ms and 1ms. The delay parameter is stored as float values, however, it is floored during actual delay applicaiton internally. The delay values are not clamped to zero. To maintain the causality of the network, one should clamp the delay values explicitly to ensure positive delays. Arguments: * ``inputSize`` (``int`` or tuple of three ints): spatial shape of the input signal in CHW format (Channel, Height, Width). If integer value is supplied, it refers to the number of neurons in channel dimension. Heighe and Width are assumed to be 1. Usage: >>> delay = snnLayer.delay((C, H, W)) >>> delayedSignal = delay(input) Always clamp the delay after ``optimizer.step()``. >>> optimizer.step() >>> delay.delay.data.clamp_(0) ''' return _delayLayer(inputSize, self.simulation['Ts'])
# def applySpikeFunction(self, membranePotential): # return _spikeFunction.apply(membranePotential, self.refKernel, self.neuron, self.simulation['Ts'])
[docs] def spike(self, membranePotential): ''' Applies spike function and refractory response. The output tensor dimension is same as input. ``membranePotential`` will reflect spike and refractory behaviour as well. Arguments: * ``membranePotential``: subthreshold membrane potential. Usage: >>> outSpike = snnLayer.spike(membranePotential) ''' return _spikeFunction.apply(membranePotential, self.refKernel, self.neuron, self.simulation['Ts'])
class _denseLayer(nn.Conv3d): def __init__(self, inFeatures, outFeatures, weightScale=1, preHookFx=None): ''' ''' # extract information for kernel and inChannels if type(inFeatures) == int: kernel = (1, 1, 1) inChannels = inFeatures elif len(inFeatures) == 2: kernel = (inFeatures[1], inFeatures[0], 1) inChannels = 1 elif len(inFeatures) == 3: kernel = (inFeatures[1], inFeatures[0], 1) inChannels = inFeatures[2] else: raise Exception('inFeatures should not be more than 3 dimension. It was: {}'.format(inFeatures.shape)) # print('Kernel Dimension:', kernel) # print('Input Channels :', inChannels) if type(outFeatures) == int: outChannels = outFeatures else: raise Exception('outFeatures should not be more than 1 dimesnion. It was: {}'.format(outFeatures.shape)) # print('Output Channels :', outChannels) super(_denseLayer, self).__init__(inChannels, outChannels, kernel, bias=False) if weightScale != 1: self.weight = torch.nn.Parameter(weightScale * self.weight) # scale the weight if needed # print('In dense, using weightScale of', weightScale) self.preHookFx = preHookFx def forward(self, input): ''' ''' if self.preHookFx is None: return F.conv3d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) else: return F.conv3d(input, self.preHookFx(self.weight), self.bias, self.stride, self.padding, self.dilation, self.groups) class _convLayer(nn.Conv3d): ''' ''' def __init__(self, inFeatures, outFeatures, kernelSize, stride=1, padding=0, dilation=1, groups=1, weightScale=1, preHookFx=None): inChannels = inFeatures outChannels = outFeatures # kernel if type(kernelSize) == int: kernel = (kernelSize, kernelSize, 1) elif len(kernelSize) == 2: kernel = (kernelSize[0], kernelSize[1], 1) else: raise Exception('kernelSize can only be of 1 or 2 dimension. It was: {}'.format(kernelSize.shape)) # stride if type(stride) == int: stride = (stride, stride, 1) elif len(stride) == 2: stride = (stride[0], stride[1], 1) else: raise Exception('stride can be either int or tuple of size 2. It was: {}'.format(stride.shape)) # padding if type(padding) == int: padding = (padding, padding, 0) elif len(padding) == 2: padding = (padding[0], padding[1], 0) else: raise Exception('padding can be either int or tuple of size 2. It was: {}'.format(padding.shape)) # dilation if type(dilation) == int: dilation = (dilation, dilation, 1) elif len(dilation) == 2: dilation = (dilation[0], dilation[1], 1) else: raise Exception('dilation can be either int or tuple of size 2. It was: {}'.format(dilation.shape)) # groups # no need to check for groups. It can only be int # print('inChannels :', inChannels) # print('outChannels:', outChannels) # print('kernel :', kernel, kernelSize) # print('stride :', stride) # print('padding :', padding) # print('dilation :', dilation) # print('groups :', groups) super(_convLayer, self).__init__(inChannels, outChannels, kernel, stride, padding, dilation, groups, bias=False) if weightScale != 1: self.weight = torch.nn.Parameter(weightScale * self.weight) # scale the weight if needed # print('In conv, using weightScale of', weightScale) self.preHookFx = preHookFx def forward(self, input): ''' ''' if self.preHookFx is None: return F.conv3d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) else: return F.conv3d(input, self.preHookFx(self.weight), self.bias, self.stride, self.padding, self.dilation, self.groups) class _poolLayer(nn.Conv3d): ''' ''' def __init__(self, theta, kernelSize, stride=None, padding=0, dilation=1, preHookFx=None): # kernel if type(kernelSize) == int: kernel = (kernelSize, kernelSize, 1) elif len(kernelSize) == 2: kernel = (kernelSize[0], kernelSize[1], 1) else: raise Exception('kernelSize can only be of 1 or 2 dimension. It was: {}'.format(kernelSize.shape)) # stride if stride is None: stride = kernel elif type(stride) == int: stride = (stride, stride, 1) elif len(stride) == 2: stride = (stride[0], stride[1], 1) else: raise Exception('stride can be either int or tuple of size 2. It was: {}'.format(stride.shape)) # padding if type(padding) == int: padding = (padding, padding, 0) elif len(padding) == 2: padding = (padding[0], padding[1], 0) else: raise Exception('padding can be either int or tuple of size 2. It was: {}'.format(padding.shape)) # dilation if type(dilation) == int: dilation = (dilation, dilation, 1) elif len(dilation) == 2: dilation = (dilation[0], dilation[1], 1) else: raise Exception('dilation can be either int or tuple of size 2. It was: {}'.format(dilation.shape)) # print('theta :', theta) # print('kernel :', kernel, kernelSize) # print('stride :', stride) # print('padding :', padding) # print('dilation :', dilation) super(_poolLayer, self).__init__(1, 1, kernel, stride, padding, dilation, bias=False) # set the weights to 1.1*theta and requires_grad = False self.weight = torch.nn.Parameter(torch.FloatTensor(1.1 * theta * np.ones((self.weight.shape))).to(self.weight.device), requires_grad = False) # print('In pool layer, weight =', self.weight.cpu().data.numpy().flatten(), theta) self.preHookFx = preHookFx def forward(self, input): ''' ''' device = input.device dtype = input.dtype # add necessary padding for odd spatial dimension # if input.shape[2]%2 != 0: # input = torch.cat((input, torch.zeros((input.shape[0], input.shape[1], 1, input.shape[3], input.shape[4]), dtype=dtype).to(device)), 2) # if input.shape[3]%2 != 0: # input = torch.cat((input, torch.zeros((input.shape[0], input.shape[1], input.shape[2], 1, input.shape[4]), dtype=dtype).to(device)), 3) if input.shape[2]%self.weight.shape[2] != 0: input = torch.cat((input, torch.zeros((input.shape[0], input.shape[1], input.shape[2]%self.weight.shape[2], input.shape[3], input.shape[4]), dtype=dtype).to(device)), 2) if input.shape[3]%self.weight.shape[3] != 0: input = torch.cat((input, torch.zeros((input.shape[0], input.shape[1], input.shape[2], input.shape[3]%self.weight.shape[3], input.shape[4]), dtype=dtype).to(device)), 3) dataShape = input.shape if self.preHookFx is None: result = F.conv3d(input.reshape((dataShape[0], 1, dataShape[1] * dataShape[2], dataShape[3], dataShape[4])), self.weight, self.bias, self.stride, self.padding, self.dilation) else: result = F.conv3d(input.reshape((dataShape[0], 1, dataShape[1] * dataShape[2], dataShape[3], dataShape[4])), self.preHooFx(self.weight), self.bias, self.stride, self.padding, self.dilation) # print(result.shape) return result.reshape((result.shape[0], dataShape[1], -1, result.shape[3], result.shape[4])) class _convTransposeLayer(nn.ConvTranspose3d): ''' ''' def __init__(self, inFeatures, outFeatures, kernelSize, stride=1, padding=0, dilation=1, groups=1, weightScale=1, preHookFx=None): inChannels = inFeatures outChannels = outFeatures # kernel if type(kernelSize) == int: kernel = (kernelSize, kernelSize, 1) elif len(kernelSize) == 2: kernel = (kernelSize[0], kernelSize[1], 1) else: raise Exception('kernelSize can only be of 1 or 2 dimension. It was: {}'.format(kernelSize.shape)) # stride if type(stride) == int: stride = (stride, stride, 1) elif len(stride) == 2: stride = (stride[0], stride[1], 1) else: raise Exception('stride can be either int or tuple of size 2. It was: {}'.format(stride.shape)) # padding if type(padding) == int: padding = (padding, padding, 0) elif len(padding) == 2: padding = (padding[0], padding[1], 0) else: raise Exception('padding can be either int or tuple of size 2. It was: {}'.format(padding.shape)) # dilation if type(dilation) == int: dilation = (dilation, dilation, 1) elif len(dilation) == 2: dilation = (dilation[0], dilation[1], 1) else: raise Exception('dilation can be either int or tuple of size 2. It was: {}'.format(dilation.shape)) # groups # no need to check for groups. It can only be int super(_convTransposeLayer, self).__init__(inChannels, outChannels, kernel, stride, padding, 0, groups, False, dilation) if weightScale != 1: self.weight = torch.nn.Parameter(weightScale * self.weight) # scale the weight if needed self.preHookFx = preHookFx def forward(self, input): ''' ''' if self.preHookFx is None: return F.conv_transpose3d( input, self.weight, self.bias, self.stride, self.padding, self.output_padding, self.groups, self.dilation, ) else: return F.conv_transpose3d( input, self.preHookFx(self.weight), self.bias, self.stride, self.padding, self.output_padding, self.groups, self.dilation, ) class _unpoolLayer(nn.ConvTranspose3d): ''' ''' def __init__(self, theta, kernelSize, stride=None, padding=0, dilation=1, preHookFx=None): # kernel if type(kernelSize) == int: kernel = (kernelSize, kernelSize, 1) elif len(kernelSize) == 2: kernel = (kernelSize[0], kernelSize[1], 1) else: raise Exception('kernelSize can only be of 1 or 2 dimension. It was: {}'.format(kernelSize.shape)) # stride if stride is None: stride = kernel elif type(stride) == int: stride = (stride, stride, 1) elif len(stride) == 2: stride = (stride[0], stride[1], 1) else: raise Exception('stride can be either int or tuple of size 2. It was: {}'.format(stride.shape)) # padding if type(padding) == int: padding = (padding, padding, 0) elif len(padding) == 2: padding = (padding[0], padding[1], 0) else: raise Exception('padding can be either int or tuple of size 2. It was: {}'.format(padding.shape)) # dilation if type(dilation) == int: dilation = (dilation, dilation, 1) elif len(dilation) == 2: dilation = (dilation[0], dilation[1], 1) else: raise Exception('dilation can be either int or tuple of size 2. It was: {}'.format(dilation.shape)) super(_unpoolLayer, self).__init__(1, 1, kernel, stride, padding, 0, 1, False, dilation) self.weight = torch.nn.Parameter(torch.FloatTensor(1.1 * theta * np.ones((self.weight.shape))).to(self.weight.device), requires_grad=False) self.preHookFx = preHookFx def forward(self, input): ''' ''' # device = input.device # dtype = input.dtype # # add necessary padding for odd spatial dimension # This is not needed as unpool multiplies the spatial dimension, hence it is always fine # if input.shape[2]%self.weight.shape[2] != 0: # input = torch.cat( # ( # input, # torch.zeros( # (input.shape[0], input.shape[1], input.shape[2]%self.weight.shape[2], input.shape[3], input.shape[4]), # dtype=dtype # ).to(device) # ), # dim=2, # ) # if input.shape[3]%self.weight.shape[3] != 0: # input = torch.cat( # ( # input, # torch.zeros( # (input.shape[0], input.shape[1], input.shape[2], input.shape[3]%self.weight.shape[3], input.shape[4]), # dtype=dtype # ), # dim=3, # ) # ) dataShape = input.shape if self.preHookFx is None: result = F.conv_transpose3d( input.reshape((dataShape[0], 1, -1, dataShape[3], dataShape[4])), self.weight, self.bias, self.stride, self.padding, self.output_padding, self.groups, self.dilation, ) else: result = F.conv_transpose3d( input.reshape((dataShape[0], 1, -1, dataShape[3], dataShape[4])), self.preHookFx(self.weight), self.bias, self.stride, self.padding, self.output_padding, self.groups, self.dilation, ) return result.reshape((result.shape[0], dataShape[1], -1, result.shape[3], result.shape[4])) class _dropoutLayer(nn.Dropout3d): ''' ''' # def __init__(self, p=0.5, inplace=False): # super(_dropoutLayer, self)(p, inplace) ''' ''' def forward(self, input): inputShape = input.shape return F.dropout3d(input.reshape((inputShape[0], -1, 1, 1, inputShape[-1])), self.p, self.training, self.inplace).reshape(inputShape) class _pspLayer(nn.Conv3d): ''' ''' def __init__(self, filter, Ts): inChannels = 1 outChannels = 1 kernel = (1, 1, torch.numel(filter)) self.Ts = Ts super(_pspLayer, self).__init__(inChannels, outChannels, kernel, bias=False) # print(filter) # print(np.flip(filter.cpu().data.numpy()).reshape(self.weight.shape)) # print(torch.FloatTensor(np.flip(filter.cpu().data.numpy()).copy())) flippedFilter = torch.FloatTensor(np.flip(filter.cpu().data.numpy()).copy()).reshape(self.weight.shape) self.weight = torch.nn.Parameter(flippedFilter.to(self.weight.device), requires_grad = True) self.pad = torch.nn.ConstantPad3d(padding=(torch.numel(filter)-1, 0, 0, 0, 0, 0), value=0) def forward(self, input): ''' ''' inShape = input.shape inPadded = self.pad(input.reshape((inShape[0], 1, 1, -1, inShape[-1]))) # print((inShape[0], 1, 1, -1, inShape[-1])) # print(input.reshape((inShape[0], 1, 1, -1, inShape[-1])).shape) # print(inPadded.shape) output = F.conv3d(inPadded, self.weight) * self.Ts return output.reshape(inShape) class _pspFilter(nn.Conv3d): ''' ''' def __init__(self, nFilter, filterLength, Ts, filterScale=1): inChannels = 1 outChannels = nFilter kernel = (1, 1, filterLength) super(_pspFilter, self).__init__(inChannels, outChannels, kernel, bias=False) self.Ts = Ts self.pad = torch.nn.ConstantPad3d(padding=(filterLength-1, 0, 0, 0, 0, 0), value=0) if filterScale != 1: self.weight.data *= filterScale def forward(self, input): ''' ''' N, C, H, W, Ns = input.shape inPadded = self.pad(input.reshape((N, 1, 1, -1, Ns))) output = F.conv3d(inPadded, self.weight) * self.Ts return output.reshape((N, -1, H, W, Ns)) class _spikeFunction(torch.autograd.Function): ''' ''' @staticmethod def forward(ctx, membranePotential, refractoryResponse, neuron, Ts): ''' ''' device = membranePotential.device dtype = membranePotential.dtype threshold = neuron['theta'] oldDevice = torch.cuda.current_device() # if device != oldDevice: torch.cuda.set_device(device) # torch.cuda.device(3) # spikeTensor = torch.empty_like(membranePotential) # print('membranePotential :', membranePotential .device) # print('spikeTensor :', spikeTensor .device) # print('refractoryResponse :', refractoryResponse.device) # (membranePotential, spikes) = slayer_cuda.get_spikes_cuda(membranePotential, # torch.empty_like(membranePotential), # tensor for spikes # refractoryResponse, # threshold, # Ts) spikes = slayerCuda.getSpikes(membranePotential.contiguous(), refractoryResponse, threshold, Ts) pdfScale = torch.autograd.Variable(torch.tensor(neuron['scaleRho'] , device=device, dtype=dtype), requires_grad=False) # pdfTimeConstant = torch.autograd.Variable(torch.tensor(neuron['tauRho'] , device=device, dtype=dtype), requires_grad=False) # needs to be scaled by theta pdfTimeConstant = torch.autograd.Variable(torch.tensor(neuron['tauRho'] * neuron['theta'] , device=device, dtype=dtype), requires_grad=False) # needs to be scaled by theta threshold = torch.autograd.Variable(torch.tensor(neuron['theta'] , device=device, dtype=dtype), requires_grad=False) ctx.save_for_backward(membranePotential, threshold, pdfTimeConstant, pdfScale) # torch.cuda.synchronize() # if device != oldDevice: torch.cuda.set_device(oldDevice) # torch.cuda.device(oldDevice) return spikes @staticmethod def backward(ctx, gradOutput): ''' ''' (membranePotential, threshold, pdfTimeConstant, pdfScale) = ctx.saved_tensors spikePdf = pdfScale / pdfTimeConstant * torch.exp( -torch.abs(membranePotential - threshold) / pdfTimeConstant) # return gradOutput, None, None, None # This seems to work better! return gradOutput * spikePdf, None, None, None # plt.figure() # plt.plot(gradOutput[0,5,0,0,:].cpu().data.numpy()) # print (gradOutput[0,0,0,0,:].cpu().data.numpy()) # plt.plot(membranePotential[0,0,0,0,:].cpu().data.numpy()) # plt.plot(spikePdf [0,0,0,0,:].cpu().data.numpy()) # print (spikePdf [0,0,0,0,:].cpu().data.numpy()) # plt.show() # return gradOutput * spikePdf, None, None, None class _pspFunction(torch.autograd.Function): ''' ''' @staticmethod def forward(ctx, spike, filter, Ts): device = spike.device dtype = spike.dtype psp = slayerCuda.conv(spike.contiguous(), filter, Ts) Ts = torch.autograd.Variable(torch.tensor(Ts, device=device, dtype=dtype), requires_grad=False) ctx.save_for_backward(filter, Ts) return psp @staticmethod def backward(ctx, gradOutput): ''' ''' (filter, Ts) = ctx.saved_tensors gradInput = slayerCuda.corr(gradOutput.contiguous(), filter, Ts) if filter.requires_grad is False: gradFilter = None else: gradFilter = None pass return gradInput, gradFilter, None class _delayLayer(nn.Module): ''' ''' def __init__(self, inputSize, Ts): super(_delayLayer, self).__init__() if type(inputSize) == int: inputChannels = inputSize inputHeight = 1 inputWidth = 1 elif len(inputSize) == 3: inputChannels = inputSize[0] inputHeight = inputSize[1] inputWidth = inputSize[2] else: raise Exception('inputSize can only be 1 or 2 dimension. It was: {}'.format(inputSize.shape)) self.delay = torch.nn.Parameter(torch.rand((inputChannels, inputHeight, inputWidth)), requires_grad=True) # self.delay = torch.nn.Parameter(torch.empty((inputChannels, inputHeight, inputWidth)), requires_grad=True) # print('delay:', torch.empty((inputChannels, inputHeight, inputWidth))) self.Ts = Ts def forward(self, input): N, C, H, W, Ns = input.shape if input.numel() != self.delay.numel() * input.shape[-1] * input.shape[0]: return _delayFunction.apply(input, self.delay.repeat((1, H, W)), self.Ts) # different delay per channel else: return _delayFunction.apply(input, self.delay, self.Ts) #different delay per neuron class _delayFunction(torch.autograd.Function): ''' ''' @staticmethod def forward(ctx, input, delay, Ts): ''' ''' device = input.device dtype = input.dtype output = slayerCuda.shift(input.contiguous(), delay.data, Ts) Ts = torch.autograd.Variable(torch.tensor(Ts, device=device, dtype=dtype), requires_grad=False) ctx.save_for_backward(output, delay.data, Ts) return output @staticmethod def backward(ctx, gradOutput): ''' ''' # autograd tested and verified (output, delay, Ts) = ctx.saved_tensors diffFilter = torch.tensor([-1, 1], dtype=gradOutput.dtype).to(gradOutput.device) / Ts outputDiff = slayerCuda.conv(output.contiguous(), diffFilter, 1) # the conv operation should not be scaled by Ts. # As such, the output is -( x[k+1]/Ts - x[k]/Ts ) which is what we want. gradDelay = torch.sum(gradOutput * outputDiff, [0, -1], keepdim=True).reshape(gradOutput.shape[1:-1]) * Ts # no minus needed here, as it is included in diffFilter which is -1 * [1, -1] return slayerCuda.shift(gradOutput.contiguous(), -delay, Ts), gradDelay, None class _delayFunctionNoGradient(torch.autograd.Function): ''' ''' @staticmethod def forward(ctx, input, delay, Ts=1): ''' ''' device = input.device dtype = input.dtype output = slayerCuda.shift(input.contiguous(), delay, Ts) Ts = torch.autograd.Variable(torch.tensor(Ts , device=device, dtype=dtype), requires_grad=False) delay = torch.autograd.Variable(torch.tensor(delay, device=device, dtype=dtype), requires_grad=False) ctx.save_for_backward(delay, Ts) return output @staticmethod def backward(ctx, gradOutput): ''' ''' (delay, Ts) = ctx.saved_tensors return slayerCuda.shift(gradOutput.contiguous(), -delay, Ts), None, None