"""
Module for gathering data from nodes after a simulation
"""
import os, pickle
from netpyne.support.recxelectrode import RecXElectrode
import numpy as np
from ..specs import Dict, ODict
from . import setup
# ------------------------------------------------------------------------------
# Gather data from nodes
# ------------------------------------------------------------------------------
[docs]
def gatherData(gatherLFP=True, gatherDipole=True, gatherOnlySimData=None, includeSimDataEntries=None, analyze=True):
"""
Function for/to <short description of `netpyne.sim.gather.gatherData`>
Parameters
----------
gatherLFP : bool
<Short description of gatherLFP>
**Default:** ``True``
**Options:** ``<option>`` <description of option>
gatherDipole : bool
<Short description of gatherDipole>
**Default:** ``True``
**Options:** ``<option>`` <description of option>
"""
from .. import sim
sim.timing('start', 'gatherTime')
## Pack data from all hosts
if sim.rank == 0:
print('\nGathering data...')
# flag to avoid saving cell data and cell gids per populations (saves gather time and space; cannot inspect cells and pops)
if gatherOnlySimData is None:
gatherOnlySimData = sim.cfg.gatherOnlySimData
# flag to avoid saving sections data for each cell (saves gather time and space; cannot inspect cell secs or re-simulate)
if not sim.cfg.saveCellSecs:
for cell in sim.net.cells:
cell.secs = {}
cell.secLists = None
# flag to avoid saving conns data for each cell (saves gather time and space; cannot inspect cell conns or re-simulate)
if not sim.cfg.saveCellConns:
for cell in sim.net.cells:
cell.conns = []
# Store conns in a compact list format instead of a long dict format (cfg.compactConnFormat contains list of keys to include)
elif sim.cfg.compactConnFormat:
sim.compactConnFormat()
# remove data structures used to calculate LFP or Dipoles
if (
((gatherLFP and sim.cfg.recordLFP) or (gatherDipole and sim.cfg.recordDipole))
and hasattr(sim.net, 'compartCells')
and sim.cfg.createNEURONObj
):
for cell in sim.net.compartCells:
try:
del cell.imembVec
del cell.imembPtr
del cell._segCoords
except:
pass
for pop in list(sim.net.pops.values()):
try:
del pop._morphSegCoords
except:
pass
simDataVecs = ['spkt', 'spkid', 'stims', 'dipole'] + list(sim.cfg.recordTraces.keys())
if sim.cfg.recordDipolesHNN:
_aggregateDipoles()
simDataVecs.append('dipole')
singleNodeVecs = ['t']
if includeSimDataEntries:
singleNodeVecs = [v for v in singleNodeVecs if v in includeSimDataEntries]
if sim.nhosts > 1: # only gather if >1 nodes
netPopsCellGids = {popLabel: list(pop.cellGids) for popLabel, pop in sim.net.pops.items()}
if includeSimDataEntries:
simData = {key: sim.simData[key] for key in includeSimDataEntries}
else:
simData = sim.simData
if gatherOnlySimData:
nodeData = {
'simData': simData
}
else:
nodeData = {
'simData': simData,
'netCells': [c.__getstate__() for c in sim.net.cells],
'netPopsCellGids': netPopsCellGids,
}
if gatherLFP and hasattr(sim.net, 'recXElectrode'):
nodeData['xElectrodeTransferResistances'] = sim.net.recXElectrode.transferResistances
data = [None] * sim.nhosts
data[0] = {}
for k, v in nodeData.items():
data[0][k] = v
gather = sim.pc.py_alltoall(data)
sim.pc.barrier()
if sim.rank == 0:
sim.allSimData = Dict()
if gatherOnlySimData:
print(' Gathering only sim data...')
# However, still need to ensure these list aren't empty to avoid errors during analyzing/plotting
# TODO: this should be improved by handling abovementioned errors in analyzing/plotting instead of adding this workaround here
allCells = getattr(sim.net, 'allCells', [])
if len(allCells) == 0:
sim.net.allCells = [c.__getstate__() for c in sim.net.cells]
allPops = getattr(sim.net, 'allPops', ODict())
if len(allPops) == 0:
sim.net.allPops = allPops
for popLabel, pop in sim.net.pops.items():
sim.net.allPops[popLabel] = pop.__getstate__() # can't use dict comprehension for OrderedDict
else:
allCells = []
allPops = ODict()
for popLabel, pop in sim.net.pops.items():
allPops[popLabel] = pop.__getstate__() # can't use dict comprehension for OrderedDict
allPopsCellGids = {popLabel: [] for popLabel in netPopsCellGids}
allResistances = {}
for k in list(gather[0]['simData'].keys()): # initialize all keys of allSimData dict
if gatherLFP and k == 'LFP':
sim.allSimData[k] = np.zeros((gather[0]['simData']['LFP'].shape))
elif gatherLFP and k == 'LFPPops':
sim.allSimData[k] = {
p: np.zeros(gather[0]['simData']['LFP'].shape)
for p in gather[0]['simData']['LFPPops'].keys()
}
elif gatherDipole and k == 'dipoleSum':
sim.allSimData[k] = np.zeros((gather[0]['simData']['dipoleSum'].shape))
elif sim.cfg.recordDipolesHNN and k == 'dipole':
for dk in sim.cfg.recordDipolesHNN:
sim.allSimData[k][dk] = np.zeros(len(gather[0]['simData']['dipole'][dk]))
else:
sim.allSimData[k] = {}
for key in singleNodeVecs: # store single node vectors (eg. 't')
sim.allSimData[key] = list(nodeData['simData'][key])
# fill in allSimData taking into account if data is dict of h.Vector (code needs improvement to be more generic)
for node in gather: # concatenate data from each node
if not gatherOnlySimData:
allCells.extend(node['netCells']) # extend allCells list
for popLabel, popCellGids in node['netPopsCellGids'].items():
allPopsCellGids[popLabel].extend(popCellGids)
if 'xElectrodeTransferResistances' in node:
allResistances.update(node['xElectrodeTransferResistances'])
for key, val in node['simData'].items(): # update simData dics of dics of h.Vector
if key in simDataVecs: # simData dicts that contain Vectors
if isinstance(val, dict):
for key2, val2 in val.items():
if isinstance(val2, dict):
sim.allSimData[key].update(Dict({key2: Dict()}))
for stim, val3 in val2.items():
sim.allSimData[key][key2].update(
{stim: list(val3)}
) # udpate simData dicts which are dicts of dicts of Vectors (eg. ['stim']['cell_1']['backgrounsd']=h.Vector)
elif key == 'dipole':
sim.allSimData[key][key2] = np.add(
sim.allSimData[key][key2], val2.as_numpy()
) # add together dipole values from each node
else:
sim.allSimData[key].update(
{key2: list(val2)}
) # udpate simData dicts which are dicts of Vectors (eg. ['v']['cell_1']=h.Vector)
else:
sim.allSimData[key] = list(sim.allSimData[key]) + list(
val
) # udpate simData dicts which are Vectors
elif gatherLFP and key == 'LFP':
sim.allSimData[key] += np.array(val)
elif gatherLFP and key == 'LFPPops':
for p in val:
sim.allSimData[key][p] += np.array(val[p])
elif gatherDipole and key == 'dipoleSum':
sim.allSimData[key] += np.array(val)
elif key not in singleNodeVecs:
sim.allSimData[key].update(val) # update simData dicts which are not Vectors
if len(sim.allSimData['spkt']) > 0:
sim.allSimData['spkt'], sim.allSimData['spkid'] = zip(
*sorted(zip(sim.allSimData['spkt'], sim.allSimData['spkid']))
) # sort spks
sim.allSimData['spkt'], sim.allSimData['spkid'] = list(sim.allSimData['spkt']), list(
sim.allSimData['spkid']
)
if not gatherOnlySimData:
sim.net.allCells = sorted(allCells, key=lambda k: k['gid'])
for popLabel, pop in allPops.items():
pop['cellGids'] = sorted(allPopsCellGids[popLabel])
sim.net.allPops = allPops
if gatherLFP and hasattr(sim.net, 'recXElectrode'):
sim.net.recXElectrode.transferResistances = allResistances
# clean to avoid mem leaks
for node in gather:
if node:
node.clear()
del node
for item in data:
if item:
item.clear()
del item
else: # if single node, save data in same format as for multiple nodes for consistency
if sim.cfg.createNEURONObj:
sim.net.allCells = [Dict(c.__getstate__()) for c in sim.net.cells]
else:
sim.net.allCells = [c.__dict__ for c in sim.net.cells]
sim.net.allPops = ODict()
for popLabel, pop in sim.net.pops.items():
sim.net.allPops[popLabel] = pop.__getstate__() # can't use dict comprehension for OrderedDict
sim.allSimData = Dict()
for k in list(sim.simData.keys()): # initialize all keys of allSimData dict
sim.allSimData[k] = Dict()
for key, val in sim.simData.items(): # update simData dics of dics of h.Vector
if key in simDataVecs + singleNodeVecs: # simData dicts that contain Vectors
if isinstance(val, dict):
for cell, val2 in val.items():
if isinstance(val2, dict):
sim.allSimData[key].update(Dict({cell: Dict()}))
for stim, val3 in val2.items():
sim.allSimData[key][cell].update(
{stim: list(val3)}
) # udpate simData dicts which are dicts of dicts of Vectors (eg. ['stim']['cell_1']['backgrounsd']=h.Vector)
else:
sim.allSimData[key].update(
{cell: list(val2)}
) # udpate simData dicts which are dicts of Vectors (eg. ['v']['cell_1']=h.Vector)
else:
sim.allSimData[key] = list(sim.allSimData[key]) + list(
val
) # udpate simData dicts which are Vectors
else:
sim.allSimData[key] = val # update simData dicts which are not Vectors
## Print statistics
sim.pc.barrier()
if sim.rank == 0:
sim.timing('stop', 'gatherTime')
if sim.cfg.timing:
print((' Done; gather time = %0.2f s.' % sim.timingData['gatherTime']))
if analyze:
print('\nAnalyzing...')
sim.totalSpikes = len(sim.allSimData['spkt'])
sim.totalSynapses = sum([len(cell['conns']) for cell in sim.net.allCells])
if sim.cfg.createPyStruct:
if sim.cfg.compactConnFormat:
preGidIndex = sim.cfg.compactConnFormat.index('preGid') if 'preGid' in sim.cfg.compactConnFormat else 0
sim.totalConnections = sum(
[len(set([conn[preGidIndex] for conn in cell['conns']])) for cell in sim.net.allCells]
)
else:
sim.totalConnections = sum(
[len(set([conn['preGid'] for conn in cell['conns']])) for cell in sim.net.allCells]
)
else:
sim.totalConnections = sim.totalSynapses
sim.numCells = len(sim.net.allCells)
if sim.totalSpikes > 0:
sim.firingRate = float(sim.totalSpikes) / sim.numCells / sim.cfg.duration * 1e3 # Calculate firing rate
else:
sim.firingRate = 0
if sim.numCells > 0:
sim.connsPerCell = sim.totalConnections / float(
sim.numCells
) # Calculate the number of connections per cell
sim.synsPerCell = sim.totalSynapses / float(sim.numCells) # Calculate the number of connections per cell
else:
sim.connsPerCell = 0
sim.synsPerCell = 0
print((' Cells: %i' % (sim.numCells)))
print((' Connections: %i (%0.2f per cell)' % (sim.totalConnections, sim.connsPerCell)))
if sim.totalSynapses != sim.totalConnections:
print((' Synaptic contacts: %i (%0.2f per cell)' % (sim.totalSynapses, sim.synsPerCell)))
if 'runTime' in sim.timingData:
print((' Spikes: %i (%0.2f Hz)' % (sim.totalSpikes, sim.firingRate)))
print((' Simulated time: %0.1f s; %i workers' % (sim.cfg.duration / 1e3, sim.nhosts)))
print((' Run time: %0.2f s' % (sim.timingData['runTime'])))
if sim.cfg.printPopAvgRates and not gatherOnlySimData:
trange = sim.cfg.printPopAvgRates if isinstance(sim.cfg.printPopAvgRates, list) else None
sim.allSimData['popRates'] = sim.analysis.popAvgRates(tranges=trange)
if 'plotfI' in sim.cfg.analysis:
sim.analysis.calculatefI() # need to call here so data is saved to file
sim.allSimData['avgRate'] = sim.firingRate # save firing rate
return sim.allSimData
# ------------------------------------------------------------------------------
# Gather data from files
# ------------------------------------------------------------------------------
[docs]
def gatherDataFromFiles(gatherLFP=True, saveFolder=None, simLabel=None, sim=None, fileType='pkl', saveMerged=False):
"""
Function to gather data from multiple files (from distributed or interval saving)
Parameters
----------
gatherLFP : bool
Whether or not to gather LFP data.
**Default:** ``True`` gathers LFP data if available.
**Options:** ``False`` does not gather LFP data.
saveFolder : str
Name of the directory where data files are located.
**Default:** ``None`` attempts to auto-locate the data directory.
"""
import os
if not sim:
from netpyne import sim
if getattr(sim, 'rank', None) is None:
sim.initialize()
sim.timing('start', 'gatherTime')
if sim.rank == 0:
fileType = fileType.lower()
if fileType not in ['pkl', 'json']:
print(f"Could not gather data from '.{fileType}' files. Only .pkl and .json are supported so far.")
return False
if not simLabel:
simLabel = sim.cfg.simLabel
if not saveFolder:
saveFolder = sim.cfg.saveFolder
nodeDataDir = os.path.join(saveFolder, simLabel + '_node_data')
print(f"\nSearching for .{fileType} node files in {nodeDataDir} ...")
simLabels = [
f.replace(f'_node_0.{fileType}', '') for f in os.listdir(nodeDataDir) if f.endswith(f'_node_0.{fileType}')
]
if len(simLabels) == 0:
print(f"Could not gather data from files. No node files found.")
return False
mergedFiles = []
for simLabel in simLabels:
allSimData = Dict()
allCells = []
allPops = ODict()
print('\nGathering data from files for simulation: %s ...' % (simLabel))
simDataVecs = ['spkt', 'spkid', 'stims'] + list(sim.cfg.recordTraces.keys())
singleNodeVecs = ['t']
if sim.cfg.recordDipolesHNN:
_aggregateDipoles()
simDataVecs.append('dipole')
fileData = {'simData': sim.simData}
fileList = sorted(
[
f
for f in os.listdir(nodeDataDir)
if (f.startswith(simLabel + '_node') and f.endswith(f'.{fileType}'))
]
)
for ifile, file in enumerate(fileList):
print(' Merging data file: %s' % (file))
with open(os.path.join(nodeDataDir, file), 'rb') as openFile:
if fileType == 'pkl':
data = pickle.load(openFile)
elif fileType == 'json':
import json
data = json.load(openFile)
if 'cells' in data.keys():
as_Dict = [cell if isinstance(cell, Dict) else Dict(cell) for cell in data['cells']]
allCells.extend(as_Dict)
if 'pops' in data.keys():
loadedPops = data['pops']
if fileType == 'pkl':
for popLabel, pop in loadedPops.items():
allPops[popLabel] = pop['tags']
elif fileType == 'json':
# if populations order is not preserved (which is inherently the case for JSON), need to sort them again
loadedPops = list(loadedPops.items())
def sort(popKeyAndValue):
# the assumption while sorting is that populations order corresponds to cell gids in this population
cellGids = popKeyAndValue[1]['cellGids']
if len(cellGids) > 0:
return cellGids[0]
else:
return -1
loadedPops.sort(key=sort)
for popLabel, pop in loadedPops:
allPops[popLabel] = pop['tags']
if 'simConfig' in data.keys():
setup.setSimCfg(data['simConfig'])
if 'net' in data and gatherLFP:
if 'recXElectrode' in data['net']:
xElectrode = data['net']['recXElectrode']
if False == isinstance(xElectrode, RecXElectrode):
xElectrode = RecXElectrode.fromJSON(xElectrode)
sim.net.recXElectrode = xElectrode
nodePopsCellGids = {popLabel: list(pop['cellGids']) for popLabel, pop in data['pops'].items()}
if ifile == 0 and gatherLFP and 'LFP' in data['simData']:
lfpData = data['simData']['LFP']
if False == isinstance(lfpData, np.ndarray):
lfpData = np.array(lfpData)
data['simData']['LFP'] = lfpData
allSimData['LFP'] = np.zeros(lfpData.shape)
if 'LFPPops' in data['simData']:
allSimData['LFPPops'] = {
p: np.zeros(lfpData.shape) for p in data['simData']['LFPPops'].keys()
}
for key, value in data['simData'].items():
if key in simDataVecs:
if isinstance(value, dict):
for key2, value2 in value.items():
if isinstance(value2, dict):
allSimData[key].update(Dict({key2: Dict()}))
for stim, value3 in value2.items():
allSimData[key][key2].update({stim: list(value3)})
elif key == 'dipole':
allSimData[key][key2] = np.add(allSimData[key][key2], value2.as_numpy())
else:
allSimData[key].update({key2: list(value2)})
else:
allSimData[key] = list(allSimData[key]) + list(value)
elif gatherLFP and key == 'LFP':
allSimData['LFP'] += np.array(value)
elif gatherLFP and key == 'LFPPops':
for p in value:
allSimData['LFPPops'][p] += np.array(value[p])
elif key == 'dipoleSum':
if key not in allSimData.keys():
allSimData[key] = value
else:
allSimData[key] += value
elif key not in singleNodeVecs:
allSimData[key].update(value)
if file == fileList[0]:
for key in singleNodeVecs:
allSimData[key] = list(fileData['simData'][key])
allPopsCellGids = {popLabel: [] for popLabel in nodePopsCellGids}
for popLabel, popCellGids in nodePopsCellGids.items():
allPopsCellGids[popLabel].extend(popCellGids)
mergedFiles.append(file)
if len(allSimData['spkt']) > 0:
allSimData['spkt'], allSimData['spkid'] = zip(*sorted(zip(allSimData['spkt'], allSimData['spkid'])))
allSimData['spkt'], allSimData['spkid'] = list(allSimData['spkt']), list(allSimData['spkid'])
sim.allSimData = allSimData
sim.net.allCells = sorted(allCells, key=lambda k: k['gid'])
for popLabel, pop in allPops.items():
pop['cellGids'] = sorted(allPopsCellGids[popLabel])
sim.net.allPops = allPops
## Print statistics
sim.pc.barrier()
if sim.rank != 0:
sim.pc.barrier()
else:
sim.timing('stop', 'gatherTime')
if sim.cfg.timing:
print((' Done; gather time = %0.2f s.' % sim.timingData['gatherTime']))
if saveMerged:
print('\nSaving merged data into single file ...')
saved = sim.saveData()
if len(saved) > 0:
# if single file saved successfully, clean up node data
for file in mergedFiles:
path = os.path.join(nodeDataDir, file)
os.remove(path)
print('\nAnalyzing...')
sim.totalSpikes = len(sim.allSimData['spkt'])
sim.totalSynapses = sum([len(cell['conns']) for cell in sim.net.allCells])
if sim.cfg.createPyStruct:
if sim.cfg.compactConnFormat:
preGidIndex = sim.cfg.compactConnFormat.index('preGid') if 'preGid' in sim.cfg.compactConnFormat else 0
sim.totalConnections = sum(
[len(set([conn[preGidIndex] for conn in cell['conns']])) for cell in sim.net.allCells]
)
else:
sim.totalConnections = sum(
[len(set([conn['preGid'] for conn in cell['conns']])) for cell in sim.net.allCells]
)
else:
sim.totalConnections = sim.totalSynapses
sim.numCells = len(sim.net.allCells)
if sim.totalSpikes > 0:
sim.firingRate = float(sim.totalSpikes) / sim.numCells / sim.cfg.duration * 1e3
else:
sim.firingRate = 0
if sim.numCells > 0:
sim.connsPerCell = sim.totalConnections / float(sim.numCells)
sim.synsPerCell = sim.totalSynapses / float(sim.numCells)
else:
sim.connsPerCell = 0
sim.synsPerCell = 0
print((' Cells: %i' % (sim.numCells)))
print((' Connections: %i (%0.2f per cell)' % (sim.totalConnections, sim.connsPerCell)))
if sim.totalSynapses != sim.totalConnections:
print((' Synaptic contacts: %i (%0.2f per cell)' % (sim.totalSynapses, sim.synsPerCell)))
print((' Spikes: %i (%0.2f Hz)' % (sim.totalSpikes, sim.firingRate)))
if 'runTime' in sim.timingData:
print((' Simulated time: %0.1f s; %i workers' % (sim.cfg.duration / 1e3, sim.nhosts)))
print((' Run time: %0.2f s' % (sim.timingData['runTime'])))
if sim.cfg.printPopAvgRates and not sim.cfg.gatherOnlySimData:
trange = sim.cfg.printPopAvgRates if isinstance(sim.cfg.printPopAvgRates, list) else None
sim.allSimData['popRates'] = sim.analysis.popAvgRates(tranges=trange)
if 'plotfI' in sim.cfg.analysis:
sim.analysis.calculatefI()
sim.allSimData['avgRate'] = sim.firingRate
# ------------------------------------------------------------------------------
# Gather tags from cells
# ------------------------------------------------------------------------------
def _gatherAllCellTags():
from .. import sim
data = [{cell.gid: cell.tags for cell in sim.net.cells}] * sim.nhosts # send cells data to other nodes
gather = sim.pc.py_alltoall(data) # collect cells data from other nodes (required to generate connections)
sim.pc.barrier()
allCellTags = {}
for dataNode in gather:
allCellTags.update(dataNode)
# clean to avoid mem leaks
for node in gather:
if node:
node.clear()
del node
for item in data:
if item:
item.clear()
del item
return allCellTags
# ------------------------------------------------------------------------------
# Gather tags from cells
# ------------------------------------------------------------------------------
def _gatherAllCellConnPreGids():
from .. import sim
data = [
{cell.gid: [conn['preGid'] for conn in cell.conns] for cell in sim.net.cells}
] * sim.nhosts # send cells data to other nodes
gather = sim.pc.py_alltoall(data) # collect cells data from other nodes (required to generate connections)
sim.pc.barrier()
allCellConnPreGids = {}
for dataNode in gather:
allCellConnPreGids.update(dataNode)
# clean to avoid mem leaks
for node in gather:
if node:
node.clear()
del node
for item in data:
if item:
item.clear()
del item
return allCellConnPreGids
# ------------------------------------------------------------------------------
# Gather data from nodes
# ------------------------------------------------------------------------------
def _gatherCells():
from .. import sim
## Pack data from all hosts
if sim.rank == 0:
print('\nUpdating sim.net.allCells...')
if sim.nhosts > 1: # only gather if >1 nodes
nodeData = {'netCells': [c.__getstate__() for c in sim.net.cells]}
data = [None] * sim.nhosts
data[0] = {}
for k, v in nodeData.items():
data[0][k] = v
gather = sim.pc.py_alltoall(data)
sim.pc.barrier()
if sim.rank == 0:
allCells = []
# fill in allSimData taking into account if data is dict of h.Vector (code needs improvement to be more generic)
for node in gather: # concatenate data from each node
allCells.extend(node['netCells']) # extend allCells list
sim.net.allCells = sorted(allCells, key=lambda k: k['gid'])
# clean to avoid mem leaks
for node in gather:
if node:
node.clear()
del node
for item in data:
if item:
item.clear()
del item
else: # if single node, save data in same format as for multiple nodes for consistency
sim.net.allCells = [c.__getstate__() for c in sim.net.cells]
# ------------------------------------------------------------------------------
# Aggregate dipole data for each cell on nodes
# ------------------------------------------------------------------------------
def _aggregateDipoles():
from .. import sim
if not hasattr(sim.net, 'compartCells'):
sim.net.compartCells = [c for c in sim.net.cells if type(c) is sim.CompartCell]
for k in sim.cfg.recordDipolesHNN:
sim.simData['dipole'][k] = sim.h.Vector((sim.cfg.duration / sim.cfg.recordStep) + 1)
for cell in sim.net.compartCells:
if hasattr(cell, 'dipole'):
for k, v in sim.cfg.recordDipolesHNN.items():
if cell.tags['pop'] in v:
sim.simData['dipole'][k].add(cell.dipole['hRec'])