"""
Module for generating a shape plot (3D network layout)
"""
try:
basestring
except NameError:
basestring = str
from netpyne import __gui__
if __gui__:
import matplotlib.pyplot as plt
import numpy as np
from numbers import Number
from math import ceil
from ..analysis.utils import colorList, exception, _roundFigures, getCellsInclude, getCellsIncludeTags
from ..analysis.utils import _saveFigData, _showFigure
from .plotter import MetaFigure
[docs]
@exception
def plotShape(
axis=None,
includePre=['all'],
includePost=['all'],
showSyns=False,
showElectrodes=False,
synStyle='.',
synSize=3,
synColor='red',
dist=0.6,
elev=90,
azim=-90,
cvar=None,
cvals=None,
clim=None,
iv=False,
ivprops=None,
includeAxon=True,
bkgColor=None,
aspect='auto',
axisLabels=False,
kind='shape',
returnPlotter=False,
dpi=300,
**kwargs
):
"""
Function to plot morphology of network>
"""
from .. import sim
from neuron import h
print('Plotting 3D cell shape ...')
cellsPreGids = [c.gid for c in sim.getCellsList(includePre)] if includePre else []
cellsPost = sim.getCellsList(includePost)
if not hasattr(sim.net, 'compartCells'):
sim.net.compartCells = [c for c in cellsPost if type(c) is sim.CompartCell]
try:
sim.net.defineCellShapes() # in case some cells had stylized morphologies without 3d pts
except:
pass
saveFig = False
if 'saveFig' in kwargs:
saveFig = kwargs['saveFig']
fontSize = None
if 'fontSize' in kwargs:
fontSize = kwargs['fontSize']
if not iv: # plot using Python instead of interviews
from mpl_toolkits.mplot3d import Axes3D
from netpyne.support import (
morphology as morph,
) # code adapted from https://github.com/ahwillia/PyNeuron-Toolbox
# create secList from include
secs = None
# Set cvals and secs
if not cvals and cvar:
cvals = []
secs = []
# weighNorm
if cvar == 'weightNorm':
for cellPost in cellsPost:
cellSecs = (
list(cellPost.secs.values())
if includeAxon
else [s for s in list(cellPost.secs.values()) if 'axon' not in s['hObj'].hname()]
)
for sec in cellSecs:
if 'weightNorm' in sec:
secs.append(sec['hObj'])
cvals.extend(sec['weightNorm'])
cvals = np.array(cvals)
cvals = cvals / min(cvals)
# numSyns
elif cvar == 'numSyns':
for cellPost in cellsPost:
cellSecs = (
cellPost.secs
if includeAxon
else {k: s for k, s in cellPost.secs.items() if 'axon' not in s['hObj'].hname()}
)
for secLabel, sec in cellSecs.items():
nseg = sec['hObj'].nseg
nsyns = [0] * nseg
secs.append(sec['hObj'])
conns = [
conn
for conn in cellPost.conns
if conn['sec'] == secLabel and conn['preGid'] in cellsPreGids
]
for conn in conns:
nsyns[int(ceil(conn['loc'] * nseg)) - 1] += 1
cvals.extend(nsyns)
cvals = np.array(cvals)
# voltage
elif cvar == 'voltage':
for cellPost in cellsPost:
cellSecs = (
cellPost.secs
if includeAxon
else {k: s for k, s in cellPost.secs.items() if 'axon' not in s['hObj'].hname()}
)
for secLabel, sec in cellSecs.items():
for seg in sec['hObj']:
cvals.append(seg.v)
cvals = np.array(cvals)
if not isinstance(cellsPost[0].secs, dict):
print('Error: Cell sections not available')
return -1
if not secs:
secs = [s['hObj'] for cellPost in cellsPost for s in list(cellPost.secs.values())]
if not includeAxon:
secs = [sec for sec in secs if 'axon' not in sec.hname()]
# Plot shapeplot
cbLabels = {
'numSyns': 'Number of synapses per segment',
'weightNorm': 'Weight scaling',
'voltage': 'Voltage (mV)',
}
# plt.rcParams.update({'font.size': fontSize})
if axis is None:
metaFig = MetaFigure(kind=kind, subplots=None, **kwargs)
fig = metaFig.fig
metaFig.ax.remove()
shapeax = fig.add_subplot(111, projection='3d')
metaFig.ax = shapeax
else:
shapeax = axis
fig = axis.figure
metafig = fig.metafig
numRows = np.shape(metafig.ax)[0]
numCols = np.shape(metafig.ax)[1]
axisLoc = np.where(metafig.ax.ravel() == shapeax)[0][0] + 1
plt.delaxes(shapeax)
figIndex = int(str(numRows) + str(numCols) + str(axisLoc))
shapeax = fig.add_subplot(figIndex, projection='3d')
shapeax.elev = elev # 90
shapeax.azim = azim # -90
shapeax.dist = dist * shapeax.dist
plt.axis(aspect)
cmap = plt.cm.viridis # plt.cm.jet #plt.cm.rainbow #plt.cm.jet #YlOrBr_r
morph.shapeplot(h, shapeax, sections=secs, cvals=cvals, cmap=cmap, clim=clim)
# fix so that axes can be scaled
ax = plt.gca()
def set_axes_equal(ax):
"""Set 3D plot axes to equal scale.
Make axes of 3D plot have equal scale so that spheres appear as
spheres and cubes as cubes. Required since `ax.axis('equal')`
and `ax.set_aspect('equal')` don't work on 3D.
"""
limits = np.array(
[
ax.get_xlim3d(),
ax.get_ylim3d(),
ax.get_zlim3d(),
]
)
origin = np.mean(limits, axis=1)
radius = 0.5 * np.max(np.abs(limits[:, 1] - limits[:, 0]))
_set_axes_radius(ax, origin, radius)
def _set_axes_radius(ax, origin, radius):
x, y, z = origin
ax.set_xlim3d([x - radius, x + radius])
ax.set_ylim3d([y - radius, y + radius])
ax.set_zlim3d([z - radius, z + radius])
ax.set_box_aspect([1, 1, 1]) # IMPORTANT - this is the new, key line
set_axes_equal(ax)
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
if cvals is not None and len(cvals) > 0:
vmin = np.min(cvals)
vmax = np.max(cvals)
if clim is not None:
vmin = np.min(clim)
vmax = np.max(clim)
sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=vmin, vmax=vmax))
sm._A = [] # fake up the array of the scalar mappable
cb = plt.colorbar(sm, fraction=0.15, shrink=0.5, pad=0.05, aspect=20)
if cvar:
cb.set_label(cbLabels[cvar], rotation=90, fontsize=fontSize)
if saveFig == 'movie':
cb.ax.set_title('Time = ' + str(round(h.t, 1)), fontsize=fontSize)
if bkgColor:
shapeax.w_xaxis.set_pane_color(bkgColor)
shapeax.w_yaxis.set_pane_color(bkgColor)
shapeax.w_zaxis.set_pane_color(bkgColor)
# shapeax.grid(False)
# Synapses
if showSyns:
for cellPost in cellsPost:
for sec in list(cellPost.secs.values()):
for synMech in sec['synMechs']:
morph.mark_locations(
h, sec['hObj'], synMech['loc'], markspec=synStyle, color=synColor, markersize=synSize
)
# Electrodes
if showElectrodes:
ax = plt.gca()
colorOffset = 0
if 'avg' in showElectrodes:
showElectrodes.remove('avg')
colorOffset = 1
coords = sim.net.recXElectrode.pos.T[np.array(showElectrodes).astype(int), :]
ax.scatter(
coords[:, 0],
coords[:, 1],
coords[:, 2],
s=150,
c=colorList[colorOffset : len(coords) + colorOffset],
marker='v',
depthshade=False,
edgecolors='k',
linewidth=2,
)
for i in range(coords.shape[0]):
ax.text(coords[i, 0], coords[i, 1], coords[i, 2], ' ' + str(showElectrodes[i]), fontweight='bold')
cb.set_label('Segment total transfer resistance to electrodes (kiloohm)', rotation=90, fontsize=fontSize)
if axisLabels:
shapeax.set_xlabel('x (um)')
shapeax.set_ylabel('y (um)')
shapeax.set_zlabel('z (um)')
else:
shapeax.set_xticklabels([])
shapeax.set_yticklabels([])
shapeax.set_zticklabels([])
if axis is None:
metaFig.finishFig(**kwargs)
else:
if saveFig:
if isinstance(saveFig, basestring):
if saveFig == 'movie':
filename = sim.cfg.filename + '_shape_movie_' + str(round(h.t, 1)) + '.png'
else:
filename = saveFig
else:
filename = sim.cfg.filename + '_shape.png'
plt.savefig(filename, dpi=dpi)
# show fig
# if showFig: _showFigure()
else: # Plot using Interviews
# colors: 0 white, 1 black, 2 red, 3 blue, 4 green, 5 orange, 6 brown, 7 violet, 8 yellow, 9 gray
from neuron import gui
fig = h.Shape()
secList = h.SectionList()
if not ivprops:
ivprops = {'colorSecs': 1, 'colorSyns': 2, 'style': 'O', 'siz': 5}
for cell in [c for c in cellsPost]:
for sec in list(cell.secs.values()):
if 'axon' in sec['hObj'].hname() and not includeAxon:
continue
sec['hObj'].push()
secList.append()
h.pop_section()
if showSyns:
for synMech in sec['synMechs']:
if synMech['hObj']:
# find pre pop using conn[preGid]
# create dict with color for each pre pop; check if exists; increase color counter
# colorsPre[prePop] = colorCounter
# find synMech using conn['loc'], conn['sec'] and conn['synMech']
fig.point_mark(synMech['hObj'], ivprops['colorSyns'], ivprops['style'], ivprops['siz'])
fig.observe(secList)
fig.color_list(secList, ivprops['colorSecs'])
fig.flush()
fig.show(0) # show real diam
# save figure
if saveFig:
if isinstance(saveFig, basestring):
filename = saveFig
else:
filename = sim.cfg.filename + '_' + 'shape.ps'
fig.printfile(filename)
if not iv:
if axis is None:
if returnPlotter:
return metaFig
else:
return fig, {}
else:
return fig, {}
else:
return fig, {}