from typing import ClassVar, NamedTuple, Optional
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from neuralib.typing import PathLike
__all__ = ['VennHandler',
'VennDiagram']
[docs]
class VennHandler(NamedTuple):
subset_a: int
"""whole number with condition a"""
subset_b: int
"""whole number with condition b"""
subset_overlap: int
"""whole number with condition a & b"""
total_set: int | None = None
@property
def chance_level(self) -> float:
if self.total_set is None:
raise ValueError('call with_total first')
fa = self.subset_a / self.total_set
fb = self.subset_b / self.total_set
return fa * fb * 100
[docs]
def with_total(self, total: int) -> 'VennHandler':
"""total set number. should include the non-classified population"""
return self._replace(total_set=total)
[docs]
def get_pure_number(self) -> tuple[int, ...]:
a = self.subset_a - self.subset_overlap
b = self.subset_b - self.subset_overlap
return tuple([a, b, self.subset_overlap])
[docs]
def get_pure_fraction(self) -> tuple[float, ...]:
if self.total_set is None:
raise ValueError('call with_total first')
a, b, o = self.get_pure_number()
total = self.total_set
return tuple([a / total * 100,
b / total * 100,
o / total * 100])
[docs]
class VennDiagram:
DEFAULT_COLORS: ClassVar[tuple[str, ...]] = ('r', 'g', 'b')
[docs]
def __init__(self,
subsets: dict[str, int],
*,
colors: tuple[str, ...] | None = None,
ax: Axes | None = None,
**kwargs):
"""
:param subsets: Dictionary of set label and its value
:param colors: colors of each venn
:param ax: ``Axes``
:param kwargs: additional args passed to ``matplotlib_venn.venn2()`` or ``matplotlib_venn.venn3()``
"""
self.subsets = subsets
self.total: int | None = None
self.contain_intersection: bool = False
self._intersections: dict[str, int] = {}
# fig
if colors is not None and len(colors) != len(self):
raise ValueError('length of colors need to be the same as length of subset label')
self.colors = colors or VennDiagram.DEFAULT_COLORS
self.ax = ax
self.kwargs = kwargs
def __len__(self):
return len(self.subsets)
@property
def intersections(self) -> dict[str, int]:
"""intersection for sets"""
return self._intersections
@property
def max_intersection_areas(self):
"""maximal number of intersection areas"""
if len(self) == 2:
return 1
elif len(self) == 3:
return 4
else:
raise NotImplementedError('')
@property
def labels(self) -> tuple[str, ...]:
"""set names"""
return tuple(self.subsets.keys())
@property
def subsets_percentage(self) -> dict[str, float]:
"""percentage of each subset"""
if self.total is None:
raise RuntimeError('add total first')
source = {
k: round(v / self.total * 100, 2)
for k, v in self.subsets.items()
}
inter = {
k: round(v / self.total * 100, 2)
for k, v in self.intersections.items()
}
return {**source, **inter}
[docs]
def add_total(self, value: int):
"""
Add total value to the venn diagram
:param value: value to be added
"""
self.total = value
[docs]
def add_intersection(self, group: str, value: int):
"""
Add intersection values using "&"
:param group: i.e., `a & b`
:param value: value of the intersection
"""
src = [g.strip() for g in group.split('&')]
for inter in list(self.intersections.keys()):
if all(g in inter for g in src):
raise ValueError('intersection value already existed')
for g in src:
if g not in self.subsets:
raise ValueError(f"Set '{g}' does not exist in the subsets.")
self._intersections['&'.join(src)] = value
[docs]
def get_chance_level(self, *label) -> float:
if self.contain_intersection:
raise RuntimeError('chance level should not contain intersection')
if self.total is None:
raise RuntimeError('add total first')
x = 100
for it in label:
x *= self.subsets[it] / self.total
return x
[docs]
def get_intersection(self, *label: str) -> int:
"""
Get intersection value from labels
:param label:i.e., `a & b`
:return: intersection value
"""
k = '&'.join([*label])
return self.intersections.get(k, 0)
[docs]
def with_intersection(self):
"""Add intersection value into subsets"""
if self.contain_intersection:
raise RuntimeError('already contain intersection')
ret = {}
if len(self) == 2:
inter = self.get_intersection(*self.labels)
for k, v in self.subsets.items():
ret[k] = v + inter
elif len(self) == 3:
ab = self.get_intersection(*self.labels[:2])
ac = self.get_intersection(self.labels[0], self.labels[2])
bc = self.get_intersection(*self.labels[1:])
abc = self.get_intersection(*self.labels)
for i, (k, v) in enumerate(self.subsets.items()):
if i == 0:
ret[k] = v + ab + ac + abc
elif i == 1:
ret[k] = v + ab + bc + abc
elif i == 2:
ret[k] = v + ac + bc + abc
else:
raise NotImplementedError('')
self.subsets = ret
self.contain_intersection = True
# ================ #
# Plotting Methods #
# ================ #
[docs]
def plot(self, add_title: bool = True):
"""
Plot the venn diagram
:param add_title: Add percentage information and total as title
"""
if self.ax is None:
self.ax = plt.gca()
n_subsets = len(self)
if n_subsets == 2:
self._venn2()
elif n_subsets == 3:
self._venn3()
if add_title and self.total is not None:
self.ax.set_title(self.title)
self.ax.set_axis_off()
self.ax.set_xticks([])
self.ax.set_yticks([])
[docs]
@staticmethod
def show():
"""Show figure"""
plt.show()
[docs]
@staticmethod
def savefig(output: PathLike):
"""
Save figure
:param output: fig output
"""
plt.savefig(output)
@property
def title(self) -> str:
"""title of the plot"""
ret = [
f'percentage: {self.subsets_percentage}%',
f'total: {self.total}'
]
return '\n'.join(ret)
# noinspection PyTypeChecker
def _venn2(self):
"""subsets = (a, b, a&b)"""
from matplotlib_venn import venn2
values = list(self.subsets.values())
subsets = (values[0], values[1], self.get_intersection(*self.labels))
labels = (self.labels[0], self.labels[1])
colors = (self.colors[0], self.colors[1])
venn2(subsets=subsets,
set_labels=labels,
set_colors=colors,
ax=self.ax,
**self.kwargs)
# noinspection PyTypeChecker
def _venn3(self):
"""subsets = (a, b, a&b, c, a&c, b&c, a&b&c)"""
from matplotlib_venn import venn3
v = list(self.subsets.values())
a = v[0]
b = v[1]
c = v[2]
ab = self.get_intersection(*self.labels[:2])
ac = self.get_intersection(self.labels[0], self.labels[2])
bc = self.get_intersection(*self.labels[1:])
abc = self.get_intersection(*self.labels)
labels = (self.labels[0], self.labels[1], self.labels[2])
colors = (self.colors[0], self.colors[1], self.colors[2])
venn3(subsets=(a, b, ab, c, ac, bc, abc),
set_labels=labels,
set_colors=colors,
ax=self.ax,
**self.kwargs)