"""
Module for analyzing and plotting connectivity-related results
"""
try:
basestring
except NameError:
basestring = str
from netpyne import __gui__
if __gui__:
import matplotlib.pyplot as plt
import numpy as np
from numbers import Number
from math import ceil
from .utils import colorList, exception, _roundFigures, getCellsInclude, getCellsIncludeTags
from .utils import _saveFigData, _showFigure
# -------------------------------------------------------------------------------------------------------------------
## Support function for plotConn() - calculate conn using data from sim object
# -------------------------------------------------------------------------------------------------------------------
def _plotConnCalculateFromSim(
includePre,
includePost,
feature,
orderBy,
groupBy,
groupByIntervalPre,
groupByIntervalPost,
synOrConn,
synMech,
removeWeightNorm,
logPlot=False,
):
# params validation
if groupBy == 'cell':
supportedFeatures = ['weight', 'delay', 'numConns']
else:
supportedFeatures = ['weight', 'delay', 'numConns', 'probability', 'strength', 'convergence', 'divergence']
if feature not in supportedFeatures:
print(f' Unsupported feauture "{feature}". Conn matrix with groupBy="{groupBy}" only supports features in {supportedFeatures}')
return None, None, None
from .. import sim
def list_of_dict_unique_by_key(seq, key):
seen = set()
seen_add = seen.add
try:
return [x for x in seq if x[key] not in seen and not seen_add(x[key])]
except:
print(' Error calculating list of dict unique by key...')
return []
# adapt indices/keys based on compact vs long conn format
if sim.cfg.compactConnFormat:
connsFormat = sim.cfg.compactConnFormat
# set indices of fields to read compact format (no keys)
missing = []
preGidIndex = connsFormat.index('preGid') if 'preGid' in connsFormat else missing.append('preGid')
synMechIndex = connsFormat.index('synMech') if 'synMech' in connsFormat else missing.append('synMech')
weightIndex = connsFormat.index('weight') if 'weight' in connsFormat else missing.append('weight')
delayIndex = connsFormat.index('delay') if 'delay' in connsFormat else missing.append('delay')
preLabelIndex = connsFormat.index('preLabel') if 'preLabel' in connsFormat else -1
if len(missing) > 0:
print(" Error: cfg.compactConnFormat missing:")
print(missing)
return None, None, None
else:
# using long conn format (dict)
preGidIndex = 'preGid'
synMechIndex = 'synMech'
weightIndex = 'weight'
delayIndex = 'delay'
preLabelIndex = 'preLabel'
# Calculate pre and post cells involved
cellsPre, cellGidsPre, netStimPopsPre = getCellsInclude(includePre)
if includePre == includePost:
cellsPost, cellGidsPost, netStimPopsPost = cellsPre, cellGidsPre, netStimPopsPre
else:
cellsPost, cellGidsPost, netStimPopsPost = getCellsInclude(includePost)
if isinstance(synMech, basestring):
synMech = [synMech] # make sure synMech is a list
# Calculate matrix if grouped by cell
if groupBy == 'cell':
connMatrix = np.zeros((len(cellGidsPre), len(cellGidsPost)))
countMatrix = np.zeros((len(cellGidsPre), len(cellGidsPost)))
cellIndsPre = {cell['gid']: ind for ind, cell in enumerate(cellsPre)}
cellIndsPost = {cell['gid']: ind for ind, cell in enumerate(cellsPost)}
# Order by
if len(cellsPre) > 0 and len(cellsPost) > 0:
if (
orderBy not in cellsPre[0]['tags'] or orderBy not in cellsPost[0]['tags']
): # if orderBy property doesn't exist or is not numeric, use gid
orderBy = 'gid'
elif not isinstance(cellsPre[0]['tags'][orderBy], Number) or not isinstance(
cellsPost[0]['tags'][orderBy], Number
):
orderBy = 'gid'
if orderBy == 'gid':
yorderPre = [cell[orderBy] for cell in cellsPre]
yorderPost = [cell[orderBy] for cell in cellsPost]
else:
yorderPre = [cell['tags'][orderBy] for cell in cellsPre]
yorderPost = [cell['tags'][orderBy] for cell in cellsPost]
sortedGidsPre = {gid: i for i, (y, gid) in enumerate(sorted(zip(yorderPre, cellGidsPre)))}
cellIndsPre = sortedGidsPre
if includePre == includePost:
sortedGidsPost = sortedGidsPre
cellIndsPost = cellIndsPre
else:
sortedGidsPost = {gid: i for i, (y, gid) in enumerate(sorted(zip(yorderPost, cellGidsPost)))}
cellIndsPost = sortedGidsPost
# Calculate conn matrix
for cell in cellsPost: # for each postsyn cell
if synOrConn == 'syn':
cellConns = cell['conns'] # include all synapses
else:
cellConns = list_of_dict_unique_by_key(cell['conns'], preGidIndex)
if synMech:
cellConns = [conn for conn in cellConns if conn[synMechIndex] in synMech]
for conn in cellConns:
if conn[preGidIndex] != 'NetStim' and conn[preGidIndex] in cellIndsPre:
if feature in ['weight', 'delay']:
featureIndex = weightIndex if feature == 'weight' else delayIndex
if conn[preGidIndex] in cellIndsPre:
if removeWeightNorm and feature == 'weight':
try:
sec = conn['sec']
loc = conn['loc']
nseg = cell['secs'][sec]['geom']['nseg']
segIndex = int(round(loc * nseg)) - 1
weightNorm = cell['secs'][sec]['weightNorm'][segIndex]
connMatrix[cellIndsPre[conn[preGidIndex]], cellIndsPost[cell['gid']]] += (
conn[featureIndex] / weightNorm
)
except:
pass
else:
connMatrix[cellIndsPre[conn[preGidIndex]], cellIndsPost[cell['gid']]] += conn[
featureIndex
]
countMatrix[cellIndsPre[conn[preGidIndex]], cellIndsPost[cell['gid']]] += 1
if feature in ['weight', 'delay']:
if logPlot:
connMatrix = np.log10(connMatrix / countMatrix)
else:
connMatrix = connMatrix / countMatrix
elif feature in ['numConns']:
connMatrix = countMatrix
pre, post = cellsPre, cellsPost
# Calculate matrix if grouped by pop
elif groupBy == 'pop':
# get list of pops
popsTempPre = list(set([cell['tags']['pop'] for cell in cellsPre]))
popsPre = [pop for pop in sim.net.allPops if pop in popsTempPre] + netStimPopsPre
popIndsPre = {pop: ind for ind, pop in enumerate(popsPre)}
if includePre == includePost:
popsPost = popsPre
popIndsPost = popIndsPre
else:
popsTempPost = list(set([cell['tags']['pop'] for cell in cellsPost]))
popsPost = [pop for pop in sim.net.allPops if pop in popsTempPost] + netStimPopsPost
popIndsPost = {pop: ind for ind, pop in enumerate(popsPost)}
# initialize matrices
if feature in ['weight', 'strength']:
weightMatrix = np.zeros((len(popsPre), len(popsPost)))
elif feature == 'delay':
delayMatrix = np.zeros((len(popsPre), len(popsPost)))
countMatrix = np.zeros((len(popsPre), len(popsPost)))
# calculate max num conns per pre and post pair of pops
numCellsPopPre = {}
for pop in popsPre:
if pop in netStimPopsPre:
numCellsPopPre[pop] = -1
else:
numCellsPopPre[pop] = len([cell for cell in cellsPre if cell['tags']['pop'] == pop])
if includePre == includePost:
numCellsPopPost = numCellsPopPre
else:
numCellsPopPost = {}
for pop in popsPost:
if pop in netStimPopsPost:
numCellsPopPost[pop] = -1
else:
numCellsPopPost[pop] = len([cell for cell in cellsPost if cell['tags']['pop'] == pop])
maxConnMatrix = np.zeros((len(popsPre), len(popsPost)))
if feature == 'convergence':
maxPostConnMatrix = np.zeros((len(popsPre), len(popsPost)))
if feature == 'divergence':
maxPreConnMatrix = np.zeros((len(popsPre), len(popsPost)))
for prePop in popsPre:
for postPop in popsPost:
if numCellsPopPre[prePop] == -1:
numCellsPopPre[prePop] = numCellsPopPost[postPop]
maxConnMatrix[popIndsPre[prePop], popIndsPost[postPop]] = (
numCellsPopPre[prePop] * numCellsPopPost[postPop]
)
if feature == 'convergence':
maxPostConnMatrix[popIndsPre[prePop], popIndsPost[postPop]] = numCellsPopPost[postPop]
if feature == 'divergence':
maxPreConnMatrix[popIndsPre[prePop], popIndsPost[postPop]] = numCellsPopPre[prePop]
preCellPops = {cell.gid: cell['tags']['pop'] for cell in cellsPre}
# Calculate conn matrix
for cell in cellsPost: # for each postsyn cell
if synOrConn == 'syn':
cellConns = cell['conns'] # include all synapses
else:
cellConns = list_of_dict_unique_by_key(cell['conns'], preGidIndex)
if synMech:
cellConns = [conn for conn in cellConns if conn[synMechIndex] in synMech]
for conn in cellConns:
if conn[preGidIndex] == 'NetStim':
prePopLabel = conn[preLabelIndex] if preLabelIndex in conn else 'NetStim'
else:
prePopLabel = preCellPops.get(conn[preGidIndex])
if prePopLabel in popIndsPre:
if feature in ['weight', 'strength']:
if removeWeightNorm:
try:
sec = conn['sec']
loc = conn['loc']
nseg = cell['secs'][sec]['geom']['nseg']
segIndex = int(round(loc * nseg)) - 1
weightNorm = cell['secs'][sec]['weightNorm'][segIndex]
weightMatrix[popIndsPre[prePopLabel], popIndsPost[cell['tags']['pop']]] += (
conn[weightIndex] / weightNorm
)
except:
import IPython
IPython.embed()
else:
weightMatrix[popIndsPre[prePopLabel], popIndsPost[cell['tags']['pop']]] += conn[
weightIndex
]
elif feature == 'delay':
delayMatrix[popIndsPre[prePopLabel], popIndsPost[cell['tags']['pop']]] += conn[delayIndex]
countMatrix[popIndsPre[prePopLabel], popIndsPost[cell['tags']['pop']]] += 1
pre, post = popsPre, popsPost
# Calculate matrix if grouped by numeric tag (eg. 'y')
elif groupBy in sim.net.allCells[0]['tags'] and isinstance(sim.net.allCells[0]['tags'][groupBy], Number):
if not isinstance(groupByIntervalPre, Number) or not isinstance(groupByIntervalPost, Number):
print(' groupByIntervalPre or groupByIntervalPost not specified')
return None, None, None
# group cells by 'groupBy' feature (eg. 'y') in intervals of 'groupByInterval')
cellValuesPre = [cell['tags'][groupBy] for cell in cellsPre]
minValuePre = _roundFigures(groupByIntervalPre * np.floor(min(cellValuesPre) / groupByIntervalPre), 3)
maxValuePre = _roundFigures(groupByIntervalPre * np.ceil(max(cellValuesPre) / groupByIntervalPre), 3)
groupsPre = np.arange(minValuePre, maxValuePre, groupByIntervalPre)
groupsPre = [_roundFigures(x, 3) for x in groupsPre]
if includePre == includePost:
groupsPost = groupsPre
else:
cellValuesPost = [cell['tags'][groupBy] for cell in cellsPost]
minValuePost = _roundFigures(groupByIntervalPost * np.floor(min(cellValuesPost) / groupByIntervalPost), 3)
maxValuePost = _roundFigures(groupByIntervalPost * np.ceil(max(cellValuesPost) / groupByIntervalPost), 3)
groupsPost = np.arange(minValuePost, maxValuePost, groupByIntervalPost)
groupsPost = [_roundFigures(x, 3) for x in groupsPost]
# only allow matrix sizes >= 2x2 [why?]
# if len(groupsPre) < 2 or len(groupsPost) < 2:
# print 'groupBy %s with groupByIntervalPre %s and groupByIntervalPost %s results in <2 groups'%(str(groupBy), str(groupByIntervalPre), str(groupByIntervalPre))
# return
# set indices for pre and post groups
groupIndsPre = {group: ind for ind, group in enumerate(groupsPre)}
groupIndsPost = {group: ind for ind, group in enumerate(groupsPost)}
# initialize matrices
if feature in ['weight', 'strength']:
weightMatrix = np.zeros((len(groupsPre), len(groupsPost)))
elif feature == 'delay':
delayMatrix = np.zeros((len(groupsPre), len(groupsPost)))
countMatrix = np.zeros((len(groupsPre), len(groupsPost)))
# calculate max num conns per pre and post pair of pops
numCellsGroupPre = {}
for groupPre in groupsPre:
numCellsGroupPre[groupPre] = len(
[cell for cell in cellsPre if groupPre <= cell['tags'][groupBy] < (groupPre + groupByIntervalPre)]
)
if includePre == includePost:
numCellsGroupPost = numCellsGroupPre
else:
numCellsGroupPost = {}
for groupPost in groupsPost:
numCellsGroupPost[groupPost] = len(
[
cell
for cell in cellsPost
if groupPost <= cell['tags'][groupBy] < (groupPost + groupByIntervalPost)
]
)
maxConnMatrix = np.zeros((len(groupsPre), len(groupsPost)))
if feature == 'convergence':
maxPostConnMatrix = np.zeros((len(groupsPre), len(groupsPost)))
if feature == 'divergence':
maxPreConnMatrix = np.zeros((len(groupsPre), len(groupsPost)))
for preGroup in groupsPre:
for postGroup in groupsPost:
if numCellsGroupPre[preGroup] == -1:
numCellsGroupPre[preGroup] = numCellsGroupPost[postGroup]
maxConnMatrix[groupIndsPre[preGroup], groupIndsPost[postGroup]] = (
numCellsGroupPre[preGroup] * numCellsGroupPost[postGroup]
)
if feature == 'convergence':
maxPostConnMatrix[groupIndsPre[preGroup], groupIndsPost[postGroup]] = numCellsGroupPost[postGroup]
if feature == 'divergence':
maxPreConnMatrix[groupIndsPre[preGroup], groupIndsPost[postGroup]] = numCellsGroupPre[preGroup]
# Calculate conn matrix
for cell in cellsPost: # for each postsyn cell
if synOrConn == 'syn':
cellConns = cell['conns'] # include all synapses
else:
cellConns = list_of_dict_unique_by_key(cell['conns'], preGidIndex)
if synMech:
cellConns = [conn for conn in cellConns if conn[synMechIndex] in synMech]
for conn in cellConns:
if conn[preGidIndex] == 'NetStim':
prePopLabel = -1 # maybe add in future
else:
preCell = next((c for c in cellsPre if c['gid'] == conn[preGidIndex]), None)
if preCell:
preGroup = _roundFigures(
groupByIntervalPre * np.floor(preCell['tags'][groupBy] / groupByIntervalPre), 3
)
else:
preGroup = None
postGroup = _roundFigures(
groupByIntervalPost * np.floor(cell['tags'][groupBy] / groupByIntervalPost), 3
)
if preGroup in groupIndsPre:
if feature in ['weight', 'strength']:
weightMatrix[groupIndsPre[preGroup], groupIndsPost[postGroup]] += conn[weightIndex]
elif feature == 'delay':
delayMatrix[groupIndsPre[preGroup], groupIndsPost[postGroup]] += conn[delayIndex]
countMatrix[groupIndsPre[preGroup], groupIndsPost[postGroup]] += 1
pre, post = groupsPre, groupsPost
# no valid groupBy
else:
print(' groupBy (%s) is not valid' % (str(groupBy)))
return None, None, None
# normalize by number of postsyn cells
if groupBy != 'cell':
if feature == 'weight':
if logPlot:
connMatrix = np.log10(weightMatrix / countMatrix) # avg log weight per conn
else:
connMatrix = weightMatrix / countMatrix # avg weight per conn
connMatrix = np.nan_to_num(connMatrix, nan=0) # if the count is 0 we get NaNs, but the weight is 0
elif feature == 'delay':
connMatrix = delayMatrix / countMatrix
connMatrix = np.nan_to_num(connMatrix, nan=0) # if the count is 0 we get NaNs, but the delay is 0
elif feature == 'numConns':
connMatrix = countMatrix
elif feature in ['probability', 'strength']:
connMatrix = countMatrix / maxConnMatrix # probability
if feature == 'strength':
if logPlot:
connMatrix = np.log10(connMatrix * weightMatrix) # log strength
else:
connMatrix = connMatrix * weightMatrix # strength
elif feature == 'convergence':
connMatrix = countMatrix / maxPostConnMatrix
elif feature == 'divergence':
connMatrix = countMatrix / maxPreConnMatrix
return connMatrix, pre, post
# -------------------------------------------------------------------------------------------------------------------
## Support function for plotConn() - calculate conn using data from files with short format (no keys)
# -------------------------------------------------------------------------------------------------------------------
def _plotConnCalculateFromFile(
includePre,
includePost,
feature,
orderBy,
groupBy,
groupByIntervalPre,
groupByIntervalPost,
synOrConn,
synMech,
connsFile,
tagsFile,
removeWeightNorm,
logPlot=False,
):
from .. import sim
import json
from time import time
def list_of_dict_unique_by_key(seq, index):
seen = set()
seen_add = seen.add
return [x for x in seq if x[index] not in seen and not seen_add(x[index])]
# load files with tags and conns
start = time()
tags, conns = None, None
if tagsFile:
print('Loading tags file...')
with open(tagsFile, 'r') as fileObj:
tagsTmp = json.load(fileObj)['tags']
tagsFormat = tagsTmp.pop('format', [])
tags = {int(k): v for k, v in tagsTmp.items()} # find method to load json with int keys?
del tagsTmp
if connsFile:
print('Loading conns file...')
with open(connsFile, 'r') as fileObj:
connsTmp = json.load(fileObj)['conns']
connsFormat = connsTmp.pop('format', [])
conns = {int(k): v for k, v in connsTmp.items()}
del connsTmp
print('Finished loading; total time (s): %.2f' % (time() - start))
# find pre and post cells
if tags and conns:
cellGidsPre = getCellsIncludeTags(includePre, tags, tagsFormat)
if includePre == includePost:
cellGidsPost = cellGidsPre
else:
cellGidsPost = getCellsIncludeTags(includePost, tags, tagsFormat)
else:
print('Error loading tags and conns from file')
return None, None, None
# set indices of fields to read compact format (no keys)
missing = []
popIndex = tagsFormat.index('pop') if 'pop' in tagsFormat else missing.append('pop')
preGidIndex = connsFormat.index('preGid') if 'preGid' in connsFormat else missing.append('preGid')
synMechIndex = connsFormat.index('synMech') if 'synMech' in connsFormat else missing.append('synMech')
weightIndex = connsFormat.index('weight') if 'weight' in connsFormat else missing.append('weight')
delayIndex = connsFormat.index('delay') if 'delay' in connsFormat else missing.append('delay')
preLabelIndex = connsFormat.index('preLabel') if 'preLabel' in connsFormat else -1
if len(missing) > 0:
print("Missing:")
print(missing)
return None, None, None
if isinstance(synMech, basestring):
synMech = [synMech] # make sure synMech is a list
# Calculate matrix if grouped by cell
if groupBy == 'cell':
print(' plotConn from file for groupBy=cell not implemented yet')
return None, None, None
# Calculate matrix if grouped by pop
elif groupBy == 'pop':
# get list of pops
print(' Obtaining list of populations ...')
popsPre = list(set([tags[gid][popIndex] for gid in cellGidsPre]))
popIndsPre = {pop: ind for ind, pop in enumerate(popsPre)}
netStimPopsPre = [] # netstims not yet supported
netStimPopsPost = []
if includePre == includePost:
popsPost = popsPre
popIndsPost = popIndsPre
else:
popsPost = list(set([tags[gid][popIndex] for gid in cellGidsPost]))
popIndsPost = {pop: ind for ind, pop in enumerate(popsPost)}
# initialize matrices
if feature in ['weight', 'strength']:
weightMatrix = np.zeros((len(popsPre), len(popsPost)))
elif feature == 'delay':
delayMatrix = np.zeros((len(popsPre), len(popsPost)))
countMatrix = np.zeros((len(popsPre), len(popsPost)))
# calculate max num conns per pre and post pair of pops
print(' Calculating max num conns for each pair of population ...')
numCellsPopPre = {}
for pop in popsPre:
if pop in netStimPopsPre:
numCellsPopPre[pop] = -1
else:
numCellsPopPre[pop] = len([gid for gid in cellGidsPre if tags[gid][popIndex] == pop])
if includePre == includePost:
numCellsPopPost = numCellsPopPre
else:
numCellsPopPost = {}
for pop in popsPost:
if pop in netStimPopsPost:
numCellsPopPost[pop] = -1
else:
numCellsPopPost[pop] = len([gid for gid in cellGidsPost if tags[gid][popIndex] == pop])
maxConnMatrix = np.zeros((len(popsPre), len(popsPost)))
if feature == 'convergence':
maxPostConnMatrix = np.zeros((len(popsPre), len(popsPost)))
if feature == 'divergence':
maxPreConnMatrix = np.zeros((len(popsPre), len(popsPost)))
for prePop in popsPre:
for postPop in popsPost:
if numCellsPopPre[prePop] == -1:
numCellsPopPre[prePop] = numCellsPopPost[postPop]
maxConnMatrix[popIndsPre[prePop], popIndsPost[postPop]] = (
numCellsPopPre[prePop] * numCellsPopPost[postPop]
)
if feature == 'convergence':
maxPostConnMatrix[popIndsPre[prePop], popIndsPost[postPop]] = numCellsPopPost[postPop]
if feature == 'divergence':
maxPreConnMatrix[popIndsPre[prePop], popIndsPost[postPop]] = numCellsPopPre[prePop]
# Calculate conn matrix
print(' Calculating weights, strength, prob, delay etc matrices ...')
for postGid in cellGidsPost: # for each postsyn cell
print(' cell %d' % (int(postGid)))
if synOrConn == 'syn':
cellConns = conns[postGid] # include all synapses
else:
cellConns = list_of_dict_unique_by_index(conns[postGid], preGidIndex)
if synMech:
cellConns = [conn for conn in cellConns if conn[synMechIndex] in synMech]
for conn in cellConns:
if conn[preGidIndex] == 'NetStim':
prePopLabel = conn[preLabelIndex] if preLabelIndex >= 0 else 'NetStims'
else:
preCellGid = next((gid for gid in cellGidsPre if gid == conn[preGidIndex]), None)
prePopLabel = tags[preCellGid][popIndex] if preCellGid else None
if prePopLabel in popIndsPre:
if feature in ['weight', 'strength']:
weightMatrix[popIndsPre[prePopLabel], popIndsPost[tags[postGid][popIndex]]] += conn[
weightIndex
]
elif feature == 'delay':
delayMatrix[popIndsPre[prePopLabel], popIndsPost[tags[postGid][popIndex]]] += conn[delayIndex]
countMatrix[popIndsPre[prePopLabel], popIndsPost[tags[postGid][popIndex]]] += 1
pre, post = popsPre, popsPost
# Calculate matrix if grouped by numeric tag (eg. 'y')
elif groupBy in sim.net.allCells[0]['tags'] and isinstance(sim.net.allCells[0]['tags'][groupBy], Number):
print('plotConn from file for groupBy=[arbitrary property] not implemented yet')
return None, None, None
# no valid groupBy
else:
print('groupBy (%s) is not valid' % (str(groupBy)))
return None, None, None
if groupBy != 'cell':
if feature == 'weight':
if logPlot:
connMatrix = np.log10(weightMatrix / countMatrix) # avg log weight per conn
else:
connMatrix = weightMatrix / countMatrix # avg weight per conn (fix to remove divide by zero warning)
connMatrix = np.nan_to_num(connMatrix, nan=0) # if the count is 0 we get NaNs, but the weight is 0
elif feature == 'delay':
connMatrix = delayMatrix / countMatrix
connMatrix = np.nan_to_num(connMatrix, nan=0) # if the count is 0 we get NaNs, but the delay is 0
elif feature == 'numConns':
connMatrix = countMatrix
elif feature in ['probability', 'strength']:
connMatrix = countMatrix / maxConnMatrix # probability
if feature == 'strength':
if logPlot:
connMatrix = np.log10(connMatrix * weightMatrix) # log strength
else:
connMatrix = connMatrix * weightMatrix # strength
elif feature == 'convergence':
connMatrix = countMatrix / maxPostConnMatrix
elif feature == 'divergence':
connMatrix = countMatrix / maxPreConnMatrix
print(' plotting ...')
return connMatrix, pre, post
# -------------------------------------------------------------------------------------------------------------------
## Plot connectivity
# -------------------------------------------------------------------------------------------------------------------
[docs]
@exception
def plotConn(
includePre=['all'],
includePost=['all'],
feature='strength',
orderBy='gid',
groupBy='pop',
groupByIntervalPre=None,
groupByIntervalPost=None,
graphType='matrix',
removeWeightNorm=False,
synOrConn='syn',
synMech=None,
connsFile=None,
tagsFile=None,
clim=None,
figSize=(8, 8),
fontSize=12,
saveData=None,
saveFig=None,
showFig=True,
logPlot=False,
):
"""
Function for/to <short description of `netpyne.analysis.network.plotConn`>
Parameters
----------
includePre : list
List of presynaptic cells to include.
**Default:** ``['all']``
**Options:**
``['all']`` plots all cells and stimulations,
``['allNetStims']`` plots just stimulations,
``['popName1']`` plots a single population,
``['popName1', 'popName2']`` plots multiple populations,
``[120]`` plots a single cell,
``[120, 130]`` plots multiple cells,
``[('popName1', 56)]`` plots a cell from a specific population,
``[('popName1', [0, 1]), ('popName2', [4, 5, 6])]``, plots cells from multiple populations
includePost : list
List of postsynaptic cells to include.
**Default:** ``['all']``
**Options:** same as in `includePre`
feature : str
Feature to show in the connectivity plot. The only features applicable to ``groupBy='cell'`` are ``'weight'``, ``'delay'`` and ``'numConns'``.
**Default:** ``'strength'``
**Options:**
``'weight'`` weight of connection,
``'delay'`` delay in connection,
``'numConns'`` number of connections,
``'probability'`` probabiluty of connection,
``'strength'`` weight * probability,
``'convergence'`` number of presynaptic cells per postynaptic one,
``'divergence'`` number of postsynaptic cells per presynaptic one
orderBy : str
Unique numeric cell property by which to order x and y axes.
**Default:** ``'gid'``
**Options:** ``'gid'``, ``'y'``, ``'ynorm'``
groupBy : str
Plot connectivity for populations, individual cells, or by other numeric tags such as ``'y'``.
**Default:** ``'pop'``
**Options:** ``'pop'``, ``'cell'``, ``'y'``
groupByIntervalPre : int or float
Interval of `groupBy` feature to group presynaptic cells by in connectivity plot, e.g. ``100`` to group by cortical depth in steps of 100 um.
**Default:** ``None``
**Options:** ``<option>`` <description of option>
groupByIntervalPost : int or float
Interval of `groupBy` feature to group postsynaptic cells by in connectivity plot, e.g. ``100`` to group by cortical depth in steps of 100 um.
**Default:** ``None``
**Options:** ``<option>`` <description of option>
graphType : str
Type of graph to represent data.
**Default:** ``'matrix'``
**Options:** ``'matrix'``, ``'bar'``, ``'pie'``
removeWeightNorm : bool
**Default:** ``False``
**Options:** ``<option>`` <description of option>
synOrConn : str
Use synapses or connections; note one connection can have multiple synapses.
**Default:** ``'syn'``
**Options:** ``'syn'``, ``'conn'``
synMech : list
Show results only for these synaptic mechanisms, e.g. ``['AMPA', 'GABAA', ...]``.
**Default:** ``None``
**Options:** ``<option>`` <description of option>
connsFile : str
Path to a saved data file of connectivity to plot from.
**Default:** ``None``
**Options:** ``<option>`` <description of option>
tagsFile : str
Path to a saved tags file to use in connectivity plot.
**Default:** ``None``
**Options:** ``<option>`` <description of option>
clim : list [min, max]
List of numeric values for the limits of the colorbar.
**Default:** ``None`` uses the min and max of the connectivity matrix
**Options:** ``<option>`` <description of option>
figSize : list [width, height]
Size of figure in inches.
**Default:** ``(8, 8)``
**Options:** ``<option>`` <description of option>
fontSize : int
Font size on figure.
**Default:** ``12``
**Options:** ``<option>`` <description of option>
saveData : bool or str
Whether and where to save the data used to generate the plot.
**Default:** ``False``
**Options:** ``True`` autosaves the data,
``'/path/filename.ext'`` saves to a custom path and filename, valid file extensions are ``'.pkl'`` and ``'.json'``
saveFig : bool or str
Whether and where to save the figure.
**Default:** ``False``
**Options:** ``True`` autosaves the figure,
``'/path/filename.ext'`` saves to a custom path and filename, valid file extensions are ``'.png'``, ``'.jpg'``, ``'.eps'``, and ``'.tiff'``
showFig : bool
Shows the figure if ``True``.
**Default:** ``True``
**Options:** ``<option>`` <description of option>
Returns
-------
"""
from .. import sim
print('Plotting connectivity matrix...')
if groupBy == 'cell' and feature == 'strength':
feature = 'weight'
if connsFile and tagsFile:
connMatrix, pre, post = _plotConnCalculateFromFile(
includePre,
includePost,
feature,
orderBy,
groupBy,
groupByIntervalPre,
groupByIntervalPost,
synOrConn,
synMech,
connsFile,
tagsFile,
removeWeightNorm,
logPlot,
)
else:
connMatrix, pre, post = _plotConnCalculateFromSim(
includePre,
includePost,
feature,
orderBy,
groupBy,
groupByIntervalPre,
groupByIntervalPost,
synOrConn,
synMech,
removeWeightNorm,
logPlot,
)
if connMatrix is None:
print(" Error calculating connMatrix in plotConn()")
return None
# set font size
plt.rcParams.update({'font.size': fontSize})
# matrix plot
if graphType == 'matrix':
# Create plot
fig = plt.figure(figsize=figSize)
fig.subplots_adjust(right=0.98) # Less space on right
fig.subplots_adjust(top=0.96) # Less space on top
fig.subplots_adjust(bottom=0.02) # Less space on bottom
h = plt.axes()
plt.imshow(
connMatrix, interpolation='nearest', cmap='viridis', vmin=np.nanmin(connMatrix), vmax=np.nanmax(connMatrix)
) # _bicolormap(gap=0)
# Plot grid lines
if groupBy == 'cell':
cellsPre, cellsPost = pre, post
# Make pretty
stepy = max(1, int(len(cellsPre) / 10.0))
basey = 100 if stepy > 100 else 10
stepy = max(1, int(basey * np.floor(float(stepy) / basey)))
stepx = max(1, int(len(cellsPost) / 10.0))
basex = 100 if stepx > 100 else 10
stepx = max(1, int(basex * np.floor(float(stepx) / basex)))
h.set_xticks(np.arange(0, len(cellsPost), stepx))
h.set_yticks(np.arange(0, len(cellsPre), stepy))
h.set_xticklabels(np.arange(0, len(cellsPost), stepx))
h.set_yticklabels(np.arange(0, len(cellsPost), stepy))
h.xaxis.set_ticks_position('top')
plt.xticks(rotation=90)
plt.xlim(-0.5, len(cellsPost) - 0.5)
plt.ylim(len(cellsPre) - 0.5, -0.5)
elif groupBy == 'pop':
popsPre, popsPost = pre, post
for ipop, pop in enumerate(popsPre):
plt.plot(np.array([0, len(popsPost)]) - 0.5, np.array([ipop, ipop]) - 0.5, '-', c=(0.7, 0.7, 0.7))
for ipop, pop in enumerate(popsPost):
plt.plot(np.array([ipop, ipop]) - 0.5, np.array([0, len(popsPre)]) - 0.5, '-', c=(0.7, 0.7, 0.7))
# Make pretty
h.set_xticks(list(range(len(popsPost))))
h.set_yticks(list(range(len(popsPre))))
h.set_xticklabels(popsPost)
h.set_yticklabels(popsPre)
h.xaxis.set_ticks_position('top')
plt.xticks(rotation=90)
plt.xlim(-0.5, len(popsPost) - 0.5)
plt.ylim(len(popsPre) - 0.5, -0.5)
else:
groupsPre, groupsPost = pre, post
for igroup, group in enumerate(groupsPre):
plt.plot(np.array([0, len(groupsPre)]) - 0.5, np.array([igroup, igroup]) - 0.5, '-', c=(0.7, 0.7, 0.7))
for igroup, group in enumerate(groupsPost):
plt.plot(
np.array([igroup, igroup]) - 0.5, np.array([0, len(groupsPost)]) - 0.5, '-', c=(0.7, 0.7, 0.7)
)
# Make pretty
h.set_xticks([i - 0.5 for i in range(len(groupsPost))])
h.set_yticks([i - 0.5 for i in range(len(groupsPre))])
h.set_xticklabels([int(x) if x > 1 else x for x in groupsPost])
h.set_yticklabels([int(x) if x > 1 else x for x in groupsPre])
h.xaxis.set_ticks_position('top')
plt.xticks(rotation=90)
plt.xlim(-0.5, len(groupsPost) - 0.5)
plt.ylim(len(groupsPre) - 0.5, -0.5)
if not clim:
clim = [np.nanmin(connMatrix), np.nanmax(connMatrix)]
plt.clim(clim[0], clim[1])
if logPlot:
plt.colorbar(label=feature + ' (log)', shrink=0.8)
else:
plt.colorbar(label=feature, shrink=0.8)
plt.xlabel('post')
h.xaxis.set_label_coords(0.5, 1.09)
plt.ylabel('pre')
if logPlot:
plt.title('Connection ' + feature + ' matrix (log)', y=1.12)
else:
plt.title('Connection ' + feature + ' matrix', y=1.12)
# stacked bar graph
elif graphType == 'bar':
if groupBy == 'pop':
popsPre, popsPost = pre, post
from netpyne.support import stackedBarGraph
SBG = stackedBarGraph.StackedBarGrapher()
fig = plt.figure(figsize=figSize)
ax = fig.add_subplot(111)
SBG.stackedBarPlot(
ax,
connMatrix.transpose(),
colorList,
xLabels=popsPost,
gap=0.1,
scale=False,
xlabel='Post',
ylabel=feature,
)
plt.title('Connection ' + feature + ' stacked bar graph')
plt.legend(popsPre, title='Pre')
plt.tight_layout()
elif groupBy == 'cell':
print(' Error: plotConn graphType="bar" with groupBy="cell" not implemented')
return None
elif graphType == 'pie':
print(' Error: plotConn graphType="pie" not yet implemented')
return None
# save figure data
if saveData:
figData = {
'connMatrix': connMatrix,
'feature': feature,
'groupBy': groupBy,
'includePre': includePre,
'includePost': includePost,
'saveData': saveData,
'saveFig': saveFig,
'showFig': showFig,
}
_saveFigData(figData, saveData, 'conn')
# save figure
if saveFig:
if isinstance(saveFig, basestring):
filename = saveFig
else:
filename = sim.cfg.filename + '_plot_conn_' + groupBy + '_' + feature + '_' + graphType + '.png'
plt.savefig(filename)
# show fig
if showFig:
_showFigure()
return fig, {
'connMatrix': connMatrix,
'feature': feature,
'groupBy': groupBy,
'includePre': includePre,
'includePost': includePost,
}
# -------------------------------------------------------------------------------------------------------------------
## Plot 2D representation of network cell positions and connections
# -------------------------------------------------------------------------------------------------------------------
[docs]
@exception
def plot2Dnet(
include=['allCells'],
view='xy',
showConns=True,
popColors=None,
tagsFile=None,
figSize=(12, 12),
fontSize=12,
saveData=None,
saveFig=None,
showFig=True,
lineWidth=0.1,
):
"""
Function for/to <short description of `netpyne.analysis.network.plot2Dnet`>
Parameters
----------
include : list
List of presynaptic cells to include.
**Default:** ``['allCells']``
**Options:**
``['all']`` plots all cells and stimulations,
``['allNetStims']`` plots just stimulations,
``['popName1']`` plots a single population,
``['popName1', 'popName2']`` plots multiple populations,
``[120]`` plots a single cell,
``[120, 130]`` plots multiple cells,
``[('popName1', 56)]`` plots a cell from a specific population,
``[('popName1', [0, 1]), ('popName2', [4, 5, 6])]``, plots cells from multiple populations
view : str
Perspective of view.
**Default:** ``'xy'`` front view,
**Options:** ``'xz'`` top-down view
showConns : bool
Whether to show connections or not.
**Default:** ``True``
**Options:** ``<option>`` <description of option>
popColors : dict
Dictionary with custom color (value) used for each population (key).
**Default:** ``None`` uses standard colors
**Options:** ``<option>`` <description of option>
tagsFile : str
Path to a saved tags file to use in connectivity plot.
**Default:** ``None``
**Options:** ``<option>`` <description of option>
figSize : list [width, height]
Size of figure in inches.
**Default:** ``(12, 12)``
**Options:** ``<option>`` <description of option>
fontSize : int
Font size on figure.
**Default:** ``12``
**Options:** ``<option>`` <description of option>
saveData : bool or str
Whether and where to save the data used to generate the plot.
**Default:** ``False``
**Options:** ``True`` autosaves the data,
``'/path/filename.ext'`` saves to a custom path and filename, valid file extensions are ``'.pkl'`` and ``'.json'``
saveFig : bool or str
Whether and where to save the figure.
**Default:** ``False``
**Options:** ``True`` autosaves the figure,
``'/path/filename.ext'`` saves to a custom path and filename, valid file extensions are ``'.png'``, ``'.jpg'``, ``'.eps'``, and ``'.tiff'``
showFig : bool
Shows the figure if ``True``.
**Default:** ``True``
**Options:** ``<option>`` <description of option>
lineWidth: float
Width of connection lines.
**Default:** ``0.1``
**Options:** ``<option>`` <description of option>
Returns
-------
"""
from .. import sim
print('Plotting 2D representation of network cell locations and connections...')
fig = plt.figure(figsize=figSize)
# front view
if view == 'xy':
ycoord = 'y'
elif view == 'xz':
ycoord = 'z'
if tagsFile:
print('Loading tags file...')
import json
with open(tagsFile, 'r') as fileObj:
tagsTmp = json.load(fileObj)['tags']
tagsFormat = tagsTmp.pop('format', [])
tags = {int(k): v for k, v in tagsTmp.items()} # find method to load json with int keys?
del tagsTmp
# set indices of fields to read compact format (no keys)
missing = []
popIndex = tagsFormat.index('pop') if 'pop' in tagsFormat else missing.append('pop')
xIndex = tagsFormat.index('x') if 'x' in tagsFormat else missing.append('x')
yIndex = tagsFormat.index('y') if 'y' in tagsFormat else missing.append('y')
zIndex = tagsFormat.index('z') if 'z' in tagsFormat else missing.append('z')
if len(missing) > 0:
print("Missing:")
print(missing)
return None, None, None
# find pre and post cells
if tags:
cellGids = getCellsIncludeTags(include, tags, tagsFormat)
popLabels = list(set([tags[gid][popIndex] for gid in cellGids]))
# pop and cell colors
popColorsTmp = {
popLabel: colorList[ipop % len(colorList)] for ipop, popLabel in enumerate(popLabels)
} # dict with color for each pop
if popColors:
popColorsTmp.update(popColors)
popColors = popColorsTmp
cellColors = [popColors[tags[gid][popIndex]] for gid in cellGids]
# cell locations
posX = [tags[gid][xIndex] for gid in cellGids] # get all x positions
if ycoord == 'y':
posY = [tags[gid][yIndex] for gid in cellGids] # get all y positions
elif ycoord == 'z':
posY = [tags[gid][zIndex] for gid in cellGids] # get all y positions
else:
print('Error loading tags from file')
return None
else:
cells, cellGids, _ = getCellsInclude(include)
selectedPops = [cell['tags']['pop'] for cell in cells]
popLabels = [pop for pop in sim.net.allPops if pop in selectedPops] # preserves original ordering
# pop and cell colors
popColorsTmp = {
popLabel: colorList[ipop % len(colorList)] for ipop, popLabel in enumerate(popLabels)
} # dict with color for each pop
if popColors:
popColorsTmp.update(popColors)
popColors = popColorsTmp
cellColors = [popColors[cell['tags']['pop']] for cell in cells]
# cell locations
posX = [cell['tags']['x'] for cell in cells] # get all x positions
posY = [cell['tags'][ycoord] for cell in cells] # get all y positions
plt.scatter(posX, posY, s=60, color=cellColors) # plot cell soma positions
posXpre, posYpre = [], []
posXpost, posYpost = [], []
if showConns and not tagsFile:
for postCell in cells:
for con in postCell['conns']: # plot connections between cells
if not isinstance(con['preGid'], basestring) and con['preGid'] in cellGids:
posXpre, posYpre = next(
((cell['tags']['x'], cell['tags'][ycoord]) for cell in cells if cell['gid'] == con['preGid']),
None,
)
posXpost, posYpost = postCell['tags']['x'], postCell['tags'][ycoord]
color = 'red'
if con['synMech'] in ['inh', 'GABA', 'GABAA', 'GABAB']:
color = 'blue'
width = lineWidth # 50*con['weight']
plt.plot(
[posXpre, posXpost], [posYpre, posYpost], color=color, linewidth=width
) # plot line from pre to post
plt.xlabel('x (um)')
plt.ylabel(ycoord + ' (um)')
plt.xlim([min(posX) - 0.05 * max(posX), 1.05 * max(posX)])
plt.ylim([min(posY) - 0.05 * max(posY), 1.05 * max(posY)])
fontsiz = fontSize
for popLabel in popLabels:
plt.plot(0, 0, color=popColors[popLabel], label=popLabel)
plt.legend(fontsize=fontsiz, bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0.0)
ax = plt.gca()
ax.invert_yaxis()
# save figure data
if saveData:
figData = {
'posX': posX,
'posY': posY,
'posX': cellColors,
'posXpre': posXpre,
'posXpost': posXpost,
'posYpre': posYpre,
'posYpost': posYpost,
'include': include,
'saveData': saveData,
'saveFig': saveFig,
'showFig': showFig,
'lineWidth': lineWidth,
}
_saveFigData(figData, saveData, '2Dnet')
# save figure
if saveFig:
if isinstance(saveFig, basestring):
filename = saveFig
else:
filename = sim.cfg.filename + '_plot_2Dnet.png'
plt.savefig(filename)
# show fig
if showFig:
_showFigure()
return fig, {
'include': include,
'posX': posX,
'posY': posY,
'posXpre': posXpre,
'posXpost': posXpost,
'posYpre': posYpre,
'posYpost': posYpost,
}
# -------------------------------------------------------------------------------------------------------------------
## Plot 2D representation of network activity
# -------------------------------------------------------------------------------------------------------------------
[docs]
@exception
def plot2Dfiring(
include=['allCells'],
view='xy',
popColors=None,
timeRange=None,
spikeBin=5,
figSize=(12, 12),
fontSize=12,
saveData=None,
saveFig=None,
showFig=True,
lineWidth=0.1,
):
"""
Function for/to <short description of `netpyne.analysis.network.plot2Dnet`>
Parameters
----------
include : list
List of presynaptic cells to include.
**Default:** ``['allCells']``
**Options:**
``['all']`` plots all cells and stimulations,
``['allNetStims']`` plots just stimulations,
``['popName1']`` plots a single population,
``['popName1', 'popName2']`` plots multiple populations,
``[120]`` plots a single cell,
``[120, 130]`` plots multiple cells,
``[('popName1', 56)]`` plots a cell from a specific population,
``[('popName1', [0, 1]), ('popName2', [4, 5, 6])]``, plots cells from multiple populations
view : str
Perspective of view.
**Default:** ``'xy'`` front view,
**Options:** ``'xz'`` top-down view
popColors : dict
Dictionary with custom color (value) used for each population (key).
**Default:** ``None`` uses standard colors
**Options:** ``<option>`` <description of option>
figSize : list [width, height]
Size of figure in inches.
**Default:** ``(12, 12)``
**Options:** ``<option>`` <description of option>
fontSize : int
Font size on figure.
**Default:** ``12``
**Options:** ``<option>`` <description of option>
saveData : bool or str
Whether and where to save the data used to generate the plot.
**Default:** ``False``
**Options:** ``True`` autosaves the data,
``'/path/filename.ext'`` saves to a custom path and filename, valid file extensions are ``'.pkl'`` and ``'.json'``
saveFig : bool or str
Whether and where to save the figure.
**Default:** ``False``
**Options:** ``True`` autosaves the figure,
``'/path/filename.ext'`` saves to a custom path and filename, valid file extensions are ``'.png'``, ``'.jpg'``, ``'.eps'``, and ``'.tiff'``
showFig : bool
Shows the figure if ``True``.
**Default:** ``True``
**Options:** ``<option>`` <description of option>
lineWidth: float
Width of connection lines.
**Default:** ``0.1``
**Options:** ``<option>`` <description of option>
Returns
-------
"""
from .. import sim
from matplotlib import animation
print('Plotting 2D representation of network cell locations and connections...')
fig = plt.figure(figsize=figSize)
# front view
if view == 'xy':
ycoord = 'y'
elif view == 'xz':
ycoord = 'z'
# get tags
cells, cellGids, _ = getCellsInclude(include)
selectedPops = [cell['tags']['pop'] for cell in cells]
popLabels = [pop for pop in sim.net.allPops if pop in selectedPops] # preserves original ordering
# pop and cell colors
popColorsTmp = {
popLabel: colorList[ipop % len(colorList)] for ipop, popLabel in enumerate(popLabels)
} # dict with color for each pop
if popColors:
popColorsTmp.update(popColors)
popColors = popColorsTmp
cellColors = [popColors[cell['tags']['pop']] for cell in cells]
# cell locations
posX = [cell['tags']['x'] for cell in cells] # get all x positions
posY = [cell['tags'][ycoord] for cell in cells] # get all y positions
sc = plt.scatter(posX, posY, s=60, color=cellColors) # plot cell soma positions
posXpre, posYpre = [], []
posXpost, posYpost = [], []
plt.xlabel('x (um)')
plt.ylabel(ycoord + ' (um)')
plt.xlim([min(posX) - 0.05 * max(posX), 1.05 * max(posX)])
plt.ylim([min(posY) - 0.05 * max(posY), 1.05 * max(posY)])
fontsiz = fontSize
for popLabel in popLabels:
plt.plot(0, 0, color=popColors[popLabel], label=popLabel)
plt.legend(fontsize=fontsiz, bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0.0)
ax = plt.gca()
ax.invert_yaxis()
# generate animation with time-resolved spiking activity
spktsAll = sim.allSimData['spkt']
spkidsAll = sim.allSimData['spkid']
if not isinstance(timeRange, list): # True or None
timeRange = [0, sim.cfg.duration]
def animate(i, sc, timeRange, spikeBin, ycoord, spkidsAll, spktsAll, cells, cellGids, popColors):
timeInterval = [timeRange[0] + i * spikeBin, timeRange[0] + (i + 1) * spikeBin]
out = list(
zip(
*[
(spkid, spkt)
for spkid, spkt in zip(spkidsAll, spktsAll)
if timeInterval[0] <= spkt <= timeInterval[1]
]
)
)
if len(out) == 2:
spkids, spkts = out
spkids = [int(x) for x in list(set(spkids) & set(cellGids))]
posX = np.array([cells[gid]['tags']['x'] for gid in spkids]) # get all x positions
posY = np.array([cells[gid]['tags'][ycoord] for gid in spkids]) # get all y positions
cellColors = [popColors[cells[gid]['tags']['pop']] for gid in spkids]
sc.set_offsets(np.c_[posX, posY])
sc.set_color(cellColors)
plt.gca().set_title('t = %d' % int(timeRange[0] + (i + 1) * spikeBin))
frames = int((timeRange[1] - timeRange[0]) / spikeBin)
ani = animation.FuncAnimation(
fig,
animate,
frames=frames,
interval=100,
repeat=True,
fargs=(
sc,
timeRange,
spikeBin,
ycoord,
spkidsAll,
spktsAll,
cells,
cellGids,
popColors,
),
)
# save figure data
if saveData:
figData = {
'posX': posX,
'posY': posY,
'posX': cellColors,
'posXpre': posXpre,
'posXpost': posXpost,
'posYpre': posYpre,
'posYpost': posYpost,
'include': include,
'saveData': saveData,
'saveFig': saveFig,
'showFig': showFig,
'lineWidth': lineWidth,
}
_saveFigData(figData, saveData, '2Dnet')
# save figure
if saveFig:
if isinstance(saveFig, basestring):
filename = saveFig
else:
filename = sim.cfg.filename + '_plot_2Dfiring.gif'
ani.save(filename)
# show fig
if showFig:
_showFigure()
return fig, {
'include': include,
'posX': posX,
'posY': posY,
'posXpre': posXpre,
'posXpost': posXpost,
'posYpre': posYpre,
'posYpost': posYpost,
}
# -------------------------------------------------------------------------------------------------------------------
## Calculate number of disynaptic connections
# -------------------------------------------------------------------------------------------------------------------
[docs]
@exception
def calculateDisynaptic(
includePost=['allCells'],
includePre=['allCells'],
includePrePre=['allCells'],
tags=None,
conns=None,
tagsFile=None,
connsFile=None,
):
"""
Function for/to <short description of `netpyne.analysis.network.calculateDisynaptic`>
Parameters
----------
includePost : list
<Short description of includePost>
**Default:** ``['allCells']``
**Options:** ``<option>`` <description of option>
includePre : list
<Short description of includePre>
**Default:** ``['allCells']``
**Options:** ``<option>`` <description of option>
includePrePre : list
<Short description of includePrePre>
**Default:** ``['allCells']``
**Options:** ``<option>`` <description of option>
tags : <``None``?>
<Short description of tags>
**Default:** ``None``
**Options:** ``<option>`` <description of option>
conns : <``None``?>
<Short description of conns>
**Default:** ``None``
**Options:** ``<option>`` <description of option>
tagsFile : <``None``?>
<Short description of tagsFile>
**Default:** ``None``
**Options:** ``<option>`` <description of option>
connsFile : <``None``?>
<Short description of connsFile>
**Default:** ``None``
**Options:** ``<option>`` <description of option>
"""
import json
from time import time
from .. import sim
numDis = 0
totCon = 0
start = time()
if tagsFile:
print('Loading tags file...')
with open(tagsFile, 'r') as fileObj:
tagsTmp = json.load(fileObj)['tags']
tags = {int(k): v for k, v in tagsTmp.items()}
del tagsTmp
if connsFile:
print('Loading conns file...')
with open(connsFile, 'r') as fileObj:
connsTmp = json.load(fileObj)['conns']
conns = {int(k): v for k, v in connsTmp.items()}
del connsTmp
print(' Calculating disynaptic connections...')
# loading from json files
if tags and conns:
cellsPreGids = getCellsIncludeTags(includePre, tags)
cellsPrePreGids = getCellsIncludeTags(includePrePre, tags)
cellsPostGids = getCellsIncludeTags(includePost, tags)
preGidIndex = conns['format'].index('preGid') if 'format' in conns else 0
for postGid in cellsPostGids:
preGidsAll = [
conn[preGidIndex]
for conn in conns[postGid]
if isinstance(conn[preGidIndex], Number) and conn[preGidIndex] in cellsPreGids + cellsPrePreGids
]
preGids = [gid for gid in preGidsAll if gid in cellsPreGids]
for preGid in preGids:
prePreGids = [conn[preGidIndex] for conn in conns[preGid] if conn[preGidIndex] in cellsPrePreGids]
totCon += 1
if not set(prePreGids).isdisjoint(preGidsAll):
numDis += 1
else:
if sim.cfg.compactConnFormat:
if 'preGid' in sim.cfg.compactConnFormat:
preGidIndex = sim.cfg.compactConnFormat.index('preGid') # using compact conn format (list)
else:
print(' Error: cfg.compactConnFormat does not include "preGid"')
return -1
else:
preGidIndex = 'preGid' # using long conn format (dict)
_, cellsPreGids, _ = getCellsInclude(includePre)
_, cellsPrePreGids, _ = getCellsInclude(includePrePre)
cellsPost, _, _ = getCellsInclude(includePost)
for postCell in cellsPost:
print(postCell['gid'])
preGidsAll = [
conn[preGidIndex]
for conn in postCell['conns']
if isinstance(conn[preGidIndex], Number) and conn[preGidIndex] in cellsPreGids + cellsPrePreGids
]
preGids = [gid for gid in preGidsAll if gid in cellsPreGids]
for preGid in preGids:
preCell = sim.net.allCells[preGid]
prePreGids = [conn[preGidIndex] for conn in preCell['conns'] if conn[preGidIndex] in cellsPrePreGids]
totCon += 1
if not set(prePreGids).isdisjoint(preGidsAll):
numDis += 1
print(
' Total disynaptic connections: %d / %d (%.2f%%)'
% (numDis, totCon, float(numDis) / float(totCon) * 100 if totCon > 0 else 0.0)
)
try:
sim.allSimData['disynConns'] = numDis
except:
pass
print(' time ellapsed (s): ', time() - start)
return numDis