from __future__ import annotations
from collections.abc import Iterator
from pathlib import Path
from typing import Any, NamedTuple, Self, cast, overload
import matplotlib.pyplot as plt
import numpy as np
from argclz import AbstractParser, argument, pos_argument
from matplotlib.axes import Axes
from matplotlib.patches import Circle
from neuralib.typing import PathLike
__all__ = [
'SwcNode',
'SwcFile',
'plot_swc'
]
Identifier = int
IdentifierName = str
IDENTIFIER_DICT: dict[Identifier, IdentifierName] = {
0: 'undefined',
1: 'soma',
2: 'axon',
3: 'basal',
4: 'apical',
5: 'custom'
}
[docs]
class SwcNode(NamedTuple):
n: int
"""node number"""
identifier: Identifier
"""See IDENTIFIER_DICT"""
x: float
"""position x"""
y: float
"""position y"""
z: float
"""position z"""
r: float
"""radius"""
parent: int
"""parent connectivity"""
@property
def identifier_name(self) -> IdentifierName:
return IDENTIFIER_DICT.get(self.identifier, 'custom')
@property
def point(self) -> np.ndarray:
return np.array([self.x, self.y, self.z])
@property
def is_undefined(self) -> bool:
return self.identifier == 0
@property
def is_soma(self) -> bool:
return self.identifier == 1
@property
def is_axon(self) -> bool:
return self.identifier == 2
@property
def is_basal_dendrite(self) -> bool:
return self.identifier == 3
@property
def is_apical_dendrite(self) -> bool:
return self.identifier == 4
@property
def is_dendrite(self) -> bool:
return self.is_basal_dendrite or self.is_apical_dendrite
@property
def is_custom(self) -> bool:
return self.identifier >= 5
[docs]
class SwcFile:
"""SWC File"""
node: list[SwcNode]
[docs]
def __init__(self, node: list[SwcNode]):
self.node = node
[docs]
@classmethod
def load(cls, file: PathLike) -> Self:
"""
:param file: swc filepath
:return: ``SwcFile``
"""
node = []
with Path(file).open('r', encoding='Big5') as f:
for line in f:
line = line.strip()
if len(line) == 0 or line.startswith('#'):
continue
part = line.split()
n = int(part[0])
i = int(part[1])
x = float(part[2])
y = float(part[3])
z = float(part[4])
r = float(part[5])
p = int(part[6])
node.append(SwcNode(n, i, x, y, z, r, p))
return cls(node)
def __str__(self):
line = [str(node) for node in self.node]
return '\n'.join(line)
@overload
def __getitem__(self, item: int) -> SwcNode:
pass
@overload
def __getitem__(self, item: IdentifierName) -> SwcFile:
pass
def __getitem__(self, item: int | str) -> SwcNode | SwcFile:
if isinstance(item, int):
try:
ret = self.node[item - 1] # to index
except IndexError:
ret = None
if ret is not None and ret.n == item:
return ret
raise KeyError(f'item not found: {item}, might be loss parent connection')
elif isinstance(item, str):
if item == 'soma':
node = [n for n in self.foreach_node() if n.is_soma]
elif item == 'axon':
node = [n for n in self.foreach_node() if n.is_axon]
elif item == 'dendrite':
node = [n for n in self.foreach_node() if n.is_dendrite]
elif item == 'basal':
node = [n for n in self.foreach_node() if n.is_basal_dendrite]
elif item == 'apical':
node = [n for n in self.foreach_node() if n.is_apical_dendrite]
elif item == 'dendrite':
node = [n for n in self.foreach_node() if n.is_dendrite]
elif item == 'custom':
node = [n for n in self.foreach_node() if n.is_custom]
elif item == 'undefined':
node = [n for n in self.foreach_node() if n.is_undefined]
else:
raise ValueError('')
return SwcFile(node)
else:
raise TypeError(f'item must be int or str: {type(item)}')
@property
def points(self) -> np.ndarray:
return np.array([[n.x, n.y, n.z] for n in self.foreach_node()])
@property
def radii(self) -> np.ndarray:
return np.array([n.r for n in self.foreach_node()])
@property
def parents(self) -> np.ndarray:
return np.array([n.parent for n in self.foreach_node()])
@property
def unique_identifier(self) -> list[IdentifierName]:
idfs = np.unique([n.identifier for n in self.foreach_node()])
return [
IDENTIFIER_DICT.get(idf, 'custom')
for idf in idfs
]
[docs]
def foreach_identifier(self, as_dict: bool) -> list[SwcFile] | dict[str, SwcFile]:
if as_dict:
return {idf: self[idf] for idf in self.unique_identifier}
else:
return [self[idf] for idf in self.unique_identifier]
[docs]
def foreach_node(self) -> Iterator[SwcNode]:
for node in self.node:
yield node
[docs]
def foreach_line(self) -> Iterator[tuple[SwcNode, SwcNode]]:
for node in self.node:
if node.parent > 0:
yield node, self[node.parent]
# ============== #
# Plot Functions #
# ============== #
Point3D = tuple[float, float, float]
Point2D = tuple[float, float]
DEFAULT_COLOR: dict[IdentifierName, str] = {
'soma': 'b',
'axon': 'r',
'dendrite': 'k',
'undefined': 'k',
'custom': 'k'
}
[docs]
def plot_swc(swc: SwcFile,
radius: bool = True,
color: dict[str, str] | None = None,
as_2d: bool = False):
"""
Plot swc file as 2d
:param swc: ``SwcFile``
:param radius: Plot with radius.
:param color: Color dict. With {identifier name: color coded}
:param as_2d:
"""
if color is None:
color = DEFAULT_COLOR
if as_2d:
_plot_swc_2d(swc, radius, color)
else:
_plot_swc_3d(swc, radius, color)
def projection_2d(p: Point3D) -> Point2D:
"""Default projection function, remove z value.
:param p: 3d points
:return: 2d points
"""
return p[0], p[1]
def smooth_line_radius(ax: Axes,
p1: Point2D,
p2: Point2D,
r1: float,
r2: float,
num: int = 2,
**kwargs):
"""
:param ax: ``Axes``
:param p1: Point 1
:param p2: Point 2
:param r1: Radius 1
:param r2: Radius 2
:param num: Number of segments
:param kwargs: Additional arguments pass to ``plt.plot()``
:return:
"""
px = np.linspace(p1[0], p2[0], num + 1)
py = np.linspace(p1[1], p2[1], num + 1)
lw = np.linspace(r1, r2, num)
for i in range(num):
ax.plot(px[i:i + 2], py[i:i + 2], lw=lw[i], **kwargs)
def _plot_swc_2d(swc, radius, color):
fig, ax = plt.subplots()
for n1, n2 in swc.foreach_line():
c = color.get(n1.identifier_name, 'k')
p1 = projection_2d((n1.x, n1.y, n1.z))
p2 = projection_2d((n2.x, n2.y, n2.z))
if radius:
if n2.is_soma:
ax.add_artist(Circle(p2, n2.r, color=color['soma']))
if not n1.is_soma:
smooth_line_radius(ax, p1, p2, n1.r, n1.r, color=c, solid_capstyle='round')
else:
smooth_line_radius(ax, p1, p2, n1.r, n2.r, color=c, solid_capstyle='round')
else:
px = p1[0], p2[0]
py = p1[1], p2[1]
ax.plot(px, py, color=c, solid_capstyle='round')
ax.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
plt.show()
def _plot_swc_3d(swc: SwcFile,
radius,
color,
spheres_size: float = 3,
lw: float = 5):
import vedo
vedo = cast(Any, vedo)
plotter = vedo.Plotter()
axons = []
axons_line = []
axons_radii = []
dendrites = []
dendrites_line = []
dendrites_radii = []
somata = []
somata_line = []
somata_radii = []
other = []
other_line = []
other_radii = []
for i, n in enumerate(swc.foreach_node()):
if n.parent == -1:
if n.is_soma:
somata.append([n.x, n.y, n.z])
somata_radii.append(10)
continue
r = n.r * spheres_size if radius else 5
if n.is_axon:
axons.append([n.x, n.y, n.z])
axons_line.append([n.parent - 1, i]) # Use parent-child connection for axons
axons_radii.append(r)
elif n.is_dendrite:
dendrites.append([n.x, n.y, n.z])
dendrites_line.append([n.parent - 1, i])
dendrites_radii.append(r)
elif n.is_soma:
somata.append([n.x, n.y, n.z])
somata_line.append([n.parent - 1, i])
somata_radii.append(10) # fix value
elif n.is_undefined or n.is_custom:
other.append([n.x, n.y, n.z])
other_line.append([n.parent - 1, i])
other_radii.append(r)
#
if 'soma' in swc.unique_identifier:
soma_spheres = vedo.Spheres(somata, r=somata_radii, c=color['soma'])
soma_lines = vedo.Lines(swc.points[somata_line], c=color['soma'], lw=lw)
plotter += soma_spheres
plotter += soma_lines
if 'axon' in swc.unique_identifier:
axon_spheres = vedo.Spheres(axons, r=axons_radii, c=color['axon'])
axon_lines = vedo.Lines(swc.points[axons_line], c=color['axon'], lw=lw)
plotter += axon_spheres
plotter += axon_lines
if len(dendrites) > 0:
dendrite_spheres = vedo.Spheres(dendrites, r=dendrites_radii, c=color['dendrite'])
dendrite_lines = vedo.Lines(swc.points[dendrites_line], c=color['dendrite'], lw=lw)
plotter += dendrite_spheres
plotter += dendrite_lines
if len(other) > 0:
other_spheres = vedo.Spheres(other, r=other_radii, c=color.get('custom', 'k'))
other_lines = vedo.Lines(swc.points[other_line], c=color.get('custom', 'k'), lw=lw)
plotter += other_spheres
plotter += other_lines
plotter.show()
# ======== #
# Plot CLI #
# ======== #
class SwcPlotOptions(AbstractParser):
file: str = pos_argument(
'FILE',
help='filepath of the swc file'
)
radius: bool = argument(
'-R', '--radius',
help='Whether plot with radius'
)
as_2d: bool = argument(
'--2d',
help='Whether plot with 2d, otherwise, plot as 3d'
)
def run(self):
swc = SwcFile.load(self.file)
plot_swc(swc, radius=self.radius, as_2d=self.as_2d)
if __name__ == '__main__':
SwcPlotOptions().main()