Source code for netpyne.analysis.spikes

"""
Module for analysis of spiking-related results

"""

try:
    to_unicode = unicode
except NameError:
    to_unicode = str
try:
    basestring
except NameError:
    basestring = str

import pandas as pd
import numpy as np
from numbers import Number
from .utils import exception, syncMeasure
from .tools import getInclude, getSpktSpkid
from .tools import saveData as saveFigData
from ..support.scalebar import add_scalebar
from ..specs import Dict


[docs] @exception def prepareSpikeData( include=['allCells'], sim=None, timeRange=None, maxSpikes=1e8, orderBy='gid', popRates=True, syncLines=True, saveData=False, fileName=None, fileDesc=None, fileType=None, fileDir=None, colorbyPhase=None, **kwargs ): """ Function to prepare data for creating spike-related plots """ print('Preparing spike data...') if not sim: from .. import sim # Replace 'eachPop' with list of pops if 'eachPop' in include: include.remove('eachPop') popLabels = [pop for pop in sim.net.allPops] for popLabel in popLabels: include.append(popLabel) # Select cells to include cells, cellGids, netStimLabels = getInclude(include) popLabels = [] if cells: df = pd.DataFrame.from_records(cells) df = pd.concat([df.drop('tags', axis=1), pd.DataFrame.from_records(df['tags'].tolist())], axis=1) keep = ['pop', 'gid', 'conns'] # if orderBy property doesn't exist or is not numeric, use gid if isinstance(orderBy, basestring) and orderBy not in cells[0]['tags']: orderBy = 'gid' elif orderBy == 'pop': df['popInd'] = df['pop'].astype('category') df['popInd'] = df['popInd'].cat.set_categories(sim.net.pops.keys()) orderBy = 'popInd' elif isinstance(orderBy, basestring) and not isinstance(cells[0]['tags'][orderBy], Number): orderBy = 'gid' if isinstance(orderBy, list): if 'pop' in orderBy: df['popInd'] = df['pop'].astype('category') df['popInd'] = df['popInd'].cat.set_categories(sim.net.pops.keys()) orderBy[orderBy.index('pop')] = 'popInd' keep = keep + list(set(orderBy) - set(keep)) elif orderBy not in keep: keep.append(orderBy) df = df[keep] # preserves original ordering: popLabels = [pop for pop in sim.net.allPops if pop in df['pop'].unique()] if netStimLabels: popLabels.append('NetStims') if len(cellGids) > 0: try: sel, spkts, spkgids = getSpktSpkid( cellGids=[] if include == ['allCells'] else cellGids, timeRange=timeRange ) # using [] is faster for all cells except: import sys print((sys.exc_info())) spkgids, spkts = [], [] sel = pd.DataFrame(columns=['spkt', 'spkid']) df.set_index('gid', inplace=True) # Order by if cells and len(df) > 0: ylabelText = 'Cells (ordered by %s)' % (orderBy) df = df.sort_values(by=orderBy) sel['spkind'] = sel['spkid'].apply(df.index.get_loc) else: sel = pd.DataFrame(columns=['spkt', 'spkid', 'spkind']) ylabelText = '' # Add NetStim spikes numCellSpks = len(sel) numNetStims = 0 for netStimLabel in netStimLabels: # print(netStimLabel) # stims = sim.allSimData['stims'].items() # print(stims) # netStimSpks = [spk for cell, stims in sim.allSimData['stims'].items() for stimLabel, stimSpks in stims.items() for spk in stimSpks if stimLabel == netStimLabel] # print(netStimSpks) netStimSpks = [ stimSpks for cell, stims in sim.allSimData['stims'].items() for stimLabel, stimSpks in stims.items() if stimLabel == netStimLabel ] if len(netStimSpks) > 0: lastInd = sel['spkind'].max() if len(sel['spkind']) > 0 else -1 # to start with 0, valid for 'allNetStims' spkindsNew = [lastInd + 1 + i for i in range(len(netStimSpks))] spikesNew = [ (spkt, spkindsNew[icell_stim]) for icell_stim, cell_stim in enumerate(netStimSpks) for spkt in cell_stim ] ns = pd.DataFrame(spikesNew, columns=['spkt', 'spkind']) # ns['spkgidColor'] = [popColors['NetStims']]*len(spikesNew) ns['pop'] = ['NetStims'] * len(spikesNew) sel = pd.concat([sel, ns]) numNetStims += len(spkindsNew) # Calculating Phase of spikes if isinstance(colorbyPhase,dict): try: Signal = colorbyPhase['signal'] except: print("Importing signal for coloring spikes - No information about the signal") Signal = 'LFP' if isinstance(Signal, basestring) or isinstance(Signal,np.ndarray): from scipy import signal, stats, interpolate from scipy.signal import hilbert from math import fmod, pi rawSignal = [] if isinstance(Signal, basestring): # signal provided as a list packed in a pkl if Signal.endswith('.pkl'): import pickle with open(Signal, 'rb') as input_file: rawSignal = pickle.load(input_file) try: fs = colorbyPhase['fs'] except: print("Importing signal for coloring spikes - No frequency sampling provided") fs = 1000.0 # assumes data sampled in ms time = np.linspace(0,len(rawSignal)*1000/fs,len(rawSignal)+1) # in milliseconds elif Signal == 'LFP': try: electrode = colorbyPhase['electrode'] if electrode > sim.net.recXElectrode.nsites: print('Wrong electrode number for coloring spikes according to LFP phase - Assigning first element in LFP recording setup') electrode = 1 except: electrode = 1 rawSignal = [sim.allSimData['LFP'][n][electrode-1] for n in range(len(sim.allSimData['LFP']))] fs = 1000.0/sim.cfg.recordStep time = np.linspace(0,len(rawSignal)*1000/fs,len(rawSignal)+1) # in milliseconds else: print('No signal recovered to color spikes according to its phase') # it is an array else: rawSignal = Signal.tolist() try: fs = colorbyPhase['fs'] except: print("Importing signal for coloring spikes - No frequency sampling provided") fs = 1000.0 # assumes data sampled in ms time = np.linspace(0,len(rawSignal)*1000/fs,len(rawSignal)+1) # in milliseconds # Processing rawSignal rawSignal.append(2*rawSignal[-1]-rawSignal[-2]) # extrapolation for the last element rawSignal = stats.zscore(np.float32(rawSignal)) rawSignal_ = np.r_[rawSignal[-1::-1],rawSignal,rawSignal[-1::-1]] # Reflect signal to minimize edge artifacts nyquist = fs/2.0 # parameters provided to filter the signal - setting defaults otherwise try: filtOrder = colorbyPhase['filtOrder'] except: filtOrder = 3 try: filtFreq = colorbyPhase['filtFreq'] except: filtFreq = [1,500] if isinstance(filtFreq, list): # bandpass Wn = [filtFreq[0]/nyquist, filtFreq[1]/nyquist] b, a = signal.butter(filtOrder, Wn, btype='bandpass') elif isinstance(filtFreq, Number): # lowpass Wn = filtFreq/nyquist b, a = signal.butter(filtOrder, Wn) rawSignalFiltered_ = signal.filtfilt(b, a, rawSignal_) analytic_signal = hilbert(rawSignalFiltered_)[len(rawSignal):-len(rawSignal)] amplitude_envelope = np.abs(analytic_signal) instantaneous_phase = np.unwrap(np.angle(analytic_signal)) instantaneous_phase_mod = [(fmod(instantaneous_phase[nn]+pi,2*pi)-pi)*(180/pi) for nn in range(len(instantaneous_phase))] instantaneous_phase = np.r_[instantaneous_phase_mod] f_Rhythm = interpolate.interp1d(time,instantaneous_phase) sel['spkPhase'] = sel['spkt'].apply(f_Rhythm) else: print('No signal recovered to color spikes according to its phase') if len(cellGids) > 0 and numNetStims: ylabelText = ylabelText + ' and NetStims (at the end)' elif numNetStims: ylabelText = ylabelText + 'NetStims' if numCellSpks + numNetStims == 0: print('No spikes available to plot raster') return None # Time Range if timeRange == [0, sim.cfg.duration]: pass elif timeRange is None: timeRange = [0, sim.cfg.duration] else: sel = sel.query('spkt >= @timeRange[0] and spkt <= @timeRange[1]') # Limit to maxSpikes if len(sel) > maxSpikes: print((' Showing only the first %i out of %i spikes' % (maxSpikes, len(sel)))) # Limit num of spikes if numNetStims: # sort first if have netStims sel = sel.sort_values(by='spkt') sel = sel.iloc[:maxSpikes] timeRange[1] = sel['spkt'].max() # Calculate plot statistics totalSpikes = len(sel) numCells = len(cells) firingRate = ( float(totalSpikes) / (numCells + numNetStims) / (timeRange[1] - timeRange[0]) * 1e3 if totalSpikes > 0 else 0 ) if cells: gidPops = df['pop'].tolist() conns = df['conns'].tolist() popNumCells = [float(gidPops.count(pop)) for pop in popLabels] if numCellSpks else [0] * len(popLabels) cellNumConns = [len(conn) for conn in conns] popNumConns = [ sum([cellNumConn for cellIndex, cellNumConn in enumerate(cellNumConns) if gidPops[cellIndex] == pop]) for pop in popLabels ] totalConnections = sum([len(conns) for conns in df['conns']]) connsPerCell = totalConnections / float(numCells) if numCells > 0 else 0 popConnsPerCell = [ popNumConns[popIndex] / popNumCells[popIndex] if popNumConns[popIndex] > 0 else 0 for popIndex, pop in enumerate(popLabels) ] # popConnsPerCell = [popNumConns[popIndex]/popNumCells[popIndex] for popIndex, pop in enumerate(popLabels)] if numNetStims: try: popNumCells[-1] = numNetStims except: popNumCells = [numNetStims] title = 'Raster plot of spiking' legendLabels = [] # Add population spiking info to plot if popRates: avgRates = {} tsecs = (timeRange[1] - timeRange[0]) / 1e3 for i, (pop, popNum) in enumerate(zip(popLabels, popNumCells)): if numCells > 0 and pop != 'NetStims': if numCellSpks == 0: avgRates[pop] = 0 else: avgRates[pop] = ( len( [ spkid for spkid in sel['spkind'].iloc[: numCellSpks - 1] if df['pop'].iloc[int(spkid)] == pop ] ) / popNum / tsecs ) if numNetStims: avgRates['NetStims'] = len([spkid for spkid in sel['spkind'].iloc[numCellSpks:]]) / numNetStims / tsecs if numCells > 0: if popRates == 'minimal': legendLabels = [ popLabel + ' (%.3g Hz)' % (avgRates[popLabel]) for popIndex, popLabel in enumerate(popLabels) if popLabel in avgRates ] title = 'cells: %i syn/cell: %0.1f rate: %0.1f Hz' % (numCells, connsPerCell, firingRate) else: legendLabels = [ popLabel + '\n cells: %i\n syn/cell: %0.1f\n rate: %.3g Hz' % (popNumCells[popIndex], popConnsPerCell[popIndex], avgRates[popLabel]) for popIndex, popLabel in enumerate(popLabels) if popLabel in avgRates ] title = 'cells: %i syn/cell: %0.1f rate: %0.1f Hz' % (numCells, connsPerCell, firingRate) if syncLines: title = '%s sync=%0.2f' % (title, syncMeasure()) if 'title' in kwargs: title = kwargs['title'] axisArgs = {'xlabel': 'Time (ms)', 'ylabel': ylabelText, 'title': title} spikeData = { 'spkTimes': sel['spkt'].tolist(), 'spkInds': sel['spkind'].tolist(), 'spkGids': sel['spkid'].tolist(), 'popNumCells': popNumCells, 'popLabels': popLabels, 'numNetStims': numNetStims, 'include': include, 'timeRange': timeRange, 'maxSpikes': maxSpikes, 'orderBy': orderBy, 'axisArgs': axisArgs, 'legendLabels': legendLabels, 'cellGids': df['pop'].index.tolist(), 'cellPops': df['pop'].tolist(), } if colorbyPhase: spikeData.update({'spkPhases': sel['spkPhase'].tolist()}) if 'include_signal' in colorbyPhase and colorbyPhase['include_signal']==True: spikeData.update({'signal': analytic_signal}) spikeData.update({'time': time}) if saveData: saveFigData(spikeData, fileName=fileName, fileDesc='spike_data', fileType=fileType, fileDir=fileDir, sim=sim) return spikeData
[docs] @exception def prepareRaster( include=['allCells'], sim=None, timeRange=None, maxSpikes=1e8, orderBy='gid', popRates=True, colorbyPhase=None, saveData=False, fileName=None, fileDesc=None, fileType=None, fileDir=None, **kwargs ): """ Function to prepare data for creating a raster plot """ figData = prepareSpikeData( include=include, sim=sim, timeRange=timeRange, maxSpikes=maxSpikes, orderBy=orderBy, popRates=popRates, colorbyPhase=colorbyPhase, saveData=saveData, fileName=fileName, fileDesc=fileDesc if fileDesc else 'raster_data', fileType=fileType, fileDir=fileDir, **kwargs ) return figData
[docs] @exception def prepareSpikeHist( include=['allCells', 'eachPop'], sim=None, timeRange=None, maxSpikes=1e8, popRates=True, saveData=False, fileName=None, fileDesc=None, fileType=None, fileDir=None, binSize=5, **kwargs ): """ Function to prepare data for creating a spike histogram plot """ figData = prepareSpikeData( include=include, sim=sim, timeRange=timeRange, maxSpikes=maxSpikes, orderBy='gid', popRates=popRates, saveData=saveData, fileName=fileName, fileDesc=fileDesc, fileType=fileType, fileDir=fileDir, **kwargs ) figData['axisArgs']['ylabel'] = 'Number of spikes' return figData
# ------------------------------------------------------------------------------ # Calculate and print avg pop rates # ------------------------------------------------------------------------------
[docs] @exception def popAvgRates(tranges=None, show=True): """Function to calculate and return average firing rates by population Parameters ---------- tranges : list or tuple The time range or time ranges to calculate firing rates within **Default:** ``None`` uses the entire simulation time range. **Options:** a single time range is defined in a list (``[startTime, stopTime]``) while multiple time ranges should be a tuple of lists (``([start1, stop1], [start2, stop2])``). show : bool Whether or not to print the population firing rates **Default:** ``True`` """ from .. import sim avgRates = Dict() if not hasattr(sim, 'allSimData') or 'spkt' not in sim.allSimData: print('Error: sim.allSimData not available; please call sim.gatherData()') return None spktsAll = sim.allSimData['spkt'] spkidsAll = sim.allSimData['spkid'] spkidsList, spktsList = [], [] if not isinstance(tranges, list): # True or None tranges = [[0, sim.cfg.duration]] if isinstance(tranges, list): # convert single time interval to list if not isinstance(tranges[0], (list, tuple)): tranges = [tranges] # calculate for multiple time intervals if isinstance(tranges[0], (list, tuple)): for trange in tranges: try: spkids, spkts = list( zip( *[ (spkid, spkt) for spkid, spkt in zip(spkidsAll, spktsAll) if trange[0] <= spkt <= trange[1] ] ) ) except: spkids, spkts = [], [] spkidsList.append(spkids) spktsList.append(spkts) else: return avgRates for pop in sim.net.allPops: if len(tranges) > 1: if show: print(' %s ' % (pop)) avgRates[pop] = {} for spkids, spkts, trange in zip(spkidsList, spktsList, tranges): numCells = float(len(sim.net.allPops[pop]['cellGids'])) if numCells > 0: # single time intervals if len(tranges) == 1: tsecs = float((trange[1] - trange[0])) / 1000.0 avgRates[pop] = ( len([spkid for spkid in spkids if sim.net.allCells[int(spkid)]['tags']['pop'] == pop]) / numCells / tsecs ) if show: print(' %s : %.3f Hz' % (pop, avgRates[pop])) # multiple time intervals else: tsecs = float((trange[1] - trange[0])) / 1000.0 avgRates[pop]['%d_%d' % (trange[0], trange[1])] = ( len([spkid for spkid in spkids if sim.net.allCells[int(spkid)]['tags']['pop'] == pop]) / numCells / tsecs ) if show: print( ' (%d - %d ms): %.3f Hz' % (trange[0], trange[1], avgRates[pop]['%d_%d' % (trange[0], trange[1])]) ) return avgRates